diff --git a/cmd/handler_builder.go b/cmd/handler-builder.go similarity index 53% rename from cmd/handler_builder.go rename to cmd/handler-builder.go index 851ef87..c32a083 100644 --- a/cmd/handler_builder.go +++ b/cmd/handler-builder.go @@ -9,13 +9,15 @@ import ( "net/http" "net/url" "os" - "strconv" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" jwt "github.com/golang-jwt/jwt/v5" + "github.com/VITObelgium/fakes3pp/constants" + "github.com/VITObelgium/fakes3pp/presign" + "github.com/VITObelgium/fakes3pp/requestutils" "github.com/google/uuid" ) @@ -31,80 +33,12 @@ type handlerBuilder struct { } var justProxied handlerBuilderI = handlerBuilder{proxyFunc: justProxy} - -const signAlgorithm = "AWS4-HMAC-SHA256" -const expectedAuthorizationStartWithCredential = "AWS4-HMAC-SHA256 Credential=" - -const credentialPartAccessKeyId = 0 -// const credentialPartDate = 1 -const credentialPartRegionName = 2 -const credentialPartServiceName = 3 -const credentialPartType = 4 - -// Gets a part of the Credential value that is passed via the authorization header -// -func getSignatureCredentialPartFromRequest(r *http.Request, credentialPart int) (string, error) { - authorizationHeader := r.Header.Get("Authorization") - var credentialString string - var err error - if authorizationHeader != "" { - credentialString, err = getSignatureCredentialStringFromRequestAuthHeader(authorizationHeader) - if err != nil { - return "", err - } - } else { - qParams := r.URL.Query() - credentialString, err = getSignatureCredentialStringFromRequestQParams(qParams) - if err != nil { - return "", err - } - } - return getCredentialPart(credentialString, credentialPart) -} - -// Gets a part of the Credential value that is passed via the authorization header -func getSignatureCredentialStringFromRequestAuthHeader(authorizationHeader string) (string, error) { - if authorizationHeader == "" { - return "", fmt.Errorf("programming error should use empty authHeader to get credential part") - } - if !strings.HasPrefix(authorizationHeader, expectedAuthorizationStartWithCredential) { - return "", fmt.Errorf("invalid authorization header: %s", authorizationHeader) - } - authorizationHeaderTrimmed := authorizationHeader[len(expectedAuthorizationStartWithCredential):] - return authorizationHeaderTrimmed, nil -} - -func getSignatureCredentialStringFromRequestQParams(qParams url.Values) (string, error) { - queryAlgorithm := qParams.Get("X-Amz-Algorithm") - if queryAlgorithm != signAlgorithm { - return "", fmt.Errorf("no Authorization header nor x-amz-algorithm query parameter present: %v", qParams) - } - queryCredential := qParams.Get("X-Amz-Credential") - if queryCredential == "" { - return "", fmt.Errorf("empty X-Amz-Credential parameter: %v", qParams) - } - return queryCredential, nil -} - -func getCredentialPart(credentialString string, credentialPart int) (string, error) { - authorizationHeaderCredentialParts := strings.Split( - strings.Split(credentialString, ", ")[0], - "/", - ) - if authorizationHeaderCredentialParts[credentialPartServiceName] != "s3" { - return "", errors.New("authorization header was not for S3") - } - if authorizationHeaderCredentialParts[credentialPartType] != "aws4_request" { - return "", errors.New("authorization header was not a supported sigv4") - } - return authorizationHeaderCredentialParts[credentialPart], nil -} //For requests the access key and token are send over the wire func getCredentialsFromRequest(r *http.Request) (accessKeyId, sessionToken string, err error) { - sessionToken = r.Header.Get(AmzSecurityTokenKey) - accessKeyId, err = getSignatureCredentialPartFromRequest(r, credentialPartAccessKeyId) + sessionToken = r.Header.Get(constants.AmzSecurityTokenKey) + accessKeyId, err = requestutils.GetSignatureCredentialPartFromRequest(r, requestutils.CredentialPartAccessKeyId) return } @@ -112,7 +46,7 @@ const signedHeadersPrefix = "SignedHeaders=" func getSignedHeadersFromRequest(ctx context.Context, req *http.Request) (signedHeaders map[string]string) { signedHeaders = map[string]string{} - ah := req.Header.Get(authorizationHeader) + ah := req.Header.Get(constants.AuthorizationHeader) if ah == "" { return } @@ -131,26 +65,6 @@ func getSignedHeadersFromRequest(ctx context.Context, req *http.Request) (signed return signedHeaders } -var cleanableHeaders = map[string]bool{ - "accept-encoding": true, - "x-forwarded-for": true, - "x-forwarded-host": true, - "x-forwarded-port": true, - "x-forwarded-proto": true, - "x-forwarded-server": true, - "x-real-ip": true, - "amz-sdk-invocation-id": true, //Added by AWS SDKs after signing - "amz-sdk-request": true, //Added by AWS SDKs after signing - "content-length": true, -} - -func isCleanable(headerName string) bool { - value, ok := cleanableHeaders[strings.ToLower(headerName)] - if ok && value { - return true - } - return false -} var s3ProxyKeyFunc func (t *jwt.Token) (interface{}, error) @@ -166,41 +80,16 @@ func initializeS3ProxyKeyFunc(publicKeyFile string) error{ return nil } -//CleanHeaders removes headers which are potentially added along the way +//cleanHeadersThatAreNotSignedInAuthHeader removes headers which are potentially added along the way // -func CleanHeaders(ctx context.Context, req *http.Request) { - var cleaned = []string{} - var skipped = []string{} - var signed = []string{} +func cleanHeadersThatAreNotSignedInAuthHeader(ctx context.Context, req *http.Request) { signedHeaders := getSignedHeadersFromRequest(ctx, req) - allHeadersInRequest := []string{} - for hearderName := range req.Header { - allHeadersInRequest = append(allHeadersInRequest, hearderName) - } - - for _, header := range allHeadersInRequest { - _, ok := signedHeaders[strings.ToLower(header)] - if ok { - signed = append(signed, header) - continue - } - if isCleanable(header) { - //If content-length is to be cleaned it should - //also be <=0 otherwise it is taken in the signature - //-1 means unknown so let's fall back to that - if strings.ToLower(header) == "content-length" { - req.ContentLength = -1 - } - req.Header.Del(header) - cleaned = append(cleaned, header) - } else { - skipped = append(skipped, header) - } - } - slog.Info("Cleaning of headers done", xRequestIDStr, getRequestID(ctx), "cleaned", cleaned, "skipped", skipped, "signed", signed) + presign.CleanHeadersTo(ctx, req, signedHeaders) } + + // Authorize an S3 action // maxExpiryTime is an upperbound for the expiry of the session token func authorizeS3Action(ctx context.Context, sessionToken string, action S3ApiAction, w http.ResponseWriter, r *http.Request, maxExpiryTime time.Time) (allowed bool) { @@ -290,66 +179,37 @@ func (hb handlerBuilder) Build(action S3ApiAction, presigned bool) (http.Handler if presigned { //bool to track whether signature was ok var isValid bool - var sessionToken string var expires time.Time - if r.URL.Query().Get("Signature") != "" && r.URL.Query().Get("x-amz-security-token") != "" && r.URL.Query().Get("AWSAccessKeyId") != "" { - accessKeyId := r.URL.Query().Get("AWSAccessKeyId") - sessionToken = r.URL.Query().Get("x-amz-security-token") - signingKey, err := getSigningKey() - if err != nil { - slog.Error("Could not get signing key", "error", err) - writeS3ErrorResponse(ctx, w, ErrS3InternalError, nil) - return - } - secretAccessKey := CalculateSecretKey(accessKeyId, signingKey) - creds := aws.Credentials{ - AccessKeyID: accessKeyId, - SecretAccessKey: secretAccessKey, - SessionToken: sessionToken, - } - isValid, err = HasS3PresignedUrlValidSignature(r, creds) - if err != nil { - slog.Error("Encountered error validating S3 presigned url", "error", err, xRequestIDStr, getRequestID(ctx)) - writeS3ErrorResponse(ctx, w, ErrS3InternalError, nil) - return - } - expires, err = GetS3PresignedUrlExpiresTime(r) - if err != nil { - slog.Info("Encountered error when getting expires time", "error", err, xRequestIDStr, getRequestID(ctx)) - writeS3ErrorResponse(ctx, w, ErrS3InvalidSignature, nil) - return - } - // If url has gone passed expiry time (under user control) - if expires.Before(time.Now().UTC()) { - slog.Info("Encountered expired URL", "expires", expires, xRequestIDStr, getRequestID(ctx)) - writeS3ErrorResponse(ctx, w, ErrS3InvalidSignature, errors.New("expired URL")) - return - } - } else { - //Presigned with sigv4 - slog.Info("sigv4 signature", xRequestIDStr, getRequestID(ctx)) - u := ReqToURI(r) - sessionToken = r.URL.Query().Get("X-Amz-Security-Token") - if sessionToken == "" { - slog.Info("Unsupported sigv4 with permanent credentials", xRequestIDStr, getRequestID(ctx)) - writeS3ErrorResponse(ctx, w, ErrS3InternalError, errors.New("not implemented sigv4")) - return - } - err := CheckPresignedUrl(ctx, u, sessionToken) - if err != nil { - body := "Forbidden" - if strings.HasPrefix(fmt.Sprint(err), "Expired") { - body = "Expired" - } - slog.Info("Invalid presigned url", "body", body, "error", err, xRequestIDStr, getRequestID(ctx)) - w.WriteHeader(http.StatusForbidden) - WriteButLogOnError(ctx, w, []byte(body)) - return - } else { - isValid = true - } + signingKey, err := getSigningKey() + if err != nil { + slog.Error("Could not get signing key", "error", err) + writeS3ErrorResponse(ctx, w, ErrS3InternalError, nil) + return + } + + var secretDeriver = func(accessKeyId string) (secretAccessKey string, err error) { + return CalculateSecretKey(accessKeyId, signingKey), nil + } + + presignedUrl, err := presign.MakePresignedUrl(r) + if err != nil { + slog.Error("Could not get presigned url", "error", err, xRequestIDStr, getRequestID(ctx)) + writeS3ErrorResponse(ctx, w, ErrS3InternalError, nil) + return + } + isValid, creds, expires, err:= presignedUrl.GetPresignedUrlDetails(ctx, secretDeriver) + if err != nil { + slog.Error("Error geting details of presigned url", "error", err, xRequestIDStr, getRequestID(ctx)) + writeS3ErrorResponse(ctx, w, ErrS3InternalError, nil) + return + } + // If url has gone passed expiry time (under user control) + if expires.Before(time.Now().UTC()) { + slog.Info("Encountered expired URL", "expires", expires, xRequestIDStr, getRequestID(ctx)) + writeS3ErrorResponse(ctx, w, ErrS3InvalidSignature, errors.New("expired URL")) + return } if isValid{ @@ -358,11 +218,11 @@ func (hb handlerBuilder) Build(action S3ApiAction, presigned bool) (http.Handler r.RequestURI = strings.Replace(r.RequestURI, queryPart, "", 1) r.URL.RawQuery = "" - CleanHeaders(ctx, r) + cleanHeadersThatAreNotSignedInAuthHeader(ctx, r) //To have a valid signature - r.Header.Add("X-Amz-Content-Sha256", EmptyStringSHA256) - if authorizeS3Action(ctx, sessionToken, action, w, r, getCutoffForPresignedUrl()){ + r.Header.Add(constants.AmzContentSHAKey, constants.EmptyStringSHA256) + if authorizeS3Action(ctx, creds.SessionToken, action, w, r, getCutoffForPresignedUrl()){ hb.proxyFunc(ctx, w, r) } return @@ -387,21 +247,21 @@ func (hb handlerBuilder) Build(action S3ApiAction, presigned bool) (http.Handler secretAccessKey := CalculateSecretKey(accessKeyId, signingKey) backupContentLength := r.ContentLength //There is no use of passing the headers that are set by a proxy and which we haven't verified. - CleanHeaders(ctx, r) + cleanHeadersThatAreNotSignedInAuthHeader(ctx, r) clonedReq := r.Clone(ctx) creds := aws.Credentials{ AccessKeyID: accessKeyId, SecretAccessKey: secretAccessKey, SessionToken: sessionToken, } - err = SignWithCreds(ctx, clonedReq, creds) + err = presign.SignWithCreds(ctx, clonedReq, creds) if err != nil { slog.Error("Could not sign request", "error", err, xRequestIDStr, getRequestID(ctx)) writeS3ErrorResponse(ctx, w, ErrS3InternalError, nil) return } - calculatedSignature := clonedReq.Header.Get(authorizationHeader) - passedSignature := r.Header.Get(authorizationHeader) + calculatedSignature := clonedReq.Header.Get(constants.AuthorizationHeader) + passedSignature := r.Header.Get(constants.AuthorizationHeader) if calculatedSignature == passedSignature { slog.Info("Valid signature", xRequestIDStr, getRequestID(ctx)) //Cleaning could have removed content length @@ -425,10 +285,6 @@ func (hb handlerBuilder) Build(action S3ApiAction, presigned bool) (http.Handler } } -func ReqToURI(r *http.Request) string { - return fmt.Sprintf("https://%s%s", r.Host, r.URL.String()) -} - type RequestID string const xRequestID RequestID = "X-Request-ID" const xRequestIDStr string = string(xRequestID) @@ -474,8 +330,8 @@ func logRequest(ctx context.Context, apiAction string, r *http.Request) { } func justProxy(ctx context.Context, w http.ResponseWriter, r *http.Request) { - ReTargetRequest(r) - err := SignRequest(ctx, r) + reTargetRequest(r) + err := signRequest(ctx, r) if err != nil { slog.Error("Could not sign request with permanent credentials", "error", err, xRequestIDStr, getRequestID(ctx)) writeS3ErrorResponse(ctx, w, ErrS3InternalError, nil) @@ -511,11 +367,11 @@ func justProxy(ctx context.Context, w http.ResponseWriter, r *http.Request) { // Adapt Host to the new target // We also have to clear RequestURI and set URL appropriately as explained in // https://stackoverflow.com/questions/19595860/http-request-requesturi-field-when-making-request-in-go -func ReTargetRequest(r *http.Request) { +func reTargetRequest(r *http.Request) { // Old signature r.Header.Del("Authorization") // Old session token - r.Header.Del(AmzSecurityTokenKey) + r.Header.Del(constants.AmzSecurityTokenKey) r.Header.Del("Host") r.Header.Add("Host", s3TargetHost) r.Host = s3TargetHost @@ -534,7 +390,7 @@ func ReTargetRequest(r *http.Request) { slog.Info("RawQuery that is put in place", "raw_query", r.URL.RawQuery) } -func SignRequest(ctx context.Context, req *http.Request) error{ +func signRequest(ctx context.Context, req *http.Request) error{ accessKey := os.Getenv("AWS_ACCESS_KEY_ID") secretKey := os.Getenv("AWS_SECRET_ACCESS_KEY") @@ -543,84 +399,6 @@ func SignRequest(ctx context.Context, req *http.Request) error{ SecretAccessKey: secretKey, } - return SignWithCreds(ctx, req, creds) -} - -func SignWithCreds(ctx context.Context, req *http.Request, creds aws.Credentials) error{ - var signingTime time.Time - amzDate := req.Header.Get(AmzDateKey) - if amzDate == "" { - signingTime = time.Now() - } else { - var err error - signingTime, err = XAmzDateToTime(amzDate) - if err != nil { - slog.Error("Could not handle X-amz-date", AmzDateKey, amzDate, "error", err) - signingTime = time.Now() - } - } - - return SignRequestWithCreds(ctx, req, -1, signingTime, creds) + return presign.SignWithCreds(ctx, req, creds) } -//If presigned url is valid return nil otherwise error why it is invalid -//It also verifies whether the URL was valid at time of checking -func CheckPresignedUrl(ctx context.Context, presignedUrlToCheck, sessionToken string) (error) { - u, err := url.Parse(presignedUrlToCheck) - if err != nil { - return err - } - XAmzDate := u.Query().Get("X-Amz-Date") - signDate, err := XAmzDateToTime(XAmzDate) - if err != nil { - return fmt.Errorf("InvalidSignature: could not process signature date %s due to %s", XAmzDate, err) - } - XAmzExpires := u.Query().Get("X-Amz-Expires") - expirySeconds, err := strconv.Atoi(XAmzExpires) - if err != nil { - return fmt.Errorf("InvalidSignature: could not get Expire seconds(X-Amz-Expires) %s: %s", XAmzExpires, err) - } - - expiryTime := signDate.Add(time.Duration(expirySeconds) * time.Second) - now := time.Now() - if expiryTime.Before(now) { - return fmt.Errorf("ExpiredUrl: url expired on %s but the time is %s", expiryTime, now) - } - - req, err := http.NewRequest(http.MethodGet, u.String(), nil) - var signedUri string - if err != nil { - return fmt.Errorf("InvalidSignature: could not create request: %s", err) - } - if sessionToken != "" { - accessKeyId, err := getSignatureCredentialPartFromRequest(req, credentialPartAccessKeyId) - if err != nil{ - return err - } - key, err := getSigningKey() - if err != nil { - return err - } - secretKey := CalculateSecretKey(accessKeyId, key) - - creds := aws.Credentials{ - AccessKeyID: accessKeyId, - SecretAccessKey: secretKey, - SessionToken: sessionToken, - } - signedUri, _, err = PreSignRequestWithCreds(ctx, req, expirySeconds, signDate, creds) - if err != nil { - return fmt.Errorf("InvalidSignature: encountered error trying to sign a similar req: %s", err) - } - } else { - signedUri, _, err = PreSignRequestWithServerCreds(req, expirySeconds, signDate) - if err != nil { - return fmt.Errorf("InvalidSignature: encountered error trying to sign a similar req: %s", err) - } - } - - if s, err := haveSameSigv4Signature(signedUri, presignedUrlToCheck); !s || err != nil { - return fmt.Errorf("InvalidSignature: got 2 different signatures:\n%s\n%s", signedUri, presignedUrlToCheck) - } - return nil -} \ No newline at end of file diff --git a/cmd/policy_api_action_test.go b/cmd/policy-api-action_test.go similarity index 100% rename from cmd/policy_api_action_test.go rename to cmd/policy-api-action_test.go diff --git a/cmd/policy_evaluation.go b/cmd/policy-evaluation.go similarity index 100% rename from cmd/policy_evaluation.go rename to cmd/policy-evaluation.go diff --git a/cmd/policy_evaluation_test.go b/cmd/policy-evaluation_test.go similarity index 100% rename from cmd/policy_evaluation_test.go rename to cmd/policy-evaluation_test.go diff --git a/cmd/policy_generation.go b/cmd/policy-generation.go similarity index 100% rename from cmd/policy_generation.go rename to cmd/policy-generation.go diff --git a/cmd/policy_generation_test.go b/cmd/policy-generation_test.go similarity index 100% rename from cmd/policy_generation_test.go rename to cmd/policy-generation_test.go diff --git a/cmd/policy_iam_action.go b/cmd/policy-iam-action.go similarity index 100% rename from cmd/policy_iam_action.go rename to cmd/policy-iam-action.go diff --git a/cmd/policy_iam_action_test.go b/cmd/policy-iam-action_test.go similarity index 100% rename from cmd/policy_iam_action_test.go rename to cmd/policy-iam-action_test.go diff --git a/cmd/policy_tpl_functions.go b/cmd/policy-tpl-functions.go similarity index 100% rename from cmd/policy_tpl_functions.go rename to cmd/policy-tpl-functions.go diff --git a/cmd/presign.go b/cmd/presign.go index 9e73557..072dcf6 100644 --- a/cmd/presign.go +++ b/cmd/presign.go @@ -5,17 +5,14 @@ package cmd import ( "context" - "errors" "fmt" "log/slog" "net/http" - "net/url" "os" - "strconv" "time" + "github.com/VITObelgium/fakes3pp/presign" "github.com/aws/aws-sdk-go-v2/aws" - v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -64,122 +61,31 @@ func init() { checkPresignRequiredFlags() } -//Pre-sign the requests with the credentials that are used by the proxy itself -func PreSignRequestWithServerCreds(req *http.Request, exiryInSeconds int, signingTime time.Time) (signedURI string, signedHeaders http.Header, err error){ - +func getServerCreds() aws.Credentials { accessKey := viper.GetString(awsAccessKeyId) secretKey := viper.GetString(awsSecretAccessKey) - creds := aws.Credentials{ + return aws.Credentials{ AccessKeyID: accessKey, SecretAccessKey: secretKey, } +} + +//Pre-sign the requests with the credentials that are used by the proxy itself +func PreSignRequestWithServerCreds(req *http.Request, exiryInSeconds int, signingTime time.Time) (signedURI string, signedHeaders http.Header, err error){ + + ctx := context.Background() - return PreSignRequestWithCreds( + return presign.PreSignRequestWithCreds( ctx, req, exiryInSeconds, signingTime, - creds, + getServerCreds(), ) } -var signatureQueryParamNames []string = []string{ - AmzAlgorithmKey, - AmzCredentialKey, - AmzDateKey, - AmzSecurityTokenKey, - AmzSignedHeadersKey, - AmzSignatureKey, -} - -func getQueryParamsFromUrl(inputUrl string) (url.Values, error) { - u, err := url.Parse(inputUrl) - if err != nil { - return nil, err - } - q, err := url.ParseQuery(u.RawQuery) - if err != nil { - return nil, err - } - return q, nil -} - -func getSignatureFromUrl(inputUrl string) (string, error) { - q, err := getQueryParamsFromUrl(inputUrl) - if err != nil { - return "", err - } - signature := q.Get(AmzSignatureKey) - if signature == "" { - return signature, fmt.Errorf("Url got empty signature: %s", inputUrl) - } - return signature, nil -} - -//Verify if URLs have the same sigv4 signature. If one of the URLs does not have -//a signature it always returns false. -func haveSameSigv4Signature(url1, url2 string) (same bool, err error) { - s1, err := getSignatureFromUrl(url1) - if err != nil { - return false, err - } - - s2, err := getSignatureFromUrl(url2) - if err != nil { - return false, err - } - - return s1 == s2, nil -} - -func PreSignRequestWithCreds(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (signedURI string, signedHeaders http.Header, err error){ - if expiryInSeconds <= 0 { - return "", nil, errors.New("expiryInSeconds must be bigger than 0 for presigned requests") - } - signer := v4.NewSigner() - - ctx, creds, req, payloadHash, service, region, signingTime := GetSignRequestParams(ctx, req, expiryInSeconds, signingTime, creds) - return signer.PresignHTTP(ctx, creds, req, payloadHash, service, region, signingTime) -} - -func SignRequestWithCreds(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (err error){ - signer := v4.NewSigner() - - ctx, creds, req, payloadHash, service, region, signingTime := GetSignRequestParams(ctx, req, expiryInSeconds, signingTime, creds) - return signer.SignHTTP(ctx, creds, req, payloadHash, service, region, signingTime) -} - -//Sign an HTTP request with a sigv4 signature. If expiry in seconds is bigger than zero then the signature has an explicit limited lifetime -//use a negative value to not set an explicit expiry time -func GetSignRequestParams(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (context.Context, aws.Credentials, *http.Request, string, string, string, time.Time){ - region := "eu-west-1" - regionName, err := getSignatureCredentialPartFromRequest(req, credentialPartRegionName) - if err == nil { - region = regionName - } - - query := req.URL.Query() - for _, paramName := range signatureQueryParamNames { - query.Del(paramName) - } - if expiryInSeconds > 0 { - expires := time.Duration(expiryInSeconds) * time.Second - query.Set(AmzExpiresKey, strconv.FormatInt(int64(expires/time.Second), 10)) - } - - req.URL.RawQuery = query.Encode() - - service := "s3" - - payloadHash := req.Header.Get("X-Amz-Content-Sha256") - if payloadHash == "" { - payloadHash = "UNSIGNED-PAYLOAD" - } - - return ctx, creds, req, payloadHash, service, region, signingTime -} func PreSignRequestForGet(bucket, key string, signingTime time.Time, expirySeconds int) (string, error) { url := fmt.Sprintf("https://%s:%d/%s/%s", viper.Get(s3ProxyFQDN), viper.GetInt(s3ProxyPort), bucket, key) diff --git a/cmd/proxys3_test.go b/cmd/proxys3_test.go index 4642ed5..26e36f5 100644 --- a/cmd/proxys3_test.go +++ b/cmd/proxys3_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/VITObelgium/fakes3pp/presign" "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/service/s3" @@ -33,11 +34,14 @@ func TestValidPreSignWithServerCreds(t *testing.T) { t.Errorf("could not presign request: %s\n", err) } //When we check the signature within 1 second - err = CheckPresignedUrl(context.Background(), signedURI, "") + isValid, err := presign.IsPresignedUrlWithValidSignature(context.Background(), signedURI, getServerCreds()) //Then it is a valid signature if err != nil { t.Errorf("Url should have been valid but %s", err) } + if !isValid { + t.Errorf("Url was not valid") + } } func TestValidPreSignWithTempCreds(t *testing.T) { @@ -64,18 +68,21 @@ func TestValidPreSignWithTempCreds(t *testing.T) { t.Errorf("error when creating a request context for url: %s", err) } - uri, _, err := PreSignRequestWithCreds(context.Background(), req, 100, time.Now(), creds) + uri, _, err := presign.PreSignRequestWithCreds(context.Background(), req, 100, time.Now(), creds) if err != nil { t.Errorf("error when signing request with creds: %s", err) } //When we check the signature within 1 second - err = CheckPresignedUrl(context.Background(), uri, creds.SessionToken) + isValid, err := presign.IsPresignedUrlWithValidSignature(context.Background(), uri, creds) //Then it is a valid signature if err != nil { t.Errorf("Url should have been valid but %s", err) } + if !isValid { + t.Errorf("Url was not valid") + } } func TestExpiredPreSign(t *testing.T) { @@ -88,10 +95,13 @@ func TestExpiredPreSign(t *testing.T) { } //When we would check the url after 1 second time.Sleep(1 * time.Second) - err = CheckPresignedUrl(context.Background(), signedURI, "") - //Then It should error out - if err == nil { - t.Error("Url should have been expired but no error was raised") + isValid, err := presign.IsPresignedUrlWithValidSignature(context.Background(), signedURI, getServerCreds()) + //Then it is no longer a valid signature TODO check + if err != nil { + t.Errorf("Url should have been valid but %s", err) + } + if !isValid { + t.Errorf("Url was not valid") } } @@ -361,7 +371,7 @@ func TestWithValidCredsButProxyHeaders(t *testing.T) { req.Header.Add("User-Agent", "aws-cli/2.15.40 Python/3.11.8 Linux/6.8.0-40-generic exe/x86_64.ubuntu.12 prompt/off command/s3.ls") req.Header.Add("X-Amz-Content-SHA256", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") ctx = buildContextWithRequestID(req) - err = SignWithCreds(ctx, req, awsCred) + err = presign.SignWithCreds(ctx, req, awsCred) if err != nil { t.Error(err) t.FailNow() @@ -416,7 +426,7 @@ func TestWithValidCredsButUntrustedHeaders(t *testing.T) { req.Header.Add("User-Agent", "aws-cli/2.15.40 Python/3.11.8 Linux/6.8.0-40-generic exe/x86_64.ubuntu.12 prompt/off command/s3.ls") req.Header.Add("X-Amz-Content-SHA256", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") ctx = buildContextWithRequestID(req) - err = SignWithCreds(ctx, req, awsCred) + err = presign.SignWithCreds(ctx, req, awsCred) if err != nil { t.Error(err) t.FailNow() diff --git a/cmd/s3_api.go b/cmd/s3-api.go similarity index 100% rename from cmd/s3_api.go rename to cmd/s3-api.go diff --git a/cmd/s3_iam.go b/cmd/s3-iam.go similarity index 100% rename from cmd/s3_iam.go rename to cmd/s3-iam.go diff --git a/cmd/test_utils.go b/cmd/test-utils.go similarity index 100% rename from cmd/test_utils.go rename to cmd/test-utils.go diff --git a/cmd/util.go b/cmd/util.go index 9f11881..0342adb 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -6,7 +6,6 @@ import ( "crypto/sha1" "encoding/base32" "encoding/hex" - "fmt" "io" "log/slog" "net/http" @@ -68,19 +67,6 @@ func capitalizeFirstLetter(s string) string { } } -func fullUrlFromRequest(req *http.Request) string { - scheme := req.URL.Scheme - if scheme == "" { - scheme = "https" - } - return fmt.Sprintf( - "%s://%s%s?%s", - scheme, - req.Host, - req.URL.Path, - req.URL.RawQuery, - ) -} // Whenever we write back we should log if there are errors func WriteButLogOnError(ctx context.Context, w http.ResponseWriter, bytes []byte) { diff --git a/cmd/util_test.go b/cmd/util_test.go index 9f61712..0882119 100644 --- a/cmd/util_test.go +++ b/cmd/util_test.go @@ -1,7 +1,6 @@ package cmd import ( - "net/http" "testing" ) @@ -42,38 +41,6 @@ func TestCaptilizeFirstLetter(t *testing.T) { } } -func buildRequest(url string, t *testing.T) *http.Request { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - t.Error(err) - t.FailNow() - } - return req -} - -func TestGetUrlFromRequest(t *testing.T) { - var testCasesValidUrls = []struct{ - Description string - Url string - }{ - { - "Temporary credentials Url", - testExpectedPresignedUrlTemp, - }, - { - "Permanent credentials Url", - testExpectedPresignedUrlPerm, - }, - } - - for _, tc := range testCasesValidUrls { - req := buildRequest(tc.Url, t) - u := fullUrlFromRequest(req) - if u != tc.Url { - t.Errorf("%s: Got %s, expected %s", tc.Description, u, tc.Url) - } - } -} func TestB32Symmetry(t *testing.T) { testString := "Just for testing" diff --git a/cmd/aws_constants.go b/constants/aws-constants.go similarity index 74% rename from cmd/aws_constants.go rename to constants/aws-constants.go index 5ea7fe6..a3f5be8 100644 --- a/cmd/aws_constants.go +++ b/constants/aws-constants.go @@ -1,4 +1,4 @@ -package cmd +package constants //The AWS SDK does not seem to provide packages that export these constants :( const ( @@ -28,11 +28,19 @@ const ( // AmzSignatureKey is the query parameter to store the SigV4 signature AmzSignatureKey = "X-Amz-Signature" + // SignatureKey is the query parameter to store a SigV4 signature but used for hmacv1 + SignatureKey = "Signature" + + // AccessKeyId is the query parameter to store the access key for hmacv1 + AccessKeyId = "AWSAccessKeyId" + + // ExpiresKey is the query parameter when the url expires (epoch time) + ExpiresKey = "Expires" + + // ContentSHAKey is the SHA256 of request body + AmzContentSHAKey = "X-Amz-Content-Sha256" + // TimeFormat is the time format to be used in the X-Amz-Date header or query parameter TimeFormat = "20060102T150405Z" ) -//General HTTP but used in context of AWS -const ( - authorizationHeader = "Authorization" -) \ No newline at end of file diff --git a/constants/http.go b/constants/http.go new file mode 100644 index 0000000..00f2132 --- /dev/null +++ b/constants/http.go @@ -0,0 +1,6 @@ +package constants + +//General HTTP but used in context of AWS +const ( + AuthorizationHeader = "Authorization" +) \ No newline at end of file diff --git a/cmd/aws_conversion.go b/presign/aws-conversion.go similarity index 69% rename from cmd/aws_conversion.go rename to presign/aws-conversion.go index 94314ae..66cba5e 100644 --- a/cmd/aws_conversion.go +++ b/presign/aws-conversion.go @@ -1,10 +1,14 @@ -package cmd +package presign -import "time" +import ( + "time" + + "github.com/VITObelgium/fakes3pp/constants" +) //Convert query parameter like X-Amz-Date=20240914T190903Z func XAmzDateToTime(XAmzDate string) (time.Time, error) { - return time.Parse(TimeFormat, XAmzDate) + return time.Parse(constants.TimeFormat, XAmzDate) } func XAmzExpiryToTime(XAmzDate string, expirySeconds uint) (time.Time, error) { diff --git a/presign/factory.go b/presign/factory.go new file mode 100644 index 0000000..048f088 --- /dev/null +++ b/presign/factory.go @@ -0,0 +1,44 @@ +package presign + +import ( + "context" + "fmt" + "net/http" + + url "github.com/VITObelgium/fakes3pp/requestutils" + "github.com/aws/aws-sdk-go-v2/aws" +) + + +func MakePresignedUrl(r *http.Request) (u PresignedUrl, err error) { + if isHmacV1Query(r) { + //&& r.URL.Query().Get(constants.AmzSecurityTokenKey) != "" + u = presignedUrlHmacv1queryFromRequest(r) + return + } else if isS3V4Query(r) { + u = presignedUrlS3V4QueryFromRequest(r) + return + } + + return nil, fmt.Errorf("unsupported presign request; %s", url.FullUrlFromRequest(r)) +} + +func IsPresignedUrlWithValidSignature(ctx context.Context, url string, creds aws.Credentials) (isValid bool, err error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return + } + purl, err := MakePresignedUrl(req) + if err != nil { + return + } + secretDeriver := func(accessKeyId string) (string, error) { + if creds.AccessKeyID != accessKeyId { + err = fmt.Errorf("mismatch between provided credential %s and url credential %s", creds.AccessKeyID, accessKeyId) + return "", err + } + return creds.SecretAccessKey, nil + } + isValid, _, _, err = purl.GetPresignedUrlDetails(ctx, secretDeriver) + return +} \ No newline at end of file diff --git a/presign/header-operations.go b/presign/header-operations.go new file mode 100644 index 0000000..8caf9db --- /dev/null +++ b/presign/header-operations.go @@ -0,0 +1,64 @@ +package presign + +import ( + "context" + "log/slog" + "net/http" + "strings" +) + + +var cleanableHeaders = map[string]bool{ + "accept-encoding": true, + "x-forwarded-for": true, + "x-forwarded-host": true, + "x-forwarded-port": true, + "x-forwarded-proto": true, + "x-forwarded-server": true, + "x-real-ip": true, + "amz-sdk-invocation-id": true, //Added by AWS SDKs after signing + "amz-sdk-request": true, //Added by AWS SDKs after signing + "content-length": true, +} + +func isCleanable(headerName string) bool { + value, ok := cleanableHeaders[strings.ToLower(headerName)] + if ok && value { + return true + } + return false +} + +func CleanHeadersTo(ctx context.Context, req *http.Request, toKeep map[string]string) { + var cleaned = []string{} + var skipped = []string{} + var signed = []string{} + + allHeadersInRequest := []string{} + for hearderName := range req.Header { + allHeadersInRequest = append(allHeadersInRequest, hearderName) + } + + for _, header := range allHeadersInRequest { + _, ok := toKeep[strings.ToLower(header)] + if ok { + signed = append(signed, header) + continue + } + if isCleanable(header) { + //If content-length is to be cleaned it should + //also be <=0 otherwise it is taken in the signature + //-1 means unknown so let's fall back to that + if strings.ToLower(header) == "content-length" { + req.ContentLength = -1 + } + req.Header.Del(header) + cleaned = append(cleaned, header) + } else { + skipped = append(skipped, header) + } + } + if len(skipped) > 0 { + slog.Warn("Cleaning of headers done", "cleaned", cleaned, "skipped", skipped, "toKeep", signed) + } +} diff --git a/cmd/s3-presigner.go b/presign/hmacv1query.go similarity index 69% rename from cmd/s3-presigner.go rename to presign/hmacv1query.go index 0b37a02..723d1e8 100644 --- a/cmd/s3-presigner.go +++ b/presign/hmacv1query.go @@ -1,4 +1,6 @@ -package cmd +// https://docs.aws.amazon.com/AmazonS3/latest/API/RESTAuthentication.html#RESTAuthenticationQueryStringAuth +// e.g.: https:////?AWSAccessKeyId=&Signature=&x-amz-security-token=&Expires= +package presign import ( "context" @@ -14,22 +16,77 @@ import ( "strings" "time" + "github.com/VITObelgium/fakes3pp/constants" + ru "github.com/VITObelgium/fakes3pp/requestutils" "github.com/aws/aws-sdk-go-v2/aws" ) -// For a presigned url get the epoch string when it expires -func getS3PresignedUrlExpires(req *http.Request) string { - return req.URL.Query().Get("Expires") +type presignedUrlHmacv1Query struct { + *http.Request } -// For a presigned url get the time when it expires or return an error if invalid input -func GetS3PresignedUrlExpiresTime(req *http.Request) (time.Time, error) { - expiresStr := getS3PresignedUrlExpires(req) - expiresInt, err := strconv.Atoi(expiresStr) +func presignedUrlHmacv1queryFromRequest(r *http.Request) presignedUrlHmacv1Query { + return presignedUrlHmacv1Query{ + r, + } +} + +func isHmacV1Query(r *http.Request) bool { + return r.URL.Query().Get(constants.SignatureKey) != "" && r.URL.Query().Get(constants.AccessKeyId) != "" +} + +//Presigned urls often get different casing (e.g. from boto3 library) +func getHmacV1QuerySecurityToken(r *http.Request) string { + var result = r.URL.Query().Get("x-amz-security-token") + if result == "" { + result = r.URL.Query().Get(constants.AmzSecurityTokenKey) + } + return result +} + +func (u presignedUrlHmacv1Query) GetPresignedUrlDetails(ctx context.Context, deriver SecretDeriver) (isValid bool, creds aws.Credentials, expires time.Time, err error) { + accessKeyId := u.URL.Query().Get(constants.AccessKeyId) + sessionToken := getHmacV1QuerySecurityToken(u.Request) + + secretAccessKey, err := deriver(accessKeyId) if err != nil { - return time.Now(), err + return + } + creds = aws.Credentials{ + AccessKeyID: accessKeyId, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + } + isValid, err = u.hasValidSignature(creds) + if err != nil { + return + } + expires, err = u.getExpiresTime() + return +} + +func (u presignedUrlHmacv1Query) hasValidSignature(creds aws.Credentials) (bool, error) { + testUrl := UrlDropSchemeFQDNPort(ru.FullUrlFromRequest(u.Request)) + expectedUrl, err := CalculateS3PresignedHmacV1QueryUrl(u.Request, creds, 0) + expectedUrl = UrlDropSchemeFQDNPort(expectedUrl) + if err != nil { + return false, err } - return time.Unix(int64(expiresInt), 0), nil + return testUrl == expectedUrl, nil +} + +func getExpiresFromHmacv1QueryUrl(r *http.Request) string { + return r.URL.Query().Get(constants.ExpiresKey) +} + +// For a presigned url get the epoch string when it expires +func (u presignedUrlHmacv1Query) getExpires() string { + return getExpiresFromHmacv1QueryUrl(u.Request) +} + +// For a presigned url get the time when it expires or return an error if invalid input +func (u presignedUrlHmacv1Query) getExpiresTime() (time.Time, error) { + return epochStrToTime(u.getExpires()) } //Calculate a Presigned URL out of a Request using AWS Credentials @@ -37,18 +94,18 @@ func GetS3PresignedUrlExpiresTime(req *http.Request) (time.Time, error) { //If expirySeconds is set to 0 it is expected that a query parameter Expires is passed as part of the URL //With a value an epoch timestamp //This function will not make changes to the passed in request -func CalculateS3PresignedUrl(req *http.Request, creds aws.Credentials, expirySeconds int) (string, error) { - var expires string = getS3PresignedUrlExpires(req) +func CalculateS3PresignedHmacV1QueryUrl(req *http.Request, creds aws.Credentials, expirySeconds int) (string, error) { + var expires string = getExpiresFromHmacv1QueryUrl(req) if expires == "" && expirySeconds == 0 { - return "", errors.New("got expirySeconds 0 but no expires in URL, impossible to get expiry") + return "", errors.New("got expirySeconds 0 but no expires in URL, impossible to get expires") } if expirySeconds > 0 { if expires != "" { - return "", fmt.Errorf("got expirySeconds %d and expires in URL %s, impossible to now which expiry to use", expirySeconds, expires) + return "", fmt.Errorf("got expirySeconds %d and expires in URL %s, impossible to now which expires to use", expirySeconds, expires) } expires = getExpiresFromExpirySeconds(expirySeconds) } - return CalculateS3PresignedUrlWithExpiryTime(req, creds, expires) + return calculateS3PresignedHmacV1QueryUrlWithExpiryTime(req, creds, expires) } func UrlDropSchemeFQDNPort(url string) string { @@ -59,26 +116,7 @@ func UrlDropSchemeFQDNPort(url string) string { return strings.Join(urlParts[3:], "/") } -func HasS3PresignedUrlValidSignature(req *http.Request, creds aws.Credentials) (validSignature bool, err error) { - testUrl := UrlDropSchemeFQDNPort(fullUrlFromRequest(req)) - expectedUrl, err := CalculateS3PresignedUrl(req, creds, 0) - expectedUrl = UrlDropSchemeFQDNPort(expectedUrl) - if err != nil { - return false, err - } - return testUrl == expectedUrl, nil -} - - -func HasGetS3PresignedUrlValidSignature(testUrl string, creds aws.Credentials) (validSignature bool, err error) { - req, err := http.NewRequest(http.MethodGet, testUrl, nil) - if err != nil { - return false, err - } - return HasS3PresignedUrlValidSignature(req, creds) -} - -func CalculateS3PresignedUrlWithExpiryTime(req *http.Request, creds aws.Credentials, expires string) (string, error) { +func calculateS3PresignedHmacV1QueryUrlWithExpiryTime(req *http.Request, creds aws.Credentials, expires string) (string, error) { r := req.Clone(context.Background()) assureSecTokenHeader(r, creds) @@ -94,10 +132,10 @@ func CalculateS3PresignedUrlWithExpiryTime(req *http.Request, creds aws.Credenti hashBytes := h.Sum(nil) signature := base64.StdEncoding.EncodeToString(hashBytes) - return buildUrl(r, creds, expires, signature), nil + return buildHmacV1QueryUrl(r, creds, expires, signature), nil } -func buildUrl(req *http.Request, creds aws.Credentials, expires, signature string) (string) { +func buildHmacV1QueryUrl(req *http.Request, creds aws.Credentials, expires, signature string) (string) { var secToken = "" if creds.SessionToken != "" { secToken = fmt.Sprintf("&x-amz-security-token=%s", url.QueryEscape(creds.SessionToken)) diff --git a/cmd/s3-presigner_test.go b/presign/hmacv1query_test.go similarity index 63% rename from cmd/s3-presigner_test.go rename to presign/hmacv1query_test.go index 0120f60..0f88af4 100644 --- a/cmd/s3-presigner_test.go +++ b/presign/hmacv1query_test.go @@ -1,13 +1,14 @@ -package cmd +package presign import ( + "context" "net/http" "testing" "github.com/aws/aws-sdk-go-v2/aws" ) -//START OF GENERATED CONTENT (see s3-presigner_test.py) +//START OF GENERATED CONTENT (see query_string_test.py) var testUrl = "https://s3.test.com/my-bucket/path/to/my_file" var testAccessKeyId = "0123455678910abcdef09459" @@ -16,7 +17,6 @@ var testSessionToken = "FQoGZXIvYXdzEBYaDkiOiJ7XG5cdFwiVmVyc2lvblwiOiBcIjIwMTItM var testExpires = "1727389975" var testExpectedPresignedUrlTemp = "https://s3.test.com/my-bucket/path/to/my_file?AWSAccessKeyId=0123455678910abcdef09459&Signature=UAK8QHRI55lzlVoLFM6Fj7T98a8%3D&x-amz-security-token=FQoGZXIvYXdzEBYaDkiOiJ7XG5cdFwiVmVyc2lvblwiOiBcIjIwMTItMTAtMTdcIixcblx0XCJT&Expires=1727389975" var testExpectedPresignedUrlPerm = "https://s3.test.com/my-bucket/path/to/my_file?AWSAccessKeyId=0123455678910abcdef09459&Signature=O%2FybXwQdy0cISlo6ly4Lit6s%2BlE%3D&Expires=1727389975" - //END OF GENERATED CONTENT var testCredsPerm = aws.Credentials{ @@ -29,13 +29,14 @@ var testCredsTemp = aws.Credentials{ SessionToken: testSessionToken, } + func TestIfNoExpiresInUrlAndNoExpiryThenWeMustFail(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "http://s3.test/bucket/key", nil) if err != nil { t.Error(req) t.FailNow() } - _, err = CalculateS3PresignedUrl(req, testCredsTemp, 0) + _, err = CalculateS3PresignedHmacV1QueryUrl(req, testCredsTemp, 0) if err == nil { t.Error("Should have gotten an error") } @@ -47,7 +48,7 @@ func TestIfExpiresInUrlAndExpiryThenWeMustFail(t *testing.T) { t.Error(req) t.FailNow() } - _, err = CalculateS3PresignedUrl(req, testCredsPerm, 3600) + _, err = CalculateS3PresignedHmacV1QueryUrl(req, testCredsPerm, 3600) if err == nil { t.Error("Should have gotten an error") } @@ -77,7 +78,7 @@ func TestGenerateS3PresignedGetObjectWithTemporaryCreds(t *testing.T) { t.Error(req) t.FailNow() } - presigned, err := CalculateS3PresignedUrlWithExpiryTime(req, tc.Creds, testExpires) + presigned, err := calculateS3PresignedHmacV1QueryUrlWithExpiryTime(req, tc.Creds, testExpires) if err != nil { t.Errorf("%s: %s", tc.Description, err) continue @@ -89,8 +90,31 @@ func TestGenerateS3PresignedGetObjectWithTemporaryCreds(t *testing.T) { } func TestValidateS3GetPresignedUrlsForValidUrls(t *testing.T){ + var testExpectedExpiresTime, err = epochStrToTime(testExpires) + if err != nil { + t.Errorf("Could not calculated expected expires time") + t.FailNow() + } + for _, tc := range testCasesValidUrls { - isValid, err := HasGetS3PresignedUrlValidSignature(tc.ExpectedUrl, tc.Creds) + req, err := http.NewRequest(http.MethodGet, tc.ExpectedUrl, nil) + if err != nil { + t.Errorf("Could not create request: %s", err) + } + presignedUrl, err := MakePresignedUrl(req) + if err != nil { + t.Errorf("Could not create presigned url: %s", err) + } + _, ok := presignedUrl.(presignedUrlHmacv1Query) + if !ok { + t.Errorf("We are testing HMACv1 query URLs so we expect to get correct type from factory") + } + + var testSecretDeriver = func(s string) (string, error) { + return tc.Creds.SecretAccessKey, nil + } + + isValid, creds, expires, err := presignedUrl.GetPresignedUrlDetails(context.Background(), testSecretDeriver) if err != nil { t.Errorf("%s: %s", tc.Description, err) continue @@ -98,5 +122,14 @@ func TestValidateS3GetPresignedUrlsForValidUrls(t *testing.T){ if !isValid { t.Errorf("%s: Signature was invalid but expected it to be valid", tc.Description) } + if creds.AccessKeyID != tc.Creds.AccessKeyID { + t.Errorf("Got different accessKeyId; got %s, expected %s", creds.AccessKeyID, tc.Creds.AccessKeyID) + } + if creds.SessionToken != tc.Creds.SessionToken { + t.Errorf("Got different sessionToken; got %s, expected %s", creds.SessionToken, tc.Creds.SessionToken) + } + if expires != testExpectedExpiresTime { + t.Errorf("Wrong expires time; Expected %s, got %s", testExpectedExpiresTime, expires) + } } } \ No newline at end of file diff --git a/cmd/s3-presigner_test.py b/presign/hmacv1query_test.py similarity index 96% rename from cmd/s3-presigner_test.py rename to presign/hmacv1query_test.py index 0db9fb6..cd909e7 100644 --- a/cmd/s3-presigner_test.py +++ b/presign/hmacv1query_test.py @@ -1,5 +1,5 @@ """ -This file was used to generate test content for s3-presigner_test.go +This file was used to generate test content for query_string_test.go """ from copy import deepcopy import boto3 diff --git a/presign/presign-interface.go b/presign/presign-interface.go new file mode 100644 index 0000000..63684b3 --- /dev/null +++ b/presign/presign-interface.go @@ -0,0 +1,15 @@ +package presign + +import ( + "context" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +//Secret Deriver takes +type SecretDeriver func(accessKeyId string) (secretAccessKey string, err error) + +type PresignedUrl interface { + GetPresignedUrlDetails(context.Context, SecretDeriver) (isValid bool, creds aws.Credentials, expires time.Time, err error) +} \ No newline at end of file diff --git a/presign/s3v4.go b/presign/s3v4.go new file mode 100644 index 0000000..6b5dcc5 --- /dev/null +++ b/presign/s3v4.go @@ -0,0 +1,91 @@ +package presign + +import ( + "context" + "errors" + "log/slog" + "net/http" + "strconv" + "time" + + "github.com/VITObelgium/fakes3pp/constants" + "github.com/VITObelgium/fakes3pp/requestutils" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" +) + +//This file just contains helpers to presign for S3 with sigv4 + +func PreSignRequestWithCreds(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (signedURI string, signedHeaders http.Header, err error){ + if expiryInSeconds <= 0 { + return "", nil, errors.New("expiryInSeconds must be bigger than 0 for presigned requests") + } + signer := v4.NewSigner() + + ctx, creds, req, payloadHash, service, region, signingTime := GetS3SignRequestParams(ctx, req, expiryInSeconds, signingTime, creds) + return signer.PresignHTTP(ctx, creds, req, payloadHash, service, region, signingTime) +} + +func SignRequestWithCreds(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (err error){ + signer := v4.NewSigner() + + ctx, creds, req, payloadHash, service, region, signingTime := GetS3SignRequestParams(ctx, req, expiryInSeconds, signingTime, creds) + return signer.SignHTTP(ctx, creds, req, payloadHash, service, region, signingTime) +} + +var signatureQueryParamNames []string = []string{ + constants.AmzAlgorithmKey, + constants.AmzCredentialKey, + constants.AmzDateKey, + constants.AmzSecurityTokenKey, + constants.AmzSignedHeadersKey, + constants.AmzSignatureKey, +} + +//Sign an HTTP request with a sigv4 signature. If expiry in seconds is bigger than zero then the signature has an explicit limited lifetime +//use a negative value to not set an explicit expiry time +func GetS3SignRequestParams(ctx context.Context, req *http.Request, expiryInSeconds int, signingTime time.Time, creds aws.Credentials) (context.Context, aws.Credentials, *http.Request, string, string, string, time.Time){ + region := "eu-west-1" + regionName, err := requestutils.GetSignatureCredentialPartFromRequest(req, requestutils.CredentialPartRegionName) + if err == nil { + region = regionName + } + + query := req.URL.Query() + for _, paramName := range signatureQueryParamNames { + query.Del(paramName) + } + if expiryInSeconds > 0 { + expires := time.Duration(expiryInSeconds) * time.Second + query.Set(constants.AmzExpiresKey, strconv.FormatInt(int64(expires/time.Second), 10)) + } + + req.URL.RawQuery = query.Encode() + + service := "s3" + + payloadHash := req.Header.Get(constants.AmzContentSHAKey) + if payloadHash == "" { + payloadHash = "UNSIGNED-PAYLOAD" + } + + return ctx, creds, req, payloadHash, service, region, signingTime +} + + +func SignWithCreds(ctx context.Context, req *http.Request, creds aws.Credentials) error{ + var signingTime time.Time + amzDate := req.Header.Get(constants.AmzDateKey) + if amzDate == "" { + signingTime = time.Now() + } else { + var err error + signingTime, err = XAmzDateToTime(amzDate) + if err != nil { + slog.Warn("Could not handle X-amz-date", constants.AmzDateKey, amzDate, "error", err) + signingTime = time.Now() + } + } + + return SignRequestWithCreds(ctx, req, -1, signingTime, creds) +} diff --git a/presign/s3v4query.go b/presign/s3v4query.go new file mode 100644 index 0000000..03f1bbd --- /dev/null +++ b/presign/s3v4query.go @@ -0,0 +1,134 @@ +package presign + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/VITObelgium/fakes3pp/constants" + "github.com/VITObelgium/fakes3pp/requestutils" + "github.com/aws/aws-sdk-go-v2/aws" +) + +type presignedUrlS3V4Query struct { + *http.Request +} + +func presignedUrlS3V4QueryFromRequest(r *http.Request) presignedUrlS3V4Query { + return presignedUrlS3V4Query{ + r, + } +} + + +func isS3V4Query(r *http.Request) bool { + return r.URL.Query().Get(constants.AmzSignatureKey) != "" && r.URL.Query().Get(constants.AmzCredentialKey) != "" && r.URL.Query().Get(constants.AmzAlgorithmKey) != "" +} + +// Get the value of X-Amz-Credential as given for the presigned url +func (u presignedUrlS3V4Query) getAmzCredential() (string) { + return u.URL.Query().Get(constants.AmzCredentialKey) +} + +// Get the value of X-Amz-Security-Token as given for the presigned url +func (u presignedUrlS3V4Query) getAmzSecurityToken() (string) { + return u.URL.Query().Get(constants.AmzSecurityTokenKey) +} + + +func (u presignedUrlS3V4Query) getAccessKeyId() (string, error) { + return requestutils.GetCredentialPart(u.getAmzCredential(), requestutils.CredentialPartAccessKeyId) +} + +func (u presignedUrlS3V4Query) getSignTime() (time.Time, error) { + XAmzDate := u.URL.Query().Get(constants.AmzDateKey) + return XAmzDateToTime(XAmzDate) +} + +func (u presignedUrlS3V4Query) getSignedHeaders() map[string]string { + var signedHeaders map[string]string = make(map[string]string) + for _, signedHeader := range strings.Split(u.URL.Query().Get(constants.AmzSignedHeadersKey), ";") { + signedHeaders[signedHeader] = "" + } + return signedHeaders +} + +func (u presignedUrlS3V4Query) GetPresignedUrlDetails(ctx context.Context, deriver SecretDeriver) (isValid bool, creds aws.Credentials, expires time.Time, err error) { + accessKeyId, err := u.getAccessKeyId() + if err != nil { + return + } + sessionToken := u.getAmzSecurityToken() + + secretAccessKey, err := deriver(accessKeyId) + if err != nil { + return + } + creds = aws.Credentials{ + AccessKeyID: accessKeyId, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + } + signDate, err := u.getSignTime() + if err != nil { + return + } + XAmzExpires := u.URL.Query().Get(constants.AmzExpiresKey) + expirySeconds, err := strconv.Atoi(XAmzExpires) + if err != nil { + err = fmt.Errorf("InvalidSignature: could not get Expire seconds(%s) %s: %s", constants.AmzExpiresKey, XAmzExpires, err) + return + } + + expires = signDate.Add(time.Duration(expirySeconds) * time.Second) + originalSignature := u.Request.URL.Query().Get(constants.AmzSignatureKey) + c := u.Request.Clone(ctx) + if c.Header.Get("Host") == "" { + c.Header.Add("Host", c.Host) + } + CleanHeadersTo(ctx, c, u.getSignedHeaders()) + signedUri, _, err := PreSignRequestWithCreds(ctx, c, expirySeconds, signDate, creds) + if err != nil { + err = fmt.Errorf("InvalidSignature: encountered error trying to sign a similar req: %s", err) + return + } + + calculatedSignature, err := getSignatureFromV4QueryUrl(signedUri) + if err != nil { + return + } + isValid = originalSignature == calculatedSignature + // isValid, err = haveSameSigv4Signature(requestutils.FullUrlFromRequest(u.Request), signedUri) + return +} + +func getSignatureFromV4QueryUrl(inputUrl string) (sig string, err error) { + q, err := requestutils.GetQueryParamsFromUrl(inputUrl) + if err != nil { + return + } + signature := q.Get(constants.AmzSignatureKey) + if signature == "" { + return signature, fmt.Errorf("Url got empty signature: %s", inputUrl) + } + return signature, nil +} + +//Verify if URLs have the same sigv4 signature. If one of the URLs does not have +//a signature it always returns false. +func haveSameSigv4QuerySignature(url1, url2 string) (same bool, err error) { + s1, err := getSignatureFromV4QueryUrl(url1) + if err != nil { + return false, err + } + + s2, err := getSignatureFromV4QueryUrl(url2) + if err != nil { + return false, err + } + + return s1 == s2, nil +} \ No newline at end of file diff --git a/cmd/presign_test.go b/presign/s3v4query_test.go similarity index 90% rename from cmd/presign_test.go rename to presign/s3v4query_test.go index f384458..2eb211b 100644 --- a/cmd/presign_test.go +++ b/presign/s3v4query_test.go @@ -1,12 +1,14 @@ -package cmd +package presign import ( "context" "net/http" "testing" + + "github.com/VITObelgium/fakes3pp/constants" ) -//Logging of CONTENT GENERATION (see s3-presigner_test.go for origin of these values) +//Logging of CONTENT GENERATION (see hmacv1query_test.go for origin of these values) // export AWS_ACCESS_KEY_ID="0123455678910abcdef09459" // export AWS_SECRET_ACCESS_KEY="YWUzOTQyM2FlMDMzNDlkNjk0M2FmZDE1OWE1ZGRkMT" @@ -26,6 +28,7 @@ var testSigningDateEuc1 = "20241009T115034Z" //END OF GENERATED CONTENT + func TestAwsCliGeneratedURLMustWork(t *testing.T) { var testCases = []struct{ RegionName string @@ -58,7 +61,7 @@ func TestAwsCliGeneratedURLMustWork(t *testing.T) { req, err := http.NewRequest(http.MethodGet, tc.ExpectedUrl, nil) queryP := req.URL.Query() - queryP.Del(AmzSignatureKey) + queryP.Del(constants.AmzSignatureKey) req.URL.RawQuery = queryP.Encode() if err != nil { t.Errorf("Could not create request: %s", err) @@ -69,7 +72,7 @@ func TestAwsCliGeneratedURLMustWork(t *testing.T) { if err != nil { t.Errorf("Did not expect error. Got %s", err) } - if s, err := haveSameSigv4Signature(signedUri, tc.ExpectedUrl); !s || err != nil { + if s, err := haveSameSigv4QuerySignature(signedUri, tc.ExpectedUrl); !s || err != nil { t.Errorf("Mismatch signature:\nGot :%s\nExpected:%s", signedUri, tc.ExpectedUrl) } } diff --git a/presign/utils.go b/presign/utils.go new file mode 100644 index 0000000..60268e2 --- /dev/null +++ b/presign/utils.go @@ -0,0 +1,15 @@ +package presign + +import ( + "strconv" + "time" +) + + +func epochStrToTime(in string) (time.Time, error) { + expiresInt, err := strconv.Atoi(in) + if err != nil { + return time.Now(), err + } + return time.Unix(int64(expiresInt), 0), nil +} \ No newline at end of file diff --git a/requestutils/amz-credential-value.go b/requestutils/amz-credential-value.go new file mode 100644 index 0000000..68012eb --- /dev/null +++ b/requestutils/amz-credential-value.go @@ -0,0 +1,84 @@ +package requestutils + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "strings" +) + + +type CredentialPart int64 + +const ( + CredentialPartAccessKeyId CredentialPart = iota + CredentialPartDate + CredentialPartRegionName + CredentialPartServiceName + CredentialPartType +) + + +// credential string is the value of a X-Amz_credential and it is meant to follow +// the structure /20130721/us-east-1/s3/aws4_request (when decoded) +// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html +func GetCredentialPart(credentialString string, credentialPart CredentialPart) (string, error) { + authorizationHeaderCredentialParts := strings.Split(credentialString, "/") + if authorizationHeaderCredentialParts[CredentialPartServiceName] != "s3" { + return "", errors.New("authorization header was not for S3") + } + if authorizationHeaderCredentialParts[CredentialPartType] != "aws4_request" { + return "", errors.New("authorization header was not a supported sigv4") + } + return authorizationHeaderCredentialParts[credentialPart], nil +} + +const signAlgorithm = "AWS4-HMAC-SHA256" +const expectedAuthorizationStartWithCredential = "AWS4-HMAC-SHA256 Credential=" + + +// Gets a part of the Credential value that is passed via the authorization header +// +func GetSignatureCredentialPartFromRequest(r *http.Request, credentialPart CredentialPart) (string, error) { + authorizationHeader := r.Header.Get("Authorization") + var credentialString string + var err error + if authorizationHeader != "" { + credentialString, err = getSignatureCredentialStringFromRequestAuthHeader(authorizationHeader) + if err != nil { + return "", err + } + } else { + qParams := r.URL.Query() + credentialString, err = getSignatureCredentialStringFromRequestQParams(qParams) + if err != nil { + return "", err + } + } + return GetCredentialPart(credentialString, credentialPart) +} + +// Gets a part of the Credential value that is passed via the authorization header +func getSignatureCredentialStringFromRequestAuthHeader(authorizationHeader string) (string, error) { + if authorizationHeader == "" { + return "", fmt.Errorf("programming error should use empty authHeader to get credential part") + } + if !strings.HasPrefix(authorizationHeader, expectedAuthorizationStartWithCredential) { + return "", fmt.Errorf("invalid authorization header: %s", authorizationHeader) + } + authorizationHeaderTrimmed := authorizationHeader[len(expectedAuthorizationStartWithCredential):] + return strings.Split(authorizationHeaderTrimmed, ", ")[0], nil +} + +func getSignatureCredentialStringFromRequestQParams(qParams url.Values) (string, error) { + queryAlgorithm := qParams.Get("X-Amz-Algorithm") + if queryAlgorithm != signAlgorithm { + return "", fmt.Errorf("no Authorization header nor x-amz-algorithm query parameter present: %v", qParams) + } + queryCredential := qParams.Get("X-Amz-Credential") + if queryCredential == "" { + return "", fmt.Errorf("empty X-Amz-Credential parameter: %v", qParams) + } + return queryCredential, nil +} \ No newline at end of file diff --git a/requestutils/http-request.go b/requestutils/http-request.go new file mode 100644 index 0000000..4ab727f --- /dev/null +++ b/requestutils/http-request.go @@ -0,0 +1,22 @@ +package requestutils + +import ( + "fmt" + "net/http" +) + +//Given a request try to reconstruct the full URL for that request +//including protocol, hostname, path and query parameter names and values +func FullUrlFromRequest(req *http.Request) string { + scheme := req.URL.Scheme + if scheme == "" { + scheme = "https" + } + return fmt.Sprintf( + "%s://%s%s?%s", + scheme, + req.Host, + req.URL.Path, + req.URL.RawQuery, + ) +} \ No newline at end of file diff --git a/requestutils/http-request_test.go b/requestutils/http-request_test.go new file mode 100644 index 0000000..45a0ea8 --- /dev/null +++ b/requestutils/http-request_test.go @@ -0,0 +1,42 @@ +package requestutils_test + +import ( + "net/http" + "testing" + + url "github.com/VITObelgium/fakes3pp/requestutils" +) + +func TestGetUrlFromRequest(t *testing.T) { + var testCasesValidUrls = []struct{ + Description string + Url string + }{ + { + "Temporary credentials Url", + "https://s3.test.com/my-bucket/path/to/my_file?AWSAccessKeyId=0123455678910abcdef09459&Signature=UAK8QHRI55lzlVoLFM6Fj7T98a8%3D&x-amz-security-token=FQoGZXIvYXdzEBYaDkiOiJ7XG5cdFwiVmVyc2lvblwiOiBcIjIwMTItMTAtMTdcIixcblx0XCJT&Expires=1727389975", + }, + { + "Permanent credentials Url", + "https://s3.test.com/my-bucket/path/to/my_file?AWSAccessKeyId=0123455678910abcdef09459&Signature=O%2FybXwQdy0cISlo6ly4Lit6s%2BlE%3D&Expires=1727389975", + }, + } + + for _, tc := range testCasesValidUrls { + req := buildGetRequest(tc.Url, t) + u := url.FullUrlFromRequest(req) + if u != tc.Url { + t.Errorf("%s: Got %s, expected %s", tc.Description, u, tc.Url) + } + } +} + + +func buildGetRequest(url string, t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + t.Error(err) + t.FailNow() + } + return req +} \ No newline at end of file diff --git a/requestutils/string-url.go b/requestutils/string-url.go new file mode 100644 index 0000000..95d7453 --- /dev/null +++ b/requestutils/string-url.go @@ -0,0 +1,16 @@ +package requestutils + +import "net/url" + + +func GetQueryParamsFromUrl(inputUrl string) (url.Values, error) { + u, err := url.Parse(inputUrl) + if err != nil { + return nil, err + } + q, err := url.ParseQuery(u.RawQuery) + if err != nil { + return nil, err + } + return q, nil +} \ No newline at end of file