Skip to content

Commit

Permalink
Migrate to java InteroperabilityUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
marianobarrios committed Apr 21, 2024
1 parent ae1bb49 commit 2c49608
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 55 deletions.
58 changes: 3 additions & 55 deletions src/test/scala/tlschannel/InteroperabilityTest.scala
Original file line number Diff line number Diff line change
@@ -1,35 +1,30 @@
package tlschannel

import scala.util.Random
import java.net.Socket
import java.nio.channels.ByteChannel
import javax.net.ssl.SSLSocket
import java.nio.ByteBuffer
import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals, assertNotEquals, assertTrue}
import org.junit.jupiter.api.{Test, TestInstance}
import org.junit.jupiter.api.TestInstance.Lifecycle
import tlschannel.helpers.TestUtil
import tlschannel.helpers.SslContextFactory
import tlschannel.helpers.SocketPairFactory
import tlschannel.util.InteroperabilityUtils._

@TestInstance(Lifecycle.PER_CLASS)
class InteroperabilityTest {

import InteroperabilityTest._

val sslContextFactory = new SslContextFactory
val factory = new SocketPairFactory(sslContextFactory.defaultContext, SslContextFactory.certificateCommonName)

def oldNio() = {
val (client, server) = factory.oldNio(None)
val clientPair = (new SSLSocketWriter(client), new SocketReader(client))
val serverPair = (new TlsSocketChannelWriter(server.tls), new ByteChannelReader(server.tls))
val serverPair = (new TlsChannelWriter(server.tls), new ByteChannelReader(server.tls))
(clientPair, serverPair)
}

def nioOld() = {
val (client, server) = factory.nioOld()
val clientPair = (new TlsSocketChannelWriter(client.tls), new ByteChannelReader(client.tls))
val clientPair = (new TlsChannelWriter(client.tls), new ByteChannelReader(client.tls))
val serverPair = (new SSLSocketWriter(server), new SocketReader(server))
(clientPair, serverPair)
}
Expand Down Expand Up @@ -165,50 +160,3 @@ class InteroperabilityTest {
}

}

object InteroperabilityTest {

trait Reader {
def read(array: Array[Byte], offset: Int, length: Int): Int
def close(): Unit
}

class SocketReader(socket: Socket) extends Reader {
private val is = socket.getInputStream
def read(array: Array[Byte], offset: Int, length: Int) = is.read(array, offset, length)
def close() = socket.close()
}

class ByteChannelReader(socket: ByteChannel) extends Reader {
def read(array: Array[Byte], offset: Int, length: Int) = socket.read(ByteBuffer.wrap(array, offset, length))
def close() = socket.close()
}

trait Writer {
def renegotiate(): Unit
def write(array: Array[Byte], offset: Int, length: Int): Unit
def close(): Unit
}

class SSLSocketWriter(socket: SSLSocket) extends Writer {
private val os = socket.getOutputStream
def write(array: Array[Byte], offset: Int, length: Int) = os.write(array, offset, length)
def renegotiate() = socket.startHandshake()
def close() = socket.close()
}

class TlsSocketChannelWriter(val socket: TlsChannel) extends Writer {

def write(array: Array[Byte], offset: Int, length: Int) = {
val buffer = ByteBuffer.wrap(array, offset, length)
while (buffer.remaining() > 0) {
val c = socket.write(buffer)
assertNotEquals(0, c, "blocking write cannot return 0")
}
}

def renegotiate(): Unit = socket.renegotiate()
def close() = socket.close()
}

}
122 changes: 122 additions & 0 deletions src/test/scala/tlschannel/util/InteroperabilityUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package tlschannel.util;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import javax.net.ssl.SSLSocket;
import org.junit.jupiter.api.Assertions;
import tlschannel.TlsChannel;

public class InteroperabilityUtils {

public interface Reader {
int read(byte[] array, int offset, int length) throws IOException;

void close() throws IOException;
}

public interface Writer {
void renegotiate() throws IOException;

void write(byte[] array, int offset, int length) throws IOException;

void close() throws IOException;
}

public static class SocketReader implements Reader {

private final Socket socket;
private final InputStream is;

public SocketReader(Socket socket) throws IOException {
this.socket = socket;
this.is = socket.getInputStream();
}

@Override
public int read(byte[] array, int offset, int length) throws IOException {
return is.read(array, offset, length);
}

@Override
public void close() throws IOException {
socket.close();
}
}

public static class ByteChannelReader implements Reader {

private final ByteChannel socket;

public ByteChannelReader(ByteChannel socket) {
this.socket = socket;
}

@Override
public int read(byte[] array, int offset, int length) throws IOException {
return socket.read(ByteBuffer.wrap(array, offset, length));
}

@Override
public void close() throws IOException {
socket.close();
}
}

public static class SSLSocketWriter implements Writer {

private final SSLSocket socket;
private final OutputStream os;

public SSLSocketWriter(SSLSocket socket) throws IOException {
this.socket = socket;
this.os = socket.getOutputStream();
}

@Override
public void write(byte[] array, int offset, int length) throws IOException {
os.write(array, offset, length);
}

@Override
public void renegotiate() throws IOException {
socket.startHandshake();
}

@Override
public void close() throws IOException {
socket.close();
}
}

public static class TlsChannelWriter implements Writer {

private final TlsChannel socket;

public TlsChannelWriter(TlsChannel socket) {
this.socket = socket;
}

@Override
public void write(byte[] array, int offset, int length) throws IOException {
ByteBuffer buffer = ByteBuffer.wrap(array, offset, length);
while (buffer.remaining() > 0) {
int c = socket.write(buffer);
Assertions.assertNotEquals(0, c, "blocking write cannot return 0");
}
}

@Override
public void renegotiate() throws IOException {
socket.renegotiate();
}

@Override
public void close() throws IOException {
socket.close();
}
}
}

0 comments on commit 2c49608

Please sign in to comment.