Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for conditional Transient header propagation #11490

Merged
merged 10 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix template setting override for replication type ([#11417](https://github.com/opensearch-project/OpenSearch/pull/11417))
- Fix Automatic addition of protocol broken in #11512 ([#11609](https://github.com/opensearch-project/OpenSearch/pull/11609))
- Fix issue when calling Delete PIT endpoint and no PITs exist ([#11711](https://github.com/opensearch-project/OpenSearch/pull/11711))
- Fix tracing context propagation for local transport instrumentation ([#11490](https://github.com/opensearch-project/OpenSearch/pull/11490))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public StoredContext stashContext() {
);
}

final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders, context.isSystemContext);
if (!transientHeaders.isEmpty()) {
threadContextStruct = threadContextStruct.putTransient(transientHeaders);
}
Expand All @@ -182,7 +182,7 @@ public StoredContext stashContext() {
public Writeable captureAsWriteable() {
final ThreadContextStruct context = threadLocal.get();
return out -> {
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
};
}
Expand Down Expand Up @@ -245,7 +245,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);

boolean transientHeadersModified = false;
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders, originalContext.isSystemContext);
if (!transientHeaders.isEmpty()) {
newTransientHeaders.putAll(transientHeaders);
transientHeadersModified = true;
Expand Down Expand Up @@ -322,7 +322,7 @@ public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
@Override
public void writeTo(StreamOutput out) throws IOException {
final ThreadContextStruct context = threadLocal.get();
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
}

Expand Down Expand Up @@ -534,7 +534,7 @@ boolean isDefaultContext() {
* by the system itself rather than by a user action.
*/
public void markAsSystemContext() {
threadLocal.set(threadLocal.get().setSystemContext());
threadLocal.set(threadLocal.get().setSystemContext(propagators));
}

/**
Expand Down Expand Up @@ -573,15 +573,15 @@ public static Map<String, String> buildDefaultHeaders(Settings settings) {
}
}

private Map<String, Object> propagateTransients(Map<String, Object> source) {
private Map<String, Object> propagateTransients(Map<String, Object> source, boolean isSystemContext) {
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(source)));
propagators.forEach(p -> transients.putAll(p.transients(source, isSystemContext)));
return transients;
}

private Map<String, String> propagateHeaders(Map<String, Object> source) {
private Map<String, String> propagateHeaders(Map<String, Object> source, boolean isSystemContext) {
final Map<String, String> headers = new HashMap<>();
propagators.forEach(p -> headers.putAll(p.headers(source)));
propagators.forEach(p -> headers.putAll(p.headers(source, isSystemContext)));
return headers;
}

Expand All @@ -603,11 +603,13 @@ private static final class ThreadContextStruct {
// saving current warning headers' size not to recalculate the size with every new warning header
private final long warningHeadersSize;

private ThreadContextStruct setSystemContext() {
private ThreadContextStruct setSystemContext(final List<ThreadContextStatePropagator> propagators) {
if (isSystemContext) {
return this;
}
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true);
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(transientHeaders, true)));
return new ThreadContextStruct(requestHeaders, responseHeaders, transients, persistentHeaders, true);
}

private ThreadContextStruct(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,39 @@
public interface ThreadContextStatePropagator {
/**
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
* @param source current context transient headers
*
* @param source current context transient headers
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
Map<String, Object> transients(Map<String, Object> source);
reta marked this conversation as resolved.
Show resolved Hide resolved

/**
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
*
* @param source current context transient headers
* @param isSystemContext if the propagation is for system context.
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
default Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
return transients(source);
};

/**
* Returns the list of request headers that needs to be propagated from current context to request.
* @param source current context headers
*
* @param source current context headers
* @return the list of request headers that needs to be propagated from current context to request
*/
Map<String, String> headers(Map<String, Object> source);

/**
* Returns the list of request headers that needs to be propagated from current context to request.
*
* @param source current context headers
* @param isSystemContext if the propagation is for system context.
* @return the list of request headers that needs to be propagated from current context to request
*/
default Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return headers(source);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextStatePropagator;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -52,17 +53,24 @@ public void put(String key, Span span) {
@Override
public Map<String, Object> transients(Map<String, Object> source) {
final Map<String, Object> transients = new HashMap<>();

if (source.containsKey(CURRENT_SPAN)) {
final SpanReference current = (SpanReference) source.get(CURRENT_SPAN);
if (current != null) {
transients.put(CURRENT_SPAN, new SpanReference(current.getSpan()));
}
}

return transients;
}

@Override
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
reta marked this conversation as resolved.
Show resolved Hide resolved
if (isSystemContext == true) {
return Collections.emptyMap();
} else {
return transients(source);
}
}

@Override
public Map<String, String> headers(Map<String, Object> source) {
final Map<String, String> headers = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -868,19 +868,10 @@ public final <T extends TransportResponse> void sendRequest(
final TransportRequestOptions options,
final TransportResponseHandler<T> handler
) {
if (connection == localNodeConnection) {
// See please https://github.com/opensearch-project/OpenSearch/issues/10291
sendRequestAsync(connection, action, request, options, handler);
} else {
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(
handler,
span,
tracer
);
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
}
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(handler, span, tracer);
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,92 @@ public void testMarkAsSystemContext() throws IOException {
assertFalse(threadContext.isSystemContext());
}

public void testSystemContextWithPropagator() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
threadContext.registerThreadContextStatePropagator(createDummyPropagator(("test_transient_propagation_key")));
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", 1);
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("bar", threadContext.getHeader("foo"));
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
threadContext.markAsSystemContext();
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

public void testSerializeSystemContext() throws IOException {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
threadContext.registerThreadContextStatePropagator(createDummyPropagator(("test_transient_propagation_key")));
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", "test");
BytesStreamOutput out = new BytesStreamOutput();
BytesStreamOutput outFromSystemContext = new BytesStreamOutput();
threadContext.writeTo(out);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.markAsSystemContext();
threadContext.writeTo(outFromSystemContext);
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(outFromSystemContext.bytes().streamInput());
assertNull(threadContext.getHeader("test_transient_propagation_key"));
}
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(out.bytes().streamInput());
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("test", threadContext.getHeader("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

private ThreadContextStatePropagator createDummyPropagator(final String key) {
Gaganjuneja marked this conversation as resolved.
Show resolved Hide resolved
return new ThreadContextStatePropagator() {

@Override
public Map<String, Object> transients(Map<String, Object> source) {
Map<String, Object> transients = new HashMap<>();
if (source.containsKey(key)) {
transients.put(key, source.get(key));
}
return transients;
}

@Override
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
if (isSystemContext == true) {
return Collections.emptyMap();
} else {
return transients(source);
}
}

@Override
public Map<String, String> headers(Map<String, Object> source) {
Map<String, String> headers = new HashMap<>();
if (source.containsKey(key)) {
headers.put(key, (String) source.get(key));
}
return headers;
}

@Override
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
if (isSystemContext == true) {
return Collections.emptyMap();
} else {
return headers(source);
}
}
};
}

public void testPutHeaders() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.tasks;

import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
import java.util.Map;

import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

public class TaskThreadContextStatePropagatorTests extends OpenSearchTestCase {
private final TaskThreadContextStatePropagator taskThreadContextStatePropagator = new TaskThreadContextStatePropagator();

public void testTransient() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, false);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}

public void testTransientForSystemContext() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, true);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,20 @@ public void run() {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}

public void testSpanNotPropagatedToChildSystemThreadContext() {
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));

try (SpanScope scope = tracer.withSpanInScope(span)) {
try (StoredContext ignored = threadContext.stashContext()) {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
threadContext.markAsSystemContext();
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}

assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}
Loading