Skip to content

Commit

Permalink
feat: add compress middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
vicanso committed Dec 28, 2018
1 parent 7a53df9 commit 9686b8d
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 60 deletions.
2 changes: 2 additions & 0 deletions df.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ const (
HeaderIfModifiedSince = "If-Modified-Since"
// HeaderIfNoneMatch if none match
HeaderIfNoneMatch = "If-None-Match"
// HeaderAcceptEncoding accept encoding
HeaderAcceptEncoding = "Accept-Encoding"

// MinRedirectCode min redirect code
MinRedirectCode = 300
Expand Down
83 changes: 83 additions & 0 deletions middleware/compress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package middleware

import (
"regexp"
"strings"

"github.com/vicanso/cod"
)

var (
defaultCompressRegexp = regexp.MustCompile("text|javascript|json")
)

const (
defaultCompresMinLength = 1024
gzipCompress = "gzip"
)

type (
// CompressConfig compress config
CompressConfig struct {
// Level 压缩率级别
Level int
// MinLength 最小压缩长度
MinLength int
// Checker 校验数据是否可压缩
Checker *regexp.Regexp
Skipper Skipper
}
)

// NewCompresss create a new compress middleware
func NewCompresss(config CompressConfig) cod.Handler {
minLength := config.MinLength
if minLength == 0 {
minLength = defaultCompresMinLength
}
skiper := config.Skipper
if skiper == nil {
skiper = DefaultSkipper
}
checker := config.Checker
if checker == nil {
checker = defaultCompressRegexp
}
return func(c *cod.Context) (err error) {
if skiper(c) {
return c.Next()
}
err = c.Next()
if err != nil {
return
}
respHeader := c.Headers
encoding := respHeader.Get(cod.HeaderContentEncoding)
// encoding 不为空,已做处理,无需要压缩
if encoding != "" {
return
}
contentType := respHeader.Get(cod.HeaderContentType)
buf := c.BodyBytes
// 如果数据长度少于最小压缩长度或数据类型为非可压缩,则返回
if len(buf) < minLength || !checker.MatchString(contentType) {
return
}

acceptEncoding := c.Header(cod.HeaderAcceptEncoding)
// 如果请求端不接受gzip,则返回
if !strings.Contains(acceptEncoding, gzipCompress) {
return
}

gzipBuf, e := doGzip(buf, config.Level)
// 如果压缩成功,则使用压缩数据
// 失败则忽略
if e == nil {
c.SetHeader(cod.HeaderContentEncoding, gzipCompress)
c.BodyBytes = gzipBuf
}

return
}
}
49 changes: 49 additions & 0 deletions middleware/compress_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package middleware

import (
"math/rand"
"net/http/httptest"
"testing"
"time"

"github.com/vicanso/cod"
)

var letterRunes = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")

// randomString get random string
func randomString(n int) string {
b := make([]rune, n)
rand.Seed(time.Now().UnixNano())
for i := range b {
b[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(b)
}

func TestCompress(t *testing.T) {
fn := NewCompresss(CompressConfig{
Level: 1,
MinLength: 1,
})

req := httptest.NewRequest("GET", "/users/me", nil)
req.Header.Set(cod.HeaderAcceptEncoding, "gzip")
resp := httptest.NewRecorder()
c := cod.NewContext(resp, req)
c.Headers.Set(cod.HeaderContentType, "text/html")
c.BodyBytes = []byte("<html><body>" + randomString(8192) + "</body></html>")
originalSize := len(c.BodyBytes)
done := false
c.Next = func() error {
done = true
return nil
}
err := fn(c)
if err != nil || !done {
t.Fatalf("compress middleware fail, %v", err)
}
if len(c.BodyBytes) >= originalSize {
t.Fatalf("compress fail")
}
}
18 changes: 18 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package middleware

import (
"bytes"
"compress/gzip"

"github.com/vicanso/cod"
)

Expand All @@ -13,3 +16,18 @@ type (
func DefaultSkipper(c *cod.Context) bool {
return c.Committed
}

// doGzip 对数据压缩
func doGzip(buf []byte, level int) ([]byte, error) {
var b bytes.Buffer
if level <= 0 {
level = gzip.DefaultCompression
}
w, _ := gzip.NewWriterLevel(&b, level)
_, err := w.Write(buf)
if err != nil {
return nil, err
}
w.Close()
return b.Bytes(), nil
}
3 changes: 3 additions & 0 deletions middleware/responder.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ func NewResponder(config ResponderConfig) cod.Handler {
return c.Next()
}
e := c.Next()
if c.BodyBytes != nil {
return e
}
var err *errors.HTTPError
if e != nil {
// 如果出错,尝试转换为HTTPError
Expand Down
52 changes: 2 additions & 50 deletions middleware/static_serve.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package middleware

import (
"bytes"
"compress/gzip"
"io/ioutil"
"mime"
"net/http"
Expand All @@ -22,7 +20,7 @@ type (
Get(string) ([]byte, error)
Stat(string) os.FileInfo
}
// StaticServeConfig static servce config
// StaticServeConfig static serve config
StaticServeConfig struct {
Path string
Mount string
Expand All @@ -33,8 +31,6 @@ type (
DisableETag bool
DisableLastModified bool
NotFoundNext bool
Gzip bool
CompressMinLength int
Skipper Skipper
}
// FS file system
Expand All @@ -51,12 +47,6 @@ var (
errNotAllowQueryString = getStaticServeError("static serve not allow query string", http.StatusBadRequest)
errNotFound = getStaticServeError("static file not found", http.StatusNotFound)
errOutOfPath = getStaticServeError("out of path", http.StatusBadRequest)

defaultCompressTypes = []string{
"text",
"javascript",
"json",
}
)

func (fs *FS) outOfPath(file string) bool {
Expand Down Expand Up @@ -107,34 +97,6 @@ func getStaticServeError(message string, statusCode int) *errors.HTTPError {
}
}

// doGzip 对数据压缩
func doGzip(buf []byte, level int) ([]byte, error) {
var b bytes.Buffer
if level <= 0 {
level = gzip.DefaultCompression
}
w, _ := gzip.NewWriterLevel(&b, level)
_, err := w.Write(buf)
if err != nil {
return nil, err
}
w.Close()
return b.Bytes(), nil
}

func isCompressable(contentType string) bool {
compressable := false
for _, v := range defaultCompressTypes {
if compressable {
break
}
if strings.Contains(contentType, v) {
compressable = true
}
}
return compressable
}

// NewStaticServe create a static serve middleware
func NewStaticServe(staticFile StaticFile, config StaticServeConfig) cod.Handler {
if config.Path == "" {
Expand Down Expand Up @@ -208,17 +170,7 @@ func NewStaticServe(staticFile StaticFile, config StaticServeConfig) cod.Handler
c.SetHeader(cod.HeaderLastModified, lmd)
}
}
if config.Gzip &&
len(buf) >= config.CompressMinLength &&
isCompressable(contentType) {
gzipBuf, e := doGzip(buf, 0)
// 如果压缩成功,则使用压缩数据
// 失败则忽略
if e == nil {
buf = gzipBuf
c.SetHeader(cod.HeaderContentEncoding, "gzip")
}
}

for k, v := range config.Header {
c.SetHeader(k, v)
}
Expand Down
15 changes: 5 additions & 10 deletions middleware/static_serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,8 @@ func TestStaticServe(t *testing.T) {

t.Run("not compresss", func(t *testing.T) {
fn := NewStaticServe(staticFile, StaticServeConfig{
Path: staticPath,
Mount: "/static",
Gzip: true,
CompressMinLength: 1,
Path: staticPath,
Mount: "/static",
})
req := httptest.NewRequest("GET", "/static/banner.jpg", nil)
res := httptest.NewRecorder()
Expand All @@ -167,10 +165,8 @@ func TestStaticServe(t *testing.T) {

t.Run("get index.html", func(t *testing.T) {
fn := NewStaticServe(staticFile, StaticServeConfig{
Path: staticPath,
Mount: "/static",
Gzip: true,
CompressMinLength: 1,
Path: staticPath,
Mount: "/static",
})
req := httptest.NewRequest("GET", "/static/index.html?a=1", nil)
res := httptest.NewRecorder()
Expand All @@ -185,11 +181,10 @@ func TestStaticServe(t *testing.T) {
h := c.Headers
if h.Get(cod.HeaderETag) != `"10-FKjW3bSjaJvr_QYzQcHNFRn-rxc="` ||
h.Get(cod.HeaderLastModified) == "" ||
h.Get(cod.HeaderContentEncoding) != "gzip" ||
h.Get("Content-Type") != "text/html; charset=utf-8" {
t.Fatalf("set header fail")
}
if len(c.Body.([]byte)) != 37 {
if len(c.Body.([]byte)) != 16 {
t.Fatalf("response body fail")
}
})
Expand Down

0 comments on commit 9686b8d

Please sign in to comment.