Skip to content

Commit

Permalink
increase app access max http request size to 70MiB (#40242)
Browse files Browse the repository at this point in the history
* increase app access max http request size to 70MB

* move request body limiting into handlers

* shim original behavior with custom reader wrappers

* speed up aws handler tests
  • Loading branch information
GavinFrazar authored Apr 19, 2024
1 parent ad6797d commit 6e9bc6a
Show file tree
Hide file tree
Showing 11 changed files with 214 additions and 63 deletions.
11 changes: 10 additions & 1 deletion lib/srv/app/aws/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ type SignerHandlerConfig struct {
*awsutils.SigningService
// Clock is used to override time in tests.
Clock clockwork.Clock
// MaxHTTPRequestBodySize is the limit on how big a request body can be.
MaxHTTPRequestBodySize int64
}

// CheckAndSetDefaults validates the AwsSignerHandlerConfig.
Expand All @@ -81,6 +83,12 @@ func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error {
if cfg.Clock == nil {
cfg.Clock = clockwork.NewRealClock()
}

// Limit HTTP request body size to 70MB, which matches AWS Lambda function
// zip file upload limit (50MB) after accounting for base64 encoding bloat.
if cfg.MaxHTTPRequestBodySize == 0 {
cfg.MaxHTTPRequestBodySize = 70 << 20
}
return nil
}

Expand Down Expand Up @@ -117,6 +125,7 @@ func (s *signerHandler) formatForwardResponseError(rw http.ResponseWriter, r *ht

// ServeHTTP handles incoming requests by signing them and then forwarding them to the proper AWS API.
func (s *signerHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
req.Body = utils.MaxBytesReader(w, req.Body, s.MaxHTTPRequestBodySize)
if err := s.serveHTTP(w, req); err != nil {
s.formatForwardResponseError(w, req, err)
return
Expand Down Expand Up @@ -228,7 +237,7 @@ func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.Resolved
}
outReq.Body = http.NoBody
if r.Body != nil {
outReq.Body = io.NopCloser(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize))
outReq.Body = r.Body
}
// need to rewrite the host header as well. The oxy forwarder will do this for us,
// since we use the PassHostHeader(false) option, but if host is a signed header
Expand Down
71 changes: 70 additions & 1 deletion lib/srv/app/aws/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand All @@ -36,6 +37,7 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/lambda"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -101,6 +103,39 @@ func dynamoRequestWithTransport(url string, provider client.ConfigProvider, tran
return err
}

// dont make tests generate huge requests just to test limiting the request
// size. Use a 1MB limit instead of the actual 70MB limit.
const maxTestHTTPRequestBodySize = 1 << 20

func maxSizeExceededRequest(url string, provider client.ConfigProvider, _ string) error {
// fake an upload that's too large
payload := strings.Repeat("x", maxTestHTTPRequestBodySize)
return lambdaRequestWithPayload(url, provider, payload)
}

func lambdaRequest(url string, provider client.ConfigProvider, awsHost string) error {
// fake a zip file with 70% of the max limit. Lambda will base64 encode it,
// which bloats it up, and our proxy should still handle it.
const size = (maxTestHTTPRequestBodySize * 7) / 10
payload := strings.Repeat("x", size)
return lambdaRequestWithPayload(url, provider, payload)
}

func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payload string) error {
lambdaClient := lambda.New(provider, &aws.Config{
Endpoint: &url,
MaxRetries: aws.Int(0),
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
})
_, err := lambdaClient.UpdateFunctionCode(&lambda.UpdateFunctionCodeInput{
FunctionName: aws.String("fakeFunc"),
ZipFile: []byte(payload),
})
return err
}

