From 1a3744efd9f2dc0a50dcd7c0a9ddd7f240530926 Mon Sep 17 00:00:00 2001 From: LatvianModder Date: Thu, 15 Aug 2024 11:43:27 +0300 Subject: [PATCH] Crude WebSocket implementation --- .../latvian/apps/tinyserver/HTTPServer.java | 121 +++++++++--------- .../latvian/apps/tinyserver/StatusCode.java | 4 + .../apps/tinyserver/http/HTTPRequest.java | 23 +++- .../http/response/HTTPResponseBuilder.java | 17 +++ .../apps/tinyserver/ws/EmptyWSHandler.java | 4 + .../dev/latvian/apps/tinyserver/ws/Frame.java | 121 ++++++++++++++++++ .../latvian/apps/tinyserver/ws/Opcode.java | 31 +++++ .../latvian/apps/tinyserver/ws/RXThread.java | 57 +++++++++ .../latvian/apps/tinyserver/ws/TXThread.java | 80 ++++++++++++ .../apps/tinyserver/ws/WSCloseStatus.java | 18 +++ .../apps/tinyserver/ws/WSEndpointHandler.java | 54 ++++++++ .../latvian/apps/tinyserver/ws/WSHandler.java | 19 ++- .../latvian/apps/tinyserver/ws/WSPayload.java | 4 - .../apps/tinyserver/ws/WSResponse.java | 17 +++ .../latvian/apps/tinyserver/ws/WSSession.java | 44 +++++-- .../apps/tinyserver/test/TestWSSession.java | 22 ++++ .../apps/tinyserver/test/TinyServerTest.java | 14 +- 17 files changed, 558 insertions(+), 92 deletions(-) create mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/Frame.java create mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/Opcode.java create mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/RXThread.java create mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/TXThread.java create mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/WSCloseStatus.java create mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/WSEndpointHandler.java delete mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/WSPayload.java create mode 100644 src/main/java/dev/latvian/apps/tinyserver/ws/WSResponse.java create mode 100644 src/test/java/dev/latvian/apps/tinyserver/test/TestWSSession.java diff --git a/src/main/java/dev/latvian/apps/tinyserver/HTTPServer.java b/src/main/java/dev/latvian/apps/tinyserver/HTTPServer.java index 7f90399..a570efc 100644 --- a/src/main/java/dev/latvian/apps/tinyserver/HTTPServer.java +++ b/src/main/java/dev/latvian/apps/tinyserver/HTTPServer.java @@ -5,19 +5,18 @@ import dev.latvian.apps.tinyserver.http.HTTPMethod; import dev.latvian.apps.tinyserver.http.HTTPPathHandler; import dev.latvian.apps.tinyserver.http.HTTPRequest; -import dev.latvian.apps.tinyserver.http.response.HTTPResponse; import dev.latvian.apps.tinyserver.http.response.HTTPResponseBuilder; import dev.latvian.apps.tinyserver.http.response.HTTPStatus; +import dev.latvian.apps.tinyserver.ws.WSEndpointHandler; import dev.latvian.apps.tinyserver.ws.WSHandler; import dev.latvian.apps.tinyserver.ws.WSSession; import dev.latvian.apps.tinyserver.ws.WSSessionFactory; import org.jetbrains.annotations.Nullable; +import java.io.BufferedInputStream; import java.io.BufferedOutputStream; -import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; -import java.io.InputStreamReader; import java.io.OutputStream; import java.net.InetAddress; import java.net.ServerSocket; @@ -30,7 +29,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; -import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; import java.util.stream.Collectors; @@ -44,6 +43,7 @@ public class HTTPServer implements Runnable, ServerRegi private int port = 8080; private int maxPortShift = 0; private boolean daemon = false; + private int bufferSize = 8192; public HTTPServer(Supplier requestFactory) { this.requestFactory = requestFactory; @@ -71,6 +71,10 @@ public void setDaemon(boolean daemon) { this.daemon = daemon; } + public void setBufferSize(int bufferSize) { + this.bufferSize = bufferSize; + } + public int start() { if (serverSocket != null) { throw new IllegalStateException("Server is already running"); @@ -123,21 +127,9 @@ public void http(HTTPMethod method, String path, HTTPHandler handler) { } } - private record WSEndpointHandler>(WSSessionFactory factory) implements WSHandler, HTTPHandler { - @Override - public Map sessions() { - return Map.of(); - } - - @Override - public HTTPResponse handle(REQ req) { - return HTTPStatus.NOT_IMPLEMENTED; - } - } - @Override public > WSHandler ws(String path, WSSessionFactory factory) { - var handler = new WSEndpointHandler<>(factory); + var handler = new WSEndpointHandler<>(factory, new ConcurrentHashMap<>(), daemon); get(path, handler); return handler; } @@ -153,16 +145,33 @@ public void run() { } } + private String readLine(InputStream in) throws IOException { + var sb = new StringBuilder(); + int b; + + while ((b = in.read()) != -1) { + if (b == '\n') { + break; + } + + if (b != '\r') { + sb.append((char) b); + } + } + + return sb.toString(); + } + private void handleClient(Socket socket) { InputStream in = null; OutputStream out = null; + WSSession upgradedToWebSocket = null; try { - in = socket.getInputStream(); - var reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); - var firstLineStr = reader.readLine(); + in = new BufferedInputStream(socket.getInputStream(), bufferSize); + var firstLineStr = readLine(in); - if (firstLineStr == null || !firstLineStr.toLowerCase().endsWith(" http/1.1")) { + if (!firstLineStr.toLowerCase().endsWith(" http/1.1")) { return; } @@ -212,9 +221,9 @@ private void handleClient(Socket socket) { var headers = new HashMap(); while (true) { - var line = reader.readLine(); + var line = readLine(in); - if (line == null || line.isBlank()) { + if (line.isBlank()) { break; } @@ -262,7 +271,7 @@ private void handleClient(Socket socket) { var builder = createBuilder(req, null); builder.setStatus(HTTPStatus.NO_CONTENT); builder.setHeader("Allow", allowed.stream().map(HTTPMethod::name).collect(Collectors.joining(","))); - out = new BufferedOutputStream(socket.getOutputStream()); + out = new BufferedOutputStream(socket.getOutputStream(), bufferSize); builder.write(out, writeBody); out.flush(); } else if (method == HTTPMethod.TRACE) { @@ -276,7 +285,7 @@ private void handleClient(Socket socket) { var handler = rootHandlers.get(method); if (handler != null) { - req.init(new String[0], CompiledPath.EMPTY, headers, query, in); + req.init(this, new String[0], CompiledPath.EMPTY, headers, query, in); builder = createBuilder(req, handler.handler()); } } else { @@ -287,14 +296,14 @@ private void handleClient(Socket socket) { var h = hl.staticHandlers().get(path); if (h != null) { - req.init(pathParts, h.path(), headers, query, in); + req.init(this, pathParts, h.path(), headers, query, in); builder = createBuilder(req, h.handler()); } else { for (var dynamicHandler : hl.dynamicHandlers()) { var matches = dynamicHandler.path().matches(pathParts); if (matches != null) { - req.init(matches, dynamicHandler.path(), headers, query, in); + req.init(this, matches, dynamicHandler.path(), headers, query, in); builder = createBuilder(req, dynamicHandler.handler()); break; } @@ -308,53 +317,43 @@ private void handleClient(Socket socket) { builder.setStatus(HTTPStatus.NOT_FOUND); } - System.out.println("Request: " + method.name() + " /" + path); - System.out.println("- Query:"); - - for (var e : query.entrySet()) { - System.out.println(" " + e.getKey() + ": " + e.getValue()); - } - - System.out.println("- Variables:"); - - for (var e : req.variables().entrySet()) { - System.out.println(" " + e.getKey() + ": " + e.getValue()); - } + out = new BufferedOutputStream(socket.getOutputStream(), bufferSize); + builder.write(out, writeBody); + out.flush(); - System.out.println("- Headers:"); + upgradedToWebSocket = (WSSession) builder.wsSession(); - for (var e : headers.entrySet()) { - System.out.println(" " + e.getKey() + ": " + e.getValue()); + if (upgradedToWebSocket != null) { + upgradedToWebSocket.start(socket, in, out); + upgradedToWebSocket.onOpen(req); } - - out = new BufferedOutputStream(socket.getOutputStream()); - builder.write(out, writeBody); - out.flush(); } } } catch (Exception ex) { ex.printStackTrace(); } - try { - if (in != null) { - in.close(); + if (upgradedToWebSocket == null) { + try { + if (in != null) { + in.close(); + } + } catch (Exception ignored) { } - } catch (Exception ignored) { - } - try { - if (out != null) { - out.close(); + try { + if (out != null) { + out.close(); + } + } catch (Exception ignored) { } - } catch (Exception ignored) { - } - try { - if (socket != null) { - socket.close(); + try { + if (socket != null) { + socket.close(); + } + } catch (Exception ignored) { } - } catch (Exception ignored) { } } @@ -369,7 +368,7 @@ public HTTPResponseBuilder createBuilder(REQ req, @Nullable HTTPHandler han if (handler != null) { try { - handler.handle(req).build(builder); + builder.setResponse(handler.handle(req)); } catch (Exception ex) { builder.setStatus(HTTPStatus.INTERNAL_ERROR); handlePayloadError(builder, ex); diff --git a/src/main/java/dev/latvian/apps/tinyserver/StatusCode.java b/src/main/java/dev/latvian/apps/tinyserver/StatusCode.java index c71258c..fe0d34c 100644 --- a/src/main/java/dev/latvian/apps/tinyserver/StatusCode.java +++ b/src/main/java/dev/latvian/apps/tinyserver/StatusCode.java @@ -1,4 +1,8 @@ package dev.latvian.apps.tinyserver; public record StatusCode(int code, String message) { + @Override + public String toString() { + return code + " " + message; + } } diff --git a/src/main/java/dev/latvian/apps/tinyserver/http/HTTPRequest.java b/src/main/java/dev/latvian/apps/tinyserver/http/HTTPRequest.java index 56b642c..a1c4ccc 100644 --- a/src/main/java/dev/latvian/apps/tinyserver/http/HTTPRequest.java +++ b/src/main/java/dev/latvian/apps/tinyserver/http/HTTPRequest.java @@ -1,21 +1,25 @@ package dev.latvian.apps.tinyserver.http; import dev.latvian.apps.tinyserver.CompiledPath; +import dev.latvian.apps.tinyserver.HTTPServer; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.HashMap; import java.util.Map; public class HTTPRequest { + private HTTPServer server; private String[] path = new String[0]; private Map variables = Map.of(); private Map query = Map.of(); private Map headers = Map.of(); private InputStream bodyStream = null; - public void init(String[] path, CompiledPath compiledPath, Map headers, Map query, InputStream bodyStream) { + public void init(HTTPServer server, String[] path, CompiledPath compiledPath, Map headers, Map query, InputStream bodyStream) { + this.server = server; this.path = path; if (compiledPath.variables() > 0) { @@ -35,6 +39,10 @@ public void init(String[] path, CompiledPath compiledPath, Map h this.bodyStream = bodyStream; } + public HTTPServer server() { + return server; + } + public Map variables() { return variables; } @@ -43,6 +51,10 @@ public Map query() { return query; } + public Map headers() { + return Collections.unmodifiableMap(headers); + } + public String header(String name) { return headers.getOrDefault(name.toLowerCase(), ""); } @@ -60,7 +72,14 @@ public InputStream bodyStream() { } public byte[] bodyBytes() throws IOException { - return bodyStream().readAllBytes(); + var h = header("content-length"); + + if (h.isEmpty()) { + return bodyStream().readAllBytes(); + } + + int len = Integer.parseInt(h); + return bodyStream().readNBytes(len); } public String body() throws IOException { diff --git a/src/main/java/dev/latvian/apps/tinyserver/http/response/HTTPResponseBuilder.java b/src/main/java/dev/latvian/apps/tinyserver/http/response/HTTPResponseBuilder.java index 5ca468a..46cc60d 100644 --- a/src/main/java/dev/latvian/apps/tinyserver/http/response/HTTPResponseBuilder.java +++ b/src/main/java/dev/latvian/apps/tinyserver/http/response/HTTPResponseBuilder.java @@ -1,6 +1,9 @@ package dev.latvian.apps.tinyserver.http.response; import dev.latvian.apps.tinyserver.content.ResponseContent; +import dev.latvian.apps.tinyserver.ws.WSResponse; +import dev.latvian.apps.tinyserver.ws.WSSession; +import org.jetbrains.annotations.Nullable; import java.io.OutputStream; import java.nio.charset.StandardCharsets; @@ -17,6 +20,7 @@ public class HTTPResponseBuilder { private HTTPStatus status = HTTPStatus.NO_CONTENT; private final Map headers = new HashMap<>(); private ResponseContent body = null; + private WSSession wsSession = null; public void setStatus(HTTPStatus status) { this.status = status; @@ -60,4 +64,17 @@ public void write(OutputStream out, boolean writeBody) throws Exception { body.write(out); } } + + public void setResponse(HTTPResponse response) throws Exception { + response.build(this); + + if (response instanceof WSResponse res) { + wsSession = res.session(); + } + } + + @Nullable + public WSSession wsSession() { + return wsSession; + } } diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/EmptyWSHandler.java b/src/main/java/dev/latvian/apps/tinyserver/ws/EmptyWSHandler.java index bdc2230..deb32cc 100644 --- a/src/main/java/dev/latvian/apps/tinyserver/ws/EmptyWSHandler.java +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/EmptyWSHandler.java @@ -14,6 +14,10 @@ public Map> sessions() { return Map.of(); } + @Override + public void broadcast(Frame frame) { + } + @Override public void broadcastText(String payload) { } diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/Frame.java b/src/main/java/dev/latvian/apps/tinyserver/ws/Frame.java new file mode 100644 index 0000000..8abf03f --- /dev/null +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/Frame.java @@ -0,0 +1,121 @@ +package dev.latvian.apps.tinyserver.ws; + +import org.jetbrains.annotations.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.Random; + +public record Frame( + Opcode opcode, + boolean mask, + boolean fin, + boolean rsv1, + boolean rsv2, + boolean rsv3, + byte[] payload +) { + public static Frame simple(Opcode opcode, boolean mask, byte[] payload) { + return new Frame(opcode, mask, true, false, false, false, payload); + } + + public static Frame text(String text) { + return simple(Opcode.TEXT, false, text.getBytes(StandardCharsets.UTF_8)); + } + + public static Frame binary(byte[] bytes) { + return simple(Opcode.BINARY, false, bytes); + } + + public static Frame read(InputStream stream) throws IOException { + int b1 = stream.read(); + var opcode = Opcode.get(b1 & 0x0F); + boolean fin = (b1 & 0x80) != 0; + boolean rsv1 = (b1 & 0x40) != 0; + boolean rsv2 = (b1 & 0x20) != 0; + boolean rsv3 = (b1 & 0x10) != 0; + + int b2 = stream.read(); + boolean mask = (b2 & -128) != 0; + int payloadlength = (byte) (b2 & ~(byte) 128); + + if (payloadlength == 126) { + var sizebytes = new byte[3]; + sizebytes[1] = (byte) stream.read(); + sizebytes[2] = (byte) stream.read(); + payloadlength = new BigInteger(sizebytes).intValue(); + } else if (payloadlength == 127) { + byte[] bytes = new byte[8]; + stream.read(bytes); + payloadlength = (int) new BigInteger(bytes).longValue(); + } + + var payload = new byte[payloadlength]; + + if (mask) { + var maskKey = new byte[4]; + stream.read(maskKey); + + for (int i = 0; i < payloadlength; i++) { + payload[i] = (byte) (stream.read() ^ maskKey[i % 4]); + } + } else { + stream.read(payload); + } + + return new Frame(opcode, mask, fin, rsv1, rsv2, rsv3, payload); + } + + public void write(Random random, OutputStream stream) throws IOException { + stream.write((fin ? 0x80 : 0) + | (rsv1 ? 0x40 : 0) + | (rsv2 ? 0x20 : 0) + | (rsv3 ? 0x10 : 0) + | opcode.opcode + ); + + if (payload.length < 126) { + stream.write((mask ? 0x80 : 0) | payload.length); + } else if (payload.length < 65536) { + stream.write((mask ? 0x80 : 0) | 126); + stream.write(payload.length >> 8); + stream.write(payload.length); + } else { + stream.write((mask ? 0x80 : 0) | 127); + stream.write(0); + stream.write(0); + stream.write(0); + stream.write(0); + stream.write(payload.length >> 24); + stream.write(payload.length >> 16); + stream.write(payload.length >> 8); + stream.write(payload.length); + } + + if (mask) { + var maskKey = new byte[4]; + random.nextBytes(maskKey); + stream.write(maskKey); + + for (int i = 0; i < payload.length; i++) { + stream.write(payload[i] ^ maskKey[i % 4]); + } + } else { + stream.write(payload); + } + } + + public Frame appendTo(@Nullable Frame previous) { + if (previous != null) { + byte[] newPayload = new byte[previous.payload.length + payload.length]; + System.arraycopy(previous.payload, 0, newPayload, 0, previous.payload.length); + System.arraycopy(payload, 0, newPayload, previous.payload.length, payload.length); + return new Frame(previous.opcode, previous.mask, fin, previous.rsv1, previous.rsv2, previous.rsv3, newPayload); + } + + return this; + } +} diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/Opcode.java b/src/main/java/dev/latvian/apps/tinyserver/ws/Opcode.java new file mode 100644 index 0000000..9de6adf --- /dev/null +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/Opcode.java @@ -0,0 +1,31 @@ +package dev.latvian.apps.tinyserver.ws; + +public enum Opcode { + CONTINUOUS(0), + TEXT(1), + BINARY(2), + + CLOSING(8), + PING(9), + PONG(10), + + ; + + public static Opcode get(int opcode) { + return switch (opcode) { + case 0 -> CONTINUOUS; + case 1 -> TEXT; + case 2 -> BINARY; + case 8 -> CLOSING; + case 9 -> PING; + case 10 -> PONG; + default -> throw new IllegalArgumentException("Invalid opcode: " + opcode); + }; + } + + public final byte opcode; + + Opcode(int opcode) { + this.opcode = (byte) opcode; + } +} \ No newline at end of file diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/RXThread.java b/src/main/java/dev/latvian/apps/tinyserver/ws/RXThread.java new file mode 100644 index 0000000..74c832a --- /dev/null +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/RXThread.java @@ -0,0 +1,57 @@ +package dev.latvian.apps.tinyserver.ws; + +import dev.latvian.apps.tinyserver.StatusCode; + +import java.nio.charset.StandardCharsets; +import java.util.concurrent.locks.LockSupport; + +class RXThread extends Thread { + private final WSSession session; + private Frame lastFrame; + + public RXThread(WSSession session) { + super("WSSession-" + session.id + "-RX"); + this.session = session; + } + + @Override + public void run() { + while (session.txThread.closeReason == null) { + try { + var frame = Frame.read(session.txThread.in); + var payload = frame.payload(); + + switch (frame.opcode()) { + case CONTINUOUS, TEXT, BINARY -> { + lastFrame = frame.appendTo(lastFrame); + + if (frame.fin()) { + switch (lastFrame.opcode()) { + case TEXT -> session.onTextMessage(new String(lastFrame.payload(), StandardCharsets.UTF_8)); + case BINARY -> session.onBinaryMessage(lastFrame.payload()); + } + + lastFrame = null; + } + } + case PING -> session.send(new Frame(Opcode.PONG, frame.mask(), frame.fin(), frame.rsv1(), frame.rsv2(), frame.rsv3(), payload)); + case CLOSING -> { + if (payload.length > 0) { + var code = (payload[0] << 8) | payload[1]; + session.txThread.closeReason = new StatusCode(code, new String(payload, 2, payload.length - 2, StandardCharsets.UTF_8)); + } else { + session.txThread.closeReason = WSCloseStatus.CLOSED.statusCode; + } + + session.txThread.remoteClosed = true; + session.send(Frame.simple(Opcode.CLOSING, false, payload)); + session.rxThread = null; + LockSupport.unpark(session.txThread); + } + } + } catch (Exception ex) { + session.onError(ex); + } + } + } +} diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/TXThread.java b/src/main/java/dev/latvian/apps/tinyserver/ws/TXThread.java new file mode 100644 index 0000000..bb0f8e2 --- /dev/null +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/TXThread.java @@ -0,0 +1,80 @@ +package dev.latvian.apps.tinyserver.ws; + +import dev.latvian.apps.tinyserver.StatusCode; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.util.Deque; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.locks.LockSupport; + +class TXThread extends Thread { + private final WSSession session; + private final Socket socket; + final InputStream in; + private final OutputStream out; + StatusCode closeReason; + boolean remoteClosed; + Deque queue; + private final Random random; + + public TXThread(WSSession session, Socket socket, InputStream in, OutputStream out) { + super("WSSession-" + session.id + "-TX"); + this.session = session; + this.socket = socket; + this.in = in; + this.out = out; + this.queue = new ConcurrentLinkedDeque<>(); + this.random = new Random(); + } + + @Override + public void run() { + while (session.txThread == this) { + var p = queue.poll(); + + if (p != null) { + try { + p.write(random, out); + } catch (IOException e) { + session.onError(e); + break; + } + } else { + try { + out.flush(); + } catch (IOException ignored) { + break; + } + + if (closeReason != null) { + break; + } else { + LockSupport.park(); + } + } + } + + session.sessionMap.remove(session.id); + session.rxThread = null; + session.onClose(closeReason, remoteClosed); + + try { + in.close(); + } catch (Exception ignore) { + } + + try { + out.close(); + } catch (Exception ignore) { + } + + try { + socket.close(); + } catch (Exception ignore) { + } + } +} diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/WSCloseStatus.java b/src/main/java/dev/latvian/apps/tinyserver/ws/WSCloseStatus.java new file mode 100644 index 0000000..8cdd134 --- /dev/null +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/WSCloseStatus.java @@ -0,0 +1,18 @@ +package dev.latvian.apps.tinyserver.ws; + +import dev.latvian.apps.tinyserver.StatusCode; + +public enum WSCloseStatus { + CLOSED(1000, "Closed"), + GOING_AWAY(1001, "Going Away"), + PROTOCOL_ERROR(1002, "Protocol Error"), + UNSUPPORTED_DATA(1003, "Unsupported Data"), + + ; + + public final StatusCode statusCode; + + WSCloseStatus(int code, String reason) { + this.statusCode = new StatusCode(code, reason); + } +} diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/WSEndpointHandler.java b/src/main/java/dev/latvian/apps/tinyserver/ws/WSEndpointHandler.java new file mode 100644 index 0000000..3b16108 --- /dev/null +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/WSEndpointHandler.java @@ -0,0 +1,54 @@ +package dev.latvian.apps.tinyserver.ws; + +import dev.latvian.apps.tinyserver.http.HTTPHandler; +import dev.latvian.apps.tinyserver.http.HTTPRequest; +import dev.latvian.apps.tinyserver.http.response.HTTPResponse; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.Base64; +import java.util.Map; +import java.util.UUID; + +public record WSEndpointHandler>(WSSessionFactory factory, Map sessions, boolean daemon) implements WSHandler, HTTPHandler { + private static final byte[] WEB_SOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(StandardCharsets.UTF_8); + + @Override + public HTTPResponse handle(REQ req) throws Exception { + var session = factory.create(); + var uuidBase64 = req.header("sec-websocket-key").getBytes(StandardCharsets.UTF_8); + session.id = UUID.nameUUIDFromBytes(Base64.getDecoder().decode(uuidBase64)); + + /* + System.out.println("Request: " + String.join("/", req.path())); + System.out.println("- Query:"); + + for (var e : req.query().entrySet()) { + System.out.println(" " + e.getKey() + ": " + e.getValue()); + } + + System.out.println("- Variables:"); + + for (var e : req.variables().entrySet()) { + System.out.println(" " + e.getKey() + ": " + e.getValue()); + } + + System.out.println("- Headers:"); + + for (var e : req.headers().entrySet()) { + System.out.println(" " + e.getKey() + ": " + e.getValue()); + } + + System.out.println("UUID: " + session.id); + */ + + var digest = MessageDigest.getInstance("SHA-1"); + digest.update(uuidBase64); + digest.update(WEB_SOCKET_GUID); + byte[] sha1 = digest.digest(); + + session.sessionMap = (Map) sessions; + sessions.put(session.id, session); + return new WSResponse(session, sha1); + } +} \ No newline at end of file diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/WSHandler.java b/src/main/java/dev/latvian/apps/tinyserver/ws/WSHandler.java index b111b24..68e2db9 100644 --- a/src/main/java/dev/latvian/apps/tinyserver/ws/WSHandler.java +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/WSHandler.java @@ -2,7 +2,6 @@ import dev.latvian.apps.tinyserver.http.HTTPRequest; -import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.UUID; import java.util.function.Supplier; @@ -14,11 +13,21 @@ static > WSHandler Map sessions(); + default void broadcast(Frame frame) { + var s = sessions().values(); + + if (!s.isEmpty()) { + for (var session : s) { + session.send(frame); + } + } + } + default void broadcastText(String payload) { var s = sessions().values(); if (!s.isEmpty()) { - var p = new WSPayload(true, payload.getBytes(StandardCharsets.UTF_8)); + var p = Frame.text(payload); for (var session : s) { session.send(p); @@ -30,7 +39,7 @@ default void broadcastText(Supplier payload) { var s = sessions().values(); if (!s.isEmpty()) { - var p = new WSPayload(true, payload.get().getBytes(StandardCharsets.UTF_8)); + var p = Frame.text(payload.get()); for (var session : s) { session.send(p); @@ -42,7 +51,7 @@ default void broadcastBinary(byte[] payload) { var s = sessions().values(); if (!s.isEmpty()) { - var p = new WSPayload(false, payload); + var p = Frame.binary(payload); for (var session : s) { session.send(p); @@ -54,7 +63,7 @@ default void broadcastBinary(Supplier payload) { var s = sessions().values(); if (!s.isEmpty()) { - var p = new WSPayload(false, payload.get()); + var p = Frame.binary(payload.get()); for (var session : s) { session.send(p); diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/WSPayload.java b/src/main/java/dev/latvian/apps/tinyserver/ws/WSPayload.java deleted file mode 100644 index 36707a3..0000000 --- a/src/main/java/dev/latvian/apps/tinyserver/ws/WSPayload.java +++ /dev/null @@ -1,4 +0,0 @@ -package dev.latvian.apps.tinyserver.ws; - -public record WSPayload(boolean text, byte[] bytes) { -} diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/WSResponse.java b/src/main/java/dev/latvian/apps/tinyserver/ws/WSResponse.java new file mode 100644 index 0000000..2afc5ee --- /dev/null +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/WSResponse.java @@ -0,0 +1,17 @@ +package dev.latvian.apps.tinyserver.ws; + +import dev.latvian.apps.tinyserver.http.response.HTTPResponse; +import dev.latvian.apps.tinyserver.http.response.HTTPResponseBuilder; +import dev.latvian.apps.tinyserver.http.response.HTTPStatus; + +import java.util.Base64; + +public record WSResponse(WSSession session, byte[] accept) implements HTTPResponse { + @Override + public void build(HTTPResponseBuilder payload) { + payload.setStatus(HTTPStatus.SWITCHING_PROTOCOLS); + payload.setHeader("Upgrade", "websocket"); + payload.setHeader("Connection", "Upgrade"); + payload.setHeader("Sec-WebSocket-Accept", Base64.getEncoder().encodeToString(accept)); + } +} diff --git a/src/main/java/dev/latvian/apps/tinyserver/ws/WSSession.java b/src/main/java/dev/latvian/apps/tinyserver/ws/WSSession.java index 6412257..ec0008b 100644 --- a/src/main/java/dev/latvian/apps/tinyserver/ws/WSSession.java +++ b/src/main/java/dev/latvian/apps/tinyserver/ws/WSSession.java @@ -3,28 +3,45 @@ import dev.latvian.apps.tinyserver.StatusCode; import dev.latvian.apps.tinyserver.http.HTTPRequest; -import java.nio.charset.StandardCharsets; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.util.Map; import java.util.UUID; +import java.util.concurrent.locks.LockSupport; public class WSSession { - WSHandler handler; + Map> sessionMap; UUID id; - StatusCode closeReason; + TXThread txThread; + RXThread rxThread; - public UUID id() { + public final void start(Socket socket, InputStream in, OutputStream out) { + this.txThread = new TXThread(this, socket, in, out); + this.txThread.setDaemon(true); + + this.rxThread = new RXThread(this); + this.rxThread.setDaemon(true); + + this.txThread.start(); + this.rxThread.start(); + } + + public final UUID id() { return id; } - public void send(WSPayload payload) { - // FIXME + public final void send(Frame frame) { + txThread.queue.add(frame); + LockSupport.unpark(txThread); } - public void sendText(String payload) { - send(new WSPayload(true, payload.getBytes(StandardCharsets.UTF_8))); + public final void sendText(String payload) { + send(Frame.text(payload)); } - public void sendBinary(byte[] payload) { - send(new WSPayload(false, payload)); + public final void sendBinary(byte[] payload) { + send(Frame.binary(payload)); } public void onOpen(REQ req) { @@ -34,7 +51,6 @@ public void onClose(StatusCode reason, boolean remote) { } public void onError(Throwable error) { - error.printStackTrace(); } public void onTextMessage(String message) { @@ -43,7 +59,9 @@ public void onTextMessage(String message) { public void onBinaryMessage(byte[] message) { } - public void close(String reason) { - closeReason = new StatusCode(1001, reason); // FIXME + public final void close(WSCloseStatus status, String reason) { + txThread.remoteClosed = false; + txThread.closeReason = new StatusCode(status.statusCode.code(), reason); + LockSupport.unpark(txThread); } } diff --git a/src/test/java/dev/latvian/apps/tinyserver/test/TestWSSession.java b/src/test/java/dev/latvian/apps/tinyserver/test/TestWSSession.java new file mode 100644 index 0000000..28e63ce --- /dev/null +++ b/src/test/java/dev/latvian/apps/tinyserver/test/TestWSSession.java @@ -0,0 +1,22 @@ +package dev.latvian.apps.tinyserver.test; + +import dev.latvian.apps.tinyserver.StatusCode; +import dev.latvian.apps.tinyserver.http.HTTPRequest; +import dev.latvian.apps.tinyserver.ws.WSSession; + +public class TestWSSession extends WSSession { + @Override + public void onOpen(HTTPRequest req) { + sendText("Hello from " + id() + "! " + req.variables() + ", " + req.headers()); + } + + @Override + public void onClose(StatusCode reason, boolean remote) { + System.out.println("WS " + id() + " Closed: " + reason + ", remote: " + remote); + } + + @Override + public void onTextMessage(String message) { + System.out.println("WS: " + message); + } +} diff --git a/src/test/java/dev/latvian/apps/tinyserver/test/TinyServerTest.java b/src/test/java/dev/latvian/apps/tinyserver/test/TinyServerTest.java index 1bebb0a..73ac261 100644 --- a/src/test/java/dev/latvian/apps/tinyserver/test/TinyServerTest.java +++ b/src/test/java/dev/latvian/apps/tinyserver/test/TinyServerTest.java @@ -12,7 +12,7 @@ public class TinyServerTest { public static HTTPServer server; public static WSHandler> wsHandler; - public static void main(String[] args) throws Exception { + public static void main(String[] args) { server = new HTTPServer<>(HTTPRequest::new); server.setServerName("TinyServer Test"); server.setAddress("127.0.0.1"); @@ -27,7 +27,7 @@ public static void main(String[] args) throws Exception { server.get("/redirect", TinyServerTest::redirect); server.post("/console", TinyServerTest::console); server.get("/stop", TinyServerTest::stop); - wsHandler = server.ws("/console/{console-type}"); + wsHandler = server.ws("/console/{console-type}", TestWSSession::new); System.out.println("Started server at https://localhost:" + server.start()); } @@ -41,22 +41,22 @@ private static HTTPResponse test(HTTPRequest req) { } private static HTTPResponse variable(HTTPRequest req) { - return HTTPResponse.ok().text("Test: " + req.variables().get("test")); + return HTTPResponse.ok().text("Test: " + req.variables().get("test")).header("X-ABC", "Def"); } private static HTTPResponse varpath(HTTPRequest req) { return HTTPResponse.ok().text("Test: " + req.variables().get("test")); } - private static HTTPResponse redirect(HTTPRequest req) { - return HTTPResponse.redirect("/"); - } - private static HTTPResponse console(HTTPRequest req) throws IOException { wsHandler.broadcastText(req.body()); return HTTPResponse.noContent(); } + private static HTTPResponse redirect(HTTPRequest req) { + return HTTPResponse.redirect("/"); + } + private static HTTPResponse stop(HTTPRequest req) { server.stop(); return HTTPResponse.noContent();