diff --git a/pgctx/pgctx.go b/pgctx/pgctx.go index 111f0b4..880e6d9 100644 --- a/pgctx/pgctx.go +++ b/pgctx/pgctx.go @@ -22,23 +22,35 @@ type Queryer interface { PrepareContext(context.Context, string) (*sql.Stmt, error) } +func NewKeyContext(ctx context.Context, key any, db DB) context.Context { + return context.WithValue(ctx, ctxKeyDB{key}, db) +} + // NewContext creates new context func NewContext(ctx context.Context, db DB) context.Context { - ctx = context.WithValue(ctx, ctxKeyDB{}, db) - ctx = context.WithValue(ctx, ctxKeyQueryer{}, db) - return ctx + return NewKeyContext(ctx, nil, db) } -// Middleware injects db into request's context -func Middleware(db DB) func(h http.Handler) http.Handler { +func KeyMiddleware(key any, db DB) func(h http.Handler) http.Handler { return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r = r.WithContext(NewContext(r.Context(), db)) + r = r.WithContext(NewKeyContext(r.Context(), key, db)) h.ServeHTTP(w, r) }) } } +// Middleware injects db into request's context +func Middleware(db DB) func(h http.Handler) http.Handler { + return KeyMiddleware(nil, db) +} + +// With creates new empty key context with db from keyed context +func With(ctx context.Context, key any) context.Context { + db := ctx.Value(ctxKeyDB{key}) + return context.WithValue(ctx, ctxKeyDB{}, db) +} + type wrapTx struct { *sql.Tx onCommitted []func(ctx context.Context) @@ -102,12 +114,17 @@ func Committed(ctx context.Context, f func(ctx context.Context)) { } type ( - ctxKeyDB struct{} + ctxKeyDB struct{ + key any + } ctxKeyQueryer struct{} ) func q(ctx context.Context) Queryer { - return ctx.Value(ctxKeyQueryer{}).(Queryer) + if q, ok := ctx.Value(ctxKeyQueryer{}).(Queryer); ok { + return q + } + return ctx.Value(ctxKeyDB{}).(Queryer) } // QueryRow calls db.QueryRowContext diff --git a/pgctx/pgctx_test.go b/pgctx/pgctx_test.go index 34c92dd..cf76dfe 100644 --- a/pgctx/pgctx_test.go +++ b/pgctx/pgctx_test.go @@ -30,6 +30,19 @@ func TestNewContext(t *testing.T) { }) } +type testKey1 struct{} + +func TestNewKeyContext(t *testing.T) { + t.Parallel() + + assert.NotPanics(t, func() { + db, _, err := sqlmock.New() + assert.NoError(t, err) + ctx := pgctx.NewKeyContext(context.Background(), testKey1{}, db) + assert.NotNil(t, ctx) + }) +} + func TestMiddleware(t *testing.T) { t.Parallel() @@ -55,6 +68,34 @@ func TestMiddleware(t *testing.T) { assert.True(t, called) } +func TestKeyMiddleware(t *testing.T) { + t.Parallel() + + db, _, err := sqlmock.New() + assert.NoError(t, err) + + called := false + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + pgctx.KeyMiddleware(testKey1{}, db)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + ctx := r.Context() + assert.NotPanics(t, func() { + pgctx.QueryRow(pgctx.With(ctx, testKey1{}), "select 1") + }) + assert.NotPanics(t, func() { + pgctx.Query(pgctx.With(ctx, testKey1{}), "select 1") + }) + assert.NotPanics(t, func() { + pgctx.Exec(pgctx.With(ctx, testKey1{}), "select 1") + }) + assert.Panics(t, func() { + pgctx.QueryRow(ctx, "select 1") + }) + })).ServeHTTP(w, r) + assert.True(t, called) +} + func TestRunInTx(t *testing.T) { t.Parallel()