Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use GRPC interceptors for bearer token #6063

Merged
merged 15 commits into from
Oct 25, 2024
18 changes: 14 additions & 4 deletions cmd/query/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,23 @@ func createGRPCServer(querySvc *querysvc.QueryService, metricsQuerySvc querysvc.

grpcOpts = append(grpcOpts, grpc.Creds(creds))
}

unaryInterceptors := []grpc.UnaryServerInterceptor{
bearertoken.NewUnaryServerInterceptor(),
}
streamInterceptors := []grpc.StreamServerInterceptor{
bearertoken.NewStreamServerInterceptor(),
}
if tm.Enabled {
grpcOpts = append(grpcOpts,
grpc.StreamInterceptor(tenancy.NewGuardingStreamInterceptor(tm)),
grpc.UnaryInterceptor(tenancy.NewGuardingUnaryInterceptor(tm)),
)
unaryInterceptors = append(unaryInterceptors, tenancy.NewGuardingUnaryInterceptor(tm))
streamInterceptors = append(streamInterceptors, tenancy.NewGuardingStreamInterceptor(tm))
}

grpcOpts = append(grpcOpts,
grpc.ChainUnaryInterceptor(unaryInterceptors...),
grpc.ChainStreamInterceptor(streamInterceptors...),
)

server := grpc.NewServer(grpcOpts...)
reflection.Register(server)

Expand Down
19 changes: 15 additions & 4 deletions cmd/remote-storage/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"google.golang.org/grpc/reflection"

"github.com/jaegertracing/jaeger/cmd/query/app/querysvc"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/pkg/telemetery"
"github.com/jaegertracing/jaeger/pkg/tenancy"
"github.com/jaegertracing/jaeger/plugin/storage/grpc/shared"
Expand Down Expand Up @@ -96,13 +97,23 @@ func createGRPCServer(opts *Options, tm *tenancy.Manager, handler *shared.GRPCHa
creds := credentials.NewTLS(tlsCfg)
grpcOpts = append(grpcOpts, grpc.Creds(creds))
}

unaryInterceptors := []grpc.UnaryServerInterceptor{
bearertoken.NewUnaryServerInterceptor(),
}
streamInterceptors := []grpc.StreamServerInterceptor{
bearertoken.NewStreamServerInterceptor(),
}
if tm.Enabled {
grpcOpts = append(grpcOpts,
grpc.StreamInterceptor(tenancy.NewGuardingStreamInterceptor(tm)),
grpc.UnaryInterceptor(tenancy.NewGuardingUnaryInterceptor(tm)),
)
unaryInterceptors = append(unaryInterceptors, tenancy.NewGuardingUnaryInterceptor(tm))
streamInterceptors = append(streamInterceptors, tenancy.NewGuardingStreamInterceptor(tm))
}

grpcOpts = append(grpcOpts,
grpc.ChainUnaryInterceptor(unaryInterceptors...),
grpc.ChainStreamInterceptor(streamInterceptors...),
)

server := grpc.NewServer(grpcOpts...)
yurishkuro marked this conversation as resolved.
Show resolved Hide resolved
healthServer := health.NewServer()
reflection.Register(server)
Expand Down
141 changes: 141 additions & 0 deletions pkg/bearertoken/grpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (c) 2024 The Jaeger Authors.
// SPDX-License-Identifier: Apache-2.0

package bearertoken

import (
"context"
"fmt"

"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

const Key = "bearer.token"

type tokenatedServerStream struct {
grpc.ServerStream
context context.Context
yurishkuro marked this conversation as resolved.
Show resolved Hide resolved
}

// getValidBearerToken attempts to retrieve the bearer token from the context.
// It does not return an error if the token is missing.
func getValidBearerToken(ctx context.Context, bearerHeader string) (string, error) {
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
bearerToken, ok := GetBearerToken(ctx)
if ok && bearerToken != "" {
return bearerToken, nil
}

Check warning on line 27 in pkg/bearertoken/grpc.go

View check run for this annotation

Codecov / codecov/patch

pkg/bearertoken/grpc.go#L26-L27

Added lines #L26 - L27 were not covered by tests

md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", nil
}

Check warning on line 32 in pkg/bearertoken/grpc.go

View check run for this annotation

Codecov / codecov/patch

pkg/bearertoken/grpc.go#L31-L32

Added lines #L31 - L32 were not covered by tests

tokens := md.Get(bearerHeader)
if len(tokens) < 1 {
return "", nil
}
if len(tokens) > 1 {
return "", fmt.Errorf("malformed token: multiple tokens found")
}
return tokens[0], nil

Check warning on line 41 in pkg/bearertoken/grpc.go

View check run for this annotation

Codecov / codecov/patch

pkg/bearertoken/grpc.go#L38-L41

Added lines #L38 - L41 were not covered by tests
}

