From 4a672763a3d1cffbf3d26b8d8c76c9388d6bc9e0 Mon Sep 17 00:00:00 2001 From: Mariano Barrios Date: Mon, 15 Apr 2024 01:08:08 +0200 Subject: [PATCH] Migrate to Java: ConcurrentTest --- src/test/scala/tlschannel/ConcurrentTest.java | 148 ++++++++++++++++++ .../scala/tlschannel/ConcurrentTest.scala | 106 ------------- 2 files changed, 148 insertions(+), 106 deletions(-) create mode 100644 src/test/scala/tlschannel/ConcurrentTest.java delete mode 100644 src/test/scala/tlschannel/ConcurrentTest.scala diff --git a/src/test/scala/tlschannel/ConcurrentTest.java b/src/test/scala/tlschannel/ConcurrentTest.java new file mode 100644 index 00000000..393e7be1 --- /dev/null +++ b/src/test/scala/tlschannel/ConcurrentTest.java @@ -0,0 +1,148 @@ +package tlschannel; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Logger; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import scala.Option; +import tlschannel.helpers.*; + +@TestInstance(Lifecycle.PER_CLASS) +public class ConcurrentTest { + + private static final Logger logger = Logger.getLogger(ConcurrentTest.class.getName()); + + private final SslContextFactory sslContextFactory = new SslContextFactory(); + private final SocketPairFactory factory = new SocketPairFactory(sslContextFactory.defaultContext()); + private final int dataSize = 250_000_000; + private final int bufferSize = 2000; + + /** Test several parties writing concurrently + */ + // write-side thread safety + @Test + public void testWriteSide() throws IOException { + SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); + Thread clientWriterThread1 = + new Thread(() -> writerLoop(dataSize, 'a', socketPair.client()), "client-writer-1"); + Thread clientWriterThread2 = + new Thread(() -> writerLoop(dataSize, 'b', socketPair.client()), "client-writer-2"); + Thread clientWriterThread3 = + new Thread(() -> writerLoop(dataSize, 'c', socketPair.client()), "client-writer-3"); + Thread clientWriterThread4 = + new Thread(() -> writerLoop(dataSize, 'd', socketPair.client()), "client-writer-4"); + Thread serverReaderThread = new Thread(() -> readerLoop(dataSize * 4, socketPair.server()), "server-reader"); + Stream.of( + serverReaderThread, + clientWriterThread1, + clientWriterThread2, + clientWriterThread3, + clientWriterThread4) + .forEach(t -> t.start()); + Stream.of(clientWriterThread1, clientWriterThread2, clientWriterThread3, clientWriterThread4) + .forEach(t -> joinInterruptible(t)); + socketPair.client().external().close(); + joinInterruptible(serverReaderThread); + SocketPairFactory.checkDeallocation(socketPair); + } + + // read-size thread-safety + @Test + public void testReadSide() throws IOException { + SocketPair socketPair = factory.nioNio(Option.apply(null), Option.apply(null), true, false, Option.apply(null)); + Thread clientWriterThread = new Thread(() -> writerLoop(dataSize, 'a', socketPair.client()), "client-writer"); + AtomicLong totalRead = new AtomicLong(); + Thread serverReaderThread1 = + new Thread(() -> readerLoopUntilEof(socketPair.server(), totalRead), "server-reader-1"); + Thread serverReaderThread2 = + new Thread(() -> readerLoopUntilEof(socketPair.server(), totalRead), "server-reader-2"); + Stream.of(serverReaderThread1, serverReaderThread2, clientWriterThread).forEach(t -> t.start()); + joinInterruptible(clientWriterThread); + socketPair.client().external().close(); + Stream.of(serverReaderThread1, serverReaderThread2).forEach(t -> joinInterruptible(t)); + SocketPairFactory.checkDeallocation(socketPair); + assertEquals(dataSize, totalRead.get()); + } + + private void writerLoop(int size, char ch, SocketGroup socketGroup) { + TestUtil.cannotFail(() -> { + try { + logger.fine(() -> String.format("Starting writer loop, size: %s", size)); + int bytesRemaining = size; + byte[] bufferArray = new byte[bufferSize]; + Arrays.fill(bufferArray, (byte) ch); + while (bytesRemaining > 0) { + ByteBuffer buffer = ByteBuffer.wrap(bufferArray, 0, Math.min(bufferSize, bytesRemaining)); + while (buffer.hasRemaining()) { + int c = socketGroup.external().write(buffer); + assertTrue(c > 0, "blocking write must return a positive number"); + bytesRemaining -= c; + assertTrue(bytesRemaining >= 0); + } + } + logger.fine("Finalizing writer loop"); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private void readerLoop(int size, SocketGroup socketGroup) { + TestUtil.cannotFail(() -> { + try { + logger.fine(() -> String.format("Starting reader loop, size: %s", size)); + byte[] readArray = new byte[bufferSize]; + int bytesRemaining = size; + while (bytesRemaining > 0) { + ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, Math.min(bufferSize, bytesRemaining)); + int c = socketGroup.external().read(readBuffer); + assertTrue(c > 0, "blocking read must return a positive number"); + bytesRemaining -= c; + assertTrue(bytesRemaining >= 0); + } + logger.fine("Finalizing reader loop"); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private void readerLoopUntilEof(SocketGroup socketGroup, AtomicLong accumulator) { + TestUtil.cannotFail(() -> { + try { + logger.fine("Starting reader loop"); + byte[] readArray = new byte[bufferSize]; + while (true) { + ByteBuffer readBuffer = ByteBuffer.wrap(readArray, 0, bufferSize); + int c = socketGroup.external().read(readBuffer); + if (c == -1) { + logger.fine("Finalizing reader loop"); + return null; + } + assertTrue(c > 0, "blocking read must return a positive number"); + accumulator.addAndGet(c); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + private static void joinInterruptible(Thread t) { + try { + t.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/test/scala/tlschannel/ConcurrentTest.scala b/src/test/scala/tlschannel/ConcurrentTest.scala deleted file mode 100644 index 51a815ec..00000000 --- a/src/test/scala/tlschannel/ConcurrentTest.scala +++ /dev/null @@ -1,106 +0,0 @@ -package tlschannel - -import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicLong -import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue, fail} -import org.junit.jupiter.api.{Test, TestInstance} -import org.junit.jupiter.api.TestInstance.Lifecycle -import tlschannel.helpers.{SocketGroup, SocketPairFactory, SslContextFactory, TestUtil} - -import java.util.logging.Logger -import scala.util.control.Breaks.{break, breakable} - -@TestInstance(Lifecycle.PER_CLASS) -class ConcurrentTest { - - val logger = Logger.getLogger(classOf[ConcurrentTest].getName) - - val sslContextFactory = new SslContextFactory - val factory = new SocketPairFactory(sslContextFactory.defaultContext) - val dataSize = 250_000_000 - val bufferSize = 2000 - - /** Test several parties writing concurrently - */ - // write-side thread safety - @Test - def testWriteSide(): Unit = { - val socketPair = factory.nioNio() - val clientWriterThread1 = new Thread(() => writerLoop(dataSize, 'a', socketPair.client), "client-writer-1") - val clientWriterThread2 = new Thread(() => writerLoop(dataSize, 'b', socketPair.client), "client-writer-2") - val clientWriterThread3 = new Thread(() => writerLoop(dataSize, 'c', socketPair.client), "client-writer-3") - val clientWriterThread4 = new Thread(() => writerLoop(dataSize, 'd', socketPair.client), "client-writer-4") - val serverReaderThread = new Thread(() => readerLoop(dataSize * 4, socketPair.server), "server-reader") - Seq(serverReaderThread, clientWriterThread1, clientWriterThread2, clientWriterThread3, clientWriterThread4) - .foreach(_.start()) - Seq(clientWriterThread1, clientWriterThread2, clientWriterThread3, clientWriterThread4).foreach(_.join()) - socketPair.client.external.close() - serverReaderThread.join() - SocketPairFactory.checkDeallocation(socketPair) - } - - // read-size thread-safety - @Test - def testReadSide(): Unit = { - val socketPair = factory.nioNio() - val clientWriterThread = new Thread(() => writerLoop(dataSize, 'a', socketPair.client), "client-writer") - val totalRead = new AtomicLong - val serverReaderThread1 = new Thread(() => readerLoopUntilEof(socketPair.server, totalRead), "server-reader-1") - val serverReaderThread2 = new Thread(() => readerLoopUntilEof(socketPair.server, totalRead), "server-reader-2") - Seq(serverReaderThread1, serverReaderThread2, clientWriterThread).foreach(_.start()) - clientWriterThread.join() - socketPair.client.external.close() - Seq(serverReaderThread1, serverReaderThread2).foreach(_.join()) - SocketPairFactory.checkDeallocation(socketPair) - assertEquals(dataSize, totalRead.get()) - } - - private def writerLoop(size: Int, char: Char, socketGroup: SocketGroup): Unit = TestUtil.cannotFail { - logger.fine(() => s"Starting writer loop, size: $size") - var bytesRemaining = size - val bufferArray = Array.fill[Byte](bufferSize)(char.toByte) - while (bytesRemaining > 0) { - val buffer = ByteBuffer.wrap(bufferArray, 0, math.min(bufferSize, bytesRemaining)) - while (buffer.hasRemaining) { - val c = socketGroup.external.write(buffer) - assertTrue(c > 0, "blocking write must return a positive number") - bytesRemaining -= c.toInt - assertTrue(bytesRemaining >= 0) - } - } - logger.fine("Finalizing writer loop") - } - - private def readerLoop(size: Int, socketGroup: SocketGroup): Unit = TestUtil.cannotFail { - logger.fine(() => s"Starting reader loop. Size: $size") - val readArray = Array.ofDim[Byte](bufferSize) - var bytesRemaining = size - while (bytesRemaining > 0) { - val readBuffer = ByteBuffer.wrap(readArray, 0, math.min(bufferSize, bytesRemaining)) - val c = socketGroup.external.read(readBuffer) - assertTrue(c > 0, "blocking read must return a positive number") - bytesRemaining -= c - assertTrue(bytesRemaining >= 0) - } - logger.fine("Finalizing reader loop") - } - - private def readerLoopUntilEof(socketGroup: SocketGroup, accumulator: AtomicLong): Unit = TestUtil.cannotFail { - breakable { - logger.fine("Starting reader loop.") - val readArray = Array.ofDim[Byte](bufferSize) - while (true) { - val readBuffer = ByteBuffer.wrap(readArray, 0, bufferSize) - val c = socketGroup.external.read(readBuffer) - if (c == -1) { - logger.fine("Finalizing reader loop") - break() - } - assertTrue(c > 0, "blocking read must return a positive number") - accumulator.addAndGet(c) - } - fail() - } - } - -}