From 6743fa1c76fd17d2aaff6449cb6d386e0de57b2e Mon Sep 17 00:00:00 2001 From: Peter Nied Date: Wed, 24 Apr 2024 21:05:24 +0000 Subject: [PATCH] Properly handle offramp for 2.11 custom serializater Signed-off-by: Peter Nied --- .../transport/SecuritySSLRequestHandler.java | 3 +- .../security/support/ConfigConstants.java | 2 + .../security/support/SerializationFormat.java | 35 ++++++++++++++++ .../transport/SecurityInterceptor.java | 42 +++++++------------ .../security/support/Base64HelperTest.java | 25 +++++++++++ .../transport/SecurityInterceptorTests.java | 26 +++++++----- .../SecuritySSLRequestHandlerTests.java | 24 ++++++++++- 7 files changed, 115 insertions(+), 42 deletions(-) create mode 100644 src/main/java/org/opensearch/security/support/SerializationFormat.java diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 078c822357..7002171595 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -34,6 +34,7 @@ import org.opensearch.security.ssl.util.ExceptionUtils; import org.opensearch.security.ssl.util.SSLRequestHelper; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.SerializationFormat; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportChannel; @@ -92,7 +93,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task threadContext.putTransient( ConfigConstants.USE_JDK_SERIALIZATION, - channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) + SerializationFormat.determineFormat(channel.getVersion()) == SerializationFormat.JDK ); if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { diff --git a/src/main/java/org/opensearch/security/support/ConfigConstants.java b/src/main/java/org/opensearch/security/support/ConfigConstants.java index 54f8131aba..9c671a80f9 100644 --- a/src/main/java/org/opensearch/security/support/ConfigConstants.java +++ b/src/main/java/org/opensearch/security/support/ConfigConstants.java @@ -332,6 +332,8 @@ public enum RolesMappingResolution { public static final String TENANCY_GLOBAL_TENANT_NAME = "global"; public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = ""; + public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization"; + // On-behalf-of endpoints settings // CS-SUPPRESS-SINGLE: RegexpSingleline get Extensions Settings public static final String EXTENSIONS_BWC_PLUGIN_MODE = "bwcPluginMode"; diff --git a/src/main/java/org/opensearch/security/support/SerializationFormat.java b/src/main/java/org/opensearch/security/support/SerializationFormat.java new file mode 100644 index 0000000000..210a5cf6a5 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/SerializationFormat.java @@ -0,0 +1,35 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.support; + +import org.opensearch.Version; + +public enum SerializationFormat { + /** Uses Java's native serialization system */ + JDK, + /** Uses a custom serializer built ontop of OpenSearch 2.11 */ + CustomSerializer_2_11; + + private static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0; + private static final Version CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION = Version.V_2_14_0; + + /** + * Determines the format of serialization that should be used from a version identifier + */ + public static SerializationFormat determineFormat(final Version version) { + if (version.onOrAfter(FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) + && version.before(CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION)) { + return SerializationFormat.CustomSerializer_2_11; + } + return SerializationFormat.JDK; + } +} diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index bca738c10e..a3a73d265c 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -63,6 +63,7 @@ import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.HeaderHelper; +import org.opensearch.security.support.SerializationFormat; import org.opensearch.security.user.User; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport.Connection; @@ -152,7 +153,7 @@ public void sendRequestDecorate( final boolean isDebugEnabled = log.isDebugEnabled(); - final var serializationFormat = shouldUseJdkSerialization(connection); + final var serializationFormat = SerializationFormat.determineFormat(connection.getVersion()); final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode()); try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { @@ -230,17 +231,20 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL ); } - if (serializationFormat == SerializationFormat.JDK) { - Map jdkSerializedHeaders = new HashMap<>(); - HeaderHelper.getAllSerializedHeaderNames() - .stream() - .filter(k -> headerMap.get(k) != null) - .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); - headerMap.putAll(jdkSerializedHeaders); + try { + if (serializationFormat == SerializationFormat.JDK) { + Map jdkSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames() + .stream() + .filter(k -> headerMap.get(k) != null) + .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); + headerMap.putAll(jdkSerializedHeaders); + } + getThreadContext().putHeader(headerMap); + } catch (IllegalArgumentException iae) { + log.debug("Failed to add headers information onto on thread context", iae); } - getThreadContext().putHeader(headerMap); - ensureCorrectHeaders( remoteAddress0, user0, @@ -270,24 +274,6 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL } } - private static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization"; - private static final Version FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION = Version.V_2_11_0; - private static final Version CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION = Version.V_2_14_0; - - private SerializationFormat shouldUseJdkSerialization(final Connection connection) { - var version = connection.getVersion(); - if (version.after(FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) - && version.before(CUSTOM_SERIALIZATION_NO_LONGER_SUPPORTED_OS_VERSION)) { - return SerializationFormat.CustomSerializer_2_11; - } - return SerializationFormat.JDK; - } - - private enum SerializationFormat { - JDK, - CustomSerializer_2_11 - } - private void ensureCorrectHeaders( final Object remoteAdr, final User origUser, diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index 3bc81aaebc..01bcc9ab90 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -11,12 +11,17 @@ package org.opensearch.security.support; import java.io.Serializable; +import java.util.HashMap; +import java.util.stream.IntStream; import org.junit.Assert; import org.junit.Test; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.closeTo; import static org.opensearch.security.support.Base64Helper.deserializeObject; import static org.opensearch.security.support.Base64Helper.serializeObject; +import static org.junit.Assert.assertThat; public class Base64HelperTest { @@ -48,4 +53,24 @@ public void testEnsureJDKSerialized() { Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized)); Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized)); } + + @Test + public void testDuplicatedItemSizes() { + var largeObject = new HashMap(); + var hm = new HashMap<>(); + IntStream.range(0, 100).forEach(i -> { hm.put("c" + i, "cvalue" + i); }); + IntStream.range(0, 100).forEach(i -> { + largeObject.put("b" + i, hm); + }); + + final var jdkSerialized = Base64Helper.serializeObject(largeObject, true); + final var customSerialized = Base64Helper.serializeObject(largeObject, false); + final var customSerializedOnlyHashMap = Base64Helper.serializeObject(hm, false); + + assertThat(jdkSerialized.length(), equalTo(3832)); + // The custom serializer is ~50x larger than the jdk serialized version + assertThat(customSerialized.length(), equalTo(184792)); + // Show that the majority of the size of the custom serialized large object is the map duplicated ~100 times + assertThat( (double)customSerializedOnlyHashMap.length(), closeTo(customSerialized.length() / 100, 70d)); + } } diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index 4b3636a000..35eb30cadd 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -12,8 +12,11 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; @@ -120,7 +123,7 @@ public class SecurityInterceptorTests { private AsyncSender sender; private AsyncSender serializedSender; - private AsyncSender nullSender; + private AtomicReference senderLatch = new AtomicReference<>(new CountDownLatch(1)); @Before public void setup() { @@ -208,6 +211,7 @@ public void sendRequest( ) { String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, true)); + senderLatch.get().countDown(); } }; @@ -222,6 +226,7 @@ public void sendRequest( ) { User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser, user); + senderLatch.get().countDown(); } }; @@ -249,17 +254,16 @@ final void completableRequestDecorate( TransportResponseHandler handler, DiscoveryNode localNode ) { + securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode); + verifyOriginalContext(user); + try { + senderLatch.get().await(1, TimeUnit.SECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } - ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor(); - - singleThreadExecutor.execute(() -> { - try { - securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode); - verifyOriginalContext(user); - } finally { - singleThreadExecutor.shutdown(); - } - }); + // Reset the latch so another request can be processed + senderLatch.set(new CountDownLatch(1)); } @Test diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java index 2d10b6f84f..c63c8d26ae 100644 --- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -94,9 +94,19 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); threadPool.getThreadContext().stashContext(); - when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0); Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } @Test @@ -118,9 +128,19 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); threadPool.getThreadContext().stashContext(); - when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + when(transportChannel.getVersion()).thenReturn(Version.V_2_13_0); Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_14_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } @Test