// NewStreamServerInterceptor creates a new stream interceptor that injects the bearer token into the context if available.
func NewStreamServerInterceptor() grpc.StreamServerInterceptor {
return streamServerInterceptor
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
}

func streamServerInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
bearerToken, err := getValidBearerToken(ss.Context(), Key)
if err != nil {
return err
}

Check warning on line 53 in pkg/bearertoken/grpc.go

View check run for this annotation

Codecov / codecov/patch

pkg/bearertoken/grpc.go#L52-L53

Added lines #L52 - L53 were not covered by tests
// Upgrade the bearer token to be part of the context.
return handler(srv, &tokenatedServerStream{
ServerStream: ss,
context: ContextWithBearerToken(ss.Context(), bearerToken),
})
}

// NewUnaryServerInterceptor creates a new unary interceptor that injects the bearer token into the context if available.
func NewUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
bearerToken, err := getValidBearerToken(ctx, Key)
if err != nil {
return nil, err
}

Check warning on line 67 in pkg/bearertoken/grpc.go

View check run for this annotation

Codecov / codecov/patch

pkg/bearertoken/grpc.go#L66-L67

Added lines #L66 - L67 were not covered by tests

return handler(ContextWithBearerToken(ctx, bearerToken), req)
}
}

// NewUnaryClientInterceptor injects the bearer token header into gRPC request metadata.
func NewUnaryClientInterceptor() grpc.UnaryClientInterceptor {
return grpc.UnaryClientInterceptor(func(
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
ctx context.Context,
method string,
req, reply any,
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
var token string
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
if md, ok := metadata.FromIncomingContext(ctx); ok {
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
tokens := md.Get(Key)
if len(tokens) > 1 {
return fmt.Errorf("malformed token: multiple tokens found")
}
if len(tokens) == 1 && tokens[0] != "" {
token = tokens[0]
}
}

if token == "" {
bearerToken, ok := GetBearerToken(ctx)
if ok && bearerToken != "" {
token = bearerToken
}
}

if token != "" {
ctx = metadata.AppendToOutgoingContext(ctx, Key, token)
}
return invoker(ctx, method, req, reply, cc, opts...)
})
}

// NewStreamClientInterceptor injects the bearer token header into gRPC request metadata.
func NewStreamClientInterceptor() grpc.StreamClientInterceptor {
return grpc.StreamClientInterceptor(func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
streamer grpc.Streamer,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
var token string
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
if md, ok := metadata.FromIncomingContext(ctx); ok {
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
tokens := md.Get(Key)
if len(tokens) > 1 {
return nil, fmt.Errorf("malformed token: multiple tokens found")
}
if len(tokens) == 1 && tokens[0] != "" {
token = tokens[0]
}
}

if token == "" {
bearerToken, ok := GetBearerToken(ctx)
if ok && bearerToken != "" {
token = bearerToken
}
}

if token != "" {
ctx = metadata.AppendToOutgoingContext(ctx, Key, token)
}
return streamer(ctx, desc, cc, method, opts...)
})
}
152 changes: 152 additions & 0 deletions pkg/bearertoken/grpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright (c) 2024 The Jaeger Authors.
// SPDX-License-Identifier: Apache-2.0

package bearertoken

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)

