Skip to content

Commit

Permalink
Properly handle offramp for 2.11 custom serializater
Browse files Browse the repository at this point in the history
Signed-off-by: Peter Nied <[email protected]>
  • Loading branch information
peternied committed Apr 24, 2024
1 parent 76b0804 commit 6743fa1
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.security.ssl.util.ExceptionUtils;
import org.opensearch.security.ssl.util.SSLRequestHelper;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.SerializationFormat;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
Expand Down Expand Up @@ -92,7 +93,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task

threadContext.putTransient(
ConfigConstants.USE_JDK_SERIALIZATION,
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
SerializationFormat.determineFormat(channel.getVersion()) == SerializationFormat.JDK
);

if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ public enum RolesMappingResolution {
public static final String TENANCY_GLOBAL_TENANT_NAME = "global";
public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = "";

public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization";

// On-behalf-of endpoints settings
// CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings
public static final String EXTENSIONS_BWC_PLUGIN_MODE = "bwcPluginMode";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.support;

import org.opensearch.Version;

public enum SerializationFormat {
/** Uses Java's native serialization system */
JDK,
/** Uses a custom serializer built ontop of OpenSearch 2.11 */
CustomSerializer_2_11;

private static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0;
private static final Version CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION = Version.V_2_14_0;

/**
* Determines the format of serialization that should be used from a version identifier
*/
public static SerializationFormat determineFormat(final Version version) {
if (version.onOrAfter(FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
&& version.before(CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION)) {
return SerializationFormat.CustomSerializer_2_11;
}
return SerializationFormat.JDK;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.opensearch.security.support.Base64Helper;
import org.opensearch.security.support.ConfigConstants;
import org.opensearch.security.support.HeaderHelper;
import org.opensearch.security.support.SerializationFormat;
import org.opensearch.security.user.User;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Transport.Connection;
Expand Down Expand Up @@ -152,7 +153,7 @@ public <T extends TransportResponse> void sendRequestDecorate(

final boolean isDebugEnabled = log.isDebugEnabled();

final var serializationFormat = shouldUseJdkSerialization(connection);
final var serializationFormat = SerializationFormat.determineFormat(connection.getVersion());
final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());

try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
Expand Down Expand Up @@ -230,17 +231,20 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
);
}

if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
try {
if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
}
getThreadContext().putHeader(headerMap);
} catch (IllegalArgumentException iae) {
log.debug("Failed to add headers information onto on thread context", iae);
}

getThreadContext().putHeader(headerMap);

ensureCorrectHeaders(
remoteAddress0,
user0,
Expand Down Expand Up @@ -270,24 +274,6 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
}
}

private static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization";
private static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0;
private static final Version CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION = Version.V_2_14_0;

private SerializationFormat shouldUseJdkSerialization(final Connection connection) {
var version = connection.getVersion();
if (version.after(FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
&& version.before(CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION)) {
return SerializationFormat.CustomSerializer_2_11;
}
return SerializationFormat.JDK;
}

private enum SerializationFormat {
JDK,
CustomSerializer_2_11
}

private void ensureCorrectHeaders(
final Object remoteAdr,
final User origUser,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@
package org.opensearch.security.support;

import java.io.Serializable;
import java.util.HashMap;
import java.util.stream.IntStream;

import org.junit.Assert;
import org.junit.Test;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.opensearch.security.support.Base64Helper.deserializeObject;
import static org.opensearch.security.support.Base64Helper.serializeObject;
import static org.junit.Assert.assertThat;

public class Base64HelperTest {

Expand Down Expand Up @@ -48,4 +53,24 @@ public void testEnsureJDKSerialized() {
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized));
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized));
}

@Test
public void testDuplicatedItemSizes() {
var largeObject = new HashMap<String, Object>();
var hm = new HashMap<>();
IntStream.range(0, 100).forEach(i -> { hm.put("c" + i, "cvalue" + i); });
IntStream.range(0, 100).forEach(i -> {
largeObject.put("b" + i, hm);
});

final var jdkSerialized = Base64Helper.serializeObject(largeObject, true);
final var customSerialized = Base64Helper.serializeObject(largeObject, false);
final var customSerializedOnlyHashMap = Base64Helper.serializeObject(hm, false);

assertThat(jdkSerialized.length(), equalTo(3832));
// The custom serializer is ~50x larger than the jdk serialized version
assertThat(customSerialized.length(), equalTo(184792));
// Show that the majority of the size of the custom serialized large object is the map duplicated ~100 times
assertThat( (double)customSerializedOnlyHashMap.length(), closeTo(customSerialized.length() / 100, 70d));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -120,7 +123,7 @@ public class SecurityInterceptorTests {

private AsyncSender sender;
private AsyncSender serializedSender;
private AsyncSender nullSender;
private AtomicReference<CountDownLatch> senderLatch = new AtomicReference<>(new CountDownLatch(1));

@Before
public void setup() {
Expand Down Expand Up @@ -208,6 +211,7 @@ public <T extends TransportResponse> void sendRequest(
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, true));
senderLatch.get().countDown();
}
};

Expand All @@ -222,6 +226,7 @@ public <T extends TransportResponse> void sendRequest(
) {
User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER);
assertEquals(transientUser, user);
senderLatch.get().countDown();
}
};

Expand Down Expand Up @@ -249,17 +254,16 @@ final void completableRequestDecorate(
TransportResponseHandler handler,
DiscoveryNode localNode
) {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
verifyOriginalContext(user);
try {
senderLatch.get().await(1, TimeUnit.SECONDS);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}

ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor();

singleThreadExecutor.execute(() -> {
try {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
verifyOriginalContext(user);
} finally {
singleThreadExecutor.shutdown();
}
});
// Reset the latch so another request can be processed
senderLatch.set(new CountDownLatch(1));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,19 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
Expand All @@ -118,9 +128,19 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
Expand Down

0 comments on commit 6743fa1

Please sign in to comment.