Skip to content

Commit

Permalink
agent/http: refactor proxy for simpler error handling (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
roobre authored Aug 7, 2023
1 parent 5af6e37 commit 72c4e48
Showing 1 changed file with 66 additions and 48 deletions.
114 changes: 66 additions & 48 deletions pkg/agent/protocol/http/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,6 @@ func NewProxy(c ProxyConfig, d Disruption) (protocol.Proxy, error) {
}, nil
}

// contains verifies if a list of strings contains the given string
func contains(list []string, target string) bool {
for _, element := range list {
if element == target {
return true
}
}
return false
}

// httpClient defines the method for executing HTTP requests. It is used to allow mocking
// the client in tests
type httpClient interface {
Expand All @@ -97,56 +87,84 @@ type httpHandler struct {
client httpClient
}

func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
var statusCode int
headers := http.Header{}
body := io.NopCloser(strings.NewReader(h.disruption.ErrorBody))

excluded := contains(h.disruption.Excluded, req.URL.Path)

if !excluded && h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate {
// force error code
statusCode = int(h.disruption.ErrorCode)
} else {
req.Host = h.upstreamURL.Host
req.URL.Host = h.upstreamURL.Host
req.URL.Scheme = h.upstreamURL.Scheme
req.RequestURI = ""
originServerResponse, srvErr := h.client.Do(req)
if srvErr != nil {
rw.WriteHeader(http.StatusInternalServerError)
_, _ = fmt.Fprint(rw, srvErr)
return
// isExcluded checks whether a request should be proxied through without any kind of modification whatsoever.
func (h *httpHandler) isExcluded(r *http.Request) bool {
for _, excluded := range h.disruption.Excluded {
if strings.EqualFold(r.URL.Path, excluded) {
return true
}
}

headers = originServerResponse.Header
statusCode = originServerResponse.StatusCode
body = originServerResponse.Body
return false
}

defer func() {
_ = originServerResponse.Body.Close()
}()
}
// forward forwards a request to the upstream URL.
// Request is performed immediately, but response won't be sent before the duration specified in delay.
func (h *httpHandler) forward(rw http.ResponseWriter, req *http.Request, delay time.Duration) {
timer := time.After(delay)

if !excluded && h.disruption.AverageDelay > 0 {
delay := int64(h.disruption.AverageDelay)
if h.disruption.DelayVariation > 0 {
variation := int64(h.disruption.DelayVariation)
delay = delay + variation - 2*rand.Int63n(variation)
}
time.Sleep(time.Duration(delay))
upstreamReq := req.Clone(context.Background())
upstreamReq.Host = h.upstreamURL.Host
upstreamReq.URL.Host = h.upstreamURL.Host
upstreamReq.URL.Scheme = h.upstreamURL.Scheme
upstreamReq.RequestURI = "" // It is an error to set this field in an HTTP client request.

response, err := h.client.Do(req)
<-timer
if err != nil {
rw.WriteHeader(http.StatusBadGateway)
_, _ = fmt.Fprint(rw, err)
return
}

// return response to the client
for key, values := range headers {
defer func() {
// Fully consume and then close upstream response body.
_, _ = io.Copy(io.Discard, response.Body)
_ = response.Body.Close()
}()

// Mirror headers.
for key, values := range response.Header {
for _, value := range values {
rw.Header().Add(key, value)
}
}
rw.WriteHeader(statusCode)

// Mirror status code.
rw.WriteHeader(response.StatusCode)

// ignore errors writing body, nothing to do.
_, _ = io.Copy(rw, body)
_, _ = io.Copy(rw, response.Body)
}

// injectError waits sleeps the duration specified in delay and then writes the configured error downstream.
func (h *httpHandler) injectError(rw http.ResponseWriter, delay time.Duration) {
time.Sleep(delay)

rw.WriteHeader(int(h.disruption.ErrorCode))
_, _ = rw.Write([]byte(h.disruption.ErrorBody))
}

func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if h.isExcluded(req) {
//nolint:contextcheck // Unclear which context the linter requires us to propagate here.
h.forward(rw, req, 0)
return
}

delay := h.disruption.AverageDelay
if h.disruption.DelayVariation > 0 {
variation := int64(h.disruption.DelayVariation)
delay += time.Duration(variation - 2*rand.Int63n(variation))
}

if h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate {
h.injectError(rw, delay)
return
}

//nolint:contextcheck // Unclear which context the linter requires us to propagate here.
h.forward(rw, req, delay)
}

// Start starts the execution of the proxy
Expand Down

0 comments on commit 72c4e48

Please sign in to comment.