func TestBearerTokenInterceptors(t *testing.T) {
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
tests := []struct {
name string
ctx context.Context
expectedErr string
expectedMD metadata.MD
}{
{
name: "no token in context",
ctx: context.Background(),
expectedErr: "",
expectedMD: nil, // Expecting no metadata
},
{
name: "token in context",
ctx: ContextWithBearerToken(context.Background(), "test-token"),
expectedErr: "",
expectedMD: metadata.MD{Key: []string{"test-token"}},
},
{
name: "multiple tokens in metadata",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{Key: []string{"token1", "token2"}}),
expectedErr: "malformed token: multiple tokens found",
},
{
name: "valid token in metadata",
ctx: metadata.NewIncomingContext(context.Background(), metadata.MD{Key: []string{"valid-token"}}),
expectedErr: "",
expectedMD: metadata.MD{Key: []string{"valid-token"}}, // Valid token setup
},
}
yurishkuro marked this conversation as resolved.
Show resolved Hide resolved

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Unary interceptor test
unaryInterceptor := NewUnaryClientInterceptor()
unaryInvoker := func(ctx context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
md, ok := metadata.FromOutgoingContext(ctx)
if test.expectedMD == nil {
require.False(t, ok) // There should be no metadata in this case
} else {
require.True(t, ok)
assert.Equal(t, test.expectedMD, md)
}
return nil
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
}
err := unaryInterceptor(test.ctx, "method", nil, nil, nil, unaryInvoker)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedErr)
}

// Stream interceptor test
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
streamInterceptor := NewStreamClientInterceptor()
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
streamInvoker := func(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
md, ok := metadata.FromOutgoingContext(ctx)
if test.expectedMD == nil {
require.False(t, ok) // There should be no metadata in this case
} else {
require.True(t, ok)
assert.Equal(t, test.expectedMD, md)
}
return nil, nil
}
_, err = streamInterceptor(test.ctx, &grpc.StreamDesc{}, nil, "method", streamInvoker)
if test.expectedErr == "" {
require.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedErr)
}
})
}
}

func TestClientUnaryInterceptorWithBearerToken(t *testing.T) {
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved
interceptor := NewUnaryClientInterceptor()

// Mock invoker
invoker := func(ctx context.Context, _ string, _ any, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
assert.Equal(t, "test-token", md[Key][0])
return nil
}

// Context with token
ctx := ContextWithBearerToken(context.Background(), "test-token")

err := interceptor(ctx, "method", nil, nil, nil, invoker)
require.NoError(t, err)
}

func TestClientStreamInterceptorWithBearerToken(t *testing.T) {
interceptor := NewStreamClientInterceptor()

// Mock streamer
streamer := func(ctx context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(t, ok)
assert.Equal(t, "test-token", md[Key][0])
return nil, nil
}

// Context with token
ctx := ContextWithBearerToken(context.Background(), "test-token")

_, err := interceptor(ctx, &grpc.StreamDesc{}, nil, "method", streamer)
require.NoError(t, err)
}

func TestMalformedToken(t *testing.T) {
// Context with multiple tokens
ctx := metadata.NewIncomingContext(context.Background(), metadata.MD{
Key: []string{"token1", "token2"},
})

// Unary interceptor
unaryInterceptor := NewUnaryClientInterceptor()
unaryInvoker := func(_ context.Context, _ string, _, _ any, _ *grpc.ClientConn, _ ...grpc.CallOption) error {
return nil
}
err := unaryInterceptor(ctx, "method", nil, nil, nil, unaryInvoker)
require.Error(t, err)
assert.Contains(t, err.Error(), "malformed token: multiple tokens found")
chahatsagarmain marked this conversation as resolved.
Show resolved Hide resolved

// Stream interceptor
streamInterceptor := NewStreamClientInterceptor()
streamInvoker := func(_ context.Context, _ *grpc.StreamDesc, _ *grpc.ClientConn, _ string, _ ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, nil
}
_, err = streamInterceptor(ctx, &grpc.StreamDesc{}, nil, "method", streamInvoker)
require.Error(t, err)
assert.Contains(t, err.Error(), "malformed token: multiple tokens found")
}
4 changes: 4 additions & 0 deletions plugin/storage/grpc/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"go.uber.org/zap"
"google.golang.org/grpc"

"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/pkg/metrics"
"github.com/jaegertracing/jaeger/pkg/tenancy"
"github.com/jaegertracing/jaeger/plugin"
Expand Down Expand Up @@ -127,6 +128,9 @@ func (f *Factory) newRemoteStorage(telset component.TelemetrySettings, newClient
opts = append(opts, grpc.WithStreamInterceptor(tenancy.NewClientStreamInterceptor(tenancyMgr)))
}

opts = append(opts, grpc.WithUnaryInterceptor(bearertoken.NewUnaryClientInterceptor()))
opts = append(opts, grpc.WithStreamInterceptor(bearertoken.NewStreamClientInterceptor()))

remoteConn, err := newClient(opts...)
if err != nil {
return nil, fmt.Errorf("error creating remote storage client: %w", err)
Expand Down
Loading