diff --git a/implementation/capture/roundtripper.go b/implementation/capture/roundtripper.go new file mode 100644 index 0000000..2dce3cc --- /dev/null +++ b/implementation/capture/roundtripper.go @@ -0,0 +1,45 @@ +package aurestcapture + +import ( + "fmt" + aurestclientapi "github.com/StephanHCB/go-autumn-restclient/api" + "io" + "net/http" + "strings" +) + +func NewRoundTripper(wrapped aurestclientapi.Client) http.RoundTripper { + return &RequestCaptureImpl{Wrapped: wrapped} +} + +func (c *RequestCaptureImpl) RoundTrip(req *http.Request) (*http.Response, error) { + requestStr := fmt.Sprintf("%s %s %v", req.Method, req.URL.String(), req.Body) + c.recording = append(c.recording, requestStr) + + var bodyDto *[]byte + parsedResponse := aurestclientapi.ParsedResponse{ + Body: &bodyDto, + } + + err := c.Wrapped.Perform(req.Context(), req.Method, req.URL.String(), req.Body, &parsedResponse) + + newReader := strings.NewReader(string(**(parsedResponse.Body.(**[]byte)))) + readCloser := io.NopCloser(newReader) + + return &http.Response{ + Status: "", + StatusCode: parsedResponse.Status, + Proto: "", + ProtoMajor: 0, + ProtoMinor: 0, + Header: parsedResponse.Header, + Body: readCloser, + ContentLength: 0, + TransferEncoding: nil, + Close: false, + Uncompressed: false, + Trailer: nil, + Request: nil, + TLS: nil, + }, err +} diff --git a/implementation/httpclient/httpclient.go b/implementation/httpclient/httpclient.go index 7991ae6..eae053c 100644 --- a/implementation/httpclient/httpclient.go +++ b/implementation/httpclient/httpclient.go @@ -10,7 +10,6 @@ import ( aurestnontripping "github.com/StephanHCB/go-autumn-restclient/implementation/errors/nontrippingerror" "github.com/go-http-utils/headers" "io" - "io/ioutil" "net/http" "net/url" "strings" @@ -38,37 +37,32 @@ type HttpClientImpl struct { // If len(customCACert) is 0, the default CA certificates are used, but if you specify it, they are excluded to ensure // only your certs are accepted. func New(timeout time.Duration, customCACert []byte, requestManipulator aurestclientapi.RequestManipulatorCallback) (aurestclientapi.Client, error) { + httpTransport := createHttpTransport(customCACert) + + return &HttpClientImpl{ + HttpClient: &http.Client{ + Transport: httpTransport, + Timeout: timeout, + }, + RequestManipulator: requestManipulator, + Now: time.Now, + RequestMetricsCallback: doNothingMetricsCallback, + ResponseMetricsCallback: doNothingMetricsCallback, + }, nil +} + +func createHttpTransport(customCACert []byte) http.RoundTripper { if len(customCACert) != 0 { caCertPool := x509.NewCertPool() caCertPool.AppendCertsFromPEM(customCACert) - transport := &http.Transport{ + return &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: caCertPool, }, } - - return &HttpClientImpl{ - HttpClient: &http.Client{ - Transport: transport, - Timeout: timeout, - }, - RequestManipulator: requestManipulator, - Now: time.Now, - RequestMetricsCallback: doNothingMetricsCallback, - ResponseMetricsCallback: doNothingMetricsCallback, - }, nil - } else { - return &HttpClientImpl{ - HttpClient: &http.Client{ - Timeout: timeout, - }, - RequestManipulator: requestManipulator, - Now: time.Now, - RequestMetricsCallback: doNothingMetricsCallback, - ResponseMetricsCallback: doNothingMetricsCallback, - }, nil } + return http.DefaultTransport } // Instrument adds instrumentation to a http client. @@ -134,7 +128,7 @@ func (c *HttpClientImpl) Perform(ctx context.Context, method string, requestUrl response.Header = responseInternal.Header response.Status = responseInternal.StatusCode - responseBody, err := ioutil.ReadAll(responseInternal.Body) + responseBody, err := io.ReadAll(responseInternal.Body) if err != nil { _ = responseInternal.Body.Close() c.ResponseMetricsCallback(ctx, method, requestUrl, response.Status, err, c.Now().Sub(response.Time), 0) diff --git a/implementation/httpclient/roundtripper.go b/implementation/httpclient/roundtripper.go new file mode 100644 index 0000000..7fa00ab --- /dev/null +++ b/implementation/httpclient/roundtripper.go @@ -0,0 +1,60 @@ +package auresthttpclient + +import ( + aurestclientapi "github.com/StephanHCB/go-autumn-restclient/api" + "net/http" + "time" +) + +type AuRestHttpClient struct { + *http.Client + + // Now is exposed so tests can fixate the time by overwriting this field + Now func() time.Time +} + +func NewHttpClient(timeout time.Duration, customCACert []byte, + requestManipulator aurestclientapi.RequestManipulatorCallback, customHttpTransport *http.RoundTripper) (*AuRestHttpClient, error) { + + var httpTransport http.RoundTripper + if customHttpTransport == nil { + httpTransport = &HttpClientRoundTripper{ + wrapped: createHttpTransport(customCACert), + RequestManipulator: requestManipulator, + RequestMetricsCallback: doNothingMetricsCallback, + ResponseMetricsCallback: doNothingMetricsCallback, + } + } else { + httpTransport = *customHttpTransport + } + + return &AuRestHttpClient{ + Client: &http.Client{ + Transport: httpTransport, + Timeout: timeout, + }, + Now: time.Now, + }, nil +} + +type HttpClientRoundTripper struct { + wrapped http.RoundTripper + + RequestManipulator aurestclientapi.RequestManipulatorCallback + RequestMetricsCallback aurestclientapi.MetricsCallbackFunction + ResponseMetricsCallback aurestclientapi.MetricsCallbackFunction +} + +func (c *HttpClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if c.RequestManipulator != nil { + c.RequestManipulator(req.Context(), req) + } + + c.RequestMetricsCallback(req.Context(), req.Method, req.URL.String(), 0, nil, 0, int(req.ContentLength)) + + response, err := c.wrapped.RoundTrip(req) + + c.ResponseMetricsCallback(req.Context(), req.Method, req.URL.String(), 0, nil, 0, int(req.ContentLength)) + + return response, err +} diff --git a/implementation/playback/playback.go b/implementation/playback/playback.go index 357890f..ddbfddf 100644 --- a/implementation/playback/playback.go +++ b/implementation/playback/playback.go @@ -36,6 +36,17 @@ type PlaybackOptions struct { // // You can optionally add a PlaybackOptions instance to your call. The ... is really just so it's an optional argument. func New(recorderPath string, additionalOptions ...PlaybackOptions) aurestclientapi.Client { + recorderRewritePath, filenameCandidates, nowFunc := initRecorderPathAndFilenameFunc(additionalOptions) + + return &PlaybackImpl{ + RecorderPath: recorderPath, + RecorderRewritePath: recorderRewritePath, + ConstructFilenameCandidates: filenameCandidates, + Now: nowFunc, + } +} + +func initRecorderPathAndFilenameFunc(additionalOptions []PlaybackOptions) (string, []aurestrecorder.ConstructFilenameFunction, func() time.Time) { filenameCandidates := []aurestrecorder.ConstructFilenameFunction{ aurestrecorder.ConstructFilenameV3WithBody, aurestrecorder.ConstructFilenameWithBody, @@ -50,12 +61,8 @@ func New(recorderPath string, additionalOptions ...PlaybackOptions) aurestclient nowFunc = o.NowFunc } } - return &PlaybackImpl{ - RecorderPath: recorderPath, - RecorderRewritePath: os.Getenv(PlaybackRewritePathEnvVariable), - ConstructFilenameCandidates: filenameCandidates, - Now: nowFunc, - } + recorderRewritePath := os.Getenv(PlaybackRewritePathEnvVariable) + return recorderRewritePath, filenameCandidates, nowFunc } func (c *PlaybackImpl) Perform(ctx context.Context, method string, requestUrl string, requestBody interface{}, response *aurestclientapi.ParsedResponse) error { diff --git a/implementation/recorder/recorder.go b/implementation/recorder/recorder.go index fe76bd3..ddea867 100644 --- a/implementation/recorder/recorder.go +++ b/implementation/recorder/recorder.go @@ -36,6 +36,15 @@ type RecorderOptions struct { // // You can optionally add a RecorderOptions instance to your call. The ... is really just so it's an optional argument. func New(wrapped aurestclientapi.Client, additionalOptions ...RecorderOptions) aurestclientapi.Client { + recorderPath, filenameFunc := initRecorderPathAndFilenameFunc(additionalOptions) + return &RecorderImpl{ + Wrapped: wrapped, + RecorderPath: recorderPath, + ConstructFilenameFunc: filenameFunc, + } +} + +func initRecorderPathAndFilenameFunc(additionalOptions []RecorderOptions) (string, ConstructFilenameFunction) { recorderPath := os.Getenv(RecorderPathEnvVariable) if recorderPath != "" { if !strings.HasSuffix(recorderPath, "/") { @@ -48,11 +57,7 @@ func New(wrapped aurestclientapi.Client, additionalOptions ...RecorderOptions) a filenameFunc = o.ConstructFilenameFunc } } - return &RecorderImpl{ - Wrapped: wrapped, - RecorderPath: recorderPath, - ConstructFilenameFunc: filenameFunc, - } + return recorderPath, filenameFunc } type RecorderData struct { @@ -65,8 +70,16 @@ type RecorderData struct { func (c *RecorderImpl) Perform(ctx context.Context, method string, requestUrl string, requestBody interface{}, response *aurestclientapi.ParsedResponse) error { responseErr := c.Wrapped.Perform(ctx, method, requestUrl, requestBody, response) - if c.RecorderPath != "" { - filename, err := c.ConstructFilenameFunc(method, requestUrl, requestBody) + + recordResponseData(method, requestUrl, requestBody, response, responseErr, c.RecorderPath, c.ConstructFilenameFunc) + return responseErr +} + +func recordResponseData(method string, requestUrl string, requestBody interface{}, + response *aurestclientapi.ParsedResponse, responseErr error, + recorderPath string, constructFilenameFunc ConstructFilenameFunction) { + if recorderPath != "" { + filename, err := constructFilenameFunc(method, requestUrl, requestBody) if err == nil { recording := RecorderData{ Method: method, @@ -78,11 +91,10 @@ func (c *RecorderImpl) Perform(ctx context.Context, method string, requestUrl st jsonRecording, err := json.MarshalIndent(&recording, "", " ") if err == nil { - _ = os.WriteFile(c.RecorderPath+filename, jsonRecording, 0644) + _ = os.WriteFile(recorderPath+filename, jsonRecording, 0644) } } } - return responseErr } func ConstructFilename(method string, requestUrl string) (string, error) { diff --git a/implementation/recorder/roundtripper.go b/implementation/recorder/roundtripper.go new file mode 100644 index 0000000..a87ba46 --- /dev/null +++ b/implementation/recorder/roundtripper.go @@ -0,0 +1,63 @@ +package aurestrecorder + +import ( + "bytes" + aurestclientapi "github.com/StephanHCB/go-autumn-restclient/api" + "io" + "net/http" + "strings" + "time" +) + +type RecorderRoundTripper struct { + wrapped http.RoundTripper + recorderPath string + constructFilenameFunc ConstructFilenameFunction +} + +func NewRecorderRoundTripper(wrapped http.RoundTripper, additionalOptions ...RecorderOptions) *RecorderRoundTripper { + recorderPath, filenameFunc := initRecorderPathAndFilenameFunc(additionalOptions) + return &RecorderRoundTripper{ + wrapped: wrapped, + recorderPath: recorderPath, + constructFilenameFunc: filenameFunc, + } +} + +func (c *RecorderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + response, err := c.wrapped.RoundTrip(req) + + if response != nil && c.recorderPath != "" { + parsedResponse := aurestclientapi.ParsedResponse{ + Body: string(readBodyAndReset(response)), + Status: response.StatusCode, + Header: response.Header, + Time: time.Now(), + } + + var requestBodyString string + var requestBody io.ReadCloser + if req.Body != nil { + requestBody, _ = req.GetBody() + requestBodyString = readBody(requestBody) + } + recordResponseData(req.Method, req.URL.String(), requestBodyString, &parsedResponse, err, c.recorderPath, c.constructFilenameFunc) + } + return response, err +} + +func readBodyAndReset(res *http.Response) []byte { + bodyBytes, _ := io.ReadAll(res.Body) + //reset the response body to the original unread state + res.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + return bodyBytes +} + +func readBody(requestBody io.ReadCloser) string { + if requestBody != nil { + buf := new(strings.Builder) + _, _ = io.Copy(buf, requestBody) + return buf.String() + } + return "" +} diff --git a/implementation/requestlogging/requestlogging.go b/implementation/requestlogging/requestlogging.go index de1951c..853383f 100644 --- a/implementation/requestlogging/requestlogging.go +++ b/implementation/requestlogging/requestlogging.go @@ -71,18 +71,28 @@ func New(wrapped aurestclientapi.Client) aurestclientapi.Client { } func (c *RequestLoggingImpl) Perform(ctx context.Context, method string, requestUrl string, requestBody interface{}, response *aurestclientapi.ParsedResponse) error { - c.Options.BeforeRequest(ctx).Printf("downstream %s %s...", method, requestUrl) - before := time.Now() + startTime := logRequest(ctx, method, requestUrl, &c.Options) + err := c.Wrapped.Perform(ctx, method, requestUrl, requestBody, response) - millis := time.Now().Sub(before).Milliseconds() + + logResponse(ctx, method, requestUrl, response.Status, err, startTime, &c.Options) + return err +} + +func logRequest(ctx context.Context, method string, requestUrl string, opts *RequestLoggingOptions) time.Time { + opts.BeforeRequest(ctx).Printf("downstream %s %s...", method, requestUrl) + return time.Now() +} + +func logResponse(ctx context.Context, method string, requestUrl string, responseStatusCode int, err error, startTime time.Time, opts *RequestLoggingOptions) { + reqDuration := time.Now().Sub(startTime).Milliseconds() if err != nil { if aurestnontripping.Is(err) { - c.Options.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms) (nontripping)", method, requestUrl, response.Status, millis) + opts.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms) (nontripping)", method, requestUrl, responseStatusCode, reqDuration) } else { - c.Options.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms)", method, requestUrl, response.Status, millis) + opts.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms)", method, requestUrl, responseStatusCode, reqDuration) } } else { - c.Options.Success(ctx).Printf("downstream %s %s -> %d OK (%d ms)", method, requestUrl, response.Status, millis) + opts.Success(ctx).Printf("downstream %s %s -> %d OK (%d ms)", method, requestUrl, responseStatusCode, reqDuration) } - return err } diff --git a/implementation/requestlogging/roundtripper.go b/implementation/requestlogging/roundtripper.go new file mode 100644 index 0000000..81d554f --- /dev/null +++ b/implementation/requestlogging/roundtripper.go @@ -0,0 +1,53 @@ +package aurestlogging + +import ( + "net/http" +) + +type LoggingRoundTripper struct { + wrapped http.RoundTripper + Options RequestLoggingOptions +} + +func NewLoggingRoundTripper(wrapped http.RoundTripper) *LoggingRoundTripper { + return NewLoggingRoundTripperWithOpts(wrapped, defaultOpts()) +} + +func NewLoggingRoundTripperWithOpts(wrapped http.RoundTripper, opts RequestLoggingOptions) *LoggingRoundTripper { + instance := &LoggingRoundTripper{ + wrapped: wrapped, + Options: defaultOpts(), + } + if opts.BeforeRequest != nil { + instance.Options.BeforeRequest = opts.BeforeRequest + } + if opts.Success != nil { + instance.Options.Success = opts.Success + } + if opts.Failure != nil { + instance.Options.Failure = opts.Failure + } + return instance +} + +func defaultOpts() RequestLoggingOptions { + return RequestLoggingOptions{ + BeforeRequest: Debug, + Success: Info, + Failure: Warn, + } +} + +func (c *LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + startTime := logRequest(req.Context(), req.Method, req.URL.String(), &c.Options) + + response, err := c.wrapped.RoundTrip(req) + + statusCode := 0 + if response != nil { + statusCode = response.StatusCode + } + logResponse(req.Context(), req.Method, req.URL.String(), statusCode, err, startTime, &c.Options) + + return response, err +}