Skip to content

Commit

Permalink
Rejig JoinServerGRPCServer to invoke auth.Server
Browse files Browse the repository at this point in the history
  • Loading branch information
strideynet committed Oct 23, 2024
1 parent 946fb1c commit 7fdaed1
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 74 deletions.
49 changes: 5 additions & 44 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,12 @@ func (a *ServerWithRoles) GetClusterCACert(
return a.authServer.GetClusterCACert(ctx)
}

// Deprecated: This method only exists to service the RegisterUsingToken HTTP
// RPC, which has been replaced by an RPC on the JoinServiceServer.
// JoinServiceServer directly invokes auth.Server and performs its own checks
// on metadata.
// TODO(strideynet): DELETE IN V18.0.0
func (a *ServerWithRoles) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
// TODO(strideynet): In v18.0.0, this logic can be moved into
// JoinServiceGRPCServer.
isProxy := a.hasBuiltinRole(types.RoleProxy)

// We do not trust remote addr in the request unless it's coming from the Proxy.
Expand Down Expand Up @@ -617,48 +620,6 @@ func (a *ServerWithRoles) RegisterUsingToken(ctx context.Context, req *types.Reg
return a.authServer.RegisterUsingToken(ctx, req)
}

// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
// See (*Server).RegisterUsingIAMMethod for further documentation.
//
// This wrapper does not do any extra authz checks, as the register method has
// its own authz mechanism.
func (a *ServerWithRoles) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) {
certs, err := a.authServer.RegisterUsingIAMMethod(ctx, challengeResponse)
return certs, trace.Wrap(err)
}

// RegisterUsingAzureMethod registers the caller using the Azure join method and
// returns signed certs to join the cluster.
//
// See (*Server).RegisterUsingAzureMethod for further documentation.
//
// This wrapper does not do any extra authz checks, as the register method has
// its own authz mechanism.
func (a *ServerWithRoles) RegisterUsingAzureMethod(ctx context.Context, challengeResponse client.RegisterAzureChallengeResponseFunc) (*proto.Certs, error) {
certs, err := a.authServer.RegisterUsingAzureMethod(ctx, challengeResponse)
return certs, trace.Wrap(err)
}

// RegisterUsingTPMMethod registers the caller using the TPM join method and
// returns signed certs to join the cluster.
//
// See (*Server).RegisterUsingTPMMethod for further documentation.
//
// This wrapper does not do any extra authz checks, as the register method has
// its own authz mechanism.
func (a *ServerWithRoles) RegisterUsingTPMMethod(
ctx context.Context,
initReq *proto.RegisterUsingTPMMethodInitialRequest,
solveChallenge client.RegisterTPMChallengeResponseFunc,
) (*proto.Certs, error) {
certs, err := a.authServer.registerUsingTPMMethod(
ctx, initReq, solveChallenge,
)
return certs, trace.Wrap(err)
}

// GenerateHostCerts generates new host certificates (signed
// by the host certificate authority) for a node.
func (a *ServerWithRoles) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequest) (*proto.Certs, error) {
Expand Down
7 changes: 1 addition & 6 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5164,12 +5164,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
}
trustv1pb.RegisterTrustServiceServer(server, trust)

// create server with no-op role to pass to JoinService server
serverWithNopRole, err := serverWithNopRole(cfg)
if err != nil {
return nil, trace.Wrap(err)
}
joinServiceServer := joinserver.NewJoinServiceGRPCServer(serverWithNopRole, false)
joinServiceServer := joinserver.NewJoinServiceGRPCServer(cfg.AuthServer)
authpb.RegisterJoinServiceServer(server, joinServiceServer)

