diff --git a/api/common.go b/api/common.go index dec5d2d..53541c4 100644 --- a/api/common.go +++ b/api/common.go @@ -3,55 +3,32 @@ package api import ( "encoding/base64" "fmt" - "net/http" "github.com/pagefaultgames/pokerogue-server/api/account" "github.com/pagefaultgames/pokerogue-server/api/daily" "github.com/pagefaultgames/pokerogue-server/db" ) -func Init(mux *http.ServeMux) { +func Init() { scheduleStatRefresh() daily.Init() - - // account - mux.HandleFunc("GET /api/account/info", handleAccountInfo) - mux.HandleFunc("POST /api/account/register", handleAccountRegister) - mux.HandleFunc("POST /api/account/login", handleAccountLogin) - mux.HandleFunc("GET /api/account/logout", handleAccountLogout) - - // game - mux.HandleFunc("GET /api/game/playercount", handleGamePlayerCount) - mux.HandleFunc("GET /api/game/titlestats", handleGameTitleStats) - mux.HandleFunc("GET /api/game/classicsessioncount", handleGameClassicSessionCount) - - // savedata - mux.HandleFunc("GET /api/savedata/get", handleSaveData) - mux.HandleFunc("POST /api/savedata/update", handleSaveData) - mux.HandleFunc("GET /api/savedata/delete", handleSaveData) - mux.HandleFunc("POST /api/savedata/clear", handleSaveData) - - // daily - mux.HandleFunc("GET /api/daily/seed", handleDailySeed) - mux.HandleFunc("GET /api/daily/rankings", handleDailyRankings) - mux.HandleFunc("GET /api/daily/rankingpagecount", handleDailyRankingPageCount) } -func getUsernameFromRequest(r *http.Request) (string, error) { - if r.Header.Get("Authorization") == "" { +func usernameFromTokenHeader(token string) (string, error) { + if token == "" { return "", fmt.Errorf("missing token") } - token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) + decoded, err := base64.StdEncoding.DecodeString(token) if err != nil { return "", fmt.Errorf("failed to decode token: %s", err) } - if len(token) != account.TokenSize { + if len(decoded) != account.TokenSize { return "", fmt.Errorf("invalid token length: got %d, expected %d", len(token), account.TokenSize) } - username, err := db.FetchUsernameFromToken(token) + username, err := db.FetchUsernameFromToken(decoded) if err != nil { return "", fmt.Errorf("failed to validate token: %s", err) } @@ -59,21 +36,21 @@ func getUsernameFromRequest(r *http.Request) (string, error) { return username, nil } -func getUUIDFromRequest(r *http.Request) ([]byte, error) { - if r.Header.Get("Authorization") == "" { +func uuidFromTokenHeader(token string) ([]byte, error) { + if token == "" { return nil, fmt.Errorf("missing token") } - token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) + decoded, err := base64.StdEncoding.DecodeString(token) if err != nil { return nil, fmt.Errorf("failed to decode token: %s", err) } - if len(token) != account.TokenSize { + if len(decoded) != account.TokenSize { return nil, fmt.Errorf("invalid token length: got %d, expected %d", len(token), account.TokenSize) } - uuid, err := db.FetchUUIDFromToken(token) + uuid, err := db.FetchUUIDFromToken(decoded) if err != nil { return nil, fmt.Errorf("failed to validate token: %s", err) } diff --git a/api/endpoints.go b/api/endpoints.go index 8f0d97e..6298b4e 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -12,6 +12,7 @@ import ( "github.com/pagefaultgames/pokerogue-server/api/daily" "github.com/pagefaultgames/pokerogue-server/api/savedata" "github.com/pagefaultgames/pokerogue-server/defs" + "github.com/valyala/fasthttp" ) /* @@ -20,146 +21,135 @@ import ( Handlers should not return serialized JSON, instead return the struct itself. */ -func handleAccountInfo(w http.ResponseWriter, r *http.Request) { - username, err := getUsernameFromRequest(r) +func HandleAccountInfo(ctx *fasthttp.RequestCtx) { + username, err := usernameFromTokenHeader(string(ctx.Request.Header.Peek("Authorization"))) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(ctx, err, http.StatusBadRequest) return } - uuid, err := getUUIDFromRequest(r) // lazy + uuid, err := uuidFromTokenHeader(string(ctx.Request.Header.Peek("Authorization"))) // lazy if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(ctx, err, http.StatusBadRequest) return } response, err := account.Info(username, uuid) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(ctx, err, http.StatusInternalServerError) return } - err = json.NewEncoder(w).Encode(response) + err = json.NewEncoder(ctx.Response.BodyWriter()).Encode(response) if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + httpError(ctx, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) return } } -func handleAccountRegister(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() +func HandleAccountRegister(ctx *fasthttp.RequestCtx) { + err := account.Register(string(ctx.PostArgs().Peek("username")), string(ctx.PostArgs().Peek("password"))) if err != nil { - httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) + httpError(ctx, err, http.StatusInternalServerError) return } - err = account.Register(r.Form.Get("username"), r.Form.Get("password")) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusOK) + ctx.SetStatusCode(http.StatusOK) } -func handleAccountLogin(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() +func HandleAccountLogin(ctx *fasthttp.RequestCtx) { + response, err := account.Login(string(ctx.PostArgs().Peek("username")), string(ctx.PostArgs().Peek("password"))) if err != nil { - httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) + httpError(ctx, err, http.StatusInternalServerError) return } - response, err := account.Login(r.Form.Get("username"), r.Form.Get("password")) + err = json.NewEncoder(ctx.Response.BodyWriter()).Encode(response) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } - - err = json.NewEncoder(w).Encode(response) - if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + httpError(ctx, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) return } } -func handleAccountLogout(w http.ResponseWriter, r *http.Request) { - token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) +func HandleAccountLogout(ctx *fasthttp.RequestCtx) { + token, err := base64.StdEncoding.DecodeString(string(ctx.Request.Header.Peek("Authorization"))) if err != nil { - httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest) + httpError(ctx, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest) return } err = account.Logout(token) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(ctx, err, http.StatusInternalServerError) return } - w.WriteHeader(http.StatusOK) + ctx.SetStatusCode(http.StatusOK) } -func handleGamePlayerCount(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(strconv.Itoa(playerCount))) +func HandleGamePlayerCount(ctx *fasthttp.RequestCtx) { + ctx.SetBody([]byte(strconv.Itoa(playerCount))) } -func handleGameTitleStats(w http.ResponseWriter, r *http.Request) { - err := json.NewEncoder(w).Encode(defs.TitleStats{ +func HandleGameTitleStats(ctx *fasthttp.RequestCtx) { + err := json.NewEncoder(ctx.Response.BodyWriter()).Encode(defs.TitleStats{ PlayerCount: playerCount, BattleCount: battleCount, }) if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + httpError(ctx, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) return } } -func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(strconv.Itoa(classicSessionCount))) +func HandleGameClassicSessionCount(ctx *fasthttp.RequestCtx) { + ctx.SetBody([]byte(strconv.Itoa(classicSessionCount))) } -func handleSaveData(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) +func HandleSaveData(ctx *fasthttp.RequestCtx) { + uuid, err := uuidFromTokenHeader(string(ctx.Request.Header.Peek("Authorization"))) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(ctx, err, http.StatusBadRequest) return } datatype := -1 - if r.URL.Query().Has("datatype") { - datatype, err = strconv.Atoi(r.URL.Query().Get("datatype")) + + if ctx.QueryArgs().Has("datatype") { + datatype, err = strconv.Atoi(string(ctx.QueryArgs().Peek("datatype"))) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(ctx, err, http.StatusBadRequest) return } } var slot int - if r.URL.Query().Has("slot") { - slot, err = strconv.Atoi(r.URL.Query().Get("slot")) + if ctx.QueryArgs().Has("slot") { + slot, err = strconv.Atoi(string(ctx.QueryArgs().Peek("slot"))) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(ctx, err, http.StatusBadRequest) return } } var save any // /savedata/get and /savedata/delete specify datatype, but don't expect data in body - if r.URL.Path != "/api/savedata/get" && r.URL.Path != "/api/savedata/delete" { + if string(ctx.Path()) != "/api/savedata/get" && string(ctx.Path()) != "/api/savedata/delete" { if datatype == 0 { var system defs.SystemSaveData - err = json.NewDecoder(r.Body).Decode(&system) + err = json.Unmarshal(ctx.Request.Body(), &system) if err != nil { - httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + httpError(ctx, fmt.Errorf("failed to unmarshal request body: %s", err), http.StatusBadRequest) return } save = system // /savedata/clear doesn't specify datatype, it is assumed to be 1 (session) - } else if datatype == 1 || r.URL.Path == "/api/savedata/clear" { + } else if datatype == 1 || string(ctx.Path()) == "/api/savedata/clear" { var session defs.SessionSaveData - err = json.NewDecoder(r.Body).Decode(&session) + err = json.Unmarshal(ctx.Request.Body(), &session) if err != nil { - httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + httpError(ctx, fmt.Errorf("failed to unmarshal request body: %s", err), http.StatusBadRequest) return } @@ -167,7 +157,7 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { } } - switch r.URL.Path { + switch string(ctx.Path()) { case "/api/savedata/get": save, err = savedata.Get(uuid, datatype, slot) case "/api/savedata/update": @@ -177,7 +167,7 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { case "/api/savedata/clear": s, ok := save.(defs.SessionSaveData) if !ok { - httpError(w, r, fmt.Errorf("save data is not type SessionSaveData"), http.StatusBadRequest) + httpError(ctx, fmt.Errorf("save data is not type SessionSaveData"), http.StatusBadRequest) return } @@ -185,84 +175,86 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { save, err = savedata.Clear(uuid, slot, daily.Seed(), s) } if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(ctx, err, http.StatusInternalServerError) return } - if save == nil || r.URL.Path == "/api/savedata/update" { - w.WriteHeader(http.StatusOK) + if save == nil || string(ctx.Path()) == "/api/savedata/update" { + ctx.SetStatusCode(http.StatusOK) return } - err = json.NewEncoder(w).Encode(save) + err = json.NewEncoder(ctx.Response.BodyWriter()).Encode(save) if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + httpError(ctx, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) return } } -func handleDailySeed(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(daily.Seed())) +func HandleDailySeed(ctx *fasthttp.RequestCtx) { + ctx.Response.SetBody([]byte(daily.Seed())) } -func handleDailyRankings(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) +func HandleDailyRankings(ctx *fasthttp.RequestCtx) { + uuid, err := uuidFromTokenHeader(string(ctx.Request.Header.Peek("Authorization"))) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(ctx, err, http.StatusBadRequest) return } var category int - if r.URL.Query().Has("category") { - category, err = strconv.Atoi(r.URL.Query().Get("category")) + + if ctx.QueryArgs().Has("category") { + category, err = strconv.Atoi(string(ctx.QueryArgs().Peek("category"))) if err != nil { - httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) + httpError(ctx, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) return } } page := 1 - if r.URL.Query().Has("page") { - page, err = strconv.Atoi(r.URL.Query().Get("page")) + if ctx.QueryArgs().Has("page") { + page, err = strconv.Atoi(string(ctx.QueryArgs().Peek("page"))) if err != nil { - httpError(w, r, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest) + httpError(ctx, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest) return } } rankings, err := daily.Rankings(uuid, category, page) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(ctx, err, http.StatusInternalServerError) return } - err = json.NewEncoder(w).Encode(rankings) + err = json.NewEncoder(ctx.Response.BodyWriter()).Encode(rankings) if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + httpError(ctx, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) return } } -func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) { +func HandleDailyRankingPageCount(ctx *fasthttp.RequestCtx) { var category int - if r.URL.Query().Has("category") { + if ctx.QueryArgs().Has("category") { var err error - category, err = strconv.Atoi(r.URL.Query().Get("category")) + category, err = strconv.Atoi(string(ctx.QueryArgs().Peek("category"))) if err != nil { - httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) + httpError(ctx, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) return } } count, err := daily.RankingPageCount(category) if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + httpError(ctx, err, http.StatusInternalServerError) + return } - w.Write([]byte(strconv.Itoa(count))) + ctx.SetBody([]byte(strconv.Itoa(count))) } -func httpError(w http.ResponseWriter, r *http.Request, err error, code int) { - log.Printf("%s: %s\n", r.URL.Path, err) - http.Error(w, err.Error(), code) +func httpError(ctx *fasthttp.RequestCtx, err error, code int) { + log.Printf("%s: %s\n", ctx.Path(), err) + ctx.Error(err.Error(), code) } diff --git a/go.mod b/go.mod index 6675c3e..f90384b 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,14 @@ go 1.22 require ( github.com/go-sql-driver/mysql v1.7.1 - github.com/klauspost/compress v1.17.4 + github.com/klauspost/compress v1.17.6 github.com/robfig/cron/v3 v3.0.1 - golang.org/x/crypto v0.16.0 + github.com/valyala/fasthttp v1.52.0 + golang.org/x/crypto v0.19.0 ) -require golang.org/x/sys v0.15.0 // indirect +require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/sys v0.17.0 // indirect +) diff --git a/go.sum b/go.sum index 88cd836..cb45b59 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,16 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= +github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0= +github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/pokerogue-server.go b/pokerogue-server.go index a147f63..ffb201b 100644 --- a/pokerogue-server.go +++ b/pokerogue-server.go @@ -4,12 +4,15 @@ import ( "encoding/gob" "flag" "log" - "net/http" + "strings" "github.com/pagefaultgames/pokerogue-server/api" "github.com/pagefaultgames/pokerogue-server/db" + "github.com/valyala/fasthttp" ) +var serveStaticContent fasthttp.RequestHandler + func main() { // flag stuff addr := flag.String("addr", "0.0.0.0:80", "network address for api to listen on") @@ -36,18 +39,52 @@ func main() { } // start web server - mux := http.NewServeMux() - - api.Init(mux) - - mux.Handle("/", http.FileServer(http.Dir(*wwwpath))) + serveStaticContent = fasthttp.FSHandler(*wwwpath, 0) + + api.Init() if *tlscert != "" && *tlskey != "" { - err = http.ListenAndServeTLS(*addr, *tlscert, *tlskey, mux) + err = fasthttp.ListenAndServeTLS(*addr, *tlscert, *tlskey, serve) } else { - err = http.ListenAndServe(*addr, mux) + err = fasthttp.ListenAndServe(*addr, serve) } if err != nil { log.Fatalf("failed to create http server or server errored: %s", err) } } + +func serve(ctx *fasthttp.RequestCtx) { + if strings.HasPrefix(string(ctx.Path()), "/api") { + switch string(ctx.Path()) { + case "/api/account/info": + api.HandleAccountInfo(ctx) + case "/api/account/register": + api.HandleAccountRegister(ctx) + case "/api/account/login": + api.HandleAccountLogin(ctx) + case "/api/account/logout": + api.HandleAccountLogout(ctx) + + case "/api/game/playercount": + api.HandleGamePlayerCount(ctx) + case "/api/game/titlestats": + api.HandleGameTitleStats(ctx) + case "/api/game/classicsessioncount": + api.HandleGameClassicSessionCount(ctx) + + case "/api/savedata/get", "/api/savedata/update", "/api/savedata/delete", "/api/savedata/clear": + api.HandleSaveData(ctx) + + case "/api/daily/seed": + api.HandleDailySeed(ctx) + case "/api/daily/rankings": + api.HandleDailyRankings(ctx) + case "/api/daily/rankingpagecount": + api.HandleDailyRankingPageCount(ctx) + } + + return + } + + serveStaticContent(ctx) +}