diff --git a/server/src/main/java/org/opensearch/common/StreamContext.java b/server/src/main/java/org/opensearch/common/StreamContext.java index b163ba65dc7db..47a3d2b8571ea 100644 --- a/server/src/main/java/org/opensearch/common/StreamContext.java +++ b/server/src/main/java/org/opensearch/common/StreamContext.java @@ -57,6 +57,8 @@ protected StreamContext(StreamContext streamContext) { /** * Vendor plugins can use this method to create new streams only when they are required for processing * New streams won't be created till this method is called with the specific partNumber + * It is the responsibility of caller to ensure that stream is properly closed after consumption + * otherwise it can leak resources. * * @param partNumber The index of the part * @return A stream reference to the part requested diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java index 5808f51f01efc..2047c99d9e13b 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainer.java @@ -19,6 +19,7 @@ import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeInputStream; +import org.opensearch.common.blobstore.transfer.stream.RateLimitingOffsetRangeInputStream; import org.opensearch.common.blobstore.transfer.stream.ResettableCheckedInputStream; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.common.util.ByteUtils; @@ -27,6 +28,8 @@ import java.io.IOException; import java.io.InputStream; import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import java.util.zip.CRC32; import com.jcraft.jzlib.JZlib; @@ -43,7 +46,7 @@ public class RemoteTransferContainer implements Closeable { private long lastPartSize; private final long contentLength; - private final SetOnce inputStreams = new SetOnce<>(); + private final SetOnce[]> checksumSuppliers = new SetOnce<>(); private final String fileName; private final String remoteFileName; private final boolean failTransferIfFileExists; @@ -51,6 +54,7 @@ public class RemoteTransferContainer implements Closeable { private final long expectedChecksum; private final OffsetRangeInputStreamSupplier offsetRangeInputStreamSupplier; private final boolean isRemoteDataIntegritySupported; + private final AtomicBoolean readBlock = new AtomicBoolean(); private static final Logger log = LogManager.getLogger(RemoteTransferContainer.class); @@ -120,23 +124,24 @@ StreamContext supplyStreamContext(long partSize) { } } + @SuppressWarnings({ "unchecked" }) private StreamContext openMultipartStreams(long partSize) throws IOException { - if (inputStreams.get() != null) { + if (checksumSuppliers.get() != null) { throw new IOException("Multi-part streams are already created."); } this.partSize = partSize; this.lastPartSize = (contentLength % partSize) != 0 ? contentLength % partSize : partSize; this.numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); - InputStream[] streams = new InputStream[numberOfParts]; - inputStreams.set(streams); + Supplier[] suppliers = new Supplier[numberOfParts]; + checksumSuppliers.set(suppliers); return new StreamContext(getTransferPartStreamSupplier(), partSize, lastPartSize, numberOfParts); } private CheckedTriFunction getTransferPartStreamSupplier() { return ((partNo, size, position) -> { - assert inputStreams.get() != null : "expected inputStreams to be initialised"; + assert checksumSuppliers.get() != null : "expected container to be initialised"; return getMultipartStreamSupplier(partNo, size, position).get(); }); } @@ -160,10 +165,21 @@ private LocalStreamSupplier getMultipartStreamSupplier( return () -> { try { OffsetRangeInputStream offsetRangeInputStream = offsetRangeInputStreamSupplier.get(size, position); - InputStream inputStream = !isRemoteDataIntegrityCheckPossible() - ? new ResettableCheckedInputStream(offsetRangeInputStream, fileName) - : offsetRangeInputStream; - Objects.requireNonNull(inputStreams.get())[streamIdx] = inputStream; + if (offsetRangeInputStream instanceof RateLimitingOffsetRangeInputStream) { + RateLimitingOffsetRangeInputStream rangeIndexInputStream = (RateLimitingOffsetRangeInputStream) offsetRangeInputStream; + rangeIndexInputStream.setReadBlock(readBlock); + } + InputStream inputStream; + if (isRemoteDataIntegrityCheckPossible() == false) { + ResettableCheckedInputStream resettableCheckedInputStream = new ResettableCheckedInputStream( + offsetRangeInputStream, + fileName + ); + Objects.requireNonNull(checksumSuppliers.get())[streamIdx] = resettableCheckedInputStream::getChecksum; + inputStream = resettableCheckedInputStream; + } else { + inputStream = offsetRangeInputStream; + } return new InputStreamContainer(inputStream, size, position); } catch (IOException e) { @@ -205,20 +221,14 @@ public long getContentLength() { return contentLength; } - private long getInputStreamChecksum(InputStream inputStream) { - assert inputStream instanceof ResettableCheckedInputStream - : "expected passed inputStream to be instance of ResettableCheckedInputStream"; - return ((ResettableCheckedInputStream) inputStream).getChecksum(); - } - private long getActualChecksum() { - InputStream[] currentInputStreams = Objects.requireNonNull(inputStreams.get()); - long checksum = getInputStreamChecksum(currentInputStreams[0]); - for (int checkSumIdx = 1; checkSumIdx < Objects.requireNonNull(inputStreams.get()).length - 1; checkSumIdx++) { - checksum = JZlib.crc32_combine(checksum, getInputStreamChecksum(currentInputStreams[checkSumIdx]), partSize); + Supplier[] ckSumSuppliers = Objects.requireNonNull(checksumSuppliers.get()); + long checksum = ckSumSuppliers[0].get(); + for (int checkSumIdx = 1; checkSumIdx < ckSumSuppliers.length - 1; checkSumIdx++) { + checksum = JZlib.crc32_combine(checksum, ckSumSuppliers[checkSumIdx].get(), partSize); } if (numberOfParts > 1) { - checksum = JZlib.crc32_combine(checksum, getInputStreamChecksum(currentInputStreams[numberOfParts - 1]), lastPartSize); + checksum = JZlib.crc32_combine(checksum, ckSumSuppliers[numberOfParts - 1].get(), lastPartSize); } return checksum; @@ -226,27 +236,8 @@ private long getActualChecksum() { @Override public void close() throws IOException { - if (inputStreams.get() == null) { - log.warn("Input streams cannot be closed since they are not yet set for multi stream upload"); - return; - } - - boolean closeStreamException = false; - for (InputStream is : Objects.requireNonNull(inputStreams.get())) { - try { - if (is != null) { - is.close(); - } - } catch (IOException ex) { - closeStreamException = true; - // Attempting to close all streams first before throwing exception. - log.error("Multipart stream failed to close ", ex); - } - } - - if (closeStreamException) { - throw new IOException("Closure of some of the multi-part streams failed."); - } + // Setting a read block on all streams ever created by the container. + readBlock.set(true); } /** diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java index 7518f9ac569b9..520c838ba8a81 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeIndexInputStream.java @@ -8,10 +8,16 @@ package org.opensearch.common.blobstore.transfer.stream; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.store.IndexInput; +import org.opensearch.common.concurrent.RefCountedReleasable; import org.opensearch.common.lucene.store.InputStreamIndexInput; +import org.opensearch.common.util.concurrent.RunOnce; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; /** * OffsetRangeIndexInputStream extends InputStream to read from a specified offset using IndexInput @@ -19,9 +25,12 @@ * @opensearch.internal */ public class OffsetRangeIndexInputStream extends OffsetRangeInputStream { - + private static final Logger logger = LogManager.getLogger(OffsetRangeIndexInputStream.class); private final InputStreamIndexInput inputStreamIndexInput; private final IndexInput indexInput; + private AtomicBoolean readBlock; + private final OffsetRangeRefCount offsetRangeRefCount; + private final RunOnce closeOnce; /** * Construct a new OffsetRangeIndexInputStream object @@ -35,16 +44,68 @@ public OffsetRangeIndexInputStream(IndexInput indexInput, long size, long positi indexInput.seek(position); this.indexInput = indexInput; this.inputStreamIndexInput = new InputStreamIndexInput(indexInput, size); + ClosingStreams closingStreams = new ClosingStreams(inputStreamIndexInput, indexInput); + offsetRangeRefCount = new OffsetRangeRefCount(closingStreams); + closeOnce = new RunOnce(offsetRangeRefCount::decRef); + } + + @Override + public void setReadBlock(AtomicBoolean readBlock) { + this.readBlock = readBlock; } @Override public int read(byte[] b, int off, int len) throws IOException { - return inputStreamIndexInput.read(b, off, len); + // There are two levels of check to ensure that we don't read an already closed stream and + // to not close the stream if it is already being read. + // 1. First check is a coarse-grained check outside reference check which allows us to fail fast if read + // was invoked after the stream was closed. We need a separate atomic boolean closed because we don't want a + // future read to succeed when #close has been invoked even if there are on-going reads. On-going reads would + // hold reference and since ref count will not be 0 even after close was invoked, future reads will go through + // without a check on closed. Also, we do need to set closed externally. It is shared across all streams of the + // file. Check on closed in this class makes sure that no other stream allows subsequent reads. closed is + // being set to true in RemoteTransferContainer#close which is invoked when we are done processing all + // parts/file. Processing completes when either all parts are completed successfully or if either of the parts + // failed. In successful case, subsequent read will anyway not go through since all streams would have been + // consumed fully but in case of failure, SDK can continue to invoke read and this would be a wasted compute + // and IO. + // 2. In second check, a tryIncRef is invoked which tries to increment reference under lock and fails if ref + // is already closed. If reference is successfully obtained by the stream then stream will not be closed. + // Ref counting ensures that stream isn't closed in between reads. + // + // All these protection mechanisms are required in order to prevent invalid access to streams happening + // from the new S3 async SDK. + ensureReadable(); + try (OffsetRangeRefCount ignored = getStreamReference()) { + return inputStreamIndexInput.read(b, off, len); + } + } + + private OffsetRangeRefCount getStreamReference() { + boolean successIncrement = offsetRangeRefCount.tryIncRef(); + if (successIncrement == false) { + throw alreadyClosed("OffsetRangeIndexInputStream is already unreferenced."); + } + return offsetRangeRefCount; + } + + private void ensureReadable() { + if (readBlock != null && readBlock.get() == true) { + logger.debug("Read attempted on a stream which was read blocked!"); + throw alreadyClosed("Read blocked stream."); + } + } + + AlreadyClosedException alreadyClosed(String msg) { + return new AlreadyClosedException(msg + this); } @Override public int read() throws IOException { - return inputStreamIndexInput.read(); + ensureReadable(); + try (OffsetRangeRefCount ignored = getStreamReference()) { + return inputStreamIndexInput.read(); + } } @Override @@ -67,9 +128,42 @@ public long getFilePointer() throws IOException { return indexInput.getFilePointer(); } + @Override + public String toString() { + return "OffsetRangeIndexInputStream{" + "indexInput=" + indexInput + ", readBlock=" + readBlock + '}'; + } + + private static class ClosingStreams { + private final InputStreamIndexInput inputStreamIndexInput; + private final IndexInput indexInput; + + public ClosingStreams(InputStreamIndexInput inputStreamIndexInput, IndexInput indexInput) { + this.inputStreamIndexInput = inputStreamIndexInput; + this.indexInput = indexInput; + } + } + + private static class OffsetRangeRefCount extends RefCountedReleasable { + private static final Logger logger = LogManager.getLogger(OffsetRangeRefCount.class); + + public OffsetRangeRefCount(ClosingStreams ref) { + super("OffsetRangeRefCount", ref, () -> { + try { + ref.inputStreamIndexInput.close(); + } catch (IOException ex) { + logger.error("Failed to close indexStreamIndexInput", ex); + } + try { + ref.indexInput.close(); + } catch (IOException ex) { + logger.error("Failed to close indexInput", ex); + } + }); + } + } + @Override public void close() throws IOException { - inputStreamIndexInput.close(); - indexInput.close(); + closeOnce.run(); } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java index e8b889db1f3b0..eacb972586a5a 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/OffsetRangeInputStream.java @@ -10,6 +10,7 @@ import java.io.IOException; import java.io.InputStream; +import java.util.concurrent.atomic.AtomicBoolean; /** * OffsetRangeInputStream is an abstract class that extends from {@link InputStream} @@ -19,4 +20,8 @@ */ public abstract class OffsetRangeInputStream extends InputStream { public abstract long getFilePointer() throws IOException; + + public void setReadBlock(AtomicBoolean readBlock) { + // Nothing + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java index b455999bbed0c..4a511ca1ac155 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java +++ b/server/src/main/java/org/opensearch/common/blobstore/transfer/stream/RateLimitingOffsetRangeInputStream.java @@ -12,6 +12,7 @@ import org.opensearch.common.StreamLimiter; import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; /** @@ -40,6 +41,10 @@ public RateLimitingOffsetRangeInputStream( this.delegate = delegate; } + public void setReadBlock(AtomicBoolean readBlock) { + delegate.setReadBlock(readBlock); + } + @Override public int read() throws IOException { int b = delegate.read(); diff --git a/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java b/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java index a33e5f453d1e1..074f659850c7b 100644 --- a/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java +++ b/server/src/test/java/org/opensearch/common/blobstore/transfer/RemoteTransferContainerTests.java @@ -8,21 +8,36 @@ package org.opensearch.common.blobstore.transfer; +import org.apache.lucene.store.AlreadyClosedException; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.RateLimiter; import org.opensearch.common.StreamContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeFileInputStream; +import org.opensearch.common.blobstore.transfer.stream.OffsetRangeIndexInputStream; import org.opensearch.common.blobstore.transfer.stream.OffsetRangeInputStream; +import org.opensearch.common.blobstore.transfer.stream.RateLimitingOffsetRangeInputStream; import org.opensearch.common.blobstore.transfer.stream.ResettableCheckedInputStream; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.test.OpenSearchTestCase; import org.junit.Before; import java.io.IOException; +import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.Arrays; import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import org.mockito.Mockito; +import org.mockito.stubbing.Answer; public class RemoteTransferContainerTests extends OpenSearchTestCase { @@ -92,25 +107,37 @@ private void testSupplyStreamContext( int partCount = streamContext.getNumberOfParts(); assertEquals(expectedPartCount, partCount); Thread[] threads = new Thread[partCount]; + InputStream[] streams = new InputStream[partCount]; long totalContentLength = remoteTransferContainer.getContentLength(); assert partSize * (partCount - 1) + lastPartSize == totalContentLength : "part sizes and last part size don't add up to total content length"; logger.info("partSize: {}, lastPartSize: {}, partCount: {}", partSize, lastPartSize, streamContext.getNumberOfParts()); - for (int partIdx = 0; partIdx < partCount; partIdx++) { - int finalPartIdx = partIdx; - long expectedPartSize = (partIdx == partCount - 1) ? lastPartSize : partSize; - threads[partIdx] = new Thread(() -> { + try { + for (int partIdx = 0; partIdx < partCount; partIdx++) { + int finalPartIdx = partIdx; + long expectedPartSize = (partIdx == partCount - 1) ? lastPartSize : partSize; + threads[partIdx] = new Thread(() -> { + try { + InputStreamContainer inputStreamContainer = streamContext.provideStream(finalPartIdx); + streams[finalPartIdx] = inputStreamContainer.getInputStream(); + assertEquals(expectedPartSize, inputStreamContainer.getContentLength()); + } catch (IOException e) { + fail("IOException during stream creation"); + } + }); + threads[partIdx].start(); + } + for (int i = 0; i < partCount; i++) { + threads[i].join(); + } + } finally { + Arrays.stream(streams).forEach(stream -> { try { - InputStreamContainer inputStreamContainer = streamContext.provideStream(finalPartIdx); - assertEquals(expectedPartSize, inputStreamContainer.getContentLength()); + stream.close(); } catch (IOException e) { - fail("IOException during stream creation"); + throw new RuntimeException(e); } }); - threads[partIdx].start(); - } - for (int i = 0; i < partCount; i++) { - threads[i].join(); } } @@ -182,6 +209,7 @@ public OffsetRangeInputStream get(long size, long position) throws IOException { } private void testTypeOfProvidedStreams(boolean isRemoteDataIntegritySupported) throws IOException { + InputStream inputStream = null; try ( RemoteTransferContainer remoteTransferContainer = new RemoteTransferContainer( testFile.getFileName().toString(), @@ -201,12 +229,132 @@ public OffsetRangeInputStream get(long size, long position) throws IOException { ) { StreamContext streamContext = remoteTransferContainer.supplyStreamContext(16); InputStreamContainer inputStreamContainer = streamContext.provideStream(0); + inputStream = inputStreamContainer.getInputStream(); if (shouldOffsetInputStreamsBeChecked(isRemoteDataIntegritySupported)) { assertTrue(inputStreamContainer.getInputStream() instanceof ResettableCheckedInputStream); } else { assertTrue(inputStreamContainer.getInputStream() instanceof OffsetRangeInputStream); } assertThrows(RuntimeException.class, () -> remoteTransferContainer.supplyStreamContext(16)); + } finally { + if (inputStream != null) { + inputStream.close(); + } + } + } + + public void testCloseDuringOngoingReadOnStream() throws IOException, InterruptedException { + Supplier rateLimiterSupplier = Mockito.mock(Supplier.class); + Mockito.when(rateLimiterSupplier.get()).thenReturn(null); + CountDownLatch readInvokedLatch = new CountDownLatch(1); + AtomicBoolean readAfterClose = new AtomicBoolean(); + CountDownLatch streamClosed = new CountDownLatch(1); + AtomicBoolean indexInputClosed = new AtomicBoolean(); + AtomicInteger closedCount = new AtomicInteger(); + try ( + RemoteTransferContainer remoteTransferContainer = new RemoteTransferContainer( + testFile.getFileName().toString(), + testFile.getFileName().toString(), + TEST_FILE_SIZE_BYTES, + true, + WritePriority.NORMAL, + new RemoteTransferContainer.OffsetRangeInputStreamSupplier() { + @Override + public OffsetRangeInputStream get(long size, long position) throws IOException { + IndexInput indexInput = Mockito.mock(IndexInput.class); + Mockito.doAnswer(invocation -> { + indexInputClosed.set(true); + closedCount.incrementAndGet(); + return null; + }).when(indexInput).close(); + Mockito.when(indexInput.getFilePointer()).thenAnswer((Answer) invocation -> { + if (readAfterClose.get() == false) { + return 0L; + } + readInvokedLatch.countDown(); + boolean closedSuccess = streamClosed.await(30, TimeUnit.SECONDS); + assertTrue(closedSuccess); + assertFalse(indexInputClosed.get()); + return 0L; + }); + + OffsetRangeIndexInputStream offsetRangeIndexInputStream = new OffsetRangeIndexInputStream( + indexInput, + size, + position + ); + return new RateLimitingOffsetRangeInputStream(offsetRangeIndexInputStream, rateLimiterSupplier, null); + } + }, + 0, + true + ) + ) { + StreamContext streamContext = remoteTransferContainer.supplyStreamContext(16); + InputStreamContainer inputStreamContainer = streamContext.provideStream(0); + assertTrue(inputStreamContainer.getInputStream() instanceof RateLimitingOffsetRangeInputStream); + CountDownLatch latch = new CountDownLatch(1); + new Thread(() -> { + try { + readAfterClose.set(true); + inputStreamContainer.getInputStream().readAllBytes(); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + latch.countDown(); + } + }).start(); + boolean successReadWait = readInvokedLatch.await(30, TimeUnit.SECONDS); + assertTrue(successReadWait); + // Closing stream here. Test Multiple invocations of close. Shouldn't throw any exception + inputStreamContainer.getInputStream().close(); + inputStreamContainer.getInputStream().close(); + inputStreamContainer.getInputStream().close(); + streamClosed.countDown(); + boolean processed = latch.await(30, TimeUnit.SECONDS); + assertTrue(processed); + assertTrue(readAfterClose.get()); + assertTrue(indexInputClosed.get()); + + // Test Multiple invocations of close. Close count should always be 1. + inputStreamContainer.getInputStream().close(); + inputStreamContainer.getInputStream().close(); + inputStreamContainer.getInputStream().close(); + assertEquals(1, closedCount.get()); + + } + } + + public void testReadAccessWhenStreamClosed() throws IOException { + Supplier rateLimiterSupplier = Mockito.mock(Supplier.class); + Mockito.when(rateLimiterSupplier.get()).thenReturn(null); + try ( + RemoteTransferContainer remoteTransferContainer = new RemoteTransferContainer( + testFile.getFileName().toString(), + testFile.getFileName().toString(), + TEST_FILE_SIZE_BYTES, + true, + WritePriority.NORMAL, + new RemoteTransferContainer.OffsetRangeInputStreamSupplier() { + @Override + public OffsetRangeInputStream get(long size, long position) throws IOException { + IndexInput indexInput = Mockito.mock(IndexInput.class); + OffsetRangeIndexInputStream offsetRangeIndexInputStream = new OffsetRangeIndexInputStream( + indexInput, + size, + position + ); + return new RateLimitingOffsetRangeInputStream(offsetRangeIndexInputStream, rateLimiterSupplier, null); + } + }, + 0, + true + ) + ) { + StreamContext streamContext = remoteTransferContainer.supplyStreamContext(16); + InputStreamContainer inputStreamContainer = streamContext.provideStream(0); + inputStreamContainer.getInputStream().close(); + assertThrows(AlreadyClosedException.class, () -> inputStreamContainer.getInputStream().readAllBytes()); } }