From 76015742a9872b5a6ec04b018dc1c1319a8b0d8c Mon Sep 17 00:00:00 2001 From: Taylor Jasko Date: Tue, 19 Nov 2024 14:10:43 -0600 Subject: [PATCH] fix: updating Konnect auth logic to properly handle geo rewrites As Konnect's authorization service is not deployed across all regions, the global endpoint is used for certain API calls. When the organization info is fetched, the hostname within the provided base URL is then rewritten to the relevant global endpoint. This rewrite logic was unable to handle any number of geos, relying on hard coded logic instead for specific geos. This logic has been updated to allow callers of this library to talk to multiple geos; this includes new geos that may not exist as of this writing. --- pkg/konnect/login_service.go | 12 ++++--- pkg/konnect/login_service_test.go | 55 +++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/pkg/konnect/login_service.go b/pkg/konnect/login_service.go index ce40fb5..d211599 100644 --- a/pkg/konnect/login_service.go +++ b/pkg/konnect/login_service.go @@ -45,11 +45,16 @@ func (s *AuthService) Login(ctx context.Context, email, return authResponse, nil } +// getGlobalEndpoint returns the global endpoint for a given base Konnect URL. +func getGlobalEndpoint(baseURL string) string { + parts := strings.Split(baseURL, "api.konghq") + return baseEndpointUS + parts[len(parts)-1] +} + // getGlobalAuthEndpoint returns the global auth endpoint // given a base Konnect URL. func getGlobalAuthEndpoint(baseURL string) string { - parts := strings.Split(baseURL, "api.konghq") - return baseEndpointUS + parts[len(parts)-1] + authEndpointV2 + return getGlobalEndpoint(baseURL) + authEndpointV2 } func createAuthRequest(baseURL, email, password string) (*http.Request, error) { @@ -129,8 +134,7 @@ func (s *AuthService) LoginV2(ctx context.Context, email, func (s *AuthService) OrgUserInfo(ctx context.Context) (*OrgUserInfo, error) { // replace geo-specific endpoint with global one for retrieving org info client := *s.client - client.baseURL = strings.Replace(s.client.baseURL, "eu.", "global.", 1) - client.baseURL = strings.Replace(client.baseURL, "au.", "global.", 1) + client.baseURL = getGlobalEndpoint(client.baseURL) req, err := client.NewRequest(http.MethodGet, "/v2/organizations/me", nil, nil) if err != nil { diff --git a/pkg/konnect/login_service_test.go b/pkg/konnect/login_service_test.go index 89f384f..99765fc 100644 --- a/pkg/konnect/login_service_test.go +++ b/pkg/konnect/login_service_test.go @@ -1,11 +1,26 @@ package konnect import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +type mockRoundTripper struct{ mockHost string } + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Host = req.URL.Host + req.URL.Scheme = "http" + req.URL.Host = m.mockHost + + return (&http.Client{}).Do(req) +} + func TestGetGlobalAuthEndpoint(t *testing.T) { tests := []struct { baseURL string @@ -40,3 +55,43 @@ func TestGetGlobalAuthEndpoint(t *testing.T) { assert.Equal(t, tt.expected, getGlobalAuthEndpoint(tt.baseURL)) } } + +func TestAuthService_OrgUserInfo(t *testing.T) { + expectedResp := OrgUserInfo{Name: "test-org", OrgID: "1234"} + + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.Host, "global.api.konghq.com") + + if r.URL.Path == "/v2/organizations/me" && r.Method == http.MethodGet { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + resp, err := json.Marshal(expectedResp) + require.NoError(t, err) + + _, err = w.Write(resp) + require.NoError(t, err) + + return + } + + http.NotFound(w, r) + })) + defer mockServer.Close() + + authService := &AuthService{ + client: &Client{ + baseURL: "https://some-geo.api.konghq.com", + client: &http.Client{ + Transport: &mockRoundTripper{ + mockHost: mockServer.Listener.Addr().String(), + }, + }, + }, + } + + info, err := authService.OrgUserInfo(context.Background()) + require.NoError(t, err) + assert.Equal(t, expectedResp.Name, info.Name) + assert.Equal(t, expectedResp.OrgID, info.OrgID) +}