Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Discord OAuth2 #25

Merged
merged 8 commits into from
Jul 27, 2024
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ FROM scratch
WORKDIR /app

COPY --from=builder /src/rogueserver .
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/

EXPOSE 8001

Expand Down
104 changes: 104 additions & 0 deletions api/account/discord.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
Copyright (C) 2024 Pagefault Games

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package account

import (
"encoding/json"
"errors"
"net/http"
"net/url"
"os"
)

func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
gameUrl := os.Getenv("GAME_URL")
if code == "" {
defer http.Redirect(w, r, gameUrl, http.StatusSeeOther)
return "", errors.New("code is empty")
}

discordId, err := RetrieveDiscordId(code)
if err != nil {
defer http.Redirect(w, r, gameUrl, http.StatusSeeOther)
return "", err
f-fsantos marked this conversation as resolved.
Show resolved Hide resolved
}

return discordId, nil
}

func RetrieveDiscordId(code string) (string, error) {
token, err := http.PostForm("https://discord.com/api/oauth2/token", url.Values{
"client_id": {os.Getenv("DISCORD_CLIENT_ID")},
"client_secret": {os.Getenv("DISCORD_CLIENT_SECRET")},
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {os.Getenv("DISCORD_CALLBACK_URL")},
"scope": {"identify"},
})

if err != nil {
return "", err
}

// extract access_token from token
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}

var tokenResponse TokenResponse
err = json.NewDecoder(token.Body).Decode(&tokenResponse)
if err != nil {
return "", err
}

access_token := tokenResponse.AccessToken
if access_token == "" {
return "", errors.New("access token is empty")
}

client := &http.Client{}
f-fsantos marked this conversation as resolved.
Show resolved Hide resolved
req, err := http.NewRequest("GET", "https://discord.com/api/users/@me", nil)
if err != nil {
return "", err
}

req.Header.Set("Authorization", "Bearer "+access_token)
resp, err := client.Do(req)
if err != nil {
return "", err
}

defer resp.Body.Close()

type User struct {
Id string `json:"id"`
}

var user User
err = json.NewDecoder(resp.Body).Decode(&user)
if err != nil {
return "", err
}

return user.Id, nil
}
97 changes: 97 additions & 0 deletions api/account/google.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright (C) 2024 Pagefault Games

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package account

import (
"encoding/json"
"errors"
"net/http"
"net/url"
"os"

"github.com/golang-jwt/jwt/v5"
)

func HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
gameUrl := os.Getenv("GAME_URL")
if code == "" {
defer http.Redirect(w, r, gameUrl, http.StatusSeeOther)
return "", errors.New("code is empty")
}

googleId, err := RetrieveGoogleId(code)
if err != nil {
defer http.Redirect(w, r, gameUrl, http.StatusSeeOther)
return "", err
}

return googleId, nil
}

func RetrieveGoogleId(code string) (string, error) {
token, err := http.PostForm("https://oauth2.googleapis.com/token", url.Values{
"client_id": {os.Getenv("GOOGLE_CLIENT_ID")},
"client_secret": {os.Getenv("GOOGLE_CLIENT_SECRET")},
"code": {code},
"grant_type": {"authorization_code"},
"redirect_uri": {os.Getenv("GOOGLE_CALLBACK_URL")},
})

if err != nil {
return "", err
}
defer token.Body.Close()
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
IdToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
}
var tokenResponse TokenResponse
err = json.NewDecoder(token.Body).Decode(&tokenResponse)
if err != nil {
return "", err
}

userId, err := parseJWTWithoutValidation(tokenResponse.IdToken)
if err != nil {
return "", err
}

return userId, nil
}

func parseJWTWithoutValidation(idToken string) (string, error) {
parser := jwt.NewParser()

// Use ParseUnverified to parse the token without validation
parsedJwt, _, err := parser.ParseUnverified(idToken, jwt.MapClaims{})
if err != nil {
return "", err
}

claims, ok := parsedJwt.Claims.(jwt.MapClaims)
if !ok {
return "", errors.New("invalid token claims")
}

return claims.GetSubject()
}
13 changes: 10 additions & 3 deletions api/account/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ import (

type InfoResponse struct {
Username string `json:"username"`
DiscordId string `json:"discordId"`
GoogleId string `json:"googleId"`
LastSessionSlot int `json:"lastSessionSlot"`
}

// /account/info - get account info
func Info(username string, uuid []byte) (InfoResponse, error) {
func Info(username string, discordId string, googleId string, uuid []byte) (InfoResponse, error) {
slot, _ := db.GetLatestSessionSaveDataSlot(uuid)

return InfoResponse{Username: username, LastSessionSlot: slot}, nil
response := InfoResponse{
Username: username,
LastSessionSlot: slot,
DiscordId: discordId,
GoogleId: googleId,
}
return response, nil
}
21 changes: 14 additions & 7 deletions api/account/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,25 @@ func Login(username, password string) (LoginResponse, error) {
return response, fmt.Errorf("password doesn't match")
}

token := make([]byte, TokenSize)
_, err = rand.Read(token)
response.Token, err = GenerateTokenForUsername(username)

if err != nil {
return response, fmt.Errorf("failed to generate token: %s", err)
}

err = db.AddAccountSession(username, token)
return response, nil
}

func GenerateTokenForUsername(username string) (string, error) {
token := make([]byte, TokenSize)
_, err := rand.Read(token)
if err != nil {
return response, fmt.Errorf("failed to add account session")
return "", fmt.Errorf("failed to generate token: %s", err)
}

response.Token = base64.StdEncoding.EncodeToString(token)

return response, nil
err = db.AddAccountSession(username, token)
if err != nil {
return "", fmt.Errorf("failed to add account session")
}
return base64.StdEncoding.EncodeToString(token), nil
}
3 changes: 3 additions & 0 deletions api/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ func Init(mux *http.ServeMux) error {
mux.HandleFunc("GET /daily/rankings", handleDailyRankings)
mux.HandleFunc("GET /daily/rankingpagecount", handleDailyRankingPageCount)

// auth
mux.HandleFunc("/auth/{provider}/callback", handleProviderCallback)
mux.HandleFunc("/auth/{provider}/logout", handleProviderLogout)
return nil
}

Expand Down
Loading
Loading