From eac61530e3f424305a23e31a047ab0075c29a8a7 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 14 Feb 2023 19:40:23 +1100 Subject: [PATCH 01/23] feat: creating `ClientServer` implementation - Supports SSL for secure communication. - Created tests [ci skip] --- package-lock.json | 9 + package.json | 1 + src/clientRPC/ClientServer.ts | 276 +++++++++++++++++++++++++ src/clientRPC/utils.ts | 9 +- tests/clientRPC/ClientServer.test.ts | 293 +++++++++++++++++++++++++++ tests/clientRPC/websocket.test.ts | 56 ++++- 6 files changed, 639 insertions(+), 5 deletions(-) create mode 100644 src/clientRPC/ClientServer.ts create mode 100644 tests/clientRPC/ClientServer.test.ts diff --git a/package-lock.json b/package-lock.json index c130defb4..09d004c19 100644 --- a/package-lock.json +++ b/package-lock.json @@ -52,6 +52,7 @@ "tslib": "^2.4.0", "tsyringe": "^4.7.0", "utp-native": "^2.5.3", + "uWebSockets.js": "github:uNetworking/uWebSockets.js#v20.19.0", "ws": "^8.12.0" }, "bin": { @@ -11836,6 +11837,10 @@ "uuid": "dist/bin/uuid" } }, + "node_modules/uWebSockets.js": { + "version": "20.19.0", + "resolved": "git+ssh://git@github.com/uNetworking/uWebSockets.js.git#42c9c0d5d31f46ca4115dc75672b0037ec970f28" + }, "node_modules/v8-compile-cache": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/v8-compile-cache/-/v8-compile-cache-2.3.0.tgz", @@ -20895,6 +20900,10 @@ "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==" }, + "uWebSockets.js": { + "version": "git+ssh://git@github.com/uNetworking/uWebSockets.js.git#42c9c0d5d31f46ca4115dc75672b0037ec970f28", + "from": "uWebSockets.js@github:uNetworking/uWebSockets.js#v20.19.0" + }, "v8-compile-cache": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/v8-compile-cache/-/v8-compile-cache-2.3.0.tgz", diff --git a/package.json b/package.json index 929c868c6..3758fd9b1 100644 --- a/package.json +++ b/package.json @@ -122,6 +122,7 @@ "tslib": "^2.4.0", "tsyringe": "^4.7.0", "utp-native": "^2.5.3", + "uWebSockets.js": "github:uNetworking/uWebSockets.js#v20.19.0", "ws": "^8.12.0" }, "devDependencies": { diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts new file mode 100644 index 000000000..d182f23e9 --- /dev/null +++ b/src/clientRPC/ClientServer.ts @@ -0,0 +1,276 @@ +import type { ReadableWritablePair } from 'stream/web'; +import type { FileSystem, PromiseDeconstructed } from 'types'; +import type { TLSConfig } from 'network/types'; +import type { WebSocket } from 'uWebSockets.js'; +import { WritableStream, ReadableStream } from 'stream/web'; +import path from 'path'; +import { createDestroy } from '@matrixai/async-init'; +import Logger from '@matrixai/logger'; +import uWebsocket from 'uWebSockets.js'; +import { promise } from '../utils'; + +type ConnectionCallback = ( + streamPair: ReadableWritablePair, +) => void; + +type Context = { + message: ( + ws: WebSocket, + message: ArrayBuffer, + isBinary: boolean, + ) => void; + drain: (ws: WebSocket) => void; + close: (ws: WebSocket, code: number, message: ArrayBuffer) => void; + logger: Logger; +}; + +// TODO: +// - shutting down active connections +// - propagating backpressure. + +interface ClientServer extends createDestroy.CreateDestroy {} +@createDestroy.CreateDestroy() +class ClientServer { + static async createWSServer({ + connectionCallback, + tlsConfig, + basePath, + host, + port, + fs = require('fs'), + logger = new Logger(this.name), + }: { + connectionCallback: ConnectionCallback; + tlsConfig: TLSConfig; + basePath: string; + host?: string; + port?: number; + fs?: FileSystem; + logger?: Logger; + }) { + logger.info(`Creating ${this.name}`); + const wsServer = new this(logger, fs); + await wsServer.start({ + connectionCallback, + tlsConfig, + basePath, + host, + port, + }); + logger.info(`Created ${this.name}`); + return wsServer; + } + + protected server: uWebsocket.TemplatedApp; + protected listenSocket: uWebsocket.us_listen_socket; + protected host: string; + protected connectionCallback: ConnectionCallback; + protected activeSockets: Set> = new Set(); + protected waitForActive: PromiseDeconstructed | null = null; + + constructor(protected logger: Logger, protected fs: FileSystem) {} + + public async start({ + connectionCallback, + tlsConfig, + basePath, + host, + port, + }: { + connectionCallback: ConnectionCallback; + tlsConfig: TLSConfig; + basePath: string; + host?: string; + port?: number; + }): Promise { + this.logger.info(`Starting ${this.constructor.name}`); + this.connectionCallback = connectionCallback; + // TODO: take a TLS config, write the files in the temp directory and + // load them. + let count = 0; + const keyFile = path.join(basePath, 'keyFile.pem'); + const certFile = path.join(basePath, 'certFile.pem'); + await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); + await this.fs.promises.writeFile(certFile, tlsConfig.certChainPem); + this.server = uWebsocket.SSLApp({ + key_file_name: keyFile, + cert_file_name: certFile, + }); + await this.fs.promises.rm(keyFile); + await this.fs.promises.rm(certFile); + this.server.ws('/*', { + upgrade: (res, req, context) => { + // Req.forEach((k, v) => console.log(k, ':', v)); + const logger = this.logger.getChild(`Connection ${count}`); + res.upgrade>( + { + logger, + }, + req.getHeader('sec-websocket-key'), + req.getHeader('sec-websocket-protocol'), + req.getHeader('sec-websocket-extensions'), + context, + ); + count += 1; + }, + open: (ws: WebSocket) => { + if (this.waitForActive == null) this.waitForActive = promise(); + this.activeSockets.add(ws); + // Set up streams and context + this.handleOpen(ws); + }, + message: (ws: WebSocket, message, isBinary) => { + ws.getUserData().message(ws, message, isBinary); + }, + close: (ws, code, message) => { + this.activeSockets.delete(ws); + if (this.activeSockets.size === 0) this.waitForActive?.resolveP(); + ws.getUserData().close(ws, code, message); + }, + drain: (ws) => { + ws.getUserData().drain(ws); + }, + }); + const listenProm = promise(); + if (host != null) { + // With custom host + this.server.listen(host, port ?? 0, (listenSocket) => { + if (listenSocket) { + this.listenSocket = listenSocket; + listenProm.resolveP(); + } else { + listenProm.rejectP(Error('TMP, no port')); + } + }); + } else { + // With default host + this.server.listen(port ?? 0, (listenSocket) => { + if (listenSocket) { + this.listenSocket = listenSocket; + listenProm.resolveP(); + } else { + listenProm.rejectP(Error('TMP, no port')); + } + }); + } + await listenProm.p; + this.logger.debug( + `bound to port ${uWebsocket.us_socket_local_port(this.listenSocket)}`, + ); + this.host = host ?? '127.0.0.1'; + this.logger.info(`Started ${this.constructor.name}`); + } + + public async destroy(force: boolean = false): Promise { + this.logger.info(`Destroying ${this.constructor.name}`); + // Close the server by closing the underlying socket + uWebsocket.us_listen_socket_close(this.listenSocket); + // Shutting down active websockets + if (force) { + for (const ws of this.activeSockets.values()) { + ws.close(); + } + } + // Wait for all active websockets to close + await this.waitForActive?.p; + this.logger.info(`Destroyed ${this.constructor.name}`); + } + + get port() { + return uWebsocket.us_socket_local_port(this.listenSocket); + } + + protected handleOpen(ws: WebSocket) { + const context = ws.getUserData(); + const logger = context.logger; + logger.info('WS opened'); + let writableClosed = false; + let readableClosed = false; + // Setting up the writable stream + const writableStream = new WritableStream({ + write: (chunk) => { + logger.info('WRITABLE WRITE'); + const writeResult = ws.send(chunk, true); + switch (writeResult) { + case 0: + logger.info('DROPPED, backpressure'); + break; + case 2: + logger.info('BACKPRESSURE'); + break; + case 1: + default: + // Do nothing + break; + } + }, + close: () => { + logger.info('WRITABLE CLOSE'); + writableClosed = true; + if (readableClosed) { + logger.debug('ENDING WS'); + ws.end(); + } + }, + abort: () => { + logger.info('WRITABLE ABORT'); + if (readableClosed) { + logger.debug('ENDING WS'); + ws.end(); + } + }, + }); + // Setting up the readable stream + const readableStream = new ReadableStream({ + start: (controller) => { + context.message = (ws, message, _) => { + logger.debug('MESSAGE CALLED'); + if (message.byteLength === 0) { + logger.debug('NULL MESSAGE, CLOSING'); + if (!readableClosed) { + logger.debug('CLOSING READABLE'); + controller.close(); + readableClosed = true; + if (writableClosed) { + ws.end(); + } + } + return; + } + controller.enqueue(Buffer.from(message)); + }; + context.close = () => { + logger.debug('CLOSING CALLED'); + if (!readableClosed) { + logger.debug('CLOSING READABLE'); + controller.close(); + readableClosed = true; + } + }; + context.drain = () => { + logger.debug('DRAINING CALLED'); + }; + }, + cancel: () => { + readableClosed = true; + if (writableClosed) { + logger.debug('ENDING WS'); + ws.end(); + } + }, + }); + logger.info('callback'); + try { + this.connectionCallback({ + readable: readableStream, + writable: writableStream, + }); + } catch (e) { + logger.error(e); + // TODO: If the callback failed then we need to handle clean up + logger.error(e.toString()); + } + } +} + +export default ClientServer; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 9b280d77e..6d203b010 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -146,7 +146,7 @@ function writeableFromWebSocket( const wait = promise(); ws.send(chunk, (e) => { if (e != null) { - logger.error(`error: ${e}`); + // Logger.error(`${e}`); controller.error(e); } wait.resolveP(); @@ -186,9 +186,10 @@ function startConnection( // Ca: tlsConfig.certChainPem }); ws.once('close', () => logger.info('CLOSED')); - // Ws.once('upgrade', () => { - // // Const tlsSocket = request.socket as TLSSocket; - // // Console.log(tlsSocket.getPeerCertificate()); + // Ws.once('upgrade', (request) => { + // const tlsSocket = request.socket as TLSSocket; + // const peerCert = tlsSocket.getPeerCertificate(); + // console.log(peerCert.issuer.CN); // logger.info('Test early cancellation'); // // Request.destroy(Error('some error')); // // tlsSocket.destroy(Error('some error')); diff --git a/tests/clientRPC/ClientServer.test.ts b/tests/clientRPC/ClientServer.test.ts new file mode 100644 index 000000000..cd3b53a23 --- /dev/null +++ b/tests/clientRPC/ClientServer.test.ts @@ -0,0 +1,293 @@ +import type { ReadableWritablePair } from 'stream/web'; +import type { TLSConfig } from '@/network/types'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; +import { testProp, fc } from '@fast-check/jest'; +import { KeyRing } from '@/keys/index'; +import ClientServer from '@/clientRPC/ClientServer'; +import * as clientRPCUtils from '@/clientRPC/utils'; +import * as testsUtils from '../utils'; + +describe('ClientServer', () => { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + const loudLogger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + + let dataDir: string; + let keyRing: KeyRing; + let tlsConfig: TLSConfig; + const host = '127.0.0.2'; + let clientServer: ClientServer; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + keyRing = await KeyRing.createKeyRing({ + keysPath: keysPath, + password: 'password', + logger: logger.getChild('keyRing'), + }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + }); + afterEach(async () => { + await clientServer.destroy(true); + await keyRing.stop(); + await fs.promises.rm(dataDir, { force: true, recursive: true }); + }); + + test('Handles a connection', async () => { + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + const message1 = Buffer.from('1request1'); + await writer.write(message1); + expect((await reader.read()).value).toStrictEqual(message1); + const message2 = Buffer.from('1request2'); + await writer.write(message2); + expect((await reader.read()).value).toStrictEqual(message2); + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); + test('Handles a connection and closes before message', async () => { + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await websocket.writable.close(); + const reader = websocket.readable.getReader(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); + const messagesArb = fc.array( + fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), + ); + const streamsArb = fc.array(messagesArb, { minLength: 1 }).noShrink(); + testProp( + 'Handles multiple connections', + [streamsArb], + async (streamsData) => { + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + + const testStream = async (messages: Array) => { + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + for (const message of messages) { + await writer.write(message); + const response = await reader.read(); + expect(response.done).toBeFalse(); + expect(response.value?.toString()).toStrictEqual(message.toString()); + } + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + }; + const streams = streamsData.map((messages) => testStream(messages)); + await Promise.all(streams); + + logger.info('ending'); + }, + ); + const asyncReadWrite = async ( + messages: Array, + streampair: ReadableWritablePair, + ) => { + await Promise.allSettled([ + (async () => { + const writer = streampair.writable.getWriter(); + for (const message of messages) { + await writer.write(message); + } + await writer.close(); + })(), + (async () => { + for await (const _ of streampair.readable) { + // No touch, only consume + } + })(), + ]); + }; + testProp( + 'allows half closed writable closes first', + [messagesArb, messagesArb], + async (messages1, messages2) => { + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + }, + ); + testProp( + 'allows half closed readable closes first', + [messagesArb, messagesArb], + async (messages1, messages2) => { + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + }, + ); + testProp( + 'handles early close of readable', + [messagesArb, messagesArb], + async (messages1, messages2) => { + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + await streamPair.readable.cancel(); + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + }, + ); + test('Destroying ClientServer stops all connections', async () => { + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await clientServer.destroy(true); + for await (const _ of websocket.readable) { + // No touch, only consume + } + logger.info('ending'); + }); +}); diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index 813935ab7..577724a87 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -14,6 +14,7 @@ import { KeyRing } from '@/keys/index'; import * as clientRPCUtils from '@/clientRPC/utils'; import { UnaryHandler } from '@/RPC/handlers'; import { UnaryCaller } from '@/RPC/callers'; +import ClientServer from '@/clientRPC/ClientServer'; import * as testsUtils from '../utils/index'; describe('websocket', () => { @@ -22,12 +23,17 @@ describe('websocket', () => { formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); + const loudLogger = new Logger('websocket test', LogLevel.DEBUG, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); let dataDir: string; let keyRing: KeyRing; let tlsConfig: TLSConfig; let server: Server; let wss: WebSocketServer; - const host = '127.0.0.1'; + const host = '127.0.0.2'; let port: number; let rpcServer: RPCServer; let rpcClient_: RPCClient; @@ -111,4 +117,52 @@ describe('websocket', () => { rpcClient.unaryCaller('test3', { hello: 'world2' }), ).toReject(); }); + + test('Using uws', async () => { + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + const server = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${server.port}`); + + const websocket1 = await clientRPCUtils.startConnection( + host, + server.port, + logger.getChild('Connection'), + ); + + // Const websocket2 = await clientRPCUtils.startConnection( + // host, + // server.port, + // logger.getChild('Connection'), + // ); + + logger.info('doing things'); + const writer1 = websocket1.writable.getWriter(); + // Const writer2 = websocket2.writable.getWriter(); + await writer1.write(Buffer.from('1request1')); + // Await writer2.write(Buffer.from('2request1')); + await writer1.write(Buffer.from('1request2')); + // Await writer2.write(Buffer.from('2request2')); + await writer1.close(); + // Await writer2.close(); + for await (const val of websocket1.readable) { + logger.info(`Client1 message: ${val.toString()}`); + } + // For await (const val of websocket2.readable) { + // logger.info(`Client2 message: ${val.toString()}`); + // } + logger.info('ending'); + await server.destroy(); + }); }); From c433d92522252a2d7dc35c4116903cb4e66bff4a Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 17 Feb 2023 14:21:24 +1100 Subject: [PATCH 02/23] feat: `ClientServer` supports backpressure Only writable side of the websocket really supports backpressure. The incoming data is buffered up to a limit. If the limit is exceeded we throw and close the stream. [ci skip] --- src/clientRPC/ClientServer.ts | 121 +++++++++++++++++--------- src/clientRPC/utils.ts | 3 +- tests/clientRPC/ClientServer.test.ts | 124 ++++++++++++++++++++++++++- 3 files changed, 203 insertions(+), 45 deletions(-) diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index d182f23e9..2eaa54fcc 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -22,6 +22,7 @@ type Context = { drain: (ws: WebSocket) => void; close: (ws: WebSocket, code: number, message: ArrayBuffer) => void; logger: Logger; + writeBackpressure: boolean; }; // TODO: @@ -38,6 +39,7 @@ class ClientServer { host, port, fs = require('fs'), + maxReadBufferBytes = 1_000_000_000, // About 1 GB logger = new Logger(this.name), }: { connectionCallback: ConnectionCallback; @@ -46,10 +48,11 @@ class ClientServer { host?: string; port?: number; fs?: FileSystem; + maxReadBufferBytes?: number; logger?: Logger; }) { logger.info(`Creating ${this.name}`); - const wsServer = new this(logger, fs); + const wsServer = new this(logger, fs, maxReadBufferBytes); await wsServer.start({ connectionCallback, tlsConfig, @@ -68,7 +71,17 @@ class ClientServer { protected activeSockets: Set> = new Set(); protected waitForActive: PromiseDeconstructed | null = null; - constructor(protected logger: Logger, protected fs: FileSystem) {} + /** + * + * @param logger + * @param fs + * @param maxReadBufferBytes Max number of bytes stored in read buffer before error + */ + constructor( + protected logger: Logger, + protected fs: FileSystem, + protected maxReadBufferBytes, + ) {} public async start({ connectionCallback, @@ -119,7 +132,8 @@ class ClientServer { // Set up streams and context this.handleOpen(ws); }, - message: (ws: WebSocket, message, isBinary) => { + // TODO: could this take an async and apply backpressure implicitly? + message: async (ws: WebSocket, message, isBinary) => { ws.getUserData().message(ws, message, isBinary); }, close: (ws, code, message) => { @@ -186,21 +200,34 @@ class ClientServer { logger.info('WS opened'); let writableClosed = false; let readableClosed = false; + let backpressure: PromiseDeconstructed | null = null; + context.drain = () => { + logger.debug('DRAINING CALLED'); + backpressure?.resolveP(); + }; // Setting up the writable stream const writableStream = new WritableStream({ - write: (chunk) => { - logger.info('WRITABLE WRITE'); + write: async (chunk, controller) => { + // Logger.debug('WRITABLE WRITE'); + await backpressure?.p; const writeResult = ws.send(chunk, true); switch (writeResult) { - case 0: - logger.info('DROPPED, backpressure'); - break; + default: case 2: - logger.info('BACKPRESSURE'); + // Write failure, emit error + controller.error(Error('TMP Failed to write')); + break; + case 0: + logger.info('Write backpressure'); + // Signal backpressure + backpressure = promise(); + context.writeBackpressure = true; + backpressure.p.finally(() => { + context.writeBackpressure = false; + }); break; case 1: - default: - // Do nothing + // Success break; } }, @@ -221,44 +248,56 @@ class ClientServer { }, }); // Setting up the readable stream - const readableStream = new ReadableStream({ - start: (controller) => { - context.message = (ws, message, _) => { - logger.debug('MESSAGE CALLED'); - if (message.byteLength === 0) { - logger.debug('NULL MESSAGE, CLOSING'); + const readableStream = new ReadableStream( + { + start: (controller) => { + context.message = (ws, message, _) => { + // Logger.debug('MESSAGE CALLED'); + if (message.byteLength === 0) { + logger.debug('NULL MESSAGE, CLOSING'); + if (!readableClosed) { + logger.debug('CLOSING READABLE'); + controller.close(); + readableClosed = true; + if (writableClosed) { + ws.end(); + } + } + return; + } + controller.enqueue(Buffer.from(message)); + if ( + controller.desiredSize != null && + controller.desiredSize < -1000 + ) { + logger.error('Read stream buffer full'); + const err = Error('TMP read buffer limit'); + ws.end(4001, err.toString()); + controller.error(err); + } + }; + context.close = () => { + logger.debug('CLOSING CALLED'); if (!readableClosed) { logger.debug('CLOSING READABLE'); controller.close(); readableClosed = true; - if (writableClosed) { - ws.end(); - } } - return; + }; + }, + cancel: () => { + readableClosed = true; + if (writableClosed) { + logger.debug('ENDING WS'); + ws.end(); } - controller.enqueue(Buffer.from(message)); - }; - context.close = () => { - logger.debug('CLOSING CALLED'); - if (!readableClosed) { - logger.debug('CLOSING READABLE'); - controller.close(); - readableClosed = true; - } - }; - context.drain = () => { - logger.debug('DRAINING CALLED'); - }; + }, }, - cancel: () => { - readableClosed = true; - if (writableClosed) { - logger.debug('ENDING WS'); - ws.end(); - } + { + highWaterMark: this.maxReadBufferBytes, + size: (chunk) => chunk?.byteLength ?? 0, }, - }); + ); logger.info('callback'); try { this.connectionCallback({ diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 6d203b010..e9b41f834 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -142,11 +142,10 @@ function writeableFromWebSocket( ws.close(); }, write: async (chunk, controller) => { - logger.debug(`writing: ${chunk?.toString()}`); + // Logger.debug(`writing: ${chunk?.toString()}`); const wait = promise(); ws.send(chunk, (e) => { if (e != null) { - // Logger.error(`${e}`); controller.error(e); } wait.resolveP(); diff --git a/tests/clientRPC/ClientServer.test.ts b/tests/clientRPC/ClientServer.test.ts index cd3b53a23..f5bee8ebe 100644 --- a/tests/clientRPC/ClientServer.test.ts +++ b/tests/clientRPC/ClientServer.test.ts @@ -1,5 +1,6 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; +import type { WebSocket } from 'uWebSockets.js'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -8,6 +9,7 @@ import { testProp, fc } from '@fast-check/jest'; import { KeyRing } from '@/keys/index'; import ClientServer from '@/clientRPC/ClientServer'; import * as clientRPCUtils from '@/clientRPC/utils'; +import { promise } from '@/utils'; import * as testsUtils from '../utils'; describe('ClientServer', () => { @@ -16,7 +18,7 @@ describe('ClientServer', () => { formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); - const loudLogger = new Logger('websocket test', LogLevel.WARN, [ + const loudLogger = new Logger('websocket test', LogLevel.DEBUG, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), @@ -53,7 +55,7 @@ describe('ClientServer', () => { void streamPair.readable .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => logger.info('STREAM HANDLING ENDED')); + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); }, basePath: dataDir, tlsConfig, @@ -290,4 +292,122 @@ describe('ClientServer', () => { } logger.info('ending'); }); + test('Writable backpressure', async () => { + let context: { writeBackpressure: boolean } | undefined; + const backpressure = promise(); + const resumeWriting = promise(); + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void Promise.allSettled([ + (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })(), + (async () => { + // Kidnap the context + let ws: WebSocket<{ writeBackpressure: boolean }> | null = null; + // @ts-ignore: kidnap protected property + for (const websocket of clientServer.activeSockets.values()) { + ws = websocket; + } + if (ws == null) { + await streamPair.writable.close(); + return; + } + context = ws.getUserData(); + // Write until backPressured + const message = Buffer.alloc(128, 0xf0); + const writer = streamPair.writable.getWriter(); + while (!context.writeBackpressure) { + await writer.write(message); + } + loudLogger.info('BACK PRESSURED'); + backpressure.resolveP(); + await resumeWriting.p; + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + await writer.close(); + loudLogger.info('WRITING ENDED'); + })(), + ]).catch((e) => logger.error(e.toString())); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await websocket.writable.close(); + + await backpressure.p; + expect(context?.writeBackpressure).toBeTrue(); + resumeWriting.resolveP(); + // Consume all of the back-pressured data + for await (const _ of websocket.readable) { + // No touch, only consume + } + expect(context?.writeBackpressure).toBeFalse(); + loudLogger.info('ending'); + }); + // Readable backpressure is not actually supported. We're dealing with it by + // using an buffer with a provided limit that can be very large. + test.only('Exceeding readable buffer limit causes error', async () => { + const startReading = promise(); + const handlingProm = promise(); + clientServer = await ClientServer.createWSServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + Promise.all([ + (async () => { + await startReading.p; + loudLogger.info('Starting consumption'); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + loudLogger.info('Reads ended'); + })(), + (async () => { + await streamPair.writable.close(); + })(), + ]) + .catch(() => {}) + .finally(() => handlingProm.resolveP()); + }, + basePath: dataDir, + tlsConfig, + host, + // Setting a really low buffer limit + maxReadBufferBytes: 1500, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + loudLogger.getChild('Connection'), + ); + const message = Buffer.alloc(1_000, 0xf0); + const writer = websocket.writable.getWriter(); + loudLogger.info('Starting writes'); + await expect(async () => { + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + }).rejects.toThrow(); + startReading.resolveP(); + loudLogger.info('writes ended'); + for await (const _ of websocket.readable) { + // No touch, only consume + } + await handlingProm.p; + loudLogger.info('ending'); + }); }); From 0ed77d6535c79f41cd770dd4bc892a51391a7133 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 20 Feb 2023 16:27:47 +1100 Subject: [PATCH 03/23] fix: switched `ClientServer` to `StartStop` and other fixes Fixed the process not exiting when tests finished. [ci skip] --- src/clientRPC/ClientServer.ts | 18 +- tests/clientRPC/ClientServer.test.ts | 262 ++++++++++++++------------- tests/clientRPC/websocket.test.ts | 2 +- 3 files changed, 148 insertions(+), 134 deletions(-) diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index 2eaa54fcc..c1e24a49e 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -4,7 +4,7 @@ import type { TLSConfig } from 'network/types'; import type { WebSocket } from 'uWebSockets.js'; import { WritableStream, ReadableStream } from 'stream/web'; import path from 'path'; -import { createDestroy } from '@matrixai/async-init'; +import { startStop } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import uWebsocket from 'uWebSockets.js'; import { promise } from '../utils'; @@ -25,14 +25,10 @@ type Context = { writeBackpressure: boolean; }; -// TODO: -// - shutting down active connections -// - propagating backpressure. - -interface ClientServer extends createDestroy.CreateDestroy {} -@createDestroy.CreateDestroy() +interface ClientServer extends startStop.StartStop {} +@startStop.StartStop() class ClientServer { - static async createWSServer({ + static async createClientServer({ connectionCallback, tlsConfig, basePath, @@ -175,8 +171,8 @@ class ClientServer { this.logger.info(`Started ${this.constructor.name}`); } - public async destroy(force: boolean = false): Promise { - this.logger.info(`Destroying ${this.constructor.name}`); + public async stop(force: boolean = false): Promise { + this.logger.info(`Stopping ${this.constructor.name}`); // Close the server by closing the underlying socket uWebsocket.us_listen_socket_close(this.listenSocket); // Shutting down active websockets @@ -187,7 +183,7 @@ class ClientServer { } // Wait for all active websockets to close await this.waitForActive?.p; - this.logger.info(`Destroyed ${this.constructor.name}`); + this.logger.info(`Stopped ${this.constructor.name}`); } get port() { diff --git a/tests/clientRPC/ClientServer.test.ts b/tests/clientRPC/ClientServer.test.ts index f5bee8ebe..e8fc75670 100644 --- a/tests/clientRPC/ClientServer.test.ts +++ b/tests/clientRPC/ClientServer.test.ts @@ -18,7 +18,7 @@ describe('ClientServer', () => { formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); - const loudLogger = new Logger('websocket test', LogLevel.DEBUG, [ + const loudLogger = new Logger('websocket test', LogLevel.WARN, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), @@ -43,13 +43,13 @@ describe('ClientServer', () => { tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { - await clientServer.destroy(true); + await clientServer.stop(true); await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); }); test('Handles a connection', async () => { - clientServer = await ClientServer.createWSServer({ + clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -81,7 +81,7 @@ describe('ClientServer', () => { logger.info('ending'); }); test('Handles a connection and closes before message', async () => { - clientServer = await ClientServer.createWSServer({ + clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -113,42 +113,48 @@ describe('ClientServer', () => { 'Handles multiple connections', [streamsArb], async (streamsData) => { - clientServer = await ClientServer.createWSServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => logger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - - const testStream = async (messages: Array) => { - const websocket = await clientRPCUtils.startConnection( + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, host, - clientServer.port, - logger.getChild('Connection'), - ); - const writer = websocket.writable.getWriter(); - const reader = websocket.readable.getReader(); - for (const message of messages) { - await writer.write(message); - const response = await reader.read(); - expect(response.done).toBeFalse(); - expect(response.value?.toString()).toStrictEqual(message.toString()); - } - await writer.close(); - expect((await reader.read()).done).toBeTrue(); - }; - const streams = streamsData.map((messages) => testStream(messages)); - await Promise.all(streams); + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); - logger.info('ending'); + const testStream = async (messages: Array) => { + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + for (const message of messages) { + await writer.write(message); + const response = await reader.read(); + expect(response.done).toBeFalse(); + expect(response.value?.toString()).toStrictEqual( + message.toString(), + ); + } + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + }; + const streams = streamsData.map((messages) => testStream(messages)); + await Promise.all(streams); + + logger.info('ending'); + } finally { + await clientServer.stop(true); + } }, ); const asyncReadWrite = async ( @@ -174,101 +180,113 @@ describe('ClientServer', () => { 'allows half closed writable closes first', [messagesArb, messagesArb], async (messages1, messages2) => { - clientServer = await ClientServer.createWSServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void (async () => { - const writer = streamPair.writable.getWriter(); - for await (const val of messages2) { - await writer.write(val); - } - await writer.close(); - for await (const _ of streamPair.readable) { - // No touch, only consume - } - })().catch((e) => logger.error(e)); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( - host, - clientServer.port, - logger.getChild('Connection'), - ); - await asyncReadWrite(messages1, websocket); - logger.info('ending'); + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await clientServer.stop(true); + } }, ); testProp( 'allows half closed readable closes first', [messagesArb, messagesArb], async (messages1, messages2) => { - clientServer = await ClientServer.createWSServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void (async () => { - for await (const _ of streamPair.readable) { - // No touch, only consume - } - const writer = streamPair.writable.getWriter(); - for await (const val of messages2) { - await writer.write(val); - } - await writer.close(); - })().catch((e) => logger.error(e)); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( - host, - clientServer.port, - logger.getChild('Connection'), - ); - await asyncReadWrite(messages1, websocket); - logger.info('ending'); + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await clientServer.stop(true); + } }, ); testProp( 'handles early close of readable', [messagesArb, messagesArb], async (messages1, messages2) => { - clientServer = await ClientServer.createWSServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void (async () => { - await streamPair.readable.cancel(); - const writer = streamPair.writable.getWriter(); - for await (const val of messages2) { - await writer.write(val); - } - await writer.close(); - })().catch((e) => logger.error(e)); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( - host, - clientServer.port, - logger.getChild('Connection'), - ); - await asyncReadWrite(messages1, websocket); - logger.info('ending'); + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + await streamPair.readable.cancel(); + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const websocket = await clientRPCUtils.startConnection( + host, + clientServer.port, + logger.getChild('Connection'), + ); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await clientServer.stop(true); + } }, ); test('Destroying ClientServer stops all connections', async () => { - clientServer = await ClientServer.createWSServer({ + clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -286,7 +304,7 @@ describe('ClientServer', () => { clientServer.port, logger.getChild('Connection'), ); - await clientServer.destroy(true); + await clientServer.stop(true); for await (const _ of websocket.readable) { // No touch, only consume } @@ -296,7 +314,7 @@ describe('ClientServer', () => { let context: { writeBackpressure: boolean } | undefined; const backpressure = promise(); const resumeWriting = promise(); - clientServer = await ClientServer.createWSServer({ + clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void Promise.allSettled([ @@ -359,10 +377,10 @@ describe('ClientServer', () => { }); // Readable backpressure is not actually supported. We're dealing with it by // using an buffer with a provided limit that can be very large. - test.only('Exceeding readable buffer limit causes error', async () => { + test('Exceeding readable buffer limit causes error', async () => { const startReading = promise(); const handlingProm = promise(); - clientServer = await ClientServer.createWSServer({ + clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); Promise.all([ diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index 577724a87..2c7d8bfd7 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -120,7 +120,7 @@ describe('websocket', () => { test('Using uws', async () => { tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - const server = await ClientServer.createWSServer({ + const server = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable From 485e4a84b622327c1301b56c38db18b04f5036ce Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 20 Feb 2023 18:06:42 +1100 Subject: [PATCH 04/23] feat: creating `ClientClient` implementation [ci skip] --- src/clientRPC/ClientClient.ts | 222 +++++++++++++++++++++++++++ src/clientRPC/ClientServer.ts | 22 +-- tests/clientRPC/ClientServer.test.ts | 178 +++++++++++++-------- tests/clientRPC/websocket.test.ts | 168 -------------------- 4 files changed, 350 insertions(+), 240 deletions(-) create mode 100644 src/clientRPC/ClientClient.ts delete mode 100644 tests/clientRPC/websocket.test.ts diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/ClientClient.ts new file mode 100644 index 000000000..2d42a391d --- /dev/null +++ b/src/clientRPC/ClientClient.ts @@ -0,0 +1,222 @@ +import type { ReadableWritablePair } from 'stream/web'; +import type { TLSSocket } from 'tls'; +import type { NodeId } from 'ids/index'; +import { WritableStream, ReadableStream } from 'stream/web'; +import { createDestroy } from '@matrixai/async-init'; +import Logger from '@matrixai/logger'; +import WebSocket from 'ws'; +import { PromiseCancellable } from '@matrixai/async-cancellable'; +import { promise } from '../utils'; + +interface ClientClient extends createDestroy.CreateDestroy {} +@createDestroy.CreateDestroy() +class ClientClient { + static async createClientClient({ + host, + port, + nodeId, + maxReadableStreamBytes = 1000, // About 1kB + logger = new Logger(this.name), + }: { + host: string; + port: number; + nodeId: NodeId; + maxReadableStreamBytes?: number; + logger?: Logger; + }): Promise { + logger.info(`Creating ${this.name}`); + const clientClient = new this( + logger, + host, + port, + maxReadableStreamBytes, + nodeId, + ); + logger.info(`Created ${this.name}`); + return clientClient; + } + + protected activeConnections: Set> = new Set(); + + constructor( + protected logger: Logger, + protected host: string, + protected port: number, + protected maxReadableStreamBytes: number, + protected nodeId: NodeId, + ) {} + + public async destroy(force: boolean = false) { + this.logger.info(`Destroying ${this.constructor.name}`); + if (force) { + for (const activeConnection of this.activeConnections) { + activeConnection.cancel(); + } + } + for (const activeConnection of this.activeConnections) { + await activeConnection; + } + this.logger.info(`Destroyed ${this.constructor.name}`); + } + + @createDestroy.ready(Error('TMP destroyed')) + public async startConnection(): Promise< + ReadableWritablePair + > { + const address = `wss://${this.host}:${this.port}`; + this.logger.info(`Connecting to ${address}`); + const connectProm = promise(); + const ws = new WebSocket(address, { + rejectUnauthorized: false, + }); + // Creating logic for awaiting active connections and terminating them + const abortHandler = () => { + ws.terminate(); + }; + const abortController = new AbortController(); + const connectionProm = new PromiseCancellable((resolve) => { + ws.once('close', () => { + abortController.signal.removeEventListener('abort', abortHandler); + resolve(); + }); + }, abortController); + abortController.signal.addEventListener('abort', abortHandler); + this.activeConnections.add(connectionProm); + connectionProm.finally(() => this.activeConnections.delete(connectionProm)); + + // Handle connection failure + const openErrorHandler = (e) => { + connectProm.rejectP(Error('TMP ERROR Connection failure', { cause: e })); + }; + ws.once('error', openErrorHandler); + // Authenticate server's certificates + ws.once('upgrade', (request) => { + const tlsSocket = request.socket as TLSSocket; + const peerCert = tlsSocket.getPeerCertificate(); + // TODO: custom authentication here + this.logger.info(`server NodeId ${peerCert.issuer.CN}`); + }); + ws.once('open', () => { + this.logger.info('starting connection'); + connectProm.resolveP(); + }); + // TODO: Race with a connection timeout here + await connectProm.p; + // Cleaning up connection error + ws.removeEventListener('error', openErrorHandler); + + let readableClosed = false; + let writableClosed = false; + const readableLogger = this.logger.getChild('readable'); + const writableLogger = this.logger.getChild('writable'); + const readableStream = new ReadableStream( + { + start: (controller) => { + readableLogger.info('STARTING'); + const messageHandler = (data) => { + // ReadableLogger.debug(`message: ${data.toString()}`); + if (controller.desiredSize == null) { + controller.error(Error('NEVER')); + return; + } + if (controller.desiredSize < 0) { + // ReadableLogger.debug('PAUSING'); + ws.pause(); + } + const message = data as Buffer; + if (message.length === 0) { + readableLogger.info('CLOSING, NULL MESSAGE'); + ws.removeListener('message', messageHandler); + if (!readableClosed) { + controller.close(); + readableClosed = true; + } + if (writableClosed) { + ws.close(); + } + return; + } + controller.enqueue(message); + }; + ws.on('message', messageHandler); + ws.once('close', () => { + readableLogger.info('CLOSED, WS CLOSED'); + ws.removeListener('message', messageHandler); + if (!readableClosed) { + controller.close(); + readableClosed = true; + } + }); + ws.once('error', (e) => readableLogger.error(e)); + }, + cancel: () => { + readableLogger.info('CANCELLED'); + if (!readableClosed) { + ws.close(); + readableClosed = true; + } + }, + pull: () => { + // ReadableLogger.debug('RESUMING'); + ws.resume(); + }, + }, + { + highWaterMark: this.maxReadableStreamBytes, + size: (chunk) => chunk?.byteLength ?? 0, + }, + ); + const writableStream = new WritableStream({ + start: (controller) => { + writableLogger.info('STARTING'); + ws.once('error', (e) => { + writableLogger.error(`error: ${e}`); + if (!writableClosed) { + controller.error(e); + writableClosed = true; + } + }); + ws.once('close', (code, reason) => { + if (!writableClosed) { + writableLogger.info( + `ws closing early! with code: ${code} and reason: ${reason.toString()}`, + ); + controller.error(Error('TMP WebSocket Closed early')); + } + }); + }, + close: () => { + writableLogger.info('CLOSING'); + ws.send(Buffer.from([])); + writableClosed = true; + if (readableClosed) { + ws.close(); + } + }, + abort: () => { + writableLogger.info('ABORTED'); + writableClosed = true; + if (readableClosed) { + ws.close(); + } + }, + write: async (chunk, controller) => { + // WritableLogger.debug(`writing: ${chunk?.toString()}`); + const wait = promise(); + ws.send(chunk, (e) => { + if (e != null) { + controller.error(e); + } + wait.resolveP(); + }); + await wait.p; + }, + }); + return { + readable: readableStream, + writable: writableStream, + }; + } +} + +export default ClientClient; diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index c1e24a49e..7c7c871d6 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -177,8 +177,8 @@ class ClientServer { uWebsocket.us_listen_socket_close(this.listenSocket); // Shutting down active websockets if (force) { - for (const ws of this.activeSockets.values()) { - ws.close(); + for (const ws of this.activeSockets) { + ws.end(); } } // Wait for all active websockets to close @@ -196,6 +196,7 @@ class ClientServer { logger.info('WS opened'); let writableClosed = false; let readableClosed = false; + let wsClosed = false; let backpressure: PromiseDeconstructed | null = null; context.drain = () => { logger.debug('DRAINING CALLED'); @@ -204,7 +205,7 @@ class ClientServer { // Setting up the writable stream const writableStream = new WritableStream({ write: async (chunk, controller) => { - // Logger.debug('WRITABLE WRITE'); + // Logger.debug(`WRITABLE WRITE ${chunk.toString()}`); await backpressure?.p; const writeResult = ws.send(chunk, true); switch (writeResult) { @@ -229,15 +230,16 @@ class ClientServer { }, close: () => { logger.info('WRITABLE CLOSE'); + if (!wsClosed) ws.send(Buffer.from([]), true); writableClosed = true; - if (readableClosed) { + if (readableClosed && !wsClosed) { logger.debug('ENDING WS'); ws.end(); } }, abort: () => { logger.info('WRITABLE ABORT'); - if (readableClosed) { + if (readableClosed && !wsClosed) { logger.debug('ENDING WS'); ws.end(); } @@ -248,14 +250,14 @@ class ClientServer { { start: (controller) => { context.message = (ws, message, _) => { - // Logger.debug('MESSAGE CALLED'); + // Logger.debug(`MESSAGE CALLED ${message.toString()}`); if (message.byteLength === 0) { logger.debug('NULL MESSAGE, CLOSING'); if (!readableClosed) { logger.debug('CLOSING READABLE'); controller.close(); readableClosed = true; - if (writableClosed) { + if (writableClosed && !wsClosed) { ws.end(); } } @@ -268,12 +270,13 @@ class ClientServer { ) { logger.error('Read stream buffer full'); const err = Error('TMP read buffer limit'); - ws.end(4001, err.toString()); + if (!wsClosed) ws.end(4001, err.toString()); controller.error(err); } }; context.close = () => { logger.debug('CLOSING CALLED'); + wsClosed = true; if (!readableClosed) { logger.debug('CLOSING READABLE'); controller.close(); @@ -283,7 +286,7 @@ class ClientServer { }, cancel: () => { readableClosed = true; - if (writableClosed) { + if (writableClosed && !wsClosed) { logger.debug('ENDING WS'); ws.end(); } @@ -301,7 +304,6 @@ class ClientServer { writable: writableStream, }); } catch (e) { - logger.error(e); // TODO: If the callback failed then we need to handle clean up logger.error(e.toString()); } diff --git a/tests/clientRPC/ClientServer.test.ts b/tests/clientRPC/ClientServer.test.ts index e8fc75670..735e1034f 100644 --- a/tests/clientRPC/ClientServer.test.ts +++ b/tests/clientRPC/ClientServer.test.ts @@ -8,8 +8,8 @@ import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { KeyRing } from '@/keys/index'; import ClientServer from '@/clientRPC/ClientServer'; -import * as clientRPCUtils from '@/clientRPC/utils'; import { promise } from '@/utils'; +import ClientClient from '@/clientRPC/ClientClient'; import * as testsUtils from '../utils'; describe('ClientServer', () => { @@ -29,6 +29,31 @@ describe('ClientServer', () => { let tlsConfig: TLSConfig; const host = '127.0.0.2'; let clientServer: ClientServer; + let clientClient: ClientClient; + + const messagesArb = fc.array( + fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), + ); + const streamsArb = fc.array(messagesArb, { minLength: 1 }).noShrink(); + const asyncReadWrite = async ( + messages: Array, + streampair: ReadableWritablePair, + ) => { + await Promise.allSettled([ + (async () => { + const writer = streampair.writable.getWriter(); + for (const message of messages) { + await writer.write(message); + } + await writer.close(); + })(), + (async () => { + for await (const _ of streampair.readable) { + // No touch, only consume + } + })(), + ]); + }; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -43,12 +68,14 @@ describe('ClientServer', () => { tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { + logger.info('AFTEREACH'); await clientServer.stop(true); + await clientClient.destroy(); await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); }); - test('Handles a connection', async () => { + test('makes a connection', async () => { clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); @@ -63,11 +90,14 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - logger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + const writer = websocket.writable.getWriter(); const reader = websocket.readable.getReader(); const message1 = Buffer.from('1request1'); @@ -95,20 +125,18 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - logger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); await websocket.writable.close(); const reader = websocket.readable.getReader(); expect((await reader.read()).done).toBeTrue(); logger.info('ending'); }); - const messagesArb = fc.array( - fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), - ); - const streamsArb = fc.array(messagesArb, { minLength: 1 }).noShrink(); testProp( 'Handles multiple connections', [streamsArb], @@ -128,13 +156,15 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); const testStream = async (messages: Array) => { - const websocket = await clientRPCUtils.startConnection( - host, - clientServer.port, - logger.getChild('Connection'), - ); + const websocket = await clientClient.startConnection(); const writer = websocket.writable.getWriter(); const reader = websocket.readable.getReader(); for (const message of messages) { @@ -157,25 +187,6 @@ describe('ClientServer', () => { } }, ); - const asyncReadWrite = async ( - messages: Array, - streampair: ReadableWritablePair, - ) => { - await Promise.allSettled([ - (async () => { - const writer = streampair.writable.getWriter(); - for (const message of messages) { - await writer.write(message); - } - await writer.close(); - })(), - (async () => { - for await (const _ of streampair.readable) { - // No touch, only consume - } - })(), - ]); - }; testProp( 'allows half closed writable closes first', [messagesArb, messagesArb], @@ -201,11 +212,13 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - logger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); await asyncReadWrite(messages1, websocket); logger.info('ending'); } finally { @@ -238,11 +251,13 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - logger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); await asyncReadWrite(messages1, websocket); logger.info('ending'); } finally { @@ -273,11 +288,13 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - logger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); await asyncReadWrite(messages1, websocket); logger.info('ending'); } finally { @@ -299,17 +316,47 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - logger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); await clientServer.stop(true); for await (const _ of websocket.readable) { // No touch, only consume } logger.info('ending'); }); + test('Destroying ClientClient stops all connections', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable.pipeTo(streamPair.writable).catch((e) => { + logger.error(e); + }); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await clientClient.destroy(true); + for await (const _ of websocket.readable) { + // No touch, only consume + } + await clientServer.stop(); + logger.info('ending'); + }); test('Writable backpressure', async () => { let context: { writeBackpressure: boolean } | undefined; const backpressure = promise(); @@ -358,11 +405,13 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - logger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); await websocket.writable.close(); await backpressure.p; @@ -407,11 +456,13 @@ describe('ClientServer', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - const websocket = await clientRPCUtils.startConnection( + clientClient = await ClientClient.createClientClient({ host, - clientServer.port, - loudLogger.getChild('Connection'), - ); + port: clientServer.port, + nodeId: keyRing.getNodeId(), + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); const message = Buffer.alloc(1_000, 0xf0); const writer = websocket.writable.getWriter(); loudLogger.info('Starting writes'); @@ -428,4 +479,7 @@ describe('ClientServer', () => { await handlingProm.p; loudLogger.info('ending'); }); + test.todo('client ends connection abruptly'); + test.todo('Server ends connection abruptly'); + test.todo('Client rejects bad server certificate'); }); diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts deleted file mode 100644 index 2c7d8bfd7..000000000 --- a/tests/clientRPC/websocket.test.ts +++ /dev/null @@ -1,168 +0,0 @@ -import type { TLSConfig } from '@/network/types'; -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; -import type { ClientManifest } from '@/RPC/types'; -import type { JSONValue } from '@/types'; -import fs from 'fs'; -import path from 'path'; -import os from 'os'; -import { createServer } from 'https'; -import Logger, { LogLevel, StreamHandler, formatting } from '@matrixai/logger'; -import RPCServer from '@/RPC/RPCServer'; -import RPCClient from '@/RPC/RPCClient'; -import { KeyRing } from '@/keys/index'; -import * as clientRPCUtils from '@/clientRPC/utils'; -import { UnaryHandler } from '@/RPC/handlers'; -import { UnaryCaller } from '@/RPC/callers'; -import ClientServer from '@/clientRPC/ClientServer'; -import * as testsUtils from '../utils/index'; - -describe('websocket', () => { - const logger = new Logger('websocket test', LogLevel.WARN, [ - new StreamHandler( - formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, - ), - ]); - const loudLogger = new Logger('websocket test', LogLevel.DEBUG, [ - new StreamHandler( - formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, - ), - ]); - let dataDir: string; - let keyRing: KeyRing; - let tlsConfig: TLSConfig; - let server: Server; - let wss: WebSocketServer; - const host = '127.0.0.2'; - let port: number; - let rpcServer: RPCServer; - let rpcClient_: RPCClient; - - beforeEach(async () => { - dataDir = await fs.promises.mkdtemp( - path.join(os.tmpdir(), 'polykey-test-'), - ); - const keysPath = path.join(dataDir, 'keys'); - keyRing = await KeyRing.createKeyRing({ - keysPath: keysPath, - password: 'password', - logger: logger.getChild('keyRing'), - }); - tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, host); - }); - afterEach(async () => { - await rpcClient_?.destroy(); - await rpcServer?.destroy(); - wss?.close(); - server.close(); - await keyRing.stop(); - await fs.promises.rm(dataDir, { force: true, recursive: true }); - }); - - test('websocket should work with RPC', async () => { - // Setting up server - class Test1 extends UnaryHandler { - public async handle(input: JSONValue): Promise { - return input; - } - } - class Test2 extends UnaryHandler { - public async handle(): Promise { - return { hello: 'not world' }; - } - } - rpcServer = await RPCServer.createRPCServer({ - manifest: { - test1: new Test1({}), - test2: new Test2({}), - }, - logger: logger.getChild('RPCServer'), - }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('client'), - ); - - // Setting up client - const rpcClient = await RPCClient.createRPCClient({ - manifest: { - test1: new UnaryCaller(), - test2: new UnaryCaller(), - }, - logger: logger.getChild('RPCClient'), - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - host, - port, - logger.getChild('Connection'), - ); - }, - }); - rpcClient_ = rpcClient; - - // Making the call - await expect( - rpcClient.methods.test1({ hello: 'world2' }), - ).resolves.toStrictEqual({ hello: 'world2' }); - await expect( - rpcClient.methods.test2({ hello: 'world2' }), - ).resolves.toStrictEqual({ hello: 'not world' }); - await expect( - rpcClient.unaryCaller('test3', { hello: 'world2' }), - ).toReject(); - }); - - test('Using uws', async () => { - tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - const server = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => logger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${server.port}`); - - const websocket1 = await clientRPCUtils.startConnection( - host, - server.port, - logger.getChild('Connection'), - ); - - // Const websocket2 = await clientRPCUtils.startConnection( - // host, - // server.port, - // logger.getChild('Connection'), - // ); - - logger.info('doing things'); - const writer1 = websocket1.writable.getWriter(); - // Const writer2 = websocket2.writable.getWriter(); - await writer1.write(Buffer.from('1request1')); - // Await writer2.write(Buffer.from('2request1')); - await writer1.write(Buffer.from('1request2')); - // Await writer2.write(Buffer.from('2request2')); - await writer1.close(); - // Await writer2.close(); - for await (const val of websocket1.readable) { - logger.info(`Client1 message: ${val.toString()}`); - } - // For await (const val of websocket2.readable) { - // logger.info(`Client2 message: ${val.toString()}`); - // } - logger.info('ending'); - await server.destroy(); - }); -}); From 52500cae71e7618ef44d1727abe0208c6942c113 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 21 Feb 2023 16:00:35 +1100 Subject: [PATCH 05/23] feat: client authenticates server [ci skip] --- src/clientRPC/ClientClient.ts | 33 ++- src/clientRPC/ClientServer.ts | 1 - src/clientRPC/utils.ts | 341 +++++++++++---------------- tests/clientRPC/ClientServer.test.ts | 121 +++++++++- tests/utils/utils.ts | 34 +++ 5 files changed, 309 insertions(+), 221 deletions(-) diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/ClientClient.ts index 2d42a391d..5780a3004 100644 --- a/src/clientRPC/ClientClient.ts +++ b/src/clientRPC/ClientClient.ts @@ -6,6 +6,7 @@ import { createDestroy } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import WebSocket from 'ws'; import { PromiseCancellable } from '@matrixai/async-cancellable'; +import * as clientRpcUtils from './utils'; import { promise } from '../utils'; interface ClientClient extends createDestroy.CreateDestroy {} @@ -14,13 +15,13 @@ class ClientClient { static async createClientClient({ host, port, - nodeId, + expectedNodeIds, maxReadableStreamBytes = 1000, // About 1kB logger = new Logger(this.name), }: { host: string; port: number; - nodeId: NodeId; + expectedNodeIds: Array; maxReadableStreamBytes?: number; logger?: Logger; }): Promise { @@ -30,7 +31,7 @@ class ClientClient { host, port, maxReadableStreamBytes, - nodeId, + expectedNodeIds, ); logger.info(`Created ${this.name}`); return clientClient; @@ -43,7 +44,7 @@ class ClientClient { protected host: string, protected port: number, protected maxReadableStreamBytes: number, - protected nodeId: NodeId, + protected expectedNodeIds: Array, ) {} public async destroy(force: boolean = false) { @@ -66,6 +67,7 @@ class ClientClient { const address = `wss://${this.host}:${this.port}`; this.logger.info(`Connecting to ${address}`); const connectProm = promise(); + const authenticateProm = promise(); const ws = new WebSocket(address, { rejectUnauthorized: false, }); @@ -83,25 +85,36 @@ class ClientClient { abortController.signal.addEventListener('abort', abortHandler); this.activeConnections.add(connectionProm); connectionProm.finally(() => this.activeConnections.delete(connectionProm)); - // Handle connection failure const openErrorHandler = (e) => { connectProm.rejectP(Error('TMP ERROR Connection failure', { cause: e })); }; ws.once('error', openErrorHandler); // Authenticate server's certificates - ws.once('upgrade', (request) => { + ws.once('upgrade', async (request) => { const tlsSocket = request.socket as TLSSocket; - const peerCert = tlsSocket.getPeerCertificate(); - // TODO: custom authentication here - this.logger.info(`server NodeId ${peerCert.issuer.CN}`); + const peerCert = tlsSocket.getPeerCertificate(true); + clientRpcUtils + .verifyServerCertificateChain( + this.expectedNodeIds, + clientRpcUtils.detailedToCertChain(peerCert), + ) + .then(authenticateProm.resolveP, authenticateProm.rejectP); }); ws.once('open', () => { this.logger.info('starting connection'); connectProm.resolveP(); }); // TODO: Race with a connection timeout here - await connectProm.p; + try { + await Promise.all([authenticateProm.p, connectProm.p]); + } catch (e) { + // Clean up + ws.close(); + await connectionProm; + throw e; + } + // Cleaning up connection error ws.removeEventListener('error', openErrorHandler); diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index 7c7c871d6..a385712ed 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -109,7 +109,6 @@ class ClientServer { await this.fs.promises.rm(certFile); this.server.ws('/*', { upgrade: (res, req, context) => { - // Req.forEach((k, v) => console.log(k, ':', v)); const logger = this.logger.getChild(`Connection ${count}`); res.upgrade>( { diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index e9b41f834..318aaa73c 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -3,18 +3,13 @@ import type KeyRing from '../keys/KeyRing'; import type SessionManager from '../sessions/SessionManager'; import type { RPCRequestParams } from './types'; import type { JsonRpcRequest } from '../RPC/types'; -import type { ReadableWritablePair } from 'stream/web'; -import type Logger from '@matrixai/logger'; -import type { ConnectionInfo, Host, Port } from '../network/types'; -import type RPCServer from '../RPC/RPCServer'; -import type { TLSSocket } from 'tls'; -import type { Server } from 'https'; -import type net from 'net'; -import type https from 'https'; -import { ReadableStream, WritableStream } from 'stream/web'; -import WebSocket, { WebSocketServer } from 'ws'; +import type { Certificate } from 'keys/types'; +import type { DetailedPeerCertificate } from 'tls'; +import type { NodeId } from 'ids/index'; +import * as x509 from '@peculiar/x509'; import * as clientErrors from '../client/errors'; -import { promise } from '../utils'; +import * as networkErrors from '../network/errors'; +import * as keysUtils from '../keys/utils/index'; async function authenticate( sessionManager: SessionManager, @@ -65,201 +60,149 @@ function encodeAuthFromPassword(password: string): string { return `Basic ${encoded}`; } -function readableFromWebSocket( - ws: WebSocket, - logger: Logger, -): ReadableStream { - return new ReadableStream({ - start: (controller) => { - logger.info('starting'); - const messageHandler = (data) => { - logger.debug(`message: ${data.toString()}`); - ws.pause(); - const message = data as Buffer; - if (message.length === 0) { - logger.info('ENDING'); - ws.removeAllListeners('message'); - try { - controller.close(); - } catch { - // Ignore already closed - } - return; - } - controller.enqueue(message); - }; - ws.on('message', messageHandler); - ws.once('close', () => { - logger.info('closed'); - ws.removeListener('message', messageHandler); - try { - controller.close(); - } catch { - // Ignore already closed - } - }); - ws.once('error', (e) => { - controller.error(e); - }); - }, - cancel: () => { - logger.info('cancelled'); - ws.close(); - }, - pull: () => { - logger.debug('resuming'); - ws.resume(); - }, - }); +function detailedToCertChain( + cert: DetailedPeerCertificate, +): Array { + const certChain: Array = []; + let currentCert = cert; + while (true) { + certChain.unshift(new x509.X509Certificate(currentCert.raw)); + if (currentCert === currentCert.issuerCertificate) break; + currentCert = currentCert.issuerCertificate; + } + return certChain; } -function writeableFromWebSocket( - ws: WebSocket, - holdOpen: boolean, - logger: Logger, -): WritableStream { - return new WritableStream({ - start: (controller) => { - logger.info('starting'); - ws.once('error', (e) => { - logger.error(`error: ${e}`); - controller.error(e); - }); - ws.once('close', (code, reason) => { - logger.info( - `ws closing early! with code: ${code} and reason: ${reason.toString()}`, +/** + * Verify the server certificate chain when connecting to it from a client + * This is a custom verification intended to verify that the server owned + * the relevant NodeId. + * It is possible that the server has a new NodeId. In that case we will + * verify that the new NodeId is the true descendant of the target NodeId. + */ +async function verifyServerCertificateChain( + nodeIds: Array, + certChain: Array, +): Promise { + if (!certChain.length) { + throw new networkErrors.ErrorCertChainEmpty( + 'No certificates available to verify', + ); + } + if (!nodeIds.length) { + throw new networkErrors.ErrorConnectionNodesEmpty( + 'No nodes were provided to verify against', + ); + } + const now = new Date(); + let certClaim: Certificate | null = null; + let certClaimIndex: number | null = null; + let verifiedNodeId: NodeId | null = null; + for (let certIndex = 0; certIndex < certChain.length; certIndex++) { + const cert = certChain[certIndex]; + if (now < cert.notBefore || now > cert.notAfter) { + throw new networkErrors.ErrorCertChainDateInvalid( + 'Chain certificate date is invalid', + { + data: { + cert, + certIndex, + notBefore: cert.notBefore, + notAfter: cert.notAfter, + now, + }, + }, + ); + } + const certNodeId = keysUtils.certNodeId(cert); + if (certNodeId == null) { + throw new networkErrors.ErrorCertChainNameInvalid( + 'Chain certificate common name attribute is missing', + { + data: { + cert, + certIndex, + }, + }, + ); + } + const certPublicKey = keysUtils.certPublicKey(cert); + if (certPublicKey == null) { + throw new networkErrors.ErrorCertChainKeyInvalid( + 'Chain certificate public key is missing', + { + data: { + cert, + certIndex, + }, + }, + ); + } + if (!(await keysUtils.certNodeSigned(cert))) { + throw new networkErrors.ErrorCertChainSignatureInvalid( + 'Chain certificate does not have a valid node-signature', + { + data: { + cert, + certIndex, + nodeId: keysUtils.publicKeyToNodeId(certPublicKey), + commonName: certNodeId, + }, + }, + ); + } + for (const nodeId of nodeIds) { + if (certNodeId.equals(nodeId)) { + // Found the certificate claiming the nodeId + certClaim = cert; + certClaimIndex = certIndex; + verifiedNodeId = nodeId; + } + } + // If cert is found then break out of loop + if (verifiedNodeId != null) break; + } + if (certClaimIndex == null || certClaim == null || verifiedNodeId == null) { + throw new networkErrors.ErrorCertChainUnclaimed( + 'Node IDs is not claimed by any certificate', + { + data: { nodeIds }, + }, + ); + } + if (certClaimIndex > 0) { + let certParent: Certificate; + let certChild: Certificate; + for (let certIndex = certClaimIndex; certIndex > 0; certIndex--) { + certParent = certChain[certIndex]; + certChild = certChain[certIndex - 1]; + if ( + !keysUtils.certIssuedBy(certParent, certChild) || + !(await keysUtils.certSignedBy( + certParent, + keysUtils.certPublicKey(certChild)!, + )) + ) { + throw new networkErrors.ErrorCertChainBroken( + 'Chain certificate is not signed by parent certificate', + { + data: { + cert: certChild, + certIndex: certIndex - 1, + certParent, + }, + }, ); - controller.error(Error('TMP WebSocket Closed early')); - }); - }, - close: () => { - logger.info('stream closing'); - ws.send(Buffer.from([])); - if (!holdOpen) ws.terminate(); - }, - abort: () => { - logger.info('aborting'); - ws.close(); - }, - write: async (chunk, controller) => { - // Logger.debug(`writing: ${chunk?.toString()}`); - const wait = promise(); - ws.send(chunk, (e) => { - if (e != null) { - controller.error(e); - } - wait.resolveP(); - }); - await wait.p; - }, - }); -} - -function webSocketToWebStreamPair( - ws: WebSocket, - holdOpen: boolean, - logger: Logger, -): ReadableWritablePair { - return { - readable: readableFromWebSocket(ws, logger.getChild('readable')), - writable: writeableFromWebSocket(ws, holdOpen, logger.getChild('writable')), - }; -} - -function startConnection( - host: string, - port: number, - logger: Logger, -): Promise> { - const ws = new WebSocket(`wss://${host}:${port}`, { - // CheckServerIdentity: ( - // servername: string, - // cert: WebSocket.CertMeta, - // ): boolean => { - // console.log('CHECKING IDENTITY'); - // console.log(servername); - // console.log(cert); - // return false; - // }, - rejectUnauthorized: false, - // Ca: tlsConfig.certChainPem - }); - ws.once('close', () => logger.info('CLOSED')); - // Ws.once('upgrade', (request) => { - // const tlsSocket = request.socket as TLSSocket; - // const peerCert = tlsSocket.getPeerCertificate(); - // console.log(peerCert.issuer.CN); - // logger.info('Test early cancellation'); - // // Request.destroy(Error('some error')); - // // tlsSocket.destroy(Error('some error')); - // // ws.close(12345, 'some reason'); - // // TODO: Use the existing verify method from the GRPC implementation - // // TODO: Have this emit an error on verification failure. - // // It's fine for the server side to close abruptly without error - // }); - const prom = promise>(); - ws.once('open', () => { - logger.info('starting connection'); - prom.resolveP(webSocketToWebStreamPair(ws, true, logger)); - }); - return prom.p; -} - -function handleConnection(ws: WebSocket, logger: Logger): void { - ws.once('close', () => logger.info('CLOSED')); - const readable = readableFromWebSocket(ws, logger.getChild('readable')); - const writable = writeableFromWebSocket( - ws, - false, - logger.getChild('writable'), - ); - void readable.pipeTo(writable).catch((e) => logger.error(e)); -} - -function createClientServer( - server: Server, - rpcServer: RPCServer, - logger: Logger, -) { - logger.info('created server'); - const wss = new WebSocketServer({ - server, - }); - wss.on('error', (e) => logger.error(e)); - logger.info('created wss'); - wss.on('connection', (ws, req) => { - logger.info('connection!'); - const socket = req.socket as TLSSocket; - const streamPair = webSocketToWebStreamPair(ws, false, logger); - rpcServer.handleStream(streamPair, { - localHost: socket.localAddress! as Host, - localPort: socket.localPort! as Port, - remoteCertificates: socket.getPeerCertificate(), - remoteHost: socket.remoteAddress! as Host, - remotePort: socket.remotePort! as Port, - } as unknown as ConnectionInfo); - }); - wss.once('close', () => { - wss.removeAllListeners('error'); - wss.removeAllListeners('connection'); - }); - return wss; -} - -async function listen(server: https.Server, host?: string, port?: number) { - await new Promise((resolve) => { - server.listen(port, host ?? '127.0.0.1', undefined, () => resolve()); - }); - const addressInfo = server.address() as net.AddressInfo; - return addressInfo.port; + } + } + } + return verifiedNodeId; } export { authenticate, decodeAuth, encodeAuthFromPassword, - startConnection, - handleConnection, - createClientServer, - listen, + detailedToCertChain, + verifyServerCertificateChain, }; diff --git a/tests/clientRPC/ClientServer.test.ts b/tests/clientRPC/ClientServer.test.ts index 735e1034f..68dd1a3e0 100644 --- a/tests/clientRPC/ClientServer.test.ts +++ b/tests/clientRPC/ClientServer.test.ts @@ -1,6 +1,7 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; import type { WebSocket } from 'uWebSockets.js'; +import type { KeyPair } from '@/keys/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -10,6 +11,9 @@ import { KeyRing } from '@/keys/index'; import ClientServer from '@/clientRPC/ClientServer'; import { promise } from '@/utils'; import ClientClient from '@/clientRPC/ClientClient'; +import * as keysUtils from '@/keys/utils'; +import * as networkErrors from '@/network/errors'; +import * as testNodeUtils from '../nodes/utils'; import * as testsUtils from '../utils'; describe('ClientServer', () => { @@ -93,7 +97,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -128,7 +132,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -159,7 +163,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); @@ -215,7 +219,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -254,7 +258,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -291,7 +295,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -319,7 +323,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -346,7 +350,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -408,7 +412,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -459,7 +463,7 @@ describe('ClientServer', () => { clientClient = await ClientClient.createClientClient({ host, port: clientServer.port, - nodeId: keyRing.getNodeId(), + expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); @@ -481,5 +485,100 @@ describe('ClientServer', () => { }); test.todo('client ends connection abruptly'); test.todo('Server ends connection abruptly'); - test.todo('Client rejects bad server certificate'); + test('Client rejects bad server certificate', async () => { + const invalidNodeId = testNodeUtils.generateRandomNodeId(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [invalidNodeId], + logger: logger.getChild('clientClient'), + }); + await expect(clientClient.startConnection()).rejects.toThrow( + networkErrors.ErrorCertChainUnclaimed, + ); + // @ts-ignore: kidnap protected property + const activeConnections = clientClient.activeConnections; + expect(activeConnections.size).toBe(0); + logger.info('ending'); + }); + test('Client authenticates with multiple certs in chain', async () => { + const keyPairs: Array = [ + keyRing.keyPair, + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + ]; + const tlsConfig = await testsUtils.createTLSConfigWithChain(keyPairs); + const nodeId = keysUtils.publicKeyToNodeId(keyPairs[1].publicKey); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [nodeId], + logger: logger.getChild('clientClient'), + }); + const connProm = clientClient.startConnection(); + await connProm; + await expect(connProm).toResolve(); + // @ts-ignore: kidnap protected property + const activeConnections = clientClient.activeConnections; + expect(activeConnections.size).toBe(1); + logger.info('ending'); + }); + test('Client authenticates with multiple expected nodes', async () => { + const alternativeNodeId = testNodeUtils.generateRandomNodeId(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], + logger: logger.getChild('clientClient'), + }); + await expect(clientClient.startConnection()).toResolve(); + // @ts-ignore: kidnap protected property + const activeConnections = clientClient.activeConnections; + expect(activeConnections.size).toBe(1); + logger.info('ending'); + }); }); diff --git a/tests/utils/utils.ts b/tests/utils/utils.ts index 79c529794..a3d7e8ac0 100644 --- a/tests/utils/utils.ts +++ b/tests/utils/utils.ts @@ -3,6 +3,7 @@ import type { NodeId, CertId } from '@/ids/types'; import type { StatusLive } from '@/status/types'; import type { TLSConfig } from '@/network/types'; import type { CertificatePEMChain, KeyPair } from '@/keys/types'; +import type { Certificate } from '@/keys/types'; import path from 'path'; import fs from 'fs'; import readline from 'readline'; @@ -126,6 +127,38 @@ async function createTLSConfig( }; } +async function createTLSConfigWithChain( + keyPairs: Array, + generateCertId?: () => CertId, +): Promise { + if (keyPairs.length === 0) throw Error('Must have at least 1 keypair'); + generateCertId = generateCertId ?? keysUtils.createCertIdGenerator(); + let previousCert: Certificate | null = null; + let previousKeyPair: KeyPair | null = null; + const certChain: Array = []; + for (const keyPair of keyPairs) { + const newCert = await keysUtils.generateCertificate({ + certId: generateCertId(), + duration: 31536000, + issuerPrivateKey: previousKeyPair?.privateKey ?? keyPair.privateKey, + subjectKeyPair: keyPair, + issuerAttrsExtra: previousCert?.subjectName.toJSON(), + }); + certChain.unshift(newCert); + previousCert = newCert; + previousKeyPair = keyPair; + } + let certChainPEM = ''; + for (const certificate of certChain) { + certChainPEM += keysUtils.certToPEM(certificate); + } + + return { + keyPrivatePem: keysUtils.privateKeyToPEM(previousKeyPair!.privateKey), + certChainPem: certChainPEM as CertificatePEMChain, + }; +} + export { setupTestAgent, generateRandomNodeId, @@ -133,4 +166,5 @@ export { testIf, describeIf, createTLSConfig, + createTLSConfigWithChain, }; From eba92b2f494bba3cbb2a9781291cf3a27a9752db Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 21 Feb 2023 16:21:00 +1100 Subject: [PATCH 06/23] tests: re-ordering tests [ci skip] --- tests/clientRPC/ClientServer.test.ts | 584 -------------------------- tests/clientRPC/clientRPC.test.ts | 597 +++++++++++++++++++++++++++ 2 files changed, 597 insertions(+), 584 deletions(-) delete mode 100644 tests/clientRPC/ClientServer.test.ts create mode 100644 tests/clientRPC/clientRPC.test.ts diff --git a/tests/clientRPC/ClientServer.test.ts b/tests/clientRPC/ClientServer.test.ts deleted file mode 100644 index 68dd1a3e0..000000000 --- a/tests/clientRPC/ClientServer.test.ts +++ /dev/null @@ -1,584 +0,0 @@ -import type { ReadableWritablePair } from 'stream/web'; -import type { TLSConfig } from '@/network/types'; -import type { WebSocket } from 'uWebSockets.js'; -import type { KeyPair } from '@/keys/types'; -import fs from 'fs'; -import path from 'path'; -import os from 'os'; -import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; -import { testProp, fc } from '@fast-check/jest'; -import { KeyRing } from '@/keys/index'; -import ClientServer from '@/clientRPC/ClientServer'; -import { promise } from '@/utils'; -import ClientClient from '@/clientRPC/ClientClient'; -import * as keysUtils from '@/keys/utils'; -import * as networkErrors from '@/network/errors'; -import * as testNodeUtils from '../nodes/utils'; -import * as testsUtils from '../utils'; - -describe('ClientServer', () => { - const logger = new Logger('websocket test', LogLevel.WARN, [ - new StreamHandler( - formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, - ), - ]); - const loudLogger = new Logger('websocket test', LogLevel.WARN, [ - new StreamHandler( - formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, - ), - ]); - - let dataDir: string; - let keyRing: KeyRing; - let tlsConfig: TLSConfig; - const host = '127.0.0.2'; - let clientServer: ClientServer; - let clientClient: ClientClient; - - const messagesArb = fc.array( - fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), - ); - const streamsArb = fc.array(messagesArb, { minLength: 1 }).noShrink(); - const asyncReadWrite = async ( - messages: Array, - streampair: ReadableWritablePair, - ) => { - await Promise.allSettled([ - (async () => { - const writer = streampair.writable.getWriter(); - for (const message of messages) { - await writer.write(message); - } - await writer.close(); - })(), - (async () => { - for await (const _ of streampair.readable) { - // No touch, only consume - } - })(), - ]); - }; - - beforeEach(async () => { - dataDir = await fs.promises.mkdtemp( - path.join(os.tmpdir(), 'polykey-test-'), - ); - const keysPath = path.join(dataDir, 'keys'); - keyRing = await KeyRing.createKeyRing({ - keysPath: keysPath, - password: 'password', - logger: logger.getChild('keyRing'), - }); - tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - }); - afterEach(async () => { - logger.info('AFTEREACH'); - await clientServer.stop(true); - await clientClient.destroy(); - await keyRing.stop(); - await fs.promises.rm(dataDir, { force: true, recursive: true }); - }); - - test('makes a connection', async () => { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - - const writer = websocket.writable.getWriter(); - const reader = websocket.readable.getReader(); - const message1 = Buffer.from('1request1'); - await writer.write(message1); - expect((await reader.read()).value).toStrictEqual(message1); - const message2 = Buffer.from('1request2'); - await writer.write(message2); - expect((await reader.read()).value).toStrictEqual(message2); - await writer.close(); - expect((await reader.read()).done).toBeTrue(); - logger.info('ending'); - }); - test('Handles a connection and closes before message', async () => { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => logger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await websocket.writable.close(); - const reader = websocket.readable.getReader(); - expect((await reader.read()).done).toBeTrue(); - logger.info('ending'); - }); - testProp( - 'Handles multiple connections', - [streamsArb], - async (streamsData) => { - try { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => logger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - - const testStream = async (messages: Array) => { - const websocket = await clientClient.startConnection(); - const writer = websocket.writable.getWriter(); - const reader = websocket.readable.getReader(); - for (const message of messages) { - await writer.write(message); - const response = await reader.read(); - expect(response.done).toBeFalse(); - expect(response.value?.toString()).toStrictEqual( - message.toString(), - ); - } - await writer.close(); - expect((await reader.read()).done).toBeTrue(); - }; - const streams = streamsData.map((messages) => testStream(messages)); - await Promise.all(streams); - - logger.info('ending'); - } finally { - await clientServer.stop(true); - } - }, - ); - testProp( - 'allows half closed writable closes first', - [messagesArb, messagesArb], - async (messages1, messages2) => { - try { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void (async () => { - const writer = streamPair.writable.getWriter(); - for await (const val of messages2) { - await writer.write(val); - } - await writer.close(); - for await (const _ of streamPair.readable) { - // No touch, only consume - } - })().catch((e) => logger.error(e)); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await asyncReadWrite(messages1, websocket); - logger.info('ending'); - } finally { - await clientServer.stop(true); - } - }, - ); - testProp( - 'allows half closed readable closes first', - [messagesArb, messagesArb], - async (messages1, messages2) => { - try { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void (async () => { - for await (const _ of streamPair.readable) { - // No touch, only consume - } - const writer = streamPair.writable.getWriter(); - for await (const val of messages2) { - await writer.write(val); - } - await writer.close(); - })().catch((e) => logger.error(e)); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await asyncReadWrite(messages1, websocket); - logger.info('ending'); - } finally { - await clientServer.stop(true); - } - }, - ); - testProp( - 'handles early close of readable', - [messagesArb, messagesArb], - async (messages1, messages2) => { - try { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void (async () => { - await streamPair.readable.cancel(); - const writer = streamPair.writable.getWriter(); - for await (const val of messages2) { - await writer.write(val); - } - await writer.close(); - })().catch((e) => logger.error(e)); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await asyncReadWrite(messages1, websocket); - logger.info('ending'); - } finally { - await clientServer.stop(true); - } - }, - ); - test('Destroying ClientServer stops all connections', async () => { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch((e) => logger.error(e)); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await clientServer.stop(true); - for await (const _ of websocket.readable) { - // No touch, only consume - } - logger.info('ending'); - }); - test('Destroying ClientClient stops all connections', async () => { - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable.pipeTo(streamPair.writable).catch((e) => { - logger.error(e); - }); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await clientClient.destroy(true); - for await (const _ of websocket.readable) { - // No touch, only consume - } - await clientServer.stop(); - logger.info('ending'); - }); - test('Writable backpressure', async () => { - let context: { writeBackpressure: boolean } | undefined; - const backpressure = promise(); - const resumeWriting = promise(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void Promise.allSettled([ - (async () => { - for await (const _ of streamPair.readable) { - // No touch, only consume - } - })(), - (async () => { - // Kidnap the context - let ws: WebSocket<{ writeBackpressure: boolean }> | null = null; - // @ts-ignore: kidnap protected property - for (const websocket of clientServer.activeSockets.values()) { - ws = websocket; - } - if (ws == null) { - await streamPair.writable.close(); - return; - } - context = ws.getUserData(); - // Write until backPressured - const message = Buffer.alloc(128, 0xf0); - const writer = streamPair.writable.getWriter(); - while (!context.writeBackpressure) { - await writer.write(message); - } - loudLogger.info('BACK PRESSURED'); - backpressure.resolveP(); - await resumeWriting.p; - for (let i = 0; i < 100; i++) { - await writer.write(message); - } - await writer.close(); - loudLogger.info('WRITING ENDED'); - })(), - ]).catch((e) => logger.error(e.toString())); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await websocket.writable.close(); - - await backpressure.p; - expect(context?.writeBackpressure).toBeTrue(); - resumeWriting.resolveP(); - // Consume all of the back-pressured data - for await (const _ of websocket.readable) { - // No touch, only consume - } - expect(context?.writeBackpressure).toBeFalse(); - loudLogger.info('ending'); - }); - // Readable backpressure is not actually supported. We're dealing with it by - // using an buffer with a provided limit that can be very large. - test('Exceeding readable buffer limit causes error', async () => { - const startReading = promise(); - const handlingProm = promise(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - Promise.all([ - (async () => { - await startReading.p; - loudLogger.info('Starting consumption'); - for await (const _ of streamPair.readable) { - // No touch, only consume - } - loudLogger.info('Reads ended'); - })(), - (async () => { - await streamPair.writable.close(); - })(), - ]) - .catch(() => {}) - .finally(() => handlingProm.resolveP()); - }, - basePath: dataDir, - tlsConfig, - host, - // Setting a really low buffer limit - maxReadBufferBytes: 1500, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - const message = Buffer.alloc(1_000, 0xf0); - const writer = websocket.writable.getWriter(); - loudLogger.info('Starting writes'); - await expect(async () => { - for (let i = 0; i < 100; i++) { - await writer.write(message); - } - }).rejects.toThrow(); - startReading.resolveP(); - loudLogger.info('writes ended'); - for await (const _ of websocket.readable) { - // No touch, only consume - } - await handlingProm.p; - loudLogger.info('ending'); - }); - test.todo('client ends connection abruptly'); - test.todo('Server ends connection abruptly'); - test('Client rejects bad server certificate', async () => { - const invalidNodeId = testNodeUtils.generateRandomNodeId(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [invalidNodeId], - logger: logger.getChild('clientClient'), - }); - await expect(clientClient.startConnection()).rejects.toThrow( - networkErrors.ErrorCertChainUnclaimed, - ); - // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; - expect(activeConnections.size).toBe(0); - logger.info('ending'); - }); - test('Client authenticates with multiple certs in chain', async () => { - const keyPairs: Array = [ - keyRing.keyPair, - keysUtils.generateKeyPair(), - keysUtils.generateKeyPair(), - keysUtils.generateKeyPair(), - keysUtils.generateKeyPair(), - ]; - const tlsConfig = await testsUtils.createTLSConfigWithChain(keyPairs); - const nodeId = keysUtils.publicKeyToNodeId(keyPairs[1].publicKey); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [nodeId], - logger: logger.getChild('clientClient'), - }); - const connProm = clientClient.startConnection(); - await connProm; - await expect(connProm).toResolve(); - // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; - expect(activeConnections.size).toBe(1); - logger.info('ending'); - }); - test('Client authenticates with multiple expected nodes', async () => { - const alternativeNodeId = testNodeUtils.generateRandomNodeId(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], - logger: logger.getChild('clientClient'), - }); - await expect(clientClient.startConnection()).toResolve(); - // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; - expect(activeConnections.size).toBe(1); - logger.info('ending'); - }); -}); diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts new file mode 100644 index 000000000..c4e26f020 --- /dev/null +++ b/tests/clientRPC/clientRPC.test.ts @@ -0,0 +1,597 @@ +import type { ReadableWritablePair } from 'stream/web'; +import type { TLSConfig } from '@/network/types'; +import type { WebSocket } from 'uWebSockets.js'; +import type { KeyPair } from '@/keys/types'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; +import { testProp, fc } from '@fast-check/jest'; +import { KeyRing } from '@/keys/index'; +import ClientServer from '@/clientRPC/ClientServer'; +import { promise } from '@/utils'; +import ClientClient from '@/clientRPC/ClientClient'; +import * as keysUtils from '@/keys/utils'; +import * as networkErrors from '@/network/errors'; +import * as testNodeUtils from '../nodes/utils'; +import * as testsUtils from '../utils'; + +// This file tests both the client and server together. They're too interlinked +// to be separate. +describe('ClientRPC', () => { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + const loudLogger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + + let dataDir: string; + let keyRing: KeyRing; + let tlsConfig: TLSConfig; + const host = '127.0.0.2'; + let clientServer: ClientServer; + let clientClient: ClientClient; + + const messagesArb = fc.array( + fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), + ); + const streamsArb = fc.array(messagesArb, { minLength: 1 }).noShrink(); + const asyncReadWrite = async ( + messages: Array, + streamPair: ReadableWritablePair, + ) => { + await Promise.allSettled([ + (async () => { + const writer = streamPair.writable.getWriter(); + for (const message of messages) { + await writer.write(message); + } + await writer.close(); + })(), + (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })(), + ]); + }; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + keyRing = await KeyRing.createKeyRing({ + keysPath: keysPath, + password: 'password', + logger: logger.getChild('keyRing'), + }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + }); + afterEach(async () => { + logger.info('AFTEREACH'); + await clientServer.stop(true); + await clientClient.destroy(); + await keyRing.stop(); + await fs.promises.rm(dataDir, { force: true, recursive: true }); + }); + + // These tests are share between client and server + test('makes a connection', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + const message1 = Buffer.from('1request1'); + await writer.write(message1); + expect((await reader.read()).value).toStrictEqual(message1); + const message2 = Buffer.from('1request2'); + await writer.write(message2); + expect((await reader.read()).value).toStrictEqual(message2); + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); + test('Handles a connection and closes before message', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await websocket.writable.close(); + const reader = websocket.readable.getReader(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); + testProp( + 'Handles multiple connections', + [streamsArb], + async (streamsData) => { + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + + const testStream = async (messages: Array) => { + const websocket = await clientClient.startConnection(); + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + for (const message of messages) { + await writer.write(message); + const response = await reader.read(); + expect(response.done).toBeFalse(); + expect(response.value?.toString()).toStrictEqual( + message.toString(), + ); + } + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + }; + const streams = streamsData.map((messages) => testStream(messages)); + await Promise.all(streams); + + logger.info('ending'); + } finally { + await clientServer.stop(true); + } + }, + ); + test.todo('client ends connection abruptly'); + test.todo('Server ends connection abruptly'); + + // These describe blocks contains tests specific to either the client or server + describe('ClientServer', () => { + testProp( + 'allows half closed writable closes first', + [messagesArb, messagesArb], + async (messages1, messages2) => { + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await clientServer.stop(true); + } + }, + ); + testProp( + 'allows half closed readable closes first', + [messagesArb, messagesArb], + async (messages1, messages2) => { + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await clientServer.stop(true); + } + }, + ); + testProp( + 'handles early close of readable', + [messagesArb, messagesArb], + async (messages1, messages2) => { + try { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void (async () => { + await streamPair.readable.cancel(); + const writer = streamPair.writable.getWriter(); + for await (const val of messages2) { + await writer.write(val); + } + await writer.close(); + })().catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await asyncReadWrite(messages1, websocket); + logger.info('ending'); + } finally { + await clientServer.stop(true); + } + }, + ); + test('Destroying ClientServer stops all connections', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch((e) => logger.error(e)); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await clientServer.stop(true); + for await (const _ of websocket.readable) { + // No touch, only consume + } + logger.info('ending'); + }); + test('Writable backpressure', async () => { + let context: { writeBackpressure: boolean } | undefined; + const backpressure = promise(); + const resumeWriting = promise(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void Promise.allSettled([ + (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })(), + (async () => { + // Kidnap the context + let ws: WebSocket<{ writeBackpressure: boolean }> | null = null; + // @ts-ignore: kidnap protected property + for (const websocket of clientServer.activeSockets.values()) { + ws = websocket; + } + if (ws == null) { + await streamPair.writable.close(); + return; + } + context = ws.getUserData(); + // Write until backPressured + const message = Buffer.alloc(128, 0xf0); + const writer = streamPair.writable.getWriter(); + while (!context.writeBackpressure) { + await writer.write(message); + } + loudLogger.info('BACK PRESSURED'); + backpressure.resolveP(); + await resumeWriting.p; + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + await writer.close(); + loudLogger.info('WRITING ENDED'); + })(), + ]).catch((e) => logger.error(e.toString())); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await websocket.writable.close(); + + await backpressure.p; + expect(context?.writeBackpressure).toBeTrue(); + resumeWriting.resolveP(); + // Consume all of the back-pressured data + for await (const _ of websocket.readable) { + // No touch, only consume + } + expect(context?.writeBackpressure).toBeFalse(); + loudLogger.info('ending'); + }); + // Readable backpressure is not actually supported. We're dealing with it by + // using an buffer with a provided limit that can be very large. + test('Exceeding readable buffer limit causes error', async () => { + const startReading = promise(); + const handlingProm = promise(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + Promise.all([ + (async () => { + await startReading.p; + loudLogger.info('Starting consumption'); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + loudLogger.info('Reads ended'); + })(), + (async () => { + await streamPair.writable.close(); + })(), + ]) + .catch(() => {}) + .finally(() => handlingProm.resolveP()); + }, + basePath: dataDir, + tlsConfig, + host, + // Setting a really low buffer limit + maxReadBufferBytes: 1500, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + const message = Buffer.alloc(1_000, 0xf0); + const writer = websocket.writable.getWriter(); + loudLogger.info('Starting writes'); + await expect(async () => { + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + }).rejects.toThrow(); + startReading.resolveP(); + loudLogger.info('writes ended'); + for await (const _ of websocket.readable) { + // No touch, only consume + } + await handlingProm.p; + loudLogger.info('ending'); + }); + }); + + describe('ClientClient', () => { + test('Destroying ClientClient stops all connections', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable.pipeTo(streamPair.writable).catch((e) => { + logger.error(e); + }); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await clientClient.destroy(true); + for await (const _ of websocket.readable) { + // No touch, only consume + } + await clientServer.stop(); + logger.info('ending'); + }); + test('Authentication rejects bad server certificate', async () => { + const invalidNodeId = testNodeUtils.generateRandomNodeId(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [invalidNodeId], + logger: logger.getChild('clientClient'), + }); + await expect(clientClient.startConnection()).rejects.toThrow( + networkErrors.ErrorCertChainUnclaimed, + ); + // @ts-ignore: kidnap protected property + const activeConnections = clientClient.activeConnections; + expect(activeConnections.size).toBe(0); + logger.info('ending'); + }); + test('Authenticates with multiple certs in chain', async () => { + const keyPairs: Array = [ + keyRing.keyPair, + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + keysUtils.generateKeyPair(), + ]; + const tlsConfig = await testsUtils.createTLSConfigWithChain(keyPairs); + const nodeId = keysUtils.publicKeyToNodeId(keyPairs[1].publicKey); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [nodeId], + logger: logger.getChild('clientClient'), + }); + const connProm = clientClient.startConnection(); + await connProm; + await expect(connProm).toResolve(); + // @ts-ignore: kidnap protected property + const activeConnections = clientClient.activeConnections; + expect(activeConnections.size).toBe(1); + logger.info('ending'); + }); + test('Authenticates with multiple expected nodes', async () => { + const alternativeNodeId = testNodeUtils.generateRandomNodeId(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], + logger: logger.getChild('clientClient'), + }); + await expect(clientClient.startConnection()).toResolve(); + // @ts-ignore: kidnap protected property + const activeConnections = clientClient.activeConnections; + expect(activeConnections.size).toBe(1); + logger.info('ending'); + }); + test.todo('Writable backpressure'); + test.todo('Readable backpressure'); + test.todo('Connection times out'); + }); +}); From e092e64b3179547eae9180ad6336ee1e6a17d027 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 21 Feb 2023 17:59:47 +1100 Subject: [PATCH 07/23] tests: tests for abrupt connection ending [ci skip] --- tests/clientRPC/clientRPC.test.ts | 345 +++++++++++++++++++----------- 1 file changed, 217 insertions(+), 128 deletions(-) diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index c4e26f020..84e9fc466 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -194,8 +194,212 @@ describe('ClientRPC', () => { } }, ); - test.todo('client ends connection abruptly'); - test.todo('Server ends connection abruptly'); + test('reverse backpressure', async () => { + let context: { writeBackpressure: boolean } | undefined; + const backpressure = promise(); + const resumeWriting = promise(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void Promise.allSettled([ + (async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + })(), + (async () => { + // Kidnap the context + let ws: WebSocket<{ writeBackpressure: boolean }> | null = null; + // @ts-ignore: kidnap protected property + for (const websocket of clientServer.activeSockets.values()) { + ws = websocket; + } + if (ws == null) { + await streamPair.writable.close(); + return; + } + context = ws.getUserData(); + // Write until backPressured + const message = Buffer.alloc(128, 0xf0); + const writer = streamPair.writable.getWriter(); + while (!context.writeBackpressure) { + await writer.write(message); + } + loudLogger.info('BACK PRESSURED'); + backpressure.resolveP(); + await resumeWriting.p; + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + await writer.close(); + loudLogger.info('WRITING ENDED'); + })(), + ]).catch((e) => logger.error(e.toString())); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await websocket.writable.close(); + + await backpressure.p; + expect(context?.writeBackpressure).toBeTrue(); + resumeWriting.resolveP(); + // Consume all of the back-pressured data + for await (const _ of websocket.readable) { + // No touch, only consume + } + expect(context?.writeBackpressure).toBeFalse(); + loudLogger.info('ending'); + }); + // Readable backpressure is not actually supported. We're dealing with it by + // using an buffer with a provided limit that can be very large. + test('Exceeding readable buffer limit causes error', async () => { + const startReading = promise(); + const handlingProm = promise(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + Promise.all([ + (async () => { + await startReading.p; + loudLogger.info('Starting consumption'); + for await (const _ of streamPair.readable) { + // No touch, only consume + } + loudLogger.info('Reads ended'); + })(), + (async () => { + await streamPair.writable.close(); + })(), + ]) + .catch(() => {}) + .finally(() => handlingProm.resolveP()); + }, + basePath: dataDir, + tlsConfig, + host, + // Setting a really low buffer limit + maxReadBufferBytes: 1500, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + const message = Buffer.alloc(1_000, 0xf0); + const writer = websocket.writable.getWriter(); + loudLogger.info('Starting writes'); + await expect(async () => { + for (let i = 0; i < 100; i++) { + await writer.write(message); + } + }).rejects.toThrow(); + startReading.resolveP(); + loudLogger.info('writes ended'); + for await (const _ of websocket.readable) { + // No touch, only consume + } + await handlingProm.p; + loudLogger.info('ending'); + }); + // To fully test these two I need to start the client or server in a separate process and kill that process. + // These require the ping/pong connection watchdogs to be implemented. + test.skip('client ends connection abruptly', async () => { + const handlerProm = promise(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .then(handlerProm.resolveP, handlerProm.rejectP) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + // @ts-ignore: kidnap protected property + const activeConnections = clientClient.activeConnections; + for (const activeConnection of activeConnections) { + activeConnection.cancel(); + } + await expect(handlerProm.p).toResolve(); + for await (const _ of websocket.readable) { + // Do nothing + } + logger.info('ending'); + }); + test.skip('Server ends connection abruptly', async () => { + const streamPairProm = + promise>(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair_) => { + logger.info('inside callback'); + // Don't do anything with the handler + streamPairProm.resolveP(streamPair_); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + // @ts-ignore: kidnap protected property + const activeSockets = clientServer.activeSockets; + for (const activeSocket of activeSockets) { + activeSocket.close(); + } + const streamPair = await streamPairProm.p; + // Expect both readable to throw + const handlerReadProm = (async () => { + for await (const _ of streamPair.readable) { + // Do nothing + } + })(); + await expect(handlerReadProm).rejects.toThrow(); + const clientReadProm = (async () => { + for await (const _ of websocket.readable) { + // Do nothing + } + })(); + await expect(clientReadProm).rejects.toThrow(); + // Both writables should throw. + const handlerWritable = streamPair.writable.getWriter(); + await expect(handlerWritable.write(Buffer.from('test'))).rejects.toThrow(); + const clientWritable = websocket.writable.getWriter(); + await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); + logger.info('ending'); + }); // These describe blocks contains tests specific to either the client or server describe('ClientServer', () => { @@ -315,12 +519,13 @@ describe('ClientRPC', () => { }, ); test('Destroying ClientServer stops all connections', async () => { + const handlerProm = promise(); clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable .pipeTo(streamPair.writable) - .catch((e) => logger.error(e)); + .then(handlerProm.resolveP, handlerProm.rejectP); }, basePath: dataDir, tlsConfig, @@ -336,132 +541,16 @@ describe('ClientRPC', () => { }); const websocket = await clientClient.startConnection(); await clientServer.stop(true); - for await (const _ of websocket.readable) { - // No touch, only consume - } - logger.info('ending'); - }); - test('Writable backpressure', async () => { - let context: { writeBackpressure: boolean } | undefined; - const backpressure = promise(); - const resumeWriting = promise(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void Promise.allSettled([ - (async () => { - for await (const _ of streamPair.readable) { - // No touch, only consume - } - })(), - (async () => { - // Kidnap the context - let ws: WebSocket<{ writeBackpressure: boolean }> | null = null; - // @ts-ignore: kidnap protected property - for (const websocket of clientServer.activeSockets.values()) { - ws = websocket; - } - if (ws == null) { - await streamPair.writable.close(); - return; - } - context = ws.getUserData(); - // Write until backPressured - const message = Buffer.alloc(128, 0xf0); - const writer = streamPair.writable.getWriter(); - while (!context.writeBackpressure) { - await writer.write(message); - } - loudLogger.info('BACK PRESSURED'); - backpressure.resolveP(); - await resumeWriting.p; - for (let i = 0; i < 100; i++) { - await writer.write(message); - } - await writer.close(); - loudLogger.info('WRITING ENDED'); - })(), - ]).catch((e) => logger.error(e.toString())); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await websocket.writable.close(); - - await backpressure.p; - expect(context?.writeBackpressure).toBeTrue(); - resumeWriting.resolveP(); - // Consume all of the back-pressured data - for await (const _ of websocket.readable) { - // No touch, only consume - } - expect(context?.writeBackpressure).toBeFalse(); - loudLogger.info('ending'); - }); - // Readable backpressure is not actually supported. We're dealing with it by - // using an buffer with a provided limit that can be very large. - test('Exceeding readable buffer limit causes error', async () => { - const startReading = promise(); - const handlingProm = promise(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - Promise.all([ - (async () => { - await startReading.p; - loudLogger.info('Starting consumption'); - for await (const _ of streamPair.readable) { - // No touch, only consume - } - loudLogger.info('Reads ended'); - })(), - (async () => { - await streamPair.writable.close(); - })(), - ]) - .catch(() => {}) - .finally(() => handlingProm.resolveP()); - }, - basePath: dataDir, - tlsConfig, - host, - // Setting a really low buffer limit - maxReadBufferBytes: 1500, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - const message = Buffer.alloc(1_000, 0xf0); - const writer = websocket.writable.getWriter(); - loudLogger.info('Starting writes'); - await expect(async () => { - for (let i = 0; i < 100; i++) { - await writer.write(message); + const clientReadProm = (async () => { + for await (const _ of websocket.readable) { + // No touch, only consume } - }).rejects.toThrow(); - startReading.resolveP(); - loudLogger.info('writes ended'); - for await (const _ of websocket.readable) { - // No touch, only consume - } - await handlingProm.p; - loudLogger.info('ending'); + })(); + await expect(clientReadProm).rejects.toThrow(); + const writer = websocket.writable.getWriter(); + await expect(handlerProm).toReject(); + await expect(writer.write(Buffer.from('test'))).rejects.toThrow(); + logger.info('ending'); }); }); From 17cb1109ddeceefe14db78b24246808450017af9 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 21 Feb 2023 18:47:25 +1100 Subject: [PATCH 08/23] feat: client rejects normal HTTP requests [ci skip] --- src/clientRPC/ClientServer.ts | 7 +++++ tests/clientRPC/clientRPC.test.ts | 43 ++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index a385712ed..1b5e2ab9e 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -143,6 +143,13 @@ class ClientServer { const listenProm = promise(); if (host != null) { // With custom host + this.server.any('/*', (res, req) => { + res + .writeStatus('426') + .writeHeader('connection', 'Upgrade') + .writeHeader('upgrade', 'websocket') + .end('426 Upgrade Required', true); + }); this.server.listen(host, port ?? 0, (listenSocket) => { if (listenSocket) { this.listenSocket = listenSocket; diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index 84e9fc466..e4c9eaf71 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -2,9 +2,11 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; import type { WebSocket } from 'uWebSockets.js'; import type { KeyPair } from '@/keys/types'; +import type http from 'http'; import fs from 'fs'; import path from 'path'; import os from 'os'; +import https from 'https'; import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { KeyRing } from '@/keys/index'; @@ -75,8 +77,8 @@ describe('ClientRPC', () => { }); afterEach(async () => { logger.info('AFTEREACH'); - await clientServer.stop(true); - await clientClient.destroy(); + await clientServer?.stop(true); + await clientClient?.destroy(true); await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); }); @@ -546,12 +548,45 @@ describe('ClientRPC', () => { // No touch, only consume } })(); - await expect(clientReadProm).rejects.toThrow(); + await expect(clientReadProm).toResolve(); const writer = websocket.writable.getWriter(); - await expect(handlerProm).toReject(); + await expect(handlerProm.p).toResolve(); await expect(writer.write(Buffer.from('test'))).rejects.toThrow(); logger.info('ending'); }); + test('Server rejects normal HTTPS requests', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + const getResProm = promise(); + https.get( + `https://${host}:${clientServer.port}/`, + { rejectUnauthorized: false }, + getResProm.resolveP, + ); + const res = await getResProm.p; + const contentProm = promise(); + res.once('data', (d) => contentProm.resolveP(d.toString())); + const endProm = promise(); + res.on('error', endProm.rejectP); + res.on('close', endProm.resolveP); + + expect(res.statusCode).toBe(426); + await expect(contentProm.p).resolves.toBe('426 Upgrade Required'); + expect(res.headers['connection']).toBe('Upgrade'); + expect(res.headers['upgrade']).toBe('websocket'); + }); }); describe('ClientClient', () => { From c21ebd1df8cab6f069a97add85c4fadd5619bdf1 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 22 Feb 2023 13:02:43 +1100 Subject: [PATCH 09/23] feat: `ClientClient` connection timeout [ci skip] --- src/clientRPC/ClientClient.ts | 52 +++++++++++++++++++++++++------ src/clientRPC/ClientServer.ts | 2 +- tests/clientRPC/clientRPC.test.ts | 21 ++++++++++--- 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/ClientClient.ts index 5780a3004..90e9d4474 100644 --- a/src/clientRPC/ClientClient.ts +++ b/src/clientRPC/ClientClient.ts @@ -6,9 +6,12 @@ import { createDestroy } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import WebSocket from 'ws'; import { PromiseCancellable } from '@matrixai/async-cancellable'; +import { Timer } from '@matrixai/timer'; import * as clientRpcUtils from './utils'; import { promise } from '../utils'; +const timeoutSymbol = Symbol('TimedOutSymbol'); + interface ClientClient extends createDestroy.CreateDestroy {} @createDestroy.CreateDestroy() class ClientClient { @@ -16,12 +19,14 @@ class ClientClient { host, port, expectedNodeIds, + connectionTimeout, maxReadableStreamBytes = 1000, // About 1kB logger = new Logger(this.name), }: { host: string; port: number; expectedNodeIds: Array; + connectionTimeout?: number; maxReadableStreamBytes?: number; logger?: Logger; }): Promise { @@ -32,6 +37,7 @@ class ClientClient { port, maxReadableStreamBytes, expectedNodeIds, + connectionTimeout, ); logger.info(`Created ${this.name}`); return clientClient; @@ -45,6 +51,7 @@ class ClientClient { protected port: number, protected maxReadableStreamBytes: number, protected expectedNodeIds: Array, + protected connectionTimeout: number | undefined, ) {} public async destroy(force: boolean = false) { @@ -61,9 +68,19 @@ class ClientClient { } @createDestroy.ready(Error('TMP destroyed')) - public async startConnection(): Promise< - ReadableWritablePair - > { + public async startConnection({ + timeoutTimer, + }: { + timeoutTimer?: Timer; + } = {}): Promise> { + // Use provided timer + let timer: Timer | undefined = timeoutTimer; + // If no timer provided use provided default timeout + if (timeoutTimer == null && this.connectionTimeout != null) { + timer = new Timer({ + delay: this.connectionTimeout, + }); + } const address = `wss://${this.host}:${this.port}`; this.logger.info(`Connecting to ${address}`); const connectProm = promise(); @@ -76,15 +93,17 @@ class ClientClient { ws.terminate(); }; const abortController = new AbortController(); - const connectionProm = new PromiseCancellable((resolve) => { + const activeConnectionProm = new PromiseCancellable((resolve) => { ws.once('close', () => { abortController.signal.removeEventListener('abort', abortHandler); resolve(); }); }, abortController); abortController.signal.addEventListener('abort', abortHandler); - this.activeConnections.add(connectionProm); - connectionProm.finally(() => this.activeConnections.delete(connectionProm)); + this.activeConnections.add(activeConnectionProm); + activeConnectionProm.finally(() => + this.activeConnections.delete(activeConnectionProm), + ); // Handle connection failure const openErrorHandler = (e) => { connectProm.rejectP(Error('TMP ERROR Connection failure', { cause: e })); @@ -106,12 +125,27 @@ class ClientClient { connectProm.resolveP(); }); // TODO: Race with a connection timeout here + // There are 3 resolve conditions here. + // 1. Connection established and authenticated + // 2. connection error or authentication failure + // 3. connection timed out try { - await Promise.all([authenticateProm.p, connectProm.p]); + const result = await Promise.race([ + timer?.then(() => timeoutSymbol) ?? new Promise(() => {}), + await Promise.all([authenticateProm.p, connectProm.p]), + ]); + if (result === timeoutSymbol) throw Error('TMP timed out'); } catch (e) { // Clean up - ws.close(); - await connectionProm; + // unregister handlers + ws.removeAllListeners('error'); + ws.removeAllListeners('upgrade'); + ws.removeAllListeners('open'); + // Close the ws if it's open at this stage + ws.terminate(); + // Ensure the connection is removed from the active connection set before + // returning. + await activeConnectionProm; throw e; } diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index 1b5e2ab9e..fab21481c 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -143,7 +143,7 @@ class ClientServer { const listenProm = promise(); if (host != null) { // With custom host - this.server.any('/*', (res, req) => { + this.server.any('/*', (res, _) => { res .writeStatus('426') .writeHeader('connection', 'Upgrade') diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index e4c9eaf71..3626311db 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -9,6 +9,7 @@ import os from 'os'; import https from 'https'; import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; +import { Timer } from '@matrixai/timer'; import { KeyRing } from '@/keys/index'; import ClientServer from '@/clientRPC/ClientServer'; import { promise } from '@/utils'; @@ -588,7 +589,6 @@ describe('ClientRPC', () => { expect(res.headers['upgrade']).toBe('websocket'); }); }); - describe('ClientClient', () => { test('Destroying ClientClient stops all connections', async () => { clientServer = await ClientServer.createClientServer({ @@ -714,8 +714,21 @@ describe('ClientRPC', () => { expect(activeConnections.size).toBe(1); logger.info('ending'); }); - test.todo('Writable backpressure'); - test.todo('Readable backpressure'); - test.todo('Connection times out'); + test('Connection times out', async () => { + clientClient = await ClientClient.createClientClient({ + host, + port: 12345, + expectedNodeIds: [keyRing.getNodeId()], + connectionTimeout: 0, + logger: logger.getChild('clientClient'), + }); + await expect(clientClient.startConnection({})).rejects.toThrow(); + await expect( + clientClient.startConnection({ + timeoutTimer: new Timer({ delay: 0 }), + }), + ).rejects.toThrow(); + logger.info('ending'); + }); }); }); From 4ceb65714741a5f8ed0080855c4f041392c83d31 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 22 Feb 2023 18:20:13 +1100 Subject: [PATCH 10/23] tests: testing abruptly dropped connections [ci skip] --- src/clientRPC/ClientServer.ts | 13 ++-- tests/clientRPC/clientRPC.test.ts | 121 +++++++++++++++++++----------- tests/clientRPC/testClient.ts | 31 ++++++++ tests/clientRPC/testServer.ts | 36 +++++++++ 4 files changed, 152 insertions(+), 49 deletions(-) create mode 100644 tests/clientRPC/testClient.ts create mode 100644 tests/clientRPC/testServer.ts diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index fab21481c..d11af0b23 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -4,6 +4,8 @@ import type { TLSConfig } from 'network/types'; import type { WebSocket } from 'uWebSockets.js'; import { WritableStream, ReadableStream } from 'stream/web'; import path from 'path'; +import os from 'os'; +import fs from 'fs'; import { startStop } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import uWebsocket from 'uWebSockets.js'; @@ -40,7 +42,7 @@ class ClientServer { }: { connectionCallback: ConnectionCallback; tlsConfig: TLSConfig; - basePath: string; + basePath?: string; host?: string; port?: number; fs?: FileSystem; @@ -82,13 +84,13 @@ class ClientServer { public async start({ connectionCallback, tlsConfig, - basePath, + basePath = os.tmpdir(), host, port, }: { connectionCallback: ConnectionCallback; tlsConfig: TLSConfig; - basePath: string; + basePath?: string; host?: string; port?: number; }): Promise { @@ -97,8 +99,9 @@ class ClientServer { // TODO: take a TLS config, write the files in the temp directory and // load them. let count = 0; - const keyFile = path.join(basePath, 'keyFile.pem'); - const certFile = path.join(basePath, 'certFile.pem'); + const tmpDir = await fs.promises.mkdtemp(path.join(basePath, 'polykey-')); + const keyFile = path.join(tmpDir, 'keyFile.pem'); + const certFile = path.join(tmpDir, 'certFile.pem'); await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); await this.fs.promises.writeFile(certFile, tlsConfig.certChainPem); this.server = uWebsocket.SSLApp({ diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index 3626311db..81ae8c5b6 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -16,18 +16,19 @@ import { promise } from '@/utils'; import ClientClient from '@/clientRPC/ClientClient'; import * as keysUtils from '@/keys/utils'; import * as networkErrors from '@/network/errors'; +import * as nodesUtils from '@/nodes/utils'; import * as testNodeUtils from '../nodes/utils'; import * as testsUtils from '../utils'; // This file tests both the client and server together. They're too interlinked // to be separate. describe('ClientRPC', () => { - const logger = new Logger('websocket test', LogLevel.WARN, [ + const logger = new Logger('websocket test', LogLevel.DEBUG, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); - const loudLogger = new Logger('websocket test', LogLevel.WARN, [ + const loudLogger = new Logger('websocket test', LogLevel.DEBUG, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), @@ -321,7 +322,7 @@ describe('ClientRPC', () => { }); // To fully test these two I need to start the client or server in a separate process and kill that process. // These require the ping/pong connection watchdogs to be implemented. - test.skip('client ends connection abruptly', async () => { + test('client ends connection abruptly', async () => { const handlerProm = promise(); clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { @@ -337,69 +338,101 @@ describe('ClientRPC', () => { logger: loudLogger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), + + const testProcess = await testsUtils.spawn( + 'ts-node', + [ + '--project', + testsUtils.tsConfigPath, + `${globalThis.testDir}/clientRPC/testClient.ts`, + ], + { + env: { + PK_TEST_HOST: host, + PK_TEST_PORT: `${clientServer.port}`, + PK_TEST_NODE_ID: nodesUtils.encodeNodeId(keyRing.getNodeId()), + }, + }, + logger, + ); + const startedProm = promise(); + testProcess.stdout!.on('data', (data) => { + startedProm.resolveP(data.toString()); }); - const websocket = await clientClient.startConnection(); + testProcess.stderr!.on('data', (data) => + startedProm.rejectP(data.toString()), + ); + const exitedProm = promise(); + testProcess.once('exit', () => exitedProm.resolveP()); + await startedProm.p; + + // Killing the client + testProcess.kill('SIGTERM'); + await exitedProm.p; + // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; - for (const activeConnection of activeConnections) { - activeConnection.cancel(); - } + const activeSockets = clientServer.activeSockets; + expect(activeSockets.size).toBe(0); + // Connection failure should cause handler to error await expect(handlerProm.p).toResolve(); - for await (const _ of websocket.readable) { - // Do nothing - } logger.info('ending'); }); - test.skip('Server ends connection abruptly', async () => { - const streamPairProm = - promise>(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair_) => { - logger.info('inside callback'); - // Don't do anything with the handler - streamPairProm.resolveP(streamPair_); + test('Server ends connection abruptly', async () => { + const testProcess = await testsUtils.spawn( + 'ts-node', + [ + '--project', + testsUtils.tsConfigPath, + `${globalThis.testDir}/clientRPC/testServer.ts`, + ], + { + env: { + PK_TEST_KEY_PRIVATE_PEM: tlsConfig.keyPrivatePem, + PK_TEST_CERT_CHAIN_PEM: tlsConfig.certChainPem, + PK_TEST_HOST: host, + }, }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), + logger, + ); + const startedProm = promise(); + testProcess.stdout!.on('data', (data) => { + startedProm.resolveP(parseInt(data.toString())); }); - logger.info(`Server started on port ${clientServer.port}`); + testProcess.stderr!.on('data', (data) => + startedProm.rejectP(data.toString()), + ); + const exitedProm = promise(); + testProcess.once('exit', () => exitedProm.resolveP()); + + logger.info(`Server started on port ${await startedProm.p}`); clientClient = await ClientClient.createClientClient({ host, - port: clientServer.port, + port: await startedProm.p, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); + + // Killing the server + testProcess.kill('SIGTERM'); + await exitedProm.p; + // @ts-ignore: kidnap protected property - const activeSockets = clientServer.activeSockets; - for (const activeSocket of activeSockets) { - activeSocket.close(); + const activeConnections = clientClient.activeConnections; + for (const activeConnection of activeConnections) { + await activeConnection; } - const streamPair = await streamPairProm.p; - // Expect both readable to throw - const handlerReadProm = (async () => { - for await (const _ of streamPair.readable) { - // Do nothing - } - })(); - await expect(handlerReadProm).rejects.toThrow(); + // Checking client's response to connection dropping + // client's readable should throw const clientReadProm = (async () => { for await (const _ of websocket.readable) { // Do nothing } })(); await expect(clientReadProm).rejects.toThrow(); - // Both writables should throw. - const handlerWritable = streamPair.writable.getWriter(); - await expect(handlerWritable.write(Buffer.from('test'))).rejects.toThrow(); + // Client's writable should throw const clientWritable = websocket.writable.getWriter(); + logger.info('asd'); await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); logger.info('ending'); }); diff --git a/tests/clientRPC/testClient.ts b/tests/clientRPC/testClient.ts new file mode 100644 index 000000000..28bf498e9 --- /dev/null +++ b/tests/clientRPC/testClient.ts @@ -0,0 +1,31 @@ +/** + * This is spawned as a background process for use in some NodeConnection.test.ts tests + * This process will not preserve jest testing environment, + * any usage of jest globals will result in an error + * Beware of propagated usage of jest globals through the script dependencies + * @module + */ +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import ClientClient from '@/clientRPC/ClientClient'; +import * as nodesUtils from '@/nodes/utils'; + +async function main() { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const clientClient = await ClientClient.createClientClient({ + expectedNodeIds: [nodesUtils.decodeNodeId(process.env.PK_TEST_NODE_ID!)!], + host: process.env.PK_TEST_HOST ?? '127.0.0.1', + port: parseInt(process.env.PK_TEST_PORT!), + logger, + }); + // Ignore streams, make connection hang + await clientClient.startConnection(); + process.stdout.write(`ready`); +} + +if (require.main === module) { + void main(); +} + +export default main; diff --git a/tests/clientRPC/testServer.ts b/tests/clientRPC/testServer.ts new file mode 100644 index 000000000..8a24c0bbd --- /dev/null +++ b/tests/clientRPC/testServer.ts @@ -0,0 +1,36 @@ +/** + * This is spawned as a background process for use in some NodeConnection.test.ts tests + * This process will not preserve jest testing environment, + * any usage of jest globals will result in an error + * Beware of propagated usage of jest globals through the script dependencies + * @module + */ +import type { CertificatePEMChain, PrivateKeyPEM } from '@/keys/types'; +import type { TLSConfig } from '@/network/types'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import ClientServer from '@/clientRPC/ClientServer'; + +async function main() { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const tlsConfig: TLSConfig = { + keyPrivatePem: process.env.PK_TEST_KEY_PRIVATE_PEM as PrivateKeyPEM, + certChainPem: process.env.PK_TEST_CERT_CHAIN_PEM as CertificatePEMChain, + }; + const clientServer = await ClientServer.createClientServer({ + connectionCallback: (_) => { + // Ignore streams and hang connections + }, + host: process.env.PK_TEST_HOST ?? '127.0.0.1', + tlsConfig, + logger, + }); + process.stdout.write(`${clientServer.port}`); +} + +if (require.main === module) { + void main(); +} + +export default main; From d80ad80855e8de15a8a80b9cabf5f566a3673d21 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 23 Feb 2023 16:36:05 +1100 Subject: [PATCH 11/23] feat: `ClientClient` keepalive and heartbeat Also updated tests and logic for what happens when a connection drops. Now when a connection drops the readable and writable throw. [ci skip] --- src/clientRPC/ClientClient.ts | 37 ++++++++- src/clientRPC/ClientServer.ts | 51 ++++++++---- src/types.ts | 1 + tests/clientRPC/clientRPC.test.ts | 124 +++++++++++++++++++++--------- 4 files changed, 159 insertions(+), 54 deletions(-) diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/ClientClient.ts index 90e9d4474..d6ffc2f9d 100644 --- a/src/clientRPC/ClientClient.ts +++ b/src/clientRPC/ClientClient.ts @@ -20,6 +20,8 @@ class ClientClient { port, expectedNodeIds, connectionTimeout, + pingInterval = 1000, + pingTimeout = 10000, maxReadableStreamBytes = 1000, // About 1kB logger = new Logger(this.name), }: { @@ -27,6 +29,8 @@ class ClientClient { port: number; expectedNodeIds: Array; connectionTimeout?: number; + pingInterval?: number; + pingTimeout?: number; maxReadableStreamBytes?: number; logger?: Logger; }): Promise { @@ -38,6 +42,8 @@ class ClientClient { maxReadableStreamBytes, expectedNodeIds, connectionTimeout, + pingInterval, + pingTimeout, ); logger.info(`Created ${this.name}`); return clientClient; @@ -52,6 +58,8 @@ class ClientClient { protected maxReadableStreamBytes: number, protected expectedNodeIds: Array, protected connectionTimeout: number | undefined, + protected pingInterval: number, + protected pingTimeout: number, ) {} public async destroy(force: boolean = false) { @@ -148,7 +156,6 @@ class ClientClient { await activeConnectionProm; throw e; } - // Cleaning up connection error ws.removeEventListener('error', openErrorHandler); @@ -190,7 +197,7 @@ class ClientClient { readableLogger.info('CLOSED, WS CLOSED'); ws.removeListener('message', messageHandler); if (!readableClosed) { - controller.close(); + controller.error(Error('TMP WebSocket Closed early CR')); readableClosed = true; } }); @@ -228,7 +235,7 @@ class ClientClient { writableLogger.info( `ws closing early! with code: ${code} and reason: ${reason.toString()}`, ); - controller.error(Error('TMP WebSocket Closed early')); + controller.error(Error('TMP WebSocket Closed early CW')); } }); }, @@ -259,6 +266,30 @@ class ClientClient { await wait.p; }, }); + + // Setting up heartbeat + const pingTimer = setInterval(() => { + ws.ping(); + }, this.pingInterval); + const pingTimeoutTimer = setTimeout(() => { + this.logger.debug('PING TIMED OUT'); + ws.close(4002, 'Timed out'); + }, this.pingTimeout); + ws.on('ping', () => { + this.logger.debug('received ping'); + ws.pong(); + }); + ws.on('pong', () => { + this.logger.debug('received pong'); + pingTimeoutTimer.refresh(); + }); + ws.once('close', () => { + this.logger.debug('Cleaning up timers'); + // Clean up timers + clearTimeout(pingTimer); + clearTimeout(pingTimeoutTimer); + }); + return { readable: readableStream, writable: writableStream, diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index d11af0b23..8e32b3b6b 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -1,11 +1,14 @@ -import type { ReadableWritablePair } from 'stream/web'; +import type { + ReadableStreamController, + ReadableWritablePair, + WritableStreamDefaultController, +} from 'stream/web'; import type { FileSystem, PromiseDeconstructed } from 'types'; import type { TLSConfig } from 'network/types'; import type { WebSocket } from 'uWebSockets.js'; import { WritableStream, ReadableStream } from 'stream/web'; import path from 'path'; import os from 'os'; -import fs from 'fs'; import { startStop } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import uWebsocket from 'uWebSockets.js'; @@ -23,6 +26,7 @@ type Context = { ) => void; drain: (ws: WebSocket) => void; close: (ws: WebSocket, code: number, message: ArrayBuffer) => void; + pong: (ws: WebSocket, message: ArrayBuffer) => void; logger: Logger; writeBackpressure: boolean; }; @@ -36,6 +40,7 @@ class ClientServer { basePath, host, port, + idleTimeout, fs = require('fs'), maxReadBufferBytes = 1_000_000_000, // About 1 GB logger = new Logger(this.name), @@ -45,12 +50,13 @@ class ClientServer { basePath?: string; host?: string; port?: number; + idleTimeout?: number; fs?: FileSystem; maxReadBufferBytes?: number; logger?: Logger; }) { logger.info(`Creating ${this.name}`); - const wsServer = new this(logger, fs, maxReadBufferBytes); + const wsServer = new this(logger, fs, maxReadBufferBytes, idleTimeout); await wsServer.start({ connectionCallback, tlsConfig, @@ -74,11 +80,13 @@ class ClientServer { * @param logger * @param fs * @param maxReadBufferBytes Max number of bytes stored in read buffer before error + * @param idleTimeout */ constructor( protected logger: Logger, protected fs: FileSystem, protected maxReadBufferBytes, + protected idleTimeout: number | undefined, ) {} public async start({ @@ -99,7 +107,9 @@ class ClientServer { // TODO: take a TLS config, write the files in the temp directory and // load them. let count = 0; - const tmpDir = await fs.promises.mkdtemp(path.join(basePath, 'polykey-')); + const tmpDir = await this.fs.promises.mkdtemp( + path.join(basePath, 'polykey-'), + ); const keyFile = path.join(tmpDir, 'keyFile.pem'); const certFile = path.join(tmpDir, 'certFile.pem'); await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); @@ -111,6 +121,8 @@ class ClientServer { await this.fs.promises.rm(keyFile); await this.fs.promises.rm(certFile); this.server.ws('/*', { + sendPingsAutomatically: true, + idleTimeout: this.idleTimeout, upgrade: (res, req, context) => { const logger = this.logger.getChild(`Connection ${count}`); res.upgrade>( @@ -207,12 +219,31 @@ class ClientServer { let readableClosed = false; let wsClosed = false; let backpressure: PromiseDeconstructed | null = null; + let writableController: WritableStreamDefaultController | undefined; + let readableController: ReadableStreamController | undefined; + context.close = () => { + logger.debug('CLOSING CALLED'); + wsClosed = true; + if (!readableClosed) { + logger.debug('CLOSING READABLE'); + readableController?.error(Error('TMP Web stream closed early SR')); + readableClosed = true; + } + if (!writableClosed) { + logger.debug('CLOSING Writable'); + writableController?.error(Error('TMP Web stream closed early SW')); + writableClosed = true; + } + }; context.drain = () => { logger.debug('DRAINING CALLED'); backpressure?.resolveP(); }; // Setting up the writable stream const writableStream = new WritableStream({ + start: (controller) => { + writableController = controller; + }, write: async (chunk, controller) => { // Logger.debug(`WRITABLE WRITE ${chunk.toString()}`); await backpressure?.p; @@ -258,6 +289,7 @@ class ClientServer { const readableStream = new ReadableStream( { start: (controller) => { + readableController = controller; context.message = (ws, message, _) => { // Logger.debug(`MESSAGE CALLED ${message.toString()}`); if (message.byteLength === 0) { @@ -283,15 +315,6 @@ class ClientServer { controller.error(err); } }; - context.close = () => { - logger.debug('CLOSING CALLED'); - wsClosed = true; - if (!readableClosed) { - logger.debug('CLOSING READABLE'); - controller.close(); - readableClosed = true; - } - }; }, cancel: () => { readableClosed = true; @@ -313,7 +336,7 @@ class ClientServer { writable: writableStream, }); } catch (e) { - // TODO: If the callback failed then we need to handle clean up + context.close(ws, 0, Buffer.from('')); logger.error(e.toString()); } } diff --git a/src/types.ts b/src/types.ts index 9a5289884..2f937bc51 100644 --- a/src/types.ts +++ b/src/types.ts @@ -110,6 +110,7 @@ interface FileSystem { readdir: typeof fs.promises.readdir; rename: typeof fs.promises.rename; open: typeof fs.promises.open; + mkdtemp: typeof fs.promises.mkdtemp; }; constants: typeof fs.constants; } diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index 81ae8c5b6..55f268fbb 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -12,7 +12,7 @@ import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; import { KeyRing } from '@/keys/index'; import ClientServer from '@/clientRPC/ClientServer'; -import { promise } from '@/utils'; +import { promise, sleep } from '@/utils'; import ClientClient from '@/clientRPC/ClientClient'; import * as keysUtils from '@/keys/utils'; import * as networkErrors from '@/network/errors'; @@ -314,23 +314,23 @@ describe('ClientRPC', () => { }).rejects.toThrow(); startReading.resolveP(); loudLogger.info('writes ended'); - for await (const _ of websocket.readable) { - // No touch, only consume - } + await expect(async () => { + for await (const _ of websocket.readable) { + // No touch, only consume + } + }).rejects.toThrow(); await handlingProm.p; loudLogger.info('ending'); }); // To fully test these two I need to start the client or server in a separate process and kill that process. // These require the ping/pong connection watchdogs to be implemented. test('client ends connection abruptly', async () => { - const handlerProm = promise(); + const streamPairProm = + promise>(); clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .then(handlerProm.resolveP, handlerProm.rejectP) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + streamPairProm.resolveP(streamPair); }, basePath: dataDir, tlsConfig, @@ -370,11 +370,15 @@ describe('ClientRPC', () => { testProcess.kill('SIGTERM'); await exitedProm.p; - // @ts-ignore: kidnap protected property - const activeSockets = clientServer.activeSockets; - expect(activeSockets.size).toBe(0); - // Connection failure should cause handler to error - await expect(handlerProm.p).toResolve(); + const streamPair = await streamPairProm.p; + // Everything should throw after websocket ends early + await expect(async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + const serverWritable = streamPair.writable.getWriter(); + await expect(serverWritable.write(Buffer.from('test'))).rejects.toThrow(); logger.info('ending'); }); test('Server ends connection abruptly', async () => { @@ -423,19 +427,45 @@ describe('ClientRPC', () => { await activeConnection; } // Checking client's response to connection dropping - // client's readable should throw - const clientReadProm = (async () => { + await expect(async () => { for await (const _ of websocket.readable) { - // Do nothing + // No touch, only consume } - })(); - await expect(clientReadProm).rejects.toThrow(); - // Client's writable should throw + }).rejects.toThrow(); const clientWritable = websocket.writable.getWriter(); - logger.info('asd'); await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); logger.info('ending'); }); + test('ping pong', async () => { + const waitP = promise(); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void waitP.p.then(() => { + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }); + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + await sleep(10000); + waitP.resolveP(); + await asyncReadWrite([], websocket); + logger.info('ending'); + }); // These describe blocks contains tests specific to either the client or server describe('ClientServer', () => { @@ -555,13 +585,12 @@ describe('ClientRPC', () => { }, ); test('Destroying ClientServer stops all connections', async () => { - const handlerProm = promise(); + const streamPairProm = + promise>(); clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); - void streamPair.readable - .pipeTo(streamPair.writable) - .then(handlerProm.resolveP, handlerProm.rejectP); + streamPairProm.resolveP(streamPair); }, basePath: dataDir, tlsConfig, @@ -577,15 +606,22 @@ describe('ClientRPC', () => { }); const websocket = await clientClient.startConnection(); await clientServer.stop(true); - const clientReadProm = (async () => { + const streamPair = await streamPairProm.p; + // Everything should throw after websocket ends early + await expect(async () => { for await (const _ of websocket.readable) { // No touch, only consume } - })(); - await expect(clientReadProm).toResolve(); - const writer = websocket.writable.getWriter(); - await expect(handlerProm.p).toResolve(); - await expect(writer.write(Buffer.from('test'))).rejects.toThrow(); + }).rejects.toThrow(); + await expect(async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + const clientWritable = websocket.writable.getWriter(); + const serverWritable = streamPair.writable.getWriter(); + await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); + await expect(serverWritable.write(Buffer.from('test'))).rejects.toThrow(); logger.info('ending'); }); test('Server rejects normal HTTPS requests', async () => { @@ -624,12 +660,12 @@ describe('ClientRPC', () => { }); describe('ClientClient', () => { test('Destroying ClientClient stops all connections', async () => { + const streamPairProm = + promise>(); clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); - void streamPair.readable.pipeTo(streamPair.writable).catch((e) => { - logger.error(e); - }); + streamPairProm.resolveP(streamPair); }, basePath: dataDir, tlsConfig, @@ -644,10 +680,24 @@ describe('ClientRPC', () => { logger: logger.getChild('clientClient'), }); const websocket = await clientClient.startConnection(); + // Destroying the client, force close connections await clientClient.destroy(true); - for await (const _ of websocket.readable) { - // No touch, only consume - } + const streamPair = await streamPairProm.p; + // Everything should throw after websocket ends early + await expect(async () => { + for await (const _ of websocket.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + await expect(async () => { + for await (const _ of streamPair.readable) { + // No touch, only consume + } + }).rejects.toThrow(); + const clientWritable = websocket.writable.getWriter(); + const serverWritable = streamPair.writable.getWriter(); + await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); + await expect(serverWritable.write(Buffer.from('test'))).rejects.toThrow(); await clientServer.stop(); logger.info('ending'); }); From 2d01f278260715c58bcad42daf0f4fd1001a8872 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 23 Feb 2023 17:23:32 +1100 Subject: [PATCH 12/23] feat: `ClientServer` keepalive and heartbeat [ci skip] --- src/clientRPC/ClientServer.ts | 72 ++++++++++++++++++++-------- tests/clientRPC/clientRPC.test.ts | 78 +++++++++++++++++++------------ 2 files changed, 100 insertions(+), 50 deletions(-) diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index 8e32b3b6b..567015d89 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -41,6 +41,8 @@ class ClientServer { host, port, idleTimeout, + pingInterval = 1000, + pingTimeout = 10000, fs = require('fs'), maxReadBufferBytes = 1_000_000_000, // About 1 GB logger = new Logger(this.name), @@ -51,12 +53,21 @@ class ClientServer { host?: string; port?: number; idleTimeout?: number; + pingInterval?: number; + pingTimeout?: number; fs?: FileSystem; maxReadBufferBytes?: number; logger?: Logger; }) { logger.info(`Creating ${this.name}`); - const wsServer = new this(logger, fs, maxReadBufferBytes, idleTimeout); + const wsServer = new this( + logger, + fs, + maxReadBufferBytes, + idleTimeout, + pingInterval, + pingTimeout, + ); await wsServer.start({ connectionCallback, tlsConfig, @@ -81,12 +92,16 @@ class ClientServer { * @param fs * @param maxReadBufferBytes Max number of bytes stored in read buffer before error * @param idleTimeout + * @param pingInterval + * @param pingTimeout */ constructor( protected logger: Logger, protected fs: FileSystem, protected maxReadBufferBytes, protected idleTimeout: number | undefined, + protected pingInterval: number, + protected pingTimeout: number, ) {} public async start({ @@ -221,24 +236,6 @@ class ClientServer { let backpressure: PromiseDeconstructed | null = null; let writableController: WritableStreamDefaultController | undefined; let readableController: ReadableStreamController | undefined; - context.close = () => { - logger.debug('CLOSING CALLED'); - wsClosed = true; - if (!readableClosed) { - logger.debug('CLOSING READABLE'); - readableController?.error(Error('TMP Web stream closed early SR')); - readableClosed = true; - } - if (!writableClosed) { - logger.debug('CLOSING Writable'); - writableController?.error(Error('TMP Web stream closed early SW')); - writableClosed = true; - } - }; - context.drain = () => { - logger.debug('DRAINING CALLED'); - backpressure?.resolveP(); - }; // Setting up the writable stream const writableStream = new WritableStream({ start: (controller) => { @@ -329,6 +326,43 @@ class ClientServer { size: (chunk) => chunk?.byteLength ?? 0, }, ); + + const pingTimer = setInterval(() => { + ws.ping(); + }, this.pingInterval); + const pingTimeoutTimer = setTimeout(() => { + logger.debug('ping timed out'); + ws.end(); + }, this.pingTimeout); + context.pong = () => { + logger.debug('received pong'); + pingTimeoutTimer.refresh(); + }; + context.close = () => { + logger.debug('CLOSING CALLED'); + wsClosed = true; + // Cleaning up timers + logger.debug('Cleaning up timers'); + clearTimeout(pingTimer); + clearTimeout(pingTimeoutTimer); + // Closing streams + logger.debug('cleaning streams'); + if (!readableClosed) { + logger.debug('CLOSING READABLE'); + readableController?.error(Error('TMP Web stream closed early SR')); + readableClosed = true; + } + if (!writableClosed) { + logger.debug('CLOSING Writable'); + writableController?.error(Error('TMP Web stream closed early SW')); + writableClosed = true; + } + }; + context.drain = () => { + logger.debug('DRAINING CALLED'); + backpressure?.resolveP(); + }; + logger.info('callback'); try { this.connectionCallback({ diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index 55f268fbb..a16bb72f8 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -12,7 +12,7 @@ import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; import { KeyRing } from '@/keys/index'; import ClientServer from '@/clientRPC/ClientServer'; -import { promise, sleep } from '@/utils'; +import { promise } from '@/utils'; import ClientClient from '@/clientRPC/ClientClient'; import * as keysUtils from '@/keys/utils'; import * as networkErrors from '@/network/errors'; @@ -436,36 +436,6 @@ describe('ClientRPC', () => { await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); logger.info('ending'); }); - test('ping pong', async () => { - const waitP = promise(); - clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - logger.info('inside callback'); - void waitP.p.then(() => { - void streamPair.readable - .pipeTo(streamPair.writable) - .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); - }); - }, - basePath: dataDir, - tlsConfig, - host, - logger: loudLogger.getChild('server'), - }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ - host, - port: clientServer.port, - expectedNodeIds: [keyRing.getNodeId()], - logger: logger.getChild('clientClient'), - }); - const websocket = await clientClient.startConnection(); - await sleep(10000); - waitP.resolveP(); - await asyncReadWrite([], websocket); - logger.info('ending'); - }); // These describe blocks contains tests specific to either the client or server describe('ClientServer', () => { @@ -657,6 +627,29 @@ describe('ClientRPC', () => { expect(res.headers['connection']).toBe('Upgrade'); expect(res.headers['upgrade']).toBe('websocket'); }); + test('ping timeout', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (_) => { + logger.info('inside callback'); + // Hang connection + }, + basePath: dataDir, + tlsConfig, + host, + pingTimeout: 100, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + await clientClient.startConnection(); + await clientClient.destroy(); + logger.info('ending'); + }); }); describe('ClientClient', () => { test('Destroying ClientClient stops all connections', async () => { @@ -813,5 +806,28 @@ describe('ClientRPC', () => { ).rejects.toThrow(); logger.info('ending'); }); + test('ping timeout', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (_) => { + logger.info('inside callback'); + // Hang connection + }, + basePath: dataDir, + tlsConfig, + host, + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host, + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + pingTimeout: 100, + logger: logger.getChild('clientClient'), + }); + await clientClient.startConnection(); + await clientClient.destroy(); + logger.info('ending'); + }); }); }); From 5da726e2996417938498fd0efd0da7c1c45b55a7 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 23 Feb 2023 17:55:33 +1100 Subject: [PATCH 13/23] feat: `IPv6` support [ci skip] --- src/clientRPC/ClientClient.ts | 15 +++++++++++-- tests/clientRPC/clientRPC.test.ts | 35 +++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/ClientClient.ts index d6ffc2f9d..10fda4ec6 100644 --- a/src/clientRPC/ClientClient.ts +++ b/src/clientRPC/ClientClient.ts @@ -7,6 +7,7 @@ import Logger from '@matrixai/logger'; import WebSocket from 'ws'; import { PromiseCancellable } from '@matrixai/async-cancellable'; import { Timer } from '@matrixai/timer'; +import { Validator } from 'ip-num'; import * as clientRpcUtils from './utils'; import { promise } from '../utils'; @@ -49,18 +50,28 @@ class ClientClient { return clientClient; } + protected host: string; + protected activeConnections: Set> = new Set(); constructor( protected logger: Logger, - protected host: string, + host: string, protected port: number, protected maxReadableStreamBytes: number, protected expectedNodeIds: Array, protected connectionTimeout: number | undefined, protected pingInterval: number, protected pingTimeout: number, - ) {} + ) { + if (Validator.isValidIPv4String(host)[0]) { + this.host = host; + } else if (Validator.isValidIPv6String(host)[0]) { + this.host = `[${host}]`; + } else { + throw Error('TMP Invalid host'); + } + } public async destroy(force: boolean = false) { this.logger.info(`Destroying ${this.constructor.name}`); diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index a16bb72f8..3fb847234 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -121,6 +121,41 @@ describe('ClientRPC', () => { expect((await reader.read()).done).toBeTrue(); logger.info('ending'); }); + test('makes a connection over IPv6', async () => { + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host: '::1', + logger: loudLogger.getChild('server'), + }); + logger.info(`Server started on port ${clientServer.port}`); + clientClient = await ClientClient.createClientClient({ + host: '::1', + port: clientServer.port, + expectedNodeIds: [keyRing.getNodeId()], + logger: logger.getChild('clientClient'), + }); + const websocket = await clientClient.startConnection(); + + const writer = websocket.writable.getWriter(); + const reader = websocket.readable.getReader(); + const message1 = Buffer.from('1request1'); + await writer.write(message1); + expect((await reader.read()).value).toStrictEqual(message1); + const message2 = Buffer.from('1request2'); + await writer.write(message2); + expect((await reader.read()).value).toStrictEqual(message2); + await writer.close(); + expect((await reader.read()).done).toBeTrue(); + logger.info('ending'); + }); test('Handles a connection and closes before message', async () => { clientServer = await ClientServer.createClientServer({ connectionCallback: (streamPair) => { From d096ce295c7b519591b5bdf2600997f45bb7fae9 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 23 Feb 2023 19:42:41 +1100 Subject: [PATCH 14/23] fix: cleaning up and errors [ci skip] --- src/clientRPC/ClientClient.ts | 82 ++++++++++++++++-------- src/clientRPC/ClientServer.ts | 116 +++++++++++++++++----------------- src/clientRPC/errors.ts | 66 +++++++++++++++++++ 3 files changed, 178 insertions(+), 86 deletions(-) create mode 100644 src/clientRPC/errors.ts diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/ClientClient.ts index 10fda4ec6..d6ca4d915 100644 --- a/src/clientRPC/ClientClient.ts +++ b/src/clientRPC/ClientClient.ts @@ -9,6 +9,7 @@ import { PromiseCancellable } from '@matrixai/async-cancellable'; import { Timer } from '@matrixai/timer'; import { Validator } from 'ip-num'; import * as clientRpcUtils from './utils'; +import * as clientRPCErrors from './errors'; import { promise } from '../utils'; const timeoutSymbol = Symbol('TimedOutSymbol'); @@ -69,7 +70,7 @@ class ClientClient { } else if (Validator.isValidIPv6String(host)[0]) { this.host = `[${host}]`; } else { - throw Error('TMP Invalid host'); + throw new clientRPCErrors.ErrorClientInvalidHost(); } } @@ -86,7 +87,7 @@ class ClientClient { this.logger.info(`Destroyed ${this.constructor.name}`); } - @createDestroy.ready(Error('TMP destroyed')) + @createDestroy.ready(new clientRPCErrors.ErrorClientDestroyed()) public async startConnection({ timeoutTimer, }: { @@ -125,7 +126,11 @@ class ClientClient { ); // Handle connection failure const openErrorHandler = (e) => { - connectProm.rejectP(Error('TMP ERROR Connection failure', { cause: e })); + connectProm.rejectP( + new clientRPCErrors.ErrorClientConnectionFailed(undefined, { + cause: e, + }), + ); }; ws.once('error', openErrorHandler); // Authenticate server's certificates @@ -153,7 +158,9 @@ class ClientClient { timer?.then(() => timeoutSymbol) ?? new Promise(() => {}), await Promise.all([authenticateProm.p, connectProm.p]), ]); - if (result === timeoutSymbol) throw Error('TMP timed out'); + if (result === timeoutSymbol) { + throw new clientRPCErrors.ErrorClientConnectionTimedOut(); + } } catch (e) { // Clean up // unregister handlers @@ -177,52 +184,67 @@ class ClientClient { const readableStream = new ReadableStream( { start: (controller) => { - readableLogger.info('STARTING'); + readableLogger.info('Starting'); const messageHandler = (data) => { - // ReadableLogger.debug(`message: ${data.toString()}`); + readableLogger.debug(`Received ${data.toString()}`); if (controller.desiredSize == null) { controller.error(Error('NEVER')); return; } if (controller.desiredSize < 0) { - // ReadableLogger.debug('PAUSING'); + readableLogger.debug('Applying readable backpressure'); ws.pause(); } const message = data as Buffer; if (message.length === 0) { - readableLogger.info('CLOSING, NULL MESSAGE'); + readableLogger.debug('Null message received'); ws.removeListener('message', messageHandler); if (!readableClosed) { - controller.close(); readableClosed = true; + readableLogger.debug('Closing'); + controller.close(); } if (writableClosed) { + this.logger.debug('Closing socket'); ws.close(); } return; } controller.enqueue(message); }; + readableLogger.debug('Registering socket message handler'); ws.on('message', messageHandler); - ws.once('close', () => { - readableLogger.info('CLOSED, WS CLOSED'); + ws.once('close', (code, reason) => { + this.logger.info('Socket closed'); ws.removeListener('message', messageHandler); if (!readableClosed) { - controller.error(Error('TMP WebSocket Closed early CR')); readableClosed = true; + readableLogger.debug( + `Closed early, ${code}, ${reason.toString()}`, + ); + controller.error( + new clientRPCErrors.ErrorClientConnectionEndedEarly(), + ); + } + }); + ws.once('error', (e) => { + if (!readableClosed) { + readableClosed = true; + readableLogger.error(e); + controller.error(e); } }); - ws.once('error', (e) => readableLogger.error(e)); }, cancel: () => { - readableLogger.info('CANCELLED'); + readableLogger.debug('Cancelled'); if (!readableClosed) { - ws.close(); + readableLogger.debug('Closing socket'); readableClosed = true; + ws.close(); } }, pull: () => { - // ReadableLogger.debug('RESUMING'); + readableLogger.debug('Releasing backpressure'); ws.resume(); }, }, @@ -233,43 +255,47 @@ class ClientClient { ); const writableStream = new WritableStream({ start: (controller) => { - writableLogger.info('STARTING'); + writableLogger.info('Starting'); ws.once('error', (e) => { - writableLogger.error(`error: ${e}`); if (!writableClosed) { - controller.error(e); writableClosed = true; + writableLogger.error(e.toString()); + controller.error(e); } }); ws.once('close', (code, reason) => { if (!writableClosed) { - writableLogger.info( - `ws closing early! with code: ${code} and reason: ${reason.toString()}`, + writableClosed = true; + writableLogger.debug(`Closed early, ${code}, ${reason.toString()}`); + controller.error( + new clientRPCErrors.ErrorClientConnectionEndedEarly(), ); - controller.error(Error('TMP WebSocket Closed early CW')); } }); }, close: () => { - writableLogger.info('CLOSING'); + writableLogger.debug('Closing, sending null message'); ws.send(Buffer.from([])); writableClosed = true; if (readableClosed) { + writableLogger.debug('Closing socket'); ws.close(); } }, abort: () => { - writableLogger.info('ABORTED'); + writableLogger.debug('Aborted'); writableClosed = true; if (readableClosed) { + writableLogger.debug('Closing socket'); ws.close(); } }, write: async (chunk, controller) => { - // WritableLogger.debug(`writing: ${chunk?.toString()}`); + writableLogger.debug(`Sending ${chunk?.toString()}`); const wait = promise(); ws.send(chunk, (e) => { if (e != null) { + writableLogger.error(e.toString()); controller.error(e); } wait.resolveP(); @@ -283,15 +309,15 @@ class ClientClient { ws.ping(); }, this.pingInterval); const pingTimeoutTimer = setTimeout(() => { - this.logger.debug('PING TIMED OUT'); + this.logger.debug('Ping timed out'); ws.close(4002, 'Timed out'); }, this.pingTimeout); ws.on('ping', () => { - this.logger.debug('received ping'); + this.logger.debug('Received ping'); ws.pong(); }); ws.on('pong', () => { - this.logger.debug('received pong'); + this.logger.debug('Received pong'); pingTimeoutTimer.refresh(); }); ws.once('close', () => { diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index 567015d89..a241306f4 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -12,6 +12,7 @@ import os from 'os'; import { startStop } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import uWebsocket from 'uWebSockets.js'; +import * as clientRPCErrors from './errors'; import { promise } from '../utils'; type ConnectionCallback = ( @@ -109,7 +110,7 @@ class ClientServer { tlsConfig, basePath = os.tmpdir(), host, - port, + port = 0, }: { connectionCallback: ConnectionCallback; tlsConfig: TLSConfig; @@ -119,12 +120,11 @@ class ClientServer { }): Promise { this.logger.info(`Starting ${this.constructor.name}`); this.connectionCallback = connectionCallback; - // TODO: take a TLS config, write the files in the temp directory and - // load them. let count = 0; const tmpDir = await this.fs.promises.mkdtemp( path.join(basePath, 'polykey-'), ); + // TODO: The key file needs to be in the encrypted format const keyFile = path.join(tmpDir, 'keyFile.pem'); const certFile = path.join(tmpDir, 'certFile.pem'); await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); @@ -170,38 +170,33 @@ class ClientServer { ws.getUserData().drain(ws); }, }); + this.server.any('/*', (res, _) => { + // Reject normal requests with an upgrade code + res + .writeStatus('426') + .writeHeader('connection', 'Upgrade') + .writeHeader('upgrade', 'websocket') + .end('426 Upgrade Required', true); + }); const listenProm = promise(); + const listenCallback = (listenSocket) => { + if (listenSocket) { + this.listenSocket = listenSocket; + listenProm.resolveP(); + } else { + listenProm.rejectP(new clientRPCErrors.ErrorServerPortUnavailable()); + } + }; if (host != null) { // With custom host - this.server.any('/*', (res, _) => { - res - .writeStatus('426') - .writeHeader('connection', 'Upgrade') - .writeHeader('upgrade', 'websocket') - .end('426 Upgrade Required', true); - }); - this.server.listen(host, port ?? 0, (listenSocket) => { - if (listenSocket) { - this.listenSocket = listenSocket; - listenProm.resolveP(); - } else { - listenProm.rejectP(Error('TMP, no port')); - } - }); + this.server.listen(host, port ?? 0, listenCallback); } else { // With default host - this.server.listen(port ?? 0, (listenSocket) => { - if (listenSocket) { - this.listenSocket = listenSocket; - listenProm.resolveP(); - } else { - listenProm.rejectP(Error('TMP, no port')); - } - }); + this.server.listen(port, listenCallback); } await listenProm.p; this.logger.debug( - `bound to port ${uWebsocket.us_socket_local_port(this.listenSocket)}`, + `Listening on port ${uWebsocket.us_socket_local_port(this.listenSocket)}`, ); this.host = host ?? '127.0.0.1'; this.logger.info(`Started ${this.constructor.name}`); @@ -236,23 +231,25 @@ class ClientServer { let backpressure: PromiseDeconstructed | null = null; let writableController: WritableStreamDefaultController | undefined; let readableController: ReadableStreamController | undefined; + const writableLogger = logger.getChild('Writable'); + const readableLogger = logger.getChild('Readable'); // Setting up the writable stream const writableStream = new WritableStream({ start: (controller) => { writableController = controller; }, write: async (chunk, controller) => { - // Logger.debug(`WRITABLE WRITE ${chunk.toString()}`); await backpressure?.p; const writeResult = ws.send(chunk, true); switch (writeResult) { default: case 2: // Write failure, emit error - controller.error(Error('TMP Failed to write')); + writableLogger.error('Send error'); + controller.error(new clientRPCErrors.ErrorServerSendFailed()); break; case 0: - logger.info('Write backpressure'); + writableLogger.info('Write backpressure'); // Signal backpressure backpressure = promise(); context.writeBackpressure = true; @@ -262,22 +259,23 @@ class ClientServer { break; case 1: // Success + writableLogger.debug(`Sending ${chunk.toString()}`); break; } }, close: () => { - logger.info('WRITABLE CLOSE'); + writableLogger.info('Closed, sending null message'); if (!wsClosed) ws.send(Buffer.from([]), true); writableClosed = true; if (readableClosed && !wsClosed) { - logger.debug('ENDING WS'); + writableLogger.debug('Ending socket'); ws.end(); } }, abort: () => { - logger.info('WRITABLE ABORT'); + writableLogger.info('Aborted'); if (readableClosed && !wsClosed) { - logger.debug('ENDING WS'); + writableLogger.debug('Ending socket'); ws.end(); } }, @@ -288,35 +286,34 @@ class ClientServer { start: (controller) => { readableController = controller; context.message = (ws, message, _) => { - // Logger.debug(`MESSAGE CALLED ${message.toString()}`); + readableLogger.debug(`Received ${message.toString()}`); if (message.byteLength === 0) { - logger.debug('NULL MESSAGE, CLOSING'); + readableLogger.debug('Null message received'); if (!readableClosed) { - logger.debug('CLOSING READABLE'); - controller.close(); readableClosed = true; + readableLogger.debug('Closing'); + controller.close(); if (writableClosed && !wsClosed) { + readableLogger.debug('Ending socket'); ws.end(); } } return; } controller.enqueue(Buffer.from(message)); - if ( - controller.desiredSize != null && - controller.desiredSize < -1000 - ) { - logger.error('Read stream buffer full'); - const err = Error('TMP read buffer limit'); - if (!wsClosed) ws.end(4001, err.toString()); - controller.error(err); + if (controller.desiredSize != null && controller.desiredSize < 0) { + readableLogger.error('Read stream buffer full'); + if (!wsClosed) ws.end(4001, 'Read stream buffer full'); + controller.error( + new clientRPCErrors.ErrorServerReadableBufferLimit(), + ); } }; }, cancel: () => { readableClosed = true; if (writableClosed && !wsClosed) { - logger.debug('ENDING WS'); + readableLogger.debug('Ending socket'); ws.end(); } }, @@ -331,39 +328,42 @@ class ClientServer { ws.ping(); }, this.pingInterval); const pingTimeoutTimer = setTimeout(() => { - logger.debug('ping timed out'); + logger.debug('Ping timed out'); ws.end(); }, this.pingTimeout); context.pong = () => { - logger.debug('received pong'); + logger.debug('Received pong'); pingTimeoutTimer.refresh(); }; context.close = () => { - logger.debug('CLOSING CALLED'); + logger.debug('Closing'); wsClosed = true; // Cleaning up timers logger.debug('Cleaning up timers'); clearTimeout(pingTimer); clearTimeout(pingTimeoutTimer); // Closing streams - logger.debug('cleaning streams'); + logger.debug('Cleaning streams'); if (!readableClosed) { - logger.debug('CLOSING READABLE'); - readableController?.error(Error('TMP Web stream closed early SR')); readableClosed = true; + readableLogger.debug('Closing'); + readableController?.error( + new clientRPCErrors.ErrorServerConnectionEndedEarly(), + ); } if (!writableClosed) { - logger.debug('CLOSING Writable'); - writableController?.error(Error('TMP Web stream closed early SW')); writableClosed = true; + writableLogger.debug('Closing'); + writableController?.error( + new clientRPCErrors.ErrorServerConnectionEndedEarly(), + ); } }; context.drain = () => { - logger.debug('DRAINING CALLED'); + logger.debug('Drained'); backpressure?.resolveP(); }; - - logger.info('callback'); + logger.debug('Calling handler callback'); try { this.connectionCallback({ readable: readableStream, diff --git a/src/clientRPC/errors.ts b/src/clientRPC/errors.ts new file mode 100644 index 000000000..bbe914582 --- /dev/null +++ b/src/clientRPC/errors.ts @@ -0,0 +1,66 @@ +import { ErrorPolykey, sysexits } from '../errors'; + +class ErrorClient extends ErrorPolykey {} + +class ErrorClientClient extends ErrorClient {} + +class ErrorClientDestroyed extends ErrorClientClient{ + static description = 'ClientClient has been destroyed'; + exitCode = sysexits.USAGE; +} + +class ErrorClientInvalidHost extends ErrorClientClient{ + static description = 'Host must be a valid IPv4 or IPv6 address string'; + exitCode = sysexits.USAGE; +} + +class ErrorClientConnectionFailed extends ErrorClientClient{ + static description = 'Failed to establish connection to server'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorClientConnectionTimedOut extends ErrorClientClient{ + static description = 'Connection timed out'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorClientConnectionEndedEarly extends ErrorClientClient{ + static description = 'Connection ended before stream ended'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorClientServer extends ErrorClient {} + +class ErrorServerPortUnavailable extends ErrorClientServer{ + static description = 'Failed to bind a free port'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorServerSendFailed extends ErrorClientServer{ + static description = 'Failed to send message'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorServerReadableBufferLimit extends ErrorClientServer{ + static description = 'Readable buffer is full, messages received too quickly'; + exitCode = sysexits.USAGE; +} + +class ErrorServerConnectionEndedEarly extends ErrorClientServer{ + static description = 'Connection ended before stream ended'; + exitCode = sysexits.UNAVAILABLE; +} + +export { + ErrorClientClient, + ErrorClientDestroyed, + ErrorClientInvalidHost, + ErrorClientConnectionFailed, + ErrorClientConnectionTimedOut, + ErrorClientConnectionEndedEarly, + ErrorClientServer, + ErrorServerPortUnavailable, + ErrorServerSendFailed, + ErrorServerReadableBufferLimit, + ErrorServerConnectionEndedEarly, +} From 1641c7e232f437964dacd26a0e35fd39511132be Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 24 Feb 2023 15:42:06 +1100 Subject: [PATCH 15/23] tests: cleaning up and fixing tests [ci skip] --- src/clientRPC/ClientClient.ts | 16 +++- src/clientRPC/errors.ts | 20 ++--- .../authenticationMiddleware.test.ts | 51 +++++++------ tests/clientRPC/clientRPC.test.ts | 74 +++++++++---------- tests/clientRPC/handlers/agentStatus.test.ts | 53 +++++++------ tests/clientRPC/handlers/agentUnlock.test.ts | 50 ++++++------- 6 files changed, 129 insertions(+), 135 deletions(-) diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/ClientClient.ts index d6ca4d915..57e279e78 100644 --- a/src/clientRPC/ClientClient.ts +++ b/src/clientRPC/ClientClient.ts @@ -259,7 +259,7 @@ class ClientClient { ws.once('error', (e) => { if (!writableClosed) { writableClosed = true; - writableLogger.error(e.toString()); + writableLogger.error(e); controller.error(e); } }); @@ -291,12 +291,20 @@ class ClientClient { } }, write: async (chunk, controller) => { + if (writableClosed) return; writableLogger.debug(`Sending ${chunk?.toString()}`); const wait = promise(); ws.send(chunk, (e) => { - if (e != null) { - writableLogger.error(e.toString()); - controller.error(e); + if (e != null && !writableClosed) { + writableClosed = true; + // Opting to debug message here and not log an error, sending + // failure is common if we send before the close event. + writableLogger.debug('failed to send'); + controller.error( + new clientRPCErrors.ErrorClientConnectionEndedEarly(undefined, { + cause: e, + }), + ); } wait.resolveP(); }); diff --git a/src/clientRPC/errors.ts b/src/clientRPC/errors.ts index bbe914582..a077575ed 100644 --- a/src/clientRPC/errors.ts +++ b/src/clientRPC/errors.ts @@ -4,49 +4,49 @@ class ErrorClient extends ErrorPolykey {} class ErrorClientClient extends ErrorClient {} -class ErrorClientDestroyed extends ErrorClientClient{ +class ErrorClientDestroyed extends ErrorClientClient { static description = 'ClientClient has been destroyed'; exitCode = sysexits.USAGE; } -class ErrorClientInvalidHost extends ErrorClientClient{ +class ErrorClientInvalidHost extends ErrorClientClient { static description = 'Host must be a valid IPv4 or IPv6 address string'; exitCode = sysexits.USAGE; } -class ErrorClientConnectionFailed extends ErrorClientClient{ +class ErrorClientConnectionFailed extends ErrorClientClient { static description = 'Failed to establish connection to server'; exitCode = sysexits.UNAVAILABLE; } -class ErrorClientConnectionTimedOut extends ErrorClientClient{ +class ErrorClientConnectionTimedOut extends ErrorClientClient { static description = 'Connection timed out'; exitCode = sysexits.UNAVAILABLE; } -class ErrorClientConnectionEndedEarly extends ErrorClientClient{ +class ErrorClientConnectionEndedEarly extends ErrorClientClient { static description = 'Connection ended before stream ended'; exitCode = sysexits.UNAVAILABLE; } class ErrorClientServer extends ErrorClient {} -class ErrorServerPortUnavailable extends ErrorClientServer{ +class ErrorServerPortUnavailable extends ErrorClientServer { static description = 'Failed to bind a free port'; exitCode = sysexits.UNAVAILABLE; } -class ErrorServerSendFailed extends ErrorClientServer{ +class ErrorServerSendFailed extends ErrorClientServer { static description = 'Failed to send message'; exitCode = sysexits.UNAVAILABLE; } -class ErrorServerReadableBufferLimit extends ErrorClientServer{ +class ErrorServerReadableBufferLimit extends ErrorClientServer { static description = 'Readable buffer is full, messages received too quickly'; exitCode = sysexits.USAGE; } -class ErrorServerConnectionEndedEarly extends ErrorClientServer{ +class ErrorServerConnectionEndedEarly extends ErrorClientServer { static description = 'Connection ended before stream ended'; exitCode = sysexits.UNAVAILABLE; } @@ -63,4 +63,4 @@ export { ErrorServerSendFailed, ErrorServerReadableBufferLimit, ErrorServerConnectionEndedEarly, -} +}; diff --git a/tests/clientRPC/authenticationMiddleware.test.ts b/tests/clientRPC/authenticationMiddleware.test.ts index 3e5d778c2..e8285a86d 100644 --- a/tests/clientRPC/authenticationMiddleware.test.ts +++ b/tests/clientRPC/authenticationMiddleware.test.ts @@ -1,10 +1,8 @@ -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; import type { RPCRequestParams, RPCResponseResult } from '@/clientRPC/types'; +import type { ConnectionInfo, TLSConfig } from '../../src/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -20,12 +18,15 @@ import { UnaryCaller } from '@/RPC/callers'; import { UnaryHandler } from '@/RPC/handlers'; import * as middlewareUtils from '@/RPC/middleware'; import * as testsUtils from '../utils'; +import ClientServer from '../../src/clientRPC/ClientServer'; +import ClientClient from '../../src/clientRPC/ClientClient'; describe('agentUnlock', () => { const logger = new Logger('agentUnlock test', LogLevel.WARN, [ new StreamHandler(), ]); const password = 'helloworld'; + const host = '127.0.0.1'; let dataDir: string; let db: DB; let keyRing: KeyRing; @@ -33,9 +34,9 @@ describe('agentUnlock', () => { let certManager: CertManager; let session: Session; let sessionManager: SessionManager; - let server: Server; - let wss: WebSocketServer; - let port: number; + let clientServer: ClientServer; + let clientClient: ClientClient; + let tlsConfig: TLSConfig; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -72,16 +73,11 @@ describe('agentUnlock', () => { keyRing, logger, }); - const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, '127.0.0.1'); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { - wss?.close(); - server.close(); + await clientServer?.stop(true); + await clientClient?.destroy(true); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -111,22 +107,25 @@ describe('agentUnlock', () => { ), logger, }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('server'), - ); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + rpcServer.handleStream(streamPair, {} as ConnectionInfo); + }, + host, + tlsConfig, + logger, + }); + clientClient = await ClientClient.createClientClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + port: clientServer.port, + logger, + }); const rpcClient = await RPCClient.createRPCClient({ manifest: { agentUnlock: new UnaryCaller(), }, - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - '127.0.0.1', - port, - logger.getChild('client'), - ); - }, + streamPairCreateCallback: async () => clientClient.startConnection(), middleware: middlewareUtils.defaultClientMiddlewareWrapper( authMiddleware.authenticationMiddlewareClient(session), ), diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/clientRPC.test.ts index 3fb847234..6037ba7bc 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/clientRPC.test.ts @@ -23,17 +23,11 @@ import * as testsUtils from '../utils'; // This file tests both the client and server together. They're too interlinked // to be separate. describe('ClientRPC', () => { - const logger = new Logger('websocket test', LogLevel.DEBUG, [ + const logger = new Logger('websocket test', LogLevel.WARN, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); - const loudLogger = new Logger('websocket test', LogLevel.DEBUG, [ - new StreamHandler( - formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, - ), - ]); - let dataDir: string; let keyRing: KeyRing; let tlsConfig: TLSConfig; @@ -93,12 +87,12 @@ describe('ClientRPC', () => { void streamPair.readable .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + .finally(() => logger.info('STREAM HANDLING ENDED')); }, basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -128,12 +122,12 @@ describe('ClientRPC', () => { void streamPair.readable .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + .finally(() => logger.info('STREAM HANDLING ENDED')); }, basePath: dataDir, tlsConfig, host: '::1', - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -168,7 +162,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -199,7 +193,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -264,21 +258,21 @@ describe('ClientRPC', () => { while (!context.writeBackpressure) { await writer.write(message); } - loudLogger.info('BACK PRESSURED'); + logger.info('BACK PRESSURED'); backpressure.resolveP(); await resumeWriting.p; for (let i = 0; i < 100; i++) { await writer.write(message); } await writer.close(); - loudLogger.info('WRITING ENDED'); + logger.info('WRITING ENDED'); })(), ]).catch((e) => logger.error(e.toString())); }, basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -298,7 +292,7 @@ describe('ClientRPC', () => { // No touch, only consume } expect(context?.writeBackpressure).toBeFalse(); - loudLogger.info('ending'); + logger.info('ending'); }); // Readable backpressure is not actually supported. We're dealing with it by // using an buffer with a provided limit that can be very large. @@ -311,11 +305,11 @@ describe('ClientRPC', () => { Promise.all([ (async () => { await startReading.p; - loudLogger.info('Starting consumption'); + logger.info('Starting consumption'); for await (const _ of streamPair.readable) { // No touch, only consume } - loudLogger.info('Reads ended'); + logger.info('Reads ended'); })(), (async () => { await streamPair.writable.close(); @@ -329,7 +323,7 @@ describe('ClientRPC', () => { host, // Setting a really low buffer limit maxReadBufferBytes: 1500, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -341,24 +335,22 @@ describe('ClientRPC', () => { const websocket = await clientClient.startConnection(); const message = Buffer.alloc(1_000, 0xf0); const writer = websocket.writable.getWriter(); - loudLogger.info('Starting writes'); + logger.info('Starting writes'); await expect(async () => { for (let i = 0; i < 100; i++) { await writer.write(message); } }).rejects.toThrow(); startReading.resolveP(); - loudLogger.info('writes ended'); + logger.info('writes ended'); await expect(async () => { for await (const _ of websocket.readable) { // No touch, only consume } }).rejects.toThrow(); await handlingProm.p; - loudLogger.info('ending'); + logger.info('ending'); }); - // To fully test these two I need to start the client or server in a separate process and kill that process. - // These require the ping/pong connection watchdogs to be implemented. test('client ends connection abruptly', async () => { const streamPairProm = promise>(); @@ -370,7 +362,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); @@ -496,7 +488,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -535,7 +527,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -572,7 +564,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -600,7 +592,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -636,12 +628,12 @@ describe('ClientRPC', () => { void streamPair.readable .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + .finally(() => logger.info('STREAM HANDLING ENDED')); }, basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); const getResProm = promise(); @@ -672,7 +664,7 @@ describe('ClientRPC', () => { tlsConfig, host, pingTimeout: 100, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -698,7 +690,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -737,12 +729,12 @@ describe('ClientRPC', () => { void streamPair.readable .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + .finally(() => logger.info('STREAM HANDLING ENDED')); }, basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -775,12 +767,12 @@ describe('ClientRPC', () => { void streamPair.readable .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + .finally(() => logger.info('STREAM HANDLING ENDED')); }, basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -805,12 +797,12 @@ describe('ClientRPC', () => { void streamPair.readable .pipeTo(streamPair.writable) .catch(() => {}) - .finally(() => loudLogger.info('STREAM HANDLING ENDED')); + .finally(() => logger.info('STREAM HANDLING ENDED')); }, basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ @@ -850,7 +842,7 @@ describe('ClientRPC', () => { basePath: dataDir, tlsConfig, host, - logger: loudLogger.getChild('server'), + logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); clientClient = await ClientClient.createClientClient({ diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 0ec84f1f2..cec1081c3 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -1,9 +1,7 @@ -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; +import type { ConnectionInfo, TLSConfig } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -16,8 +14,9 @@ import { AgentStatusHandler, } from '@/clientRPC/handlers/agentStatus'; import RPCClient from '@/RPC/RPCClient'; -import * as clientRPCUtils from '@/clientRPC/utils'; import * as nodesUtils from '@/nodes/utils'; +import ClientClient from '@/clientRPC/ClientClient'; +import ClientServer from '@/clientRPC/ClientServer'; import * as testsUtils from '../../utils'; describe('agentStatus', () => { @@ -25,15 +24,15 @@ describe('agentStatus', () => { new StreamHandler(), ]); const password = 'helloworld'; + const host = '127.0.0.1'; let dataDir: string; let db: DB; let keyRing: KeyRing; let taskManager: TaskManager; let certManager: CertManager; - let server: Server; - let wss: WebSocketServer; - const host = '127.0.0.1'; - let port: number; + let clientServer: ClientServer; + let clientClient: ClientClient; + let tlsConfig: TLSConfig; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -60,16 +59,11 @@ describe('agentStatus', () => { taskManager, logger, }); - const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, host); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { - wss?.close(); - server?.close(); + await clientServer?.stop(true); + await clientClient?.destroy(true); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -91,22 +85,25 @@ describe('agentStatus', () => { }, logger: logger.getChild('RPCServer'), }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('server'), - ); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => { + rpcServer.handleStream(streamPair, {} as ConnectionInfo); + }, + host, + tlsConfig, + logger, + }); + clientClient = await ClientClient.createClientClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + port: clientServer.port, + logger, + }); const rpcClient = await RPCClient.createRPCClient({ manifest: { agentStatus: agentStatusCaller, }, - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - host, - port, - logger.getChild('client'), - ); - }, + streamPairCreateCallback: async () => clientClient.startConnection(), logger: logger.getChild('RPCClient'), }); // Doing the test diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 1b592af3b..0bd84c4aa 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -1,9 +1,7 @@ -import type { Server } from 'https'; -import type { WebSocketServer } from 'ws'; +import type { ConnectionInfo, TLSConfig } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -20,6 +18,8 @@ import { Session, SessionManager } from '@/sessions'; import * as clientRPCUtils from '@/clientRPC/utils'; import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; import * as middlewareUtils from '@/RPC/middleware'; +import ClientServer from '@/clientRPC/ClientServer'; +import ClientClient from '@/clientRPC/ClientClient'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { @@ -27,6 +27,7 @@ describe('agentUnlock', () => { new StreamHandler(), ]); const password = 'helloworld'; + const host = '127.0.0.1'; let dataDir: string; let db: DB; let keyRing: KeyRing; @@ -34,9 +35,9 @@ describe('agentUnlock', () => { let certManager: CertManager; let session: Session; let sessionManager: SessionManager; - let server: Server; - let wss: WebSocketServer; - let port: number; + let clientClient: ClientClient; + let clientServer: ClientServer; + let tlsConfig: TLSConfig; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -73,16 +74,11 @@ describe('agentUnlock', () => { keyRing, logger, }); - const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); - server = createServer({ - cert: tlsConfig.certChainPem, - key: tlsConfig.keyPrivatePem, - }); - port = await clientRPCUtils.listen(server, '127.0.0.1'); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); }); afterEach(async () => { - wss?.close(); - server.close(); + await clientServer.stop(true); + await clientClient.destroy(true); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -103,22 +99,24 @@ describe('agentUnlock', () => { ), logger, }); - wss = clientRPCUtils.createClientServer( - server, - rpcServer, - logger.getChild('server'), - ); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair) => + rpcServer.handleStream(streamPair, {} as ConnectionInfo), + host, + tlsConfig, + logger, + }); + clientClient = await ClientClient.createClientClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + logger, + port: clientServer.port, + }); const rpcClient = await RPCClient.createRPCClient({ manifest: { agentUnlock: agentUnlockCaller, }, - streamPairCreateCallback: async () => { - return clientRPCUtils.startConnection( - '127.0.0.1', - port, - logger.getChild('client'), - ); - }, + streamPairCreateCallback: async () => clientClient.startConnection(), middleware: middlewareUtils.defaultClientMiddlewareWrapper( authMiddleware.authenticationMiddlewareClient(session), ), From d2eea49d209ecca8f66337653be6f800a250ec0f Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 24 Feb 2023 16:29:21 +1100 Subject: [PATCH 16/23] feat: stream handler callback provided with `ConnectionInfo` [ci skip] --- src/RPC/RPCServer.ts | 2 +- src/RPC/handlers.ts | 2 +- src/RPC/types.ts | 22 ++++++++++++++++++- src/clientRPC/ClientServer.ts | 20 +++++++++++++---- .../authenticationMiddleware.test.ts | 6 ++--- tests/clientRPC/handlers/agentStatus.test.ts | 6 ++--- tests/clientRPC/handlers/agentUnlock.test.ts | 6 ++--- 7 files changed, 48 insertions(+), 16 deletions(-) diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index d33d8f053..664811c2f 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -10,10 +10,10 @@ import type { RawHandlerImplementation, ServerHandlerImplementation, UnaryHandlerImplementation, + ConnectionInfo, } from './types'; import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue } from '../types'; -import type { ConnectionInfo } from '../network/types'; import type { RPCErrorEvent } from './utils'; import type { MiddlewareFactory } from './types'; import { ReadableStream } from 'stream/web'; diff --git a/src/RPC/handlers.ts b/src/RPC/handlers.ts index c738c74e8..3ca13ce5b 100644 --- a/src/RPC/handlers.ts +++ b/src/RPC/handlers.ts @@ -2,7 +2,7 @@ import type { JSONValue } from 'types'; import type { ContainerType } from 'RPC/types'; import type { ReadableStream } from 'stream/web'; import type { JsonRpcRequest } from 'RPC/types'; -import type { ConnectionInfo } from '../network/types'; +import type { ConnectionInfo } from './types'; import type { ContextCancellable } from '../contexts/types'; abstract class Handler< diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 4d96fcc0c..e34d0fabc 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -1,5 +1,4 @@ import type { JSONValue } from '../types'; -import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; import type { ReadableStream, ReadableWritablePair } from 'stream/web'; import type { Handler } from './handlers'; @@ -11,6 +10,8 @@ import type { ClientCaller, UnaryCaller, } from './callers'; +import type { NodeId } from '../nodes/types'; +import type { Certificate } from '../keys/types'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -108,6 +109,24 @@ type JsonRpcMessage = | JsonRpcRequest | JsonRpcResponse; +/** + * Proxy connection information + * @property remoteNodeId - NodeId of the remote connecting node + * @property remoteCertificates - Certificate chain of the remote connecting node + * @property localHost - Proxy host of the local connecting node + * @property localPort - Proxy port of the local connecting node + * @property remoteHost - Proxy host of the remote connecting node + * @property remotePort - Proxy port of the remote connecting node + */ +type ConnectionInfo = Partial<{ + remoteNodeId: NodeId; + remoteCertificates: Array; + localHost: string; + localPort: number; + remoteHost: string; + remotePort: number; +}>; + // Handler types type HandlerImplementation = ( input: I, @@ -218,6 +237,7 @@ export type { JsonRpcRequest, JsonRpcResponse, JsonRpcMessage, + ConnectionInfo, HandlerImplementation, RawHandlerImplementation, DuplexHandlerImplementation, diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index a241306f4..ee2c595e7 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -6,6 +6,7 @@ import type { import type { FileSystem, PromiseDeconstructed } from 'types'; import type { TLSConfig } from 'network/types'; import type { WebSocket } from 'uWebSockets.js'; +import type { ConnectionInfo } from '../RPC/types'; import { WritableStream, ReadableStream } from 'stream/web'; import path from 'path'; import os from 'os'; @@ -17,6 +18,7 @@ import { promise } from '../utils'; type ConnectionCallback = ( streamPair: ReadableWritablePair, + connectionInfo: ConnectionInfo, ) => void; type Context = { @@ -364,11 +366,21 @@ class ClientServer { backpressure?.resolveP(); }; logger.debug('Calling handler callback'); + // There is not nodeId or certs for the client, and we can't get the remote + // port from the `uWebsocket` library. + const connectionInfo: ConnectionInfo = { + remoteHost: Buffer.from(ws.getRemoteAddressAsText()).toString(), + localHost: this.host, + localPort: this.port, + }; try { - this.connectionCallback({ - readable: readableStream, - writable: writableStream, - }); + this.connectionCallback( + { + readable: readableStream, + writable: writableStream, + }, + connectionInfo, + ); } catch (e) { context.close(ws, 0, Buffer.from('')); logger.error(e.toString()); diff --git a/tests/clientRPC/authenticationMiddleware.test.ts b/tests/clientRPC/authenticationMiddleware.test.ts index e8285a86d..59de3addd 100644 --- a/tests/clientRPC/authenticationMiddleware.test.ts +++ b/tests/clientRPC/authenticationMiddleware.test.ts @@ -1,5 +1,5 @@ import type { RPCRequestParams, RPCResponseResult } from '@/clientRPC/types'; -import type { ConnectionInfo, TLSConfig } from '../../src/network/types'; +import type { TLSConfig } from '../../src/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -108,8 +108,8 @@ describe('agentUnlock', () => { logger, }); clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - rpcServer.handleStream(streamPair, {} as ConnectionInfo); + connectionCallback: (streamPair, connectionInfo) => { + rpcServer.handleStream(streamPair, connectionInfo); }, host, tlsConfig, diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index cec1081c3..7b0f6a4d4 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -1,4 +1,4 @@ -import type { ConnectionInfo, TLSConfig } from '@/network/types'; +import type { TLSConfig } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -86,8 +86,8 @@ describe('agentStatus', () => { logger: logger.getChild('RPCServer'), }); clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => { - rpcServer.handleStream(streamPair, {} as ConnectionInfo); + connectionCallback: (streamPair, connectionInfo) => { + rpcServer.handleStream(streamPair, connectionInfo); }, host, tlsConfig, diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 0bd84c4aa..695cf9a0b 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -1,4 +1,4 @@ -import type { ConnectionInfo, TLSConfig } from '@/network/types'; +import type { TLSConfig } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -100,8 +100,8 @@ describe('agentUnlock', () => { logger, }); clientServer = await ClientServer.createClientServer({ - connectionCallback: (streamPair) => - rpcServer.handleStream(streamPair, {} as ConnectionInfo), + connectionCallback: (streamPair, connectionInfo) => + rpcServer.handleStream(streamPair, connectionInfo), host, tlsConfig, logger, From 00a2094a680884cdd023afe8fad04e2b0e0d638a Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 27 Feb 2023 15:45:21 +1100 Subject: [PATCH 17/23] fix: changed names of `ClientClient` and `ClientServer` to `WebSocketClient` and `WebSocketServer` respectively [ci skip] --- .../{ClientClient.ts => WebSocketClient.ts} | 10 +-- .../{ClientServer.ts => WebSocketServer.ts} | 8 +- .../{clientRPC.test.ts => WebSocket.test.ts} | 86 +++++++++---------- .../authenticationMiddleware.test.ts | 12 +-- tests/clientRPC/handlers/agentStatus.test.ts | 12 +-- tests/clientRPC/handlers/agentUnlock.test.ts | 12 +-- tests/clientRPC/testClient.ts | 4 +- tests/clientRPC/testServer.ts | 4 +- 8 files changed, 74 insertions(+), 74 deletions(-) rename src/clientRPC/{ClientClient.ts => WebSocketClient.ts} (98%) rename src/clientRPC/{ClientServer.ts => WebSocketServer.ts} (98%) rename tests/clientRPC/{clientRPC.test.ts => WebSocket.test.ts} (91%) diff --git a/src/clientRPC/ClientClient.ts b/src/clientRPC/WebSocketClient.ts similarity index 98% rename from src/clientRPC/ClientClient.ts rename to src/clientRPC/WebSocketClient.ts index 57e279e78..27415965c 100644 --- a/src/clientRPC/ClientClient.ts +++ b/src/clientRPC/WebSocketClient.ts @@ -14,10 +14,10 @@ import { promise } from '../utils'; const timeoutSymbol = Symbol('TimedOutSymbol'); -interface ClientClient extends createDestroy.CreateDestroy {} +interface WebSocketClient extends createDestroy.CreateDestroy {} @createDestroy.CreateDestroy() -class ClientClient { - static async createClientClient({ +class WebSocketClient { + static async createWebSocketClient({ host, port, expectedNodeIds, @@ -35,7 +35,7 @@ class ClientClient { pingTimeout?: number; maxReadableStreamBytes?: number; logger?: Logger; - }): Promise { + }): Promise { logger.info(`Creating ${this.name}`); const clientClient = new this( logger, @@ -342,4 +342,4 @@ class ClientClient { } } -export default ClientClient; +export default WebSocketClient; diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/WebSocketServer.ts similarity index 98% rename from src/clientRPC/ClientServer.ts rename to src/clientRPC/WebSocketServer.ts index ee2c595e7..0efa3676e 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/WebSocketServer.ts @@ -34,10 +34,10 @@ type Context = { writeBackpressure: boolean; }; -interface ClientServer extends startStop.StartStop {} +interface WebSocketServer extends startStop.StartStop {} @startStop.StartStop() -class ClientServer { - static async createClientServer({ +class WebSocketServer { + static async createWebSocketServer({ connectionCallback, tlsConfig, basePath, @@ -388,4 +388,4 @@ class ClientServer { } } -export default ClientServer; +export default WebSocketServer; diff --git a/tests/clientRPC/clientRPC.test.ts b/tests/clientRPC/WebSocket.test.ts similarity index 91% rename from tests/clientRPC/clientRPC.test.ts rename to tests/clientRPC/WebSocket.test.ts index 6037ba7bc..34e6ce003 100644 --- a/tests/clientRPC/clientRPC.test.ts +++ b/tests/clientRPC/WebSocket.test.ts @@ -11,9 +11,9 @@ import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; import { KeyRing } from '@/keys/index'; -import ClientServer from '@/clientRPC/ClientServer'; +import WebSocketServer from '@/clientRPC/WebSocketServer'; import { promise } from '@/utils'; -import ClientClient from '@/clientRPC/ClientClient'; +import WebSocketClient from '@/clientRPC/WebSocketClient'; import * as keysUtils from '@/keys/utils'; import * as networkErrors from '@/network/errors'; import * as nodesUtils from '@/nodes/utils'; @@ -22,7 +22,7 @@ import * as testsUtils from '../utils'; // This file tests both the client and server together. They're too interlinked // to be separate. -describe('ClientRPC', () => { +describe('WebSocket', () => { const logger = new Logger('websocket test', LogLevel.WARN, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, @@ -32,8 +32,8 @@ describe('ClientRPC', () => { let keyRing: KeyRing; let tlsConfig: TLSConfig; const host = '127.0.0.2'; - let clientServer: ClientServer; - let clientClient: ClientClient; + let clientServer: WebSocketServer; + let clientClient: WebSocketClient; const messagesArb = fc.array( fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), @@ -81,7 +81,7 @@ describe('ClientRPC', () => { // These tests are share between client and server test('makes a connection', async () => { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -95,7 +95,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -116,7 +116,7 @@ describe('ClientRPC', () => { logger.info('ending'); }); test('makes a connection over IPv6', async () => { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -130,7 +130,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host: '::1', port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -151,7 +151,7 @@ describe('ClientRPC', () => { logger.info('ending'); }); test('Handles a connection and closes before message', async () => { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -165,7 +165,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -182,7 +182,7 @@ describe('ClientRPC', () => { [streamsArb], async (streamsData) => { try { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -196,7 +196,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -231,7 +231,7 @@ describe('ClientRPC', () => { let context: { writeBackpressure: boolean } | undefined; const backpressure = promise(); const resumeWriting = promise(); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void Promise.allSettled([ @@ -275,7 +275,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -299,7 +299,7 @@ describe('ClientRPC', () => { test('Exceeding readable buffer limit causes error', async () => { const startReading = promise(); const handlingProm = promise(); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); Promise.all([ @@ -326,7 +326,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -354,7 +354,7 @@ describe('ClientRPC', () => { test('client ends connection abruptly', async () => { const streamPairProm = promise>(); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); @@ -436,7 +436,7 @@ describe('ClientRPC', () => { testProcess.once('exit', () => exitedProm.resolveP()); logger.info(`Server started on port ${await startedProm.p}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: await startedProm.p, expectedNodeIds: [keyRing.getNodeId()], @@ -465,13 +465,13 @@ describe('ClientRPC', () => { }); // These describe blocks contains tests specific to either the client or server - describe('ClientServer', () => { + describe('WebSocketServer', () => { testProp( 'allows half closed writable closes first', [messagesArb, messagesArb], async (messages1, messages2) => { try { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void (async () => { @@ -491,7 +491,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -510,7 +510,7 @@ describe('ClientRPC', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void (async () => { @@ -530,7 +530,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -549,7 +549,7 @@ describe('ClientRPC', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void (async () => { @@ -567,7 +567,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -584,7 +584,7 @@ describe('ClientRPC', () => { test('Destroying ClientServer stops all connections', async () => { const streamPairProm = promise>(); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); @@ -595,7 +595,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -622,7 +622,7 @@ describe('ClientRPC', () => { logger.info('ending'); }); test('Server rejects normal HTTPS requests', async () => { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -655,7 +655,7 @@ describe('ClientRPC', () => { expect(res.headers['upgrade']).toBe('websocket'); }); test('ping timeout', async () => { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (_) => { logger.info('inside callback'); // Hang connection @@ -667,7 +667,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -678,11 +678,11 @@ describe('ClientRPC', () => { logger.info('ending'); }); }); - describe('ClientClient', () => { + describe('WebSocketClient', () => { test('Destroying ClientClient stops all connections', async () => { const streamPairProm = promise>(); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); @@ -693,7 +693,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], @@ -723,7 +723,7 @@ describe('ClientRPC', () => { }); test('Authentication rejects bad server certificate', async () => { const invalidNodeId = testNodeUtils.generateRandomNodeId(); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -737,7 +737,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [invalidNodeId], @@ -761,7 +761,7 @@ describe('ClientRPC', () => { ]; const tlsConfig = await testsUtils.createTLSConfigWithChain(keyPairs); const nodeId = keysUtils.publicKeyToNodeId(keyPairs[1].publicKey); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -775,7 +775,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [nodeId], @@ -791,7 +791,7 @@ describe('ClientRPC', () => { }); test('Authenticates with multiple expected nodes', async () => { const alternativeNodeId = testNodeUtils.generateRandomNodeId(); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -805,7 +805,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], @@ -818,7 +818,7 @@ describe('ClientRPC', () => { logger.info('ending'); }); test('Connection times out', async () => { - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: 12345, expectedNodeIds: [keyRing.getNodeId()], @@ -834,7 +834,7 @@ describe('ClientRPC', () => { logger.info('ending'); }); test('ping timeout', async () => { - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (_) => { logger.info('inside callback'); // Hang connection @@ -845,7 +845,7 @@ describe('ClientRPC', () => { logger: logger.getChild('server'), }); logger.info(`Server started on port ${clientServer.port}`); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ host, port: clientServer.port, expectedNodeIds: [keyRing.getNodeId()], diff --git a/tests/clientRPC/authenticationMiddleware.test.ts b/tests/clientRPC/authenticationMiddleware.test.ts index 59de3addd..685652e4c 100644 --- a/tests/clientRPC/authenticationMiddleware.test.ts +++ b/tests/clientRPC/authenticationMiddleware.test.ts @@ -17,9 +17,9 @@ import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; import { UnaryCaller } from '@/RPC/callers'; import { UnaryHandler } from '@/RPC/handlers'; import * as middlewareUtils from '@/RPC/middleware'; +import WebSocketServer from '@/clientRPC/WebSocketServer'; +import WebSocketClient from '@/clientRPC/WebSocketClient'; import * as testsUtils from '../utils'; -import ClientServer from '../../src/clientRPC/ClientServer'; -import ClientClient from '../../src/clientRPC/ClientClient'; describe('agentUnlock', () => { const logger = new Logger('agentUnlock test', LogLevel.WARN, [ @@ -34,8 +34,8 @@ describe('agentUnlock', () => { let certManager: CertManager; let session: Session; let sessionManager: SessionManager; - let clientServer: ClientServer; - let clientClient: ClientClient; + let clientServer: WebSocketServer; + let clientClient: WebSocketClient; let tlsConfig: TLSConfig; beforeEach(async () => { @@ -107,7 +107,7 @@ describe('agentUnlock', () => { ), logger, }); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair, connectionInfo) => { rpcServer.handleStream(streamPair, connectionInfo); }, @@ -115,7 +115,7 @@ describe('agentUnlock', () => { tlsConfig, logger, }); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ expectedNodeIds: [keyRing.getNodeId()], host, port: clientServer.port, diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 7b0f6a4d4..1f4943bb6 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -15,8 +15,8 @@ import { } from '@/clientRPC/handlers/agentStatus'; import RPCClient from '@/RPC/RPCClient'; import * as nodesUtils from '@/nodes/utils'; -import ClientClient from '@/clientRPC/ClientClient'; -import ClientServer from '@/clientRPC/ClientServer'; +import WebSocketClient from '@/clientRPC/WebSocketClient'; +import WebSocketServer from '@/clientRPC/WebSocketServer'; import * as testsUtils from '../../utils'; describe('agentStatus', () => { @@ -30,8 +30,8 @@ describe('agentStatus', () => { let keyRing: KeyRing; let taskManager: TaskManager; let certManager: CertManager; - let clientServer: ClientServer; - let clientClient: ClientClient; + let clientServer: WebSocketServer; + let clientClient: WebSocketClient; let tlsConfig: TLSConfig; beforeEach(async () => { @@ -85,7 +85,7 @@ describe('agentStatus', () => { }, logger: logger.getChild('RPCServer'), }); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair, connectionInfo) => { rpcServer.handleStream(streamPair, connectionInfo); }, @@ -93,7 +93,7 @@ describe('agentStatus', () => { tlsConfig, logger, }); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ expectedNodeIds: [keyRing.getNodeId()], host, port: clientServer.port, diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 695cf9a0b..c927f5cac 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -18,8 +18,8 @@ import { Session, SessionManager } from '@/sessions'; import * as clientRPCUtils from '@/clientRPC/utils'; import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; import * as middlewareUtils from '@/RPC/middleware'; -import ClientServer from '@/clientRPC/ClientServer'; -import ClientClient from '@/clientRPC/ClientClient'; +import WebSocketServer from '@/clientRPC/WebSocketServer'; +import WebSocketClient from '@/clientRPC/WebSocketClient'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { @@ -35,8 +35,8 @@ describe('agentUnlock', () => { let certManager: CertManager; let session: Session; let sessionManager: SessionManager; - let clientClient: ClientClient; - let clientServer: ClientServer; + let clientClient: WebSocketClient; + let clientServer: WebSocketServer; let tlsConfig: TLSConfig; beforeEach(async () => { @@ -99,14 +99,14 @@ describe('agentUnlock', () => { ), logger, }); - clientServer = await ClientServer.createClientServer({ + clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair, connectionInfo) => rpcServer.handleStream(streamPair, connectionInfo), host, tlsConfig, logger, }); - clientClient = await ClientClient.createClientClient({ + clientClient = await WebSocketClient.createWebSocketClient({ expectedNodeIds: [keyRing.getNodeId()], host, logger, diff --git a/tests/clientRPC/testClient.ts b/tests/clientRPC/testClient.ts index 28bf498e9..bfc33232a 100644 --- a/tests/clientRPC/testClient.ts +++ b/tests/clientRPC/testClient.ts @@ -6,14 +6,14 @@ * @module */ import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import ClientClient from '@/clientRPC/ClientClient'; +import WebSocketClient from '@/clientRPC/WebSocketClient'; import * as nodesUtils from '@/nodes/utils'; async function main() { const logger = new Logger('websocket test', LogLevel.WARN, [ new StreamHandler(), ]); - const clientClient = await ClientClient.createClientClient({ + const clientClient = await WebSocketClient.createWebSocketClient({ expectedNodeIds: [nodesUtils.decodeNodeId(process.env.PK_TEST_NODE_ID!)!], host: process.env.PK_TEST_HOST ?? '127.0.0.1', port: parseInt(process.env.PK_TEST_PORT!), diff --git a/tests/clientRPC/testServer.ts b/tests/clientRPC/testServer.ts index 8a24c0bbd..5225b954c 100644 --- a/tests/clientRPC/testServer.ts +++ b/tests/clientRPC/testServer.ts @@ -8,7 +8,7 @@ import type { CertificatePEMChain, PrivateKeyPEM } from '@/keys/types'; import type { TLSConfig } from '@/network/types'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import ClientServer from '@/clientRPC/ClientServer'; +import WebSocketServer from '@/clientRPC/WebSocketServer'; async function main() { const logger = new Logger('websocket test', LogLevel.WARN, [ @@ -18,7 +18,7 @@ async function main() { keyPrivatePem: process.env.PK_TEST_KEY_PRIVATE_PEM as PrivateKeyPEM, certChainPem: process.env.PK_TEST_CERT_CHAIN_PEM as CertificatePEMChain, }; - const clientServer = await ClientServer.createClientServer({ + const clientServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (_) => { // Ignore streams and hang connections }, From 0f9cbcd4441fefd09f64adcb27d29aef587af7a6 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 27 Feb 2023 16:34:50 +1100 Subject: [PATCH 18/23] fix: `WebSocketServer` using protected arrow functions for the `uWebsocket` handlers [ci skip] --- src/clientRPC/WebSocketServer.ts | 112 +++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 35 deletions(-) diff --git a/src/clientRPC/WebSocketServer.ts b/src/clientRPC/WebSocketServer.ts index 0efa3676e..527d07c42 100644 --- a/src/clientRPC/WebSocketServer.ts +++ b/src/clientRPC/WebSocketServer.ts @@ -5,7 +5,12 @@ import type { } from 'stream/web'; import type { FileSystem, PromiseDeconstructed } from 'types'; import type { TLSConfig } from 'network/types'; -import type { WebSocket } from 'uWebSockets.js'; +import type { + HttpRequest, + HttpResponse, + us_socket_context_t, + WebSocket, +} from 'uWebSockets.js'; import type { ConnectionInfo } from '../RPC/types'; import { WritableStream, ReadableStream } from 'stream/web'; import path from 'path'; @@ -88,6 +93,7 @@ class WebSocketServer { protected connectionCallback: ConnectionCallback; protected activeSockets: Set> = new Set(); protected waitForActive: PromiseDeconstructed | null = null; + protected connectionIndex: number = 0; /** * @@ -122,7 +128,6 @@ class WebSocketServer { }): Promise { this.logger.info(`Starting ${this.constructor.name}`); this.connectionCallback = connectionCallback; - let count = 0; const tmpDir = await this.fs.promises.mkdtemp( path.join(basePath, 'polykey-'), ); @@ -140,37 +145,14 @@ class WebSocketServer { this.server.ws('/*', { sendPingsAutomatically: true, idleTimeout: this.idleTimeout, - upgrade: (res, req, context) => { - const logger = this.logger.getChild(`Connection ${count}`); - res.upgrade>( - { - logger, - }, - req.getHeader('sec-websocket-key'), - req.getHeader('sec-websocket-protocol'), - req.getHeader('sec-websocket-extensions'), - context, - ); - count += 1; - }, - open: (ws: WebSocket) => { - if (this.waitForActive == null) this.waitForActive = promise(); - this.activeSockets.add(ws); - // Set up streams and context - this.handleOpen(ws); - }, - // TODO: could this take an async and apply backpressure implicitly? - message: async (ws: WebSocket, message, isBinary) => { - ws.getUserData().message(ws, message, isBinary); - }, - close: (ws, code, message) => { - this.activeSockets.delete(ws); - if (this.activeSockets.size === 0) this.waitForActive?.resolveP(); - ws.getUserData().close(ws, code, message); - }, - drain: (ws) => { - ws.getUserData().drain(ws); - }, + upgrade: this.upgrade, + open: this.open, + message: this.message, + close: this.close, + drain: this.drain, + pong: this.pong, + // Ping uses default behaviour. + // We don't use subscriptions. }); this.server.any('/*', (res, _) => { // Reject normal requests with an upgrade code @@ -223,7 +205,37 @@ class WebSocketServer { return uWebsocket.us_socket_local_port(this.listenSocket); } - protected handleOpen(ws: WebSocket) { + /** + * Applies default upgrade behaviour and creates a UserData object we can + * mutate for the Context + */ + protected upgrade = ( + res: HttpResponse, + req: HttpRequest, + context: us_socket_context_t, + ) => { + const logger = this.logger.getChild(`Connection ${this.connectionIndex}`); + res.upgrade>( + { + logger, + }, + req.getHeader('sec-websocket-key'), + req.getHeader('sec-websocket-protocol'), + req.getHeader('sec-websocket-extensions'), + context, + ); + this.connectionIndex += 1; + }; + + /** + * Handles the creation of the `ReadableWritablePair` and provides it to the + * StreamPair handler. + */ + protected open = (ws: WebSocket) => { + if (this.waitForActive == null) this.waitForActive = promise(); + // Adding socket to the active sockets map + this.activeSockets.add(ws); + const context = ws.getUserData(); const logger = context.logger; logger.info('WS opened'); @@ -385,7 +397,37 @@ class WebSocketServer { context.close(ws, 0, Buffer.from('')); logger.error(e.toString()); } - } + }; + + /** + * Routes incoming messages to each stream using the `Context` message + * callback. + */ + protected message = ( + ws: WebSocket, + message: ArrayBuffer, + isBinary: boolean, + ) => { + ws.getUserData().message(ws, message, isBinary); + }; + + protected drain = (ws: WebSocket) => { + ws.getUserData().drain(ws); + }; + + protected close = ( + ws: WebSocket, + code: number, + message: ArrayBuffer, + ) => { + this.activeSockets.delete(ws); + if (this.activeSockets.size === 0) this.waitForActive?.resolveP(); + ws.getUserData().close(ws, code, message); + }; + + protected pong = (ws: WebSocket, message: ArrayBuffer) => { + ws.getUserData().pong(ws, message); + }; } export default WebSocketServer; From 4a4427a6a8e44f8b8aea64028db3556f74fec9cd Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 27 Feb 2023 19:43:51 +1100 Subject: [PATCH 19/23] feat: `WebSocketClient` now extends the `WebSocketStream` class for it's streams [ci skip] --- src/clientRPC/WebSocketClient.ts | 166 +++++++++++++++++------------- src/clientRPC/WebSocketStream.ts | 130 +++++++++++++++++++++++ tests/clientRPC/WebSocket.test.ts | 12 +-- 3 files changed, 227 insertions(+), 81 deletions(-) create mode 100644 src/clientRPC/WebSocketStream.ts diff --git a/src/clientRPC/WebSocketClient.ts b/src/clientRPC/WebSocketClient.ts index 27415965c..79eaa1680 100644 --- a/src/clientRPC/WebSocketClient.ts +++ b/src/clientRPC/WebSocketClient.ts @@ -1,13 +1,12 @@ -import type { ReadableWritablePair } from 'stream/web'; import type { TLSSocket } from 'tls'; import type { NodeId } from 'ids/index'; import { WritableStream, ReadableStream } from 'stream/web'; import { createDestroy } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import WebSocket from 'ws'; -import { PromiseCancellable } from '@matrixai/async-cancellable'; import { Timer } from '@matrixai/timer'; import { Validator } from 'ip-num'; +import WebSocketStream from './WebSocketStream'; import * as clientRpcUtils from './utils'; import * as clientRPCErrors from './errors'; import { promise } from '../utils'; @@ -52,8 +51,7 @@ class WebSocketClient { } protected host: string; - - protected activeConnections: Set> = new Set(); + protected activeConnections: Set = new Set(); constructor( protected logger: Logger, @@ -78,11 +76,11 @@ class WebSocketClient { this.logger.info(`Destroying ${this.constructor.name}`); if (force) { for (const activeConnection of this.activeConnections) { - activeConnection.cancel(); + activeConnection.end(); } } for (const activeConnection of this.activeConnections) { - await activeConnection; + await activeConnection.endedProm.catch(() => {}); // Ignore errors here } this.logger.info(`Destroyed ${this.constructor.name}`); } @@ -92,7 +90,7 @@ class WebSocketClient { timeoutTimer, }: { timeoutTimer?: Timer; - } = {}): Promise> { + } = {}): Promise { // Use provided timer let timer: Timer | undefined = timeoutTimer; // If no timer provided use provided default timeout @@ -108,22 +106,6 @@ class WebSocketClient { const ws = new WebSocket(address, { rejectUnauthorized: false, }); - // Creating logic for awaiting active connections and terminating them - const abortHandler = () => { - ws.terminate(); - }; - const abortController = new AbortController(); - const activeConnectionProm = new PromiseCancellable((resolve) => { - ws.once('close', () => { - abortController.signal.removeEventListener('abort', abortHandler); - resolve(); - }); - }, abortController); - abortController.signal.addEventListener('abort', abortHandler); - this.activeConnections.add(activeConnectionProm); - activeConnectionProm.finally(() => - this.activeConnections.delete(activeConnectionProm), - ); // Handle connection failure const openErrorHandler = (e) => { connectProm.rejectP( @@ -148,7 +130,10 @@ class WebSocketClient { this.logger.info('starting connection'); connectProm.resolveP(); }); - // TODO: Race with a connection timeout here + const earlyCloseProm = promise(); + ws.once('close', () => { + earlyCloseProm.resolveP(); + }); // There are 3 resolve conditions here. // 1. Connection established and authenticated // 2. connection error or authentication failure @@ -171,17 +156,45 @@ class WebSocketClient { ws.terminate(); // Ensure the connection is removed from the active connection set before // returning. - await activeConnectionProm; + await earlyCloseProm.p; throw e; } // Cleaning up connection error ws.removeEventListener('error', openErrorHandler); - let readableClosed = false; - let writableClosed = false; - const readableLogger = this.logger.getChild('readable'); - const writableLogger = this.logger.getChild('writable'); - const readableStream = new ReadableStream( + // Constructing the `ReadableWritablePair`, the lifecycle is handed off to + // the webSocketStream at this point. + const webSocketStreamClient = new WebSocketStreamClientInternal( + ws, + this.maxReadableStreamBytes, + this.pingInterval, + this.pingTimeout, + this.logger, + ); + // Setting up activeStream map lifecycle + this.activeConnections.add(webSocketStreamClient); + void webSocketStreamClient.endedProm + .catch(() => {}) // Ignore errors + .finally(() => { + this.activeConnections.delete(webSocketStreamClient); + }); + return webSocketStreamClient; + } +} + +// This is the internal implementation of the client's stream pair. +class WebSocketStreamClientInternal extends WebSocketStream { + constructor( + protected ws: WebSocket, + maxReadableStreamBytes: number, + pingInterval: number, + pingTimeout: number, + logger: Logger, + ) { + super(); + const readableLogger = logger.getChild('readable'); + const writableLogger = logger.getChild('writable'); + this.readable = new ReadableStream( { start: (controller) => { readableLogger.info('Starting'); @@ -199,13 +212,13 @@ class WebSocketClient { if (message.length === 0) { readableLogger.debug('Null message received'); ws.removeListener('message', messageHandler); - if (!readableClosed) { - readableClosed = true; + if (!this.readableEnded) { + this.endReadable(); readableLogger.debug('Closing'); controller.close(); } - if (writableClosed) { - this.logger.debug('Closing socket'); + if (this.writableEnded) { + logger.debug('Closing socket'); ws.close(); } return; @@ -215,31 +228,30 @@ class WebSocketClient { readableLogger.debug('Registering socket message handler'); ws.on('message', messageHandler); ws.once('close', (code, reason) => { - this.logger.info('Socket closed'); + logger.info('Socket closed'); ws.removeListener('message', messageHandler); - if (!readableClosed) { - readableClosed = true; + if (!this.readableEnded) { readableLogger.debug( `Closed early, ${code}, ${reason.toString()}`, ); - controller.error( - new clientRPCErrors.ErrorClientConnectionEndedEarly(), - ); + const e = new clientRPCErrors.ErrorClientConnectionEndedEarly(); + this.endReadable(e); + controller.error(e); } }); ws.once('error', (e) => { - if (!readableClosed) { - readableClosed = true; + if (!this.readableEnded) { readableLogger.error(e); + this.endReadable(e); controller.error(e); } }); }, cancel: () => { readableLogger.debug('Cancelled'); - if (!readableClosed) { + if (!this.writableEnded) { readableLogger.debug('Closing socket'); - readableClosed = true; + this.endReadable(); ws.close(); } }, @@ -249,62 +261,63 @@ class WebSocketClient { }, }, { - highWaterMark: this.maxReadableStreamBytes, + highWaterMark: maxReadableStreamBytes, size: (chunk) => chunk?.byteLength ?? 0, }, ); - const writableStream = new WritableStream({ + this.writable = new WritableStream({ start: (controller) => { writableLogger.info('Starting'); ws.once('error', (e) => { - if (!writableClosed) { - writableClosed = true; + if (!this.writableEnded) { writableLogger.error(e); + this.endWritable(e); controller.error(e); } }); ws.once('close', (code, reason) => { - if (!writableClosed) { - writableClosed = true; + if (!this.writableEnded) { writableLogger.debug(`Closed early, ${code}, ${reason.toString()}`); - controller.error( - new clientRPCErrors.ErrorClientConnectionEndedEarly(), - ); + const e = new clientRPCErrors.ErrorClientConnectionEndedEarly(); + this.endWritable(e); + controller.error(e); } }); }, close: () => { writableLogger.debug('Closing, sending null message'); ws.send(Buffer.from([])); - writableClosed = true; - if (readableClosed) { + this.endWritable(); + if (this.readableEnded) { writableLogger.debug('Closing socket'); ws.close(); } }, abort: () => { writableLogger.debug('Aborted'); - writableClosed = true; - if (readableClosed) { + this.endWritable(Error('TMP ABORTED')); + if (this.readableEnded) { writableLogger.debug('Closing socket'); ws.close(); } }, write: async (chunk, controller) => { - if (writableClosed) return; + if (this.writableEnded) return; writableLogger.debug(`Sending ${chunk?.toString()}`); const wait = promise(); ws.send(chunk, (e) => { - if (e != null && !writableClosed) { - writableClosed = true; + if (e != null && !this.writableEnded) { // Opting to debug message here and not log an error, sending // failure is common if we send before the close event. writableLogger.debug('failed to send'); - controller.error( - new clientRPCErrors.ErrorClientConnectionEndedEarly(undefined, { + const err = new clientRPCErrors.ErrorClientConnectionEndedEarly( + undefined, + { cause: e, - }), + }, ); + this.endWritable(err); + controller.error(err); } wait.resolveP(); }); @@ -315,30 +328,35 @@ class WebSocketClient { // Setting up heartbeat const pingTimer = setInterval(() => { ws.ping(); - }, this.pingInterval); + }, pingInterval); const pingTimeoutTimer = setTimeout(() => { - this.logger.debug('Ping timed out'); + logger.debug('Ping timed out'); ws.close(4002, 'Timed out'); - }, this.pingTimeout); + }, pingTimeout); ws.on('ping', () => { - this.logger.debug('Received ping'); + logger.debug('Received ping'); ws.pong(); }); ws.on('pong', () => { - this.logger.debug('Received pong'); + logger.debug('Received pong'); pingTimeoutTimer.refresh(); }); - ws.once('close', () => { - this.logger.debug('Cleaning up timers'); + ws.once('close', (code, reason) => { + logger.debug('WebSocket closed'); + const err = + code !== 1000 + ? Error(`TMP WebSocket ended with code ${code}, ${reason.toString()}`) + : undefined; + this.endWebSocket(err); + logger.debug('Cleaning up timers'); // Clean up timers clearTimeout(pingTimer); clearTimeout(pingTimeoutTimer); }); + } - return { - readable: readableStream, - writable: writableStream, - }; + end(): void { + this.ws.terminate(); } } diff --git a/src/clientRPC/WebSocketStream.ts b/src/clientRPC/WebSocketStream.ts new file mode 100644 index 000000000..63a14c198 --- /dev/null +++ b/src/clientRPC/WebSocketStream.ts @@ -0,0 +1,130 @@ +import type { + ReadableStream, + ReadableWritablePair, + WritableStream, +} from 'stream/web'; +import { promise } from '../utils'; + +abstract class WebSocketStream + implements ReadableWritablePair +{ + public readable: ReadableStream; + public writable: WritableStream; + + protected readableEnded_ = false; + protected readableEndedProm_ = promise(); + protected writableEnded_ = false; + protected writableEndedProm_ = promise(); + protected webSocketEnded_ = false; + protected webSocketEndedProm_ = promise(); + protected endedProm_: Promise; + + protected constructor() { + // Sanitise promises so they don't result in unhandled rejections + this.readableEndedProm_.p.catch(() => {}); + this.writableEndedProm_.p.catch(() => {}); + this.webSocketEndedProm_.p.catch(() => {}); + // Creating the endedPromise + this.endedProm_ = Promise.allSettled([ + this.readableEndedProm_.p, + this.writableEndedProm_.p, + this.webSocketEndedProm_.p, + ]).then((result) => { + if ( + result[0].status === 'rejected' || + result[1].status === 'rejected' || + result[2].status === 'rejected' + ) { + // Throw a compound error + throw Error('TMP Stream failed', { cause: result }); + } + // Otherwise return nothing + }); + // Ignore errors if it's never used + this.endedProm_.catch(() => {}); + } + + get readableEnded() { + return this.readableEnded_; + } + + /** + * Resolves when the readable has ended and rejects with any errors. + */ + get readableEndedProm() { + return this.readableEndedProm_.p; + } + + get writableEnded() { + return this.writableEnded_; + } + + /** + * Resolves when the writable has ended and rejects with any errors. + */ + get writableEndedProm() { + return this.writableEndedProm_.p; + } + + get webSocketEnded() { + return this.webSocketEnded_; + } + + /** + * Resolves when the webSocket has ended and rejects with any errors. + */ + get webSocketEndedProm() { + return this.webSocketEndedProm_.p; + } + + get ended() { + return this.readableEnded_ && this.writableEnded_; + } + + /** + * Resolves when the stream has fully closed + */ + get endedProm(): Promise { + return this.endedProm_; + } + + /** + * Forces the active stream to end early + */ + abstract end(): void; + + /** + * Signals the end of the ReadableStream. to be used with the extended class + * to track the streams state. + */ + protected endReadable(e?: Error) { + if (this.readableEnded_) return; + this.readableEnded_ = true; + if (e == null) this.readableEndedProm_.resolveP(); + else this.readableEndedProm_.rejectP(e); + } + + /** + * Signals the end of the WritableStream. to be used with the extended class + * to track the streams state. + */ + protected endWritable(e?: Error) { + if (this.writableEnded_) return; + this.writableEnded_ = true; + if (e == null) this.writableEndedProm_.resolveP(); + else this.writableEndedProm_.rejectP(e); + } + + /** + * Signals the end of the WebSocket. to be used with the extended class + * to track the streams state. + */ + protected endWebSocket(e?: Error) { + if (this.webSocketEnded_) return; + this.webSocketEnded_ = true; + if (e == null) this.webSocketEndedProm_.resolveP(); + else this.webSocketEndedProm_.rejectP(e); + } +} + +export default WebSocketStream; diff --git a/tests/clientRPC/WebSocket.test.ts b/tests/clientRPC/WebSocket.test.ts index 34e6ce003..7cdf55d93 100644 --- a/tests/clientRPC/WebSocket.test.ts +++ b/tests/clientRPC/WebSocket.test.ts @@ -287,7 +287,7 @@ describe('WebSocket', () => { await backpressure.p; expect(context?.writeBackpressure).toBeTrue(); resumeWriting.resolveP(); - // Consume all of the back-pressured data + // Consume all the back-pressured data for await (const _ of websocket.readable) { // No touch, only consume } @@ -295,7 +295,7 @@ describe('WebSocket', () => { logger.info('ending'); }); // Readable backpressure is not actually supported. We're dealing with it by - // using an buffer with a provided limit that can be very large. + // using a buffer with a provided limit that can be very large. test('Exceeding readable buffer limit causes error', async () => { const startReading = promise(); const handlingProm = promise(); @@ -448,11 +448,8 @@ describe('WebSocket', () => { testProcess.kill('SIGTERM'); await exitedProm.p; - // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; - for (const activeConnection of activeConnections) { - await activeConnection; - } + // Waiting for connections to end + await clientClient.destroy(); // Checking client's response to connection dropping await expect(async () => { for await (const _ of websocket.readable) { @@ -749,6 +746,7 @@ describe('WebSocket', () => { // @ts-ignore: kidnap protected property const activeConnections = clientClient.activeConnections; expect(activeConnections.size).toBe(0); + await clientServer.stop(); logger.info('ending'); }); test('Authenticates with multiple certs in chain', async () => { From bf3627b34c026936e0f57b32c2f7ddf9a509b7d4 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 27 Feb 2023 20:39:39 +1100 Subject: [PATCH 20/23] feat: `WebSocketServer` now extends the `WebSocketStream` class for it's streams [ci skip] --- src/clientRPC/WebSocketClient.ts | 44 +++--- src/clientRPC/WebSocketServer.ts | 224 +++++++++++++++------------- src/clientRPC/WebSocketStream.ts | 6 +- tests/clientRPC/WebSocket.test.ts | 239 +++++++++++++++--------------- 4 files changed, 268 insertions(+), 245 deletions(-) diff --git a/src/clientRPC/WebSocketClient.ts b/src/clientRPC/WebSocketClient.ts index 79eaa1680..cc7be82c6 100644 --- a/src/clientRPC/WebSocketClient.ts +++ b/src/clientRPC/WebSocketClient.ts @@ -212,12 +212,12 @@ class WebSocketStreamClientInternal extends WebSocketStream { if (message.length === 0) { readableLogger.debug('Null message received'); ws.removeListener('message', messageHandler); - if (!this.readableEnded) { - this.endReadable(); + if (!this.readableEnded_) { + this.signalReadableEnd(); readableLogger.debug('Closing'); controller.close(); } - if (this.writableEnded) { + if (this.writableEnded_) { logger.debug('Closing socket'); ws.close(); } @@ -230,28 +230,28 @@ class WebSocketStreamClientInternal extends WebSocketStream { ws.once('close', (code, reason) => { logger.info('Socket closed'); ws.removeListener('message', messageHandler); - if (!this.readableEnded) { + if (!this.readableEnded_) { readableLogger.debug( `Closed early, ${code}, ${reason.toString()}`, ); const e = new clientRPCErrors.ErrorClientConnectionEndedEarly(); - this.endReadable(e); + this.signalReadableEnd(e); controller.error(e); } }); ws.once('error', (e) => { - if (!this.readableEnded) { + if (!this.readableEnded_) { readableLogger.error(e); - this.endReadable(e); + this.signalReadableEnd(e); controller.error(e); } }); }, cancel: () => { readableLogger.debug('Cancelled'); - if (!this.writableEnded) { + if (!this.writableEnded_) { readableLogger.debug('Closing socket'); - this.endReadable(); + this.signalReadableEnd(); ws.close(); } }, @@ -269,17 +269,17 @@ class WebSocketStreamClientInternal extends WebSocketStream { start: (controller) => { writableLogger.info('Starting'); ws.once('error', (e) => { - if (!this.writableEnded) { + if (!this.writableEnded_) { writableLogger.error(e); - this.endWritable(e); + this.signalWritableEnd(e); controller.error(e); } }); ws.once('close', (code, reason) => { - if (!this.writableEnded) { + if (!this.writableEnded_) { writableLogger.debug(`Closed early, ${code}, ${reason.toString()}`); const e = new clientRPCErrors.ErrorClientConnectionEndedEarly(); - this.endWritable(e); + this.signalWritableEnd(e); controller.error(e); } }); @@ -287,26 +287,26 @@ class WebSocketStreamClientInternal extends WebSocketStream { close: () => { writableLogger.debug('Closing, sending null message'); ws.send(Buffer.from([])); - this.endWritable(); - if (this.readableEnded) { + this.signalWritableEnd(); + if (this.readableEnded_) { writableLogger.debug('Closing socket'); ws.close(); } }, abort: () => { writableLogger.debug('Aborted'); - this.endWritable(Error('TMP ABORTED')); - if (this.readableEnded) { + this.signalWritableEnd(Error('TMP ABORTED')); + if (this.readableEnded_) { writableLogger.debug('Closing socket'); ws.close(); } }, write: async (chunk, controller) => { - if (this.writableEnded) return; + if (this.writableEnded_) return; writableLogger.debug(`Sending ${chunk?.toString()}`); const wait = promise(); ws.send(chunk, (e) => { - if (e != null && !this.writableEnded) { + if (e != null && !this.writableEnded_) { // Opting to debug message here and not log an error, sending // failure is common if we send before the close event. writableLogger.debug('failed to send'); @@ -316,7 +316,7 @@ class WebSocketStreamClientInternal extends WebSocketStream { cause: e, }, ); - this.endWritable(err); + this.signalWritableEnd(err); controller.error(err); } wait.resolveP(); @@ -347,7 +347,7 @@ class WebSocketStreamClientInternal extends WebSocketStream { code !== 1000 ? Error(`TMP WebSocket ended with code ${code}, ${reason.toString()}`) : undefined; - this.endWebSocket(err); + this.signalWebSocketEnd(err); logger.debug('Cleaning up timers'); // Clean up timers clearTimeout(pingTimer); @@ -356,7 +356,7 @@ class WebSocketStreamClientInternal extends WebSocketStream { } end(): void { - this.ws.terminate(); + this.ws.close(4001, 'TMP ENDING CONNECTION'); } } diff --git a/src/clientRPC/WebSocketServer.ts b/src/clientRPC/WebSocketServer.ts index 527d07c42..03f556fd7 100644 --- a/src/clientRPC/WebSocketServer.ts +++ b/src/clientRPC/WebSocketServer.ts @@ -18,6 +18,7 @@ import os from 'os'; import { startStop } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import uWebsocket from 'uWebSockets.js'; +import WebSocketStream from './WebSocketStream'; import * as clientRPCErrors from './errors'; import { promise } from '../utils'; @@ -36,7 +37,6 @@ type Context = { close: (ws: WebSocket, code: number, message: ArrayBuffer) => void; pong: (ws: WebSocket, message: ArrayBuffer) => void; logger: Logger; - writeBackpressure: boolean; }; interface WebSocketServer extends startStop.StartStop {} @@ -91,8 +91,7 @@ class WebSocketServer { protected listenSocket: uWebsocket.us_listen_socket; protected host: string; protected connectionCallback: ConnectionCallback; - protected activeSockets: Set> = new Set(); - protected waitForActive: PromiseDeconstructed | null = null; + protected activeSockets: Set = new Set(); protected connectionIndex: number = 0; /** @@ -192,12 +191,14 @@ class WebSocketServer { uWebsocket.us_listen_socket_close(this.listenSocket); // Shutting down active websockets if (force) { - for (const ws of this.activeSockets) { - ws.end(); + for (const webSocketStream of this.activeSockets) { + webSocketStream.end(); } } // Wait for all active websockets to close - await this.waitForActive?.p; + for (const webSocketStream of this.activeSockets) { + webSocketStream.endedProm.catch(() => {}); // Ignore errors + } this.logger.info(`Stopped ${this.constructor.name}`); } @@ -232,28 +233,91 @@ class WebSocketServer { * StreamPair handler. */ protected open = (ws: WebSocket) => { - if (this.waitForActive == null) this.waitForActive = promise(); + const webSocketStream = new WebSocketStreamServerInternal( + ws, + this.maxReadBufferBytes, + this.pingInterval, + this.pingTimeout, + ); // Adding socket to the active sockets map - this.activeSockets.add(ws); + this.activeSockets.add(webSocketStream); + webSocketStream.endedProm + .catch(() => {}) // Ignore errors here + .finally(() => { + this.activeSockets.delete(webSocketStream); + }); + // There is not nodeId or certs for the client, and we can't get the remote + // port from the `uWebsocket` library. + const connectionInfo: ConnectionInfo = { + remoteHost: Buffer.from(ws.getRemoteAddressAsText()).toString(), + localHost: this.host, + localPort: this.port, + }; + const context = ws.getUserData(); + context.logger.debug('Calling callback'); + try { + this.connectionCallback(webSocketStream, connectionInfo); + } catch (e) { + context.close(ws, 0, Buffer.from('')); + context.logger.error(e.toString()); + } + }; + + /** + * Routes incoming messages to each stream using the `Context` message + * callback. + */ + protected message = ( + ws: WebSocket, + message: ArrayBuffer, + isBinary: boolean, + ) => { + ws.getUserData().message(ws, message, isBinary); + }; + + protected drain = (ws: WebSocket) => { + ws.getUserData().drain(ws); + }; + + protected close = ( + ws: WebSocket, + code: number, + message: ArrayBuffer, + ) => { + ws.getUserData().close(ws, code, message); + }; + + protected pong = (ws: WebSocket, message: ArrayBuffer) => { + ws.getUserData().pong(ws, message); + }; +} + +class WebSocketStreamServerInternal extends WebSocketStream { + protected backPressure: PromiseDeconstructed | null = null; + protected writeBackpressure: boolean = false; + + constructor( + protected ws: WebSocket, + maxReadBufferBytes: number, + pingInterval: number, + pingTimeout: number, + ) { + super(); const context = ws.getUserData(); const logger = context.logger; logger.info('WS opened'); - let writableClosed = false; - let readableClosed = false; - let wsClosed = false; - let backpressure: PromiseDeconstructed | null = null; let writableController: WritableStreamDefaultController | undefined; let readableController: ReadableStreamController | undefined; const writableLogger = logger.getChild('Writable'); const readableLogger = logger.getChild('Readable'); // Setting up the writable stream - const writableStream = new WritableStream({ + this.writable = new WritableStream({ start: (controller) => { writableController = controller; }, write: async (chunk, controller) => { - await backpressure?.p; + await this.backPressure?.p; const writeResult = ws.send(chunk, true); switch (writeResult) { default: @@ -265,169 +329,127 @@ class WebSocketServer { case 0: writableLogger.info('Write backpressure'); // Signal backpressure - backpressure = promise(); - context.writeBackpressure = true; - backpressure.p.finally(() => { - context.writeBackpressure = false; + this.backPressure = promise(); + this.writeBackpressure = true; + this.backPressure.p.finally(() => { + this.writeBackpressure = false; }); break; case 1: // Success - writableLogger.debug(`Sending ${chunk.toString()}`); + writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); break; } }, close: () => { writableLogger.info('Closed, sending null message'); - if (!wsClosed) ws.send(Buffer.from([]), true); - writableClosed = true; - if (readableClosed && !wsClosed) { + if (!this.webSocketEnded_) ws.send(Buffer.from([]), true); + this.signalWritableEnd(); + if (this.readableEnded_ && !this.webSocketEnded_) { writableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); ws.end(); } }, abort: () => { writableLogger.info('Aborted'); - if (readableClosed && !wsClosed) { + if (this.readableEnded_ && !this.webSocketEnded_) { writableLogger.debug('Ending socket'); - ws.end(); + this.signalWebSocketEnd(Error('TMP ERROR ABORTED')); + ws.end(4001, 'ABORTED'); } }, }); // Setting up the readable stream - const readableStream = new ReadableStream( + this.readable = new ReadableStream( { start: (controller) => { readableController = controller; context.message = (ws, message, _) => { - readableLogger.debug(`Received ${message.toString()}`); + const messageBuffer = Buffer.from(message); + readableLogger.debug(`Received ${messageBuffer.toString()}`); if (message.byteLength === 0) { readableLogger.debug('Null message received'); - if (!readableClosed) { - readableClosed = true; + if (!this.readableEnded_) { readableLogger.debug('Closing'); + this.signalReadableEnd(); controller.close(); - if (writableClosed && !wsClosed) { + if (this.writableEnded_ && !this.webSocketEnded_) { readableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); ws.end(); } } return; } - controller.enqueue(Buffer.from(message)); + controller.enqueue(messageBuffer); if (controller.desiredSize != null && controller.desiredSize < 0) { readableLogger.error('Read stream buffer full'); - if (!wsClosed) ws.end(4001, 'Read stream buffer full'); - controller.error( - new clientRPCErrors.ErrorServerReadableBufferLimit(), - ); + const err = new clientRPCErrors.ErrorServerReadableBufferLimit(); + if (!this.webSocketEnded_) { + this.signalWebSocketEnd(err); + ws.end(4001, 'Read stream buffer full'); + } + controller.error(err); } }; }, cancel: () => { - readableClosed = true; - if (writableClosed && !wsClosed) { + this.signalReadableEnd(Error('TMP READABLE CANCELLED')); + if (this.writableEnded_ && !this.webSocketEnded_) { readableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); ws.end(); } }, }, { - highWaterMark: this.maxReadBufferBytes, + highWaterMark: maxReadBufferBytes, size: (chunk) => chunk?.byteLength ?? 0, }, ); const pingTimer = setInterval(() => { ws.ping(); - }, this.pingInterval); + }, pingInterval); const pingTimeoutTimer = setTimeout(() => { logger.debug('Ping timed out'); ws.end(); - }, this.pingTimeout); + }, pingTimeout); context.pong = () => { logger.debug('Received pong'); pingTimeoutTimer.refresh(); }; context.close = () => { logger.debug('Closing'); - wsClosed = true; + this.signalWebSocketEnd(); // Cleaning up timers logger.debug('Cleaning up timers'); clearTimeout(pingTimer); clearTimeout(pingTimeoutTimer); // Closing streams logger.debug('Cleaning streams'); - if (!readableClosed) { - readableClosed = true; + const err = new clientRPCErrors.ErrorServerConnectionEndedEarly(); + if (!this.readableEnded_) { readableLogger.debug('Closing'); - readableController?.error( - new clientRPCErrors.ErrorServerConnectionEndedEarly(), - ); + this.signalReadableEnd(err); + readableController?.error(err); } - if (!writableClosed) { - writableClosed = true; + if (!this.writableEnded_) { writableLogger.debug('Closing'); - writableController?.error( - new clientRPCErrors.ErrorServerConnectionEndedEarly(), - ); + this.signalWritableEnd(err); + writableController?.error(err); } }; context.drain = () => { logger.debug('Drained'); - backpressure?.resolveP(); - }; - logger.debug('Calling handler callback'); - // There is not nodeId or certs for the client, and we can't get the remote - // port from the `uWebsocket` library. - const connectionInfo: ConnectionInfo = { - remoteHost: Buffer.from(ws.getRemoteAddressAsText()).toString(), - localHost: this.host, - localPort: this.port, + this.backPressure?.resolveP(); }; - try { - this.connectionCallback( - { - readable: readableStream, - writable: writableStream, - }, - connectionInfo, - ); - } catch (e) { - context.close(ws, 0, Buffer.from('')); - logger.error(e.toString()); - } - }; - - /** - * Routes incoming messages to each stream using the `Context` message - * callback. - */ - protected message = ( - ws: WebSocket, - message: ArrayBuffer, - isBinary: boolean, - ) => { - ws.getUserData().message(ws, message, isBinary); - }; - - protected drain = (ws: WebSocket) => { - ws.getUserData().drain(ws); - }; - - protected close = ( - ws: WebSocket, - code: number, - message: ArrayBuffer, - ) => { - this.activeSockets.delete(ws); - if (this.activeSockets.size === 0) this.waitForActive?.resolveP(); - ws.getUserData().close(ws, code, message); - }; + } - protected pong = (ws: WebSocket, message: ArrayBuffer) => { - ws.getUserData().pong(ws, message); - }; + end(): void { + this.ws.end(4001, 'TMP ENDING CONNECTION'); + } } export default WebSocketServer; diff --git a/src/clientRPC/WebSocketStream.ts b/src/clientRPC/WebSocketStream.ts index 63a14c198..a91d63a81 100644 --- a/src/clientRPC/WebSocketStream.ts +++ b/src/clientRPC/WebSocketStream.ts @@ -97,7 +97,7 @@ abstract class WebSocketStream * Signals the end of the ReadableStream. to be used with the extended class * to track the streams state. */ - protected endReadable(e?: Error) { + protected signalReadableEnd(e?: Error) { if (this.readableEnded_) return; this.readableEnded_ = true; if (e == null) this.readableEndedProm_.resolveP(); @@ -108,7 +108,7 @@ abstract class WebSocketStream * Signals the end of the WritableStream. to be used with the extended class * to track the streams state. */ - protected endWritable(e?: Error) { + protected signalWritableEnd(e?: Error) { if (this.writableEnded_) return; this.writableEnded_ = true; if (e == null) this.writableEndedProm_.resolveP(); @@ -119,7 +119,7 @@ abstract class WebSocketStream * Signals the end of the WebSocket. to be used with the extended class * to track the streams state. */ - protected endWebSocket(e?: Error) { + protected signalWebSocketEnd(e?: Error) { if (this.webSocketEnded_) return; this.webSocketEnded_ = true; if (e == null) this.webSocketEndedProm_.resolveP(); diff --git a/tests/clientRPC/WebSocket.test.ts b/tests/clientRPC/WebSocket.test.ts index 7cdf55d93..8f276c4e9 100644 --- a/tests/clientRPC/WebSocket.test.ts +++ b/tests/clientRPC/WebSocket.test.ts @@ -1,8 +1,8 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; -import type { WebSocket } from 'uWebSockets.js'; import type { KeyPair } from '@/keys/types'; import type http from 'http'; +import type WebSocketStream from '@/clientRPC/WebSocketStream'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -32,8 +32,8 @@ describe('WebSocket', () => { let keyRing: KeyRing; let tlsConfig: TLSConfig; const host = '127.0.0.2'; - let clientServer: WebSocketServer; - let clientClient: WebSocketClient; + let webSocketServer: WebSocketServer; + let webSocketClient: WebSocketClient; const messagesArb = fc.array( fc.uint8Array({ minLength: 1 }).map((d) => Buffer.from(d)), @@ -73,15 +73,15 @@ describe('WebSocket', () => { }); afterEach(async () => { logger.info('AFTEREACH'); - await clientServer?.stop(true); - await clientClient?.destroy(true); + await webSocketServer?.stop(true); + await webSocketClient?.destroy(true); await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); }); // These tests are share between client and server test('makes a connection', async () => { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -94,14 +94,14 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); const writer = websocket.writable.getWriter(); const reader = websocket.readable.getReader(); @@ -116,7 +116,7 @@ describe('WebSocket', () => { logger.info('ending'); }); test('makes a connection over IPv6', async () => { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -129,14 +129,14 @@ describe('WebSocket', () => { host: '::1', logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host: '::1', - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); const writer = websocket.writable.getWriter(); const reader = websocket.readable.getReader(); @@ -151,7 +151,7 @@ describe('WebSocket', () => { logger.info('ending'); }); test('Handles a connection and closes before message', async () => { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -164,14 +164,14 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); await websocket.writable.close(); const reader = websocket.readable.getReader(); expect((await reader.read()).done).toBeTrue(); @@ -182,7 +182,7 @@ describe('WebSocket', () => { [streamsArb], async (streamsData) => { try { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -195,16 +195,16 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); const testStream = async (messages: Array) => { - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); const writer = websocket.writable.getWriter(); const reader = websocket.readable.getReader(); for (const message of messages) { @@ -223,15 +223,15 @@ describe('WebSocket', () => { logger.info('ending'); } finally { - await clientServer.stop(true); + await webSocketServer.stop(true); } }, ); test('reverse backpressure', async () => { - let context: { writeBackpressure: boolean } | undefined; const backpressure = promise(); const resumeWriting = promise(); - clientServer = await WebSocketServer.createWebSocketServer({ + let webSocketStream: WebSocketStream | null = null; + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void Promise.allSettled([ @@ -242,20 +242,19 @@ describe('WebSocket', () => { })(), (async () => { // Kidnap the context - let ws: WebSocket<{ writeBackpressure: boolean }> | null = null; // @ts-ignore: kidnap protected property - for (const websocket of clientServer.activeSockets.values()) { - ws = websocket; + for (const websocket of webSocketServer.activeSockets.values()) { + webSocketStream = websocket; } - if (ws == null) { + if (webSocketStream == null) { await streamPair.writable.close(); return; } - context = ws.getUserData(); // Write until backPressured const message = Buffer.alloc(128, 0xf0); const writer = streamPair.writable.getWriter(); - while (!context.writeBackpressure) { + // @ts-ignore: kidnap protected property + while (!webSocketStream.writeBackpressure) { await writer.write(message); } logger.info('BACK PRESSURED'); @@ -274,24 +273,26 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); await websocket.writable.close(); await backpressure.p; - expect(context?.writeBackpressure).toBeTrue(); + // @ts-ignore: kidnap protected property + expect(webSocketStream.writeBackpressure).toBeTrue(); resumeWriting.resolveP(); // Consume all the back-pressured data for await (const _ of websocket.readable) { // No touch, only consume } - expect(context?.writeBackpressure).toBeFalse(); + // @ts-ignore: kidnap protected property + expect(webSocketStream.writeBackpressure).toBeFalse(); logger.info('ending'); }); // Readable backpressure is not actually supported. We're dealing with it by @@ -299,7 +300,7 @@ describe('WebSocket', () => { test('Exceeding readable buffer limit causes error', async () => { const startReading = promise(); const handlingProm = promise(); - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); Promise.all([ @@ -325,14 +326,14 @@ describe('WebSocket', () => { maxReadBufferBytes: 1500, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); const message = Buffer.alloc(1_000, 0xf0); const writer = websocket.writable.getWriter(); logger.info('Starting writes'); @@ -354,7 +355,7 @@ describe('WebSocket', () => { test('client ends connection abruptly', async () => { const streamPairProm = promise>(); - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); @@ -364,7 +365,7 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); + logger.info(`Server started on port ${webSocketServer.port}`); const testProcess = await testsUtils.spawn( 'ts-node', @@ -376,7 +377,7 @@ describe('WebSocket', () => { { env: { PK_TEST_HOST: host, - PK_TEST_PORT: `${clientServer.port}`, + PK_TEST_PORT: `${webSocketServer.port}`, PK_TEST_NODE_ID: nodesUtils.encodeNodeId(keyRing.getNodeId()), }, }, @@ -436,20 +437,20 @@ describe('WebSocket', () => { testProcess.once('exit', () => exitedProm.resolveP()); logger.info(`Server started on port ${await startedProm.p}`); - clientClient = await WebSocketClient.createWebSocketClient({ + webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: await startedProm.p, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); // Killing the server testProcess.kill('SIGTERM'); await exitedProm.p; // Waiting for connections to end - await clientClient.destroy(); + await webSocketClient.destroy(); // Checking client's response to connection dropping await expect(async () => { for await (const _ of websocket.readable) { @@ -468,7 +469,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void (async () => { @@ -487,18 +488,18 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); logger.info('ending'); } finally { - await clientServer.stop(true); + await webSocketServer.stop(true); } }, ); @@ -507,7 +508,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void (async () => { @@ -526,18 +527,18 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); logger.info('ending'); } finally { - await clientServer.stop(true); + await webSocketServer.stop(true); } }, ); @@ -546,7 +547,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void (async () => { @@ -563,25 +564,25 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); logger.info('ending'); } finally { - await clientServer.stop(true); + await webSocketServer.stop(true); } }, ); test('Destroying ClientServer stops all connections', async () => { const streamPairProm = promise>(); - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); @@ -591,15 +592,15 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); - await clientServer.stop(true); + const websocket = await webSocketClient.startConnection(); + await webSocketServer.stop(true); const streamPair = await streamPairProm.p; // Everything should throw after websocket ends early await expect(async () => { @@ -619,7 +620,7 @@ describe('WebSocket', () => { logger.info('ending'); }); test('Server rejects normal HTTPS requests', async () => { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -632,10 +633,10 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); + logger.info(`Server started on port ${webSocketServer.port}`); const getResProm = promise(); https.get( - `https://${host}:${clientServer.port}/`, + `https://${host}:${webSocketServer.port}/`, { rejectUnauthorized: false }, getResProm.resolveP, ); @@ -652,7 +653,7 @@ describe('WebSocket', () => { expect(res.headers['upgrade']).toBe('websocket'); }); test('ping timeout', async () => { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (_) => { logger.info('inside callback'); // Hang connection @@ -663,15 +664,15 @@ describe('WebSocket', () => { pingTimeout: 100, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - await clientClient.startConnection(); - await clientClient.destroy(); + await webSocketClient.startConnection(); + await webSocketClient.destroy(); logger.info('ending'); }); }); @@ -679,7 +680,7 @@ describe('WebSocket', () => { test('Destroying ClientClient stops all connections', async () => { const streamPairProm = promise>(); - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); @@ -689,16 +690,16 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], logger: logger.getChild('clientClient'), }); - const websocket = await clientClient.startConnection(); + const websocket = await webSocketClient.startConnection(); // Destroying the client, force close connections - await clientClient.destroy(true); + await webSocketClient.destroy(true); const streamPair = await streamPairProm.p; // Everything should throw after websocket ends early await expect(async () => { @@ -715,12 +716,12 @@ describe('WebSocket', () => { const serverWritable = streamPair.writable.getWriter(); await expect(clientWritable.write(Buffer.from('test'))).rejects.toThrow(); await expect(serverWritable.write(Buffer.from('test'))).rejects.toThrow(); - await clientServer.stop(); + await webSocketServer.stop(); logger.info('ending'); }); test('Authentication rejects bad server certificate', async () => { const invalidNodeId = testNodeUtils.generateRandomNodeId(); - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -733,20 +734,20 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [invalidNodeId], logger: logger.getChild('clientClient'), }); - await expect(clientClient.startConnection()).rejects.toThrow( + await expect(webSocketClient.startConnection()).rejects.toThrow( networkErrors.ErrorCertChainUnclaimed, ); // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; + const activeConnections = webSocketClient.activeConnections; expect(activeConnections.size).toBe(0); - await clientServer.stop(); + await webSocketServer.stop(); logger.info('ending'); }); test('Authenticates with multiple certs in chain', async () => { @@ -759,7 +760,7 @@ describe('WebSocket', () => { ]; const tlsConfig = await testsUtils.createTLSConfigWithChain(keyPairs); const nodeId = keysUtils.publicKeyToNodeId(keyPairs[1].publicKey); - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -772,24 +773,24 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); - const connProm = clientClient.startConnection(); + const connProm = webSocketClient.startConnection(); await connProm; await expect(connProm).toResolve(); // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; + const activeConnections = webSocketClient.activeConnections; expect(activeConnections.size).toBe(1); logger.info('ending'); }); test('Authenticates with multiple expected nodes', async () => { const alternativeNodeId = testNodeUtils.generateRandomNodeId(); - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); void streamPair.readable @@ -802,37 +803,37 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], logger: logger.getChild('clientClient'), }); - await expect(clientClient.startConnection()).toResolve(); + await expect(webSocketClient.startConnection()).toResolve(); // @ts-ignore: kidnap protected property - const activeConnections = clientClient.activeConnections; + const activeConnections = webSocketClient.activeConnections; expect(activeConnections.size).toBe(1); logger.info('ending'); }); test('Connection times out', async () => { - clientClient = await WebSocketClient.createWebSocketClient({ + webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: 12345, expectedNodeIds: [keyRing.getNodeId()], connectionTimeout: 0, logger: logger.getChild('clientClient'), }); - await expect(clientClient.startConnection({})).rejects.toThrow(); + await expect(webSocketClient.startConnection({})).rejects.toThrow(); await expect( - clientClient.startConnection({ + webSocketClient.startConnection({ timeoutTimer: new Timer({ delay: 0 }), }), ).rejects.toThrow(); logger.info('ending'); }); test('ping timeout', async () => { - clientServer = await WebSocketServer.createWebSocketServer({ + webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (_) => { logger.info('inside callback'); // Hang connection @@ -842,16 +843,16 @@ describe('WebSocket', () => { host, logger: logger.getChild('server'), }); - logger.info(`Server started on port ${clientServer.port}`); - clientClient = await WebSocketClient.createWebSocketClient({ + logger.info(`Server started on port ${webSocketServer.port}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ host, - port: clientServer.port, + port: webSocketServer.port, expectedNodeIds: [keyRing.getNodeId()], pingTimeout: 100, logger: logger.getChild('clientClient'), }); - await clientClient.startConnection(); - await clientClient.destroy(); + await webSocketClient.startConnection(); + await webSocketClient.destroy(); logger.info('ending'); }); }); From 953efbcbb4c9010bf2d5af05b82c0f35406be40a Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 28 Feb 2023 12:14:45 +1100 Subject: [PATCH 21/23] fix: separating the `websockets` code to its own domain [ci skip] --- src/clientRPC/errors.ts | 68 ++------ src/clientRPC/utils.ts | 163 +----------------- .../WebSocketClient.ts | 0 .../WebSocketServer.ts | 0 .../WebSocketStream.ts | 0 src/websockets/errors.ts | 119 +++++++++++++ src/websockets/utils.ts | 147 ++++++++++++++++ .../authenticationMiddleware.test.ts | 4 +- tests/clientRPC/handlers/agentStatus.test.ts | 4 +- tests/clientRPC/handlers/agentUnlock.test.ts | 4 +- .../WebSocket.test.ts | 14 +- tests/{clientRPC => websockets}/testClient.ts | 2 +- tests/{clientRPC => websockets}/testServer.ts | 2 +- 13 files changed, 302 insertions(+), 225 deletions(-) rename src/{clientRPC => websockets}/WebSocketClient.ts (100%) rename src/{clientRPC => websockets}/WebSocketServer.ts (100%) rename src/{clientRPC => websockets}/WebSocketStream.ts (100%) create mode 100644 src/websockets/errors.ts create mode 100644 src/websockets/utils.ts rename tests/{clientRPC => websockets}/WebSocket.test.ts (98%) rename tests/{clientRPC => websockets}/testClient.ts (94%) rename tests/{clientRPC => websockets}/testServer.ts (95%) diff --git a/src/clientRPC/errors.ts b/src/clientRPC/errors.ts index a077575ed..030bad8d5 100644 --- a/src/clientRPC/errors.ts +++ b/src/clientRPC/errors.ts @@ -1,66 +1,28 @@ import { ErrorPolykey, sysexits } from '../errors'; -class ErrorClient extends ErrorPolykey {} +class ErrorRPC extends ErrorPolykey {} -class ErrorClientClient extends ErrorClient {} +class ErrorRPCClient extends ErrorRPC {} -class ErrorClientDestroyed extends ErrorClientClient { - static description = 'ClientClient has been destroyed'; - exitCode = sysexits.USAGE; -} - -class ErrorClientInvalidHost extends ErrorClientClient { - static description = 'Host must be a valid IPv4 or IPv6 address string'; - exitCode = sysexits.USAGE; -} - -class ErrorClientConnectionFailed extends ErrorClientClient { - static description = 'Failed to establish connection to server'; - exitCode = sysexits.UNAVAILABLE; -} - -class ErrorClientConnectionTimedOut extends ErrorClientClient { - static description = 'Connection timed out'; - exitCode = sysexits.UNAVAILABLE; -} - -class ErrorClientConnectionEndedEarly extends ErrorClientClient { - static description = 'Connection ended before stream ended'; - exitCode = sysexits.UNAVAILABLE; -} - -class ErrorClientServer extends ErrorClient {} - -class ErrorServerPortUnavailable extends ErrorClientServer { - static description = 'Failed to bind a free port'; - exitCode = sysexits.UNAVAILABLE; -} - -class ErrorServerSendFailed extends ErrorClientServer { - static description = 'Failed to send message'; - exitCode = sysexits.UNAVAILABLE; +class ErrorClientAuthMissing extends ErrorRPCClient { + static description = 'Authorisation metadata is required but missing'; + exitCode = sysexits.NOPERM; } -class ErrorServerReadableBufferLimit extends ErrorClientServer { - static description = 'Readable buffer is full, messages received too quickly'; +class ErrorClientAuthFormat extends ErrorRPCClient { + static description = 'Authorisation metadata has invalid format'; exitCode = sysexits.USAGE; } -class ErrorServerConnectionEndedEarly extends ErrorClientServer { - static description = 'Connection ended before stream ended'; - exitCode = sysexits.UNAVAILABLE; +class ErrorClientAuthDenied extends ErrorRPCClient { + static description = 'Authorisation metadata is incorrect or expired'; + exitCode = sysexits.NOPERM; } export { - ErrorClientClient, - ErrorClientDestroyed, - ErrorClientInvalidHost, - ErrorClientConnectionFailed, - ErrorClientConnectionTimedOut, - ErrorClientConnectionEndedEarly, - ErrorClientServer, - ErrorServerPortUnavailable, - ErrorServerSendFailed, - ErrorServerReadableBufferLimit, - ErrorServerConnectionEndedEarly, + ErrorRPC, + ErrorRPCClient, + ErrorClientAuthMissing, + ErrorClientAuthFormat, + ErrorClientAuthDenied, }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 318aaa73c..8fea99443 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -1,15 +1,9 @@ -import type { SessionToken } from '../sessions/types'; -import type KeyRing from '../keys/KeyRing'; -import type SessionManager from '../sessions/SessionManager'; import type { RPCRequestParams } from './types'; -import type { JsonRpcRequest } from '../RPC/types'; -import type { Certificate } from 'keys/types'; -import type { DetailedPeerCertificate } from 'tls'; -import type { NodeId } from 'ids/index'; -import * as x509 from '@peculiar/x509'; -import * as clientErrors from '../client/errors'; -import * as networkErrors from '../network/errors'; -import * as keysUtils from '../keys/utils/index'; +import type SessionManager from 'sessions/SessionManager'; +import type KeyRing from 'keys/KeyRing'; +import type { JsonRpcRequest } from 'RPC/types'; +import type { SessionToken } from 'sessions/types'; +import * as clientErrors from './errors'; async function authenticate( sessionManager: SessionManager, @@ -60,149 +54,4 @@ function encodeAuthFromPassword(password: string): string { return `Basic ${encoded}`; } -function detailedToCertChain( - cert: DetailedPeerCertificate, -): Array { - const certChain: Array = []; - let currentCert = cert; - while (true) { - certChain.unshift(new x509.X509Certificate(currentCert.raw)); - if (currentCert === currentCert.issuerCertificate) break; - currentCert = currentCert.issuerCertificate; - } - return certChain; -} - -/** - * Verify the server certificate chain when connecting to it from a client - * This is a custom verification intended to verify that the server owned - * the relevant NodeId. - * It is possible that the server has a new NodeId. In that case we will - * verify that the new NodeId is the true descendant of the target NodeId. - */ -async function verifyServerCertificateChain( - nodeIds: Array, - certChain: Array, -): Promise { - if (!certChain.length) { - throw new networkErrors.ErrorCertChainEmpty( - 'No certificates available to verify', - ); - } - if (!nodeIds.length) { - throw new networkErrors.ErrorConnectionNodesEmpty( - 'No nodes were provided to verify against', - ); - } - const now = new Date(); - let certClaim: Certificate | null = null; - let certClaimIndex: number | null = null; - let verifiedNodeId: NodeId | null = null; - for (let certIndex = 0; certIndex < certChain.length; certIndex++) { - const cert = certChain[certIndex]; - if (now < cert.notBefore || now > cert.notAfter) { - throw new networkErrors.ErrorCertChainDateInvalid( - 'Chain certificate date is invalid', - { - data: { - cert, - certIndex, - notBefore: cert.notBefore, - notAfter: cert.notAfter, - now, - }, - }, - ); - } - const certNodeId = keysUtils.certNodeId(cert); - if (certNodeId == null) { - throw new networkErrors.ErrorCertChainNameInvalid( - 'Chain certificate common name attribute is missing', - { - data: { - cert, - certIndex, - }, - }, - ); - } - const certPublicKey = keysUtils.certPublicKey(cert); - if (certPublicKey == null) { - throw new networkErrors.ErrorCertChainKeyInvalid( - 'Chain certificate public key is missing', - { - data: { - cert, - certIndex, - }, - }, - ); - } - if (!(await keysUtils.certNodeSigned(cert))) { - throw new networkErrors.ErrorCertChainSignatureInvalid( - 'Chain certificate does not have a valid node-signature', - { - data: { - cert, - certIndex, - nodeId: keysUtils.publicKeyToNodeId(certPublicKey), - commonName: certNodeId, - }, - }, - ); - } - for (const nodeId of nodeIds) { - if (certNodeId.equals(nodeId)) { - // Found the certificate claiming the nodeId - certClaim = cert; - certClaimIndex = certIndex; - verifiedNodeId = nodeId; - } - } - // If cert is found then break out of loop - if (verifiedNodeId != null) break; - } - if (certClaimIndex == null || certClaim == null || verifiedNodeId == null) { - throw new networkErrors.ErrorCertChainUnclaimed( - 'Node IDs is not claimed by any certificate', - { - data: { nodeIds }, - }, - ); - } - if (certClaimIndex > 0) { - let certParent: Certificate; - let certChild: Certificate; - for (let certIndex = certClaimIndex; certIndex > 0; certIndex--) { - certParent = certChain[certIndex]; - certChild = certChain[certIndex - 1]; - if ( - !keysUtils.certIssuedBy(certParent, certChild) || - !(await keysUtils.certSignedBy( - certParent, - keysUtils.certPublicKey(certChild)!, - )) - ) { - throw new networkErrors.ErrorCertChainBroken( - 'Chain certificate is not signed by parent certificate', - { - data: { - cert: certChild, - certIndex: certIndex - 1, - certParent, - }, - }, - ); - } - } - } - return verifiedNodeId; -} - -export { - authenticate, - decodeAuth, - encodeAuthFromPassword, - detailedToCertChain, - verifyServerCertificateChain, -}; +export { authenticate, decodeAuth, encodeAuthFromPassword }; diff --git a/src/clientRPC/WebSocketClient.ts b/src/websockets/WebSocketClient.ts similarity index 100% rename from src/clientRPC/WebSocketClient.ts rename to src/websockets/WebSocketClient.ts diff --git a/src/clientRPC/WebSocketServer.ts b/src/websockets/WebSocketServer.ts similarity index 100% rename from src/clientRPC/WebSocketServer.ts rename to src/websockets/WebSocketServer.ts diff --git a/src/clientRPC/WebSocketStream.ts b/src/websockets/WebSocketStream.ts similarity index 100% rename from src/clientRPC/WebSocketStream.ts rename to src/websockets/WebSocketStream.ts diff --git a/src/websockets/errors.ts b/src/websockets/errors.ts new file mode 100644 index 000000000..002282359 --- /dev/null +++ b/src/websockets/errors.ts @@ -0,0 +1,119 @@ +import { ErrorPolykey, sysexits } from '../errors'; + +class ErrorWebSocket extends ErrorPolykey {} + +class ErrorWebSocketClient extends ErrorWebSocket {} + +class ErrorClientDestroyed extends ErrorWebSocketClient { + static description = 'ClientClient has been destroyed'; + exitCode = sysexits.USAGE; +} + +class ErrorClientInvalidHost extends ErrorWebSocketClient { + static description = 'Host must be a valid IPv4 or IPv6 address string'; + exitCode = sysexits.USAGE; +} + +class ErrorClientConnectionFailed extends ErrorWebSocketClient { + static description = 'Failed to establish connection to server'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorClientConnectionTimedOut extends ErrorWebSocketClient { + static description = 'Connection timed out'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorClientConnectionEndedEarly extends ErrorWebSocketClient { + static description = 'Connection ended before stream ended'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorWebSocketServer extends ErrorWebSocket {} + +class ErrorServerPortUnavailable extends ErrorWebSocketServer { + static description = 'Failed to bind a free port'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorServerSendFailed extends ErrorWebSocketServer { + static description = 'Failed to send message'; + exitCode = sysexits.UNAVAILABLE; +} + +class ErrorServerReadableBufferLimit extends ErrorWebSocketServer { + static description = 'Readable buffer is full, messages received too quickly'; + exitCode = sysexits.USAGE; +} + +class ErrorServerConnectionEndedEarly extends ErrorWebSocketServer { + static description = 'Connection ended before stream ended'; + exitCode = sysexits.UNAVAILABLE; +} + +/** + * Used for certificate verification + */ +class ErrorCertChain extends ErrorWebSocket {} + +class ErrorCertChainEmpty extends ErrorCertChain { + static description = 'Certificate chain is empty'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainUnclaimed extends ErrorCertChain { + static description = 'The target node id is not claimed by any certificate'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainBroken extends ErrorCertChain { + static description = 'The signature chain is broken'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainDateInvalid extends ErrorCertChain { + static description = 'Certificate in the chain is expired'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainNameInvalid extends ErrorCertChain { + static description = 'Certificate is missing the common name'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainKeyInvalid extends ErrorCertChain { + static description = 'Certificate public key does not generate the Node ID'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorCertChainSignatureInvalid extends ErrorCertChain { + static description = 'Certificate self-signed signature is invalid'; + exitCode = sysexits.PROTOCOL; +} + +class ErrorConnectionNodesEmpty extends ErrorWebSocket { + static description = 'Nodes list to verify against was empty'; + exitCode = sysexits.USAGE; +} + +export { + ErrorWebSocketClient, + ErrorClientDestroyed, + ErrorClientInvalidHost, + ErrorClientConnectionFailed, + ErrorClientConnectionTimedOut, + ErrorClientConnectionEndedEarly, + ErrorWebSocketServer, + ErrorServerPortUnavailable, + ErrorServerSendFailed, + ErrorServerReadableBufferLimit, + ErrorServerConnectionEndedEarly, + ErrorCertChainEmpty, + ErrorCertChainUnclaimed, + ErrorCertChainBroken, + ErrorCertChainDateInvalid, + ErrorCertChainNameInvalid, + ErrorCertChainKeyInvalid, + ErrorCertChainSignatureInvalid, + ErrorConnectionNodesEmpty, +}; diff --git a/src/websockets/utils.ts b/src/websockets/utils.ts new file mode 100644 index 000000000..638bdb181 --- /dev/null +++ b/src/websockets/utils.ts @@ -0,0 +1,147 @@ +import type { Certificate } from 'keys/types'; +import type { DetailedPeerCertificate } from 'tls'; +import type { NodeId } from 'ids/index'; +import * as x509 from '@peculiar/x509'; +import * as webSocketErrors from './errors'; +import * as keysUtils from '../keys/utils/index'; + +function detailedToCertChain( + cert: DetailedPeerCertificate, +): Array { + const certChain: Array = []; + let currentCert = cert; + while (true) { + certChain.unshift(new x509.X509Certificate(currentCert.raw)); + if (currentCert === currentCert.issuerCertificate) break; + currentCert = currentCert.issuerCertificate; + } + return certChain; +} + +/** + * Verify the server certificate chain when connecting to it from a client + * This is a custom verification intended to verify that the server owned + * the relevant NodeId. + * It is possible that the server has a new NodeId. In that case we will + * verify that the new NodeId is the true descendant of the target NodeId. + */ +async function verifyServerCertificateChain( + nodeIds: Array, + certChain: Array, +): Promise { + if (!certChain.length) { + throw new webSocketErrors.ErrorCertChainEmpty( + 'No certificates available to verify', + ); + } + if (!nodeIds.length) { + throw new webSocketErrors.ErrorConnectionNodesEmpty( + 'No nodes were provided to verify against', + ); + } + const now = new Date(); + let certClaim: Certificate | null = null; + let certClaimIndex: number | null = null; + let verifiedNodeId: NodeId | null = null; + for (let certIndex = 0; certIndex < certChain.length; certIndex++) { + const cert = certChain[certIndex]; + if (now < cert.notBefore || now > cert.notAfter) { + throw new webSocketErrors.ErrorCertChainDateInvalid( + 'Chain certificate date is invalid', + { + data: { + cert, + certIndex, + notBefore: cert.notBefore, + notAfter: cert.notAfter, + now, + }, + }, + ); + } + const certNodeId = keysUtils.certNodeId(cert); + if (certNodeId == null) { + throw new webSocketErrors.ErrorCertChainNameInvalid( + 'Chain certificate common name attribute is missing', + { + data: { + cert, + certIndex, + }, + }, + ); + } + const certPublicKey = keysUtils.certPublicKey(cert); + if (certPublicKey == null) { + throw new webSocketErrors.ErrorCertChainKeyInvalid( + 'Chain certificate public key is missing', + { + data: { + cert, + certIndex, + }, + }, + ); + } + if (!(await keysUtils.certNodeSigned(cert))) { + throw new webSocketErrors.ErrorCertChainSignatureInvalid( + 'Chain certificate does not have a valid node-signature', + { + data: { + cert, + certIndex, + nodeId: keysUtils.publicKeyToNodeId(certPublicKey), + commonName: certNodeId, + }, + }, + ); + } + for (const nodeId of nodeIds) { + if (certNodeId.equals(nodeId)) { + // Found the certificate claiming the nodeId + certClaim = cert; + certClaimIndex = certIndex; + verifiedNodeId = nodeId; + } + } + // If cert is found then break out of loop + if (verifiedNodeId != null) break; + } + if (certClaimIndex == null || certClaim == null || verifiedNodeId == null) { + throw new webSocketErrors.ErrorCertChainUnclaimed( + 'Node IDs is not claimed by any certificate', + { + data: { nodeIds }, + }, + ); + } + if (certClaimIndex > 0) { + let certParent: Certificate; + let certChild: Certificate; + for (let certIndex = certClaimIndex; certIndex > 0; certIndex--) { + certParent = certChain[certIndex]; + certChild = certChain[certIndex - 1]; + if ( + !keysUtils.certIssuedBy(certParent, certChild) || + !(await keysUtils.certSignedBy( + certParent, + keysUtils.certPublicKey(certChild)!, + )) + ) { + throw new webSocketErrors.ErrorCertChainBroken( + 'Chain certificate is not signed by parent certificate', + { + data: { + cert: certChild, + certIndex: certIndex - 1, + certParent, + }, + }, + ); + } + } + } + return verifiedNodeId; +} + +export { detailedToCertChain, verifyServerCertificateChain }; diff --git a/tests/clientRPC/authenticationMiddleware.test.ts b/tests/clientRPC/authenticationMiddleware.test.ts index 685652e4c..e204dda20 100644 --- a/tests/clientRPC/authenticationMiddleware.test.ts +++ b/tests/clientRPC/authenticationMiddleware.test.ts @@ -17,8 +17,8 @@ import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; import { UnaryCaller } from '@/RPC/callers'; import { UnaryHandler } from '@/RPC/handlers'; import * as middlewareUtils from '@/RPC/middleware'; -import WebSocketServer from '@/clientRPC/WebSocketServer'; -import WebSocketClient from '@/clientRPC/WebSocketClient'; +import WebSocketServer from '@/websockets/WebSocketServer'; +import WebSocketClient from '@/websockets/WebSocketClient'; import * as testsUtils from '../utils'; describe('agentUnlock', () => { diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 1f4943bb6..b40216b99 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -15,8 +15,8 @@ import { } from '@/clientRPC/handlers/agentStatus'; import RPCClient from '@/RPC/RPCClient'; import * as nodesUtils from '@/nodes/utils'; -import WebSocketClient from '@/clientRPC/WebSocketClient'; -import WebSocketServer from '@/clientRPC/WebSocketServer'; +import WebSocketClient from '@/websockets/WebSocketClient'; +import WebSocketServer from '@/websockets/WebSocketServer'; import * as testsUtils from '../../utils'; describe('agentStatus', () => { diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index c927f5cac..41d8c9b44 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -18,8 +18,8 @@ import { Session, SessionManager } from '@/sessions'; import * as clientRPCUtils from '@/clientRPC/utils'; import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; import * as middlewareUtils from '@/RPC/middleware'; -import WebSocketServer from '@/clientRPC/WebSocketServer'; -import WebSocketClient from '@/clientRPC/WebSocketClient'; +import WebSocketServer from '@/websockets/WebSocketServer'; +import WebSocketClient from '@/websockets/WebSocketClient'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { diff --git a/tests/clientRPC/WebSocket.test.ts b/tests/websockets/WebSocket.test.ts similarity index 98% rename from tests/clientRPC/WebSocket.test.ts rename to tests/websockets/WebSocket.test.ts index 8f276c4e9..d098d60e8 100644 --- a/tests/clientRPC/WebSocket.test.ts +++ b/tests/websockets/WebSocket.test.ts @@ -2,7 +2,7 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; import type { KeyPair } from '@/keys/types'; import type http from 'http'; -import type WebSocketStream from '@/clientRPC/WebSocketStream'; +import type WebSocketStream from '@/websockets/WebSocketStream'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -11,11 +11,11 @@ import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; import { KeyRing } from '@/keys/index'; -import WebSocketServer from '@/clientRPC/WebSocketServer'; +import WebSocketServer from '@/websockets/WebSocketServer'; +import WebSocketClient from '@/websockets/WebSocketClient'; import { promise } from '@/utils'; -import WebSocketClient from '@/clientRPC/WebSocketClient'; import * as keysUtils from '@/keys/utils'; -import * as networkErrors from '@/network/errors'; +import * as webSocketErrors from '@/websockets/errors'; import * as nodesUtils from '@/nodes/utils'; import * as testNodeUtils from '../nodes/utils'; import * as testsUtils from '../utils'; @@ -372,7 +372,7 @@ describe('WebSocket', () => { [ '--project', testsUtils.tsConfigPath, - `${globalThis.testDir}/clientRPC/testClient.ts`, + `${globalThis.testDir}/websockets/testClient.ts`, ], { env: { @@ -415,7 +415,7 @@ describe('WebSocket', () => { [ '--project', testsUtils.tsConfigPath, - `${globalThis.testDir}/clientRPC/testServer.ts`, + `${globalThis.testDir}/websockets/testServer.ts`, ], { env: { @@ -742,7 +742,7 @@ describe('WebSocket', () => { logger: logger.getChild('clientClient'), }); await expect(webSocketClient.startConnection()).rejects.toThrow( - networkErrors.ErrorCertChainUnclaimed, + webSocketErrors.ErrorCertChainUnclaimed, ); // @ts-ignore: kidnap protected property const activeConnections = webSocketClient.activeConnections; diff --git a/tests/clientRPC/testClient.ts b/tests/websockets/testClient.ts similarity index 94% rename from tests/clientRPC/testClient.ts rename to tests/websockets/testClient.ts index bfc33232a..52179d0c3 100644 --- a/tests/clientRPC/testClient.ts +++ b/tests/websockets/testClient.ts @@ -6,7 +6,7 @@ * @module */ import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import WebSocketClient from '@/clientRPC/WebSocketClient'; +import WebSocketClient from '@/websockets/WebSocketClient'; import * as nodesUtils from '@/nodes/utils'; async function main() { diff --git a/tests/clientRPC/testServer.ts b/tests/websockets/testServer.ts similarity index 95% rename from tests/clientRPC/testServer.ts rename to tests/websockets/testServer.ts index 5225b954c..0a7aac880 100644 --- a/tests/clientRPC/testServer.ts +++ b/tests/websockets/testServer.ts @@ -8,7 +8,7 @@ import type { CertificatePEMChain, PrivateKeyPEM } from '@/keys/types'; import type { TLSConfig } from '@/network/types'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import WebSocketServer from '@/clientRPC/WebSocketServer'; +import WebSocketServer from '@/websockets/WebSocketServer'; async function main() { const logger = new Logger('websocket test', LogLevel.WARN, [ From f260d73e42bc495f06c0e56e238de8dbbcdce1e0 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 28 Feb 2023 12:40:03 +1100 Subject: [PATCH 22/23] fix: abstracting `uWebsocket` server creation to its own protected method [ci skip] --- src/websockets/WebSocketServer.ts | 39 ++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index 03f556fd7..a57242dc5 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -127,20 +127,7 @@ class WebSocketServer { }): Promise { this.logger.info(`Starting ${this.constructor.name}`); this.connectionCallback = connectionCallback; - const tmpDir = await this.fs.promises.mkdtemp( - path.join(basePath, 'polykey-'), - ); - // TODO: The key file needs to be in the encrypted format - const keyFile = path.join(tmpDir, 'keyFile.pem'); - const certFile = path.join(tmpDir, 'certFile.pem'); - await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); - await this.fs.promises.writeFile(certFile, tlsConfig.certChainPem); - this.server = uWebsocket.SSLApp({ - key_file_name: keyFile, - cert_file_name: certFile, - }); - await this.fs.promises.rm(keyFile); - await this.fs.promises.rm(certFile); + await this.setupServer(basePath, tlsConfig); this.server.ws('/*', { sendPingsAutomatically: true, idleTimeout: this.idleTimeout, @@ -206,6 +193,30 @@ class WebSocketServer { return uWebsocket.us_socket_local_port(this.listenSocket); } + /** + * This creates the pem files and starts the server with them. It ensures that + * files are cleaned up to the best of its ability. + */ + protected async setupServer(basePath: string, tlsConfig: TLSConfig) { + const tmpDir = await this.fs.promises.mkdtemp( + path.join(basePath, 'polykey-'), + ); + // TODO: The key file needs to be in the encrypted format + const keyFile = path.join(tmpDir, 'keyFile.pem'); + const certFile = path.join(tmpDir, 'certFile.pem'); + try { + await this.fs.promises.writeFile(keyFile, tlsConfig.keyPrivatePem); + await this.fs.promises.writeFile(certFile, tlsConfig.certChainPem); + this.server = uWebsocket.SSLApp({ + key_file_name: keyFile, + cert_file_name: certFile, + }); + } finally { + await this.fs.promises.rm(keyFile); + await this.fs.promises.rm(certFile); + } + } + /** * Applies default upgrade behaviour and creates a UserData object we can * mutate for the Context From c6fecebf5a844daa72aa4f3975143199d7ef773c Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 28 Feb 2023 15:38:20 +1100 Subject: [PATCH 23/23] feat: `WebSocketServer` now extends `EventTarget` with `connection`, `start` and `stop` events [ci skip] --- src/websockets/WebSocketServer.ts | 61 ++++++++++++++++++++++++------- src/websockets/events.ts | 46 +++++++++++++++++++++++ 2 files changed, 93 insertions(+), 14 deletions(-) create mode 100644 src/websockets/events.ts diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index a57242dc5..5ab3a80c0 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -20,6 +20,7 @@ import Logger from '@matrixai/logger'; import uWebsocket from 'uWebSockets.js'; import WebSocketStream from './WebSocketStream'; import * as clientRPCErrors from './errors'; +import * as webSocketEvents from './events'; import { promise } from '../utils'; type ConnectionCallback = ( @@ -39,9 +40,15 @@ type Context = { logger: Logger; }; +/** + * Events: + * - start + * - stop + * - connection + */ interface WebSocketServer extends startStop.StartStop {} @startStop.StartStop() -class WebSocketServer { +class WebSocketServer extends EventTarget { static async createWebSocketServer({ connectionCallback, tlsConfig, @@ -90,7 +97,9 @@ class WebSocketServer { protected server: uWebsocket.TemplatedApp; protected listenSocket: uWebsocket.us_listen_socket; protected host: string; - protected connectionCallback: ConnectionCallback; + protected connectionEventHandler: ( + event: webSocketEvents.ConnectionEvent, + ) => void; protected activeSockets: Set = new Set(); protected connectionIndex: number = 0; @@ -110,23 +119,35 @@ class WebSocketServer { protected idleTimeout: number | undefined, protected pingInterval: number, protected pingTimeout: number, - ) {} + ) { + super(); + } public async start({ - connectionCallback, tlsConfig, basePath = os.tmpdir(), host, port = 0, + connectionCallback, }: { - connectionCallback: ConnectionCallback; tlsConfig: TLSConfig; basePath?: string; host?: string; port?: number; + connectionCallback?: ConnectionCallback; }): Promise { this.logger.info(`Starting ${this.constructor.name}`); - this.connectionCallback = connectionCallback; + if (connectionCallback != null) { + this.connectionEventHandler = ( + event: webSocketEvents.ConnectionEvent, + ) => { + connectionCallback( + event.detail.webSocketStream, + event.detail.connectionInfo, + ); + }; + this.addEventListener('connection', this.connectionEventHandler); + } await this.setupServer(basePath, tlsConfig); this.server.ws('/*', { sendPingsAutomatically: true, @@ -169,6 +190,14 @@ class WebSocketServer { `Listening on port ${uWebsocket.us_socket_local_port(this.listenSocket)}`, ); this.host = host ?? '127.0.0.1'; + this.dispatchEvent( + new webSocketEvents.StartEvent({ + detail: { + host: this.host, + port: this.port, + }, + }), + ); this.logger.info(`Started ${this.constructor.name}`); } @@ -186,6 +215,10 @@ class WebSocketServer { for (const webSocketStream of this.activeSockets) { webSocketStream.endedProm.catch(() => {}); // Ignore errors } + if (this.connectionEventHandler != null) { + this.removeEventListener('connection', this.connectionEventHandler); + } + this.dispatchEvent(new webSocketEvents.StopEvent()); this.logger.info(`Stopped ${this.constructor.name}`); } @@ -265,14 +298,14 @@ class WebSocketServer { localHost: this.host, localPort: this.port, }; - const context = ws.getUserData(); - context.logger.debug('Calling callback'); - try { - this.connectionCallback(webSocketStream, connectionInfo); - } catch (e) { - context.close(ws, 0, Buffer.from('')); - context.logger.error(e.toString()); - } + this.dispatchEvent( + new webSocketEvents.ConnectionEvent({ + detail: { + webSocketStream, + connectionInfo, + }, + }), + ); }; /** diff --git a/src/websockets/events.ts b/src/websockets/events.ts new file mode 100644 index 000000000..aaabb5842 --- /dev/null +++ b/src/websockets/events.ts @@ -0,0 +1,46 @@ +import type WebSocketStream from 'websockets/WebSocketStream'; +import type { ConnectionInfo } from 'RPC/types'; + +class StartEvent extends Event { + public detail: { + host: string; + port: number; + }; + constructor( + options: EventInit & { + detail: { + host: string; + port: number; + }; + }, + ) { + super('start', options); + this.detail = options.detail; + } +} + +class StopEvent extends Event { + constructor(options?: EventInit) { + super('stop', options); + } +} + +class ConnectionEvent extends Event { + public detail: { + webSocketStream: WebSocketStream; + connectionInfo: ConnectionInfo; + }; + constructor( + options: EventInit & { + detail: { + webSocketStream: WebSocketStream; + connectionInfo: ConnectionInfo; + }; + }, + ) { + super('connection', options); + this.detail = options.detail; + } +} + +export { StartEvent, StopEvent, ConnectionEvent };