diff --git a/.gitignore b/.gitignore index 35bf59b..deca83f 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,7 @@ gopayloader-darwin-amd64 gopayloader-linux-amd64 gopayloader-windows-amd64.exe *.tar.gz -build \ No newline at end of file +build +go.work +main2.go +go.work.sum diff --git a/README.md b/README.md index 7521477..052b96b 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,10 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/domsolutions/gopayloader)](https://goreportcard.com/report/github.com/domsolutions/gopayloader) [![GoDoc](https://godoc.org/github.com/domsolutions/gopayloader?status.svg)](http://godoc.org/github.com/domsolutions/gopayloader) -Gopayloader is an HTTP/S benchmarking tool. Inspired by [bombardier](https://github.com/codesenberg/bombardier/) it also uses [fasthttp](https://github.com/valyala/fasthttp) which allows for fast creation and sending of requests due to low allocations and lots of other improvements. But with -added improvement of also supporting fashttp for HTTP/2. +Gopayloader is an HTTP/S benchmarking tool. Inspired by [bombardier](https://github.com/codesenberg/bombardier/) it also uses [fasthttp](https://github.com/valyala/fasthttp) which allows for fast creation and sending of requests due to low allocations and lots of other improvements. It uses this client by default, a different client can be used with `--client` flag. -Supports all HTTP versions, using [quic-go](https://github.com/quic-go/quic-go) for HTTP/3 client with `--client nethttp-3`. For HTTP/2 can use fasthttp with `--client fasthttp-2` or standard core golang `net/http` with `--client nethttp` +Supports all HTTP versions, using [quic-go](https://github.com/quic-go/quic-go) for HTTP/3 client with `--client nethttp3`. For HTTP/2 can use with `--client nethttp2`. By default uses fasthttp HTTP/1.1 client. Supports ability to generate custom JWTs to send in headers with payload (only limited by HDD size). This can be useful if the service being tested is JWT authenticated. Each JWT generated will be unique as contains a unique `jti` in claims i.e. @@ -68,8 +67,6 @@ achieved mean RPS of **53,098** To list all available flags run; ```shell -./gopayloader run --help - Load test HTTP/S server - supports HTTP/1.1 HTTP/2 HTTP/3 Usage: @@ -78,10 +75,10 @@ Usage: Flags: -b, --body string request body --body-file string read request body from file - --client string fasthttp-1 for fast http/1.1 requests - fasthttp-2 for fast http/2 requests - nethttp for standard net/http requests supporting http/1.1 http/2 - nethttp-3 for standard net/http requests supporting http/3 using quic-go (default "fasthttp-1") + --client string fasthttp for fast http/1.1 requests + nethttp for standard net/http requests using http/1.1 + nethttp2 for standard net/http requests using http/2 + nethttp3 for standard net/http requests supporting http/3 using quic-go (default "fasthttp") -c, --connections uint Number of simultaneous connections (default 1) -k, --disable-keep-alive Disable keep-alive connections -H, --headers strings headers to send in request, can have multiple i.e -H 'content-type:application/json' -H' connection:close' @@ -97,6 +94,7 @@ Flags: -m, --method string request method (default "GET") --mtls-cert string mTLS cert path --mtls-key string mTLS cert private key path + --parallel Sends reqs in parallel per connection with HTTP/2 --read-timeout duration Read timeout (default 5s) -r, --requests int Number of requests --skip-verify Skip verify SSL cert signer diff --git a/cmd/payloader/run.go b/cmd/payloader/run.go index 0f363be..812b12d 100644 --- a/cmd/payloader/run.go +++ b/cmd/payloader/run.go @@ -33,6 +33,7 @@ const ( argBody = "body" argBodyFile = "body-file" argClient = "client" + argParallel = "parallel" ) var ( @@ -60,6 +61,7 @@ var ( headers *[]string body string bodyFile string + parallel bool ) var runCmd = &cobra.Command{ @@ -98,7 +100,8 @@ var runCmd = &cobra.Command{ *headers, body, bodyFile, - client) + client, + parallel) }, } @@ -106,11 +109,12 @@ func init() { runCmd.Flags().Int64VarP(&reqs, argRequests, "r", 0, "Number of requests") runCmd.Flags().UintVarP(&conns, argConnections, "c", 1, "Number of simultaneous connections") runCmd.Flags().BoolVarP(&disableKeepAlive, argKeepAlive, "k", false, "Disable keep-alive connections") + runCmd.Flags().BoolVar(¶llel, argParallel, false, "Sends reqs in parallel per connection with HTTP/2") runCmd.Flags().BoolVar(&skipVerify, argVerifySigner, false, "Skip verify SSL cert signer") runCmd.Flags().DurationVarP(&duration, argTime, "t", 0, "Execution time window, if used with -r will uniformly distribute reqs within time window, without -r reqs are unlimited") - runCmd.Flags().DurationVar(&readTimeout, argReadTimeout, 5*time.Second, "Read timeout") - runCmd.Flags().DurationVar(&writeTimeout, argWriteTimeout, 5*time.Second, "Write timeout") + runCmd.Flags().DurationVar(&readTimeout, argReadTimeout, 10*time.Second, "Read timeout") + runCmd.Flags().DurationVar(&writeTimeout, argWriteTimeout, 10*time.Second, "Write timeout") runCmd.Flags().StringVarP(&method, argMethod, "m", "GET", "request method") runCmd.Flags().StringVarP(&body, argBody, "b", "", "request body") runCmd.Flags().StringVar(&bodyFile, argBodyFile, "", "read request body from file") @@ -121,8 +125,8 @@ func init() { runCmd.Flags().StringVar(&mTLSKey, argMTLSKey, "", "mTLS cert private key path") runCmd.Flags().StringVar(&client, argClient, worker.HttpClientFastHTTP1, worker.HttpClientFastHTTP1+` for fast http/1.1 requests -`+worker.HttpClientFastHTTP2+` for fast http/2 requests -`+worker.HttpClientNetHTTP+` for standard net/http requests supporting http/1.1 http/2 +`+worker.HttpClientNetHTTP+` for standard net/http requests using http/1.1 +`+worker.HttpClientNetHTTP2+` for standard net/http requests using http/2 `+worker.HttpClientNetHTTP3+` for standard net/http requests supporting http/3 using quic-go`) runCmd.Flags().StringVar(&jwtKID, argJWTKid, "", "JWT KID") diff --git a/cmd/payloader/test-server.go b/cmd/payloader/test-server.go index 5d91a13..f71038f 100644 --- a/cmd/payloader/test-server.go +++ b/cmd/payloader/test-server.go @@ -2,13 +2,17 @@ package payloader import ( "bufio" + "context" "crypto/tls" "errors" + "github.com/domsolutions/http2" "github.com/quic-go/quic-go" httpv3server "github.com/quic-go/quic-go/http3" "github.com/spf13/cobra" "github.com/valyala/fasthttp" + golanghttp2 "golang.org/x/net/http2" "log" + "net" "net/http" "os" "os/signal" @@ -23,6 +27,7 @@ var ( port int responseSize int fasthttp1 bool + fasthttp2 bool nethttp2 bool httpv3 bool debug bool @@ -31,6 +36,8 @@ var ( var ( serverCert string privateKey string + crt []byte + key []byte ) func init() { @@ -43,12 +50,13 @@ func init() { } func tlsConfig() *tls.Config { - crt, err := os.ReadFile(serverCert) + var err error + crt, err = os.ReadFile(serverCert) if err != nil { log.Fatal(err) } - key, err := os.ReadFile(privateKey) + key, err = os.ReadFile(privateKey) if err != nil { log.Fatal(err) } @@ -103,20 +111,85 @@ var runServerCmd = &cobra.Command{ select { case <-c: log.Println("User cancelled, shutting down") - server.Shutdown() case err := <-errs: log.Printf("Got error from server; %v \n", err) } + server.Shutdown() + return nil + } + + if fasthttp2 { + var err error + + server := fasthttp.Server{ + ErrorHandler: func(c *fasthttp.RequestCtx, err error) { + log.Println(err) + c.WriteString(err.Error()) + }, + Handler: func(c *fasthttp.RequestCtx) { + _, err = c.WriteString(response) + if err != nil { + log.Println(err) + } + if debug { + log.Printf("%s\n", c.Request.Header.String()) + log.Printf("%s\n", c.Request.Body()) + } + }, + } + + tlsConfig() + err = server.AppendCertEmbed(crt, key) + if err != nil { + log.Fatalln(err) + } + + http2.ConfigureServer(&server, http2.ServerConfig{ + Debug: debug, + }) + + errs := make(chan error) + go func() { + if err := server.ListenAndServeTLSEmbed(addr, crt, key); err != nil { + log.Println(err) + errs <- err + } + }() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + select { + case <-c: + log.Println("User cancelled, shutting down") + case err := <-errs: + log.Printf("Got error from server; %v \n", err) + } + + server.Shutdown() return nil } if nethttp2 { server := &http.Server{ Addr: addr, - ReadTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, WriteTimeout: 10 * time.Second, TLSConfig: tlsConfig(), + ConnState: func(c net.Conn, s http.ConnState) { + if !debug { + return + } + switch s { + case http.StateNew: + log.Println("NEW conn") + case http.StateClosed: + log.Println("CLOSED conn") + case http.StateHijacked: + log.Println("HIJACKED conn") + } + }, } var err error @@ -126,13 +199,35 @@ var runServerCmd = &cobra.Command{ log.Println(err) } if debug { - log.Printf("%+v\n", r.Header.Get("Some-Jwt")) + log.Printf("%+v\n", r.Header) + log.Printf("%+v\n", r.Body) } }) - if err := server.ListenAndServeTLS("", ""); err != nil { - log.Fatal(err) + err = golanghttp2.ConfigureServer(server, &golanghttp2.Server{}) + if err != nil { + return err } + + errs := make(chan error) + go func() { + if err := server.ListenAndServeTLS(serverCert, privateKey); err != nil { + errs <- err + } + }() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + select { + case <-c: + log.Println("User cancelled, shutting down") + case err := <-errs: + log.Printf("Got error from server; %v \n", err) + } + + server.Shutdown(context.Background()) + return nil } if httpv3 { @@ -172,6 +267,7 @@ func init() { runServerCmd.Flags().IntVarP(&port, "port", "p", 8080, "Port") runServerCmd.Flags().IntVarP(&responseSize, "response-size", "s", 10, "Response size") runServerCmd.Flags().BoolVar(&fasthttp1, "fasthttp-1", false, "Fasthttp HTTP/1.1 server") + runServerCmd.Flags().BoolVar(&fasthttp2, "fasthttp-2", false, "Fasthttp HTTP/2 server") runServerCmd.Flags().BoolVar(&nethttp2, "netHTTP-2", false, "net/http HTTP/2 server") runServerCmd.Flags().BoolVar(&httpv3, "http-3", false, "HTTP/3 server") runServerCmd.Flags().BoolVarP(&debug, "verbose", "v", false, "print logs") diff --git a/config/config.go b/config/config.go index 69e12c6..da73b94 100644 --- a/config/config.go +++ b/config/config.go @@ -2,14 +2,15 @@ package config import ( "context" + "encoding/json" "errors" "fmt" + "github.com/domsolutions/gopayloader/pkgs/payloader/worker" "net/url" "os" "regexp" "strings" "time" - "encoding/json" ) type Config struct { @@ -40,9 +41,10 @@ type Config struct { Body string BodyFile string Client string + Parallel bool } -func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename string, headers []string, body, bodyFile string, client string) *Config { +func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename string, headers []string, body, bodyFile string, client string, parallel bool) *Config { return &Config{ Ctx: ctx, ReqURI: reqURI, @@ -70,6 +72,7 @@ func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKee Body: body, BodyFile: bodyFile, Client: client, + Parallel: parallel, } } @@ -197,6 +200,10 @@ func (c *Config) Validate() error { } } + if c.Parallel && c.Client != worker.HttpClientNetHTTP2 { + return fmt.Errorf("can only run parallel with %s client", worker.HttpClientNetHTTP2) + } + if c.VerboseTicker == 0 { return errors.New("ticker value can't be zero") } diff --git a/go.mod b/go.mod index 81876c5..e2b3b12 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/quic-go/quic-go v0.40.0 github.com/spf13/cobra v1.8.0 github.com/valyala/fasthttp v1.51.0 + golang.org/x/net v0.17.0 golang.org/x/text v0.14.0 ) @@ -39,7 +40,6 @@ require ( golang.org/x/crypto v0.14.0 // indirect golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df // indirect golang.org/x/mod v0.11.0 // indirect - golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/term v0.13.0 // indirect golang.org/x/tools v0.10.0 // indirect diff --git a/pkgs/http-clients/definitions.go b/pkgs/http-clients/definitions.go index 75cffd9..03aefcf 100644 --- a/pkgs/http-clients/definitions.go +++ b/pkgs/http-clients/definitions.go @@ -23,6 +23,7 @@ type GoPayLoaderClient interface { NewReq(method, url string) (Request, error) NewResponse() Response CloseConns() + HTTP2() bool } type Config struct { @@ -49,6 +50,7 @@ type Config struct { HTTPV3 bool ReqStats chan<- time.Duration Client string + Parallel bool } func (c *Config) ReqLimitedOnly() bool { diff --git a/pkgs/http-clients/fasthttp/fasthttp.go b/pkgs/http-clients/fasthttp/fasthttp.go index 26112c4..8627b95 100644 --- a/pkgs/http-clients/fasthttp/fasthttp.go +++ b/pkgs/http-clients/fasthttp/fasthttp.go @@ -3,7 +3,6 @@ package fasthttp import ( "crypto/tls" "github.com/domsolutions/gopayloader/pkgs/http-clients" - "github.com/domsolutions/http2" "github.com/valyala/fasthttp" "net" "net/url" @@ -11,6 +10,7 @@ import ( type Client struct { client *fasthttp.HostClient + http2 bool } type Req struct { @@ -59,15 +59,21 @@ func (fh *Client) Do(req http_clients.Request, resp http_clients.Response) error return fh.client.Do(req.(*Req).req, resp.(*Resp).resp) } +func (c *Client) HTTP2() bool { + return c.http2 +} + func (c *Client) CloseConns() { c.client.CloseIdleConnections() } func (fh *Client) NewResponse() http_clients.Response { + // TODO: buffer pool return &Resp{resp: &fasthttp.Response{}} } func (fh *Client) NewReq(method, url string) (http_clients.Request, error) { + // TODO: buffer pool r := &fasthttp.Request{} r.SetRequestURI(url) r.Header.SetMethodBytes([]byte(method)) @@ -107,20 +113,5 @@ func GetFastHTTPClient1(config *http_clients.Config) (http_clients.GoPayLoaderCl }, } - return &Client{client: client}, nil -} - -func GetFastHTTPClient2(config *http_clients.Config) (http_clients.GoPayLoaderClient, error) { - client, err := GetFastHTTPClient1(config) - if err != nil { - return nil, err - } - - if err := http2.ConfigureClient(client.(*Client).client, http2.ClientOpts{ - //MaxResponseTime: config.ReadTimeout + config.WriteTimeout, - }); err != nil { - return nil, err - } - - return &Client{client: client.(*Client).client}, nil + return &Client{client: client, http2: false}, nil } diff --git a/pkgs/http-clients/nethttp/nethttp.go b/pkgs/http-clients/nethttp/nethttp.go index b60a3fc..65de71b 100644 --- a/pkgs/http-clients/nethttp/nethttp.go +++ b/pkgs/http-clients/nethttp/nethttp.go @@ -5,12 +5,15 @@ import ( "crypto/tls" "github.com/domsolutions/gopayloader/pkgs/http-clients" "github.com/quic-go/quic-go/http3" + "golang.org/x/net/http2" "io" + "log" "net/http" ) type Client struct { client *http.Client + http2 bool } type Req struct { @@ -26,6 +29,10 @@ func (r *Resp) StatusCode() int { } func (r *Resp) Close() { + // need to read conn before closing otherwise conn not freed + if _, err := io.Copy(io.Discard, r.resp.Body); err != nil { + log.Printf("Failed to read response body and discard %v \n", err) + } r.resp.Body.Close() } @@ -80,6 +87,10 @@ func (c *Client) CloseConns() { c.client.CloseIdleConnections() } +func (c *Client) HTTP2() bool { + return c.http2 +} + func (c *Client) NewResponse() http_clients.Response { return &Resp{ resp: &http.Response{}, @@ -91,6 +102,7 @@ func (c *Client) NewReq(method, url string) (http_clients.Request, error) { if err != nil { return nil, err } + req.Header.Set("Connection", "Keep-Alive") return &Req{ req: req, @@ -110,14 +122,39 @@ func GetNetHTTPClient(config *http_clients.Config) (http_clients.GoPayLoaderClie tlsConfig.Certificates = []tls.Certificate{cert} } - return &Client{client: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: tlsConfig, - MaxConnsPerHost: 1, - MaxIdleConns: 1, - }, - Timeout: config.ReadTimeout + config.WriteTimeout, - }}, nil + return &Client{ + http2: false, + client: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + MaxConnsPerHost: 1, + }, + Timeout: config.ReadTimeout + config.WriteTimeout, + }}, nil +} + +func GetNetHTTP2Client(config *http_clients.Config) (http_clients.GoPayLoaderClient, error) { + tlsConfig := &tls.Config{ + InsecureSkipVerify: config.SkipVerify, + } + + if config.MTLSCert != "" && config.MTLSKey != "" { + cert, err := tls.LoadX509KeyPair(config.MTLSCert, config.MTLSKey) + if err != nil { + return nil, err + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + return &Client{ + http2: true, + client: &http.Client{ + Transport: &http2.Transport{ + TLSClientConfig: tlsConfig, + StrictMaxConcurrentStreams: true, + }, + Timeout: config.ReadTimeout + config.WriteTimeout, + }}, nil } func GetNetHTTP3Client(config *http_clients.Config) (http_clients.GoPayLoaderClient, error) { @@ -141,6 +178,7 @@ func GetNetHTTP3Client(config *http_clients.Config) (http_clients.GoPayLoaderCli } return &Client{ + http2: false, client: &http.Client{ Transport: roundTripper, }, diff --git a/pkgs/payloader/output/cli/cli.go b/pkgs/payloader/output/cli/cli.go index 7df1ee9..5e275b8 100644 --- a/pkgs/payloader/output/cli/cli.go +++ b/pkgs/payloader/output/cli/cli.go @@ -62,7 +62,7 @@ func displayRespSize(resp payloader.ByteSize, t table.Writer) { t.AppendSeparator() } -func displayErrors(errors map[string]uint, t table.Writer) { +func displayErrors(errors map[string]uint64, t table.Writer) { rows := make([]table.Row, 0) for err, count := range errors { rows = append(rows, table.Row{"Error; " + err, count}) diff --git a/pkgs/payloader/payloader-results.go b/pkgs/payloader/payloader-results.go index 4ad163c..bb8c1d8 100644 --- a/pkgs/payloader/payloader-results.go +++ b/pkgs/payloader/payloader-results.go @@ -10,7 +10,7 @@ func (p *PayLoader) ComputeResults(workers []worker.Worker, results *GoPayloader results.Start = p.startTime results.End = p.stopTime results.Total = p.stopTime.Sub(p.startTime) - results.Errors = make(map[string]uint) + results.Errors = make(map[string]uint64) results.Responses = make(map[worker.ResponseCode]int64) pterm.Debug.Println("Calculating response code statistics") @@ -20,21 +20,16 @@ func (p *PayLoader) ComputeResults(workers []worker.Worker, results *GoPayloader results.CompletedReqs += stats.CompletedReqs results.FailedReqs += stats.FailedReqs - for err, count := range stats.Errors { - if _, ok := results.Errors[err]; ok { - results.Errors[err] += count - } else { - results.Errors[err] = count - } - } + stats.Errors.Range(func(key, value any) bool { + results.Errors[key.(string)] += value.(uint64) + return true + }) + + stats.Responses.Range(func(key, value any) bool { + results.Responses[key.(worker.ResponseCode)] += value.(int64) + return true + }) - for code, val := range stats.Responses { - if _, ok := results.Responses[code]; ok { - results.Responses[code] += val - } else { - results.Responses[code] = val - } - } } if results.CompletedReqs > 0 { diff --git a/pkgs/payloader/payloader.go b/pkgs/payloader/payloader.go index 275438b..04cf8d0 100644 --- a/pkgs/payloader/payloader.go +++ b/pkgs/payloader/payloader.go @@ -48,7 +48,7 @@ type GoPayloaderResults struct { RPS RPS Latency Latency Responses map[worker.ResponseCode]int64 - Errors map[string]uint + Errors map[string]uint64 ReqByteSize ByteSize RespByteSize ByteSize } @@ -156,6 +156,7 @@ func (p *PayLoader) handleReqs() (*GoPayloaderResults, error) { var conn uint for conn = 0; conn < p.config.Conns; conn++ { + c := &http_clients.Config{ ReqURI: p.config.ReqURI, DisableKeepAlive: p.config.DisableKeepAlive, @@ -176,6 +177,7 @@ func (p *PayLoader) handleReqs() (*GoPayloaderResults, error) { BodyFile: p.config.BodyFile, ReqStats: reqStats, Client: p.config.Client, + Parallel: p.config.Parallel, } // evenly distribute remainder reqs diff --git a/pkgs/payloader/payloader_test.go b/pkgs/payloader/payloader_test.go index f4be43f..7d2b13c 100644 --- a/pkgs/payloader/payloader_test.go +++ b/pkgs/payloader/payloader_test.go @@ -9,6 +9,7 @@ import ( "github.com/quic-go/quic-go" httpv3server "github.com/quic-go/quic-go/http3" "github.com/valyala/fasthttp" + golanghttp2 "golang.org/x/net/http2" "log" "net/http" "os" @@ -21,9 +22,14 @@ import ( var ( testServerHTTP3 httpv3server.Server testFastHTTP fasthttp.Server + crtPath string + keyPath string ) func init() { + crtPath = filepath.Join("..", "..", "test", "server.crt") + keyPath = filepath.Join("..", "..", "test", "server.key") + go testStartHTTP1Server("localhost:8888") go testStartHTTP2Server("localhost:8889") go testStartHTTP3Server("localhost:8890") @@ -37,12 +43,12 @@ func init() { } func tlsConfig() *tls.Config { - crt, err := os.ReadFile(filepath.Join("..", "..", "test", "server.crt")) + crt, err := os.ReadFile(crtPath) if err != nil { log.Fatal(err) } - key, err := os.ReadFile(filepath.Join("..", "..", "test", "server.key")) + key, err := os.ReadFile(keyPath) if err != nil { log.Fatal(err) } @@ -111,31 +117,36 @@ func testStartHTTP2Server(addr string) { } }) - if err := server.ListenAndServeTLS("", ""); err != nil { + err = golanghttp2.ConfigureServer(server, &golanghttp2.Server{}) + if err != nil { + panic(err) + } + + if err := server.ListenAndServeTLS(crtPath, keyPath); err != nil { log.Println(err) } } func TestPayLoader_RunFastHTTP1NonSSL(t *testing.T) { - testPayLoader_Run(t, "http://localhost:8888", "fasthttp-1", func() { + testPayLoader_Run(t, "http://localhost:8888", "fasthttp", func() { testFastHTTP.Shutdown() }) } func TestPayLoader_RunFastHTTP1SSL(t *testing.T) { - testPayLoader_Run(t, "https://localhost:8889", "fasthttp-1", nil) + testPayLoader_Run(t, "https://localhost:8889", "fasthttp", nil) } -func TestPayLoader_RunNetHTTP1SSL(t *testing.T) { - testPayLoader_Run(t, "https://localhost:8889", "nethttp", nil) +func TestPayLoader_RunNetHTT21SSL(t *testing.T) { + testPayLoader_Run(t, "https://localhost:8889", "nethttp2", nil) } -func TestPayLoader_RunFastHTTP2SSL(t *testing.T) { - testPayLoader_Run(t, "https://localhost:8889", "fasthttp-2", nil) +func TestPayLoader_RunNetHTTP1SSL(t *testing.T) { + testPayLoader_Run(t, "https://localhost:8889", "nethttp", nil) } func TestPayLoader_RunNetHTTP3(t *testing.T) { - testPayLoader_Run(t, "https://localhost:8890", "nethttp-3", func() { + testPayLoader_Run(t, "https://localhost:8890", "nethttp3", func() { testServerHTTP3.Close() }) } @@ -144,13 +155,16 @@ func testPayLoader_Run(t *testing.T, addr, client string, cleanup func()) { type fields struct { config *config.Config } - tests := []struct { + + type tcase struct { name string fields fields want *GoPayloaderResults wantErr error check func(t *testing.T) - }{ + } + + tests := []tcase{ { name: "GET 10 connections for 210 requests", fields: fields{config: &config.Config{ @@ -405,6 +419,33 @@ func testPayLoader_Run(t *testing.T, addr, client string, cleanup func()) { }, } + if client == "nethttp2" { + tests = append(tests, tcase{ + name: "PARALLEL - GET 10 connections for 210 requests", + fields: fields{config: &config.Config{ + Parallel: true, + Ctx: context.Background(), + ReqURI: addr, + ReqTarget: 210, + Conns: 10, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + Method: "GET", + Client: client, + VerboseTicker: time.Second, + SkipVerify: true, + }}, + want: &GoPayloaderResults{ + CompletedReqs: 210, + FailedReqs: 0, + Responses: map[worker.ResponseCode]int64{ + 200: 210, + }, + Errors: nil, + }, + }) + } + if cleanup != nil { t.Cleanup(cleanup) } diff --git a/pkgs/payloader/worker/generate.go b/pkgs/payloader/worker/generate.go index a3002fb..bef15d2 100644 --- a/pkgs/payloader/worker/generate.go +++ b/pkgs/payloader/worker/generate.go @@ -7,13 +7,14 @@ import ( "github.com/domsolutions/gopayloader/pkgs/http-clients/nethttp" "os" "strings" + "sync" ) const ( HttpClientNetHTTP = "nethttp" - HttpClientNetHTTP3 = "nethttp-3" - HttpClientFastHTTP1 = "fasthttp-1" - HttpClientFastHTTP2 = "fasthttp-2" + HttpClientNetHTTP2 = "nethttp2" + HttpClientNetHTTP3 = "nethttp3" + HttpClientFastHTTP1 = "fasthttp" ) type TotalRequestsComplete int64 @@ -23,43 +24,37 @@ type ResponseCode int type Stats struct { CompletedReqs int64 FailedReqs int64 - Responses map[ResponseCode]int64 - Errors map[string]uint + Responses *sync.Map + Errors *sync.Map } func NewWorker(config *http_clients.Config) (Worker, error) { - client, err := getClient(config) - if err != nil { - return nil, err - } - - resp := client.NewResponse() - req, err := getReq(client, config) + client, err := http(config) if err != nil { return nil, err } if config.ReqLimitedOnly() { if config.JwtStreamReceiver != nil { - w := &WorkerFixedReqs{baseConfig(config, client, req, resp)} + w := &WorkerFixedReqs{baseConfig(config, client)} w.middleware = jwtMiddleware return w, nil } - return &WorkerFixedReqs{baseConfig(config, client, req, resp)}, nil + return &WorkerFixedReqs{baseConfig(config, client)}, nil } if config.UnlimitedReqs() { - return &WorkerFixedTime{baseConfig(config, client, req, resp)}, nil + return &WorkerFixedTime{baseConfig(config, client)}, nil } - w := &WorkerFixedTimeRequests{baseConfig(config, client, req, resp)} + w := &WorkerFixedTimeRequests{baseConfig(config, client)} if config.JwtStreamReceiver != nil { w.middleware = jwtMiddleware } return w, nil } -func getReq(client http_clients.GoPayLoaderClient, config *http_clients.Config) (http_clients.Request, error) { +func newReq(client http_clients.GoPayLoaderClient, config *http_clients.Config) (http_clients.Request, error) { req, err := client.NewReq(config.Method, config.ReqURI) if err != nil { return nil, err @@ -86,40 +81,45 @@ func getReq(client http_clients.GoPayLoaderClient, config *http_clients.Config) } req.SetBody(bb) } + return req, nil } -func jwtMiddleware(w *WorkerBase) { +func jwtMiddleware(w *WorkerBase, req http_clients.Request) { select { case jwt := <-w.config.JwtStreamReceiver: - w.req.SetHeader(w.config.JWTHeader, jwt) + req.SetHeader(w.config.JWTHeader, jwt) } } -func baseConfig(config *http_clients.Config, client http_clients.GoPayLoaderClient, req http_clients.Request, resp http_clients.Response) *WorkerBase { +func baseConfig(config *http_clients.Config, client http_clients.GoPayLoaderClient) *WorkerBase { return &WorkerBase{ - config: config, - req: req, - resp: resp, - client: client, - reqStats: config.ReqStats, + config: config, + client: client, + parallel: config.Parallel, + parallelWg: &sync.WaitGroup{}, + reqStats: config.ReqStats, + method: config.Method, + url: config.ReqURI, stats: Stats{ - Responses: make(map[ResponseCode]int64), - Errors: make(map[string]uint), + Responses: &sync.Map{}, + Errors: &sync.Map{}, }, + statsSuccessLock: &sync.Mutex{}, + statsErrorLock: &sync.Mutex{}, } } -func getClient(config *http_clients.Config) (http_clients.GoPayLoaderClient, error) { +func http(config *http_clients.Config) (http_clients.GoPayLoaderClient, error) { switch config.Client { case HttpClientNetHTTP: return nethttp.GetNetHTTPClient(config) + case HttpClientNetHTTP2: + return nethttp.GetNetHTTP2Client(config) case HttpClientNetHTTP3: return nethttp.GetNetHTTP3Client(config) case HttpClientFastHTTP1: return fasthttp.GetFastHTTPClient1(config) - case HttpClientFastHTTP2: - return fasthttp.GetFastHTTPClient2(config) } return nil, fmt.Errorf("client %s not recognised", config.Client) } diff --git a/pkgs/payloader/worker/worker-fixed-reqs.go b/pkgs/payloader/worker/worker-fixed-reqs.go index ee6f61b..3b9d735 100644 --- a/pkgs/payloader/worker/worker-fixed-reqs.go +++ b/pkgs/payloader/worker/worker-fixed-reqs.go @@ -12,6 +12,8 @@ func (w *WorkerFixedReqs) Run(wg *sync.WaitGroup) { defer wg.Done() defer w.client.CloseConns() + w.config.StartTrigger.Wait() + var i int64 for i = 0; i < w.config.ReqTarget; i++ { select { @@ -22,4 +24,8 @@ func (w *WorkerFixedReqs) Run(wg *sync.WaitGroup) { w.run() } } + + if w.parallel { + w.parallelWg.Wait() + } } diff --git a/pkgs/payloader/worker/worker-fixed-time-requests.go b/pkgs/payloader/worker/worker-fixed-time-requests.go index 602d7a1..dba9b0a 100644 --- a/pkgs/payloader/worker/worker-fixed-time-requests.go +++ b/pkgs/payloader/worker/worker-fixed-time-requests.go @@ -26,10 +26,14 @@ func (w *WorkerFixedTimeRequests) Run(wg *sync.WaitGroup) { return case <-deadline.Done(): // required reqs were not completed in time period, finish reqs - if w.config.ReqTarget != w.stats.CompletedReqs+w.stats.FailedReqs { + if w.config.ReqTarget != w.CompletedReqs.Load()+w.FailedReqs.Load() { w.run() continue } + + if w.parallel { + w.parallelWg.Wait() + } return case <-newReq.C: w.run() diff --git a/pkgs/payloader/worker/worker-fixed-time.go b/pkgs/payloader/worker/worker-fixed-time.go index 1ae05b8..b08f8e9 100644 --- a/pkgs/payloader/worker/worker-fixed-time.go +++ b/pkgs/payloader/worker/worker-fixed-time.go @@ -22,6 +22,9 @@ func (w *WorkerFixedTime) Run(wg *sync.WaitGroup) { // user cancelled return case <-ticker.C: + if w.parallel { + w.parallelWg.Wait() + } return default: w.run() diff --git a/pkgs/payloader/worker/worker.go b/pkgs/payloader/worker/worker.go index d408bd2..d52c2aa 100644 --- a/pkgs/payloader/worker/worker.go +++ b/pkgs/payloader/worker/worker.go @@ -3,6 +3,7 @@ package worker import ( http_clients "github.com/domsolutions/gopayloader/pkgs/http-clients" "sync" + "sync/atomic" "time" ) @@ -14,38 +15,63 @@ type Worker interface { } type WorkerBase struct { - config *http_clients.Config - client http_clients.GoPayLoaderClient - stats Stats - req http_clients.Request - resp http_clients.Response - middleware func(w *WorkerBase) - reqStats chan<- time.Duration + statsSuccessLock *sync.Mutex + statsErrorLock *sync.Mutex + config *http_clients.Config + client http_clients.GoPayLoaderClient + stats Stats + middleware func(w *WorkerBase, req http_clients.Request) + reqStats chan<- time.Duration + parallel bool + method string + url string + reqSize int64 + respSize int64 + parallelWg *sync.WaitGroup + CompletedReqs atomic.Int64 + FailedReqs atomic.Int64 } func (w *WorkerBase) ReqSize() int64 { - return w.req.Size() + return w.reqSize } func (w *WorkerBase) RespSize() int64 { - if w.resp == nil { - return 0 + return w.respSize +} + +func (w *WorkerBase) updateErrStats(err error) { + w.statsErrorLock.Lock() + defer w.statsErrorLock.Unlock() + + val, ok := w.stats.Errors.Load(err.Error()) + if ok { + w.stats.Errors.Store(err.Error(), val.(uint64)+1) + } else { + w.stats.Errors.Store(err.Error(), uint64(1)) } - return w.resp.Size() + + w.FailedReqs.Add(1) } func (w *WorkerBase) run() { + if w.parallel { + w.parallelWg.Add(1) + go func() { + defer w.parallelWg.Done() + + err := w.process() + if err != nil { + w.updateErrStats(err) + } + }() + return + } + err := w.process() if err != nil { - if _, ok := w.stats.Errors[err.Error()]; ok { - w.stats.Errors[err.Error()]++ - } else { - w.stats.Errors[err.Error()] = 1 - } - w.stats.FailedReqs++ - return + w.updateErrStats(err) } - w.stats.CompletedReqs++ } func (w *WorkerBase) process() error { @@ -53,36 +79,60 @@ func (w *WorkerBase) process() error { var end int64 var err error + req, err := newReq(w.client, w.config) + if err != nil { + return err + } + + resp := w.client.NewResponse() + defer func() { if err == nil { w.reqStats <- time.Duration(end - begin) - } - if w.resp != nil { // this frees up the connection to be used by other requests - w.resp.Close() + resp.Close() } }() if w.middleware != nil { - w.middleware(w) + w.middleware(w, req) } - if err = w.client.Do(w.req, w.resp); err != nil { + if err = w.client.Do(req, resp); err != nil { end = time.Now().UnixNano() return err } end = time.Now().UnixNano() - status := w.resp.StatusCode() - _, ok := w.stats.Responses[(ResponseCode(status))] + w.updateRespStats(req, resp) + return nil +} + +func (w *WorkerBase) updateRespStats(req http_clients.Request, resp http_clients.Response) { + w.statsSuccessLock.Lock() + defer w.statsSuccessLock.Unlock() + + if w.reqSize == 0 { + w.reqSize = req.Size() + } + + if w.respSize == 0 { + w.respSize = resp.Size() + } + + w.CompletedReqs.Add(1) + + val, ok := w.stats.Responses.Load(ResponseCode(resp.StatusCode())) if ok { - w.stats.Responses[(ResponseCode(status))]++ - return nil + w.stats.Responses.Store(ResponseCode(resp.StatusCode()), val.(int64)+1) + return } - w.stats.Responses[(ResponseCode(status))] = 1 - return nil + + w.stats.Responses.Store(ResponseCode(resp.StatusCode()), int64(1)) } func (w *WorkerBase) Stats() Stats { + w.stats.FailedReqs = w.FailedReqs.Load() + w.stats.CompletedReqs = w.CompletedReqs.Load() return w.stats } diff --git a/wrapper/wrapper.go b/wrapper/wrapper.go index 31a211d..eaa190b 100644 --- a/wrapper/wrapper.go +++ b/wrapper/wrapper.go @@ -15,7 +15,7 @@ import ( "github.com/domsolutions/gopayloader/pkgs/payloader" ) -func RunGoPayLoader(reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename string, headers []string, body, bodyFile string, client string) error { +func RunGoPayLoader(reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename string, headers []string, body, bodyFile string, client string, parallel bool) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -33,7 +33,7 @@ func RunGoPayLoader(reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, req method, verbose, ticker, - jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename, headers, body, bodyFile, client) + jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename, headers, body, bodyFile, client, parallel) if err := conf.Validate(); err != nil { return err }