From 70bd6568604b16b10aa261158eeaab8ac27acc1b Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 11 Aug 2023 15:56:42 +1000 Subject: [PATCH] refactor: combined the websocket client and server stream implementation * Related #540 [ci skip] --- src/PolykeyAgent.ts | 3 - src/config.ts | 1 - src/websockets/WebSocketClient.ts | 262 +++--------------------------- src/websockets/WebSocketServer.ts | 218 +------------------------ src/websockets/WebSocketStream.ts | 249 +++++++++++++++++++++++++++- 5 files changed, 272 insertions(+), 461 deletions(-) diff --git a/src/PolykeyAgent.ts b/src/PolykeyAgent.ts index 078dc6f3a..a833a5a1c 100644 --- a/src/PolykeyAgent.ts +++ b/src/PolykeyAgent.ts @@ -55,7 +55,6 @@ type NetworkConfig = { clientHost?: string; clientPort?: number; // Websocket server config - maxReadableStreamBytes?: number; maxIdleTimeout?: number; pingIntervalTime?: number; pingTimeoutTimeTime?: number; @@ -496,11 +495,9 @@ class PolykeyAgent { (await WebSocketServer.createWebSocketServer({ connectionCallback: (rpcStream) => rpcServerClient!.handleStream(rpcStream), - fs, host: networkConfig_.clientHost, port: networkConfig_.clientPort, tlsConfig, - maxReadableStreamBytes: networkConfig_.maxReadableStreamBytes, maxIdleTimeout: networkConfig_.maxIdleTimeout, pingIntervalTime: networkConfig_.pingIntervalTime, pingTimeoutTimeTime: networkConfig_.pingTimeoutTimeTime, diff --git a/src/config.ts b/src/config.ts index 604b9cb79..9d75c4cf5 100644 --- a/src/config.ts +++ b/src/config.ts @@ -98,7 +98,6 @@ const config = { clientHost: '127.0.0.1', clientPort: 0, // Websocket server config - maxReadableStreamBytes: 1_000_000_000, // About 1 GB maxIdleTimeout: 120, // 2 minutes pingIntervalTime: 1_000, // 1 second pingTimeoutTimeTime: 10_000, // 10 seconds diff --git a/src/websockets/WebSocketClient.ts b/src/websockets/WebSocketClient.ts index ffa5d3510..481e9a8f1 100644 --- a/src/websockets/WebSocketClient.ts +++ b/src/websockets/WebSocketClient.ts @@ -1,12 +1,6 @@ import type { TLSSocket } from 'tls'; -import type { - ReadableStreamController, - WritableStreamDefaultController, -} from 'stream/web'; import type { ContextTimed } from '@matrixai/contexts'; import type { NodeId, NodeIdEncoded } from '../ids'; -import type { JSONValue } from '../types'; -import { WritableStream, ReadableStream } from 'stream/web'; import { createDestroy } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import WebSocket from 'ws'; @@ -33,7 +27,6 @@ class WebSocketClient { * Default is 1,000 milliseconds. * @param obj.pingTimeoutTimeTime - Time before connection is cleaned up after no ping responses. * Default is 10,000 milliseconds. - * @param obj.maxReadableStreamBytes - The number of bytes the readable stream will buffer until pausing. * @param obj.logger */ static async createWebSocketClient({ @@ -43,7 +36,6 @@ class WebSocketClient { connectionTimeoutTime = Infinity, pingIntervalTime = 1_000, pingTimeoutTimeTime = 10_000, - maxReadableStreamBytes = 1_000, // About 1kB logger = new Logger(this.name), }: { host: string; @@ -52,7 +44,6 @@ class WebSocketClient { connectionTimeoutTime?: number; pingIntervalTime?: number; pingTimeoutTimeTime?: number; - maxReadableStreamBytes?: number; logger?: Logger; }): Promise { logger.info(`Creating ${this.name}`); @@ -60,7 +51,6 @@ class WebSocketClient { logger, host, port, - maxReadableStreamBytes, expectedNodeIds, connectionTimeoutTime, pingIntervalTime, @@ -77,7 +67,6 @@ class WebSocketClient { protected logger: Logger, host: string, protected port: number, - protected maxReadableStreamBytes: number, protected expectedNodeIds: Array, protected connectionTimeoutTime: number, protected pingIntervalTime: number, @@ -126,7 +115,7 @@ class WebSocketClient { @createDestroy.ready(new webSocketErrors.ErrorClientDestroyed()) public async startConnection( ctx: Partial = {}, - ): Promise { + ): Promise { // Setting up abort/cancellation logic const abortRaceProm = promise(); // Ignore unhandled rejection @@ -161,7 +150,13 @@ class WebSocketClient { const address = `wss://${this.host}:${this.port}`; this.logger.info(`Connecting to ${address}`); const connectProm = promise(); - const authenticateProm = promise(); + const authenticateProm = promise<{ + nodeId: NodeIdEncoded; + localHost: string; + localPort: number; + remoteHost: string; + remotePort: number; + }>(); const ws = new WebSocket(address, { rejectUnauthorized: false, }); @@ -178,12 +173,21 @@ class WebSocketClient { ws.once('upgrade', async (request) => { const tlsSocket = request.socket as TLSSocket; const peerCert = tlsSocket.getPeerCertificate(true); - webSocketUtils - .verifyServerCertificateChain( + try { + const nodeId = await webSocketUtils.verifyServerCertificateChain( this.expectedNodeIds, webSocketUtils.detailedToCertChain(peerCert), - ) - .then(authenticateProm.resolveP, authenticateProm.rejectP); + ); + authenticateProm.resolveP({ + nodeId: nodesUtils.encodeNodeId(nodeId), + localHost: request.connection.localAddress ?? '', + localPort: request.connection.localPort ?? 0, + remoteHost: request.connection.remoteAddress ?? '', + remotePort: request.connection.remotePort ?? 0, + }); + } catch (e) { + authenticateProm.rejectP(e); + } }); ws.once('open', () => { this.logger.info('starting connection'); @@ -222,17 +226,14 @@ class WebSocketClient { // Constructing the `ReadableWritablePair`, the lifecycle is handed off to // the webSocketStream at this point. - const webSocketStreamClient = new WebSocketStreamClientInternal( + const webSocketStreamClient = new WebSocketStream( ws, - this.maxReadableStreamBytes, this.pingIntervalTime, this.pingTimeoutTimeTime, { - host: this.host, - nodeId: nodesUtils.encodeNodeId(await authenticateProm.p), - port: this.port, + ...(await authenticateProm.p), }, - this.logger, + this.logger.getChild(WebSocketStream.name), ); const abortStream = () => { webSocketStreamClient.cancel( @@ -258,219 +259,4 @@ class WebSocketClient { } // This is the internal implementation of the client's stream pair. -class WebSocketStreamClientInternal extends WebSocketStream { - protected readableController: - | ReadableStreamController - | undefined; - protected writableController: WritableStreamDefaultController | undefined; - - constructor( - protected ws: WebSocket, - maxReadableStreamBytes: number, - pingInterval: number, - pingTimeoutTime: number, - protected clientMetadata: { - nodeId: NodeIdEncoded; - host: string; - port: number; - }, - logger: Logger, - ) { - super(); - const readableLogger = logger.getChild('readable'); - const writableLogger = logger.getChild('writable'); - - this.readable = new ReadableStream( - { - start: (controller) => { - this.readableController = controller; - readableLogger.info('Starting'); - const messageHandler = (data) => { - readableLogger.debug(`Received ${data.toString()}`); - if (controller.desiredSize == null) { - controller.error(Error('NEVER')); - return; - } - if (controller.desiredSize < 0) { - readableLogger.debug('Applying readable backpressure'); - ws.pause(); - } - const message = data as Buffer; - if (message.length === 0) { - readableLogger.debug('Null message received'); - ws.removeListener('message', messageHandler); - if (!this._readableEnded) { - this.signalReadableEnd(); - readableLogger.debug('Closing'); - controller.close(); - } - if (this._writableEnded) { - logger.debug('Closing socket'); - ws.close(); - } - return; - } - controller.enqueue(message); - }; - readableLogger.debug('Registering socket message handler'); - ws.on('message', messageHandler); - ws.once('close', (code, reason) => { - logger.info('Socket closed'); - ws.removeListener('message', messageHandler); - if (!this._readableEnded) { - readableLogger.debug( - `Closed early, ${code}, ${reason.toString()}`, - ); - const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); - this.signalReadableEnd(e); - controller.error(e); - } - }); - ws.once('error', (e) => { - if (!this._readableEnded) { - readableLogger.error(e); - this.signalReadableEnd(e); - controller.error(e); - } - }); - }, - cancel: (reason) => { - readableLogger.debug('Cancelled'); - this.signalReadableEnd(reason); - if (!this._writableEnded) { - readableLogger.debug('Closing socket'); - this.signalWritableEnd(reason); - ws.close(); - } - }, - pull: () => { - readableLogger.debug('Releasing backpressure'); - ws.resume(); - }, - }, - { - highWaterMark: maxReadableStreamBytes, - size: (chunk) => chunk?.byteLength ?? 0, - }, - ); - this.writable = new WritableStream({ - start: (controller) => { - this.writableController = controller; - writableLogger.info('Starting'); - ws.once('error', (e) => { - if (!this._writableEnded) { - writableLogger.error(e); - this.signalWritableEnd(e); - controller.error(e); - } - }); - ws.once('close', (code, reason) => { - if (!this._writableEnded) { - writableLogger.debug(`Closed early, ${code}, ${reason.toString()}`); - const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); - this.signalWritableEnd(e); - controller.error(e); - } - }); - }, - close: () => { - writableLogger.debug('Closing, sending null message'); - ws.send(Buffer.from([])); - this.signalWritableEnd(); - if (this._readableEnded) { - writableLogger.debug('Closing socket'); - ws.close(); - } - }, - abort: (reason) => { - writableLogger.debug('Aborted'); - this.signalWritableEnd(reason); - if (this._readableEnded) { - writableLogger.debug('Closing socket'); - ws.close(); - } - }, - write: async (chunk, controller) => { - if (this._writableEnded) return; - writableLogger.debug(`Sending ${chunk?.toString()}`); - const wait = promise(); - ws.send(chunk, (e) => { - if (e != null && !this._writableEnded) { - // Opting to debug message here and not log an error, sending - // failure is common if we send before the close event. - writableLogger.debug('failed to send'); - const err = new webSocketErrors.ErrorClientConnectionEndedEarly( - undefined, - { - cause: e, - }, - ); - this.signalWritableEnd(err); - controller.error(err); - } - wait.resolveP(); - }); - await wait.p; - }, - }); - - // Setting up heartbeat - const pingTimer = setInterval(() => { - ws.ping(); - }, pingInterval); - const pingTimeoutTimeTimer = setTimeout(() => { - logger.debug('Ping timed out'); - ws.close(4002, 'Timed out'); - }, pingTimeoutTime); - ws.on('ping', () => { - logger.debug('Received ping'); - ws.pong(); - }); - ws.on('pong', () => { - logger.debug('Received pong'); - pingTimeoutTimeTimer.refresh(); - }); - ws.once('close', (code, reason) => { - logger.debug('WebSocket closed'); - const err = - code !== 1000 - ? new webSocketErrors.ErrorClientConnectionEndedEarly( - `ended with code ${code}, ${reason.toString()}`, - ) - : undefined; - this.signalWebSocketEnd(err); - logger.debug('Cleaning up timers'); - // Clean up timers - clearTimeout(pingTimer); - clearTimeout(pingTimeoutTimeTimer); - }); - } - - get meta(): Record { - // Spreading to avoid modifying the data - return { - ...this.clientMetadata, - }; - } - - cancel(reason?: any): void { - // Default error - const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); - // Close the streams with the given error, - if (!this._readableEnded) { - this.readableController?.error(err); - this.signalReadableEnd(err); - } - if (!this._writableEnded) { - this.writableController?.error(err); - this.signalWritableEnd(err); - } - // Then close the websocket - if (!this._webSocketEnded) { - this.ws.close(4000, 'Ending connection'); - this.signalWebSocketEnd(err); - } - } -} - export default WebSocketClient; diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index 43bd31ffe..9df3c552d 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -1,11 +1,5 @@ -import type { - ReadableStreamController, - WritableStreamDefaultController, -} from 'stream/web'; -import type { JSONValue } from '../types'; import type { TLSConfig } from '../network/types'; import type { IncomingMessage, ServerResponse } from 'http'; -import { WritableStream, ReadableStream } from 'stream/web'; import https from 'https'; import { startStop, status } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; @@ -39,7 +33,6 @@ class WebSocketServer extends EventTarget { * Default is 1,000 milliseconds. * @param obj.pingTimeoutTimeTime - Time before connection is cleaned up after no ping responses. * Default is 10,000 milliseconds. - * @param obj.maxReadableStreamBytes - The number of bytes the readable stream will buffer until pausing. * @param obj.logger */ static async createWebSocketServer({ @@ -93,7 +86,6 @@ class WebSocketServer extends EventTarget { /** * * @param logger - * @param maxReadableStreamBytes Max number of bytes stored in read buffer before error * @param maxIdleTimeout * @param pingIntervalTime * @param pingTimeoutTimeTime @@ -145,21 +137,10 @@ class WebSocketServer extends EventTarget { this.server.on('error', this.errorHandler); this.server.on('request', this.requestHandler); - // This.server.any('/*', (res, _) => { - // // Reject normal requests with an upgrade code - // res - // .writeStatus('426') - // .writeHeader('connection', 'Upgrade') - // .writeHeader('upgrade', 'websocket') - // .end('426 Upgrade Required', true); - // }); - - // TODO: tell normal requests to upgrade. const listenProm = promise(); this.server.listen(port ?? 0, host, listenProm.resolveP); await listenProm.p; const address = this.server.address(); - // TODO: handle string if (address == null || typeof address === 'string') never(); this._port = address.port; this.logger.debug(`Listening on port ${this._port}`); @@ -241,18 +222,18 @@ class WebSocketServer extends EventTarget { webSocket: ws.WebSocket, request: IncomingMessage, ) => { - const socket = request.connection; - const webSocketStream = new WebSocketStreamServerInternal( + const connection = request.connection; + const webSocketStream = new WebSocketStream( webSocket, this.pingIntervalTime, this.pingTimeoutTimeTime, { - localHost: socket.localAddress ?? '', - localPort: socket.localPort ?? 0, - remoteHost: socket.remoteAddress ?? '', - remotePort: socket.remotePort ?? '', + localHost: connection.localAddress ?? '', + localPort: connection.localPort ?? 0, + remoteHost: connection.remoteAddress ?? '', + remotePort: connection.remotePort ?? 0, }, - this.logger.getChild(WebSocketStreamServerInternal.name), + this.logger.getChild(WebSocketStream.name), ); // Adding socket to the active sockets map this.activeSockets.add(webSocketStream); @@ -306,189 +287,4 @@ class WebSocketServer extends EventTarget { }; } -class WebSocketStreamServerInternal extends WebSocketStream { - protected writableController: WritableStreamDefaultController | undefined; - protected readableController: - | ReadableStreamController - | undefined; - protected messageHandler: (data: ws.RawData, isBinary: boolean) => void; - - constructor( - protected webSocket: ws.WebSocket, - pingInterval: number, - pingTimeoutTime: number, - protected metadata: Record, - protected logger: Logger, - ) { - super(); - logger.info('WS opened'); - const writableLogger = logger.getChild('Writable'); - const readableLogger = logger.getChild('Readable'); - // Setting up the writable stream - this.writable = new WritableStream({ - start: (controller) => { - this.writableController = controller; - }, - write: async (chunk, controller) => { - const writeResultProm = promise(); - this.webSocket.send(chunk, (err) => { - if (err == null) writeResultProm.resolveP(); - else writeResultProm.rejectP(err); - }); - try { - await writeResultProm.p; - writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); - } catch (e) { - this.logger.error(e); - controller.error(new webSocketErrors.ErrorServerSendFailed()); - } - }, - close: async () => { - writableLogger.info('Closed, sending null message'); - if (!this._webSocketEnded) { - const endProm = promise(); - this.webSocket.send(Buffer.from([]), (err) => { - if (err == null) endProm.resolveP(); - else endProm.rejectP(err); - }); - await endProm.p; - } - this.signalWritableEnd(); - if (this._readableEnded && !this._webSocketEnded) { - writableLogger.debug('Ending socket'); - this.signalWebSocketEnd(); - this.webSocket.close(); - } - }, - abort: (reason) => { - writableLogger.info('Aborted'); - if (this._readableEnded && !this._webSocketEnded) { - writableLogger.debug('Ending socket'); - this.signalWebSocketEnd(reason); - this.webSocket.close(4000, 'Aborting connection'); - } - }, - }, - { - highWaterMark: 1, - }); - // Setting up the readable stream - this.messageHandler = (data: ws.RawData, isBinary: boolean) => { - if (!isBinary) never(); - if (data instanceof Array) never(); - const messageBuffer = Buffer.from(data); - readableLogger.debug(`Received ${messageBuffer.toString()}`); - if (messageBuffer.byteLength === 0) { - readableLogger.debug('Null message received'); - this.webSocket.off('message', this.messageHandler); - if (!this._readableEnded) { - readableLogger.debug('Closing'); - this.signalReadableEnd(); - this.readableController!.close(); - if (this._writableEnded && !this._webSocketEnded) { - readableLogger.debug('Ending socket'); - this.signalWebSocketEnd(); - this.webSocket.close(); - } - } - return; - } - this.readableController!.enqueue(messageBuffer); - if ( - this.readableController!.desiredSize != null && - this.readableController!.desiredSize < 0 - ) { - this.webSocket.pause(); - } - }; - this.readable = new ReadableStream( - { - start: (controller) => { - this.readableController = controller; - this.webSocket.on('message', this.messageHandler); - }, - pull: () => { - this.webSocket.resume(); - }, - cancel: (reason) => { - this.webSocket.off('message', this.messageHandler); - this.signalReadableEnd(reason); - if (this._writableEnded && !this._webSocketEnded) { - readableLogger.debug('Ending socket'); - this.signalWebSocketEnd(); - this.webSocket.close(); - } - }, - }, - { - highWaterMark: 1, - }, - ); - - const pingTimer = setInterval(() => { - this.webSocket.ping(); - }, pingInterval); - const pingTimeoutTimeTimer = setTimeout(() => { - logger.debug('Ping timed out'); - this.webSocket.close(); - }, pingTimeoutTime); - const pongHandler = (data: Buffer) => { - logger.debug(`Received pong with (${data.toString()})`); - pingTimeoutTimeTimer.refresh(); - }; - this.webSocket.on('pong', pongHandler); - - const closeHandler = () => { - logger.debug('Closing'); - this.signalWebSocketEnd(); - // Cleaning up timers - logger.debug('Cleaning up timers'); - clearTimeout(pingTimer); - clearTimeout(pingTimeoutTimeTimer); - // Closing streams - logger.debug('Cleaning streams'); - this.webSocket.off('message', this.messageHandler); - const err = new webSocketErrors.ErrorServerConnectionEndedEarly(); - if (!this._readableEnded) { - readableLogger.debug('Closing'); - this.signalReadableEnd(err); - this.webSocket.off('message', this.messageHandler); - this.readableController?.error(err); - } - if (!this._writableEnded) { - writableLogger.debug('Closing'); - this.signalWritableEnd(err); - this.writableController?.error(err); - } - }; - this.webSocket.once('close', closeHandler); - } - - get meta(): Record { - return { - ...this.metadata, - }; - } - - cancel(reason?: any): void { - // Default error - const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); - // Close the streams with the given error, - if (!this._readableEnded) { - this.webSocket.off('message', this.messageHandler); - this.readableController?.error(err); - this.signalReadableEnd(err); - } - if (!this._writableEnded) { - this.writableController?.error(err); - this.signalWritableEnd(err); - } - // Then close the websocket - if (!this._webSocketEnded) { - this.webSocket.terminate(); - this.signalWebSocketEnd(err); - } - } -} - export default WebSocketServer; diff --git a/src/websockets/WebSocketStream.ts b/src/websockets/WebSocketStream.ts index f71d50e26..d2d5627cf 100644 --- a/src/websockets/WebSocketStream.ts +++ b/src/websockets/WebSocketStream.ts @@ -1,13 +1,18 @@ +import type { ReadableWritablePair } from 'stream/web'; import type { - ReadableStream, - ReadableWritablePair, - WritableStream, + ReadableStreamController, + WritableStreamDefaultController, } from 'stream/web'; +import type * as ws from 'ws'; +import type Logger from '@matrixai/logger'; +import type { NodeIdEncoded } from '../ids/types'; +import type { JSONValue } from '../types'; +import { WritableStream, ReadableStream } from 'stream/web'; +import * as webSocketErrors from './errors'; +import * as utilsErrors from '../utils/errors'; import { promise } from '../utils'; -abstract class WebSocketStream - implements ReadableWritablePair -{ +class WebSocketStream implements ReadableWritablePair { public readable: ReadableStream; public writable: WritableStream; @@ -19,7 +24,24 @@ abstract class WebSocketStream protected _webSocketEndedProm = promise(); protected _endedProm: Promise; - protected constructor() { + protected readableController: + | ReadableStreamController + | undefined; + protected writableController: WritableStreamDefaultController | undefined; + + constructor( + protected ws: ws.WebSocket, + pingInterval: number, + pingTimeoutTime: number, + protected metadata: { + nodeId?: NodeIdEncoded; + localHost: string; + localPort: number; + remoteHost: string; + remotePort: number; + }, + logger: Logger, + ) { // Sanitise promises so they don't result in unhandled rejections this._readableEndedProm.p.catch(() => {}); this._writableEndedProm.p.catch(() => {}); @@ -42,6 +64,193 @@ abstract class WebSocketStream }); // Ignore errors if it's never used this._endedProm.catch(() => {}); + + logger.info('WS opened'); + const readableLogger = logger.getChild('readable'); + const writableLogger = logger.getChild('writable'); + // Setting up the readable stream + this.readable = new ReadableStream( + { + start: (controller) => { + readableLogger.debug('Starting'); + this.readableController = controller; + const messageHandler = (data: ws.RawData, isBinary: boolean) => { + if (!isBinary || data instanceof Array) { + controller.error(new utilsErrors.ErrorUtilsUndefinedBehaviour()); + return; + } + const message = data as Buffer; + readableLogger.debug(`Received ${message.toString()}`); + if (message.length === 0) { + readableLogger.debug('Null message received'); + ws.removeListener('message', messageHandler); + if (!this._readableEnded) { + readableLogger.debug('Closing'); + this.signalReadableEnd(); + controller.close(); + } + if (this._writableEnded) { + logger.debug('Closing socket'); + ws.close(); + } + return; + } + if (this._readableEnded) { + return; + } + controller.enqueue(message); + if (controller.desiredSize == null) { + controller.error(new utilsErrors.ErrorUtilsUndefinedBehaviour()); + return; + } + if (controller.desiredSize < 0) { + readableLogger.debug('Applying readable backpressure'); + ws.pause(); + } + }; + readableLogger.debug('Registering socket message handler'); + ws.on('message', messageHandler); + ws.once('close', (code, reason) => { + logger.info('Socket closed'); + ws.removeListener('message', messageHandler); + if (!this._readableEnded) { + readableLogger.debug( + `Closed early, ${code}, ${reason.toString()}`, + ); + const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); + this.signalReadableEnd(e); + controller.error(e); + } + }); + ws.once('error', (e) => { + if (!this._readableEnded) { + readableLogger.error(e); + this.signalReadableEnd(e); + controller.error(e); + } + }); + }, + cancel: (reason) => { + readableLogger.debug('Cancelled'); + this.signalReadableEnd(reason); + if (this._writableEnded) { + readableLogger.debug('Closing socket'); + this.signalWritableEnd(reason); + ws.close(); + } + }, + pull: () => { + readableLogger.debug('Releasing backpressure'); + ws.resume(); + }, + }, + { highWaterMark: 1 }, + ); + this.writable = new WritableStream( + { + start: (controller) => { + this.writableController = controller; + writableLogger.info('Starting'); + ws.once('error', (e) => { + if (!this._writableEnded) { + writableLogger.error(e); + this.signalWritableEnd(e); + controller.error(e); + } + }); + ws.once('close', (code, reason) => { + if (!this._writableEnded) { + writableLogger.debug( + `Closed early, ${code}, ${reason.toString()}`, + ); + const e = new webSocketErrors.ErrorClientConnectionEndedEarly(); + this.signalWritableEnd(e); + controller.error(e); + } + }); + }, + close: async () => { + writableLogger.debug('Closing, sending null message'); + const sendProm = promise(); + ws.send(Buffer.from([]), (err) => { + if (err == null) sendProm.resolveP(); + else sendProm.rejectP(err); + }); + await sendProm.p; + this.signalWritableEnd(); + if (this._readableEnded) { + writableLogger.debug('Closing socket'); + ws.close(); + } + }, + abort: (reason) => { + writableLogger.debug('Aborted'); + this.signalWritableEnd(reason); + if (this._readableEnded) { + writableLogger.debug('Closing socket'); + ws.close(4000, `Aborting connection with ${reason.message}`); + } + }, + write: async (chunk, controller) => { + if (this._writableEnded) return; + writableLogger.debug(`Sending ${chunk?.toString()}`); + const wait = promise(); + ws.send(chunk, (e) => { + if (e != null && !this._writableEnded) { + // Opting to debug message here and not log an error, sending + // failure is common if we send before the close event. + writableLogger.debug('failed to send'); + const err = new webSocketErrors.ErrorClientConnectionEndedEarly( + undefined, + { + cause: e, + }, + ); + this.signalWritableEnd(err); + controller.error(err); + } + wait.resolveP(); + }); + await wait.p; + }, + }, + { highWaterMark: 1 }, + ); + + // Setting up heartbeat + const pingTimer = setInterval(() => { + ws.ping(); + }, pingInterval); + const pingTimeoutTimeTimer = setTimeout(() => { + logger.debug('Ping timed out'); + ws.close(4002, 'Timed out'); + }, pingTimeoutTime); + const pingHandler = () => { + logger.debug('Received ping'); + ws.pong(); + }; + const pongHandler = () => { + logger.debug('Received pong'); + pingTimeoutTimeTimer.refresh(); + }; + ws.on('ping', pingHandler); + ws.on('pong', pongHandler); + ws.once('close', (code, reason) => { + ws.off('ping', pingHandler); + ws.off('pong', pongHandler); + logger.debug('WebSocket closed'); + const err = + code !== 1000 + ? new webSocketErrors.ErrorClientConnectionEndedEarly( + `ended with code ${code}, ${reason.toString()}`, + ) + : undefined; + this.signalWebSocketEnd(err); + logger.debug('Cleaning up timers'); + // Clean up timers + clearTimeout(pingTimer); + clearTimeout(pingTimeoutTimeTimer); + }); } get readableEnded() { @@ -88,10 +297,34 @@ abstract class WebSocketStream return this._endedProm; } + get meta(): Record { + // Spreading to avoid modifying the data + return { + ...this.metadata, + }; + } + /** * Forces the active stream to end early */ - abstract cancel(reason?: any): void; + public cancel(reason?: any): void { + // Default error + const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); + // Close the streams with the given error, + if (!this._readableEnded) { + this.readableController?.error(err); + this.signalReadableEnd(err); + } + if (!this._writableEnded) { + this.writableController?.error(err); + this.signalWritableEnd(err); + } + // Then close the websocket + if (!this._webSocketEnded) { + this.ws.close(4000, 'Ending connection'); + this.signalWebSocketEnd(err); + } + } /** * Signals the end of the ReadableStream. to be used with the extended class