diff --git a/gzip_test.go b/gzip_test.go index 91e840e..047031f 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -247,3 +247,75 @@ func TestDecompressGzipWithIncorrectData(t *testing.T) { assert.Equal(t, http.StatusBadRequest, w.Code) } + +func TestDecompressOnly(t *testing.T) { + buf := &bytes.Buffer{} + gz, _ := gzip.NewWriterLevel(buf, gzip.DefaultCompression) + if _, err := gz.Write([]byte(testResponse)); err != nil { + gz.Close() + t.Fatal(err) + } + gz.Close() + + req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", buf) + req.Header.Add("Content-Encoding", "gzip") + + router := gin.New() + router.Use(Gzip(NoCompression, WithDecompressOnly(true), WithDecompressFn(DefaultDecompressHandle))) + router.POST("/", func(c *gin.Context) { + if v := c.Request.Header.Get("Content-Encoding"); v != "" { + t.Errorf("unexpected `Content-Encoding`: %s header", v) + } + if v := c.Request.Header.Get("Content-Length"); v != "" { + t.Errorf("unexpected `Content-Length`: %s header", v) + } + data, err := c.GetRawData() + if err != nil { + t.Fatal(err) + } + c.Data(200, "text/plain", data) + }) + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "", w.Header().Get("Content-Encoding")) + assert.Equal(t, "", w.Header().Get("Vary")) + assert.Equal(t, testResponse, w.Body.String()) + assert.Equal(t, "", w.Header().Get("Content-Length")) +} + +func TestGzipWithDecompressOnly(t *testing.T) { + buf := &bytes.Buffer{} + gz, _ := gzip.NewWriterLevel(buf, gzip.DefaultCompression) + if _, err := gz.Write([]byte(testResponse)); err != nil { + gz.Close() + t.Fatal(err) + } + gz.Close() + + req, _ := http.NewRequestWithContext(context.Background(), "POST", "/", buf) + req.Header.Add("Content-Encoding", "gzip") + req.Header.Add("Accept-Encoding", "gzip") + + r := gin.New() + r.Use(Gzip(NoCompression, WithDecompressOnly(true), WithDecompressFn(DefaultDecompressHandle))) + r.POST("/", func(c *gin.Context) { + assert.Equal(t, c.Request.Header.Get("Content-Encoding"), "") + assert.Equal(t, c.Request.Header.Get("Content-Length"), "") + body, err := c.GetRawData() + if err != nil { + t.Fatal(err) + } + assert.Equal(t, testResponse, string(body)) + c.String(200, testResponse) + }) + + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + assert.Equal(t, w.Code, 200) + assert.Equal(t, w.Header().Get("Content-Encoding"), "") + assert.Equal(t, w.Body.String(), testResponse) +} diff --git a/handler.go b/handler.go index fccfa93..12d3051 100644 --- a/handler.go +++ b/handler.go @@ -41,7 +41,7 @@ func (g *gzipHandler) Handle(c *gin.Context) { fn(c) } - if !g.shouldCompress(c.Request) { + if g.DecompressOnly || !g.shouldCompress(c.Request) { return } diff --git a/options.go b/options.go index 6b3bc3f..42555e5 100644 --- a/options.go +++ b/options.go @@ -23,6 +23,7 @@ type Options struct { ExcludedPaths ExcludedPaths ExcludedPathesRegexs ExcludedPathesRegexs DecompressFn func(c *gin.Context) + DecompressOnly bool } type Option func(*Options) @@ -51,6 +52,13 @@ func WithDecompressFn(decompressFn func(c *gin.Context)) Option { } } +// disable compression, only decompress incoming request +func WithDecompressOnly(decompressOnly bool) Option { + return func(o *Options) { + o.DecompressOnly = decompressOnly + } +} + // Using map for better lookup performance type ExcludedExtensions map[string]bool