diff --git a/pkg/agent/protocol/http/proxy.go b/pkg/agent/protocol/http/proxy.go index b6efd080..8e3ebe15 100644 --- a/pkg/agent/protocol/http/proxy.go +++ b/pkg/agent/protocol/http/proxy.go @@ -98,73 +98,95 @@ func (h *httpHandler) isExcluded(r *http.Request) bool { return false } -// forward forwards a request to the upstream URL. -// Request is performed immediately, but response won't be sent before the duration specified in delay. -func (h *httpHandler) forward(rw http.ResponseWriter, req *http.Request, delay time.Duration) { - timer := time.After(delay) - +// forward forwards a request to the upstream URL and returns a function that +// copies the response to a ResponseWriter +func (h *httpHandler) forward(req *http.Request) func(rw http.ResponseWriter) { upstreamReq := req.Clone(context.Background()) upstreamReq.Host = h.upstreamURL.Host upstreamReq.URL.Host = h.upstreamURL.Host upstreamReq.URL.Scheme = h.upstreamURL.Scheme upstreamReq.RequestURI = "" // It is an error to set this field in an HTTP client request. + //nolint:bodyclose // it is closed in the returned functions response, err := h.client.Do(req) - <-timer + + // return a function that writes the upstream error if err != nil { - rw.WriteHeader(http.StatusBadGateway) - _, _ = fmt.Fprint(rw, err) - return - } - - defer func() { - // Fully consume and then close upstream response body. - _, _ = io.Copy(io.Discard, response.Body) - _ = response.Body.Close() - }() + return func(rw http.ResponseWriter) { + rw.WriteHeader(http.StatusBadGateway) + _, _ = fmt.Fprint(rw, err) - // Mirror headers. - for key, values := range response.Header { - for _, value := range values { - rw.Header().Add(key, value) + // Fully consume and then close upstream response body. + _, _ = io.Copy(io.Discard, response.Body) + _ = response.Body.Close() } } - // Mirror status code. - rw.WriteHeader(response.StatusCode) + // return a function that copies upstream response + return func(rw http.ResponseWriter) { + // Mirror headers. + for key, values := range response.Header { + for _, value := range values { + rw.Header().Add(key, value) + } + } - // ignore errors writing body, nothing to do. - _, _ = io.Copy(rw, response.Body) -} + // Mirror status code. + rw.WriteHeader(response.StatusCode) -// injectError waits sleeps the duration specified in delay and then writes the configured error downstream. -func (h *httpHandler) injectError(rw http.ResponseWriter, delay time.Duration) { - time.Sleep(delay) + // ignore errors writing body, nothing to do. + _, _ = io.Copy(rw, response.Body) + _ = response.Body.Close() + } +} +// injectError writes the configured error to a ResponseWriter +func (h *httpHandler) injectError(rw http.ResponseWriter) { rw.WriteHeader(int(h.disruption.ErrorCode)) _, _ = rw.Write([]byte(h.disruption.ErrorBody)) } -func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if h.isExcluded(req) { - //nolint:contextcheck // Unclear which context the linter requires us to propagate here. - h.forward(rw, req, 0) - return - } - +func (h *httpHandler) delay() time.Duration { delay := h.disruption.AverageDelay if h.disruption.DelayVariation > 0 { variation := int64(h.disruption.DelayVariation) delay += time.Duration(variation - 2*rand.Int63n(variation)) } - if h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate { - h.injectError(rw, delay) + return delay +} + +func (h *httpHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // if excluded, forward request and return response immediately + if h.isExcluded(req) { + //nolint:contextcheck // Unclear which context the linter requires us to propagate here. + h.forward(req)(rw) return } - //nolint:contextcheck // Unclear which context the linter requires us to propagate here. - h.forward(rw, req, delay) + // writer is used to write the response + var writer func(rw http.ResponseWriter) + + // forward request + done := make(chan struct{}) + go func() { + if h.disruption.ErrorRate > 0 && rand.Float32() <= h.disruption.ErrorRate { + writer = h.injectError + } else { + writer = h.forward(req) + } + + done <- struct{}{} + }() + + // wait for delay + <-time.After(h.delay()) + + // wait for upstream request + <-done + + // return response + writer(rw) } // Start starts the execution of the proxy