Skip to content

Commit

Permalink
user sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
mbund committed Oct 25, 2024
1 parent 8a17589 commit 2ad2b6b
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 11 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module main
go 1.23.2

require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/k0kubun/pp/v3 v3.2.0
modernc.org/sqlite v1.33.1
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkp
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
Expand Down
163 changes: 155 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"crypto/tls"
"database/sql"
"embed"
Expand All @@ -14,21 +15,150 @@ import (
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/k0kubun/pp/v3"
_ "modernc.org/sqlite"
)

func hello(s AuthProvider) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
fmt.Println("/hello")
type Router struct {
db *sql.DB
authProvider AuthProvider
jwtSecret []byte
}

const AUTH_COOKIE_NAME string = "csc-auth"

func (r *Router) signin(w http.ResponseWriter, req *http.Request) {
attributes := r.authProvider.attributesFromContext(req.Context())

pp.Println(attributes)
now := time.Now()

token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"idm_id": attributes.IDMUID,
"iat": now.Unix(),
"exp": now.AddDate(1, 0, 0).Unix(),
})

signedTokenString, err := token.SignedString(r.jwtSecret)

if err != nil {
log.Fatalln("Failed to sign JWT:", err)
}

http.SetCookie(w, &http.Cookie{
Name: AUTH_COOKIE_NAME,
HttpOnly: true,
Value: signedTokenString,
MaxAge: 365 * 24 * 60 * 60, // 1 year
SameSite: http.SameSiteLaxMode,
Path: "/",
})

nameNum := strings.TrimSuffix(attributes.Email, "@osu.edu")
nameNum = strings.TrimSuffix(nameNum, "@buckeyemail.osu.edu")

attributes := s.attributesFromContext(r.Context())
pp.Println(attributes)
student := false
alum := false
employee := false
faculty := false

fmt.Fprintf(w, "Bye, %s!", attributes.GivenName)
for _, affiliation := range attributes.Affiliations {
if affiliation == "[email protected]" {
student = true
} else if affiliation == "[email protected]" {
alum = true
} else if affiliation == "[email protected]" {
employee = true
} else if affiliation == "[email protected]" {
faculty = true
}
}

r.db.Exec(`

Check failure on line 78 in main.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `r.db.Exec` is not checked (errcheck)
INSERT OR REPLACE INTO users (idm_id, buck_id, name_num, display_name, student, alum, employee, faculty)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
`, attributes.IDMUID, attributes.BuckID, nameNum, attributes.DisplayName, student, alum, employee, faculty)

redirect := req.URL.Query().Get("redirect")
if redirect != "" {
http.Redirect(w, req, redirect, http.StatusTemporaryRedirect)
return
}

fmt.Fprintf(w, "Hello, %s!", attributes.GivenName)
}

func getUserIDFromContext(ctx context.Context) (string, bool) {
userId, ok := ctx.Value(CONTEXT_USER_ID_KEY).(string)

return userId, ok
}

func (r *Router) hello(w http.ResponseWriter, req *http.Request) {
userId, hasUserId := getUserIDFromContext(req.Context())

if hasUserId {
row := r.db.QueryRow("SELECT display_name FROM users WHERE idm_id = ?", userId)
var displayName string
row.Scan(&displayName)

Check failure on line 104 in main.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `row.Scan` is not checked (errcheck)
fmt.Fprintf(w, "Hello, %s!", displayName)
} else {
fmt.Fprintln(w, "Hello, unknown user!")
}
}

type contextUserIdType int

const CONTEXT_USER_ID_KEY contextUserIdType = iota

func (r *Router) InjectJwtMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
cookie, err := req.Cookie(AUTH_COOKIE_NAME)
if err != nil {
handler.ServeHTTP(w, req)
return
}

token, err := jwt.Parse(cookie.Value, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

return r.jwtSecret, nil
})
if err != nil {
log.Println(err)
http.Redirect(w, req, fmt.Sprintf("/signin?redirect=%v", req.URL.Path), http.StatusTemporaryRedirect)
return
}

claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
log.Println("Invalid token", token)
http.Redirect(w, req, fmt.Sprintf("/signin?redirect=%v", req.URL.Path), http.StatusTemporaryRedirect)
return
}

idm_id := claims["idm_id"].(string)

req = req.WithContext(context.WithValue(req.Context(), CONTEXT_USER_ID_KEY, idm_id))
handler.ServeHTTP(w, req)
})
}

func (r *Router) EnforceJwtMiddleware(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, hasUserId := getUserIDFromContext(req.Context())
if !hasUserId {
http.Redirect(w, req, fmt.Sprintf("/signin?redirect=%v", req.URL.Path), http.StatusTemporaryRedirect)
return
}

handler.ServeHTTP(w, req)
})
}

//go:embed migrations/*
var migrations embed.FS

Expand Down Expand Up @@ -57,7 +187,6 @@ func main() {
log.Fatalln("Failed to read", entry.Name(), err)
}
sql := string(data)
fmt.Println(sql)
_, err = db.Exec(sql)
if err != nil {
log.Fatalln("Failed to run", entry.Name(), err)
Expand All @@ -83,7 +212,25 @@ func main() {
authProvider, _ = samlAuthProvider(mux, rootURL, &keyPair)
}

mux.Handle("/hello", authProvider.requireAuth(http.HandlerFunc(hello(authProvider))))
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
if authEnvironment != "" && authEnvironment != "saml" {
log.Fatalln("JWT_SECRET not set")
}

log.Println("DEFAULTING JWT_SECRET TO `secret` DO NOT RUN IN PRODUCTION")
jwtSecret = "secret"
}

router := &Router{
db: db,
authProvider: authProvider,
jwtSecret: []byte(jwtSecret),
}

mux.Handle("/hello", router.InjectJwtMiddleware(router.EnforceJwtMiddleware(http.HandlerFunc(router.hello))))
// mux.Handle("/hello", router.InjectJwtMiddleware(http.HandlerFunc(router.hello)))
mux.Handle("/signin", authProvider.requireAuth(http.HandlerFunc(router.signin)))
mux.Handle("/logout", authProvider.requireAuth(http.HandlerFunc(authProvider.globalLogout)))

if authEnvironment == "saml" {
Expand Down
10 changes: 7 additions & 3 deletions migrations/001.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ CREATE TABLE IF NOT EXISTS users (
name_num TEXT NOT NULL,
display_name TEXT NOT NULL,
last_login INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
-- Student, allumni, or employee
affiliation TEXT NOT NULL

-- 0 or 1 depending on if the user has the affiliation
student INTEGER NOT NULL,
alum INTEGER NOT NULL,
employee INTEGER NOT NULL,
faculty INTEGER NOT NULL
);

CREATE INDEX IF NOT EXISTS users_discord_id ON users (discord_id);
CREATE INDEX IF NOT EXISTS users_buck_id ON users (buck_id);
CREATE INDEX IF NOT EXISTS users_affiliation ON users (affiliation);
CREATE INDEX IF NOT EXISTS users_student ON users (student);

CREATE TABLE IF NOT EXISTS attendance_records (
user_id INTEGER NOT NULL,
Expand Down
2 changes: 2 additions & 0 deletions tls_certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func getTlsCert() (*tls.Certificate, error) {
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})

os.MkdirAll("tlskeys", 0700)

Check failure on line 67 in tls_certs.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `os.MkdirAll` is not checked (errcheck)

if err := os.WriteFile("tlskeys/auth-test-osucyber-club-selfsigned-key.pem", privateKeyPEM, 0600); err != nil {
return nil, err
}
Expand Down

0 comments on commit 2ad2b6b

Please sign in to comment.