diff --git a/component/http/middleware.go b/component/http/middleware.go index 8a0a84f36..f9aaad250 100644 --- a/component/http/middleware.go +++ b/component/http/middleware.go @@ -73,6 +73,11 @@ func (w *responseWriter) Header() http.Header { // Write to the internal responseWriter and sets the status if not set already. func (w *responseWriter) Write(d []byte) (int, error) { + if !w.statusHeaderWritten { + w.status = http.StatusOK + w.statusHeaderWritten = true + } + value, err := w.writer.Write(d) if err != nil { return value, err @@ -82,11 +87,6 @@ func (w *responseWriter) Write(d []byte) (int, error) { w.responsePayload.Write(d) } - if !w.statusHeaderWritten { - w.status = http.StatusOK - w.statusHeaderWritten = true - } - return value, err } diff --git a/component/http/middleware_test.go b/component/http/middleware_test.go index 03cc91ff6..796ed6d40 100644 --- a/component/http/middleware_test.go +++ b/component/http/middleware_test.go @@ -688,3 +688,45 @@ func TestIsConnectionReset(t *testing.T) { }) } } + +type failWriter struct { +} + +func (fw *failWriter) Header() http.Header { + return http.Header{} +} + +func (fw *failWriter) Write([]byte) (int, error) { + return 0, fmt.Errorf("foo") +} + +func (fw *failWriter) WriteHeader(statusCode int) { + +} + +func TestSetResponseWriterStatusOnResponseFailWrite(t *testing.T) { + failWriter := &failWriter{} + failDynamicCompressionResponseWriter := &dynamicCompressionResponseWriter{failWriter, "", nil, 0, deflateLevel} + + tests := []struct { + Name string + ResponseWriter *responseWriter + }{ + { + Name: "Failing responseWriter with http.ResponseWriter", + ResponseWriter: newResponseWriter(failWriter, false), + }, + { + Name: "Failing responseWriter with http.ResponseWriter", + ResponseWriter: newResponseWriter(failDynamicCompressionResponseWriter, false), + }, + } + + for _, test := range tests { + t.Run(test.Name, func(t *testing.T) { + _, err := test.ResponseWriter.Write([]byte(`"foo":"bar"`)) + assert.Error(t, err) + assert.Equal(t, http.StatusOK, test.ResponseWriter.status) + }) + } +}