func assumeRoleRequest(requestDuration time.Duration) makeRequest {
return func(url string, provider client.ConfigProvider, _ string) error {
stsClient := sts.New(provider, &aws.Config{
Expand Down Expand Up @@ -296,6 +331,37 @@ func TestAWSSignerHandler(t *testing.T) {
require.NoError,
},
},
{
name: "Lambda access",
app: consoleApp,
awsClientSession: session.Must(session.NewSession(&aws.Config{
Credentials: staticAWSCredentialsForClient,
Region: aws.String("us-east-1"),
})),
request: lambdaRequest,
wantHost: "lambda.us-east-1.amazonaws.com",
wantAuthCredKeyID: "AKIDl",
wantAuthCredService: "lambda",
wantAuthCredRegion: "us-east-1",
wantEventType: &events.AppSessionRequest{},
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
name: "Request exceeding max size",
app: consoleApp,
awsClientSession: session.Must(session.NewSession(&aws.Config{
Credentials: staticAWSCredentialsForClient,
Region: aws.String("us-east-1"),
})),
request: maxSizeExceededRequest,
errAssertionFns: []require.ErrorAssertionFunc{
// TODO(gavin): change this to [http.StatusRequestEntityTooLarge]
// after updating [trace.ErrorToCode].
hasStatusCode(http.StatusTooManyRequests),
},
},
{
name: "AssumeRole success (shorter identity duration)",
app: consoleApp,
Expand Down Expand Up @@ -353,7 +419,9 @@ func TestAWSSignerHandler(t *testing.T) {
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fakeClock := clockwork.NewFakeClock()
mockAwsHandler := func(w http.ResponseWriter, r *http.Request) {
// check that we got what the test case expects first.
Expand Down Expand Up @@ -539,7 +607,8 @@ func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Applic
return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String())
},
},
Clock: clock,
Clock: clock,
MaxHTTPRequestBodySize: maxTestHTTPRequestBodySize,
})
require.NoError(t, err)
mux := http.NewServeMux()
Expand Down
3 changes: 3 additions & 0 deletions lib/srv/app/azure/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ func newAzureHandler(ctx context.Context, config HandlerConfig) (*handler, error

// RoundTrip handles incoming requests and forwards them to the proper API.
func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Body != nil {
req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize)
}
if err := s.serveHTTP(w, req); err != nil {
s.formatForwardResponseError(w, req, err)
return
Expand Down
3 changes: 3 additions & 0 deletions lib/srv/app/gcp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ func newGCPHandler(ctx context.Context, config HandlerConfig) (*handler, error)

// RoundTrip handles incoming requests and forwards them to the proper API.
func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Body != nil {
req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize)
}
if err := s.serveHTTP(w, req); err != nil {
s.formatForwardResponseError(w, req, err)
return
Expand Down
50 changes: 32 additions & 18 deletions lib/srv/db/clickhouse/engine_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/andybalholm/brotli"
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/utils"
Expand All @@ -59,31 +60,44 @@ func (e *Engine) handleHTTPConnection(ctx context.Context, sessionCtx *common.Se
if err != nil {
return trace.Wrap(err)
}
query, err := getQuery(req)
if err != nil {
if err := e.handleRequest(req, sessionCtx, tr); err != nil {
return trace.Wrap(err)
}
}
}

queryEvent := common.Query{
Query: query,
Parameters: []string{fmt.Sprintf("url=%s", req.URL.String())},
}
func (e *Engine) handleRequest(req *http.Request, sessionCtx *common.Session, tr *http.Transport) error {
if req.Body != nil {
// we have to close the request body since [http.Server] didn't serve it
// up for us.
defer req.Body.Close()
req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize))
}
query, err := getQuery(req)
if err != nil {
return trace.Wrap(err)
}

e.Audit.OnQuery(e.Context, sessionCtx, queryEvent)
queryEvent := common.Query{
Query: query,
Parameters: []string{fmt.Sprintf("url=%s", req.URL.String())},
}

if err := e.handleRequest(req, sessionCtx); err != nil {
return trace.Wrap(err)
}
e.Audit.OnQuery(e.Context, sessionCtx, queryEvent)

resp, err := tr.RoundTrip(req)
if err != nil {
return trace.Wrap(err)
}
if err := e.rewriteRequest(req, sessionCtx); err != nil {
return trace.Wrap(err)
}

if err := e.writeResp(resp); err != nil {
return trace.Wrap(err)
}
resp, err := tr.RoundTrip(req)
if err != nil {
return trace.Wrap(err)
}

if err := e.writeResp(resp); err != nil {
return trace.Wrap(err)
}
return nil
}

