Skip to content

Commit

Permalink
feat: add middleware for tracing GRPC stream client (#364)
Browse files Browse the repository at this point in the history
GPRC middleware was missing helper for helping with creating child
span with correct metadata. The spans were using default, service
name `grpc.client`. Using this middle should fix such clients.
  • Loading branch information
arunpoudel authored Jun 4, 2024
1 parent a2a7b22 commit 9e74751
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 7 deletions.
15 changes: 15 additions & 0 deletions middleware/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ func UnaryClientInterceptor(options ...Option) grpc.UnaryClientInterceptor {
return ddGrpc.UnaryClientInterceptor(opts...)
}

// StreamClientInterceptor create a client-interceptor to automatically create child-spans, and append to gRPC metadata.
func StreamClientInterceptor(options ...Option) grpc.StreamClientInterceptor {
if internal.IsDatadogDisabled() {
return noOpStreamClientInterceptor()
}
opts := convertOptions(options...)
return ddGrpc.StreamClientInterceptor(opts...)
}

func noOpStreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return streamer(ctx, desc, cc, method, opts...)
}
}

func noOpUnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
return invoker(ctx, method, req, reply, cc, opts...)
Expand Down
70 changes: 63 additions & 7 deletions middleware/grpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,33 @@ type testServer struct {
ddParentID string
}

func (s *testServer) EmptyCall(ctx context.Context, _ *testgrpc.Empty) (*testgrpc.Empty, error) {
func (s *testServer) hydrateTraceData(ctx context.Context) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.InvalidArgument, "not metadata in request")
return status.Error(codes.InvalidArgument, "not metadata in request")
}
s.traceparent = strings.Join(md.Get("traceparent"), "")
s.tracestate = strings.Join(md.Get("tracestate"), "")
s.ddParentID = strings.Join(md.Get("x-datadog-parent-id"), "")
s.ddTraceID = strings.Join(md.Get("x-datadog-trace-id"), "")

return nil
}

func (s *testServer) EmptyCall(ctx context.Context, _ *testgrpc.Empty) (*testgrpc.Empty, error) {
if err := s.hydrateTraceData(ctx); err != nil {
return nil, err
}
s.traceparent = strings.Join(md.Get("Traceparent"), "")
s.tracestate = strings.Join(md.Get("Tracestate"), "")
s.ddTraceID = strings.Join(md.Get("X-Datadog-Trace-Id"), "")
s.ddParentID = strings.Join(md.Get("X-Datadog-Parent-Id"), "")
return new(testgrpc.Empty), nil
}

func (s *testServer) StreamingOutputCall(_ *testgrpc.StreamingOutputCallRequest, streamingServer grpc.ServerStreamingServer[testgrpc.StreamingOutputCallResponse]) error {
if err := s.hydrateTraceData(streamingServer.Context()); err != nil {
return err
}
return nil
}

const bufSize = 1024 * 1024

func TestTraceUnaryClientInterceptor(t *testing.T) {
Expand Down Expand Up @@ -93,7 +108,6 @@ func TestTraceUnaryClientInterceptorW3C(t *testing.T) {
tracer.Start()

server := &testServer{}

listener := bufconn.Listen(bufSize)
grpcServer := grpc.NewServer()
testgrpc.RegisterTestServiceServer(grpcServer, server)
Expand Down Expand Up @@ -142,3 +156,45 @@ func TestTraceUnaryClientInterceptorW3C(t *testing.T) {
}
assert.True(t, found, "Did not find Datadog's list-member in w3c tracestate")
}

func TestStreamClientInterceptor(t *testing.T) {
testhelpers.ConfigureDatadog(t)

// Start Datadog tracer, so that we don't create NoopSpans.
testTracer := mocktracer.Start()

server := &testServer{}

listener := bufconn.Listen(bufSize)
grpcServer := grpc.NewServer()
testgrpc.RegisterTestServiceServer(grpcServer, server)
errCh := make(chan error, 1)
go func() {
errCh <- grpcServer.Serve(listener)
}()

conn, err := grpc.NewClient("dns:///localhost",
grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) { return listener.Dial() }),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithStreamInterceptor(StreamClientInterceptor()),
)
require.NoError(t, err)

client := testgrpc.NewTestServiceClient(conn)
span, spanCtx := tracer.StartSpanFromContext(context.Background(), "grpc.request", tracer.ResourceName("/helloworld"))
defer span.Finish()
c, err := client.StreamingOutputCall(spanCtx, &testgrpc.StreamingOutputCallRequest{})
require.NoError(t, err)

c.Recv()

testTracer.Stop()

spans := testTracer.FinishedSpans()
require.Equal(t, 2, len(spans))
assert.Equal(t, strconv.Itoa(int(span.Context().TraceID())), server.ddTraceID)
for _, finishedSpan := range spans {
assert.Equal(t, strconv.Itoa(int(finishedSpan.TraceID())), server.ddTraceID)
assert.Equal(t, strconv.Itoa(int(finishedSpan.ParentID())), server.ddParentID)
}
}

0 comments on commit 9e74751

Please sign in to comment.