diff --git a/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java b/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java new file mode 100644 index 0000000000000..db0ddfc3f0cd0 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/BaseInboundMessage.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +import org.opensearch.common.annotation.ExperimentalApi; + +/** + * Base class for inbound data as a message. + * Different implementations are used for different protocols. + * + * @opensearch.internal + */ +@ExperimentalApi +public interface BaseInboundMessage { + + /** + * The protocol used to encode this message + */ + static String NATIVE_PROTOCOL = "native"; + static String PROTOBUF_PROTOCOL = "protobuf"; + + /** + * @return the protocol used to encode this message + */ + public String getProtocol(); + + /** + * Set the protocol used to encode this message + */ + public void setProtocol(); +} diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index a8315c3cae4e0..815e78ec46571 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -32,35 +32,12 @@ package org.opensearch.transport; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.lucene.util.BytesRef; -import org.opensearch.Version; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.AbstractRunnable; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.common.io.stream.ByteBufferStreamInput; -import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.transport.TransportResponse; -import org.opensearch.telemetry.tracing.Span; -import org.opensearch.telemetry.tracing.SpanBuilder; -import org.opensearch.telemetry.tracing.SpanScope; import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; import org.opensearch.threadpool.ThreadPool; -import java.io.EOFException; import java.io.IOException; -import java.net.InetSocketAddress; -import java.nio.ByteBuffer; -import java.util.Collection; -import java.util.Collections; -import java.util.Map; -import java.util.stream.Collectors; /** * Handler for inbound data @@ -69,8 +46,6 @@ */ public class InboundHandler { - private static final Logger logger = LogManager.getLogger(InboundHandler.class); - private final ThreadPool threadPool; private final OutboundHandler outboundHandler; private final NamedWriteableRegistry namedWriteableRegistry; @@ -117,377 +92,32 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) { this.slowLogThresholdMs = slowLogThreshold.getMillis(); } - void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception { + void inboundMessage(TcpChannel channel, BaseInboundMessage message) throws Exception { final long startTime = threadPool.relativeTimeInMillis(); channel.getChannelStats().markAccessed(startTime); - TransportLogger.logInboundMessage(channel, message); - if (message.isPing()) { - keepAlive.receiveKeepAlive(channel); - } else { - messageReceived(channel, message, startTime); - } - } - - // Empty stream constant to avoid instantiating a new stream for empty messages. - private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES)); - - private void messageReceived(TcpChannel channel, InboundMessage message, long startTime) throws IOException { - final InetSocketAddress remoteAddress = channel.getRemoteAddress(); - final Header header = message.getHeader(); - assert header.needsToReadVariableHeader() == false; - ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext existing = threadContext.stashContext()) { - // Place the context with the headers from the message - threadContext.setHeaders(header.getHeaders()); - threadContext.putTransient("_remote_address", remoteAddress); - if (header.isRequest()) { - handleRequest(channel, header, message); - } else { - // Responses do not support short circuiting currently - assert message.isShortCircuit() == false; - final TransportResponseHandler handler; - long requestId = header.getRequestId(); - if (header.isHandshake()) { - handler = handshaker.removeHandlerForHandshake(requestId); - } else { - TransportResponseHandler theHandler = responseHandlers.onResponseReceived( - requestId, - messageListener - ); - if (theHandler == null && header.isError()) { - handler = handshaker.removeHandlerForHandshake(requestId); - } else { - handler = theHandler; - } - } - // ignore if its null, the service logs it - if (handler != null) { - final StreamInput streamInput; - if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) { - streamInput = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(streamInput, header.getVersion()); - if (header.isError()) { - handlerResponseError(requestId, streamInput, handler); - } else { - handleResponse(requestId, remoteAddress, streamInput, handler); - } - } else { - assert header.isError() == false; - handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler); - } - } - - } - } finally { - final long took = threadPool.relativeTimeInMillis() - startTime; - final long logThreshold = slowLogThresholdMs; - if (logThreshold > 0 && took > logThreshold) { - logger.warn( - "handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]", - message, - took, - logThreshold - ); - } - } + messageReceivedFromPipeline(channel, message, startTime); } - private Map> extractHeaders(Map headers) { - return headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> Collections.singleton(e.getValue()))); - } - - private void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException { - final String action = header.getActionName(); - final long requestId = header.getRequestId(); - final Version version = header.getVersion(); - final Map> headers = extractHeaders(header.getHeaders().v1()); - Span span = tracer.startSpan(SpanBuilder.from(action, channel), headers); - try (SpanScope spanScope = tracer.withSpanInScope(span)) { - if (header.isHandshake()) { - messageListener.onRequestReceived(requestId, action); - // Cannot short circuit handshakes - assert message.isShortCircuit() == false; - final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(stream, header.getVersion()); - final TcpTransportChannel transportChannel = new TcpTransportChannel( - outboundHandler, - channel, - action, - requestId, - version, - header.getFeatures(), - header.isCompressed(), - header.isHandshake(), - message.takeBreakerReleaseControl() - ); - TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); - try { - handshaker.handleHandshake(traceableTransportChannel, requestId, stream); - } catch (Exception e) { - if (Version.CURRENT.isCompatible(header.getVersion())) { - sendErrorResponse(action, traceableTransportChannel, e); - } else { - logger.warn( - new ParameterizedMessage( - "could not send error response to handshake received on [{}] using wire format version [{}], closing channel", - channel, - header.getVersion() - ), - e - ); - channel.close(); - } - } + private void messageReceivedFromPipeline(TcpChannel channel, BaseInboundMessage message, long startTime) throws IOException { + if ((message.getProtocol()).equals(BaseInboundMessage.PROTOBUF_PROTOCOL)) { + // protobuf messages logic can be added here + } else { + InboundMessage inboundMessage = (InboundMessage) message; + TransportLogger.logInboundMessage(channel, inboundMessage); + if (inboundMessage.isPing()) { + keepAlive.receiveKeepAlive(channel); } else { - final TcpTransportChannel transportChannel = new TcpTransportChannel( + NativeMessageHandler nativeInboundHandler = new NativeMessageHandler( + threadPool, outboundHandler, - channel, - action, - requestId, - version, - header.getFeatures(), - header.isCompressed(), - header.isHandshake(), - message.takeBreakerReleaseControl() - ); - TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); - try { - messageListener.onRequestReceived(requestId, action); - if (message.isShortCircuit()) { - sendErrorResponse(action, traceableTransportChannel, message.getException()); - } else { - final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(stream, header.getVersion()); - final RequestHandlerRegistry reg = requestHandlers.getHandler(action); - assert reg != null; - - final T request = newRequest(requestId, action, stream, reg); - request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); - checkStreamIsFullyConsumed(requestId, action, stream); - - final String executor = reg.getExecutor(); - if (ThreadPool.Names.SAME.equals(executor)) { - try { - reg.processMessageReceived(request, traceableTransportChannel); - } catch (Exception e) { - sendErrorResponse(reg.getAction(), traceableTransportChannel, e); - } - } else { - threadPool.executor(executor).execute(new RequestHandler<>(reg, request, traceableTransportChannel)); - } - } - } catch (Exception e) { - sendErrorResponse(action, traceableTransportChannel, e); - } - } - } catch (Exception e) { - span.setError(e); - span.endSpan(); - throw e; - } - } - - /** - * Creates new request instance out of input stream. Throws IllegalStateException if the end of - * the stream was reached before the request is fully deserialized from the stream. - * @param transport request type - * @param requestId request identifier - * @param action action name - * @param stream stream - * @param reg request handler registry - * @return new request instance - * @throws IOException IOException - * @throws IllegalStateException IllegalStateException - */ - private T newRequest( - final long requestId, - final String action, - final StreamInput stream, - final RequestHandlerRegistry reg - ) throws IOException { - try { - return reg.newRequest(stream); - } catch (final EOFException e) { - // Another favor of (de)serialization issues is when stream contains less bytes than - // the request handler needs to deserialize the payload. - throw new IllegalStateException( - "Message fully read (request) but more data is expected for requestId [" - + requestId - + "], action [" - + action - + "]; resetting", - e - ); - } - } - - /** - * Checks if the stream is fully consumed and throws the exceptions if that is not the case. - * @param requestId request identifier - * @param action action name - * @param stream stream - * @throws IOException IOException - */ - private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException { - // in case we throw an exception, i.e. when the limit is hit, we don't want to verify - final int nextByte = stream.read(); - - // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker - if (nextByte != -1) { - throw new IllegalStateException( - "Message not fully read (request) for requestId [" - + requestId - + "], action [" - + action - + "], available [" - + stream.available() - + "]; resetting" - ); - } - } - - /** - * Checks if the stream is fully consumed and throws the exceptions if that is not the case. - * @param requestId request identifier - * @param handler response handler - * @param stream stream - * @param error "true" if response represents error, "false" otherwise - * @throws IOException IOException - */ - private void checkStreamIsFullyConsumed( - final long requestId, - final TransportResponseHandler handler, - final StreamInput stream, - final boolean error - ) throws IOException { - if (stream != EMPTY_STREAM_INPUT) { - // Check the entire message has been read - final int nextByte = stream.read(); - // calling read() is useful to make sure the message is fully read, even if there is an EOS marker - if (nextByte != -1) { - throw new IllegalStateException( - "Message not fully read (response) for requestId [" - + requestId - + "], handler [" - + handler - + "], error [" - + error - + "]; resetting" + namedWriteableRegistry, + handshaker, + requestHandlers, + responseHandlers, + tracer ); + nativeInboundHandler.messageReceived(channel, inboundMessage, startTime, slowLogThresholdMs, messageListener); } } } - - private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) { - try { - transportChannel.sendResponse(e); - } catch (Exception inner) { - inner.addSuppressed(e); - logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", actionName), inner); - } - } - - private void handleResponse( - final long requestId, - InetSocketAddress remoteAddress, - final StreamInput stream, - final TransportResponseHandler handler - ) { - final T response; - try { - response = handler.read(stream); - response.remoteAddress(new TransportAddress(remoteAddress)); - checkStreamIsFullyConsumed(requestId, handler, stream, false); - } catch (Exception e) { - final Exception serializationException = new TransportSerializationException( - "Failed to deserialize response from handler [" + handler + "]", - e - ); - logger.warn(new ParameterizedMessage("Failed to deserialize response from [{}]", remoteAddress), serializationException); - handleException(handler, serializationException); - return; - } - final String executor = handler.executor(); - if (ThreadPool.Names.SAME.equals(executor)) { - doHandleResponse(handler, response); - } else { - threadPool.executor(executor).execute(() -> doHandleResponse(handler, response)); - } - } - - private void doHandleResponse(TransportResponseHandler handler, T response) { - try { - handler.handleResponse(response); - } catch (Exception e) { - handleException(handler, new ResponseHandlerFailureTransportException(e)); - } - } - - private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler handler) { - Exception error; - try { - error = stream.readException(); - checkStreamIsFullyConsumed(requestId, handler, stream, true); - } catch (Exception e) { - error = new TransportSerializationException( - "Failed to deserialize exception response from stream for handler [" + handler + "]", - e - ); - } - handleException(handler, error); - } - - private void handleException(final TransportResponseHandler handler, Throwable error) { - if (!(error instanceof RemoteTransportException)) { - error = new RemoteTransportException(error.getMessage(), error); - } - final RemoteTransportException rtx = (RemoteTransportException) error; - threadPool.executor(handler.executor()).execute(() -> { - try { - handler.handleException(rtx); - } catch (Exception e) { - logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e); - } - }); - } - - private StreamInput namedWriteableStream(StreamInput delegate) { - return new NamedWriteableAwareStreamInput(delegate, namedWriteableRegistry); - } - - static void assertRemoteVersion(StreamInput in, Version version) { - assert version.equals(in.getVersion()) : "Stream version [" + in.getVersion() + "] does not match version [" + version + "]"; - } - - /** - * Internal request handler - * - * @opensearch.internal - */ - private static class RequestHandler extends AbstractRunnable { - private final RequestHandlerRegistry reg; - private final T request; - private final TransportChannel transportChannel; - - RequestHandler(RequestHandlerRegistry reg, T request, TransportChannel transportChannel) { - this.reg = reg; - this.request = request; - this.transportChannel = transportChannel; - } - - @Override - protected void doRun() throws Exception { - reg.processMessageReceived(request, transportChannel); - } - - @Override - public boolean isForceExecution() { - return reg.isForceExecution(); - } - - @Override - public void onFailure(Exception e) { - sendErrorResponse(reg.getAction(), transportChannel, e); - } - } } diff --git a/server/src/main/java/org/opensearch/transport/InboundMessage.java b/server/src/main/java/org/opensearch/transport/InboundMessage.java index 71c4d6973505d..ce1cb4173a9f0 100644 --- a/server/src/main/java/org/opensearch/transport/InboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/InboundMessage.java @@ -47,7 +47,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public class InboundMessage implements Releasable { +public class InboundMessage implements Releasable, BaseInboundMessage { private final Header header; private final ReleasableBytesReference content; @@ -55,6 +55,7 @@ public class InboundMessage implements Releasable { private final boolean isPing; private Releasable breakerRelease; private StreamInput streamInput; + private String protocol; public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { this.header = header; @@ -133,4 +134,14 @@ public void close() { public String toString() { return "InboundMessage{" + header + "}"; } + + @Override + public String getProtocol() { + return NATIVE_PROTOCOL; + } + + @Override + public void setProtocol() { + this.protocol = NATIVE_PROTOCOL; + } } diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index dd4690e5e6abf..c4f1a57cd857d 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -62,10 +62,11 @@ public class InboundPipeline implements Releasable { private final StatsTracker statsTracker; private final InboundDecoder decoder; private final InboundAggregator aggregator; - private final BiConsumer messageHandler; + private final BiConsumer messageHandler; private Exception uncaughtException; private final ArrayDeque pending = new ArrayDeque<>(2); private boolean isClosed = false; + private Version version; public InboundPipeline( Version version, @@ -74,14 +75,15 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, Supplier circuitBreaker, Function> registryFunction, - BiConsumer messageHandler + BiConsumer messageHandler ) { this( statsTracker, relativeTimeInMillis, new InboundDecoder(version, recycler), new InboundAggregator(circuitBreaker, registryFunction), - messageHandler + messageHandler, + version ); } @@ -90,13 +92,15 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, InboundDecoder decoder, InboundAggregator aggregator, - BiConsumer messageHandler + BiConsumer messageHandler, + Version version ) { this.relativeTimeInMillis = relativeTimeInMillis; this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; this.messageHandler = messageHandler; + this.version = version; } @Override @@ -124,37 +128,42 @@ public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference statsTracker.markBytesRead(reference.length()); pending.add(reference.retain()); - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + String incomingMessageProtocol = TcpTransport.determineTransportProtocol(reference); + if (incomingMessageProtocol.equals(BaseInboundMessage.PROTOBUF_PROTOCOL) && this.version.onOrAfter(Version.V_3_0_0)) { + // protobuf messages logic can be added here + } else { + final ArrayList fragments = fragmentList.get(); + boolean continueHandling = true; + + while (continueHandling && isClosed == false) { + boolean continueDecoding = true; + while (continueDecoding && pending.isEmpty() == false) { + try (ReleasableBytesReference toDecode = getPendingBytes()) { + final int bytesDecoded = decoder.decode(toDecode, fragments::add); + if (bytesDecoded != 0) { + releasePendingBytes(bytesDecoded); + if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + continueDecoding = false; + } + } else { continueDecoding = false; } - } else { - continueDecoding = false; } } - } - if (fragments.isEmpty()) { - continueHandling = false; - } else { - try { - forwardFragments(channel, fragments); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); + if (fragments.isEmpty()) { + continueHandling = false; + } else { + try { + forwardFragments(channel, fragments); + } finally { + for (Object fragment : fragments) { + if (fragment instanceof ReleasableBytesReference) { + ((ReleasableBytesReference) fragment).close(); + } } + fragments.clear(); } - fragments.clear(); } } } diff --git a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java new file mode 100644 index 0000000000000..27ee3c0873bc9 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java @@ -0,0 +1,474 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.BytesRef; +import org.opensearch.Version; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.io.stream.ByteBufferStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.transport.TransportResponse; +import org.opensearch.telemetry.tracing.Span; +import org.opensearch.telemetry.tracing.SpanBuilder; +import org.opensearch.telemetry.tracing.SpanScope; +import org.opensearch.telemetry.tracing.Tracer; +import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; +import org.opensearch.threadpool.ThreadPool; + +import java.io.EOFException; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Native handler for inbound data + * + * @opensearch.internal + */ +public class NativeMessageHandler { + + private static final Logger logger = LogManager.getLogger(NativeMessageHandler.class); + + private final ThreadPool threadPool; + private final OutboundHandler outboundHandler; + private final NamedWriteableRegistry namedWriteableRegistry; + private final TransportHandshaker handshaker; + private final Transport.ResponseHandlers responseHandlers; + private final Transport.RequestHandlers requestHandlers; + + private final Tracer tracer; + + NativeMessageHandler( + ThreadPool threadPool, + OutboundHandler outboundHandler, + NamedWriteableRegistry namedWriteableRegistry, + TransportHandshaker handshaker, + Transport.RequestHandlers requestHandlers, + Transport.ResponseHandlers responseHandlers, + Tracer tracer + ) { + this.threadPool = threadPool; + this.outboundHandler = outboundHandler; + this.namedWriteableRegistry = namedWriteableRegistry; + this.handshaker = handshaker; + this.requestHandlers = requestHandlers; + this.responseHandlers = responseHandlers; + this.tracer = tracer; + } + + // Empty stream constant to avoid instantiating a new stream for empty messages. + private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES)); + + public void messageReceived( + TcpChannel channel, + InboundMessage message, + long startTime, + long slowLogThresholdMs, + TransportMessageListener messageListener + ) throws IOException { + final InetSocketAddress remoteAddress = channel.getRemoteAddress(); + final Header header = message.getHeader(); + assert header.needsToReadVariableHeader() == false; + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + // Place the context with the headers from the message + threadContext.setHeaders(header.getHeaders()); + threadContext.putTransient("_remote_address", remoteAddress); + if (header.isRequest()) { + handleRequest(channel, header, message, messageListener); + } else { + // Responses do not support short circuiting currently + assert message.isShortCircuit() == false; + final TransportResponseHandler handler; + long requestId = header.getRequestId(); + if (header.isHandshake()) { + handler = handshaker.removeHandlerForHandshake(requestId); + } else { + TransportResponseHandler theHandler = responseHandlers.onResponseReceived( + requestId, + messageListener + ); + if (theHandler == null && header.isError()) { + handler = handshaker.removeHandlerForHandshake(requestId); + } else { + handler = theHandler; + } + } + // ignore if its null, the service logs it + if (handler != null) { + final StreamInput streamInput; + if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) { + streamInput = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(streamInput, header.getVersion()); + if (header.isError()) { + handlerResponseError(requestId, streamInput, handler); + } else { + handleResponse(requestId, remoteAddress, streamInput, handler); + } + } else { + assert header.isError() == false; + handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler); + } + } + + } + } finally { + final long took = threadPool.relativeTimeInMillis() - startTime; + final long logThreshold = slowLogThresholdMs; + if (logThreshold > 0 && took > logThreshold) { + logger.warn( + "handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]", + message, + took, + logThreshold + ); + } + } + } + + private Map> extractHeaders(Map headers) { + return headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> Collections.singleton(e.getValue()))); + } + + private void handleRequest( + TcpChannel channel, + Header header, + InboundMessage message, + TransportMessageListener messageListener + ) throws IOException { + final String action = header.getActionName(); + final long requestId = header.getRequestId(); + final Version version = header.getVersion(); + final Map> headers = extractHeaders(header.getHeaders().v1()); + Span span = tracer.startSpan(SpanBuilder.from(action, channel), headers); + try (SpanScope spanScope = tracer.withSpanInScope(span)) { + if (header.isHandshake()) { + messageListener.onRequestReceived(requestId, action); + // Cannot short circuit handshakes + assert message.isShortCircuit() == false; + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final TcpTransportChannel transportChannel = new TcpTransportChannel( + outboundHandler, + channel, + action, + requestId, + version, + header.getFeatures(), + header.isCompressed(), + header.isHandshake(), + message.takeBreakerReleaseControl() + ); + TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); + try { + handshaker.handleHandshake(traceableTransportChannel, requestId, stream); + } catch (Exception e) { + if (Version.CURRENT.isCompatible(header.getVersion())) { + sendErrorResponse(action, traceableTransportChannel, e); + } else { + logger.warn( + new ParameterizedMessage( + "could not send error response to handshake received on [{}] using wire format version [{}], closing channel", + channel, + header.getVersion() + ), + e + ); + channel.close(); + } + } + } else { + final TcpTransportChannel transportChannel = new TcpTransportChannel( + outboundHandler, + channel, + action, + requestId, + version, + header.getFeatures(), + header.isCompressed(), + header.isHandshake(), + message.takeBreakerReleaseControl() + ); + TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); + try { + messageListener.onRequestReceived(requestId, action); + if (message.isShortCircuit()) { + sendErrorResponse(action, traceableTransportChannel, message.getException()); + } else { + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final RequestHandlerRegistry reg = requestHandlers.getHandler(action); + assert reg != null; + + final T request = newRequest(requestId, action, stream, reg); + request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); + checkStreamIsFullyConsumed(requestId, action, stream); + + final String executor = reg.getExecutor(); + if (ThreadPool.Names.SAME.equals(executor)) { + try { + reg.processMessageReceived(request, traceableTransportChannel); + } catch (Exception e) { + sendErrorResponse(reg.getAction(), traceableTransportChannel, e); + } + } else { + threadPool.executor(executor).execute(new RequestHandler<>(reg, request, traceableTransportChannel)); + } + } + } catch (Exception e) { + sendErrorResponse(action, traceableTransportChannel, e); + } + } + } catch (Exception e) { + span.setError(e); + span.endSpan(); + throw e; + } + } + + /** + * Creates new request instance out of input stream. Throws IllegalStateException if the end of + * the stream was reached before the request is fully deserialized from the stream. + * @param transport request type + * @param requestId request identifier + * @param action action name + * @param stream stream + * @param reg request handler registry + * @return new request instance + * @throws IOException IOException + * @throws IllegalStateException IllegalStateException + */ + private T newRequest( + final long requestId, + final String action, + final StreamInput stream, + final RequestHandlerRegistry reg + ) throws IOException { + try { + return reg.newRequest(stream); + } catch (final EOFException e) { + // Another favor of (de)serialization issues is when stream contains less bytes than + // the request handler needs to deserialize the payload. + throw new IllegalStateException( + "Message fully read (request) but more data is expected for requestId [" + + requestId + + "], action [" + + action + + "]; resetting", + e + ); + } + } + + /** + * Checks if the stream is fully consumed and throws the exceptions if that is not the case. + * @param requestId request identifier + * @param action action name + * @param stream stream + * @throws IOException IOException + */ + private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException { + // in case we throw an exception, i.e. when the limit is hit, we don't want to verify + final int nextByte = stream.read(); + + // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker + if (nextByte != -1) { + throw new IllegalStateException( + "Message not fully read (request) for requestId [" + + requestId + + "], action [" + + action + + "], available [" + + stream.available() + + "]; resetting" + ); + } + } + + /** + * Checks if the stream is fully consumed and throws the exceptions if that is not the case. + * @param requestId request identifier + * @param handler response handler + * @param stream stream + * @param error "true" if response represents error, "false" otherwise + * @throws IOException IOException + */ + private void checkStreamIsFullyConsumed( + final long requestId, + final TransportResponseHandler handler, + final StreamInput stream, + final boolean error + ) throws IOException { + if (stream != EMPTY_STREAM_INPUT) { + // Check the entire message has been read + final int nextByte = stream.read(); + // calling read() is useful to make sure the message is fully read, even if there is an EOS marker + if (nextByte != -1) { + throw new IllegalStateException( + "Message not fully read (response) for requestId [" + + requestId + + "], handler [" + + handler + + "], error [" + + error + + "]; resetting" + ); + } + } + } + + private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) { + try { + transportChannel.sendResponse(e); + } catch (Exception inner) { + inner.addSuppressed(e); + logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", actionName), inner); + } + } + + private void handleResponse( + final long requestId, + InetSocketAddress remoteAddress, + final StreamInput stream, + final TransportResponseHandler handler + ) { + final T response; + try { + response = handler.read(stream); + response.remoteAddress(new TransportAddress(remoteAddress)); + checkStreamIsFullyConsumed(requestId, handler, stream, false); + } catch (Exception e) { + final Exception serializationException = new TransportSerializationException( + "Failed to deserialize response from handler [" + handler + "]", + e + ); + logger.warn(new ParameterizedMessage("Failed to deserialize response from [{}]", remoteAddress), serializationException); + handleException(handler, serializationException); + return; + } + final String executor = handler.executor(); + if (ThreadPool.Names.SAME.equals(executor)) { + doHandleResponse(handler, response); + } else { + threadPool.executor(executor).execute(() -> doHandleResponse(handler, response)); + } + } + + private void doHandleResponse(TransportResponseHandler handler, T response) { + try { + handler.handleResponse(response); + } catch (Exception e) { + handleException(handler, new ResponseHandlerFailureTransportException(e)); + } + } + + private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler handler) { + Exception error; + try { + error = stream.readException(); + checkStreamIsFullyConsumed(requestId, handler, stream, true); + } catch (Exception e) { + error = new TransportSerializationException( + "Failed to deserialize exception response from stream for handler [" + handler + "]", + e + ); + } + handleException(handler, error); + } + + private void handleException(final TransportResponseHandler handler, Throwable error) { + if (!(error instanceof RemoteTransportException)) { + error = new RemoteTransportException(error.getMessage(), error); + } + final RemoteTransportException rtx = (RemoteTransportException) error; + threadPool.executor(handler.executor()).execute(() -> { + try { + handler.handleException(rtx); + } catch (Exception e) { + logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e); + } + }); + } + + private StreamInput namedWriteableStream(StreamInput delegate) { + return new NamedWriteableAwareStreamInput(delegate, namedWriteableRegistry); + } + + static void assertRemoteVersion(StreamInput in, Version version) { + assert version.equals(in.getVersion()) : "Stream version [" + in.getVersion() + "] does not match version [" + version + "]"; + } + + /** + * Internal request handler + * + * @opensearch.internal + */ + private static class RequestHandler extends AbstractRunnable { + private final RequestHandlerRegistry reg; + private final T request; + private final TransportChannel transportChannel; + + RequestHandler(RequestHandlerRegistry reg, T request, TransportChannel transportChannel) { + this.reg = reg; + this.request = request; + this.transportChannel = transportChannel; + } + + @Override + protected void doRun() throws Exception { + reg.processMessageReceived(request, transportChannel); + } + + @Override + public boolean isForceExecution() { + return reg.isForceExecution(); + } + + @Override + public void onFailure(Exception e) { + sendErrorResponse(reg.getAction(), transportChannel, e); + } + } + +} diff --git a/server/src/main/java/org/opensearch/transport/TcpHeader.java b/server/src/main/java/org/opensearch/transport/TcpHeader.java index 78353a9a80403..1dd6e89ca9bee 100644 --- a/server/src/main/java/org/opensearch/transport/TcpHeader.java +++ b/server/src/main/java/org/opensearch/transport/TcpHeader.java @@ -73,6 +73,7 @@ public static int headerSize(Version version) { } private static final byte[] PREFIX = { (byte) 'E', (byte) 'S' }; + private static final byte[] PROTOBUF_PREFIX = { (byte) 'O', (byte) 'S', (byte) 'P' }; public static void writeHeader( StreamOutput output, @@ -91,4 +92,8 @@ public static void writeHeader( assert variableHeaderSize != -1 : "Variable header size not set"; output.writeInt(variableHeaderSize); } + + public static void writeHeaderForProtobuf(StreamOutput output) throws IOException { + output.writeBytes(PROTOBUF_PREFIX); + } } diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index 7d45152089f37..517f0c34a938e 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -767,7 +767,7 @@ protected void serverAcceptedChannel(TcpChannel channel) { * @param channel the channel the message is from * @param message the message */ - public void inboundMessage(TcpChannel channel, InboundMessage message) { + public void inboundMessage(TcpChannel channel, BaseInboundMessage message) { try { inboundHandler.inboundMessage(channel, message); } catch (Exception e) { @@ -794,6 +794,14 @@ public static int readMessageLength(BytesReference networkBytes) throws IOExcept } } + public static String determineTransportProtocol(BytesReference headerBuffer) { + if (headerBuffer.get(0) == 'O' && headerBuffer.get(1) == 'S' && headerBuffer.get(2) == 'P') { + return BaseInboundMessage.PROTOBUF_PROTOCOL; + } else { + return BaseInboundMessage.NATIVE_PROTOCOL; + } + } + private static int readHeaderBuffer(BytesReference headerBuffer) throws IOException { if (headerBuffer.get(0) != 'E' || headerBuffer.get(1) != 'S') { if (appearsToBeHTTPRequest(headerBuffer)) { diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index e002297911788..0d171e17e70e1 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -275,11 +275,11 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws // response so we must just close the connection on an error. To avoid the failure disappearing into a black hole we at least log // it. - try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(InboundHandler.class))) { + try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(NativeMessageHandler.class))) { mockAppender.addExpectation( new MockLogAppender.SeenEventExpectation( "expected message", - InboundHandler.class.getCanonicalName(), + NativeMessageHandler.class.getCanonicalName(), Level.WARN, "could not send error response to handshake" ) @@ -308,11 +308,11 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws } public void testLogsSlowInboundProcessing() throws Exception { - try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(InboundHandler.class))) { + try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(NativeMessageHandler.class))) { mockAppender.addExpectation( new MockLogAppender.SeenEventExpectation( "expected message", - InboundHandler.class.getCanonicalName(), + NativeMessageHandler.class.getCanonicalName(), Level.WARN, "handling inbound transport message " ) diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index ae4b537223394..fe933625be6fa 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -72,24 +72,25 @@ public void testPipelineHandling() throws IOException { final List> expected = new ArrayList<>(); final List> actual = new ArrayList<>(); final List toRelease = new ArrayList<>(); - final BiConsumer messageHandler = (c, m) -> { + final BiConsumer messageHandler = (c, m) -> { try { - final Header header = m.getHeader(); + InboundMessage message = (InboundMessage) m; + final Header header = message.getHeader(); final MessageData actualData; final Version version = header.getVersion(); final boolean isRequest = header.isRequest(); final long requestId = header.getRequestId(); final boolean isCompressed = header.isCompressed(); - if (m.isShortCircuit()) { + if (message.isShortCircuit()) { actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), null); } else if (isRequest) { - final TestRequest request = new TestRequest(m.openOrGetStreamInput()); + final TestRequest request = new TestRequest(message.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), request.value); } else { - final TestResponse response = new TestResponse(m.openOrGetStreamInput()); + final TestResponse response = new TestResponse(message.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, null, response.value); } - actual.add(new Tuple<>(actualData, m.getException())); + actual.add(new Tuple<>(actualData, message.getException())); } catch (IOException e) { throw new AssertionError(e); } @@ -104,7 +105,14 @@ public void testPipelineHandling() throws IOException { final TestCircuitBreaker circuitBreaker = new TestCircuitBreaker(); circuitBreaker.startBreaking(); final InboundAggregator aggregator = new InboundAggregator(() -> circuitBreaker, canTripBreaker); - final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + final InboundPipeline pipeline = new InboundPipeline( + statsTracker, + millisSupplier, + decoder, + aggregator, + messageHandler, + Version.CURRENT + ); final FakeTcpChannel channel = new FakeTcpChannel(); final int iterations = randomIntBetween(5, 10); @@ -214,13 +222,20 @@ public void testPipelineHandling() throws IOException { } public void testDecodeExceptionIsPropagated() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final Supplier breaker = () -> new NoopCircuitBreaker("test"); final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); - final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + final InboundPipeline pipeline = new InboundPipeline( + statsTracker, + millisSupplier, + decoder, + aggregator, + messageHandler, + Version.CURRENT + ); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; @@ -268,13 +283,20 @@ public void testDecodeExceptionIsPropagated() throws IOException { } public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); final Supplier breaker = () -> new NoopCircuitBreaker("test"); final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); - final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + final InboundPipeline pipeline = new InboundPipeline( + statsTracker, + millisSupplier, + decoder, + aggregator, + messageHandler, + Version.CURRENT + ); try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { String actionName = "actionName"; diff --git a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java index ff99435f765d8..846a93ec09c8c 100644 --- a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java @@ -97,12 +97,13 @@ public void setUp() throws Exception { final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { - Streams.copy(m.openOrGetStreamInput(), streamOutput); - message.set(new Tuple<>(m.getHeader(), streamOutput.bytes())); + InboundMessage m1 = (InboundMessage) m; + Streams.copy(m1.openOrGetStreamInput(), streamOutput); + message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); } catch (IOException e) { throw new AssertionError(e); } - }); + }, Version.CURRENT); } @After