diff --git a/libs/telemetry/src/test/java/org/opensearch/telemetry/tracing/DefaultTracerTests.java b/libs/telemetry/src/test/java/org/opensearch/telemetry/tracing/DefaultTracerTests.java index 5205bdfc8a031..48b72e1f673fe 100644 --- a/libs/telemetry/src/test/java/org/opensearch/telemetry/tracing/DefaultTracerTests.java +++ b/libs/telemetry/src/test/java/org/opensearch/telemetry/tracing/DefaultTracerTests.java @@ -19,8 +19,8 @@ import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicReference; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -252,12 +252,6 @@ public void testEndSpanByClosingSpanScopeMultiple() { public void testSpanAcrossThreads() { TracingTelemetry tracingTelemetry = new MockTracingTelemetry(); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - AtomicReference currentSpanRefThread1 = new AtomicReference<>(); - AtomicReference currentSpanRefThread2 = new AtomicReference<>(); - AtomicReference currentSpanRefAfterEndThread2 = new AtomicReference<>(); - - AtomicReference spanRef = new AtomicReference<>(); - AtomicReference spanT2Ref = new AtomicReference<>(); ThreadContextBasedTracerContextStorage spanTracerStorage = new ThreadContextBasedTracerContextStorage( threadContext, @@ -265,29 +259,26 @@ public void testSpanAcrossThreads() { ); DefaultTracer defaultTracer = new DefaultTracer(tracingTelemetry, spanTracerStorage); - executorService.execute(() -> { + CompletableFuture asyncTask = CompletableFuture.runAsync(() -> { // create a span Span span = defaultTracer.startSpan(new SpanCreationContext("span_name_t_1", Attributes.EMPTY)); SpanScope spanScope = defaultTracer.withSpanInScope(span); - spanRef.set(span); - executorService.execute(() -> { + CompletableFuture asyncTask1 = CompletableFuture.runAsync(() -> { Span spanT2 = defaultTracer.startSpan(new SpanCreationContext("span_name_t_2", Attributes.EMPTY)); SpanScope spanScopeT2 = defaultTracer.withSpanInScope(spanT2); - spanT2Ref.set(spanT2); - - currentSpanRefThread2.set(defaultTracer.getCurrentSpan().getSpan()); + assertEquals(spanT2, defaultTracer.getCurrentSpan().getSpan()); - spanT2.endSpan(); spanScopeT2.close(); - currentSpanRefAfterEndThread2.set(getCurrentSpanFromContext(defaultTracer)); - }); + spanT2.endSpan(); + assertEquals(null, defaultTracer.getCurrentSpan()); + }, executorService); + asyncTask1.join(); spanScope.close(); - currentSpanRefThread1.set(getCurrentSpanFromContext(defaultTracer)); - }); - assertEquals(spanT2Ref.get(), currentSpanRefThread2.get()); - assertEquals(spanRef.get(), currentSpanRefAfterEndThread2.get()); - assertEquals(null, currentSpanRefThread1.get()); + span.endSpan(); + assertEquals(null, defaultTracer.getCurrentSpan()); + }, executorService); + asyncTask.join(); } public void testSpanCloseOnThread2() { @@ -297,27 +288,27 @@ public void testSpanCloseOnThread2() { threadContext, tracingTelemetry ); - AtomicReference currentSpanRefThread1 = new AtomicReference<>(); - AtomicReference currentSpanRefThread2 = new AtomicReference<>(); DefaultTracer defaultTracer = new DefaultTracer(tracingTelemetry, spanTracerStorage); final Span span = defaultTracer.startSpan(new SpanCreationContext("span_name_t1", Attributes.EMPTY)); try (SpanScope spanScope = defaultTracer.withSpanInScope(span)) { - executorService.execute(() -> async(new ActionListener() { + CompletableFuture asyncTask = CompletableFuture.runAsync(() -> async(new ActionListener() { @Override public void onResponse(Boolean response) { - span.endSpan(); - currentSpanRefThread2.set(defaultTracer.getCurrentSpan()); + try (SpanScope s = defaultTracer.withSpanInScope(span)) { + assertEquals(span, defaultTracer.getCurrentSpan().getSpan()); + } finally { + span.endSpan(); + } } @Override public void onFailure(Exception e) { } - })); - currentSpanRefThread1.set(defaultTracer.getCurrentSpan()); + }), executorService); + assertEquals(span, defaultTracer.getCurrentSpan().getSpan()); + asyncTask.join(); } - assertEquals(null, currentSpanRefThread2.get()); - assertEquals(span, currentSpanRefThread1.get().getSpan()); assertEquals(null, defaultTracer.getCurrentSpan()); } @@ -337,13 +328,6 @@ private void async(ActionListener actionListener) { public void testSpanAcrossThreadsMultipleSpans() { TracingTelemetry tracingTelemetry = new MockTracingTelemetry(); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); - AtomicReference currentSpanRefThread1 = new AtomicReference<>(); - AtomicReference currentSpanRefThread2 = new AtomicReference<>(); - AtomicReference currentSpanRefAfterEndThread2 = new AtomicReference<>(); - - AtomicReference parentSpanRef = new AtomicReference<>(); - AtomicReference spanRef = new AtomicReference<>(); - AtomicReference spanT2Ref = new AtomicReference<>(); ThreadContextBasedTracerContextStorage spanTracerStorage = new ThreadContextBasedTracerContextStorage( threadContext, @@ -351,43 +335,38 @@ public void testSpanAcrossThreadsMultipleSpans() { ); DefaultTracer defaultTracer = new DefaultTracer(tracingTelemetry, spanTracerStorage); - executorService.execute(() -> { + CompletableFuture asyncTask = CompletableFuture.runAsync(() -> { // create a parent span Span parentSpan = defaultTracer.startSpan(new SpanCreationContext("p_span_name", Attributes.EMPTY)); SpanScope parentSpanScope = defaultTracer.withSpanInScope(parentSpan); - parentSpanRef.set(parentSpan); // create a span Span span = defaultTracer.startSpan(new SpanCreationContext("span_name_t_1", Attributes.EMPTY)); SpanScope spanScope = defaultTracer.withSpanInScope(span); - spanRef.set(span); - executorService.execute(() -> { + CompletableFuture asyncTask1 = CompletableFuture.runAsync(() -> { Span spanT2 = defaultTracer.startSpan(new SpanCreationContext("span_name_t_2", Attributes.EMPTY)); SpanScope spanScopeT2 = defaultTracer.withSpanInScope(spanT2); - Span spanT21 = defaultTracer.startSpan(new SpanCreationContext("span_name_t_2", Attributes.EMPTY)); - SpanScope spanScopeT21 = defaultTracer.withSpanInScope(spanT2); - spanT2Ref.set(spanT21); - currentSpanRefThread2.set(defaultTracer.getCurrentSpan().getSpan()); - - spanT21.endSpan(); + SpanScope spanScopeT21 = defaultTracer.withSpanInScope(spanT21); + assertEquals(spanT21, defaultTracer.getCurrentSpan().getSpan()); spanScopeT21.close(); + spanT21.endSpan(); - spanT2.endSpan(); spanScopeT2.close(); - currentSpanRefAfterEndThread2.set(getCurrentSpanFromContext(defaultTracer)); - }); + spanT2.endSpan(); + + assertEquals(null, defaultTracer.getCurrentSpan()); + }, executorService); + + asyncTask1.join(); + spanScope.close(); + span.endSpan(); parentSpanScope.close(); - currentSpanRefThread1.set(getCurrentSpanFromContext(defaultTracer)); - }); - assertEquals(spanT2Ref.get(), currentSpanRefThread2.get()); - assertEquals(spanRef.get(), currentSpanRefAfterEndThread2.get()); - assertEquals(null, currentSpanRefThread1.get()); - } - - private static Span getCurrentSpanFromContext(DefaultTracer defaultTracer) { - return defaultTracer.getCurrentSpan() != null ? defaultTracer.getCurrentSpan().getSpan() : null; + parentSpan.endSpan(); + assertEquals(null, defaultTracer.getCurrentSpan()); + }, executorService); + asyncTask.join(); } public void testClose() throws IOException {