diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 078c822357..7002171595 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -34,6 +34,7 @@ import org.opensearch.security.ssl.util.ExceptionUtils; import org.opensearch.security.ssl.util.SSLRequestHelper; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.SerializationFormat; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportChannel; @@ -92,7 +93,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task threadContext.putTransient( ConfigConstants.USE_JDK_SERIALIZATION, - channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) + SerializationFormat.determineFormat(channel.getVersion()) == SerializationFormat.JDK ); if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { diff --git a/src/main/java/org/opensearch/security/support/ConfigConstants.java b/src/main/java/org/opensearch/security/support/ConfigConstants.java index 9e68288d41..956dff1165 100644 --- a/src/main/java/org/opensearch/security/support/ConfigConstants.java +++ b/src/main/java/org/opensearch/security/support/ConfigConstants.java @@ -35,7 +35,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import org.opensearch.Version; import org.opensearch.common.settings.Settings; import org.opensearch.security.auditlog.impl.AuditCategory; @@ -332,7 +331,6 @@ public enum RolesMappingResolution { public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = ""; public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization"; - public static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0; // On-behalf-of endpoints settings // CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings diff --git a/src/main/java/org/opensearch/security/support/SerializationFormat.java b/src/main/java/org/opensearch/security/support/SerializationFormat.java new file mode 100644 index 0000000000..210a5cf6a5 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/SerializationFormat.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import org.opensearch.Version; + +public enum SerializationFormat { + /** Uses Java's native serialization system */ + JDK, + /** Uses a custom serializer built ontop of OpenSearch 2.11 */ + CustomSerializer_2_11; + + private static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0; + private static final Version CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION = Version.V_2_14_0; + + /** + * Determines the format of serialization that should be used from a version identifier + */ + public static SerializationFormat determineFormat(final Version version) { + if (version.onOrAfter(FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) + && version.before(CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION)) { + return SerializationFormat.CustomSerializer_2_11; + } + return SerializationFormat.JDK; + } +} diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index f791cd013a..f55d9ac338 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -62,6 +62,7 @@ import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.HeaderHelper; +import org.opensearch.security.support.SerializationFormat; import org.opensearch.security.user.User; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport.Connection; @@ -150,7 +151,8 @@ public void sendRequestDecorate( final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS); final boolean isDebugEnabled = log.isDebugEnabled(); - final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); + + final var serializationFormat = SerializationFormat.determineFormat(connection.getVersion()); final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode()); try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { @@ -228,17 +230,20 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL ); } - if (useJDKSerialization) { - Map jdkSerializedHeaders = new HashMap<>(); - HeaderHelper.getAllSerializedHeaderNames() - .stream() - .filter(k -> headerMap.get(k) != null) - .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); - headerMap.putAll(jdkSerializedHeaders); + try { + if (serializationFormat == SerializationFormat.JDK) { + Map jdkSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames() + .stream() + .filter(k -> headerMap.get(k) != null) + .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); + headerMap.putAll(jdkSerializedHeaders); + } + getThreadContext().putHeader(headerMap); + } catch (IllegalArgumentException iae) { + log.debug("Failed to add headers information onto on thread context", iae); } - getThreadContext().putHeader(headerMap); - ensureCorrectHeaders( remoteAddress0, user0, @@ -246,7 +251,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL injectedUserString, injectedRolesString, isSameNodeRequest, - useJDKSerialization + serializationFormat ); if (actionTraceEnabled.get()) { @@ -275,7 +280,7 @@ private void ensureCorrectHeaders( final String injectedUserString, final String injectedRolesString, final boolean isSameNodeRequest, - final boolean useJDKSerialization + final SerializationFormat format ) { // keep original address @@ -313,6 +318,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADE getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserString); } } else { + final var useJDKSerialization = format == SerializationFormat.JDK; if (transportAddress != null) { getThreadContext().putHeader( ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index 3bc81aaebc..32d96767d8 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -11,12 +11,17 @@ package org.opensearch.security.support; import java.io.Serializable; +import java.util.HashMap; +import java.util.stream.IntStream; import org.junit.Assert; import org.junit.Test; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; import static org.opensearch.security.support.Base64Helper.deserializeObject; import static org.opensearch.security.support.Base64Helper.serializeObject; +import static org.junit.Assert.assertThat; public class Base64HelperTest { @@ -48,4 +53,22 @@ public void testEnsureJDKSerialized() { Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized)); Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized)); } + + @Test + public void testDuplicatedItemSizes() { + var largeObject = new HashMap(); + var hm = new HashMap<>(); + IntStream.range(0, 100).forEach(i -> { hm.put("c" + i, "cvalue" + i); }); + IntStream.range(0, 100).forEach(i -> { largeObject.put("b" + i, hm); }); + + final var jdkSerialized = Base64Helper.serializeObject(largeObject, true); + final var customSerialized = Base64Helper.serializeObject(largeObject, false); + final var customSerializedOnlyHashMap = Base64Helper.serializeObject(hm, false); + + assertThat(jdkSerialized.length(), equalTo(3832)); + // The custom serializer is ~50x larger than the jdk serialized version + assertThat(customSerialized.length(), equalTo(184792)); + // Show that the majority of the size of the custom serialized large object is the map duplicated ~100 times + assertThat((double) customSerializedOnlyHashMap.length(), closeTo(customSerialized.length() / 100, 70d)); + } } diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index 4b3636a000..8d902ed498 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -12,8 +12,9 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; @@ -120,7 +121,7 @@ public class SecurityInterceptorTests { private AsyncSender sender; private AsyncSender serializedSender; - private AsyncSender nullSender; + private AtomicReference senderLatch = new AtomicReference<>(new CountDownLatch(1)); @Before public void setup() { @@ -208,6 +209,7 @@ public void sendRequest( ) { String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, true)); + senderLatch.get().countDown(); } }; @@ -222,6 +224,7 @@ public void sendRequest( ) { User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser, user); + senderLatch.get().countDown(); } }; @@ -249,17 +252,16 @@ final void completableRequestDecorate( TransportResponseHandler handler, DiscoveryNode localNode ) { + securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode); + verifyOriginalContext(user); + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } - ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor(); - - singleThreadExecutor.execute(() -> { - try { - securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode); - verifyOriginalContext(user); - } finally { - singleThreadExecutor.shutdown(); - } - }); + // Reset the latch so another request can be processed + senderLatch.set(new CountDownLatch(1)); } @Test diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java index ba791c2494..d096510495 100644 --- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -92,6 +92,16 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } @Test @@ -111,6 +121,16 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } @Test