diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index 5d99f79a5..43bd31ffe 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -1,6 +1,11 @@ +import type { + ReadableStreamController, + WritableStreamDefaultController, +} from 'stream/web'; +import type { JSONValue } from '../types'; import type { TLSConfig } from '../network/types'; import type { IncomingMessage, ServerResponse } from 'http'; -import type tls from 'tls'; +import { WritableStream, ReadableStream } from 'stream/web'; import https from 'https'; import { startStop, status } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; @@ -25,6 +30,7 @@ class WebSocketServer extends EventTarget { * @param obj * @param obj.connectionCallback - * @param obj.tlsConfig - TLSConfig containing the private key and cert chain used for TLS. + * @param obj.basePath - Directory path used for storing temp cert files for starting the `uWebsocket` server. * @param obj.host - Listen address to bind to. * @param obj.port - Listen port to bind to. * @param obj.maxIdleTimeout - Timeout time for when the connection is cleaned up after no activity. @@ -33,11 +39,13 @@ class WebSocketServer extends EventTarget { * Default is 1,000 milliseconds. * @param obj.pingTimeoutTimeTime - Time before connection is cleaned up after no ping responses. * Default is 10,000 milliseconds. + * @param obj.maxReadableStreamBytes - The number of bytes the readable stream will buffer until pausing. * @param obj.logger */ static async createWebSocketServer({ connectionCallback, tlsConfig, + basePath, host, port, maxIdleTimeout = 120, @@ -47,6 +55,7 @@ class WebSocketServer extends EventTarget { }: { connectionCallback: ConnectionCallback; tlsConfig: TLSConfig; + basePath?: string; host?: string; port?: number; maxIdleTimeout?: number; @@ -64,6 +73,7 @@ class WebSocketServer extends EventTarget { await wsServer.start({ connectionCallback, tlsConfig, + basePath, host, port, }); @@ -83,6 +93,7 @@ class WebSocketServer extends EventTarget { /** * * @param logger + * @param maxReadableStreamBytes Max number of bytes stored in read buffer before error * @param maxIdleTimeout * @param pingIntervalTime * @param pingTimeoutTimeTime @@ -103,6 +114,7 @@ class WebSocketServer extends EventTarget { connectionCallback, }: { tlsConfig: TLSConfig; + basePath?: string; host?: string; port?: number; connectionCallback?: ConnectionCallback; @@ -121,6 +133,8 @@ class WebSocketServer extends EventTarget { cert: tlsConfig.certChainPem, }); this.webSocketServer = new ws.WebSocketServer({ + host: this._host, + port: this._port, server: this.server, }); @@ -131,10 +145,21 @@ class WebSocketServer extends EventTarget { this.server.on('error', this.errorHandler); this.server.on('request', this.requestHandler); + // This.server.any('/*', (res, _) => { + // // Reject normal requests with an upgrade code + // res + // .writeStatus('426') + // .writeHeader('connection', 'Upgrade') + // .writeHeader('upgrade', 'websocket') + // .end('426 Upgrade Required', true); + // }); + + // TODO: tell normal requests to upgrade. const listenProm = promise(); this.server.listen(port ?? 0, host, listenProm.resolveP); await listenProm.p; const address = this.server.address(); + // TODO: handle string if (address == null || typeof address === 'string') never(); this._port = address.port; this.logger.debug(`Listening on port ${this._port}`); @@ -208,15 +233,6 @@ class WebSocketServer extends EventTarget { return this._host; } - @startStop.ready(new webSocketErrors.ErrorWebSocketServerNotRunning()) - public setTlsConfig(tlsConfig: TLSConfig): void { - const tlsServer = this.server as tls.Server; - tlsServer.setSecureContext({ - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }); - } - /** * Handles the creation of the `ReadableWritablePair` and provides it to the * StreamPair handler. @@ -225,18 +241,18 @@ class WebSocketServer extends EventTarget { webSocket: ws.WebSocket, request: IncomingMessage, ) => { - const connection = request.connection; - const webSocketStream = new WebSocketStream( + const socket = request.connection; + const webSocketStream = new WebSocketStreamServerInternal( webSocket, this.pingIntervalTime, this.pingTimeoutTimeTime, { - localHost: connection.localAddress ?? '', - localPort: connection.localPort ?? 0, - remoteHost: connection.remoteAddress ?? '', - remotePort: connection.remotePort ?? 0, + localHost: socket.localAddress ?? '', + localPort: socket.localPort ?? 0, + remoteHost: socket.remoteAddress ?? '', + remotePort: socket.remotePort ?? '', }, - this.logger.getChild(WebSocketStream.name), + this.logger.getChild(WebSocketStreamServerInternal.name), ); // Adding socket to the active sockets map this.activeSockets.add(webSocketStream); @@ -290,4 +306,189 @@ class WebSocketServer extends EventTarget { }; } +class WebSocketStreamServerInternal extends WebSocketStream { + protected writableController: WritableStreamDefaultController | undefined; + protected readableController: + | ReadableStreamController + | undefined; + protected messageHandler: (data: ws.RawData, isBinary: boolean) => void; + + constructor( + protected webSocket: ws.WebSocket, + pingInterval: number, + pingTimeoutTime: number, + protected metadata: Record, + protected logger: Logger, + ) { + super(); + logger.info('WS opened'); + const writableLogger = logger.getChild('Writable'); + const readableLogger = logger.getChild('Readable'); + // Setting up the writable stream + this.writable = new WritableStream({ + start: (controller) => { + this.writableController = controller; + }, + write: async (chunk, controller) => { + const writeResultProm = promise(); + this.webSocket.send(chunk, (err) => { + if (err == null) writeResultProm.resolveP(); + else writeResultProm.rejectP(err); + }); + try { + await writeResultProm.p; + writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); + } catch (e) { + this.logger.error(e); + controller.error(new webSocketErrors.ErrorServerSendFailed()); + } + }, + close: async () => { + writableLogger.info('Closed, sending null message'); + if (!this._webSocketEnded) { + const endProm = promise(); + this.webSocket.send(Buffer.from([]), (err) => { + if (err == null) endProm.resolveP(); + else endProm.rejectP(err); + }); + await endProm.p; + } + this.signalWritableEnd(); + if (this._readableEnded && !this._webSocketEnded) { + writableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); + this.webSocket.close(); + } + }, + abort: (reason) => { + writableLogger.info('Aborted'); + if (this._readableEnded && !this._webSocketEnded) { + writableLogger.debug('Ending socket'); + this.signalWebSocketEnd(reason); + this.webSocket.close(4000, 'Aborting connection'); + } + }, + }, + { + highWaterMark: 1, + }); + // Setting up the readable stream + this.messageHandler = (data: ws.RawData, isBinary: boolean) => { + if (!isBinary) never(); + if (data instanceof Array) never(); + const messageBuffer = Buffer.from(data); + readableLogger.debug(`Received ${messageBuffer.toString()}`); + if (messageBuffer.byteLength === 0) { + readableLogger.debug('Null message received'); + this.webSocket.off('message', this.messageHandler); + if (!this._readableEnded) { + readableLogger.debug('Closing'); + this.signalReadableEnd(); + this.readableController!.close(); + if (this._writableEnded && !this._webSocketEnded) { + readableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); + this.webSocket.close(); + } + } + return; + } + this.readableController!.enqueue(messageBuffer); + if ( + this.readableController!.desiredSize != null && + this.readableController!.desiredSize < 0 + ) { + this.webSocket.pause(); + } + }; + this.readable = new ReadableStream( + { + start: (controller) => { + this.readableController = controller; + this.webSocket.on('message', this.messageHandler); + }, + pull: () => { + this.webSocket.resume(); + }, + cancel: (reason) => { + this.webSocket.off('message', this.messageHandler); + this.signalReadableEnd(reason); + if (this._writableEnded && !this._webSocketEnded) { + readableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); + this.webSocket.close(); + } + }, + }, + { + highWaterMark: 1, + }, + ); + + const pingTimer = setInterval(() => { + this.webSocket.ping(); + }, pingInterval); + const pingTimeoutTimeTimer = setTimeout(() => { + logger.debug('Ping timed out'); + this.webSocket.close(); + }, pingTimeoutTime); + const pongHandler = (data: Buffer) => { + logger.debug(`Received pong with (${data.toString()})`); + pingTimeoutTimeTimer.refresh(); + }; + this.webSocket.on('pong', pongHandler); + + const closeHandler = () => { + logger.debug('Closing'); + this.signalWebSocketEnd(); + // Cleaning up timers + logger.debug('Cleaning up timers'); + clearTimeout(pingTimer); + clearTimeout(pingTimeoutTimeTimer); + // Closing streams + logger.debug('Cleaning streams'); + this.webSocket.off('message', this.messageHandler); + const err = new webSocketErrors.ErrorServerConnectionEndedEarly(); + if (!this._readableEnded) { + readableLogger.debug('Closing'); + this.signalReadableEnd(err); + this.webSocket.off('message', this.messageHandler); + this.readableController?.error(err); + } + if (!this._writableEnded) { + writableLogger.debug('Closing'); + this.signalWritableEnd(err); + this.writableController?.error(err); + } + }; + this.webSocket.once('close', closeHandler); + } + + get meta(): Record { + return { + ...this.metadata, + }; + } + + cancel(reason?: any): void { + // Default error + const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); + // Close the streams with the given error, + if (!this._readableEnded) { + this.webSocket.off('message', this.messageHandler); + this.readableController?.error(err); + this.signalReadableEnd(err); + } + if (!this._writableEnded) { + this.writableController?.error(err); + this.signalWritableEnd(err); + } + // Then close the websocket + if (!this._webSocketEnded) { + this.webSocket.terminate(); + this.signalWebSocketEnd(err); + } + } +} + export default WebSocketServer; diff --git a/tests/websockets/WebSocket.test.ts b/tests/websockets/WebSocket.test.ts index 64815dcf9..b7fb00356 100644 --- a/tests/websockets/WebSocket.test.ts +++ b/tests/websockets/WebSocket.test.ts @@ -1,7 +1,6 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; import type { KeyPair } from '@/keys/types'; -import type { NodeId } from '@/ids/types'; import type http from 'http'; import fs from 'fs'; import path from 'path'; @@ -11,6 +10,7 @@ import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; import { status } from '@matrixai/async-init'; +import { KeyRing } from '@/keys/index'; import WebSocketServer from '@/websockets/WebSocketServer'; import WebSocketClient from '@/websockets/WebSocketClient'; import { promise } from '@/utils'; @@ -28,8 +28,8 @@ describe('WebSocket', () => { formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); - let keyPair: KeyPair; - let nodeId: NodeId; + let dataDir: string; + let keyRing: KeyRing; let tlsConfig: TLSConfig; const host = '127.0.0.2'; let webSocketServer: WebSocketServer; @@ -63,14 +63,19 @@ describe('WebSocket', () => { dataDir = await fs.promises.mkdtemp( path.join(os.tmpdir(), 'polykey-test-'), ); - keyPair = keysUtils.generateKeyPair(); - nodeId = keysUtils.publicKeyToNodeId(keyPair.publicKey); - tlsConfig = await testsUtils.createTLSConfig(keyPair); + const keysPath = path.join(dataDir, 'keys'); + keyRing = await KeyRing.createKeyRing({ + keysPath: keysPath, + password: 'password', + logger: logger.getChild('keyRing'), + }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { logger.info('AFTEREACH'); await webSocketServer?.stop(true); await webSocketClient?.destroy(true); + await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); }); @@ -84,6 +89,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -92,7 +98,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -109,42 +115,6 @@ describe('WebSocket', () => { expect((await reader.read()).done).toBeTrue(); logger.info('ending'); }); - test('can change TLS config', async () => { - const keyPairNew = keysUtils.generateKeyPair(); - const nodeIdNew = keysUtils.publicKeyToNodeId(keyPairNew.publicKey); - const tlsConfigNew = await testsUtils.createTLSConfig(keyPairNew); - - webSocketServer = await WebSocketServer.createWebSocketServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => logger.info('STREAM HANDLING ENDED')); - }, - tlsConfig, - host, - logger: logger.getChild('server'), - }); - logger.info(`Server started on port ${webSocketServer.getPort()}`); - webSocketClient = await WebSocketClient.createWebSocketClient({ - host, - port: webSocketServer.getPort(), - expectedNodeIds: [nodeId, nodeIdNew], - logger: logger.getChild('clientClient'), - }); - const websocket = await webSocketClient.startConnection(); - expect(websocket.meta.nodeId).toBe(nodesUtils.encodeNodeId(nodeId)); - websocket.cancel(); - - // Changing certs - webSocketServer.setTlsConfig(tlsConfigNew); - const websocket2 = await webSocketClient.startConnection(); - expect(websocket2.meta.nodeId).toBe(nodesUtils.encodeNodeId(nodeIdNew)); - websocket2.cancel(); - - logger.info('ending'); - }); test('makes a connection over IPv6', async () => { webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { @@ -154,6 +124,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host: '::1', logger: logger.getChild('server'), @@ -162,7 +133,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host: '::1', port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -188,6 +159,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -196,7 +168,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -218,6 +190,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -226,7 +199,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); @@ -263,6 +236,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -289,6 +263,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -314,6 +289,7 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -331,7 +307,7 @@ describe('WebSocket', () => { env: { PK_TEST_HOST: host, PK_TEST_PORT: `${webSocketServer.getPort()}`, - PK_TEST_NODE_ID: nodesUtils.encodeNodeId(nodeId), + PK_TEST_NODE_ID: nodesUtils.encodeNodeId(keyRing.getNodeId()), }, }, logger, @@ -393,7 +369,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: await startedProm.p, - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -436,6 +412,7 @@ describe('WebSocket', () => { } })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -444,7 +421,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -476,6 +453,7 @@ describe('WebSocket', () => { await writer.close(); })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -484,7 +462,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -514,6 +492,7 @@ describe('WebSocket', () => { await writer.close(); })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -522,7 +501,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -542,6 +521,7 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -550,7 +530,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -582,6 +562,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -611,6 +592,7 @@ describe('WebSocket', () => { logger.info('inside callback'); // Hang connection }, + basePath: dataDir, tlsConfig, host, pingTimeoutTimeTime: 100, @@ -620,7 +602,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); await webSocketClient.startConnection(); @@ -637,6 +619,7 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -645,7 +628,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -680,6 +663,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -702,7 +686,7 @@ describe('WebSocket', () => { }); test('authenticates with multiple certs in chain', async () => { const keyPairs: Array = [ - keyPair, + keyRing.keyPair, keysUtils.generateKeyPair(), keysUtils.generateKeyPair(), keysUtils.generateKeyPair(), @@ -718,6 +702,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -747,6 +732,7 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -755,7 +741,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId, alternativeNodeId], + expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], logger: logger.getChild('clientClient'), }); await expect(webSocketClient.startConnection()).toResolve(); @@ -768,7 +754,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: 12345, - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], connectionTimeoutTime: 0, logger: logger.getChild('clientClient'), }); @@ -786,6 +772,7 @@ describe('WebSocket', () => { logger.info('inside callback'); // Hang connection }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -794,7 +781,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], pingTimeoutTimeTime: 100, logger: logger.getChild('clientClient'), }); @@ -822,6 +809,7 @@ describe('WebSocket', () => { })().catch(() => {}), ]); }, + basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -830,7 +818,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [nodeId], + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const abortController = new AbortController();