diff --git a/src/test/scala/tlschannel/InteroperabilityTest.scala b/src/test/scala/tlschannel/InteroperabilityTest.scala index 95f6ebb7..eb84c12a 100644 --- a/src/test/scala/tlschannel/InteroperabilityTest.scala +++ b/src/test/scala/tlschannel/InteroperabilityTest.scala @@ -1,35 +1,30 @@ package tlschannel import scala.util.Random -import java.net.Socket -import java.nio.channels.ByteChannel -import javax.net.ssl.SSLSocket -import java.nio.ByteBuffer import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals, assertNotEquals, assertTrue} import org.junit.jupiter.api.{Test, TestInstance} import org.junit.jupiter.api.TestInstance.Lifecycle import tlschannel.helpers.TestUtil import tlschannel.helpers.SslContextFactory import tlschannel.helpers.SocketPairFactory +import tlschannel.util.InteroperabilityUtils._ @TestInstance(Lifecycle.PER_CLASS) class InteroperabilityTest { - import InteroperabilityTest._ - val sslContextFactory = new SslContextFactory val factory = new SocketPairFactory(sslContextFactory.defaultContext, SslContextFactory.certificateCommonName) def oldNio() = { val (client, server) = factory.oldNio(None) val clientPair = (new SSLSocketWriter(client), new SocketReader(client)) - val serverPair = (new TlsSocketChannelWriter(server.tls), new ByteChannelReader(server.tls)) + val serverPair = (new TlsChannelWriter(server.tls), new ByteChannelReader(server.tls)) (clientPair, serverPair) } def nioOld() = { val (client, server) = factory.nioOld() - val clientPair = (new TlsSocketChannelWriter(client.tls), new ByteChannelReader(client.tls)) + val clientPair = (new TlsChannelWriter(client.tls), new ByteChannelReader(client.tls)) val serverPair = (new SSLSocketWriter(server), new SocketReader(server)) (clientPair, serverPair) } @@ -165,50 +160,3 @@ class InteroperabilityTest { } } - -object InteroperabilityTest { - - trait Reader { - def read(array: Array[Byte], offset: Int, length: Int): Int - def close(): Unit - } - - class SocketReader(socket: Socket) extends Reader { - private val is = socket.getInputStream - def read(array: Array[Byte], offset: Int, length: Int) = is.read(array, offset, length) - def close() = socket.close() - } - - class ByteChannelReader(socket: ByteChannel) extends Reader { - def read(array: Array[Byte], offset: Int, length: Int) = socket.read(ByteBuffer.wrap(array, offset, length)) - def close() = socket.close() - } - - trait Writer { - def renegotiate(): Unit - def write(array: Array[Byte], offset: Int, length: Int): Unit - def close(): Unit - } - - class SSLSocketWriter(socket: SSLSocket) extends Writer { - private val os = socket.getOutputStream - def write(array: Array[Byte], offset: Int, length: Int) = os.write(array, offset, length) - def renegotiate() = socket.startHandshake() - def close() = socket.close() - } - - class TlsSocketChannelWriter(val socket: TlsChannel) extends Writer { - - def write(array: Array[Byte], offset: Int, length: Int) = { - val buffer = ByteBuffer.wrap(array, offset, length) - while (buffer.remaining() > 0) { - val c = socket.write(buffer) - assertNotEquals(0, c, "blocking write cannot return 0") - } - } - - def renegotiate(): Unit = socket.renegotiate() - def close() = socket.close() - } - -} diff --git a/src/test/scala/tlschannel/util/InteroperabilityUtils.java b/src/test/scala/tlschannel/util/InteroperabilityUtils.java new file mode 100644 index 00000000..af0d8733 --- /dev/null +++ b/src/test/scala/tlschannel/util/InteroperabilityUtils.java @@ -0,0 +1,122 @@ +package tlschannel.util; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import javax.net.ssl.SSLSocket; +import org.junit.jupiter.api.Assertions; +import tlschannel.TlsChannel; + +public class InteroperabilityUtils { + + public interface Reader { + int read(byte[] array, int offset, int length) throws IOException; + + void close() throws IOException; + } + + public interface Writer { + void renegotiate() throws IOException; + + void write(byte[] array, int offset, int length) throws IOException; + + void close() throws IOException; + } + + public static class SocketReader implements Reader { + + private final Socket socket; + private final InputStream is; + + public SocketReader(Socket socket) throws IOException { + this.socket = socket; + this.is = socket.getInputStream(); + } + + @Override + public int read(byte[] array, int offset, int length) throws IOException { + return is.read(array, offset, length); + } + + @Override + public void close() throws IOException { + socket.close(); + } + } + + public static class ByteChannelReader implements Reader { + + private final ByteChannel socket; + + public ByteChannelReader(ByteChannel socket) { + this.socket = socket; + } + + @Override + public int read(byte[] array, int offset, int length) throws IOException { + return socket.read(ByteBuffer.wrap(array, offset, length)); + } + + @Override + public void close() throws IOException { + socket.close(); + } + } + + public static class SSLSocketWriter implements Writer { + + private final SSLSocket socket; + private final OutputStream os; + + public SSLSocketWriter(SSLSocket socket) throws IOException { + this.socket = socket; + this.os = socket.getOutputStream(); + } + + @Override + public void write(byte[] array, int offset, int length) throws IOException { + os.write(array, offset, length); + } + + @Override + public void renegotiate() throws IOException { + socket.startHandshake(); + } + + @Override + public void close() throws IOException { + socket.close(); + } + } + + public static class TlsChannelWriter implements Writer { + + private final TlsChannel socket; + + public TlsChannelWriter(TlsChannel socket) { + this.socket = socket; + } + + @Override + public void write(byte[] array, int offset, int length) throws IOException { + ByteBuffer buffer = ByteBuffer.wrap(array, offset, length); + while (buffer.remaining() > 0) { + int c = socket.write(buffer); + Assertions.assertNotEquals(0, c, "blocking write cannot return 0"); + } + } + + @Override + public void renegotiate() throws IOException { + socket.renegotiate(); + } + + @Override + public void close() throws IOException { + socket.close(); + } + } +}