Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to java InteroperabilityUtils #212

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 3 additions & 55 deletions src/test/scala/tlschannel/InteroperabilityTest.scala
Original file line number Diff line number Diff line change
@@ -1,35 +1,30 @@
package tlschannel

import scala.util.Random
import java.net.Socket
import java.nio.channels.ByteChannel
import javax.net.ssl.SSLSocket
import java.nio.ByteBuffer
import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals, assertNotEquals, assertTrue}
import org.junit.jupiter.api.{Test, TestInstance}
import org.junit.jupiter.api.TestInstance.Lifecycle
import tlschannel.helpers.TestUtil
import tlschannel.helpers.SslContextFactory
import tlschannel.helpers.SocketPairFactory
import tlschannel.util.InteroperabilityUtils._

@TestInstance(Lifecycle.PER_CLASS)
class InteroperabilityTest {

import InteroperabilityTest._

val sslContextFactory = new SslContextFactory
val factory = new SocketPairFactory(sslContextFactory.defaultContext, SslContextFactory.certificateCommonName)

def oldNio() = {
val (client, server) = factory.oldNio(None)
val clientPair = (new SSLSocketWriter(client), new SocketReader(client))
val serverPair = (new TlsSocketChannelWriter(server.tls), new ByteChannelReader(server.tls))
val serverPair = (new TlsChannelWriter(server.tls), new ByteChannelReader(server.tls))
(clientPair, serverPair)
}

def nioOld() = {
val (client, server) = factory.nioOld()
val clientPair = (new TlsSocketChannelWriter(client.tls), new ByteChannelReader(client.tls))
val clientPair = (new TlsChannelWriter(client.tls), new ByteChannelReader(client.tls))
val serverPair = (new SSLSocketWriter(server), new SocketReader(server))
(clientPair, serverPair)
}
Expand Down Expand Up @@ -165,50 +160,3 @@ class InteroperabilityTest {
}

}

object InteroperabilityTest {

trait Reader {
def read(array: Array[Byte], offset: Int, length: Int): Int
def close(): Unit
}

class SocketReader(socket: Socket) extends Reader {
private val is = socket.getInputStream
def read(array: Array[Byte], offset: Int, length: Int) = is.read(array, offset, length)
def close() = socket.close()
}

class ByteChannelReader(socket: ByteChannel) extends Reader {
def read(array: Array[Byte], offset: Int, length: Int) = socket.read(ByteBuffer.wrap(array, offset, length))
def close() = socket.close()
}

trait Writer {
def renegotiate(): Unit
def write(array: Array[Byte], offset: Int, length: Int): Unit
def close(): Unit
}

class SSLSocketWriter(socket: SSLSocket) extends Writer {
private val os = socket.getOutputStream
def write(array: Array[Byte], offset: Int, length: Int) = os.write(array, offset, length)
def renegotiate() = socket.startHandshake()
def close() = socket.close()
}

class TlsSocketChannelWriter(val socket: TlsChannel) extends Writer {

def write(array: Array[Byte], offset: Int, length: Int) = {
val buffer = ByteBuffer.wrap(array, offset, length)
while (buffer.remaining() > 0) {
val c = socket.write(buffer)
assertNotEquals(0, c, "blocking write cannot return 0")
}
}

def renegotiate(): Unit = socket.renegotiate()
def close() = socket.close()
}

}
122 changes: 122 additions & 0 deletions src/test/scala/tlschannel/util/InteroperabilityUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package tlschannel.util;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import javax.net.ssl.SSLSocket;
import org.junit.jupiter.api.Assertions;
import tlschannel.TlsChannel;

public class InteroperabilityUtils {

public interface Reader {
int read(byte[] array, int offset, int length) throws IOException;

void close() throws IOException;
}

public interface Writer {
void renegotiate() throws IOException;

void write(byte[] array, int offset, int length) throws IOException;

void close() throws IOException;
}

public static class SocketReader implements Reader {

private final Socket socket;
private final InputStream is;

public SocketReader(Socket socket) throws IOException {
this.socket = socket;
this.is = socket.getInputStream();
}

@Override
public int read(byte[] array, int offset, int length) throws IOException {
return is.read(array, offset, length);
}

@Override
public void close() throws IOException {
socket.close();
}
}

public static class ByteChannelReader implements Reader {

private final ByteChannel socket;

public ByteChannelReader(ByteChannel socket) {
this.socket = socket;
}

@Override
public int read(byte[] array, int offset, int length) throws IOException {
return socket.read(ByteBuffer.wrap(array, offset, length));
}

@Override
public void close() throws IOException {
socket.close();
}
}

public static class SSLSocketWriter implements Writer {

private final SSLSocket socket;
private final OutputStream os;

public SSLSocketWriter(SSLSocket socket) throws IOException {
this.socket = socket;
this.os = socket.getOutputStream();
}

@Override
public void write(byte[] array, int offset, int length) throws IOException {
os.write(array, offset, length);
}

@Override
public void renegotiate() throws IOException {
socket.startHandshake();
}

@Override
public void close() throws IOException {
socket.close();
}
}

public static class TlsChannelWriter implements Writer {

private final TlsChannel socket;

public TlsChannelWriter(TlsChannel socket) {
this.socket = socket;
}

@Override
public void write(byte[] array, int offset, int length) throws IOException {
ByteBuffer buffer = ByteBuffer.wrap(array, offset, length);
while (buffer.remaining() > 0) {
int c = socket.write(buffer);
Assertions.assertNotEquals(0, c, "blocking write cannot return 0");
}
}

@Override
public void renegotiate() throws IOException {
socket.renegotiate();
}

@Override
public void close() throws IOException {
socket.close();
}
}
}