integrationServiceServer, err := integrationv1.NewService(&integrationv1.ServiceConfig{
Expand Down
17 changes: 15 additions & 2 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,13 @@ func generateAzureChallenge() (string, error) {
return challenge, trace.Wrap(err)
}

// RegisterUsingAzureMethod registers the caller using the Azure join method
// RegisterUsingAzureMethodWithOpts registers the caller using the Azure join method
// and returns signed certs to join the cluster.
//
// The caller must provide a ChallengeResponseFunc which returns a
// *proto.RegisterUsingAzureMethodRequest with a signed attested data document
// including the challenge as a nonce.
func (a *Server) RegisterUsingAzureMethod(
func (a *Server) RegisterUsingAzureMethodWithOpts(
ctx context.Context,
challengeResponse client.RegisterAzureChallengeResponseFunc,
opts ...azureRegisterOption,
Expand Down Expand Up @@ -422,6 +422,19 @@ func (a *Server) RegisterUsingAzureMethod(
return certs, trace.Wrap(err)
}

// RegisterUsingAzureMethod registers the caller using the Azure join method
// and returns signed certs to join the cluster.
//
// The caller must provide a ChallengeResponseFunc which returns a
// *proto.RegisterUsingAzureMethodRequest with a signed attested data document
// including the challenge as a nonce.
func (a *Server) RegisterUsingAzureMethod(
ctx context.Context,
challengeResponse client.RegisterAzureChallengeResponseFunc,
) (certs *proto.Certs, err error) {
return a.RegisterUsingAzureMethodWithOpts(ctx, challengeResponse)
}

// fixAzureSigningAlgorithm fixes a mismatch between the object IDs of the
// hashing algorithm sent by Azure vs the ones expected by the pkcs7 library.
// Specifically, Azure (incorrectly?) sends a [digest encryption algorithm]
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/join_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {

vmClient := &mockAzureVMClient{vm: vmResult}

_, err = a.RegisterUsingAzureMethod(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
_, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
cfg := &azureChallengeResponseConfig{Challenge: challenge}
for _, opt := range tc.challengeResponseOptions {
opt(cfg)
Expand Down
17 changes: 15 additions & 2 deletions lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,13 @@ func withFips(fips bool) iamRegisterOption {
}
}

// RegisterUsingIAMMethod registers the caller using the IAM join method and
// RegisterUsingIAMMethodWithOpts registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
// The caller must provide a ChallengeResponseFunc which returns a
// *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request
// including the challenge as a signed header.
func (a *Server) RegisterUsingIAMMethod(
func (a *Server) RegisterUsingIAMMethodWithOpts(
ctx context.Context,
challengeResponse client.RegisterIAMChallengeResponseFunc,
opts ...iamRegisterOption,
Expand Down Expand Up @@ -388,3 +388,16 @@ func (a *Server) RegisterUsingIAMMethod(
certs, err = a.generateCerts(ctx, provisionToken, req.RegisterUsingTokenRequest, nil)
return certs, trace.Wrap(err, "generating certs")
}

// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
// The caller must provide a ChallengeResponseFunc which returns a
// *types.RegisterUsingTokenRequest with a signed sts:GetCallerIdentity request
// including the challenge as a signed header.
func (a *Server) RegisterUsingIAMMethod(
ctx context.Context,
challengeResponse client.RegisterIAMChallengeResponseFunc,
) (certs *proto.Certs, err error) {
return a.RegisterUsingIAMMethodWithOpts(ctx, challengeResponse)
}
2 changes: 1 addition & 1 deletion lib/auth/join_iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) {
require.NoError(t, a.DeleteToken(ctx, token.GetName()))
}()

_, err = a.RegisterUsingIAMMethod(context.Background(), func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
_, err = a.RegisterUsingIAMMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
templateInput := defaultIdentityRequestTemplateInput(challenge)
for _, opt := range tc.challengeResponseOptions {
opt(&templateInput)
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/join_tpm.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
"github.com/gravitational/teleport/lib/tpm"
)

func (a *Server) registerUsingTPMMethod(
func (a *Server) RegisterUsingTPMMethod(
ctx context.Context,
initReq *proto.RegisterUsingTPMMethodInitialRequest,
solveChallenge client.RegisterTPMChallengeResponseFunc,
Expand Down
22 changes: 7 additions & 15 deletions lib/joinserver/joinserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,12 @@ type joinServiceClient interface {
type JoinServiceGRPCServer struct {
proto.UnimplementedJoinServiceServer

isProxy bool
joinServiceClient joinServiceClient
clock clockwork.Clock
}

// NewJoinServiceGRPCServer returns a new JoinServiceGRPCServer.
func NewJoinServiceGRPCServer(joinServiceClient joinServiceClient, isProxy bool) *JoinServiceGRPCServer {
func NewJoinServiceGRPCServer(joinServiceClient joinServiceClient) *JoinServiceGRPCServer {
return &JoinServiceGRPCServer{
joinServiceClient: joinServiceClient,
clock: clockwork.NewRealClock(),
Expand Down Expand Up @@ -384,23 +383,16 @@ func (s *JoinServiceGRPCServer) registerUsingTPMMethod(

// RegisterUsingToken allows nodes and proxies to join the cluster using
// legacy join methods which do not yet have their own RPC.
// On the Auth server, this method will call the auth.ServerWithRoles's
// On the Auth server, this method will call the auth.Server's
// RegisterUsingToken method. When running on the Proxy, this method will
// forward the request to the Auth server.
// forward the request to the Auth server's JoinServiceServer.
func (s *JoinServiceGRPCServer) RegisterUsingToken(
ctx context.Context, req *types.RegisterUsingTokenRequest,
) (*proto.Certs, error) {
// We only want to set bot params/client id if we are the proxy, because
// ServerWithRoles currently handles this for us on the auth server. If the
// auth server join service also did this, we would emit concerning warnings
// about potential fakes.
// TODO(strideynet): In v18.0.0, we can move the handling up ServerWithRoles
// and into here once the legacy HTTP endpoint is gone.
if s.isProxy {
if err := setClientRemoteAddr(ctx, req); err != nil {
return nil, trace.Wrap(err, "setting client address")
}
setBotParameters(ctx, req)
if err := setClientRemoteAddr(ctx, req); err != nil {
return nil, trace.Wrap(err, "setting client address")
}
setBotParameters(ctx, req)

return s.joinServiceClient.RegisterUsingToken(ctx, req)
}
13 changes: 11 additions & 2 deletions lib/joinserver/joinserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type mockJoinServiceClient struct {
gotAzureChallengeResponse *proto.RegisterUsingAzureMethodRequest
gotTPMChallengeResponse *proto.RegisterUsingTPMMethodChallengeResponse
gotTPMInitReq *proto.RegisterUsingTPMMethodInitialRequest
gotRegisterUsingTokenReq *types.RegisterUsingTokenRequest
}

func (c *mockJoinServiceClient) RegisterUsingIAMMethod(ctx context.Context, challengeResponse client.RegisterIAMChallengeResponseFunc) (*proto.Certs, error) {
Expand Down Expand Up @@ -85,6 +86,14 @@ func (c *mockJoinServiceClient) RegisterUsingTPMMethod(
return c.returnCerts, c.returnError
}

func (c *mockJoinServiceClient) RegisterUsingToken(
ctx context.Context,
req *types.RegisterUsingTokenRequest,
) (*proto.Certs, error) {
c.gotRegisterUsingTokenReq = req
return c.returnCerts, c.returnError
}

func ConnectionCountingStreamInterceptor(count *atomic.Int32) grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
count.Add(1)
Expand Down Expand Up @@ -133,7 +142,7 @@ func newTestPack(t *testing.T) *testPack {
// create the first instance of JoinServiceGRPCServer wrapping the mock auth
// server, to imitate the JoinServiceGRPCServer which runs on Auth
authGRPCServer, authGRPCListener := newGRPCServer(t, grpc.ChainStreamInterceptor(ConnectionCountingStreamInterceptor(streamConnectionCount)))
authServer := NewJoinServiceGRPCServer(mockAuthServer, false)
authServer := NewJoinServiceGRPCServer(mockAuthServer)
proto.RegisterJoinServiceServer(authGRPCServer, authServer)

// create a client to the "auth" gRPC service
Expand All @@ -144,7 +153,7 @@ func newTestPack(t *testing.T) *testPack {
// create a second instance of JoinServiceGRPCServer wrapping the "auth"
// gRPC client, to imitate the JoinServiceGRPCServer which runs on Proxy
proxyGRPCServer, proxyGRPCListener := newGRPCServer(t, grpc.ChainStreamInterceptor(ConnectionCountingStreamInterceptor(streamConnectionCount)))
proxyServer := NewJoinServiceGRPCServer(authJoinServiceClient, true)
proxyServer := NewJoinServiceGRPCServer(authJoinServiceClient)
proto.RegisterJoinServiceServer(proxyGRPCServer, proxyServer)

// create a client to the "proxy" gRPC service
Expand Down
2 changes: 2 additions & 0 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4067,6 +4067,8 @@ type eventsListGetResponse struct {

// hostCredentials sends a registration token and metadata to the Auth Server
// and gets back SSH and TLS certificates.
// TODO(strideynet): DELETE IN V18.0.0
// Deprecated: Use the RegisterUsingToken RPC instead.
func (h *Handler) hostCredentials(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
var req types.RegisterUsingTokenRequest
if err := httplib.ReadJSON(r, &req); err != nil {
Expand Down

0 comments on commit 7fdaed1

Please sign in to comment.