func handleCompression(body []byte, compression string) ([]byte, error) {
Expand Down Expand Up @@ -155,7 +169,7 @@ func (e *Engine) writeResp(resp *http.Response) error {
return nil
}

func (e *Engine) handleRequest(req *http.Request, sessionCtx *common.Session) error {
func (e *Engine) rewriteRequest(req *http.Request, sessionCtx *common.Session) error {
uri, err := url.Parse(sessionCtx.Database.GetURI())
if err != nil {
return trace.Wrap(err)
Expand Down
2 changes: 2 additions & 0 deletions lib/srv/db/dynamodb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/gravitational/trace"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
apiaws "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/cloud"
Expand Down Expand Up @@ -180,6 +181,7 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws.
if req.Body != nil {
// make sure we close the incoming request's body. ignore any close error.
defer req.Body.Close()
req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize))
}

re, err := e.resolveEndpoint(req)
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/db/elasticsearch/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/gravitational/trace"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -156,6 +157,11 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
func (e *Engine) process(ctx context.Context, sessionCtx *common.Session, req *http.Request, client *http.Client, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) error {
msgFromClient.Inc()

if req.Body != nil {
// make sure we close the incoming request's body. ignore any close error.
defer req.Body.Close()
req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize))
}
payload, err := utils.GetAndReplaceRequestBody(req)
if err != nil {
return trace.Wrap(err)
Expand Down
11 changes: 10 additions & 1 deletion lib/srv/db/opensearch/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/gravitational/trace"
"github.com/prometheus/client_golang/prometheus"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/cloud"
Expand Down Expand Up @@ -190,6 +191,11 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error
func (e *Engine) process(ctx context.Context, tr *http.Transport, signer *libaws.SigningService, req *http.Request, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) error {
msgFromClient.Inc()

if req.Body != nil {
// make sure we close the incoming request's body. ignore any close error.
defer req.Body.Close()
req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize))
}
reqCopy, payload, err := e.rewriteRequest(ctx, req)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -332,11 +338,14 @@ func (e *Engine) emitAuditEvent(req *http.Request, body []byte, statusCode uint3

// sendResponse sends the response back to the OpenSearch client.
func (e *Engine) sendResponse(serverResponse *http.Response) error {
if serverResponse.Body != nil {
defer serverResponse.Body.Close()
serverResponse.Body = io.NopCloser(io.LimitReader(serverResponse.Body, teleport.MaxHTTPResponseSize))
}
payload, err := utils.GetAndReplaceResponseBody(serverResponse)
if err != nil {
return trace.Wrap(err)
}

// serverResponse may be HTTP2 response, but we should reply with HTTP 1.1
clientResponse := &http.Response{
ProtoMajor: 1,
Expand Down
10 changes: 5 additions & 5 deletions lib/srv/db/snowflake/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,24 @@ func readRequestBody(req *http.Request) ([]byte, error) {
return nil, trace.Wrap(err)
}

return maybeReadGzip(&req.Header, body)
return maybeReadGzip(&req.Header, body, teleport.MaxHTTPRequestSize)
}

func readResponseBody(resp *http.Response) ([]byte, error) {
defer resp.Body.Close()

body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPRequestSize)
body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)
if err != nil {
return nil, trace.Wrap(err)
}

return maybeReadGzip(&resp.Header, body)
return maybeReadGzip(&resp.Header, body, teleport.MaxHTTPResponseSize)
}

// maybeReadGzip checks if the body is gzip encoded and returns decoded version.
// To determine gzip encoding the beginning of body message is being checked
// instead of HTTP header and the second one was less reliable during testing.
func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) {
func maybeReadGzip(headers *http.Header, body []byte, limit int64) ([]byte, error) {
gzipMagic := []byte{0x1f, 0x8b, 0x08}

// Check if the body is gzip encoded. Alternative here could check
Expand All @@ -108,7 +108,7 @@ func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) {
}
defer bodyGZ.Close()

body, err = utils.ReadAtMost(bodyGZ, teleport.MaxHTTPRequestSize)
body, err = utils.ReadAtMost(bodyGZ, limit)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
Loading

0 comments on commit 6e9bc6a

Please sign in to comment.