Skip to content

Commit

Permalink
pgctx: add key context (#25)
Browse files Browse the repository at this point in the history
* pgctx: add key context

* adjust api

* add test

* refactor
  • Loading branch information
acoshift authored Dec 1, 2022
1 parent 81a2fc8 commit 687443f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 8 deletions.
33 changes: 25 additions & 8 deletions pgctx/pgctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,35 @@ type Queryer interface {
PrepareContext(context.Context, string) (*sql.Stmt, error)
}

func NewKeyContext(ctx context.Context, key any, db DB) context.Context {
return context.WithValue(ctx, ctxKeyDB{key}, db)
}

// NewContext creates new context
func NewContext(ctx context.Context, db DB) context.Context {
ctx = context.WithValue(ctx, ctxKeyDB{}, db)
ctx = context.WithValue(ctx, ctxKeyQueryer{}, db)
return ctx
return NewKeyContext(ctx, nil, db)
}

// Middleware injects db into request's context
func Middleware(db DB) func(h http.Handler) http.Handler {
func KeyMiddleware(key any, db DB) func(h http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(NewContext(r.Context(), db))
r = r.WithContext(NewKeyContext(r.Context(), key, db))
h.ServeHTTP(w, r)
})
}
}

// Middleware injects db into request's context
func Middleware(db DB) func(h http.Handler) http.Handler {
return KeyMiddleware(nil, db)
}

// With creates new empty key context with db from keyed context
func With(ctx context.Context, key any) context.Context {
db := ctx.Value(ctxKeyDB{key})
return context.WithValue(ctx, ctxKeyDB{}, db)
}

type wrapTx struct {
*sql.Tx
onCommitted []func(ctx context.Context)
Expand Down Expand Up @@ -102,12 +114,17 @@ func Committed(ctx context.Context, f func(ctx context.Context)) {
}

type (
ctxKeyDB struct{}
ctxKeyDB struct{
key any
}
ctxKeyQueryer struct{}
)

func q(ctx context.Context) Queryer {
return ctx.Value(ctxKeyQueryer{}).(Queryer)
if q, ok := ctx.Value(ctxKeyQueryer{}).(Queryer); ok {
return q
}
return ctx.Value(ctxKeyDB{}).(Queryer)
}

// QueryRow calls db.QueryRowContext
Expand Down
41 changes: 41 additions & 0 deletions pgctx/pgctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ func TestNewContext(t *testing.T) {
})
}

type testKey1 struct{}

func TestNewKeyContext(t *testing.T) {
t.Parallel()

assert.NotPanics(t, func() {
db, _, err := sqlmock.New()
assert.NoError(t, err)
ctx := pgctx.NewKeyContext(context.Background(), testKey1{}, db)
assert.NotNil(t, ctx)
})
}

func TestMiddleware(t *testing.T) {
t.Parallel()

Expand All @@ -55,6 +68,34 @@ func TestMiddleware(t *testing.T) {
assert.True(t, called)
}

func TestKeyMiddleware(t *testing.T) {
t.Parallel()

db, _, err := sqlmock.New()
assert.NoError(t, err)

called := false
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
pgctx.KeyMiddleware(testKey1{}, db)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
ctx := r.Context()
assert.NotPanics(t, func() {
pgctx.QueryRow(pgctx.With(ctx, testKey1{}), "select 1")
})
assert.NotPanics(t, func() {
pgctx.Query(pgctx.With(ctx, testKey1{}), "select 1")
})
assert.NotPanics(t, func() {
pgctx.Exec(pgctx.With(ctx, testKey1{}), "select 1")
})
assert.Panics(t, func() {
pgctx.QueryRow(ctx, "select 1")
})
})).ServeHTTP(w, r)
assert.True(t, called)
}

func TestRunInTx(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 687443f

Please sign in to comment.