diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index a9df35d92..f5c011646 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -2,39 +2,21 @@ import type { ReadableStreamController, WritableStreamDefaultController, } from 'stream/web'; -import type { - HttpRequest, - HttpResponse, - us_socket_context_t, - WebSocket, -} from 'uWebSockets.js'; -import type { FileSystem, JSONValue, PromiseDeconstructed } from '../types'; +import type { JSONValue } from '../types'; import type { TLSConfig } from '../network/types'; +import type { IncomingMessage } from 'http'; import { WritableStream, ReadableStream } from 'stream/web'; -import path from 'path'; -import os from 'os'; +import https from 'https'; import { startStop } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; -import uWebsocket from 'uWebSockets.js'; +import * as ws from 'ws'; import WebSocketStream from './WebSocketStream'; import * as webSocketErrors from './errors'; import * as webSocketEvents from './events'; -import { promise } from '../utils'; +import { never, promise } from '../utils'; type ConnectionCallback = (streamPair: WebSocketStream) => void; -type Context = { - message: ( - ws: WebSocket, - message: ArrayBuffer, - isBinary: boolean, - ) => void; - drain: (ws: WebSocket) => void; - close: (ws: WebSocket, code: number, message: ArrayBuffer) => void; - pong: (ws: WebSocket, message: ArrayBuffer) => void; - logger: Logger; -}; - /** * Events: * - start @@ -57,7 +39,6 @@ class WebSocketServer extends EventTarget { * Default is 1,000 milliseconds. * @param obj.pingTimeoutTimeTime - Time before connection is cleaned up after no ping responses. * Default is 10,000 milliseconds. - * @param obj.fs - FileSystem interface used for creating files. * @param obj.maxReadableStreamBytes - The number of bytes the readable stream will buffer until pausing. * @param obj.logger */ @@ -70,8 +51,7 @@ class WebSocketServer extends EventTarget { maxIdleTimeout = 120, pingIntervalTime = 1_000, pingTimeoutTimeTime = 10_000, - fs = require('fs'), - maxReadableStreamBytes = 1_000_000_000, // About 1 GB + maxReadableStreamBytes = 1_000, // About 1 GB logger = new Logger(this.name), }: { connectionCallback: ConnectionCallback; @@ -82,14 +62,12 @@ class WebSocketServer extends EventTarget { maxIdleTimeout?: number; pingIntervalTime?: number; pingTimeoutTimeTime?: number; - fs?: FileSystem; maxReadableStreamBytes?: number; logger?: Logger; }) { logger.info(`Creating ${this.name}`); const wsServer = new this( logger, - fs, maxReadableStreamBytes, maxIdleTimeout, pingIntervalTime, @@ -106,20 +84,18 @@ class WebSocketServer extends EventTarget { return wsServer; } - protected server: uWebsocket.TemplatedApp; - protected listenSocket: uWebsocket.us_listen_socket; + protected server: https.Server; + protected webSocketServer: ws.WebSocketServer; protected _port: number; protected _host: string; protected connectionEventHandler: ( event: webSocketEvents.ConnectionEvent, ) => void; protected activeSockets: Set = new Set(); - protected connectionIndex: number = 0; /** * * @param logger - * @param fs * @param maxReadableStreamBytes Max number of bytes stored in read buffer before error * @param maxIdleTimeout * @param pingIntervalTime @@ -127,7 +103,6 @@ class WebSocketServer extends EventTarget { */ constructor( protected logger: Logger, - protected fs: FileSystem, protected maxReadableStreamBytes, protected maxIdleTimeout: number | undefined, protected pingIntervalTime: number, @@ -138,7 +113,6 @@ class WebSocketServer extends EventTarget { public async start({ tlsConfig, - basePath = os.tmpdir(), host, port = 0, connectionCallback, @@ -158,47 +132,30 @@ class WebSocketServer extends EventTarget { }; this.addEventListener('connection', this.connectionEventHandler); } - await this.setupServer(basePath, tlsConfig); - this.server.ws('/*', { - sendPingsAutomatically: true, - idleTimeout: this.maxIdleTimeout, - upgrade: this.upgrade, - open: this.open, - message: this.message, - close: this.close, - drain: this.drain, - pong: this.pong, - // Ping uses default behaviour. - // We don't use subscriptions. + this.server = https.createServer({ + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, }); - this.server.any('/*', (res, _) => { - // Reject normal requests with an upgrade code - res - .writeStatus('426') - .writeHeader('connection', 'Upgrade') - .writeHeader('upgrade', 'websocket') - .end('426 Upgrade Required', true); + this.webSocketServer = new ws.WebSocketServer({ + host: this._host, + port: this._port, + server: this.server, }); + + this.webSocketServer.on('connection', this.connectionHandler); + // This.webSocketServer.on('error', console.error); + // this.webSocketServer.on('close', this.closeHandler); + + // TODO: tell normal requests to upgrade. const listenProm = promise(); - const listenCallback = (listenSocket) => { - if (listenSocket) { - this.listenSocket = listenSocket; - listenProm.resolveP(); - } else { - listenProm.rejectP(new webSocketErrors.ErrorServerPortUnavailable()); - } - }; - if (host != null) { - // With custom host - this.server.listen(host, port ?? 0, listenCallback); - } else { - // With default host - this.server.listen(port, listenCallback); - } + this.server.listen(port ?? 0, host, listenProm.resolveP); await listenProm.p; - this._port = uWebsocket.us_socket_local_port(this.listenSocket); + const address = this.server.address(); + // TODO: handle string + if (address == null || typeof address === 'string') never(); + this._port = address.port; this.logger.debug(`Listening on port ${this._port}`); - this._host = host ?? '127.0.0.1'; + this._host = address.address ?? '127.0.0.1'; this.dispatchEvent( new webSocketEvents.StartEvent({ detail: { @@ -213,7 +170,7 @@ class WebSocketServer extends EventTarget { public async stop(force: boolean = false): Promise { this.logger.info(`Stopping ${this.constructor.name}`); // Close the server by closing the underlying socket - uWebsocket.us_listen_socket_close(this.listenSocket); + this.server.close(); // Shutting down active websockets if (force) { for (const webSocketStream of this.activeSockets) { @@ -242,68 +199,31 @@ class WebSocketServer extends EventTarget { return this._host; } - /** - * This creates the pem files and starts the server with them. It ensures that - * files are cleaned up to the best of its ability. - */ - protected async setupServer(basePath: string, tlsConfig: TLSConfig) { - const tmpDir = await this.fs.promises.mkdtemp( - path.join(basePath, 'polykey-'), - ); - // TODO: The key file needs to be in the encrypted format - const keyFile = path.join(tmpDir, 'keyFile.pem'); - const certFile = path.join(tmpDir, 'certFile.pem'); - await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); - await this.fs.promises.writeFile(certFile, tlsConfig.certChainPem); - try { - this.server = uWebsocket.SSLApp({ - key_file_name: keyFile, - cert_file_name: certFile, - }); - } finally { - await this.fs.promises.rm(keyFile); - await this.fs.promises.rm(certFile); - await this.fs.promises.rm(tmpDir, { recursive: true, force: true }); - } - } - - /** - * Applies default upgrade behaviour and creates a UserData object we can - * mutate for the Context - */ - protected upgrade = ( - res: HttpResponse, - req: HttpRequest, - context: us_socket_context_t, - ) => { - const logger = this.logger.getChild(`Connection ${this.connectionIndex}`); - res.upgrade>( - { - logger, - }, - req.getHeader('sec-websocket-key'), - req.getHeader('sec-websocket-protocol'), - req.getHeader('sec-websocket-extensions'), - context, - ); - this.connectionIndex += 1; - }; - /** * Handles the creation of the `ReadableWritablePair` and provides it to the * StreamPair handler. */ - protected open = (ws: WebSocket) => { + protected connectionHandler = ( + webSocket: ws.WebSocket, + request: IncomingMessage, + ) => { + const socket = request.connection; const webSocketStream = new WebSocketStreamServerInternal( - ws, + webSocket, this.maxReadableStreamBytes, this.pingIntervalTime, this.pingTimeoutTimeTime, - {}, // TODO: fill in connection metadata + { + localHost: socket.localAddress ?? '', + localPort: socket.localPort ?? 0, + remoteHost: socket.remoteAddress ?? '', + remotePort: socket.remotePort ?? '', + }, + this.logger.getChild(WebSocketStreamServerInternal.name), ); // Adding socket to the active sockets map this.activeSockets.add(webSocketStream); - webSocketStream.endedProm + void webSocketStream.endedProm // Ignore errors, we only care that it finished .catch(() => {}) .finally(() => { @@ -320,54 +240,24 @@ class WebSocketServer extends EventTarget { }), ); }; - - /** - * Routes incoming messages to each stream using the `Context` message - * callback. - */ - protected message = ( - ws: WebSocket, - message: ArrayBuffer, - isBinary: boolean, - ) => { - ws.getUserData().message(ws, message, isBinary); - }; - - protected drain = (ws: WebSocket) => { - ws.getUserData().drain(ws); - }; - - protected close = ( - ws: WebSocket, - code: number, - message: ArrayBuffer, - ) => { - ws.getUserData().close(ws, code, message); - }; - - protected pong = (ws: WebSocket, message: ArrayBuffer) => { - ws.getUserData().pong(ws, message); - }; } class WebSocketStreamServerInternal extends WebSocketStream { - protected backPressure: PromiseDeconstructed | null = null; - protected writeBackpressure: boolean = false; + protected readableBackpressure: boolean = false; protected writableController: WritableStreamDefaultController | undefined; protected readableController: | ReadableStreamController | undefined; constructor( - protected ws: WebSocket, + protected webSocket: ws.WebSocket, maxReadBufferBytes: number, pingInterval: number, pingTimeoutTime: number, protected metadata: Record, + protected logger: Logger, ) { super(); - const context = ws.getUserData(); - const logger = context.logger; logger.info('WS opened'); const writableLogger = logger.getChild('Writable'); const readableLogger = logger.getChild('Readable'); @@ -377,38 +267,35 @@ class WebSocketStreamServerInternal extends WebSocketStream { this.writableController = controller; }, write: async (chunk, controller) => { - await this.backPressure?.p; - const writeResult = ws.send(chunk, true); - switch (writeResult) { - default: - case 2: - // Write failure, emit error - writableLogger.error('Send error'); - controller.error(new webSocketErrors.ErrorServerSendFailed()); - break; - case 0: - writableLogger.info('Write backpressure'); - // Signal backpressure - this.backPressure = promise(); - this.writeBackpressure = true; - this.backPressure.p.finally(() => { - this.writeBackpressure = false; - }); - break; - case 1: - // Success - writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); - break; + const writeResultProm = promise(); + this.webSocket.send(chunk, (err) => { + if (err == null) writeResultProm.resolveP(); + else writeResultProm.rejectP(err); + }); + try { + await writeResultProm.p; + writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); + } catch (e) { + console.error(`Failed sending`, e); + this.logger.error(e); + controller.error(new webSocketErrors.ErrorServerSendFailed()); } }, - close: () => { + close: async () => { writableLogger.info('Closed, sending null message'); - if (!this._webSocketEnded) ws.send(Buffer.from([]), true); + if (!this._webSocketEnded) { + const endProm = promise(); + this.webSocket.send(Buffer.from([]), (err) => { + if (err == null) endProm.resolveP(); + else endProm.rejectP(err); + }); + await endProm.p; + } this.signalWritableEnd(); if (this._readableEnded && !this._webSocketEnded) { writableLogger.debug('Ending socket'); this.signalWebSocketEnd(); - ws.end(); + this.webSocket.close(); } }, abort: (reason) => { @@ -416,7 +303,7 @@ class WebSocketStreamServerInternal extends WebSocketStream { if (this._readableEnded && !this._webSocketEnded) { writableLogger.debug('Ending socket'); this.signalWebSocketEnd(reason); - ws.end(4000, 'Aborting connection'); + this.webSocket.close(4000, 'Aborting connection'); } }, }); @@ -425,11 +312,15 @@ class WebSocketStreamServerInternal extends WebSocketStream { { start: (controller) => { this.readableController = controller; - context.message = (ws, message, _) => { - const messageBuffer = Buffer.from(message); + const messageHandler = (data: ws.RawData, isBinary: boolean) => { + if (!isBinary) never(); + console.log(data.toString()); + if (data instanceof Array) never(); + const messageBuffer = Buffer.from(data); readableLogger.debug(`Received ${messageBuffer.toString()}`); - if (message.byteLength === 0) { + if (messageBuffer.byteLength === 0) { readableLogger.debug('Null message received'); + this.webSocket.off('message', messageHandler); if (!this._readableEnded) { readableLogger.debug('Closing'); this.signalReadableEnd(); @@ -437,50 +328,52 @@ class WebSocketStreamServerInternal extends WebSocketStream { if (this._writableEnded && !this._webSocketEnded) { readableLogger.debug('Ending socket'); this.signalWebSocketEnd(); - ws.end(); + this.webSocket.close(); } } return; } + console.log(this._readableEnded); controller.enqueue(messageBuffer); if (controller.desiredSize != null && controller.desiredSize < 0) { - readableLogger.error('Read stream buffer full'); - const err = new webSocketErrors.ErrorServerReadableBufferLimit(); - if (!this._webSocketEnded) { - this.signalWebSocketEnd(err); - ws.end(4000, 'Read stream buffer full'); - } - controller.error(err); + this.webSocket.pause(); + this.readableBackpressure = true; } }; + this.webSocket.on('message', messageHandler); + }, + pull: () => { + this.webSocket.resume(); + this.readableBackpressure = false; }, cancel: (reason) => { this.signalReadableEnd(reason); if (this._writableEnded && !this._webSocketEnded) { readableLogger.debug('Ending socket'); this.signalWebSocketEnd(); - ws.end(); + this.webSocket.close(); } }, }, { - highWaterMark: maxReadBufferBytes, - size: (chunk) => chunk?.byteLength ?? 0, + highWaterMark: 1, }, ); const pingTimer = setInterval(() => { - ws.ping(); + this.webSocket.ping(); }, pingInterval); const pingTimeoutTimeTimer = setTimeout(() => { logger.debug('Ping timed out'); - ws.end(); + this.webSocket.close(); }, pingTimeoutTime); - context.pong = () => { - logger.debug('Received pong'); + const pongHandler = (data: Buffer) => { + logger.debug(`Received pong with (${data.toString()})`); pingTimeoutTimeTimer.refresh(); }; - context.close = () => { + this.webSocket.on('pong', pongHandler); + + const closeHandler = () => { logger.debug('Closing'); this.signalWebSocketEnd(); // Cleaning up timers @@ -493,6 +386,7 @@ class WebSocketStreamServerInternal extends WebSocketStream { if (!this._readableEnded) { readableLogger.debug('Closing'); this.signalReadableEnd(err); + console.log('EROROROROROED'); this.readableController?.error(err); } if (!this._writableEnded) { @@ -501,10 +395,7 @@ class WebSocketStreamServerInternal extends WebSocketStream { this.writableController?.error(err); } }; - context.drain = () => { - logger.debug('Drained'); - this.backPressure?.resolveP(); - }; + this.webSocket.once('close', closeHandler); } get meta(): Record { @@ -518,6 +409,7 @@ class WebSocketStreamServerInternal extends WebSocketStream { const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); // Close the streams with the given error, if (!this._readableEnded) { + console.log('ERRORORORROED'); this.readableController?.error(err); this.signalReadableEnd(err); } @@ -527,7 +419,7 @@ class WebSocketStreamServerInternal extends WebSocketStream { } // Then close the websocket if (!this._webSocketEnded) { - this.ws.end(4000, 'Ending connection'); + this.webSocket.close(4000, 'Ending connection'); this.signalWebSocketEnd(err); } } diff --git a/tests/scratch.test.ts b/tests/scratch.test.ts index c7c21d965..7c748cab7 100644 --- a/tests/scratch.test.ts +++ b/tests/scratch.test.ts @@ -1,5 +1,13 @@ +import type { IncomingMessage } from 'http'; +import type { TLSSocket } from 'tls'; +import https from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import * as ws from 'ws'; +import { sleep } from 'ix/asynciterable/_sleep'; import NodeManager from '@/nodes/NodeManager'; +import * as keysUtils from '@/keys/utils'; +import { promise } from '@/utils'; +import * as testsUtils from './utils'; // This is a 'scratch paper' test file for quickly running tests in the CI describe('scratch', () => { @@ -12,3 +20,54 @@ describe('scratch', () => { test('Should avoid empty test suite', async () => { expect(1 + 1).toBe(2); }); + +test('ws server', async () => { + const keyPair = keysUtils.generateKeyPair(); + const tlsConfig = await testsUtils.createTLSConfig(keyPair); + const server = https.createServer({ + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, + }); + console.log(tlsConfig); + const webSocketServer = new ws.WebSocketServer({ + server, + }); + server.on('listening', (...args) => console.log('listening', args)); + + webSocketServer.on( + 'connection', + function connection(ws, request: IncomingMessage) { + console.log(request.connection.localAddress); + console.log(request.connection.localPort); + console.log(request.connection.remoteAddress); + console.log(request.connection.remotePort); + const tlsSocket = request.connection as TLSSocket; + console.log(tlsSocket.getCertificate()); + console.log(tlsSocket.getPeerCertificate()); + ws.on('error', console.error); + + ws.on('message', function message(data) { + console.log('received: %s', data); + }); + + ws.send('something'); + }, + ); + const listenProm = promise(); + server.listen(55555, '127.0.0.1', listenProm.resolveP); + await listenProm.p; + console.log(server.address()); + + // Try connecting! + const webSocket = new ws.WebSocket('wss://127.0.0.1:55555', { + rejectUnauthorized: false, + }); + webSocket.on('error', console.error); + + webSocket.on('open', function open() { + webSocket.send(Buffer.from('HELLO!')); + }); + + await sleep(2000); + server.close(); +}); diff --git a/tests/websockets/WebSocket.test.ts b/tests/websockets/WebSocket.test.ts index 3d56191d6..8ac3b7cda 100644 --- a/tests/websockets/WebSocket.test.ts +++ b/tests/websockets/WebSocket.test.ts @@ -2,7 +2,6 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; import type { KeyPair } from '@/keys/types'; import type http from 'http'; -import type WebSocketStream from '@/websockets/WebSocketStream'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -23,7 +22,7 @@ import * as testsUtils from '../utils'; // This file tests both the client and server together. They're too interlinked // to be separate. describe('WebSocket', () => { - const logger = new Logger('websocket test', LogLevel.WARN, [ + const logger = new Logger('websocket test', LogLevel.DEBUG, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), @@ -227,131 +226,6 @@ describe('WebSocket', () => { } }, ); - test('reverse backpressure', async () => { - const backpressure = promise(); - const resumeWriting = promise(); - let webSocketStream: WebSocketStream | null = null; - webSocketServer = await WebSocketServer.createWebSocketServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void Promise.allSettled([ - (async () => { - for await (const _ of streamPair.readable) { - // No touch, only consume - } - })(), - (async () => { - // Kidnap the context - // @ts-ignore: kidnap protected property - for (const websocket of webSocketServer.activeSockets.values()) { - webSocketStream = websocket; - } - if (webSocketStream == null) { - await streamPair.writable.close(); - return; - } - // Write until backPressured - const message = Buffer.alloc(128, 0xf0); - const writer = streamPair.writable.getWriter(); - // @ts-ignore: kidnap protected property - while (!webSocketStream.writeBackpressure) { - await writer.write(message); - } - logger.info('BACK PRESSURED'); - backpressure.resolveP(); - await resumeWriting.p; - for (let i = 0; i < 100; i++) { - await writer.write(message); - } - await writer.close(); - logger.info('WRITING ENDED'); - })(), - ]).catch((e) => logger.error(e.toString())); - }, - basePath: dataDir, - tlsConfig, - host, - logger: logger.getChild('server'), - }); - logger.info(`Server started on port ${webSocketServer.getPort()}`); - webSocketClient = await WebSocketClient.createWebSocketClient({ - host, - port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await webSocketClient.startConnection(); - await websocket.writable.close(); - - await backpressure.p; - // @ts-ignore: kidnap protected property - expect(webSocketStream.writeBackpressure).toBeTrue(); - resumeWriting.resolveP(); - // Consume all the back-pressured data - for await (const _ of websocket.readable) { - // No touch, only consume - } - // @ts-ignore: kidnap protected property - expect(webSocketStream.writeBackpressure).toBeFalse(); - logger.info('ending'); - }); - // Readable backpressure is not actually supported. We're dealing with it by - // using a buffer with a provided limit that can be very large. - test('exceeding readable buffer limit causes error', async () => { - const startReading = promise(); - const handlingProm = promise(); - webSocketServer = await WebSocketServer.createWebSocketServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - Promise.all([ - (async () => { - await startReading.p; - logger.info('Starting consumption'); - for await (const _ of streamPair.readable) { - // No touch, only consume - } - logger.info('Reads ended'); - })(), - (async () => { - await streamPair.writable.close(); - })(), - ]) - .catch(() => {}) - .finally(() => handlingProm.resolveP()); - }, - basePath: dataDir, - tlsConfig, - host, - // Setting a really low buffer limit - maxReadableStreamBytes: 1500, - logger: logger.getChild('server'), - }); - logger.info(`Server started on port ${webSocketServer.getPort()}`); - webSocketClient = await WebSocketClient.createWebSocketClient({ - host, - port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await webSocketClient.startConnection(); - const message = Buffer.alloc(1_000, 0xf0); - const writer = websocket.writable.getWriter(); - logger.info('Starting writes'); - await expect(async () => { - for (let i = 0; i < 100; i++) { - await writer.write(message); - } - }).rejects.toThrow(); - startReading.resolveP(); - logger.info('writes ended'); - await expect(async () => { - for await (const _ of websocket.readable) { - // No touch, only consume - } - }).rejects.toThrow(); - await handlingProm.p; - logger.info('ending'); - }); test('client ends connection abruptly', async () => { const streamPairProm = promise>(); @@ -463,7 +337,7 @@ describe('WebSocket', () => { }); // These describe blocks contains tests specific to either the client or server describe('WebSocketServer', () => { - testProp( + testProp.only( 'allows half closed writable closes first', [messagesArb, messagesArb], async (messages1, messages2) => {