Skip to content

Commit

Permalink
Tracing handler cleans up trace state after creating a root span (#134)
Browse files Browse the repository at this point in the history
Otherwise multiple root spans will be created.
  • Loading branch information
Carter Kozak authored and schlosna committed Oct 24, 2018
1 parent 425f127 commit 51f51b8
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@

package com.palantir.tritium.tracing;

import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.palantir.tracing.Tracers;
import com.palantir.tritium.api.functions.BooleanSupplier;
import com.palantir.tritium.event.AbstractInvocationEventHandler;
import com.palantir.tritium.event.DefaultInvocationContext;
import com.palantir.tritium.event.InstrumentationProperties;
import com.palantir.tritium.event.InvocationContext;
import com.palantir.tritium.event.InvocationEventHandler;
import java.lang.reflect.Method;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

public final class TracingInvocationEventHandler extends AbstractInvocationEventHandler<InvocationContext> {

Expand All @@ -52,6 +56,7 @@ public TracingInvocationEventHandler(String component) {
* @param component component name
* @return tracing event handler
*/
@SuppressWarnings("unchecked")
public static InvocationEventHandler<InvocationContext> create(String component) {
if (RemotingCompatibleTracingInvocationEventHandler.requiresRemotingFallback()) {
return RemotingCompatibleTracingInvocationEventHandler.create(component);
Expand All @@ -61,8 +66,9 @@ public static InvocationEventHandler<InvocationContext> create(String component)
}

@Override
public InvocationContext preInvocation(Object instance, Method method, Object[] args) {
InvocationContext context = DefaultInvocationContext.of(instance, method, args);
public TracingInvocationContext preInvocation(Object instance, Method method, Object[] args) {
boolean rootSpan = MDC.get(Tracers.TRACE_ID_KEY) == null;
TracingInvocationContext context = TracingInvocationContext.of(instance, method, args, rootSpan);
String operationName = getOperationName(method);
com.palantir.tracing.Tracer.startSpan(operationName);
return context;
Expand All @@ -74,20 +80,23 @@ private String getOperationName(Method method) {

@Override
public void onSuccess(@Nullable InvocationContext context, @Nullable Object result) {
debugIfNullContext(context);
// Context is null if no span was created, in which case the existing span should not be completed
if (context != null) {
com.palantir.tracing.Tracer.fastCompleteSpan();
}
complete(context);
}

@Override
public void onFailure(@Nullable InvocationContext context, @Nonnull Throwable cause) {
debugIfNullContext(context);
// TODO(davids): add Error event
complete(context);
}

private static void complete(@Nullable InvocationContext context) {
debugIfNullContext(context);
// Context is null if no span was created, in which case the existing span should not be completed
if (context != null) {
com.palantir.tracing.Tracer.fastCompleteSpan();
if (context instanceof TracingInvocationContext && ((TracingInvocationContext) context).isRootSpan()) {
com.palantir.tracing.Tracer.getAndClearTrace();
}
}
}

Expand All @@ -101,4 +110,68 @@ static BooleanSupplier getEnabledSupplier(String component) {
return InstrumentationProperties.getSystemPropertySupplier(component);
}

static final class TracingInvocationContext implements InvocationContext {

private static final Object[] NO_ARGS = {};

private final long startTimeNanos;
private final Object instance;
private final Method method;
private final Object[] args;
private final boolean rootSpan;

private TracingInvocationContext(
long startTimeNanos, Object instance, Method method, @Nullable Object[] args, boolean rootSpan) {
this.startTimeNanos = startTimeNanos;
this.instance = instance;
this.method = method;
this.args = toNonNullClone(args);
this.rootSpan = rootSpan;
}

private static Object[] toNonNullClone(@Nullable Object[] args) {
return args == null ? NO_ARGS : args.clone();
}

public static TracingInvocationContext of(
Object instance, Method method, @Nullable Object[] args, boolean rootSpan) {
return new TracingInvocationContext(
System.nanoTime(),
checkNotNull(instance, "instance"),
checkNotNull(method, "method"),
args,
rootSpan);
}

@Override
public long getStartTimeNanos() {
return startTimeNanos;
}

@Override
public Object getInstance() {
return instance;
}

@Override
public Method getMethod() {
return method;
}

@Override
public Object[] getArgs() {
return args;
}

boolean isRootSpan() {
return rootSpan;
}

@Override
@SuppressWarnings("DesignForExtension")
public String toString() {
return "TracingInvocationContext [startTimeNanos=" + startTimeNanos + ", instance=" + instance + ", method="
+ method + ", args=" + Arrays.toString(args) + ", rootSpan=" + rootSpan + "]";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.util.concurrent.MoreExecutors;
import com.palantir.tracing.AsyncSlf4jSpanObserver;
import com.palantir.tracing.Tracer;
import com.palantir.tracing.Tracers;
import com.palantir.tracing.api.Span;
import com.palantir.tracing.api.SpanObserver;
import com.palantir.tritium.event.InvocationContext;
Expand All @@ -41,6 +42,7 @@
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.slf4j.MDC;

@RunWith(MockitoJUnitRunner.class)
public class TracingInvocationEventHandlerTest {
Expand All @@ -56,6 +58,7 @@ public class TracingInvocationEventHandlerTest {

@Before
public void before() throws Exception {
Tracer.getAndClearTrace();
executor = MoreExecutors.newDirectExecutorService();
handler = TracingInvocationEventHandler.create("testComponent");
assertThat(handler).isInstanceOf(TracingInvocationEventHandler.class);
Expand All @@ -74,6 +77,7 @@ public void after() {
Tracer.unsubscribe("mock");
Tracer.unsubscribe("slf4j");
executor.shutdownNow();
Tracer.getAndClearTrace();
}

@Test
Expand All @@ -87,32 +91,39 @@ public void testPreInvocation() {
assertThat(context.getArgs()).isEqualTo(args);
assertThat(context.getStartTimeNanos()).isGreaterThan(startNanoseconds);
assertThat(context.getStartTimeNanos()).isLessThan(System.nanoTime());
assertThat(MDC.get(Tracers.TRACE_ID_KEY)).isNotNull();
}

@Test
public void testSuccess() {
InvocationContext context = handler.preInvocation(instance, method, args);

assertThat(MDC.get(Tracers.TRACE_ID_KEY)).isNotNull();

handler.onSuccess(context, null);

ArgumentCaptor<Span> spanCaptor = ArgumentCaptor.forClass(Span.class);
verify(mockSpanObserver, times(1)).consume(spanCaptor.capture());

Span span = spanCaptor.getValue();
assertThat(span.getDurationNanoSeconds()).isGreaterThan(0L);
assertThat(MDC.get(Tracers.TRACE_ID_KEY)).isNull();
}

@Test
public void testFailure() {
InvocationContext context = handler.preInvocation(instance, method, args);

assertThat(MDC.get(Tracers.TRACE_ID_KEY)).isNotNull();

handler.onFailure(context, new RuntimeException("unexpected"));

ArgumentCaptor<Span> spanCaptor = ArgumentCaptor.forClass(Span.class);
verify(mockSpanObserver, times(1)).consume(spanCaptor.capture());

Span span = spanCaptor.getValue();
assertThat(span.getDurationNanoSeconds()).isGreaterThan(0L);
assertThat(MDC.get(Tracers.TRACE_ID_KEY)).isNull();
}

@Test
Expand Down

0 comments on commit 51f51b8

Please sign in to comment.