Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[devicetrust] fix: handle server errors in bi-directional streams #44677

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions api/utils/stream_consumer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package utils

import (
"errors"
"io"

"github.com/gravitational/trace"
)

// ClientReceiver is an interface for receiving messages from a gRPC stream.
type ClientReceiver[T any] interface {
tigrato marked this conversation as resolved.
Show resolved Hide resolved
// Recv reads the next message from the stream.
Recv() (T, error)
}

// ConsumeStreamToErrorIfEOF reads from the gRPC bi-directional stream until an error is encountered if
// the sendErr is io.EOF. If the sendErr is not io.EOF, it is returned immediately because the error is not
// from the server - it is a client error and server did not send any response yet which will cause Recv to block.
// This function should be used when the client encounters an error while sending a message to the stream
// and wants to surface the server's error.
// gRPC never returns the server's error when calling Send function, instead client has to call Recv to get the error.
// It might need to call Recv multiple times to get the error if client's buffer has other messages from server.
codingllama marked this conversation as resolved.
Show resolved Hide resolved
func ConsumeStreamToErrorIfEOF[T any](sendErr error, stream ClientReceiver[T]) error {
tigrato marked this conversation as resolved.
Show resolved Hide resolved
// If the error is not EOF, return it immediately.
tigrato marked this conversation as resolved.
Show resolved Hide resolved
if !errors.Is(sendErr, io.EOF) {
return sendErr
}
for {
_, err := stream.Recv()
if err != nil {
return trace.Wrap(err)
}
}
}
101 changes: 101 additions & 0 deletions api/utils/stream_consumer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package utils

import (
"errors"
"io"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

func TestConsumeStreamToErrorIfEOF(t *testing.T) {
type args struct {
sendErr error
stream ClientReceiver[string]
}
type testCase struct {
name string
args args
assertErr require.ErrorAssertionFunc
}
tests := []testCase{
{
name: "send error is nil", /* this is a special case to avoid locks */
args: args{
sendErr: nil,
stream: &fakeClientReceiver[string]{},
},
assertErr: func(t require.TestingT, err error, i ...any) {
require.NoError(t, err)
},
},
{
name: "send error is not EOF",
args: args{
sendErr: errors.New("fake send error"),
stream: &fakeClientReceiver[string]{},
},
assertErr: func(t require.TestingT, err error, i ...any) {
require.ErrorContains(t, err, "fake send error")
},
},
{
name: "send error is EOF and stream returns err immediately",
args: args{
sendErr: trace.Wrap(io.EOF),
stream: &fakeClientReceiver[string]{numberOfCallsToErr: 1},
},
assertErr: func(t require.TestingT, err error, i ...any) {
require.ErrorContains(t, err, "fake error")
},
},
{
name: "send error is EOF and stream returns err after 10 calls",
args: args{
sendErr: trace.Wrap(io.EOF),
stream: &fakeClientReceiver[string]{numberOfCallsToErr: 10},
},
assertErr: func(t require.TestingT, err error, i ...any) {
require.ErrorContains(t, err, "fake error")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ConsumeStreamToErrorIfEOF(tt.args.sendErr, tt.args.stream)
tt.assertErr(t, err)
})
}
}

type fakeClientReceiver[T any] struct {
numberOfCallsToErr int
}

func (f *fakeClientReceiver[T]) Recv() (T, error) {
var v T
f.numberOfCallsToErr--
if f.numberOfCallsToErr == 0 {
return v, errors.New("fake error")
}
return v, nil
}
17 changes: 14 additions & 3 deletions lib/devicetrust/authn/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ import (
"github.com/gravitational/trace"

devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/devicetrust"
"github.com/gravitational/teleport/lib/devicetrust/native"
)

// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until
// it finds the first error.
var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*devicepb.AuthenticateDeviceResponse]

// Ceremony is the device authentication ceremony.
// It takes the client role of
// [devicepb.DeviceTrustServiceClient.AuthenticateDevice]
Expand Down Expand Up @@ -154,7 +159,7 @@ func (c *Ceremony) run(
Init: init,
},
}); err != nil {
return nil, trace.Wrap(devicetrust.HandleUnimplemented(err))
return nil, trace.Wrap(devicetrust.HandleUnimplemented(consumeStreamToErrorIfEOFFunc(err, stream)))
codingllama marked this conversation as resolved.
Show resolved Hide resolved
}
resp, err := stream.Recv()
if err != nil {
Expand Down Expand Up @@ -203,7 +208,10 @@ func (c *Ceremony) authenticateDeviceMacOS(
},
},
})
return trace.Wrap(err)
if err != nil {
return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream))
}
return nil
}

func (c *Ceremony) authenticateDeviceTPM(
Expand All @@ -223,5 +231,8 @@ func (c *Ceremony) authenticateDeviceTPM(
TpmChallengeResponse: challengeResponse,
},
})
return trace.Wrap(err)
if err != nil {
return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream))
}
return nil
}
17 changes: 14 additions & 3 deletions lib/devicetrust/enroll/enroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ import (
log "github.com/sirupsen/logrus"

devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/devicetrust"
"github.com/gravitational/teleport/lib/devicetrust/native"
)

// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until
// it finds the first error.
var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*devicepb.EnrollDeviceResponse]

