diff --git a/lib/srv/db/common/errors.go b/lib/srv/db/common/errors.go index 6b9fdca498fe8..94c657d6a2897 100644 --- a/lib/srv/db/common/errors.go +++ b/lib/srv/db/common/errors.go @@ -17,6 +17,7 @@ limitations under the License. package common import ( + "errors" "fmt" "net/http" "strings" @@ -95,11 +96,17 @@ func convertPostgresError(err *pgconn.PgError) error { func convertMySQLError(err *mysql.MyError) error { switch err.Code { case mysql.ER_ACCESS_DENIED_ERROR, mysql.ER_DBACCESS_DENIED_ERROR: - return trace.AccessDenied(err.Error()) + return trace.AccessDenied(fmtEscape(err)) } return err // Return unmodified. } +// fmtEscape escapes "%" in the original error message to prevent fmt from +// thinking some args are missing. +func fmtEscape(err error) string { + return strings.ReplaceAll(err.Error(), "%", "%%") +} + // causer defines an interface for errors wrapped by the "errors" package. type causer interface { Cause() error @@ -131,14 +138,15 @@ func ConvertConnectError(err error, sessionCtx *Session) error { return trace.AccessDenied("Database cannot validate client certificate generated by database service: %v.", err) } - err = ConvertError(err) + orgErr := err + err = ConvertError(orgErr) if trace.IsAccessDenied(err) { switch sessionCtx.Database.GetType() { case types.DatabaseTypeElastiCache: return createElastiCacheRedisAccessDeniedError(err, sessionCtx) case types.DatabaseTypeRDS: - return createRDSAccessDeniedError(err, sessionCtx) + return createRDSAccessDeniedError(err, orgErr, sessionCtx) case types.DatabaseTypeRDSProxy: return createRDSProxyAccessDeniedError(err, sessionCtx) case types.DatabaseTypeAzure: @@ -175,9 +183,20 @@ take a few minutes to propagate): } } +func isRDSMySQLIAMAuthError(err error) bool { + if causer, ok := err.(causer); ok { + return isRDSMySQLIAMAuthError(causer.Cause()) + } + var mysqlError *mysql.MyError + if !errors.As(trace.Unwrap(err), &mysqlError) { + return false + } + return mysqlError.Code == mysql.ER_ACCESS_DENIED_ERROR +} + // createRDSAccessDeniedError creates an error with help message to setup IAM // auth for RDS. -func createRDSAccessDeniedError(err error, sessionCtx *Session) error { +func createRDSAccessDeniedError(err, orgErr error, sessionCtx *Session) error { policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(sessionCtx.Database) if getPolicyErr != nil { policy = fmt.Sprintf("failed to generate IAM policy: %v", getPolicyErr) @@ -185,6 +204,11 @@ func createRDSAccessDeniedError(err error, sessionCtx *Session) error { switch sessionCtx.Database.GetProtocol() { case defaults.ProtocolMySQL: + // Not all access denied errors are IAM Auth errors, so check again. + if !isRDSMySQLIAMAuthError(orgErr) { + return trace.Wrap(err) + } + return trace.AccessDenied(`Could not connect to database: %v diff --git a/lib/srv/db/common/errors_test.go b/lib/srv/db/common/errors_test.go new file mode 100644 index 0000000000000..49a6b40b4f6ac --- /dev/null +++ b/lib/srv/db/common/errors_test.go @@ -0,0 +1,48 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "testing" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func Test_isRDSMySQLIAMAuthError(t *testing.T) { + iamAuthError := &mysql.MyError{ + Code: mysql.ER_ACCESS_DENIED_ERROR, + Message: "Access denied for user 'alice'@'10.0.0.197' (using password: YES)", + State: "28000", + } + require.True(t, isRDSMySQLIAMAuthError(iamAuthError)) + + dbAccessError := &mysql.MyError{ + Code: mysql.ER_DBACCESS_DENIED_ERROR, + Message: "Access denied for user 'alice'@'%' to database 'db-no-access'", + State: "42000", + } + noDBError := &mysql.MyError{ + Code: mysql.ER_BAD_DB_ERROR, + Message: "Unknown database 'db-not-exist'", + State: "42000", + } + require.False(t, isRDSMySQLIAMAuthError(dbAccessError)) + require.False(t, isRDSMySQLIAMAuthError(noDBError)) + require.False(t, isRDSMySQLIAMAuthError(trace.AccessDenied("access denied"))) +}