Skip to content

Commit

Permalink
Merge pull request #32 from StephanHCB/roundtripper
Browse files Browse the repository at this point in the history
Roundtripper implementationw
  • Loading branch information
StephanHCB authored Sep 13, 2024
2 parents 993c0d2 + 012cde9 commit 94ba915
Show file tree
Hide file tree
Showing 8 changed files with 290 additions and 46 deletions.
45 changes: 45 additions & 0 deletions implementation/capture/roundtripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package aurestcapture

import (
"fmt"
aurestclientapi "github.com/StephanHCB/go-autumn-restclient/api"
"io"
"net/http"
"strings"
)

func NewRoundTripper(wrapped aurestclientapi.Client) http.RoundTripper {
return &RequestCaptureImpl{Wrapped: wrapped}
}

func (c *RequestCaptureImpl) RoundTrip(req *http.Request) (*http.Response, error) {
requestStr := fmt.Sprintf("%s %s %v", req.Method, req.URL.String(), req.Body)
c.recording = append(c.recording, requestStr)

var bodyDto *[]byte
parsedResponse := aurestclientapi.ParsedResponse{
Body: &bodyDto,
}

err := c.Wrapped.Perform(req.Context(), req.Method, req.URL.String(), req.Body, &parsedResponse)

newReader := strings.NewReader(string(**(parsedResponse.Body.(**[]byte))))
readCloser := io.NopCloser(newReader)

return &http.Response{
Status: "",
StatusCode: parsedResponse.Status,
Proto: "",
ProtoMajor: 0,
ProtoMinor: 0,
Header: parsedResponse.Header,
Body: readCloser,
ContentLength: 0,
TransferEncoding: nil,
Close: false,
Uncompressed: false,
Trailer: nil,
Request: nil,
TLS: nil,
}, err
}
42 changes: 18 additions & 24 deletions implementation/httpclient/httpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
aurestnontripping "github.com/StephanHCB/go-autumn-restclient/implementation/errors/nontrippingerror"
"github.com/go-http-utils/headers"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -38,37 +37,32 @@ type HttpClientImpl struct {
// If len(customCACert) is 0, the default CA certificates are used, but if you specify it, they are excluded to ensure
// only your certs are accepted.
func New(timeout time.Duration, customCACert []byte, requestManipulator aurestclientapi.RequestManipulatorCallback) (aurestclientapi.Client, error) {
httpTransport := createHttpTransport(customCACert)

return &HttpClientImpl{
HttpClient: &http.Client{
Transport: httpTransport,
Timeout: timeout,
},
RequestManipulator: requestManipulator,
Now: time.Now,
RequestMetricsCallback: doNothingMetricsCallback,
ResponseMetricsCallback: doNothingMetricsCallback,
}, nil
}

func createHttpTransport(customCACert []byte) http.RoundTripper {
if len(customCACert) != 0 {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(customCACert)

transport := &http.Transport{
return &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: caCertPool,
},
}

return &HttpClientImpl{
HttpClient: &http.Client{
Transport: transport,
Timeout: timeout,
},
RequestManipulator: requestManipulator,
Now: time.Now,
RequestMetricsCallback: doNothingMetricsCallback,
ResponseMetricsCallback: doNothingMetricsCallback,
}, nil
} else {
return &HttpClientImpl{
HttpClient: &http.Client{
Timeout: timeout,
},
RequestManipulator: requestManipulator,
Now: time.Now,
RequestMetricsCallback: doNothingMetricsCallback,
ResponseMetricsCallback: doNothingMetricsCallback,
}, nil
}
return http.DefaultTransport
}