// Ceremony is the device enrollment ceremony.
// It takes the client role of
// [devicepb.DeviceTrustServiceClient.EnrollDevice].
Expand Down Expand Up @@ -183,7 +188,7 @@ func (c *Ceremony) Run(ctx context.Context, devicesClient devicepb.DeviceTrustSe
Init: init,
},
}); err != nil {
return nil, trace.Wrap(devicetrust.HandleUnimplemented(err))
return nil, trace.Wrap(devicetrust.HandleUnimplemented(consumeStreamToErrorIfEOFFunc(err, stream)))
}
resp, err := stream.Recv()
if err != nil {
Expand Down Expand Up @@ -236,7 +241,10 @@ func (c *Ceremony) enrollDeviceMacOS(stream devicepb.DeviceTrustService_EnrollDe
},
},
})
return trace.Wrap(err)
if err != nil {
return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream))
}
return nil
}

func (c *Ceremony) enrollDeviceTPM(ctx context.Context, stream devicepb.DeviceTrustService_EnrollDeviceClient, resp *devicepb.EnrollDeviceResponse, debug bool) error {
Expand All @@ -259,5 +267,8 @@ func (c *Ceremony) enrollDeviceTPM(ctx context.Context, stream devicepb.DeviceTr
TpmChallengeResponse: challengeResponse,
},
})
return trace.Wrap(err)
if err != nil {
return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream))
}
return nil
}
12 changes: 12 additions & 0 deletions lib/secretsscanner/reporter/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ type env struct {

type opts struct {
device *device
preAssertError error
preReconcileError error
}

Expand All @@ -68,6 +69,12 @@ func withPreReconcileError(err error) option {
}
}

func withPreAssertError(err error) option {
return func(o *opts) {
o.preAssertError = err
}
}

func setup(t *testing.T, ops ...option) env {
t.Helper()

Expand All @@ -91,6 +98,7 @@ func setup(t *testing.T, ops ...option) env {

svc := newServiceFake(dtFakeSvc.Service)
svc.preReconcileError = o.preReconcileError
svc.preAssertError = o.preAssertError

tlsConfig, err := fixtures.LocalTLSConfig()
require.NoError(t, err)
Expand Down Expand Up @@ -130,9 +138,13 @@ type serviceFake struct {
privateKeysReported []*accessgraphsecretsv1pb.PrivateKey
deviceTrustSvc *dttestenv.FakeDeviceService
preReconcileError error
preAssertError error
}

func (s *serviceFake) ReportSecrets(in accessgraphsecretsv1pb.SecretsScannerService_ReportSecretsServer) error {
if s.preAssertError != nil {
return s.preAssertError
}
// Step 1. Assert the device.
if _, err := s.deviceTrustSvc.AssertDevice(in.Context(), streamAdapter{stream: in}); err != nil {
return trace.Wrap(err)
Expand Down
7 changes: 6 additions & 1 deletion lib/secretsscanner/reporter/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,15 @@ import (

accessgraphsecretsv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessgraph/v1"
devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
apiutils "github.com/gravitational/teleport/api/utils"
dtassert "github.com/gravitational/teleport/lib/devicetrust/assert"
secretsscannerclient "github.com/gravitational/teleport/lib/secretsscanner/client"
)

// consumeStreamToErrorIfEOFFunc consumes all messages from the stream until
// it finds the first error.
var consumeStreamToErrorIfEOFFunc = apiutils.ConsumeStreamToErrorIfEOF[*accessgraphsecretsv1pb.ReportSecretsResponse]

// AssertCeremonyBuilderFunc is a function that builds the device authentication ceremony.
type AssertCeremonyBuilderFunc func() (*dtassert.Ceremony, error)

Expand Down Expand Up @@ -139,7 +144,7 @@ func (r *Reporter) reportPrivateKeys(stream accessgraphsecretsv1pb.SecretsScanne
},
},
}); err != nil {
return trace.Wrap(err, "failed to send private keys")
return trace.Wrap(consumeStreamToErrorIfEOFFunc(err, stream), "failed to send private keys")
}
}
return nil
Expand Down
10 changes: 10 additions & 0 deletions lib/secretsscanner/reporter/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func TestReporter(t *testing.T) {
tests := []struct {
name string
preReconcileError error
preAssertError error
assertErr require.ErrorAssertionFunc
report []*accessgraphsecretsv1pb.PrivateKey
want []*accessgraphsecretsv1pb.PrivateKey
Expand All @@ -62,6 +63,14 @@ func TestReporter(t *testing.T) {
want: newPrivateKeys(t, deviceID),
assertErr: require.NoError,
},
{
name: "pre-assert error",
preAssertError: errors.New("pre-assert error"),
report: newPrivateKeys(t, deviceID),
assertErr: func(t require.TestingT, err error, _ ...any) {
require.ErrorContains(t, err, "pre-assert error")
},
},
{
name: "pre-reconcile error",
preReconcileError: errors.New("pre-reconcile error"),
Expand All @@ -80,6 +89,7 @@ func TestReporter(t *testing.T) {
t,
withDevice(deviceID, device),
withPreReconcileError(tt.preReconcileError),
withPreAssertError(tt.preAssertError),
)

ctx := context.Background()
Expand Down
Loading