diff --git a/.taskfiles/test.yml b/.taskfiles/test.yml index a8e8468c9cd..d6933c4b372 100644 --- a/.taskfiles/test.yml +++ b/.taskfiles/test.yml @@ -50,8 +50,9 @@ tasks: - rm -rf coverage && mkdir -p coverage - for: { var: packages, as: package } cmd: |- - gotestsum --no-color=false --hide-summary=skipped --raw-command \ - go test -p 1 -parallel 1 -json {{.testArgs}} {{.args}} -count=1 -v \ + gotestsum --no-color=false --hide-summary=skipped \ + --jsonfile coverage/{{.product}}-{{.package | replace "." "gateway" | replace "/" "-"}}.json \ + --raw-command go test -p 1 -parallel 1 -json {{.testArgs}} {{.args}} -count=1 -v \ -coverprofile=coverage/{{.package | replace "." "gateway" | replace "/" "-"}}.cov {{.package}} | head -n -2 integration-combined: @@ -86,8 +87,8 @@ tasks: vars: count: '{{ .count | default "10" }}' cmds: - - gotestsum --hide-summary=skipped --junitfile=unit-tests.xml --raw-command cat coverage/{{.product}}-all.json - - echo "Slowest {{.count}} tests:" && cat coverage/{{.product}}-all.json | gotestsum tool slowest | head -n {{.count}} | sed -e 's|{{.package}}/||g' + - gotestsum --hide-summary=skipped --junitfile=unit-tests.xml --raw-command cat coverage/*.json + - echo "Slowest {{.count}} tests:" && cat coverage/*.json | gotestsum tool slowest | head -n {{.count}} | sed -e 's|{{.package}}/||g' clean: desc: "Clean test outputs" diff --git a/coprocess/Taskfile.yml b/coprocess/Taskfile.yml index 5668ff2d8a1..af69e2cb617 100644 --- a/coprocess/Taskfile.yml +++ b/coprocess/Taskfile.yml @@ -16,10 +16,13 @@ tasks: test: desc: "Run tests" deps: [ services:up ] + env: + GOTESTSUM_FORMAT: testname cmds: - defer: { task: services:down } - go fmt ./... - - go test -count=1 ./... + - go test -p 1 -parallel 1 -count=1 -json ./... | gotestsum --format testname --hide-summary=all + # lint target is run from CI lint: diff --git a/coprocess/grpc/coprocess_grpc_test.go b/coprocess/grpc/coprocess_grpc_test.go index a97fb7c41b8..7fb7848d111 100644 --- a/coprocess/grpc/coprocess_grpc_test.go +++ b/coprocess/grpc/coprocess_grpc_test.go @@ -8,31 +8,23 @@ import ( "io/ioutil" "math/rand" "mime/multipart" - "net" "net/http" "os" "strconv" "strings" "testing" - "time" "github.com/TykTechnologies/tyk/header" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" - "github.com/TykTechnologies/tyk/apidef" - "github.com/TykTechnologies/tyk/config" - "github.com/TykTechnologies/tyk/coprocess" "github.com/TykTechnologies/tyk/gateway" "github.com/TykTechnologies/tyk/test" "github.com/TykTechnologies/tyk/user" ) const ( - grpcListenAddr = ":9999" - grpcListenPath = "tcp://127.0.0.1:9999" grpcTestMaxSize = 100000000 grpcAuthority = "localhost" @@ -40,132 +32,6 @@ const ( testHeaderValue = "testvalue" ) -type dispatcher struct{} - -func (d *dispatcher) grpcError(object *coprocess.Object, errorMsg string) (*coprocess.Object, error) { - object.Request.ReturnOverrides.ResponseError = errorMsg - object.Request.ReturnOverrides.ResponseCode = 400 - return object, nil -} - -func (d *dispatcher) Dispatch(ctx context.Context, object *coprocess.Object) (*coprocess.Object, error) { - switch object.HookName { - case "testPreHook1": - object.Request.SetHeaders = map[string]string{ - testHeaderName: testHeaderValue, - } - case "testPreHook2": - contentType, found := object.Request.Headers["Content-Type"] - if !found { - return d.grpcError(object, "Content Type field not found") - } - if strings.Contains(contentType, "json") { - if len(object.Request.Body) == 0 { - return d.grpcError(object, "Body field is empty") - } - if len(object.Request.RawBody) == 0 { - return d.grpcError(object, "Raw body field is empty") - } - if strings.Compare(object.Request.Body, string(object.Request.Body)) != 0 { - return d.grpcError(object, "Raw body and body fields don't match") - } - } else if strings.Contains(contentType, "multipart") { - if len(object.Request.Body) != 0 { - return d.grpcError(object, "Body field isn't empty") - } - if len(object.Request.RawBody) == 0 { - return d.grpcError(object, "Raw body field is empty") - } - } else { - return d.grpcError(object, "Request content type should be either JSON or multipart") - } - case "testPostHook1": - testKeyValue, ok := object.Session.Metadata["testkey"] - if !ok { - return d.grpcError(object, "'testkey' not found in session metadata") - } - jsonObject := make(map[string]string) - if err := json.Unmarshal([]byte(testKeyValue), &jsonObject); err != nil { - return d.grpcError(object, "couldn't decode 'testkey' nested value") - } - nestedKeyValue, ok := jsonObject["nestedkey"] - if !ok { - return d.grpcError(object, "'nestedkey' not found in JSON object") - } - if nestedKeyValue != "nestedvalue" { - return d.grpcError(object, "'nestedvalue' value doesn't match") - } - testKey2Value, ok := object.Session.Metadata["testkey2"] - if !ok { - return d.grpcError(object, "'testkey' not found in session metadata") - } - if testKey2Value != "testvalue" { - return d.grpcError(object, "'testkey2' value doesn't match") - } - - // Check for compatibility (object.Metadata should contain the same keys as object.Session.Metadata) - for k, v := range object.Metadata { - sessionKeyValue, ok := object.Session.Metadata[k] - if !ok { - return d.grpcError(object, k+" not found in object.Session.Metadata") - } - if strings.Compare(sessionKeyValue, v) != 0 { - return d.grpcError(object, k+" doesn't match value in object.Session.Metadata") - } - } - case "testResponseHook": - object.Response.RawBody = []byte("newbody") - case "testConfigDataResponseHook": - if _, ok := object.Spec["config_data"]; ok { - object.Response.Headers["x-config-data"] = "true" - object.Response.MultivalueHeaders = append(object.Response.MultivalueHeaders, &coprocess.Header{ - Key: "x-config-data", - Values: []string{"true"}, - }) - } else { - object.Response.Headers["x-config-data"] = "false" - object.Response.MultivalueHeaders = append(object.Response.MultivalueHeaders, &coprocess.Header{ - Key: "x-config-data", - Values: []string{"false"}, - }) - } - case "testAuthHook1": - req := object.Request - token := req.Headers["Authorization"] - if object.Metadata == nil { - object.Metadata = map[string]string{} - } - object.Metadata["token"] = token - if token != "abc" { - return d.grpcError(object, "invalid token") - } - - session := coprocess.SessionState{ - Rate: 100, - IdExtractorDeadline: time.Now().Add(2 * time.Second).Unix(), - Metadata: map[string]string{ - "sessionMetaKey": "customAuthSessionMetaValue", - }, - } - - object.Session = &session - } - return object, nil -} - -func (d *dispatcher) DispatchEvent(ctx context.Context, event *coprocess.Event) (*coprocess.EventReply, error) { - return &coprocess.EventReply{}, nil -} - -func newTestGRPCServer() (s *grpc.Server) { - s = grpc.NewServer( - grpc.MaxRecvMsgSize(grpcTestMaxSize), - grpc.MaxSendMsgSize(grpcTestMaxSize), - ) - coprocess.RegisterDispatcherServer(s, &dispatcher{}) - return s -} - func loadTestGRPCAPIs(s *gateway.Test) { s.Gw.BuildAndLoadAPI(func(spec *gateway.APISpec) { spec.APIID = "1" @@ -454,38 +320,13 @@ func loadTestGRPCAPIs(s *gateway.Test) { ) } -func startTykWithGRPC() (*gateway.Test, *grpc.Server) { - // Setup the gRPC server: - listener, _ := net.Listen("tcp", grpcListenAddr) - grpcServer := newTestGRPCServer() - go grpcServer.Serve(listener) - - // Setup Tyk: - cfg := config.CoProcessConfig{ - EnableCoProcess: true, - CoProcessGRPCServer: grpcListenPath, - GRPCRecvMaxSize: grpcTestMaxSize, - GRPCSendMaxSize: grpcTestMaxSize, - GRPCAuthority: grpcAuthority, - } - ts := gateway.StartTest(nil, gateway.TestConfig{ - CoprocessConfig: cfg, - EnableTestDNSMock: false, - }) - - // Load test APIs: - loadTestGRPCAPIs(ts) - return ts, grpcServer -} - func TestMain(m *testing.M) { os.Exit(gateway.InitTestMain(context.Background(), m)) } func TestGRPCDispatch(t *testing.T) { - ts, grpcServer := startTykWithGRPC() - defer ts.Close() - defer grpcServer.Stop() + ts, cleanupFn := startTestServices(t) + t.Cleanup(cleanupFn) keyID := gateway.CreateSession(ts.Gw, func(s *user.SessionState) { s.MetaData = map[string]interface{}{ @@ -604,19 +445,18 @@ func TestGRPCDispatch(t *testing.T) { } func BenchmarkGRPCDispatch(b *testing.B) { - ts, grpcServer := startTykWithGRPC() - defer ts.Close() - defer grpcServer.Stop() + ts, cleanupFn := startTestServices(b) + b.Cleanup(cleanupFn) keyID := gateway.CreateSession(ts.Gw) headers := map[string]string{"authorization": keyID} b.Run("Pre Hook with SetHeaders", func(b *testing.B) { - path := "/grpc-test-api/" + basepath := "/grpc-test-api/" b.ReportAllocs() for i := 0; i < b.N; i++ { ts.Run(b, test.TestCase{ - Path: path, + Path: basepath, Method: http.MethodGet, Code: http.StatusOK, Headers: headers, @@ -632,15 +472,14 @@ func randStringBytes(n int) string { } func TestGRPCIgnore(t *testing.T) { - ts, grpcServer := startTykWithGRPC() - defer ts.Close() - defer grpcServer.Stop() + ts, cleanupFn := startTestServices(t) + t.Cleanup(cleanupFn) - path := "/grpc-test-api-ignore/" + basepath := "/grpc-test-api-ignore/" // no header ts.Run(t, test.TestCase{ - Path: path + "something", + Path: basepath + "something", Method: http.MethodGet, Code: http.StatusBadRequest, BodyMatchFunc: func(b []byte) bool { @@ -649,7 +488,7 @@ func TestGRPCIgnore(t *testing.T) { }) ts.Run(t, test.TestCase{ - Path: path + "anything", + Path: basepath + "anything", Method: http.MethodGet, Code: http.StatusOK, }) @@ -657,14 +496,14 @@ func TestGRPCIgnore(t *testing.T) { // bad header headers := map[string]string{"authorization": "bad"} ts.Run(t, test.TestCase{ - Path: path + "something", + Path: basepath + "something", Method: http.MethodGet, Code: http.StatusForbidden, Headers: headers, }) ts.Run(t, test.TestCase{ - Path: path + "anything", + Path: basepath + "anything", Method: http.MethodGet, Code: http.StatusOK, Headers: headers, @@ -672,43 +511,40 @@ func TestGRPCIgnore(t *testing.T) { } func TestGRPCAuthHook(t *testing.T) { - ts, grpcServer := startTykWithGRPC() - defer ts.Close() - defer grpcServer.Stop() + ts, cleanupFn := startTestServices(t) + t.Cleanup(cleanupFn) t.Run("id extractor enabled", func(t *testing.T) { - path := "/grpc-auth-hook-test-api-1/" - baseMW := &gateway.BaseMiddleware{ - Gw: ts.Gw, - Spec: &gateway.APISpec{ - APIDefinition: &apidef.APIDefinition{ - OrgID: "default", - }, - }} + basepath := "/grpc-auth-hook-test-api-1/" + spec := &gateway.APISpec{ + APIDefinition: &apidef.APIDefinition{ + OrgID: "default", + }, + } + baseMW := gateway.NewBaseMiddleware(ts.Gw, spec, nil, nil) baseExtractor := gateway.BaseExtractor{ BaseMiddleware: baseMW, } expectedSessionID := baseExtractor.GenerateSessionID("abc", baseMW) _, _ = ts.Run(t, []test.TestCase{ - {Method: http.MethodGet, Path: path, Headers: map[string]string{"Authorization": "abc"}, Code: http.StatusOK}, + {Method: http.MethodGet, Path: basepath, Headers: map[string]string{"Authorization": "abc"}, Code: http.StatusOK}, {Method: http.MethodGet, Path: fmt.Sprintf("/tyk/keys/%s", expectedSessionID), AdminAuth: true, Code: http.StatusOK}, }...) }) // won't extract id and a session with sessionID as token is created t.Run("id extractor disabled", func(t *testing.T) { - path := "/grpc-auth-hook-test-api-2/" + basepath := "/grpc-auth-hook-test-api-2/" _, _ = ts.Run(t, []test.TestCase{ - {Method: http.MethodGet, Path: path, Headers: map[string]string{"Authorization": "abc"}, Code: http.StatusOK}, + {Method: http.MethodGet, Path: basepath, Headers: map[string]string{"Authorization": "abc"}, Code: http.StatusOK}, {Method: http.MethodGet, Path: "/tyk/keys/abc", AdminAuth: true, Code: http.StatusOK}, }...) }) } func TestGRPC_MultiAuthentication(t *testing.T) { - ts, grpcServer := startTykWithGRPC() - defer ts.Close() - defer grpcServer.Stop() + ts, cleanupFn := startTestServices(t) + t.Cleanup(cleanupFn) const ( apiID = "my-api-id" @@ -780,23 +616,22 @@ func TestGRPC_MultiAuthentication(t *testing.T) { } func TestGRPCConfigData(t *testing.T) { - ts, grpcServer := startTykWithGRPC() - defer ts.Close() - defer grpcServer.Stop() + ts, cleanupFn := startTestServices(t) + t.Cleanup(cleanupFn) t.Run("config data disabled", func(t *testing.T) { - path := "/grpc-config-data-1/" + basepath := "/grpc-config-data-1/" _, _ = ts.Run(t, []test.TestCase{ - {Method: http.MethodGet, Path: path, Code: http.StatusOK, + {Method: http.MethodGet, Path: basepath, Code: http.StatusOK, HeadersMatch: map[string]string{"x-config-data": "true"}, }, }...) }) t.Run("config data disabled", func(t *testing.T) { - path := "/grpc-config-data-2/" + basepath := "/grpc-config-data-2/" _, _ = ts.Run(t, []test.TestCase{ - {Method: http.MethodGet, Path: path, Code: http.StatusOK, + {Method: http.MethodGet, Path: basepath, Code: http.StatusOK, HeadersMatch: map[string]string{"x-config-data": "false"}, }, }...) diff --git a/coprocess/grpc/dispatcher_test.go b/coprocess/grpc/dispatcher_test.go new file mode 100644 index 00000000000..6573070c247 --- /dev/null +++ b/coprocess/grpc/dispatcher_test.go @@ -0,0 +1,127 @@ +package grpc + +import ( + "context" + "encoding/json" + "strings" + "time" + + "github.com/TykTechnologies/tyk/coprocess" +) + +type dispatcher struct{} + +func (d *dispatcher) grpcError(object *coprocess.Object, errorMsg string) (*coprocess.Object, error) { + object.Request.ReturnOverrides.ResponseError = errorMsg + object.Request.ReturnOverrides.ResponseCode = 400 + return object, nil +} + +func (d *dispatcher) Dispatch(_ context.Context, object *coprocess.Object) (*coprocess.Object, error) { + switch object.HookName { + case "testPreHook1": + object.Request.SetHeaders = map[string]string{ + testHeaderName: testHeaderValue, + } + case "testPreHook2": + contentType, found := object.Request.Headers["Content-Type"] + if !found { + return d.grpcError(object, "Content Type field not found") + } + if strings.Contains(contentType, "json") { + if len(object.Request.Body) == 0 { + return d.grpcError(object, "Body field is empty") + } + if len(object.Request.RawBody) == 0 { + return d.grpcError(object, "Raw body field is empty") + } + if strings.Compare(object.Request.Body, string(object.Request.Body)) != 0 { + return d.grpcError(object, "Raw body and body fields don't match") + } + } else if strings.Contains(contentType, "multipart") { + if len(object.Request.Body) != 0 { + return d.grpcError(object, "Body field isn't empty") + } + if len(object.Request.RawBody) == 0 { + return d.grpcError(object, "Raw body field is empty") + } + } else { + return d.grpcError(object, "Request content type should be either JSON or multipart") + } + case "testPostHook1": + testKeyValue, ok := object.Session.Metadata["testkey"] + if !ok { + return d.grpcError(object, "'testkey' not found in session metadata") + } + jsonObject := make(map[string]string) + if err := json.Unmarshal([]byte(testKeyValue), &jsonObject); err != nil { + return d.grpcError(object, "couldn't decode 'testkey' nested value") + } + nestedKeyValue, ok := jsonObject["nestedkey"] + if !ok { + return d.grpcError(object, "'nestedkey' not found in JSON object") + } + if nestedKeyValue != "nestedvalue" { + return d.grpcError(object, "'nestedvalue' value doesn't match") + } + testKey2Value, ok := object.Session.Metadata["testkey2"] + if !ok { + return d.grpcError(object, "'testkey' not found in session metadata") + } + if testKey2Value != "testvalue" { + return d.grpcError(object, "'testkey2' value doesn't match") + } + + // Check for compatibility (object.Metadata should contain the same keys as object.Session.Metadata) + for k, v := range object.Metadata { + sessionKeyValue, ok := object.Session.Metadata[k] + if !ok { + return d.grpcError(object, k+" not found in object.Session.Metadata") + } + if strings.Compare(sessionKeyValue, v) != 0 { + return d.grpcError(object, k+" doesn't match value in object.Session.Metadata") + } + } + case "testResponseHook": + object.Response.RawBody = []byte("newbody") + case "testConfigDataResponseHook": + if _, ok := object.Spec["config_data"]; ok { + object.Response.Headers["x-config-data"] = "true" + object.Response.MultivalueHeaders = append(object.Response.MultivalueHeaders, &coprocess.Header{ + Key: "x-config-data", + Values: []string{"true"}, + }) + } else { + object.Response.Headers["x-config-data"] = "false" + object.Response.MultivalueHeaders = append(object.Response.MultivalueHeaders, &coprocess.Header{ + Key: "x-config-data", + Values: []string{"false"}, + }) + } + case "testAuthHook1": + req := object.Request + token := req.Headers["Authorization"] + if object.Metadata == nil { + object.Metadata = map[string]string{} + } + object.Metadata["token"] = token + if token != "abc" { + return d.grpcError(object, "invalid token") + } + + session := coprocess.SessionState{ + Rate: 100, + IdExtractorDeadline: time.Now().Add(2 * time.Second).Unix(), + Metadata: map[string]string{ + "sessionMetaKey": "customAuthSessionMetaValue", + }, + } + + object.Session = &session + } + return object, nil +} + +func (d *dispatcher) DispatchEvent(_ context.Context, _ *coprocess.Event) (*coprocess.EventReply, error) { + return &coprocess.EventReply{}, nil +} diff --git a/coprocess/grpc/services_test.go b/coprocess/grpc/services_test.go new file mode 100644 index 00000000000..1fefd75cfc2 --- /dev/null +++ b/coprocess/grpc/services_test.go @@ -0,0 +1,69 @@ +package grpc + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + + "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/coprocess" + "github.com/TykTechnologies/tyk/gateway" +) + +func newTestGRPCServer() (s *grpc.Server) { + s = grpc.NewServer( + grpc.MaxRecvMsgSize(grpcTestMaxSize), + grpc.MaxSendMsgSize(grpcTestMaxSize), + ) + coprocess.RegisterDispatcherServer(s, &dispatcher{}) + return s +} + +func startTestServices(tb testing.TB) (*gateway.Test, func()) { + tb.Helper() + + listener, err := net.Listen("tcp", ":0") + require.NoError(tb, err) + + grpcServer := newTestGRPCServer() + go func() { + err := grpcServer.Serve(listener) + require.NoError(tb, err) + }() + + conf := config.CoProcessConfig{ + EnableCoProcess: true, + CoProcessGRPCServer: grpcServerAddress(listener), + GRPCRecvMaxSize: grpcTestMaxSize, + GRPCSendMaxSize: grpcTestMaxSize, + GRPCAuthority: grpcAuthority, + } + + ts := gateway.StartTest(nil, gateway.TestConfig{ + CoprocessConfig: conf, + }) + // Load test APIs: + loadTestGRPCAPIs(ts) + + shutdown := stopTestServices(ts, grpcServer, listener) + + tb.Logf("Started with conf.CoProcessGRPCServer %q", conf.CoProcessGRPCServer) + + return ts, shutdown +} + +func grpcServerAddress(l net.Listener) string { + addr := l.Addr() + target := addr.String() + return addr.Network() + "://" + target +} + +func stopTestServices(ts *gateway.Test, grpcServer *grpc.Server, listener net.Listener) func() { + return func() { + ts.Close() + grpcServer.Stop() + listener.Close() + } +} diff --git a/dnscache/storage_test.go b/dnscache/storage_test.go index 74dca0673dc..f49de3ca8e4 100644 --- a/dnscache/storage_test.go +++ b/dnscache/storage_test.go @@ -151,6 +151,8 @@ func TestStorageFetchItem(t *testing.T) { } func TestStorageRecordExpiration(t *testing.T) { + t.Skip() // Slow test, bad practices with time.Sleep. + var ( expiration = 2000 checkInterval = 1500 diff --git a/gateway/api_loader.go b/gateway/api_loader.go index bb63cb1cdff..276e3af5396 100644 --- a/gateway/api_loader.go +++ b/gateway/api_loader.go @@ -275,16 +275,7 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, // Create the response processors, pass all the loaded custom middleware response functions: gw.createResponseMiddlewareChain(spec, mwResponseFuncs) - baseMid := &BaseMiddleware{Spec: spec, Proxy: proxy, logger: logger, Gw: gw} - - for _, v := range baseMid.Spec.VersionData.Versions { - if len(v.ExtendedPaths.CircuitBreaker) > 0 { - baseMid.Spec.CircuitBreakerEnabled = true - } - if len(v.ExtendedPaths.HardTimeouts) > 0 { - baseMid.Spec.EnforcedTimeoutEnabled = true - } - } + baseMid := NewBaseMiddleware(gw, spec, proxy, logger) keyPrefix := "cache-" + spec.APIID cacheStore := storage.RedisCluster{KeyPrefix: keyPrefix, IsCache: true, ConnectionHandler: gw.StorageConnectionHandler} @@ -299,14 +290,14 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, logger.Info("Checking security policy: Open") } - gw.mwAppendEnabled(&chainArray, &VersionCheck{BaseMiddleware: baseMid}) + gw.mwAppendEnabled(&chainArray, &VersionCheck{BaseMiddleware: baseMid.Copy()}) for _, obj := range mwPreFuncs { if mwDriver == apidef.GoPluginDriver { gw.mwAppendEnabled( &chainArray, &GoPluginMiddleware{ - BaseMiddleware: baseMid, + BaseMiddleware: baseMid.Copy(), Path: obj.Path, SymbolName: obj.Name, APILevel: true, @@ -314,44 +305,44 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, ) } else if mwDriver != apidef.OttoDriver { coprocessLog.Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver) - gw.mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Pre, obj.Name, mwDriver, obj.RawBodyOnly, nil}) + gw.mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid.Copy(), coprocess.HookType_Pre, obj.Name, mwDriver, obj.RawBodyOnly, nil}) } else { - chainArray = append(chainArray, gw.createDynamicMiddleware(obj.Name, true, obj.RequireSession, baseMid)) + chainArray = append(chainArray, gw.createDynamicMiddleware(obj.Name, true, obj.RequireSession, baseMid.Copy())) } } - gw.mwAppendEnabled(&chainArray, &RateCheckMW{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &IPBlackListMiddleware{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &CertificateCheckMW{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid, mon: Monitor{Gw: gw}}) - gw.mwAppendEnabled(&chainArray, &RequestSizeLimitMiddleware{baseMid}) - gw.mwAppendEnabled(&chainArray, &MiddlewareContextVars{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &TrackEndpointMiddleware{baseMid}) + gw.mwAppendEnabled(&chainArray, &RateCheckMW{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &IPWhiteListMiddleware{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &IPBlackListMiddleware{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &CertificateCheckMW{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &OrganizationMonitor{BaseMiddleware: baseMid.Copy(), mon: Monitor{Gw: gw}}) + gw.mwAppendEnabled(&chainArray, &RequestSizeLimitMiddleware{baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &MiddlewareContextVars{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &TrackEndpointMiddleware{baseMid.Copy()}) if !spec.UseKeylessAccess { // Select the keying method to use for setting session states - if gw.mwAppendEnabled(&authArray, &Oauth2KeyExists{baseMid}) { + if gw.mwAppendEnabled(&authArray, &Oauth2KeyExists{baseMid.Copy()}) { logger.Info("Checking security policy: OAuth") } - if gw.mwAppendEnabled(&authArray, &ExternalOAuthMiddleware{baseMid}) { + if gw.mwAppendEnabled(&authArray, &ExternalOAuthMiddleware{baseMid.Copy()}) { logger.Info("Checking security policy: External OAuth") } - if gw.mwAppendEnabled(&authArray, &BasicAuthKeyIsValid{baseMid, nil, nil}) { + if gw.mwAppendEnabled(&authArray, &BasicAuthKeyIsValid{baseMid.Copy(), nil, nil}) { logger.Info("Checking security policy: Basic") } - if gw.mwAppendEnabled(&authArray, &HTTPSignatureValidationMiddleware{BaseMiddleware: baseMid}) { + if gw.mwAppendEnabled(&authArray, &HTTPSignatureValidationMiddleware{BaseMiddleware: baseMid.Copy()}) { logger.Info("Checking security policy: HMAC") } - if gw.mwAppendEnabled(&authArray, &JWTMiddleware{baseMid}) { + if gw.mwAppendEnabled(&authArray, &JWTMiddleware{baseMid.Copy()}) { logger.Info("Checking security policy: JWT") } - if gw.mwAppendEnabled(&authArray, &OpenIDMW{BaseMiddleware: baseMid}) { + if gw.mwAppendEnabled(&authArray, &OpenIDMW{BaseMiddleware: baseMid.Copy()}) { logger.Info("Checking security policy: OpenID") } @@ -362,7 +353,7 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, case apidef.OttoDriver: logger.Info("----> Checking security policy: JS Plugin") authArray = append(authArray, gw.createMiddleware(&DynamicMiddleware{ - BaseMiddleware: baseMid, + BaseMiddleware: baseMid.Copy(), MiddlewareClassName: mwAuthCheckFunc.Name, Pre: true, Auth: true, @@ -371,7 +362,7 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, gw.mwAppendEnabled( &authArray, &GoPluginMiddleware{ - BaseMiddleware: baseMid, + BaseMiddleware: baseMid.Copy(), Path: mwAuthCheckFunc.Path, SymbolName: mwAuthCheckFunc.Name, APILevel: true, @@ -380,14 +371,14 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, default: coprocessLog.Debug("Registering coprocess middleware, hook name: ", mwAuthCheckFunc.Name, "hook type: CustomKeyCheck", ", driver: ", mwDriver) - newExtractor(spec, baseMid) - gw.mwAppendEnabled(&authArray, &CoProcessMiddleware{baseMid, coprocess.HookType_CustomKeyCheck, mwAuthCheckFunc.Name, mwDriver, mwAuthCheckFunc.RawBodyOnly, nil}) + newExtractor(spec, baseMid.Copy()) + gw.mwAppendEnabled(&authArray, &CoProcessMiddleware{baseMid.Copy(), coprocess.HookType_CustomKeyCheck, mwAuthCheckFunc.Name, mwDriver, mwAuthCheckFunc.RawBodyOnly, nil}) } } if spec.UseStandardAuth || len(authArray) == 0 { logger.Info("Checking security policy: Token") - authArray = append(authArray, gw.createMiddleware(&AuthKey{baseMid})) + authArray = append(authArray, gw.createMiddleware(&AuthKey{baseMid.Copy()})) } chainArray = append(chainArray, authArray...) @@ -405,7 +396,7 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, gw.mwAppendEnabled( &chainArray, &GoPluginMiddleware{ - BaseMiddleware: baseMid, + BaseMiddleware: baseMid.Copy(), Path: obj.Path, SymbolName: obj.Name, APILevel: true, @@ -413,46 +404,47 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, ) } else { coprocessLog.Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Pre", ", driver: ", mwDriver) - gw.mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_PostKeyAuth, obj.Name, mwDriver, obj.RawBodyOnly, nil}) + gw.mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid.Copy(), coprocess.HookType_PostKeyAuth, obj.Name, mwDriver, obj.RawBodyOnly, nil}) } } - gw.mwAppendEnabled(&chainArray, &StripAuth{baseMid}) - gw.mwAppendEnabled(&chainArray, &KeyExpired{baseMid}) - gw.mwAppendEnabled(&chainArray, &AccessRightsCheck{baseMid}) - gw.mwAppendEnabled(&chainArray, &GranularAccessMiddleware{baseMid}) - gw.mwAppendEnabled(&chainArray, &RateLimitAndQuotaCheck{baseMid}) + gw.mwAppendEnabled(&chainArray, &StripAuth{baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &KeyExpired{baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &AccessRightsCheck{baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &GranularAccessMiddleware{baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &RateLimitAndQuotaCheck{baseMid.Copy()}) } - gw.mwAppendEnabled(&chainArray, &RateLimitForAPI{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &GraphQLMiddleware{BaseMiddleware: baseMid}) + gw.mwAppendEnabled(&chainArray, &RateLimitForAPI{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &GraphQLMiddleware{BaseMiddleware: baseMid.Copy()}) + if !spec.UseKeylessAccess { - gw.mwAppendEnabled(&chainArray, &GraphQLComplexityMiddleware{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &GraphQLGranularAccessMiddleware{BaseMiddleware: baseMid}) + gw.mwAppendEnabled(&chainArray, &GraphQLComplexityMiddleware{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &GraphQLGranularAccessMiddleware{BaseMiddleware: baseMid.Copy()}) } - gw.mwAppendEnabled(&chainArray, &ValidateJSON{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &ValidateRequest{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &PersistGraphQLOperationMiddleware{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &TransformMiddleware{baseMid}) - gw.mwAppendEnabled(&chainArray, &TransformJQMiddleware{baseMid}) - gw.mwAppendEnabled(&chainArray, &TransformHeaders{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &URLRewriteMiddleware{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &TransformMethod{BaseMiddleware: baseMid}) + gw.mwAppendEnabled(&chainArray, &ValidateJSON{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &ValidateRequest{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &PersistGraphQLOperationMiddleware{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &TransformMiddleware{baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &TransformJQMiddleware{baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &TransformHeaders{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &URLRewriteMiddleware{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &TransformMethod{BaseMiddleware: baseMid.Copy()}) // Earliest we can respond with cache get 200 ok - gw.mwAppendEnabled(&chainArray, &RedisCacheMiddleware{BaseMiddleware: baseMid, store: &cacheStore}) + gw.mwAppendEnabled(&chainArray, &RedisCacheMiddleware{BaseMiddleware: baseMid.Copy(), store: &cacheStore}) - gw.mwAppendEnabled(&chainArray, &VirtualEndpoint{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &RequestSigning{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&chainArray, &GoPluginMiddleware{BaseMiddleware: baseMid}) + gw.mwAppendEnabled(&chainArray, &VirtualEndpoint{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &RequestSigning{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&chainArray, &GoPluginMiddleware{BaseMiddleware: baseMid.Copy()}) for _, obj := range mwPostFuncs { if mwDriver == apidef.GoPluginDriver { gw.mwAppendEnabled( &chainArray, &GoPluginMiddleware{ - BaseMiddleware: baseMid, + BaseMiddleware: baseMid.Copy(), Path: obj.Path, SymbolName: obj.Name, APILevel: true, @@ -460,23 +452,23 @@ func (gw *Gateway) processSpec(spec *APISpec, apisByListen map[string]int, ) } else if mwDriver != apidef.OttoDriver { coprocessLog.Debug("Registering coprocess middleware, hook name: ", obj.Name, "hook type: Post", ", driver: ", mwDriver) - gw.mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid, coprocess.HookType_Post, obj.Name, mwDriver, obj.RawBodyOnly, nil}) + gw.mwAppendEnabled(&chainArray, &CoProcessMiddleware{baseMid.Copy(), coprocess.HookType_Post, obj.Name, mwDriver, obj.RawBodyOnly, nil}) } else { - chainArray = append(chainArray, gw.createDynamicMiddleware(obj.Name, false, obj.RequireSession, baseMid)) + chainArray = append(chainArray, gw.createDynamicMiddleware(obj.Name, false, obj.RequireSession, baseMid.Copy())) } } - chain = alice.New(chainArray...).Then(&DummyProxyHandler{SH: SuccessHandler{baseMid}, Gw: gw}) + chain = alice.New(chainArray...).Then(&DummyProxyHandler{SH: SuccessHandler{baseMid.Copy()}, Gw: gw}) if !spec.UseKeylessAccess { var simpleArray []alice.Constructor - gw.mwAppendEnabled(&simpleArray, &IPWhiteListMiddleware{baseMid}) - gw.mwAppendEnabled(&simpleArray, &IPBlackListMiddleware{BaseMiddleware: baseMid}) - gw.mwAppendEnabled(&simpleArray, &OrganizationMonitor{BaseMiddleware: baseMid, mon: Monitor{Gw: gw}}) - gw.mwAppendEnabled(&simpleArray, &VersionCheck{BaseMiddleware: baseMid}) + gw.mwAppendEnabled(&simpleArray, &IPWhiteListMiddleware{baseMid.Copy()}) + gw.mwAppendEnabled(&simpleArray, &IPBlackListMiddleware{BaseMiddleware: baseMid.Copy()}) + gw.mwAppendEnabled(&simpleArray, &OrganizationMonitor{BaseMiddleware: baseMid.Copy(), mon: Monitor{Gw: gw}}) + gw.mwAppendEnabled(&simpleArray, &VersionCheck{BaseMiddleware: baseMid.Copy()}) simpleArray = append(simpleArray, authArray...) - gw.mwAppendEnabled(&simpleArray, &KeyExpired{baseMid}) - gw.mwAppendEnabled(&simpleArray, &AccessRightsCheck{baseMid}) + gw.mwAppendEnabled(&simpleArray, &KeyExpired{baseMid.Copy()}) + gw.mwAppendEnabled(&simpleArray, &AccessRightsCheck{baseMid.Copy()}) rateLimitPath := path.Join(spec.Proxy.ListenPath, rateLimitEndpoint) logger.Debug("Rate limit endpoint is: ", rateLimitPath) diff --git a/gateway/coprocess.go b/gateway/coprocess.go index 36663ed8f45..ff41724ebae 100644 --- a/gateway/coprocess.go +++ b/gateway/coprocess.go @@ -50,7 +50,7 @@ func CreateCoProcessMiddleware(hookName string, hookType coprocess.HookType, mwD HookType: hookType, HookName: hookName, MiddlewareDriver: mwDriver, - successHandler: &SuccessHandler{baseMid}, + successHandler: &SuccessHandler{baseMid.Copy()}, } return baseMid.Gw.createMiddleware(dMiddleware) @@ -308,7 +308,7 @@ func (m *CoProcessMiddleware) EnabledForSpec() bool { log.WithFields(logrus.Fields{ "prefix": "coprocess", }).Debug("Enabling CP middleware.") - m.successHandler = &SuccessHandler{m.BaseMiddleware} + m.successHandler = &SuccessHandler{m.BaseMiddleware.Copy()} return true } @@ -547,7 +547,7 @@ func (h *CustomMiddlewareResponseHook) Name() string { } func (h *CustomMiddlewareResponseHook) HandleError(rw http.ResponseWriter, req *http.Request) { - handler := ErrorHandler{h.mw.BaseMiddleware} + handler := ErrorHandler{h.mw.BaseMiddleware.Copy()} handler.HandleError(rw, req, "Middleware error", http.StatusInternalServerError, true) } diff --git a/gateway/coprocess_id_extractor.go b/gateway/coprocess_id_extractor.go index 59c8c51e723..a9178c29200 100644 --- a/gateway/coprocess_id_extractor.go +++ b/gateway/coprocess_id_extractor.go @@ -273,7 +273,7 @@ func newExtractor(referenceSpec *APISpec, mw *BaseMiddleware) { baseExtractor := BaseExtractor{ Config: &referenceSpec.CustomMiddleware.IdExtractor, IDExtractorConfig: idExtractorConfig, - BaseMiddleware: mw, + BaseMiddleware: mw, // Already a Copy from api_loader.go Spec: referenceSpec, } diff --git a/gateway/coprocess_test.go b/gateway/coprocess_test.go index 4f3c190e178..f861edf3223 100644 --- a/gateway/coprocess_test.go +++ b/gateway/coprocess_test.go @@ -244,12 +244,7 @@ func equalHeaders(h1, h2 []*coprocess.Header) bool { } func TestCoProcessMiddlewareName(t *testing.T) { - // Initialize the CoProcessMiddleware - m := &CoProcessMiddleware{BaseMiddleware: &BaseMiddleware{}} + m := &CoProcessMiddleware{} - // Get the name using the method - name := m.Name() - - // Check that the returned name is "CoProcessMiddleware" - require.Equal(t, "CoProcessMiddleware", name, "Name method did not return the expected value") + require.Equal(t, "CoProcessMiddleware", m.Name(), "Name method did not return the expected value") } diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index 2238f65e13c..a00246d137d 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -1294,7 +1294,7 @@ func TestCacheEtag(t *testing.T) { } func TestOldCachePlugin(t *testing.T) { - test.Exclusive(t) // Test uses cache-* while other tests delete it. + t.Skip() // DeleteScanMatch interferes with other tests. api := BuildAPI(func(spec *APISpec) { spec.Proxy.ListenPath = "/" diff --git a/gateway/middleware.go b/gateway/middleware.go index ec8d6cf265e..78242f81ecd 100644 --- a/gateway/middleware.go +++ b/gateway/middleware.go @@ -10,6 +10,7 @@ import ( "io/ioutil" "net/http" "strconv" + "sync" "time" "github.com/TykTechnologies/tyk/internal/cache" @@ -44,11 +45,9 @@ var ( ) type TykMiddleware interface { - Init() Base() *BaseMiddleware - SetName(string) - SetRequestLogger(*http.Request) + Init() Logger() *logrus.Entry Config() (interface{}, error) ProcessRequest(w http.ResponseWriter, r *http.Request, conf interface{}) (error, int) // Handles request @@ -68,7 +67,9 @@ func (tr TraceMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, defer span.Finish() setContext(r, ctx) return tr.TykMiddleware.ProcessRequest(w, r, conf) - } else if baseMw := tr.Base(); baseMw != nil { + } + + if baseMw := tr.Base(); baseMw != nil { cfg := baseMw.Gw.GetConfig() if cfg.OpenTelemetry.Enabled { otel.AddTraceID(r.Context(), w) @@ -101,7 +102,7 @@ func (tr TraceMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, func (gw *Gateway) createDynamicMiddleware(name string, isPre, useSession bool, baseMid *BaseMiddleware) func(http.Handler) http.Handler { dMiddleware := &DynamicMiddleware{ - BaseMiddleware: baseMid, + BaseMiddleware: baseMid, // already a Copy from api_loader. MiddlewareClassName: name, Pre: isPre, UseSession: useSession, @@ -117,7 +118,7 @@ func (gw *Gateway) createMiddleware(actualMW TykMiddleware) func(http.Handler) h } // construct a new instance mw.Init() - mw.SetName(mw.Name()) + mw.Base().SetName(mw.Name()) mw.Logger().Debug("Init") // Pull the configuration @@ -128,7 +129,7 @@ func (gw *Gateway) createMiddleware(actualMW TykMiddleware) func(http.Handler) h return func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mw.SetRequestLogger(r) + logger := mw.Base().SetRequestLogger(r) if gw.GetConfig().NewRelic.AppName != "" { if txn, ok := w.(newrelic.Transaction); ok { @@ -154,7 +155,7 @@ func (gw *Gateway) createMiddleware(actualMW TykMiddleware) func(http.Handler) h } startTime := time.Now() - mw.Logger().WithField("ts", startTime.UnixNano()).Debug("Started") + logger.WithField("ts", startTime.UnixNano()).WithField("mw", mw.Name()).Debug("Started") if mw.Base().Spec.CORS.OptionsPassthrough && r.Method == "OPTIONS" { h.ServeHTTP(w, r) @@ -181,7 +182,7 @@ func (gw *Gateway) createMiddleware(actualMW TykMiddleware) func(http.Handler) h job.TimingKv(eventName+".exec_time", finishTime.Nanoseconds(), meta) } - mw.Logger().WithError(err).WithField("code", errCode).WithField("ns", finishTime.Nanoseconds()).Debug("Finished") + logger.WithError(err).WithField("code", errCode).WithField("ns", finishTime.Nanoseconds()).Debug("Finished") return } @@ -192,7 +193,7 @@ func (gw *Gateway) createMiddleware(actualMW TykMiddleware) func(http.Handler) h job.TimingKv(eventName+".exec_time", finishTime.Nanoseconds(), meta) } - mw.Logger().WithField("code", errCode).WithField("ns", finishTime.Nanoseconds()).Debug("Finished") + logger.WithField("code", errCode).WithField("ns", finishTime.Nanoseconds()).Debug("Finished") mw.Base().UpdateRequestSession(r) // Special code, bypasses all other execution @@ -233,30 +234,88 @@ func (gw *Gateway) mwList(mws ...TykMiddleware) []alice.Constructor { // BaseMiddleware wraps up the ApiSpec and Proxy objects to be included in a // middleware handler, this can probably be handled better. type BaseMiddleware struct { - Spec *APISpec - Proxy ReturningHttpHandler - logger *logrus.Entry - Gw *Gateway `json:"-"` + Spec *APISpec + Proxy ReturningHttpHandler + Gw *Gateway `json:"-"` + + loggerMu sync.Mutex + logger *logrus.Entry +} + +// NewBaseMiddleware creates a new *BaseMiddleware. +// The passed logrus.Entry is duplicated. +// BaseMiddleware keeps the pointer to *Gateway and *APISpec, as well as Proxy. +// The logger duplication is used so that basemiddleware copies can be created for different middleware. +func NewBaseMiddleware(gw *Gateway, spec *APISpec, proxy ReturningHttpHandler, logger *logrus.Entry) *BaseMiddleware { + if logger == nil { + logger = logrus.NewEntry(log) + } + baseMid := &BaseMiddleware{ + Spec: spec, + Proxy: proxy, + logger: logger.Dup(), + Gw: gw, + } + + for _, v := range baseMid.Spec.VersionData.Versions { + if len(v.ExtendedPaths.CircuitBreaker) > 0 { + baseMid.Spec.CircuitBreakerEnabled = true + } + if len(v.ExtendedPaths.HardTimeouts) > 0 { + baseMid.Spec.EnforcedTimeoutEnabled = true + } + } + + return baseMid } -func (t BaseMiddleware) Base() *BaseMiddleware { - return &t +// Copy provides a new BaseMiddleware with it's own logger scope (copy). +// The Spec, Proxy and Gw values are not copied. +func (m *BaseMiddleware) Copy() *BaseMiddleware { + return &BaseMiddleware{ + logger: m.logger.Dup(), + Spec: m.Spec, + Proxy: m.Proxy, + Gw: m.Gw, + } } -func (t *BaseMiddleware) Logger() (logger *logrus.Entry) { +// Base serves to provide the full BaseMiddleware API. It's part of the TykMiddleware interface. +// It escapes to a wider API surface than TykMiddleware, used by middlewares, etc. +func (t *BaseMiddleware) Base() *BaseMiddleware { + return t +} + +func (t *BaseMiddleware) SetName(name string) { + t.loggerMu.Lock() + defer t.loggerMu.Unlock() + if t.logger == nil { t.logger = logrus.NewEntry(log) } + t.logger = t.logger.WithField("mw", name) +} + +// Logger is used by middleware process functions. +func (t *BaseMiddleware) Logger() (logger *logrus.Entry) { + t.loggerMu.Lock() + defer t.loggerMu.Unlock() + if t.logger == nil { + t.logger = logrus.NewEntry(log) + } return t.logger } -func (t *BaseMiddleware) SetName(name string) { - t.logger = t.Logger().WithField("mw", name) -} +func (t *BaseMiddleware) SetRequestLogger(r *http.Request) *logrus.Entry { + t.loggerMu.Lock() + defer t.loggerMu.Unlock() -func (t *BaseMiddleware) SetRequestLogger(r *http.Request) { - t.logger = t.Gw.getLogEntryForRequest(t.Logger(), r, ctxGetAuthToken(r), nil) + if t.logger == nil { + t.logger = logrus.NewEntry(log) + } + t.logger = t.Gw.getLogEntryForRequest(t.logger, r, ctxGetAuthToken(r), nil) + return t.logger } func (t *BaseMiddleware) Init() {} diff --git a/gateway/mw_go_plugin.go b/gateway/mw_go_plugin.go index c1ed1b3dfdf..a885b9fb575 100644 --- a/gateway/mw_go_plugin.go +++ b/gateway/mw_go_plugin.go @@ -165,7 +165,7 @@ func (m *GoPluginMiddleware) loadPlugin() bool { } // to record 2XX hits in analytics - m.successHandler = &SuccessHandler{BaseMiddleware: m.BaseMiddleware} + m.successHandler = &SuccessHandler{BaseMiddleware: m.BaseMiddleware.Copy()} return true } @@ -189,7 +189,7 @@ func (m *GoPluginMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Reque if pluginMw, found := m.goPluginFromRequest(r); found { logger = pluginMw.logger handler = pluginMw.handler - successHandler = &SuccessHandler{BaseMiddleware: m.BaseMiddleware} + successHandler = &SuccessHandler{BaseMiddleware: m.BaseMiddleware.Copy()} } else { return nil, http.StatusOK // next middleware } diff --git a/gateway/testutil.go b/gateway/testutil.go index 5f980c8f2ec..c916d68d75f 100644 --- a/gateway/testutil.go +++ b/gateway/testutil.go @@ -1727,9 +1727,13 @@ func BuildAPI(apiGens ...func(spec *APISpec)) (specs []*APISpec) { } func (gw *Gateway) LoadAPI(specs ...*APISpec) (out []*APISpec) { + var err error gwConf := gw.GetConfig() oldPath := gwConf.AppPath - gwConf.AppPath, _ = ioutil.TempDir("", "apps") + gwConf.AppPath, err = ioutil.TempDir("", "apps") + if err != nil { + log.WithError(err).Errorf("loadapi: failed to create temp dir") + } gw.SetConfig(gwConf, true) defer func() { globalConf := gw.GetConfig()