From 66f07f88a94210b8273957f950b61b8d2165cfaa Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Wed, 23 Oct 2024 15:16:51 +0100 Subject: [PATCH] Rejig JoinServerGRPCServer to invoke auth.Server --- lib/auth/auth_with_roles.go | 49 ++++--------------------------- lib/auth/grpcserver.go | 7 +---- lib/auth/join_azure.go | 17 +++++++++-- lib/auth/join_azure_test.go | 2 +- lib/auth/join_iam.go | 17 +++++++++-- lib/auth/join_iam_test.go | 2 +- lib/auth/join_tpm.go | 2 +- lib/joinserver/joinserver.go | 22 +++++--------- lib/joinserver/joinserver_test.go | 13 ++++++-- lib/web/apiserver.go | 2 ++ 10 files changed, 59 insertions(+), 74 deletions(-) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 4401479798d17..15c186f54e0c2 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -573,9 +573,12 @@ func (a *ServerWithRoles) GetClusterCACert( return a.authServer.GetClusterCACert(ctx) } +// Deprecated: This method only exists to service the RegisterUsingToken HTTP +// RPC, which has been replaced by an RPC on the JoinServiceServer. +// JoinServiceServer directly invokes auth.Server and performs its own checks +// on metadata. +// TODO(strideynet): DELETE IN V18.0.0 func (a *ServerWithRoles) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) { - // TODO(strideynet): In v18.0.0, this logic can be moved into - // JoinServiceGRPCServer. isProxy := a.hasBuiltinRole(types.RoleProxy) // We do not trust remote addr in the request unless it's coming from the Proxy. @@ -617,48 +620,6 @@ func (a *ServerWithRoles) RegisterUsingToken(ctx context.Context, req *types.Reg return a.authServer.RegisterUsingToken(ctx, req) } -// RegisterUsingIAMMethod registers the caller using the IAM join method and -// returns signed certs to join the cluster. -// -// See (*Server).RegisterUsingIAMMethod for further documentation. -// -// This wrapper does not do any extra authz checks, as the register method has -// its own authz mechanism. -func (a *ServerWithRoles) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) { - certs, err := a.authServer.RegisterUsingIAMMethod(ctx, challengeResponse) - return certs, trace.Wrap(err) -} - -// RegisterUsingAzureMethod registers the caller using the Azure join method and -// returns signed certs to join the cluster. -// -// See (*Server).RegisterUsingAzureMethod for further documentation. -// -// This wrapper does not do any extra authz checks, as the register method has -// its own authz mechanism. -func (a *ServerWithRoles) RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error) { - certs, err := a.authServer.RegisterUsingAzureMethod(ctx, challengeResponse) - return certs, trace.Wrap(err) -} - -// RegisterUsingTPMMethod registers the caller using the TPM join method and -// returns signed certs to join the cluster. -// -// See (*Server).RegisterUsingTPMMethod for further documentation. -// -// This wrapper does not do any extra authz checks, as the register method has -// its own authz mechanism. -func (a *ServerWithRoles) RegisterUsingTPMMethod( - ctx context.Context, - initReq *proto.RegisterUsingTPMMethodInitialRequest, - solveChallenge client.RegisterTPMChallengeResponseFunc, -) (*proto.Certs, error) { - certs, err := a.authServer.registerUsingTPMMethod( - ctx, initReq, solveChallenge, - ) - return certs, trace.Wrap(err) -} - // GenerateHostCerts generates new host certificates (signed // by the host certificate authority) for a node. func (a *ServerWithRoles) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) { diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 40529c26a1775..56e057fe83098 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5164,12 +5164,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { } trustv1pb.RegisterTrustServiceServer(server, trust) - // create server with no-op role to pass to JoinService server - serverWithNopRole, err := serverWithNopRole(cfg) - if err != nil { - return nil, trace.Wrap(err) - } - joinServiceServer := joinserver.NewJoinServiceGRPCServer(serverWithNopRole, false) + joinServiceServer := joinserver.NewJoinServiceGRPCServer(cfg.AuthServer) authpb.RegisterJoinServiceServer(server, joinServiceServer) integrationServiceServer, err := integrationv1.NewService(&integrationv1.ServiceConfig{ diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index e044d4e810a69..9cee8259dd4c2 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -351,13 +351,13 @@ func generateAzureChallenge() (string, error) { return challenge, trace.Wrap(err) } -// RegisterUsingAzureMethod registers the caller using the Azure join method +// RegisterUsingAzureMethodWithOpts registers the caller using the Azure join method // and returns signed certs to join the cluster. // // The caller must provide a ChallengeResponseFunc which returns a // *proto.RegisterUsingAzureMethodRequest with a signed attested data document // including the challenge as a nonce. -func (a *Server) RegisterUsingAzureMethod( +func (a *Server) RegisterUsingAzureMethodWithOpts( ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc, opts ...azureRegisterOption, @@ -422,6 +422,19 @@ func (a *Server) RegisterUsingAzureMethod( return certs, trace.Wrap(err) } +// RegisterUsingAzureMethod registers the caller using the Azure join method +// and returns signed certs to join the cluster. +// +// The caller must provide a ChallengeResponseFunc which returns a +// *proto.RegisterUsingAzureMethodRequest with a signed attested data document +// including the challenge as a nonce. +func (a *Server) RegisterUsingAzureMethod( + ctx context.Context, + challengeResponse client.RegisterAzureChallengeResponseFunc, +) (certs *proto.Certs, err error) { + return a.RegisterUsingAzureMethodWithOpts(ctx, challengeResponse) +} + // fixAzureSigningAlgorithm fixes a mismatch between the object IDs of the // hashing algorithm sent by Azure vs the ones expected by the pkcs7 library. // Specifically, Azure (incorrectly?) sends a [digest encryption algorithm] diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index 5fe5d487fcf3a..1e8af282de7ef 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -419,7 +419,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { vmClient := &mockAzureVMClient{vm: vmResult} - _, err = a.RegisterUsingAzureMethod(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { + _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { cfg := &azureChallengeResponseConfig{Challenge: challenge} for _, opt := range tc.challengeResponseOptions { opt(cfg) diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index a43fbe70fd920..ba2209105c7c0 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -328,13 +328,13 @@ func withFips(fips bool) iamRegisterOption { } } -// RegisterUsingIAMMethod registers the caller using the IAM join method and +// RegisterUsingIAMMethodWithOpts registers the caller using the IAM join method and // returns signed certs to join the cluster. // // The caller must provide a ChallengeResponseFunc which returns a // *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request // including the challenge as a signed header. -func (a *Server) RegisterUsingIAMMethod( +func (a *Server) RegisterUsingIAMMethodWithOpts( ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc, opts ...iamRegisterOption, @@ -388,3 +388,16 @@ func (a *Server) RegisterUsingIAMMethod( certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil) return certs, trace.Wrap(err, "generating certs") } + +// RegisterUsingIAMMethod registers the caller using the IAM join method and +// returns signed certs to join the cluster. +// +// The caller must provide a ChallengeResponseFunc which returns a +// *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request +// including the challenge as a signed header. +func (a *Server) RegisterUsingIAMMethod( + ctx context.Context, + challengeResponse client.RegisterIAMChallengeResponseFunc, +) (certs *proto.Certs, err error) { + return a.RegisterUsingIAMMethodWithOpts(ctx, challengeResponse) +} diff --git a/lib/auth/join_iam_test.go b/lib/auth/join_iam_test.go index 5ba3d6ae76c27..a64399d6e7136 100644 --- a/lib/auth/join_iam_test.go +++ b/lib/auth/join_iam_test.go @@ -620,7 +620,7 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) { require.NoError(t, a.DeleteToken(ctx, token.GetName())) }() - _, err = a.RegisterUsingIAMMethod(context.Background(), func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { + _, err = a.RegisterUsingIAMMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { templateInput := defaultIdentityRequestTemplateInput(challenge) for _, opt := range tc.challengeResponseOptions { opt(&templateInput) diff --git a/lib/auth/join_tpm.go b/lib/auth/join_tpm.go index 4304628b9d07c..05bf9e3c35a54 100644 --- a/lib/auth/join_tpm.go +++ b/lib/auth/join_tpm.go @@ -33,7 +33,7 @@ import ( "github.com/gravitational/teleport/lib/tpm" ) -func (a *Server) registerUsingTPMMethod( +func (a *Server) RegisterUsingTPMMethod( ctx context.Context, initReq *proto.RegisterUsingTPMMethodInitialRequest, solveChallenge client.RegisterTPMChallengeResponseFunc, diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go index 4707f49dd7b6d..2b753e871fe19 100644 --- a/lib/joinserver/joinserver.go +++ b/lib/joinserver/joinserver.go @@ -68,13 +68,12 @@ type joinServiceClient interface { type JoinServiceGRPCServer struct { proto.UnimplementedJoinServiceServer - isProxy bool joinServiceClient joinServiceClient clock clockwork.Clock } // NewJoinServiceGRPCServer returns a new JoinServiceGRPCServer. -func NewJoinServiceGRPCServer(joinServiceClient joinServiceClient, isProxy bool) *JoinServiceGRPCServer { +func NewJoinServiceGRPCServer(joinServiceClient joinServiceClient) *JoinServiceGRPCServer { return &JoinServiceGRPCServer{ joinServiceClient: joinServiceClient, clock: clockwork.NewRealClock(), @@ -384,23 +383,16 @@ func (s *JoinServiceGRPCServer) registerUsingTPMMethod( // RegisterUsingToken allows nodes and proxies to join the cluster using // legacy join methods which do not yet have their own RPC. -// On the Auth server, this method will call the auth.ServerWithRoles's +// On the Auth server, this method will call the auth.Server's // RegisterUsingToken method. When running on the Proxy, this method will -// forward the request to the Auth server. +// forward the request to the Auth server's JoinServiceServer. func (s *JoinServiceGRPCServer) RegisterUsingToken( ctx context.Context, req *types.RegisterUsingTokenRequest, ) (*proto.Certs, error) { - // We only want to set bot params/client id if we are the proxy, because - // ServerWithRoles currently handles this for us on the auth server. If the - // auth server join service also did this, we would emit concerning warnings - // about potential fakes. - // TODO(strideynet): In v18.0.0, we can move the handling up ServerWithRoles - // and into here once the legacy HTTP endpoint is gone. - if s.isProxy { - if err := setClientRemoteAddr(ctx, req); err != nil { - return nil, trace.Wrap(err, "setting client address") - } - setBotParameters(ctx, req) + if err := setClientRemoteAddr(ctx, req); err != nil { + return nil, trace.Wrap(err, "setting client address") } + setBotParameters(ctx, req) + return s.joinServiceClient.RegisterUsingToken(ctx, req) } diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go index d9a089f45adb7..79865f146c9c3 100644 --- a/lib/joinserver/joinserver_test.go +++ b/lib/joinserver/joinserver_test.go @@ -49,6 +49,7 @@ type mockJoinServiceClient struct { gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest + gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest } func (c *mockJoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) { @@ -85,6 +86,14 @@ func (c *mockJoinServiceClient) RegisterUsingTPMMethod( return c.returnCerts, c.returnError } +func (c *mockJoinServiceClient) RegisterUsingToken( + ctx context.Context, + req *types.RegisterUsingTokenRequest, +) (*proto.Certs, error) { + c.gotRegisterUsingTokenReq = req + return c.returnCerts, c.returnError +} + func ConnectionCountingStreamInterceptor(count *atomic.Int32) grpc.StreamServerInterceptor { return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { count.Add(1) @@ -133,7 +142,7 @@ func newTestPack(t *testing.T) *testPack { // create the first instance of JoinServiceGRPCServer wrapping the mock auth // server, to imitate the JoinServiceGRPCServer which runs on Auth authGRPCServer, authGRPCListener := newGRPCServer(t, grpc.ChainStreamInterceptor(ConnectionCountingStreamInterceptor(streamConnectionCount))) - authServer := NewJoinServiceGRPCServer(mockAuthServer, false) + authServer := NewJoinServiceGRPCServer(mockAuthServer) proto.RegisterJoinServiceServer(authGRPCServer, authServer) // create a client to the "auth" gRPC service @@ -144,7 +153,7 @@ func newTestPack(t *testing.T) *testPack { // create a second instance of JoinServiceGRPCServer wrapping the "auth" // gRPC client, to imitate the JoinServiceGRPCServer which runs on Proxy proxyGRPCServer, proxyGRPCListener := newGRPCServer(t, grpc.ChainStreamInterceptor(ConnectionCountingStreamInterceptor(streamConnectionCount))) - proxyServer := NewJoinServiceGRPCServer(authJoinServiceClient, true) + proxyServer := NewJoinServiceGRPCServer(authJoinServiceClient) proto.RegisterJoinServiceServer(proxyGRPCServer, proxyServer) // create a client to the "proxy" gRPC service diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 8e1e4ba43f691..d2957b56e2ec0 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -4067,6 +4067,8 @@ type eventsListGetResponse struct { // hostCredentials sends a registration token and metadata to the Auth Server // and gets back SSH and TLS certificates. +// TODO(strideynet): DELETE IN V18.0.0 +// Deprecated: Use the RegisterUsingToken RPC instead. func (h *Handler) hostCredentials(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { var req types.RegisterUsingTokenRequest if err := httplib.ReadJSON(r, &req); err != nil {