diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index f5c011646..43bd31ffe 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -4,10 +4,10 @@ import type { } from 'stream/web'; import type { JSONValue } from '../types'; import type { TLSConfig } from '../network/types'; -import type { IncomingMessage } from 'http'; +import type { IncomingMessage, ServerResponse } from 'http'; import { WritableStream, ReadableStream } from 'stream/web'; import https from 'https'; -import { startStop } from '@matrixai/async-init'; +import { startStop, status } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; import * as ws from 'ws'; import WebSocketStream from './WebSocketStream'; @@ -51,7 +51,6 @@ class WebSocketServer extends EventTarget { maxIdleTimeout = 120, pingIntervalTime = 1_000, pingTimeoutTimeTime = 10_000, - maxReadableStreamBytes = 1_000, // About 1 GB logger = new Logger(this.name), }: { connectionCallback: ConnectionCallback; @@ -62,13 +61,11 @@ class WebSocketServer extends EventTarget { maxIdleTimeout?: number; pingIntervalTime?: number; pingTimeoutTimeTime?: number; - maxReadableStreamBytes?: number; logger?: Logger; }) { logger.info(`Creating ${this.name}`); const wsServer = new this( logger, - maxReadableStreamBytes, maxIdleTimeout, pingIntervalTime, pingTimeoutTimeTime, @@ -103,7 +100,6 @@ class WebSocketServer extends EventTarget { */ constructor( protected logger: Logger, - protected maxReadableStreamBytes, protected maxIdleTimeout: number | undefined, protected pingIntervalTime: number, protected pingTimeoutTimeTime: number, @@ -143,8 +139,20 @@ class WebSocketServer extends EventTarget { }); this.webSocketServer.on('connection', this.connectionHandler); - // This.webSocketServer.on('error', console.error); - // this.webSocketServer.on('close', this.closeHandler); + this.webSocketServer.on('close', this.closeHandler); + this.server.on('close', this.closeHandler); + this.webSocketServer.on('error', this.errorHandler); + this.server.on('error', this.errorHandler); + this.server.on('request', this.requestHandler); + + // This.server.any('/*', (res, _) => { + // // Reject normal requests with an upgrade code + // res + // .writeStatus('426') + // .writeHeader('connection', 'Upgrade') + // .writeHeader('upgrade', 'websocket') + // .end('426 Upgrade Required', true); + // }); // TODO: tell normal requests to upgrade. const listenProm = promise(); @@ -169,8 +177,6 @@ class WebSocketServer extends EventTarget { public async stop(force: boolean = false): Promise { this.logger.info(`Stopping ${this.constructor.name}`); - // Close the server by closing the underlying socket - this.server.close(); // Shutting down active websockets if (force) { for (const webSocketStream of this.activeSockets) { @@ -182,9 +188,37 @@ class WebSocketServer extends EventTarget { // Ignore errors, we only care that it finished webSocketStream.endedProm.catch(() => {}); } + // Close the server by closing the underlying socket + const wssCloseProm = promise(); + this.webSocketServer.close((e) => { + if (e == null || e.message === 'The server is not running') { + wssCloseProm.resolveP(); + } else { + wssCloseProm.rejectP(e); + } + }); + await wssCloseProm.p; + const serverCloseProm = promise(); + this.server.close((e) => { + if (e == null || e.message === 'Server is not running.') { + serverCloseProm.resolveP(); + } else { + serverCloseProm.rejectP(e); + } + }); + await serverCloseProm.p; + // Removing handlers if (this.connectionEventHandler != null) { this.removeEventListener('connection', this.connectionEventHandler); } + + this.webSocketServer.off('connection', this.connectionHandler); + this.webSocketServer.off('close', this.closeHandler); + this.server.off('close', this.closeHandler); + this.webSocketServer.off('error', this.errorHandler); + this.server.off('error', this.errorHandler); + this.server.on('request', this.requestHandler); + this.dispatchEvent(new webSocketEvents.StopEvent()); this.logger.info(`Stopped ${this.constructor.name}`); } @@ -210,7 +244,6 @@ class WebSocketServer extends EventTarget { const socket = request.connection; const webSocketStream = new WebSocketStreamServerInternal( webSocket, - this.maxReadableStreamBytes, this.pingIntervalTime, this.pingTimeoutTimeTime, { @@ -240,18 +273,48 @@ class WebSocketServer extends EventTarget { }), ); }; + + /** + * Used to trigger stopping if the underlying server fails + */ + protected closeHandler = async () => { + if (this[status] == null || this[status] === 'stopping') { + this.logger.debug('close event but already stopping'); + return; + } + this.logger.debug('close event, forcing stop'); + await this.stop(true); + }; + + /** + * Used to propagate error conditions + */ + protected errorHandler = (e: Error) => { + this.logger.error(e); + }; + + /** + * Will tell any normal HTTP request to upgrade + */ + protected requestHandler = (_req, res: ServerResponse) => { + res + .writeHead(426, '426 Upgrade Required', { + connection: 'Upgrade', + upgrade: 'websocket', + }) + .end('426 Upgrade Required'); + }; } class WebSocketStreamServerInternal extends WebSocketStream { - protected readableBackpressure: boolean = false; protected writableController: WritableStreamDefaultController | undefined; protected readableController: | ReadableStreamController | undefined; + protected messageHandler: (data: ws.RawData, isBinary: boolean) => void; constructor( protected webSocket: ws.WebSocket, - maxReadBufferBytes: number, pingInterval: number, pingTimeoutTime: number, protected metadata: Record, @@ -276,7 +339,6 @@ class WebSocketStreamServerInternal extends WebSocketStream { await writeResultProm.p; writableLogger.debug(`Sending ${Buffer.from(chunk).toString()}`); } catch (e) { - console.error(`Failed sending`, e); this.logger.error(e); controller.error(new webSocketErrors.ErrorServerSendFailed()); } @@ -306,47 +368,50 @@ class WebSocketStreamServerInternal extends WebSocketStream { this.webSocket.close(4000, 'Aborting connection'); } }, + }, + { + highWaterMark: 1, }); // Setting up the readable stream + this.messageHandler = (data: ws.RawData, isBinary: boolean) => { + if (!isBinary) never(); + if (data instanceof Array) never(); + const messageBuffer = Buffer.from(data); + readableLogger.debug(`Received ${messageBuffer.toString()}`); + if (messageBuffer.byteLength === 0) { + readableLogger.debug('Null message received'); + this.webSocket.off('message', this.messageHandler); + if (!this._readableEnded) { + readableLogger.debug('Closing'); + this.signalReadableEnd(); + this.readableController!.close(); + if (this._writableEnded && !this._webSocketEnded) { + readableLogger.debug('Ending socket'); + this.signalWebSocketEnd(); + this.webSocket.close(); + } + } + return; + } + this.readableController!.enqueue(messageBuffer); + if ( + this.readableController!.desiredSize != null && + this.readableController!.desiredSize < 0 + ) { + this.webSocket.pause(); + } + }; this.readable = new ReadableStream( { start: (controller) => { this.readableController = controller; - const messageHandler = (data: ws.RawData, isBinary: boolean) => { - if (!isBinary) never(); - console.log(data.toString()); - if (data instanceof Array) never(); - const messageBuffer = Buffer.from(data); - readableLogger.debug(`Received ${messageBuffer.toString()}`); - if (messageBuffer.byteLength === 0) { - readableLogger.debug('Null message received'); - this.webSocket.off('message', messageHandler); - if (!this._readableEnded) { - readableLogger.debug('Closing'); - this.signalReadableEnd(); - controller.close(); - if (this._writableEnded && !this._webSocketEnded) { - readableLogger.debug('Ending socket'); - this.signalWebSocketEnd(); - this.webSocket.close(); - } - } - return; - } - console.log(this._readableEnded); - controller.enqueue(messageBuffer); - if (controller.desiredSize != null && controller.desiredSize < 0) { - this.webSocket.pause(); - this.readableBackpressure = true; - } - }; - this.webSocket.on('message', messageHandler); + this.webSocket.on('message', this.messageHandler); }, pull: () => { this.webSocket.resume(); - this.readableBackpressure = false; }, cancel: (reason) => { + this.webSocket.off('message', this.messageHandler); this.signalReadableEnd(reason); if (this._writableEnded && !this._webSocketEnded) { readableLogger.debug('Ending socket'); @@ -382,11 +447,12 @@ class WebSocketStreamServerInternal extends WebSocketStream { clearTimeout(pingTimeoutTimeTimer); // Closing streams logger.debug('Cleaning streams'); + this.webSocket.off('message', this.messageHandler); const err = new webSocketErrors.ErrorServerConnectionEndedEarly(); if (!this._readableEnded) { readableLogger.debug('Closing'); this.signalReadableEnd(err); - console.log('EROROROROROED'); + this.webSocket.off('message', this.messageHandler); this.readableController?.error(err); } if (!this._writableEnded) { @@ -409,7 +475,7 @@ class WebSocketStreamServerInternal extends WebSocketStream { const err = reason ?? new webSocketErrors.ErrorClientConnectionEndedEarly(); // Close the streams with the given error, if (!this._readableEnded) { - console.log('ERRORORORROED'); + this.webSocket.off('message', this.messageHandler); this.readableController?.error(err); this.signalReadableEnd(err); } @@ -419,7 +485,7 @@ class WebSocketStreamServerInternal extends WebSocketStream { } // Then close the websocket if (!this._webSocketEnded) { - this.webSocket.close(4000, 'Ending connection'); + this.webSocket.terminate(); this.signalWebSocketEnd(err); } } diff --git a/tests/scratch.test.ts b/tests/scratch.test.ts index 7c748cab7..4959b5280 100644 --- a/tests/scratch.test.ts +++ b/tests/scratch.test.ts @@ -1,73 +1,13 @@ -import type { IncomingMessage } from 'http'; -import type { TLSSocket } from 'tls'; -import https from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import * as ws from 'ws'; -import { sleep } from 'ix/asynciterable/_sleep'; -import NodeManager from '@/nodes/NodeManager'; -import * as keysUtils from '@/keys/utils'; -import { promise } from '@/utils'; -import * as testsUtils from './utils'; // This is a 'scratch paper' test file for quickly running tests in the CI describe('scratch', () => { - const _logger = new Logger(`${NodeManager.name} test`, LogLevel.WARN, [ + const _logger = new Logger(`scratch test`, LogLevel.WARN, [ new StreamHandler(), ]); -}); - -// We can't have empty test files so here is a sanity test -test('Should avoid empty test suite', async () => { - expect(1 + 1).toBe(2); -}); - -test('ws server', async () => { - const keyPair = keysUtils.generateKeyPair(); - const tlsConfig = await testsUtils.createTLSConfig(keyPair); - const server = https.createServer({ - key: tlsConfig.keyPrivatePem, - cert: tlsConfig.certChainPem, - }); - console.log(tlsConfig); - const webSocketServer = new ws.WebSocketServer({ - server, - }); - server.on('listening', (...args) => console.log('listening', args)); - webSocketServer.on( - 'connection', - function connection(ws, request: IncomingMessage) { - console.log(request.connection.localAddress); - console.log(request.connection.localPort); - console.log(request.connection.remoteAddress); - console.log(request.connection.remotePort); - const tlsSocket = request.connection as TLSSocket; - console.log(tlsSocket.getCertificate()); - console.log(tlsSocket.getPeerCertificate()); - ws.on('error', console.error); - - ws.on('message', function message(data) { - console.log('received: %s', data); - }); - - ws.send('something'); - }, - ); - const listenProm = promise(); - server.listen(55555, '127.0.0.1', listenProm.resolveP); - await listenProm.p; - console.log(server.address()); - - // Try connecting! - const webSocket = new ws.WebSocket('wss://127.0.0.1:55555', { - rejectUnauthorized: false, + // We can't have empty test files so here is a sanity test + test('Should avoid empty test suite', async () => { + expect(1 + 1).toBe(2); }); - webSocket.on('error', console.error); - - webSocket.on('open', function open() { - webSocket.send(Buffer.from('HELLO!')); - }); - - await sleep(2000); - server.close(); }); diff --git a/tests/websockets/WebSocket.test.ts b/tests/websockets/WebSocket.test.ts index 8ac3b7cda..b7fb00356 100644 --- a/tests/websockets/WebSocket.test.ts +++ b/tests/websockets/WebSocket.test.ts @@ -9,6 +9,7 @@ import https from 'https'; import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; +import { status } from '@matrixai/async-init'; import { KeyRing } from '@/keys/index'; import WebSocketServer from '@/websockets/WebSocketServer'; import WebSocketClient from '@/websockets/WebSocketClient'; @@ -22,7 +23,7 @@ import * as testsUtils from '../utils'; // This file tests both the client and server together. They're too interlinked // to be separate. describe('WebSocket', () => { - const logger = new Logger('websocket test', LogLevel.DEBUG, [ + const logger = new Logger('websocket test', LogLevel.WARN, [ new StreamHandler( formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), @@ -226,6 +227,60 @@ describe('WebSocket', () => { } }, ); + test('handles https server failure', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.getPort()}`); + + const closeP = promise(); + // @ts-ignore: protected property + webSocketServer.server.close(() => { + closeP.resolveP(); + }); + await closeP.p; + // The webSocketServer should stop itself + expect(webSocketServer[status]).toBe(null); + + logger.info('ending'); + }); + test('handles webSocketServer server failure', async () => { + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + basePath: dataDir, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.getPort()}`); + + const closeP = promise(); + // @ts-ignore: protected property + webSocketServer.webSocketServer.close(() => { + closeP.resolveP(); + }); + await closeP.p; + // The webSocketServer should stop itself + expect(webSocketServer[status]).toBe(null); + + logger.info('ending'); + }); test('client ends connection abruptly', async () => { const streamPairProm = promise>(); @@ -337,11 +392,12 @@ describe('WebSocket', () => { }); // These describe blocks contains tests specific to either the client or server describe('WebSocketServer', () => { - testProp.only( + testProp( 'allows half closed writable closes first', [messagesArb, messagesArb], async (messages1, messages2) => { try { + const serverStreamProm = promise(); webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); @@ -354,7 +410,7 @@ describe('WebSocket', () => { for await (const _ of streamPair.readable) { // No touch, only consume } - })().catch((e) => logger.error(e)); + })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, basePath: dataDir, tlsConfig, @@ -370,6 +426,7 @@ describe('WebSocket', () => { }); const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); + await serverStreamProm.p; logger.info('ending'); } finally { await webSocketServer.stop(true); @@ -381,6 +438,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { + const serverStreamProm = promise(); webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); @@ -393,7 +451,7 @@ describe('WebSocket', () => { await writer.write(val); } await writer.close(); - })().catch((e) => logger.error(e)); + })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, basePath: dataDir, tlsConfig, @@ -409,6 +467,7 @@ describe('WebSocket', () => { }); const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); + await serverStreamProm.p; logger.info('ending'); } finally { await webSocketServer.stop(true); @@ -420,6 +479,7 @@ describe('WebSocket', () => { [messagesArb, messagesArb], async (messages1, messages2) => { try { + const serverStreamProm = promise(); webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { logger.info('inside callback'); @@ -430,7 +490,7 @@ describe('WebSocket', () => { await writer.write(val); } await writer.close(); - })().catch((e) => logger.error(e)); + })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, basePath: dataDir, tlsConfig, @@ -446,6 +506,7 @@ describe('WebSocket', () => { }); const websocket = await webSocketClient.startConnection(); await asyncReadWrite(messages1, websocket); + await serverStreamProm.p; logger.info('ending'); } finally { await webSocketServer.stop(true);