diff --git a/systemapi/middleware.go b/systemapi/middleware.go index d852ca3..10d535e 100644 --- a/systemapi/middleware.go +++ b/systemapi/middleware.go @@ -10,20 +10,23 @@ import ( func BasicAuth(realm string, getCreds func() map[string]string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Loading credentials dynamically because they can be updated at runtime creds := getCreds() + // If no credentials are set, just pass through (unauthenticated) if len(creds) == 0 { - // if no credentials are set, just pass through next.ServeHTTP(w, r) return } + // Load credentials from request user, pass, ok := r.BasicAuth() if !ok { basicAuthFailed(w, realm) return } + // Compare to allowed credentials credPass, credUserOk := creds[user] if !credUserOk || subtle.ConstantTimeCompare([]byte(pass), []byte(credPass)) != 1 { basicAuthFailed(w, realm) diff --git a/systemapi/server_test.go b/systemapi/server_test.go index 25dd6b6..e0cbf51 100644 --- a/systemapi/server_test.go +++ b/systemapi/server_test.go @@ -27,37 +27,34 @@ func getTestConfig() *HTTPServerConfig { } } -func TestGeneralHandlers(t *testing.T) { - // Create the config - cfg := getTestConfig() +func execRequest(t *testing.T, router http.Handler, method, url string, body io.Reader) *httptest.ResponseRecorder { + t.Helper() + req, err := http.NewRequest(method, url, body) + require.NoError(t, err) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + return rr +} +func TestGeneralHandlers(t *testing.T) { // Instantiate the server - srv, err := NewServer(cfg) + srv, err := NewServer(getTestConfig()) require.NoError(t, err) router := srv.getRouter() // Test /livez - req, err := http.NewRequest(http.MethodGet, "/livez", nil) - require.NoError(t, err) - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + rr := execRequest(t, router, http.MethodGet, "/livez", nil) require.Equal(t, http.StatusOK, rr.Code) // Test /api/v1/events - req, err = http.NewRequest(http.MethodGet, "/api/v1/events", nil) - require.NoError(t, err) - rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + rr = execRequest(t, router, http.MethodGet, "/api/v1/events", nil) require.Equal(t, http.StatusOK, rr.Code) body, err := io.ReadAll(rr.Body) require.NoError(t, err) require.Equal(t, "[]\n", string(body)) // Add an event - req, err = http.NewRequest(http.MethodGet, "/api/v1/new_event?message=foo", nil) - require.NoError(t, err) - rr = httptest.NewRecorder() - router.ServeHTTP(rr, req) + rr = execRequest(t, router, http.MethodGet, "/api/v1/new_event?message=foo", nil) require.Equal(t, http.StatusOK, rr.Code) require.Len(t, srv.events, 1) } @@ -79,6 +76,7 @@ func TestBasicAuth(t *testing.T) { require.NoError(t, err) router := srv.getRouter() + // Helper to get /livez with and without basic auth getLiveZ := func(basicAuthUser, basicAuthPass string) int { req, err := http.NewRequest(http.MethodGet, "/livez", nil) if basicAuthUser != "" { @@ -94,13 +92,10 @@ func TestBasicAuth(t *testing.T) { require.Equal(t, http.StatusOK, getLiveZ("", "")) // Set a basic auth secret - req, err := http.NewRequest(http.MethodPost, "/api/v1/set-basic-auth", bytes.NewReader(basicAuthSecret)) - require.NoError(t, err) - rr := httptest.NewRecorder() - router.ServeHTTP(rr, req) + rr := execRequest(t, router, http.MethodPost, "/api/v1/set-basic-auth", bytes.NewReader(basicAuthSecret)) require.Equal(t, http.StatusOK, rr.Code) - // Verify secretFromFile was written to file + // Ensure secretFromFile was written to file secretFromFile, err := os.ReadFile(cfg.Config.General.BasicAuthSecretPath) require.NoError(t, err) require.Equal(t, basicAuthSecret, secretFromFile)