diff --git a/elton.go b/elton.go index 58e1179..26267eb 100644 --- a/elton.go +++ b/elton.go @@ -696,3 +696,29 @@ func Compose(handlerList ...Handler) Handler { return c.Next() } } + +// copy from io.ReadAll +// ReadAll reads from r until an error or EOF and returns the data it read. +// A successful call returns err == nil, not err == EOF. Because ReadAll is +// defined to read from src until EOF, it does not treat an EOF from Read +// as an error to be reported. +func ReadAllInitCap(r io.Reader, initCap int) ([]byte, error) { + if initCap <= 0 { + initCap = 512 + } + b := make([]byte, 0, initCap) + for { + if len(b) == cap(b) { + // Add more capacity (let append pick how much). + b = append(b, 0)[:len(b)] + } + n, err := r.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if err == io.EOF { + err = nil + } + return b, err + } + } +} diff --git a/elton_test.go b/elton_test.go index 33bd62c..a3c439a 100644 --- a/elton_test.go +++ b/elton_test.go @@ -565,6 +565,42 @@ func TestGracefulClose(t *testing.T) { }) } +func TestReadAllInitCap(t *testing.T) { + assert := assert.New(t) + + buf := &bytes.Buffer{} + for i := 0; i < 1024*1024; i++ { + buf.Write([]byte("hello world!")) + } + result := buf.Bytes() + + data, err := ReadAllInitCap(buf, 1024*100) + assert.Nil(err) + assert.Equal(result, data) + + data, err = ReadAllInitCap(bytes.NewBufferString("hello world!"), 1024*100) + assert.Nil(err) + assert.Equal([]byte("hello world!"), data) +} + +func BenchmarkReadAllInitCap(b *testing.B) { + buf := &bytes.Buffer{} + for i := 0; i < 1024*1024; i++ { + buf.Write([]byte("hello world!")) + } + result := buf.Bytes() + size := buf.Len() + for i := 0; i < b.N; i++ { + data, err := ReadAllInitCap(bytes.NewBuffer(result), 1024*1024) + if err != nil { + panic(err) + } + if len(data) != size { + panic(errors.New("data is invalid")) + } + } +} + // https://stackoverflow.com/questions/50120427/fail-unit-tests-if-coverage-is-below-certain-percentage func TestMain(m *testing.M) { // call flag.Parse() here if TestMain uses flags diff --git a/middleware/body_parser.go b/middleware/body_parser.go index d19466d..0a44eb4 100644 --- a/middleware/body_parser.go +++ b/middleware/body_parser.go @@ -30,6 +30,7 @@ import ( "io/ioutil" "net/http" "net/url" + "strconv" "strings" "github.com/vicanso/elton" @@ -301,7 +302,13 @@ func NewBodyParser(config BodyParserConfig) elton.Handler { } defer r.Close() var body []byte - body, err := ioutil.ReadAll(r) + initCapSize := 0 + contentLength := c.GetRequestHeader(elton.HeaderContentLength) + // 如果请求头有指定了content length,则根据content length来分配[]byte大小 + if contentLength != "" { + initCapSize, _ = strconv.Atoi(contentLength) + } + body, err := elton.ReadAllInitCap(r, initCapSize) if err != nil { if hes.IsError(err) { return err diff --git a/middleware/body_parser_test.go b/middleware/body_parser_test.go index 6ee5fc5..4e5ffa8 100644 --- a/middleware/body_parser_test.go +++ b/middleware/body_parser_test.go @@ -31,6 +31,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "strconv" "strings" "testing" @@ -355,6 +356,7 @@ func TestBodyParserMiddleware(t *testing.T) { body := `{"name": "tree.xie"}` req := httptest.NewRequest("POST", "https://aslant.site/", strings.NewReader(body)) req.Header.Set(elton.HeaderContentType, "application/json") + req.Header.Set(elton.HeaderContentLength, strconv.Itoa(len(body))) c := elton.NewContext(nil, req) c.Next = next return c