From 4d211970dc790ef35721ec3bc765ae335583da92 Mon Sep 17 00:00:00 2001 From: vicanso Date: Sun, 7 Apr 2019 10:58:57 +0800 Subject: [PATCH] feat: clear some header field in error handler --- cod.go | 10 ++++++++++ cod_test.go | 55 ++++++++++++++++++++++++++++++++++++++++------------- df.go | 2 ++ 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/cod.go b/cod.go index 7959546..ba2a674 100644 --- a/cod.go +++ b/cod.go @@ -401,10 +401,20 @@ func (d *Cod) NotFound(resp http.ResponseWriter, req *http.Request) { // Error error handle func (d *Cod) Error(c *Context, err error) { + // 出错时清除部分响应头 + for _, key := range []string{ + HeaderETag, + HeaderLastModified, + HeaderContentEncoding, + HeaderContentLength, + } { + c.SetHeader(key, "") + } if d.ErrorHandler != nil { d.ErrorHandler(c, err) return } + resp := c.Response he, ok := err.(*hes.Error) if ok { diff --git a/cod_test.go b/cod_test.go index 8386a39..99ab0d2 100644 --- a/cod_test.go +++ b/cod_test.go @@ -2,6 +2,7 @@ package cod import ( "bytes" + "errors" "net" "net/http" "net/http/httptest" @@ -336,21 +337,49 @@ func TestHandle(t *testing.T) { } func TestErrorHandler(t *testing.T) { - d := New() - d.GET("/", func(c *Context) error { - return hes.New("abc") + t.Run("remove header", func(t *testing.T) { + d := New() + resp := httptest.NewRecorder() + c := NewContext(resp, nil) + keys := []string{ + HeaderETag, + HeaderLastModified, + HeaderContentEncoding, + HeaderContentLength, + } + for _, key := range keys { + c.SetHeader(key, "a") + } + d.Error(c, errors.New("abcd")) + for _, key := range keys { + value := c.GetHeader(key) + if value != "" { + t.Fatalf("default error handler should remove some header files") + } + } + if resp.Code != http.StatusInternalServerError || + resp.Body.String() != "abcd" { + t.Fatalf("error response fail") + } }) - done := false - d.ErrorHandler = func(c *Context, err error) { - done = true - } - req := httptest.NewRequest("GET", "/", nil) - resp := httptest.NewRecorder() - d.ServeHTTP(resp, req) - if !done { - t.Fatalf("custom error handler is not called") - } + t.Run("custom error handler", func(t *testing.T) { + d := New() + d.GET("/", func(c *Context) error { + return hes.New("abc") + }) + + done := false + d.ErrorHandler = func(c *Context, err error) { + done = true + } + req := httptest.NewRequest("GET", "/", nil) + resp := httptest.NewRecorder() + d.ServeHTTP(resp, req) + if !done { + t.Fatalf("custom error handler is not called") + } + }) } func TestNotFoundHandler(t *testing.T) { diff --git a/df.go b/df.go index 294e612..0fd7046 100644 --- a/df.go +++ b/df.go @@ -78,6 +78,8 @@ const ( HeaderLastModified = "Last-Modified" // HeaderContentEncoding content encoding HeaderContentEncoding = "Content-Encoding" + // HeaderContentLength content length + HeaderContentLength = "Content-Length" // HeaderIfModifiedSince if modified since HeaderIfModifiedSince = "If-Modified-Since" // HeaderIfNoneMatch if none match