// Instrument adds instrumentation to a http client.
Expand Down Expand Up @@ -134,7 +128,7 @@ func (c *HttpClientImpl) Perform(ctx context.Context, method string, requestUrl
response.Header = responseInternal.Header
response.Status = responseInternal.StatusCode

responseBody, err := ioutil.ReadAll(responseInternal.Body)
responseBody, err := io.ReadAll(responseInternal.Body)
if err != nil {
_ = responseInternal.Body.Close()
c.ResponseMetricsCallback(ctx, method, requestUrl, response.Status, err, c.Now().Sub(response.Time), 0)
Expand Down
60 changes: 60 additions & 0 deletions implementation/httpclient/roundtripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package auresthttpclient

import (
aurestclientapi "github.com/StephanHCB/go-autumn-restclient/api"
"net/http"
"time"
)

type AuRestHttpClient struct {
*http.Client

// Now is exposed so tests can fixate the time by overwriting this field
Now func() time.Time
}

func NewHttpClient(timeout time.Duration, customCACert []byte,
requestManipulator aurestclientapi.RequestManipulatorCallback, customHttpTransport *http.RoundTripper) (*AuRestHttpClient, error) {

var httpTransport http.RoundTripper
if customHttpTransport == nil {
httpTransport = &HttpClientRoundTripper{
wrapped: createHttpTransport(customCACert),
RequestManipulator: requestManipulator,
RequestMetricsCallback: doNothingMetricsCallback,
ResponseMetricsCallback: doNothingMetricsCallback,
}
} else {
httpTransport = *customHttpTransport
}

return &AuRestHttpClient{
Client: &http.Client{
Transport: httpTransport,
Timeout: timeout,
},
Now: time.Now,
}, nil
}

type HttpClientRoundTripper struct {
wrapped http.RoundTripper

RequestManipulator aurestclientapi.RequestManipulatorCallback
RequestMetricsCallback aurestclientapi.MetricsCallbackFunction
ResponseMetricsCallback aurestclientapi.MetricsCallbackFunction
}

func (c *HttpClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if c.RequestManipulator != nil {
c.RequestManipulator(req.Context(), req)
}

c.RequestMetricsCallback(req.Context(), req.Method, req.URL.String(), 0, nil, 0, int(req.ContentLength))

response, err := c.wrapped.RoundTrip(req)

c.ResponseMetricsCallback(req.Context(), req.Method, req.URL.String(), 0, nil, 0, int(req.ContentLength))

return response, err
}
19 changes: 13 additions & 6 deletions implementation/playback/playback.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ type PlaybackOptions struct {
//
// You can optionally add a PlaybackOptions instance to your call. The ... is really just so it's an optional argument.
func New(recorderPath string, additionalOptions ...PlaybackOptions) aurestclientapi.Client {
recorderRewritePath, filenameCandidates, nowFunc := initRecorderPathAndFilenameFunc(additionalOptions)

return &PlaybackImpl{
RecorderPath: recorderPath,
RecorderRewritePath: recorderRewritePath,
ConstructFilenameCandidates: filenameCandidates,
Now: nowFunc,
}
}

func initRecorderPathAndFilenameFunc(additionalOptions []PlaybackOptions) (string, []aurestrecorder.ConstructFilenameFunction, func() time.Time) {
filenameCandidates := []aurestrecorder.ConstructFilenameFunction{
aurestrecorder.ConstructFilenameV3WithBody,
aurestrecorder.ConstructFilenameWithBody,
Expand All @@ -50,12 +61,8 @@ func New(recorderPath string, additionalOptions ...PlaybackOptions) aurestclient
nowFunc = o.NowFunc
}
}
return &PlaybackImpl{
RecorderPath: recorderPath,
RecorderRewritePath: os.Getenv(PlaybackRewritePathEnvVariable),
ConstructFilenameCandidates: filenameCandidates,
Now: nowFunc,
}
recorderRewritePath := os.Getenv(PlaybackRewritePathEnvVariable)
return recorderRewritePath, filenameCandidates, nowFunc
}

func (c *PlaybackImpl) Perform(ctx context.Context, method string, requestUrl string, requestBody interface{}, response *aurestclientapi.ParsedResponse) error {
Expand Down
30 changes: 21 additions & 9 deletions implementation/recorder/recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ type RecorderOptions struct {
//
// You can optionally add a RecorderOptions instance to your call. The ... is really just so it's an optional argument.
func New(wrapped aurestclientapi.Client, additionalOptions ...RecorderOptions) aurestclientapi.Client {
recorderPath, filenameFunc := initRecorderPathAndFilenameFunc(additionalOptions)
return &RecorderImpl{
Wrapped: wrapped,
RecorderPath: recorderPath,
ConstructFilenameFunc: filenameFunc,
}
}

func initRecorderPathAndFilenameFunc(additionalOptions []RecorderOptions) (string, ConstructFilenameFunction) {
recorderPath := os.Getenv(RecorderPathEnvVariable)
if recorderPath != "" {
if !strings.HasSuffix(recorderPath, "/") {
Expand All @@ -48,11 +57,7 @@ func New(wrapped aurestclientapi.Client, additionalOptions ...RecorderOptions) a
filenameFunc = o.ConstructFilenameFunc
}
}
return &RecorderImpl{
Wrapped: wrapped,
RecorderPath: recorderPath,
ConstructFilenameFunc: filenameFunc,
}
return recorderPath, filenameFunc
}

type RecorderData struct {
Expand All @@ -65,8 +70,16 @@ type RecorderData struct {

func (c *RecorderImpl) Perform(ctx context.Context, method string, requestUrl string, requestBody interface{}, response *aurestclientapi.ParsedResponse) error {
responseErr := c.Wrapped.Perform(ctx, method, requestUrl, requestBody, response)
if c.RecorderPath != "" {
filename, err := c.ConstructFilenameFunc(method, requestUrl, requestBody)

recordResponseData(method, requestUrl, requestBody, response, responseErr, c.RecorderPath, c.ConstructFilenameFunc)
return responseErr
}

func recordResponseData(method string, requestUrl string, requestBody interface{},
response *aurestclientapi.ParsedResponse, responseErr error,
recorderPath string, constructFilenameFunc ConstructFilenameFunction) {
if recorderPath != "" {
filename, err := constructFilenameFunc(method, requestUrl, requestBody)
if err == nil {
recording := RecorderData{
Method: method,
Expand All @@ -78,11 +91,10 @@ func (c *RecorderImpl) Perform(ctx context.Context, method string, requestUrl st

jsonRecording, err := json.MarshalIndent(&recording, "", " ")
if err == nil {
_ = os.WriteFile(c.RecorderPath+filename, jsonRecording, 0644)
_ = os.WriteFile(recorderPath+filename, jsonRecording, 0644)
}
}
}
return responseErr
}

func ConstructFilename(method string, requestUrl string) (string, error) {
Expand Down
63 changes: 63 additions & 0 deletions implementation/recorder/roundtripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package aurestrecorder

import (
"bytes"
aurestclientapi "github.com/StephanHCB/go-autumn-restclient/api"
"io"
"net/http"
"strings"
"time"
)

type RecorderRoundTripper struct {
wrapped http.RoundTripper
recorderPath string
constructFilenameFunc ConstructFilenameFunction
}

func NewRecorderRoundTripper(wrapped http.RoundTripper, additionalOptions ...RecorderOptions) *RecorderRoundTripper {
recorderPath, filenameFunc := initRecorderPathAndFilenameFunc(additionalOptions)
return &RecorderRoundTripper{
wrapped: wrapped,
recorderPath: recorderPath,
constructFilenameFunc: filenameFunc,
}
}

func (c *RecorderRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
response, err := c.wrapped.RoundTrip(req)

if response != nil && c.recorderPath != "" {
parsedResponse := aurestclientapi.ParsedResponse{
Body: string(readBodyAndReset(response)),
Status: response.StatusCode,
Header: response.Header,
Time: time.Now(),
}

var requestBodyString string
var requestBody io.ReadCloser
if req.Body != nil {
requestBody, _ = req.GetBody()
requestBodyString = readBody(requestBody)
}
recordResponseData(req.Method, req.URL.String(), requestBodyString, &parsedResponse, err, c.recorderPath, c.constructFilenameFunc)
}
return response, err
}

func readBodyAndReset(res *http.Response) []byte {
bodyBytes, _ := io.ReadAll(res.Body)
//reset the response body to the original unread state
res.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
return bodyBytes
}

func readBody(requestBody io.ReadCloser) string {
if requestBody != nil {
buf := new(strings.Builder)
_, _ = io.Copy(buf, requestBody)
return buf.String()
}
return ""
}
24 changes: 17 additions & 7 deletions implementation/requestlogging/requestlogging.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,28 @@ func New(wrapped aurestclientapi.Client) aurestclientapi.Client {
}

func (c *RequestLoggingImpl) Perform(ctx context.Context, method string, requestUrl string, requestBody interface{}, response *aurestclientapi.ParsedResponse) error {
c.Options.BeforeRequest(ctx).Printf("downstream %s %s...", method, requestUrl)
before := time.Now()
startTime := logRequest(ctx, method, requestUrl, &c.Options)

err := c.Wrapped.Perform(ctx, method, requestUrl, requestBody, response)
millis := time.Now().Sub(before).Milliseconds()

logResponse(ctx, method, requestUrl, response.Status, err, startTime, &c.Options)
return err
}

func logRequest(ctx context.Context, method string, requestUrl string, opts *RequestLoggingOptions) time.Time {
opts.BeforeRequest(ctx).Printf("downstream %s %s...", method, requestUrl)
return time.Now()
}

func logResponse(ctx context.Context, method string, requestUrl string, responseStatusCode int, err error, startTime time.Time, opts *RequestLoggingOptions) {
reqDuration := time.Now().Sub(startTime).Milliseconds()
if err != nil {
if aurestnontripping.Is(err) {
c.Options.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms) (nontripping)", method, requestUrl, response.Status, millis)
opts.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms) (nontripping)", method, requestUrl, responseStatusCode, reqDuration)
} else {
c.Options.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms)", method, requestUrl, response.Status, millis)
opts.Failure(ctx).WithErr(err).Printf("downstream %s %s -> %d FAILED (%d ms)", method, requestUrl, responseStatusCode, reqDuration)
}
} else {
c.Options.Success(ctx).Printf("downstream %s %s -> %d OK (%d ms)", method, requestUrl, response.Status, millis)
opts.Success(ctx).Printf("downstream %s %s -> %d OK (%d ms)", method, requestUrl, responseStatusCode, reqDuration)
}
return err
}
Loading

0 comments on commit 94ba915

Please sign in to comment.