From 1d57fabca51f8ba6737a1e0cdf04acdea2693cc1 Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Fri, 26 Jul 2024 11:58:12 +0100 Subject: [PATCH 1/3] [devicetrust] fix: handle server errors in bi-directional streams Fix an edge case in the devicetrust gRPC bi-directional stream handler where the stream could terminate with an error after the last client `Recv` call but before the next `Send` call. Previously, this scenario caused the `Send` method to return an `io.EOF` error, which indicated that the connection was terminated but did not reveal any errors returned by the server. This update ensures that the client continues to consume the stream until an error is returned, allowing for proper error handling and more robust stream management. Signed-off-by: Tiago Silva --- api/utils/stream_consumer.go | 34 +++++++++ api/utils/stream_consumer_test.go | 83 ++++++++++++++++++++++ lib/devicetrust/authn/authn.go | 17 ++++- lib/devicetrust/enroll/enroll.go | 17 ++++- lib/secretsscanner/reporter/env_test.go | 12 ++++ lib/secretsscanner/reporter/report.go | 7 +- lib/secretsscanner/reporter/report_test.go | 10 +++ 7 files changed, 173 insertions(+), 7 deletions(-) create mode 100644 api/utils/stream_consumer.go create mode 100644 api/utils/stream_consumer_test.go diff --git a/api/utils/stream_consumer.go b/api/utils/stream_consumer.go new file mode 100644 index 0000000000000..c0f48087f5218 --- /dev/null +++ b/api/utils/stream_consumer.go @@ -0,0 +1,34 @@ +package utils + +import ( + "errors" + "io" + + "github.com/gravitational/trace" +) + +// ClientReceiver is an interface for receiving messages from a gRPC stream. +type ClientReceiver[T any] interface { + // Recv reads the next message from the stream. + Recv() (T, error) +} + +// ConsumeStreamToErrorIfEOF reads from the gRPC bi-directional stream until an error is encountered if +// the sendErr is io.EOF. If the sendErr is not io.EOF, it is returned immediately because the error is not +// from the server - it is a client error and server did not send any response yet which will cause Recv to block. +// This function should be used when the client encounters an error while sending a message to the stream +// and wants to surface the server's error. +// gRPC never returns the server's error when calling Send function, instead client has to call Recv to get the error. +// It might need to call Recv multiple times to get the error if client's buffer has other messages from server. +func ConsumeStreamToErrorIfEOF[T any](sendErr error, stream ClientReceiver[T]) error { + // If the error is not EOF, return it immediately. + if !errors.Is(sendErr, io.EOF) { + return sendErr + } + for { + _, err := stream.Recv() + if err != nil { + return trace.Wrap(err) + } + } +} diff --git a/api/utils/stream_consumer_test.go b/api/utils/stream_consumer_test.go new file mode 100644 index 0000000000000..e8da65bea9b3d --- /dev/null +++ b/api/utils/stream_consumer_test.go @@ -0,0 +1,83 @@ +package utils + +import ( + "errors" + "io" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestConsumeStreamToErrorIfEOF(t *testing.T) { + type args struct { + sendErr error + stream ClientReceiver[string] + } + type testCase struct { + name string + args args + assertErr require.ErrorAssertionFunc + } + tests := []testCase{ + { + name: "send error is nil", /* this is a special case to avoid locks */ + args: args{ + sendErr: nil, + stream: &fakeClientReceiver[string]{}, + }, + assertErr: func(t require.TestingT, err error, i ...any) { + require.NoError(t, err) + }, + }, + { + name: "send error is not EOF", + args: args{ + sendErr: errors.New("fake send error"), + stream: &fakeClientReceiver[string]{}, + }, + assertErr: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "fake send error") + }, + }, + { + name: "send error is EOF and stream returns err immediately", + args: args{ + sendErr: trace.Wrap(io.EOF), + stream: &fakeClientReceiver[string]{numberOfCallsToErr: 1}, + }, + assertErr: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "fake error") + }, + }, + { + name: "send error is EOF and stream returns err after 10 calls", + args: args{ + sendErr: trace.Wrap(io.EOF), + stream: &fakeClientReceiver[string]{numberOfCallsToErr: 10}, + }, + assertErr: func(t require.TestingT, err error, i ...any) { + require.ErrorContains(t, err, "fake error") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ConsumeStreamToErrorIfEOF(tt.args.sendErr, tt.args.stream) + tt.assertErr(t, err) + }) + } +} + +type fakeClientReceiver[T any] struct { + numberOfCallsToErr int +} + +func (f *fakeClientReceiver[T]) Recv() (T, error) { + var v T + f.numberOfCallsToErr-- + if f.numberOfCallsToErr == 0 { + return v, errors.New("fake error") + } + return v, nil +} diff --git a/lib/devicetrust/authn/authn.go b/lib/devicetrust/authn/authn.go index ebb1140d743a7..b125b1f11b894 100644 --- a/lib/devicetrust/authn/authn.go +++ b/lib/devicetrust/authn/authn.go @@ -24,10 +24,15 @@ import ( "github.com/gravitational/trace" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/devicetrust" "github.com/gravitational/teleport/lib/devicetrust/native" ) +// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until +// it finds the first error. +var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*devicepb.AuthenticateDeviceResponse] + // Ceremony is the device authentication ceremony. // It takes the client role of // [devicepb.DeviceTrustServiceClient.AuthenticateDevice] @@ -154,7 +159,7 @@ func (c *Ceremony) run( Init: init, }, }); err != nil { - return nil, trace.Wrap(devicetrust.HandleUnimplemented(err)) + return nil, trace.Wrap(devicetrust.HandleUnimplemented(consumeStreamToErrorIfEOFFunc(err, stream))) } resp, err := stream.Recv() if err != nil { @@ -203,7 +208,10 @@ func (c *Ceremony) authenticateDeviceMacOS( }, }, }) - return trace.Wrap(err) + if err != nil { + return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + } + return nil } func (c *Ceremony) authenticateDeviceTPM( @@ -223,5 +231,8 @@ func (c *Ceremony) authenticateDeviceTPM( TpmChallengeResponse: challengeResponse, }, }) - return trace.Wrap(err) + if err != nil { + return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + } + return nil } diff --git a/lib/devicetrust/enroll/enroll.go b/lib/devicetrust/enroll/enroll.go index 0e7175a99f25e..c022d93981350 100644 --- a/lib/devicetrust/enroll/enroll.go +++ b/lib/devicetrust/enroll/enroll.go @@ -26,10 +26,15 @@ import ( log "github.com/sirupsen/logrus" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/devicetrust" "github.com/gravitational/teleport/lib/devicetrust/native" ) +// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until +// it finds the first error. +var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*devicepb.EnrollDeviceResponse] + // Ceremony is the device enrollment ceremony. // It takes the client role of // [devicepb.DeviceTrustServiceClient.EnrollDevice]. @@ -183,7 +188,7 @@ func (c *Ceremony) Run(ctx context.Context, devicesClient devicepb.DeviceTrustSe Init: init, }, }); err != nil { - return nil, trace.Wrap(devicetrust.HandleUnimplemented(err)) + return nil, trace.Wrap(devicetrust.HandleUnimplemented(consumeStreamToErrorIfEOFFunc(err, stream))) } resp, err := stream.Recv() if err != nil { @@ -236,7 +241,10 @@ func (c *Ceremony) enrollDeviceMacOS(stream devicepb.DeviceTrustService_EnrollDe }, }, }) - return trace.Wrap(err) + if err != nil { + return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + } + return nil } func (c *Ceremony) enrollDeviceTPM(ctx context.Context, stream devicepb.DeviceTrustService_EnrollDeviceClient, resp *devicepb.EnrollDeviceResponse, debug bool) error { @@ -259,5 +267,8 @@ func (c *Ceremony) enrollDeviceTPM(ctx context.Context, stream devicepb.DeviceTr TpmChallengeResponse: challengeResponse, }, }) - return trace.Wrap(err) + if err != nil { + return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + } + return nil } diff --git a/lib/secretsscanner/reporter/env_test.go b/lib/secretsscanner/reporter/env_test.go index 04ffebb5c6711..734260c3c81af 100644 --- a/lib/secretsscanner/reporter/env_test.go +++ b/lib/secretsscanner/reporter/env_test.go @@ -43,6 +43,7 @@ type env struct { type opts struct { device *device + preAssertError error preReconcileError error } @@ -68,6 +69,12 @@ func withPreReconcileError(err error) option { } } +func withPreAssertError(err error) option { + return func(o *opts) { + o.preAssertError = err + } +} + func setup(t *testing.T, ops ...option) env { t.Helper() @@ -91,6 +98,7 @@ func setup(t *testing.T, ops ...option) env { svc := newServiceFake(dtFakeSvc.Service) svc.preReconcileError = o.preReconcileError + svc.preAssertError = o.preAssertError tlsConfig, err := fixtures.LocalTLSConfig() require.NoError(t, err) @@ -130,9 +138,13 @@ type serviceFake struct { privateKeysReported []*accessgraphsecretsv1pb.PrivateKey deviceTrustSvc *dttestenv.FakeDeviceService preReconcileError error + preAssertError error } func (s *serviceFake) ReportSecrets(in accessgraphsecretsv1pb.SecretsScannerService_ReportSecretsServer) error { + if s.preAssertError != nil { + return s.preAssertError + } // Step 1. Assert the device. if _, err := s.deviceTrustSvc.AssertDevice(in.Context(), streamAdapter{stream: in}); err != nil { return trace.Wrap(err) diff --git a/lib/secretsscanner/reporter/report.go b/lib/secretsscanner/reporter/report.go index 22dd8c22b9d3b..4ca50d93fdd09 100644 --- a/lib/secretsscanner/reporter/report.go +++ b/lib/secretsscanner/reporter/report.go @@ -28,10 +28,15 @@ import ( accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" + apiutils "github.com/gravitational/teleport/api/utils" dtassert "github.com/gravitational/teleport/lib/devicetrust/assert" secretsscannerclient "github.com/gravitational/teleport/lib/secretsscanner/client" ) +// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until +// it finds the first error. +var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*accessgraphsecretsv1pb.ReportSecretsResponse] + // AssertCeremonyBuilderFunc is a function that builds the device authentication ceremony. type AssertCeremonyBuilderFunc func() (*dtassert.Ceremony, error) @@ -139,7 +144,7 @@ func (r *Reporter) reportPrivateKeys(stream accessgraphsecretsv1pb.SecretsScanne }, }, }); err != nil { - return trace.Wrap(err, "failed to send private keys") + return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream), "failed to send private keys") } } return nil diff --git a/lib/secretsscanner/reporter/report_test.go b/lib/secretsscanner/reporter/report_test.go index 76c2d661cd116..afc63df273899 100644 --- a/lib/secretsscanner/reporter/report_test.go +++ b/lib/secretsscanner/reporter/report_test.go @@ -52,6 +52,7 @@ func TestReporter(t *testing.T) { tests := []struct { name string preReconcileError error + preAssertError error assertErr require.ErrorAssertionFunc report []*accessgraphsecretsv1pb.PrivateKey want []*accessgraphsecretsv1pb.PrivateKey @@ -62,6 +63,14 @@ func TestReporter(t *testing.T) { want: newPrivateKeys(t, deviceID), assertErr: require.NoError, }, + { + name: "pre-assert error", + preAssertError: errors.New("pre-assert error"), + report: newPrivateKeys(t, deviceID), + assertErr: func(t require.TestingT, err error, _ ...any) { + require.ErrorContains(t, err, "pre-assert error") + }, + }, { name: "pre-reconcile error", preReconcileError: errors.New("pre-reconcile error"), @@ -80,6 +89,7 @@ func TestReporter(t *testing.T) { t, withDevice(deviceID, device), withPreReconcileError(tt.preReconcileError), + withPreAssertError(tt.preAssertError), ) ctx := context.Background() From 7988cbc87684f6e3da0e1baf844f7350bddee28b Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Fri, 26 Jul 2024 15:47:58 +0100 Subject: [PATCH 2/3] add license headers --- api/utils/stream_consumer.go | 18 ++++++++++++++++++ api/utils/stream_consumer_test.go | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/api/utils/stream_consumer.go b/api/utils/stream_consumer.go index c0f48087f5218..4049672b222cd 100644 --- a/api/utils/stream_consumer.go +++ b/api/utils/stream_consumer.go @@ -1,3 +1,21 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + package utils import ( diff --git a/api/utils/stream_consumer_test.go b/api/utils/stream_consumer_test.go index e8da65bea9b3d..d822a966c2b67 100644 --- a/api/utils/stream_consumer_test.go +++ b/api/utils/stream_consumer_test.go @@ -1,3 +1,21 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + package utils import ( From cf739f5b3d890afdf5f5edb7ad1845dc045fa14d Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Mon, 29 Jul 2024 15:01:12 +0100 Subject: [PATCH 3/3] drop helper and local skip io.EOF errors --- api/utils/stream_consumer.go | 52 ------------- api/utils/stream_consumer_test.go | 101 -------------------------- lib/devicetrust/authn/authn.go | 28 ++++--- lib/devicetrust/enroll/enroll.go | 28 ++++--- lib/secretsscanner/reporter/report.go | 12 ++- 5 files changed, 39 insertions(+), 182 deletions(-) delete mode 100644 api/utils/stream_consumer.go delete mode 100644 api/utils/stream_consumer_test.go diff --git a/api/utils/stream_consumer.go b/api/utils/stream_consumer.go deleted file mode 100644 index 4049672b222cd..0000000000000 --- a/api/utils/stream_consumer.go +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Teleport - * Copyright (C) 2024 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package utils - -import ( - "errors" - "io" - - "github.com/gravitational/trace" -) - -// ClientReceiver is an interface for receiving messages from a gRPC stream. -type ClientReceiver[T any] interface { - // Recv reads the next message from the stream. - Recv() (T, error) -} - -// ConsumeStreamToErrorIfEOF reads from the gRPC bi-directional stream until an error is encountered if -// the sendErr is io.EOF. If the sendErr is not io.EOF, it is returned immediately because the error is not -// from the server - it is a client error and server did not send any response yet which will cause Recv to block. -// This function should be used when the client encounters an error while sending a message to the stream -// and wants to surface the server's error. -// gRPC never returns the server's error when calling Send function, instead client has to call Recv to get the error. -// It might need to call Recv multiple times to get the error if client's buffer has other messages from server. -func ConsumeStreamToErrorIfEOF[T any](sendErr error, stream ClientReceiver[T]) error { - // If the error is not EOF, return it immediately. - if !errors.Is(sendErr, io.EOF) { - return sendErr - } - for { - _, err := stream.Recv() - if err != nil { - return trace.Wrap(err) - } - } -} diff --git a/api/utils/stream_consumer_test.go b/api/utils/stream_consumer_test.go deleted file mode 100644 index d822a966c2b67..0000000000000 --- a/api/utils/stream_consumer_test.go +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Teleport - * Copyright (C) 2024 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package utils - -import ( - "errors" - "io" - "testing" - - "github.com/gravitational/trace" - "github.com/stretchr/testify/require" -) - -func TestConsumeStreamToErrorIfEOF(t *testing.T) { - type args struct { - sendErr error - stream ClientReceiver[string] - } - type testCase struct { - name string - args args - assertErr require.ErrorAssertionFunc - } - tests := []testCase{ - { - name: "send error is nil", /* this is a special case to avoid locks */ - args: args{ - sendErr: nil, - stream: &fakeClientReceiver[string]{}, - }, - assertErr: func(t require.TestingT, err error, i ...any) { - require.NoError(t, err) - }, - }, - { - name: "send error is not EOF", - args: args{ - sendErr: errors.New("fake send error"), - stream: &fakeClientReceiver[string]{}, - }, - assertErr: func(t require.TestingT, err error, i ...any) { - require.ErrorContains(t, err, "fake send error") - }, - }, - { - name: "send error is EOF and stream returns err immediately", - args: args{ - sendErr: trace.Wrap(io.EOF), - stream: &fakeClientReceiver[string]{numberOfCallsToErr: 1}, - }, - assertErr: func(t require.TestingT, err error, i ...any) { - require.ErrorContains(t, err, "fake error") - }, - }, - { - name: "send error is EOF and stream returns err after 10 calls", - args: args{ - sendErr: trace.Wrap(io.EOF), - stream: &fakeClientReceiver[string]{numberOfCallsToErr: 10}, - }, - assertErr: func(t require.TestingT, err error, i ...any) { - require.ErrorContains(t, err, "fake error") - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := ConsumeStreamToErrorIfEOF(tt.args.sendErr, tt.args.stream) - tt.assertErr(t, err) - }) - } -} - -type fakeClientReceiver[T any] struct { - numberOfCallsToErr int -} - -func (f *fakeClientReceiver[T]) Recv() (T, error) { - var v T - f.numberOfCallsToErr-- - if f.numberOfCallsToErr == 0 { - return v, errors.New("fake error") - } - return v, nil -} diff --git a/lib/devicetrust/authn/authn.go b/lib/devicetrust/authn/authn.go index b125b1f11b894..80a875cb92123 100644 --- a/lib/devicetrust/authn/authn.go +++ b/lib/devicetrust/authn/authn.go @@ -20,19 +20,16 @@ package authn import ( "context" + "errors" + "io" "github.com/gravitational/trace" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" - apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/devicetrust" "github.com/gravitational/teleport/lib/devicetrust/native" ) -// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until -// it finds the first error. -var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*devicepb.AuthenticateDeviceResponse] - // Ceremony is the device authentication ceremony. // It takes the client role of // [devicepb.DeviceTrustServiceClient.AuthenticateDevice] @@ -158,8 +155,11 @@ func (c *Ceremony) run( Payload: &devicepb.AuthenticateDeviceRequest_Init{ Init: init, }, - }); err != nil { - return nil, trace.Wrap(devicetrust.HandleUnimplemented(consumeStreamToErrorIfEOFFunc(err, stream))) + }); err != nil && !errors.Is(err, io.EOF) { + // [io.EOF] indicates that the server has closed the stream. + // The client should handle the underlying error on the subsequent Recv call. + // All other errors are client-side errors and should be returned. + return nil, trace.Wrap(devicetrust.HandleUnimplemented(err)) } resp, err := stream.Recv() if err != nil { @@ -208,8 +208,11 @@ func (c *Ceremony) authenticateDeviceMacOS( }, }, }) - if err != nil { - return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + if err != nil && !errors.Is(err, io.EOF) { + // [io.EOF] indicates that the server has closed the stream. + // The client should handle the underlying error on the subsequent Recv call. + // All other errors are client-side errors and should be returned. + return trace.Wrap(err) } return nil } @@ -231,8 +234,11 @@ func (c *Ceremony) authenticateDeviceTPM( TpmChallengeResponse: challengeResponse, }, }) - if err != nil { - return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + if err != nil && !errors.Is(err, io.EOF) { + // [io.EOF] indicates that the server has closed the stream. + // The client should handle the underlying error on the subsequent Recv call. + // All other errors are client-side errors and should be returned. + return trace.Wrap(err) } return nil } diff --git a/lib/devicetrust/enroll/enroll.go b/lib/devicetrust/enroll/enroll.go index c022d93981350..dd733aa39c0d2 100644 --- a/lib/devicetrust/enroll/enroll.go +++ b/lib/devicetrust/enroll/enroll.go @@ -20,21 +20,18 @@ package enroll import ( "context" + "errors" + "io" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" log "github.com/sirupsen/logrus" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" - apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/devicetrust" "github.com/gravitational/teleport/lib/devicetrust/native" ) -// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until -// it finds the first error. -var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*devicepb.EnrollDeviceResponse] - // Ceremony is the device enrollment ceremony. // It takes the client role of // [devicepb.DeviceTrustServiceClient.EnrollDevice]. @@ -187,8 +184,11 @@ func (c *Ceremony) Run(ctx context.Context, devicesClient devicepb.DeviceTrustSe Payload: &devicepb.EnrollDeviceRequest_Init{ Init: init, }, - }); err != nil { - return nil, trace.Wrap(devicetrust.HandleUnimplemented(consumeStreamToErrorIfEOFFunc(err, stream))) + }); err != nil && !errors.Is(err, io.EOF) { + // [io.EOF] indicates that the server has closed the stream. + // The client should handle the underlying error on the subsequent Recv call. + // All other errors are client-side errors and should be returned. + return nil, trace.Wrap(devicetrust.HandleUnimplemented(err)) } resp, err := stream.Recv() if err != nil { @@ -241,8 +241,11 @@ func (c *Ceremony) enrollDeviceMacOS(stream devicepb.DeviceTrustService_EnrollDe }, }, }) - if err != nil { - return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + if err != nil && !errors.Is(err, io.EOF) { + // [io.EOF] indicates that the server has closed the stream. + // The client should handle the underlying error on the subsequent Recv call. + // All other errors are client-side errors and should be returned. + return trace.Wrap(err) } return nil } @@ -267,8 +270,11 @@ func (c *Ceremony) enrollDeviceTPM(ctx context.Context, stream devicepb.DeviceTr TpmChallengeResponse: challengeResponse, }, }) - if err != nil { - return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream)) + // [io.EOF] indicates that the server has closed the stream. + // The client should handle the underlying error on the subsequent Recv call. + // All other errors are client-side errors and should be returned. + if err != nil && !errors.Is(err, io.EOF) { + return trace.Wrap(err) } return nil } diff --git a/lib/secretsscanner/reporter/report.go b/lib/secretsscanner/reporter/report.go index 4ca50d93fdd09..09e0bfe0ba9e3 100644 --- a/lib/secretsscanner/reporter/report.go +++ b/lib/secretsscanner/reporter/report.go @@ -28,15 +28,10 @@ import ( accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" - apiutils "github.com/gravitational/teleport/api/utils" dtassert "github.com/gravitational/teleport/lib/devicetrust/assert" secretsscannerclient "github.com/gravitational/teleport/lib/secretsscanner/client" ) -// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until -// it finds the first error. -var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*accessgraphsecretsv1pb.ReportSecretsResponse] - // AssertCeremonyBuilderFunc is a function that builds the device authentication ceremony. type AssertCeremonyBuilderFunc func() (*dtassert.Ceremony, error) @@ -143,8 +138,11 @@ func (r *Reporter) reportPrivateKeys(stream accessgraphsecretsv1pb.SecretsScanne Keys: privateKeys[start:end], }, }, - }); err != nil { - return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream), "failed to send private keys") + }); err != nil && !errors.Is(err, io.EOF) { + // [io.EOF] indicates that the server has closed the stream. + // The client should handle the underlying error on the subsequent Recv call. + // All other errors are client-side errors and should be returned. + return trace.Wrap(err, "failed to send private keys") } } return nil