From 07e36f81cd365614c5ea747f66d5f4c3183ec595 Mon Sep 17 00:00:00 2001 From: Paul Bakker Date: Mon, 8 Apr 2024 18:51:15 -0700 Subject: [PATCH] Make ContextSnapshotFactory nullable in VirtualThreadTaskExecutor to support frameworks that wrap the whole task executor for context propagation (#1878) --- .../VirtualThreadTaskExecutor.java | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/graphql-dgs/src/main/java21/com.netflix.graphql.dgs.internal.VirtualThreadTaskExecutor/VirtualThreadTaskExecutor.java b/graphql-dgs/src/main/java21/com.netflix.graphql.dgs.internal.VirtualThreadTaskExecutor/VirtualThreadTaskExecutor.java index cbde733e7..408670080 100644 --- a/graphql-dgs/src/main/java21/com.netflix.graphql.dgs.internal.VirtualThreadTaskExecutor/VirtualThreadTaskExecutor.java +++ b/graphql-dgs/src/main/java21/com.netflix.graphql.dgs.internal.VirtualThreadTaskExecutor/VirtualThreadTaskExecutor.java @@ -16,9 +16,11 @@ package com.netflix.graphql.dgs.internal; +import io.micrometer.context.ContextSnapshot; import io.micrometer.context.ContextSnapshotFactory; import org.jetbrains.annotations.NotNull; import org.springframework.core.task.AsyncTaskExecutor; +import org.springframework.lang.Nullable; import java.util.concurrent.Callable; import java.util.concurrent.Future; @@ -32,18 +34,24 @@ @SuppressWarnings("unused") public class VirtualThreadTaskExecutor implements AsyncTaskExecutor { private final ThreadFactory threadFactory; + + @Nullable private final ContextSnapshotFactory contextSnapshotFactory; - public VirtualThreadTaskExecutor(ContextSnapshotFactory contextSnapshotFactory) { + public VirtualThreadTaskExecutor(@Nullable ContextSnapshotFactory contextSnapshotFactory) { this.contextSnapshotFactory = contextSnapshotFactory; this.threadFactory = Thread.ofVirtual().name("dgs-virtual-thread-", 0).factory(); } @Override public void execute(@NotNull Runnable task) { - var contextSnapshot = contextSnapshotFactory.captureAll(); - var wrapped = contextSnapshot.wrap(task); - threadFactory.newThread(wrapped).start(); + if (contextSnapshotFactory != null) { + ContextSnapshot contextSnapshot = contextSnapshotFactory.captureAll(); + var wrapped = contextSnapshot.wrap(task); + threadFactory.newThread(wrapped).start(); + } else { + threadFactory.newThread(task).start(); + } } @Override