From d1c0e090e2454f7fa85fe035d9fa74bf0444a3da Mon Sep 17 00:00:00 2001 From: Arkadiusz Noster Date: Sat, 23 Mar 2024 23:43:58 +0100 Subject: [PATCH] feat: draft implementation of request validation --- README.md | 1 - client.go | 7 +++---- naming_scheme.go | 34 ++---------------------------- record_transport.go | 7 +++++-- replay_transport.go | 21 +++++++++++++++---- request_data.go | 50 ++++++++++++++++++++++++++++++++++++++++++++ request_validator.go | 24 +++++++++++++++++++++ 7 files changed, 101 insertions(+), 43 deletions(-) create mode 100644 request_data.go create mode 100644 request_validator.go diff --git a/README.md b/README.md index 3f60339..aa34e42 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ # hypert Opinionated go package for rapid testing of real HTTP APIs integrations. -Zero deps. [![build-img]][build-url] [![pkg-img]][pkg-url] diff --git a/client.go b/client.go index 4631b80..26241d6 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,6 @@ package hypert import ( - "fmt" "net/http" "path" "runtime" @@ -86,7 +85,7 @@ func TestClient(t *testing.T, recordModeOn bool, opts ...Option) *http.Client { t.Logf("hypert: using sequential naming scheme in %s directory", requestsDir) scheme, err := NewSequentialNamingScheme(requestsDir) if err != nil { - t.Fatal(fmt.Errorf("failed to create naming scheme: %w", err)) + t.Fatalf("failed to create naming scheme: %s", err.Error()) } cfg.namingScheme = scheme } @@ -100,10 +99,10 @@ func TestClient(t *testing.T, recordModeOn bool, opts ...Option) *http.Client { var transport http.RoundTripper if cfg.isRecordMode { t.Log("hypert: record request mode - requests will be stored") - transport = newRecordTransport(cfg.parentHTTPClient.Transport, cfg.namingScheme, cfg.requestSanitizer) + transport = newRecordTransport(t, cfg.parentHTTPClient.Transport, cfg.namingScheme, cfg.requestSanitizer) } else { t.Log("hypert: replay request mode - requests will be read from previously stored files.") - transport = newReplayTransport(cfg.namingScheme) + transport = newReplayTransport(t, cfg.namingScheme, ComposedRequestValidator()) // todo: add default validators } cfg.parentHTTPClient.Transport = transport return cfg.parentHTTPClient diff --git a/naming_scheme.go b/naming_scheme.go index 6b5c7f1..e94ad34 100644 --- a/naming_scheme.go +++ b/naming_scheme.go @@ -2,42 +2,12 @@ package hypert import ( "fmt" - "net/http" - "net/url" "os" "path" "strconv" "sync" ) -// RequestMeta is some data related to the request, that can be used to create filename in the NamingScheme's FileNames method implementations. -// The fields are cloned from request's URL and their modification will not affect actual request's values. -type RequestMeta struct { - Header http.Header - URL *url.URL -} - -func cloneURL(u *url.URL) *url.URL { - if u == nil { // this shouldn't actually happen, unless there is very weird injected clients' transport setup - return nil - } - var userInfo *url.Userinfo - if u.User != nil { - userInfoCopy := *u.User - userInfo = &userInfoCopy - } - uCopy := *u - uCopy.User = userInfo - return &uCopy -} - -func requestMetaFromRequest(req *http.Request) RequestMeta { - return RequestMeta{ - Header: req.Header.Clone(), - URL: cloneURL(req.URL), - } -} - // NamingScheme defines an interface that is used by hypert's test client to store or retrieve files with HTTP requests. // // FileNames returns a pair of filenames that request and response should be stored in, when Record Mode is active, and retrieved from when Replay Mode is active. @@ -48,7 +18,7 @@ func requestMetaFromRequest(req *http.Request) RequestMeta { // // This method should be safe for concurrent use. This requirement can be skipped, if you are the user of the package, and know, that all invocations would be sequential. type NamingScheme interface { - FileNames(RequestMeta) (reqFile, respFile string) + FileNames(RequestData) (reqFile, respFile string) } // SequentialNamingScheme should be initialized using NewSequentialNamingScheme function. @@ -76,7 +46,7 @@ func NewSequentialNamingScheme(dir string) (*SequentialNamingScheme, error) { }, nil } -func (s *SequentialNamingScheme) FileNames(_ RequestMeta) (reqFile, respFile string) { +func (s *SequentialNamingScheme) FileNames(_ RequestData) (reqFile, respFile string) { s.requestIndexMx.Lock() requestIndex := strconv.Itoa(s.requestIndex) defer func() { diff --git a/record_transport.go b/record_transport.go index ff160f6..a65e79e 100644 --- a/record_transport.go +++ b/record_transport.go @@ -7,16 +7,19 @@ import ( "io" "net/http" "os" + "testing" ) type recordTransport struct { httpTransport http.RoundTripper namingScheme NamingScheme sanitizer RequestSanitizer + t *testing.T } -func newRecordTransport(httpTransport http.RoundTripper, namingScheme NamingScheme, sanitizer RequestSanitizer) *recordTransport { +func newRecordTransport(t *testing.T, httpTransport http.RoundTripper, namingScheme NamingScheme, sanitizer RequestSanitizer) *recordTransport { return &recordTransport{ + t: t, httpTransport: httpTransport, namingScheme: namingScheme, sanitizer: sanitizer, @@ -28,7 +31,7 @@ func (d *recordTransport) RoundTrip(req *http.Request) (*http.Response, error) { d.httpTransport = http.DefaultTransport } - reqFile, respFile := d.namingScheme.FileNames(requestMetaFromRequest(req)) + reqFile, respFile := d.namingScheme.FileNames(requestDataFromRequest(d.t, req)) req, err := d.dumpReqToFile(reqFile, req) if err != nil { return nil, err diff --git a/replay_transport.go b/replay_transport.go index c6310a6..92b00ee 100644 --- a/replay_transport.go +++ b/replay_transport.go @@ -6,21 +6,34 @@ import ( "io" "net/http" "os" + "testing" ) type replayTransport struct { - scheme NamingScheme + t *testing.T + scheme NamingScheme + validator RequestValidator } -func newReplayTransport(scheme NamingScheme) *replayTransport { - return &replayTransport{scheme: scheme} +func newReplayTransport(t *testing.T, scheme NamingScheme, validator RequestValidator) *replayTransport { + return &replayTransport{ + t: t, + scheme: scheme, + validator: validator, + } } func (d *replayTransport) RoundTrip(req *http.Request) (*http.Response, error) { - _, respFile := d.scheme.FileNames(requestMetaFromRequest(req)) + reqFile, respFile := d.scheme.FileNames(requestDataFromRequest(d.t, req)) + d.readReqFromFile(reqFile) + return d.readRespFromFile(respFile, req) } +func (d *replayTransport) readReqFromFile(name string) (*http.Response, error) { + return nil, nil +} + func (d *replayTransport) readRespFromFile(name string, req *http.Request) (*http.Response, error) { f, err := os.OpenFile(name, os.O_RDONLY, 000) if err != nil { diff --git a/request_data.go b/request_data.go new file mode 100644 index 0000000..4de1c1e --- /dev/null +++ b/request_data.go @@ -0,0 +1,50 @@ +package hypert + +import ( + "bytes" + "io" + "net/http" + "net/url" + "testing" +) + +// RequestData is some data related to the request, that can be used to create filename in the NamingScheme's FileNames method implementations or during request validation. +// The fields are cloned from request's fields and their modification will not affect actual request's values. +type RequestData struct { + Header http.Header + URL *url.URL + BodyBytes []byte +} + +func cloneURL(u *url.URL) *url.URL { + if u == nil { // this shouldn't actually happen, unless there is very weird injected clients' transport setup + return nil + } + var userInfo *url.Userinfo + if u.User != nil { + userInfoCopy := *u.User + userInfo = &userInfoCopy + } + uCopy := *u + uCopy.User = userInfo + return &uCopy +} + +func requestDataFromRequest(t *testing.T, req *http.Request) RequestData { + if req.Body == nil { + req.Body = http.NoBody + } + var originalReqBody bytes.Buffer + teeReader := io.TeeReader(req.Body, &originalReqBody) + req.Body = io.NopCloser(&originalReqBody) + gotBodyBytes, err := io.ReadAll(teeReader) + if err != nil { + t.Fatal("hypert: got error when reading request body") + } + + return RequestData{ + Header: req.Header.Clone(), + URL: cloneURL(req.URL), + BodyBytes: gotBodyBytes, + } +} diff --git a/request_validator.go b/request_validator.go new file mode 100644 index 0000000..d541ea5 --- /dev/null +++ b/request_validator.go @@ -0,0 +1,24 @@ +package hypert + +import ( + "testing" +) + +// RequestValidator does assertions, that allows to make assertions on request that was caught by TestClient in the replay mode. +type RequestValidator interface { + Validate(t *testing.T, recorded RequestData, got RequestData) +} + +type RequestValidatorFunc func(t *testing.T, recorded RequestData, got RequestData) + +func (f RequestValidatorFunc) Validate(t *testing.T, recorded RequestData, got RequestData) { + f(t, recorded, got) +} + +func ComposedRequestValidator(validators ...RequestValidator) RequestValidator { + return RequestValidatorFunc(func(t *testing.T, recorded RequestData, got RequestData) { + for _, validator := range validators { + validator.Validate(t, recorded, got) + } + }) +}