Skip to content

Commit

Permalink
feat: draft implementation of request validation
Browse files Browse the repository at this point in the history
  • Loading branch information
areknoster committed Mar 23, 2024
1 parent d020eec commit d1c0e09
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 43 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# hypert
Opinionated go package for rapid testing of real HTTP APIs integrations.
Zero deps.

[![build-img]][build-url]
[![pkg-img]][pkg-url]
Expand Down
7 changes: 3 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package hypert

import (
"fmt"
"net/http"
"path"
"runtime"
Expand Down Expand Up @@ -86,7 +85,7 @@ func TestClient(t *testing.T, recordModeOn bool, opts ...Option) *http.Client {
t.Logf("hypert: using sequential naming scheme in %s directory", requestsDir)
scheme, err := NewSequentialNamingScheme(requestsDir)
if err != nil {
t.Fatal(fmt.Errorf("failed to create naming scheme: %w", err))
t.Fatalf("failed to create naming scheme: %s", err.Error())
}
cfg.namingScheme = scheme
}
Expand All @@ -100,10 +99,10 @@ func TestClient(t *testing.T, recordModeOn bool, opts ...Option) *http.Client {
var transport http.RoundTripper
if cfg.isRecordMode {
t.Log("hypert: record request mode - requests will be stored")
transport = newRecordTransport(cfg.parentHTTPClient.Transport, cfg.namingScheme, cfg.requestSanitizer)
transport = newRecordTransport(t, cfg.parentHTTPClient.Transport, cfg.namingScheme, cfg.requestSanitizer)
} else {
t.Log("hypert: replay request mode - requests will be read from previously stored files.")
transport = newReplayTransport(cfg.namingScheme)
transport = newReplayTransport(t, cfg.namingScheme, ComposedRequestValidator()) // todo: add default validators
}
cfg.parentHTTPClient.Transport = transport
return cfg.parentHTTPClient
Expand Down
34 changes: 2 additions & 32 deletions naming_scheme.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,12 @@ package hypert

import (
"fmt"
"net/http"
"net/url"
"os"
"path"
"strconv"
"sync"
)

// RequestMeta is some data related to the request, that can be used to create filename in the NamingScheme's FileNames method implementations.
// The fields are cloned from request's URL and their modification will not affect actual request's values.
type RequestMeta struct {
Header http.Header
URL *url.URL
}

func cloneURL(u *url.URL) *url.URL {
if u == nil { // this shouldn't actually happen, unless there is very weird injected clients' transport setup
return nil
}
var userInfo *url.Userinfo
if u.User != nil {
userInfoCopy := *u.User
userInfo = &userInfoCopy
}
uCopy := *u
uCopy.User = userInfo
return &uCopy
}

func requestMetaFromRequest(req *http.Request) RequestMeta {
return RequestMeta{
Header: req.Header.Clone(),
URL: cloneURL(req.URL),
}
}

// NamingScheme defines an interface that is used by hypert's test client to store or retrieve files with HTTP requests.
//
// FileNames returns a pair of filenames that request and response should be stored in, when Record Mode is active, and retrieved from when Replay Mode is active.
Expand All @@ -48,7 +18,7 @@ func requestMetaFromRequest(req *http.Request) RequestMeta {
//
// This method should be safe for concurrent use. This requirement can be skipped, if you are the user of the package, and know, that all invocations would be sequential.
type NamingScheme interface {
FileNames(RequestMeta) (reqFile, respFile string)
FileNames(RequestData) (reqFile, respFile string)
}

// SequentialNamingScheme should be initialized using NewSequentialNamingScheme function.
Expand Down Expand Up @@ -76,7 +46,7 @@ func NewSequentialNamingScheme(dir string) (*SequentialNamingScheme, error) {
}, nil
}

func (s *SequentialNamingScheme) FileNames(_ RequestMeta) (reqFile, respFile string) {
func (s *SequentialNamingScheme) FileNames(_ RequestData) (reqFile, respFile string) {
s.requestIndexMx.Lock()
requestIndex := strconv.Itoa(s.requestIndex)
defer func() {
Expand Down
7 changes: 5 additions & 2 deletions record_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@ import (
"io"
"net/http"
"os"
"testing"
)

type recordTransport struct {
httpTransport http.RoundTripper
namingScheme NamingScheme
sanitizer RequestSanitizer
t *testing.T
}

func newRecordTransport(httpTransport http.RoundTripper, namingScheme NamingScheme, sanitizer RequestSanitizer) *recordTransport {
func newRecordTransport(t *testing.T, httpTransport http.RoundTripper, namingScheme NamingScheme, sanitizer RequestSanitizer) *recordTransport {
return &recordTransport{
t: t,
httpTransport: httpTransport,
namingScheme: namingScheme,
sanitizer: sanitizer,
Expand All @@ -28,7 +31,7 @@ func (d *recordTransport) RoundTrip(req *http.Request) (*http.Response, error) {
d.httpTransport = http.DefaultTransport
}

reqFile, respFile := d.namingScheme.FileNames(requestMetaFromRequest(req))
reqFile, respFile := d.namingScheme.FileNames(requestDataFromRequest(d.t, req))
req, err := d.dumpReqToFile(reqFile, req)
if err != nil {
return nil, err
Expand Down
21 changes: 17 additions & 4 deletions replay_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,34 @@ import (
"io"
"net/http"
"os"
"testing"
)

type replayTransport struct {
scheme NamingScheme
t *testing.T
scheme NamingScheme
validator RequestValidator
}

func newReplayTransport(scheme NamingScheme) *replayTransport {
return &replayTransport{scheme: scheme}
func newReplayTransport(t *testing.T, scheme NamingScheme, validator RequestValidator) *replayTransport {
return &replayTransport{
t: t,
scheme: scheme,
validator: validator,
}
}

func (d *replayTransport) RoundTrip(req *http.Request) (*http.Response, error) {
_, respFile := d.scheme.FileNames(requestMetaFromRequest(req))
reqFile, respFile := d.scheme.FileNames(requestDataFromRequest(d.t, req))
d.readReqFromFile(reqFile)

return d.readRespFromFile(respFile, req)
}

func (d *replayTransport) readReqFromFile(name string) (*http.Response, error) {
return nil, nil
}

func (d *replayTransport) readRespFromFile(name string, req *http.Request) (*http.Response, error) {
f, err := os.OpenFile(name, os.O_RDONLY, 000)
if err != nil {
Expand Down
50 changes: 50 additions & 0 deletions request_data.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package hypert

import (
"bytes"
"io"
"net/http"
"net/url"
"testing"
)

// RequestData is some data related to the request, that can be used to create filename in the NamingScheme's FileNames method implementations or during request validation.
// The fields are cloned from request's fields and their modification will not affect actual request's values.
type RequestData struct {
Header http.Header
URL *url.URL
BodyBytes []byte
}

func cloneURL(u *url.URL) *url.URL {
if u == nil { // this shouldn't actually happen, unless there is very weird injected clients' transport setup
return nil
}
var userInfo *url.Userinfo
if u.User != nil {
userInfoCopy := *u.User
userInfo = &userInfoCopy
}
uCopy := *u
uCopy.User = userInfo
return &uCopy
}

func requestDataFromRequest(t *testing.T, req *http.Request) RequestData {
if req.Body == nil {
req.Body = http.NoBody
}
var originalReqBody bytes.Buffer
teeReader := io.TeeReader(req.Body, &originalReqBody)
req.Body = io.NopCloser(&originalReqBody)
gotBodyBytes, err := io.ReadAll(teeReader)
if err != nil {
t.Fatal("hypert: got error when reading request body")
}

return RequestData{
Header: req.Header.Clone(),
URL: cloneURL(req.URL),
BodyBytes: gotBodyBytes,
}
}
24 changes: 24 additions & 0 deletions request_validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package hypert

import (
"testing"
)

// RequestValidator does assertions, that allows to make assertions on request that was caught by TestClient in the replay mode.
type RequestValidator interface {
Validate(t *testing.T, recorded RequestData, got RequestData)
}

type RequestValidatorFunc func(t *testing.T, recorded RequestData, got RequestData)

func (f RequestValidatorFunc) Validate(t *testing.T, recorded RequestData, got RequestData) {
f(t, recorded, got)
}

func ComposedRequestValidator(validators ...RequestValidator) RequestValidator {
return RequestValidatorFunc(func(t *testing.T, recorded RequestData, got RequestData) {
for _, validator := range validators {
validator.Validate(t, recorded, got)
}
})
}

0 comments on commit d1c0e09

Please sign in to comment.