From abdfe70fb988959aacbae7d27e5f29304ed6056d Mon Sep 17 00:00:00 2001 From: Mariano Barrios Date: Thu, 9 May 2024 13:04:58 +0200 Subject: [PATCH] Migrate to Java: NonBlockingLoops --- .../tlschannel/MultiNonBlockingTest.java | 27 +- .../scala/tlschannel/NonBlockingTest.java | 8 +- .../tlschannel/NullMultiNonBlockingTest.java | 10 +- .../tlschannel/helpers/NonBlockingLoops.java | 281 ++++++++++++++++++ .../tlschannel/helpers/NonBlockingLoops.scala | 189 ------------ .../tlschannel/helpers/TestJavaUtil.java | 13 + .../scala/tlschannel/helpers/TestUtil.scala | 9 - 7 files changed, 318 insertions(+), 219 deletions(-) create mode 100644 src/test/scala/tlschannel/helpers/NonBlockingLoops.java delete mode 100644 src/test/scala/tlschannel/helpers/NonBlockingLoops.scala diff --git a/src/test/scala/tlschannel/MultiNonBlockingTest.java b/src/test/scala/tlschannel/MultiNonBlockingTest.java index 65a075c2..63efc9d5 100644 --- a/src/test/scala/tlschannel/MultiNonBlockingTest.java +++ b/src/test/scala/tlschannel/MultiNonBlockingTest.java @@ -2,12 +2,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.List; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import scala.Option; -import scala.collection.immutable.Seq; +import scala.jdk.CollectionConverters; import tlschannel.helpers.NonBlockingLoops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -25,10 +26,11 @@ public class MultiNonBlockingTest { @Test public void testTaskLoop() { System.out.println("testTasksInExecutorWithRenegotiation():"); - Seq pairs = factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null)); + List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( + Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null))) + .asJava(); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, false); - assertEquals(0, report.asyncTasksRun()); + assertEquals(0, report.asyncTasksRun); report.print(); } @@ -36,8 +38,9 @@ public void testTaskLoop() { @Test public void testTasksInExecutor() { System.out.println("testTasksInExecutorWithRenegotiation():"); - Seq pairs = factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null)); + List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( + Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null))) + .asJava(); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, false); report.print(); } @@ -46,10 +49,11 @@ public void testTasksInExecutor() { @Test public void testTasksInLoopWithRenegotiation() { System.out.println("testTasksInExecutorWithRenegotiation():"); - Seq pairs = factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null)); + List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( + Option.apply(null), totalConnections, Option.apply(null), true, false, Option.apply(null))) + .asJava(); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true); - assertEquals(0, report.asyncTasksRun()); + assertEquals(0, report.asyncTasksRun); report.print(); } @@ -57,8 +61,9 @@ public void testTasksInLoopWithRenegotiation() { @Test public void testTasksInExecutorWithRenegotiation() { System.out.println("testTasksInExecutorWithRenegotiation():"); - Seq pairs = factory.nioNioN( - Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null)); + List pairs = CollectionConverters.SeqHasAsJava(factory.nioNioN( + Option.apply(null), totalConnections, Option.apply(null), false, false, Option.apply(null))) + .asJava(); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, true); report.print(); } diff --git a/src/test/scala/tlschannel/NonBlockingTest.java b/src/test/scala/tlschannel/NonBlockingTest.java index a4f58d22..ebf95982 100644 --- a/src/test/scala/tlschannel/NonBlockingTest.java +++ b/src/test/scala/tlschannel/NonBlockingTest.java @@ -8,7 +8,6 @@ import org.junit.jupiter.api.TestInstance.Lifecycle; import scala.Option; import scala.Some; -import scala.jdk.javaapi.CollectionConverters; import tlschannel.helpers.NonBlockingLoops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -46,11 +45,8 @@ public Collection testSelectorLoop() { false, Option.apply(null)); - NonBlockingLoops.Report report = NonBlockingLoops.loop( - CollectionConverters.asScala(Collections.singletonList(socketPair)) - .toSeq(), - dataSize, - true); + NonBlockingLoops.Report report = + NonBlockingLoops.loop(Collections.singletonList(socketPair), dataSize, true); System.out.printf("%5d -eng-> %5d -net-> %5d -eng-> %5d\n", size1, size2, size1, size2); report.print(); })); diff --git a/src/test/scala/tlschannel/NullMultiNonBlockingTest.java b/src/test/scala/tlschannel/NullMultiNonBlockingTest.java index 0db9349f..1cef9716 100644 --- a/src/test/scala/tlschannel/NullMultiNonBlockingTest.java +++ b/src/test/scala/tlschannel/NullMultiNonBlockingTest.java @@ -1,12 +1,13 @@ package tlschannel; +import java.util.List; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.TestInstance.Lifecycle; import scala.Option; -import scala.collection.immutable.Seq; +import scala.jdk.CollectionConverters; import tlschannel.helpers.NonBlockingLoops; import tlschannel.helpers.SocketGroups.SocketPair; import tlschannel.helpers.SocketPairFactory; @@ -25,10 +26,11 @@ public class NullMultiNonBlockingTest { @Test public void testRunTasksInNonBlockingLoop() { - Seq pairs = - factory.nioNioN(null, totalConnections, Option.apply(null), true, false, Option.apply(null)); + List pairs = CollectionConverters.SeqHasAsJava( + factory.nioNioN(null, totalConnections, Option.apply(null), true, false, Option.apply(null))) + .asJava(); NonBlockingLoops.Report report = NonBlockingLoops.loop(pairs, dataSize, false); - Assertions.assertEquals(0, report.asyncTasksRun()); + Assertions.assertEquals(0, report.asyncTasksRun); } @AfterAll diff --git a/src/test/scala/tlschannel/helpers/NonBlockingLoops.java b/src/test/scala/tlschannel/helpers/NonBlockingLoops.java new file mode 100644 index 00000000..6b55a133 --- /dev/null +++ b/src/test/scala/tlschannel/helpers/NonBlockingLoops.java @@ -0,0 +1,281 @@ +package tlschannel.helpers; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.SplittableRandom; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.LongAdder; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import scala.util.Random; +import tlschannel.NeedsReadException; +import tlschannel.NeedsTaskException; +import tlschannel.NeedsWriteException; +import tlschannel.helpers.SocketGroups.SocketGroup; +import tlschannel.helpers.SocketGroups.SocketPair; + +public class NonBlockingLoops { + + interface Endpoint { + SelectionKey key(); + + int remaining(); + } + + public static class WriterEndpoint implements Endpoint { + + private final SocketGroup socketGroup; + private SelectionKey key; + private final SplittableRandom random = new SplittableRandom(Loops.seed); + private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize); + private int remaining; + + public WriterEndpoint(SocketGroup socketGroup, SelectionKey key, int remaining) { + this.socketGroup = socketGroup; + this.key = key; + this.remaining = remaining; + buffer.flip(); + } + + @Override + public SelectionKey key() { + return key; + } + + @Override + public int remaining() { + return remaining; + } + } + + public static class ReaderEndpoint implements Endpoint { + private final SocketGroup socketGroup; + private SelectionKey key; + private final ByteBuffer buffer = ByteBuffer.allocate(Loops.bufferSize); + private final MessageDigest digest; + private int remaining; + + public ReaderEndpoint(SocketGroup socketGroup, SelectionKey key, int remaining) { + this.socketGroup = socketGroup; + this.key = key; + this.remaining = remaining; + try { + this.digest = MessageDigest.getInstance(Loops.hashAlgorithm); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + + @Override + public SelectionKey key() { + return key; + } + + @Override + public int remaining() { + return remaining; + } + } + + public static class Report { + public final int selectorCycles; + public final int needReadCount; + public final int needWriteCount; + public final int renegotiationCount; + public final int asyncTasksRun; + public final Duration totalAsyncTaskRunningTime; + + public Report( + int selectorCycles, + int needReadCount, + int needWriteCount, + int renegotiationCount, + int asyncTasksRun, + Duration totalAsyncTaskRunningTime) { + this.selectorCycles = selectorCycles; + this.needReadCount = needReadCount; + this.needWriteCount = needWriteCount; + this.renegotiationCount = renegotiationCount; + this.asyncTasksRun = asyncTasksRun; + this.totalAsyncTaskRunningTime = totalAsyncTaskRunningTime; + } + + public void print() { + System.out.printf("Selector cycles:%s\n", selectorCycles); + System.out.printf("NeedRead count: %s\n", needReadCount); + System.out.printf("NeedWrite count: %s\n", needWriteCount); + System.out.printf("Renegotiation count: %s\n", renegotiationCount); + System.out.printf("Asynchronous tasks run: %s\n", asyncTasksRun); + System.out.printf("Total asynchronous task running time: %s ms\n", totalAsyncTaskRunningTime.toMillis()); + } + } + + public static Report loop(List socketPairs, int dataSize, boolean renegotiate) { + try { + int totalConnections = socketPairs.size(); + Selector selector = Selector.open(); + ExecutorService executor = + Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors() - 1); + + ConcurrentLinkedQueue readyTaskSockets = new ConcurrentLinkedQueue<>(); + + List writers = socketPairs.stream() + .map(pair -> { + try { + pair.client.plain.configureBlocking(false); + WriterEndpoint clientEndpoint = new WriterEndpoint(pair.client, null, dataSize); + clientEndpoint.key = + pair.client.plain.register(selector, SelectionKey.OP_WRITE, clientEndpoint); + return clientEndpoint; + } catch (IOException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + + List readers = socketPairs.stream() + .map(pair -> { + try { + pair.server.plain.configureBlocking(false); + ReaderEndpoint serverEndpoint = new ReaderEndpoint(pair.server, null, dataSize); + serverEndpoint.key = + pair.server.plain.register(selector, SelectionKey.OP_READ, serverEndpoint); + return serverEndpoint; + } catch (IOException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + + // var allEndpoints = writers ++ readers; + + int taskCount = 0; + int needReadCount = 0; + int needWriteCount = 0; + int selectorCycles = 0; + int renegotiationCount = 0; + int maxRenegotiations = renegotiate ? totalConnections * 2 * 20 : 0; + + Random random = new Random(); + + LongAdder totalTaskTimeNanos = new LongAdder(); + + byte[] dataHash = Loops.expectedBytesHash.apply(dataSize); + + while (readers.stream().anyMatch(r -> r.remaining > 0) + || writers.stream().anyMatch(r -> r.remaining > 0)) { + selectorCycles += 1; + selector.select(); // block + + for (Endpoint endpoint : Stream.concat( + getSelectedEndpoints(selector), + TestJavaUtil.removeAndCollect(readyTaskSockets.iterator())) + .collect(Collectors.toList())) { + try { + if (endpoint instanceof WriterEndpoint) { + WriterEndpoint writer = (WriterEndpoint) endpoint; + // rewriting do-while loop in a way compatible with Scala 23 + do { + if (renegotiationCount < maxRenegotiations) { + if (random.nextBoolean()) { + renegotiationCount += 1; + writer.socketGroup.tls.renegotiate(); + } + } + if (!writer.buffer.hasRemaining()) { + TestUtil.nextBytes(writer.random, writer.buffer.array()); + writer.buffer.position(0); + writer.buffer.limit(Math.min(writer.buffer.capacity(), writer.remaining)); + } + int oldPosition = writer.buffer.position(); + try { + int c = writer.socketGroup.external.write(writer.buffer); + assertTrue(c >= 0); // the necessity of blocking is communicated with exceptions + } finally { + int bytesWriten = writer.buffer.position() - oldPosition; + writer.remaining -= bytesWriten; + } + + } while (writer.remaining > 0); + + } else if (endpoint instanceof ReaderEndpoint) { + ReaderEndpoint reader = (ReaderEndpoint) endpoint; + // rewriting do-while loop in a way compatible with Scala 23 + do { + reader.buffer.clear(); + int c = reader.socketGroup.external.read(reader.buffer); + assertTrue(c > 0); // the necessity of blocking is communicated with exceptions + reader.digest.update(reader.buffer.array(), 0, c); + reader.remaining -= c; + } while (reader.remaining > 0); + } else { + throw new IllegalArgumentException(); + } + } catch (NeedsWriteException e) { + needWriteCount += 1; + endpoint.key().interestOps(SelectionKey.OP_WRITE); + } catch (NeedsReadException e) { + needReadCount += 1; + endpoint.key().interestOps(SelectionKey.OP_READ); + } catch (NeedsTaskException e) { + Runnable r = () -> { + long start = System.nanoTime(); + e.getTask().run(); + Duration elapsed = Duration.ofNanos(System.nanoTime() - start); + selector.wakeup(); + readyTaskSockets.add(endpoint); + totalTaskTimeNanos.add(elapsed.toNanos()); + }; + executor.submit(r); + taskCount += 1; + } + } + } + + for (SocketPair socketPair : socketPairs) { + socketPair.client.external.close(); + socketPair.server.external.close(); + SocketPairFactory.checkDeallocation(socketPair); + } + + for (ReaderEndpoint reader : readers) { + assertArrayEquals(reader.digest.digest(), dataHash); + } + + return new Report( + selectorCycles, + needReadCount, + needWriteCount, + renegotiationCount, + taskCount, + Duration.ofNanos(totalTaskTimeNanos.longValue())); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static Stream getSelectedEndpoints(Selector selector) { + List builder = new ArrayList<>(); + Iterator it = selector.selectedKeys().iterator(); + while (it.hasNext()) { + SelectionKey key = it.next(); + key.interestOps(0); // delete all operations + builder.add((Endpoint) key.attachment()); + it.remove(); + } + return builder.stream(); + } +} diff --git a/src/test/scala/tlschannel/helpers/NonBlockingLoops.scala b/src/test/scala/tlschannel/helpers/NonBlockingLoops.scala deleted file mode 100644 index 00610ab9..00000000 --- a/src/test/scala/tlschannel/helpers/NonBlockingLoops.scala +++ /dev/null @@ -1,189 +0,0 @@ -package tlschannel.helpers - -import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertTrue} -import tlschannel.NeedsWriteException -import tlschannel.NeedsReadException -import tlschannel.NeedsTaskException -import tlschannel.helpers.SocketGroups.{SocketGroup, SocketPair} - -import java.util.concurrent.atomic.LongAdder -import scala.util.Random -import java.util.concurrent.ConcurrentLinkedQueue -import java.nio.channels.Selector -import java.util.concurrent.Executors -import java.nio.ByteBuffer -import java.nio.channels.SelectionKey -import java.security.MessageDigest -import java.time.Duration -import java.util.SplittableRandom - -object NonBlockingLoops { - - trait Endpoint { - def key: SelectionKey - def remaining: Int - } - - case class WriterEndpoint(socketGroup: SocketGroup, var key: SelectionKey, var remaining: Int) extends Endpoint { - val random = new SplittableRandom(Loops.seed) - val buffer = ByteBuffer.allocate(Loops.bufferSize) - buffer.flip() - } - - case class ReaderEndpoint(socketGroup: SocketGroup, var key: SelectionKey, var remaining: Int) extends Endpoint { - val buffer = ByteBuffer.allocate(Loops.bufferSize) - val digest = MessageDigest.getInstance(Loops.hashAlgorithm) - } - - case class Report( - selectorCycles: Int, - needReadCount: Int, - needWriteCount: Int, - renegotiationCount: Int, - asyncTasksRun: Int, - totalAsyncTaskRunningTime: Duration - ) { - - def print() = { - println(s"Selector cycles: $selectorCycles") - println(s"NeedRead count: $needReadCount") - println(s"NeedWrite count: $needWriteCount") - println(s"Renegotiation count: $renegotiationCount") - println(s"Asynchronous tasks run: $asyncTasksRun") - println(s"Total asynchronous task running time: ${totalAsyncTaskRunningTime.toMillis} ms") - } - } - - def loop(socketPairs: Seq[SocketPair], dataSize: Int, renegotiate: Boolean): Report = { - - val totalConnections = socketPairs.size - val selector = Selector.open() - val executor = Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors - 1) - - val readyTaskSockets = new ConcurrentLinkedQueue[Endpoint] - - val endpoints = for (pair <- socketPairs) yield { - pair.client.plain.configureBlocking(false) - pair.server.plain.configureBlocking(false) - - val clientEndpoint = WriterEndpoint(pair.client, key = null, remaining = dataSize) - val serverEndpoint = ReaderEndpoint(pair.server, key = null, remaining = dataSize) - - clientEndpoint.key = pair.client.plain.register(selector, SelectionKey.OP_WRITE, clientEndpoint) - serverEndpoint.key = pair.server.plain.register(selector, SelectionKey.OP_READ, serverEndpoint) - (clientEndpoint, serverEndpoint) - } - - val (writers, readers) = endpoints.unzip - val allEndpoints = writers ++ readers - - var taskCount = 0 - var needReadCount = 0 - var needWriteCount = 0 - var selectorCycles = 0 - var renegotiationCount = 0 - val maxRenegotiations = if (renegotiate) totalConnections * 2 * 20 else 0 - - val random = new Random - - val totalTaskTimeNanos = new LongAdder - - val dataHash = Loops.expectedBytesHash(dataSize) - - while (allEndpoints.exists(_.remaining > 0)) { - selectorCycles += 1 - selector.select() // block - - for (endpoint <- getSelectedEndpoints(selector) ++ TestUtil.removeAndCollect(readyTaskSockets.iterator())) { - try { - endpoint match { - case writer: WriterEndpoint => - // rewriting do-while loop in a way compatible with Scala 23 - while { - if (renegotiationCount < maxRenegotiations) { - if (random.nextBoolean()) { - renegotiationCount += 1 - writer.socketGroup.tls.renegotiate() - } - } - if (!writer.buffer.hasRemaining) { - TestUtil.nextBytes(writer.random, writer.buffer.array()) - writer.buffer.position(0) - writer.buffer.limit(math.min(writer.buffer.capacity, writer.remaining)) - } - val oldPosition = writer.buffer.position() - try { - val c = writer.socketGroup.external.write(writer.buffer) - assertTrue(c >= 0) // the necessity of blocking is communicated with exceptions - } finally { - val bytesWriten = writer.buffer.position() - oldPosition - writer.remaining -= bytesWriten - } - writer.remaining > 0 - } do () - case reader: ReaderEndpoint => - // rewriting do-while loop in a way compatible with Scala 23 - while { - reader.buffer.clear() - val c = reader.socketGroup.external.read(reader.buffer) - assertTrue(c > 0) // the necessity of blocking is communicated with exceptions - reader.digest.update(reader.buffer.array, 0, c) - reader.remaining -= c - reader.remaining > 0 - } do () - } - } catch { - case e: NeedsWriteException => - needWriteCount += 1 - endpoint.key.interestOps(SelectionKey.OP_WRITE) - case e: NeedsReadException => - needReadCount += 1 - endpoint.key.interestOps(SelectionKey.OP_READ) - case e: NeedsTaskException => - val r: Runnable = { () => - val start = System.nanoTime() - e.getTask.run() - val elapsed = Duration.ofNanos(System.nanoTime() - start) - selector.wakeup() - readyTaskSockets.add(endpoint) - totalTaskTimeNanos.add(elapsed.toNanos) - } - executor.submit(r) - taskCount += 1 - } - } - } - - for (socketPair <- socketPairs) { - socketPair.client.external.close() - socketPair.server.external.close() - SocketPairFactory.checkDeallocation(socketPair) - } - - for (reader <- readers) { - assertArrayEquals(reader.digest.digest(), dataHash) - } - - Report( - selectorCycles, - needReadCount, - needWriteCount, - renegotiationCount, - taskCount, - Duration.ofNanos(totalTaskTimeNanos.longValue()) - ) - } - - def getSelectedEndpoints(selector: Selector): Seq[Endpoint] = { - val builder = Seq.newBuilder[Endpoint] - val it = selector.selectedKeys().iterator() - while (it.hasNext) { - val key = it.next() - key.interestOps(0) // delete all operations - builder += key.attachment.asInstanceOf[Endpoint] - it.remove() - } - builder.result() - } - -} diff --git a/src/test/scala/tlschannel/helpers/TestJavaUtil.java b/src/test/scala/tlschannel/helpers/TestJavaUtil.java index 16a3a284..8fd53a37 100644 --- a/src/test/scala/tlschannel/helpers/TestJavaUtil.java +++ b/src/test/scala/tlschannel/helpers/TestJavaUtil.java @@ -1,12 +1,25 @@ package tlschannel.helpers; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.Stream; public class TestJavaUtil { + public static Stream removeAndCollect(Iterator iterator) { + List builder = new ArrayList<>(); + while (iterator.hasNext()) { + builder.add(iterator.next()); + iterator.remove(); + } + return builder.stream(); + } + @FunctionalInterface public interface ExceptionalRunnable { void run() throws Exception; diff --git a/src/test/scala/tlschannel/helpers/TestUtil.scala b/src/test/scala/tlschannel/helpers/TestUtil.scala index a615af8f..211c3b10 100644 --- a/src/test/scala/tlschannel/helpers/TestUtil.scala +++ b/src/test/scala/tlschannel/helpers/TestUtil.scala @@ -7,15 +7,6 @@ object TestUtil { val logger = Logger.getLogger(TestUtil.getClass.getName) - def removeAndCollect[A](iterator: java.util.Iterator[A]): Seq[A] = { - val builder = Seq.newBuilder[A] - while (iterator.hasNext) { - builder += iterator.next() - iterator.remove() - } - builder.result() - } - def nextBytes(random: SplittableRandom, bytes: Array[Byte]): Unit = { nextBytes(random, bytes, bytes.length) }