From f8555f1080f7e6e542243450af5cfd96f7282e41 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 11 Aug 2023 17:11:58 +1000 Subject: [PATCH] fix: `WebSocketServer` can now update TLS certs while running * Related #540 [ci skip] --- src/PolykeyAgent.ts | 5 +- src/websockets/WebSocketServer.ts | 15 ++-- src/websockets/WebSocketStream.ts | 3 +- tests/websockets/WebSocket.test.ts | 106 ++++++++++++++++------------- 4 files changed, 71 insertions(+), 58 deletions(-) diff --git a/src/PolykeyAgent.ts b/src/PolykeyAgent.ts index a833a5a1c..9d2c62f88 100644 --- a/src/PolykeyAgent.ts +++ b/src/PolykeyAgent.ts @@ -739,10 +739,7 @@ class PolykeyAgent { keyPrivatePem: keysUtils.privateKeyToPEM(data.keyPair.privateKey), certChainPem: await this.certManager.getCertPEMsChainPEM(), }; - // FIXME: Can we even support updating TLS config anymore? - // We would need to shut down the Websocket server and re-create it with the new config. - // Right now graceful shutdown is not supported. - // this.grpcServerClient.setTLSConfig(tlsConfig); + this.webSocketServerClient.setTlsConfig(tlsConfig); this.nodeConnectionManager.updateTlsConfig(tlsConfig); this.logger.info(`${KeyRing.name} change propagated`); }, diff --git a/src/websockets/WebSocketServer.ts b/src/websockets/WebSocketServer.ts index 8db94519a..5d99f79a5 100644 --- a/src/websockets/WebSocketServer.ts +++ b/src/websockets/WebSocketServer.ts @@ -1,5 +1,6 @@ import type { TLSConfig } from '../network/types'; import type { IncomingMessage, ServerResponse } from 'http'; +import type tls from 'tls'; import https from 'https'; import { startStop, status } from '@matrixai/async-init'; import Logger from '@matrixai/logger'; @@ -24,7 +25,6 @@ class WebSocketServer extends EventTarget { * @param obj * @param obj.connectionCallback - * @param obj.tlsConfig - TLSConfig containing the private key and cert chain used for TLS. - * @param obj.basePath - Directory path used for storing temp cert files for starting the `uWebsocket` server. * @param obj.host - Listen address to bind to. * @param obj.port - Listen port to bind to. * @param obj.maxIdleTimeout - Timeout time for when the connection is cleaned up after no activity. @@ -38,7 +38,6 @@ class WebSocketServer extends EventTarget { static async createWebSocketServer({ connectionCallback, tlsConfig, - basePath, host, port, maxIdleTimeout = 120, @@ -48,7 +47,6 @@ class WebSocketServer extends EventTarget { }: { connectionCallback: ConnectionCallback; tlsConfig: TLSConfig; - basePath?: string; host?: string; port?: number; maxIdleTimeout?: number; @@ -66,7 +64,6 @@ class WebSocketServer extends EventTarget { await wsServer.start({ connectionCallback, tlsConfig, - basePath, host, port, }); @@ -106,7 +103,6 @@ class WebSocketServer extends EventTarget { connectionCallback, }: { tlsConfig: TLSConfig; - basePath?: string; host?: string; port?: number; connectionCallback?: ConnectionCallback; @@ -212,6 +208,15 @@ class WebSocketServer extends EventTarget { return this._host; } + @startStop.ready(new webSocketErrors.ErrorWebSocketServerNotRunning()) + public setTlsConfig(tlsConfig: TLSConfig): void { + const tlsServer = this.server as tls.Server; + tlsServer.setSecureContext({ + key: tlsConfig.keyPrivatePem, + cert: tlsConfig.certChainPem, + }); + } + /** * Handles the creation of the `ReadableWritablePair` and provides it to the * StreamPair handler. diff --git a/src/websockets/WebSocketStream.ts b/src/websockets/WebSocketStream.ts index d2d5627cf..ea5934cfc 100644 --- a/src/websockets/WebSocketStream.ts +++ b/src/websockets/WebSocketStream.ts @@ -6,7 +6,6 @@ import type { import type * as ws from 'ws'; import type Logger from '@matrixai/logger'; import type { NodeIdEncoded } from '../ids/types'; -import type { JSONValue } from '../types'; import { WritableStream, ReadableStream } from 'stream/web'; import * as webSocketErrors from './errors'; import * as utilsErrors from '../utils/errors'; @@ -297,7 +296,7 @@ class WebSocketStream implements ReadableWritablePair { return this._endedProm; } - get meta(): Record { + get meta() { // Spreading to avoid modifying the data return { ...this.metadata, diff --git a/tests/websockets/WebSocket.test.ts b/tests/websockets/WebSocket.test.ts index b7fb00356..64815dcf9 100644 --- a/tests/websockets/WebSocket.test.ts +++ b/tests/websockets/WebSocket.test.ts @@ -1,6 +1,7 @@ import type { ReadableWritablePair } from 'stream/web'; import type { TLSConfig } from '@/network/types'; import type { KeyPair } from '@/keys/types'; +import type { NodeId } from '@/ids/types'; import type http from 'http'; import fs from 'fs'; import path from 'path'; @@ -10,7 +11,6 @@ import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import { Timer } from '@matrixai/timer'; import { status } from '@matrixai/async-init'; -import { KeyRing } from '@/keys/index'; import WebSocketServer from '@/websockets/WebSocketServer'; import WebSocketClient from '@/websockets/WebSocketClient'; import { promise } from '@/utils'; @@ -28,8 +28,8 @@ describe('WebSocket', () => { formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, ), ]); - let dataDir: string; - let keyRing: KeyRing; + let keyPair: KeyPair; + let nodeId: NodeId; let tlsConfig: TLSConfig; const host = '127.0.0.2'; let webSocketServer: WebSocketServer; @@ -63,19 +63,14 @@ describe('WebSocket', () => { dataDir = await fs.promises.mkdtemp( path.join(os.tmpdir(), 'polykey-test-'), ); - const keysPath = path.join(dataDir, 'keys'); - keyRing = await KeyRing.createKeyRing({ - keysPath: keysPath, - password: 'password', - logger: logger.getChild('keyRing'), - }); - tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + keyPair = keysUtils.generateKeyPair(); + nodeId = keysUtils.publicKeyToNodeId(keyPair.publicKey); + tlsConfig = await testsUtils.createTLSConfig(keyPair); }); afterEach(async () => { logger.info('AFTEREACH'); await webSocketServer?.stop(true); await webSocketClient?.destroy(true); - await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); }); @@ -89,7 +84,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -98,7 +92,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -115,6 +109,42 @@ describe('WebSocket', () => { expect((await reader.read()).done).toBeTrue(); logger.info('ending'); }); + test('can change TLS config', async () => { + const keyPairNew = keysUtils.generateKeyPair(); + const nodeIdNew = keysUtils.publicKeyToNodeId(keyPairNew.publicKey); + const tlsConfigNew = await testsUtils.createTLSConfig(keyPairNew); + + webSocketServer = await WebSocketServer.createWebSocketServer({ + connectionCallback: (streamPair) => { + logger.info('inside callback'); + void streamPair.readable + .pipeTo(streamPair.writable) + .catch(() => {}) + .finally(() => logger.info('STREAM HANDLING ENDED')); + }, + tlsConfig, + host, + logger: logger.getChild('server'), + }); + logger.info(`Server started on port ${webSocketServer.getPort()}`); + webSocketClient = await WebSocketClient.createWebSocketClient({ + host, + port: webSocketServer.getPort(), + expectedNodeIds: [nodeId, nodeIdNew], + logger: logger.getChild('clientClient'), + }); + const websocket = await webSocketClient.startConnection(); + expect(websocket.meta.nodeId).toBe(nodesUtils.encodeNodeId(nodeId)); + websocket.cancel(); + + // Changing certs + webSocketServer.setTlsConfig(tlsConfigNew); + const websocket2 = await webSocketClient.startConnection(); + expect(websocket2.meta.nodeId).toBe(nodesUtils.encodeNodeId(nodeIdNew)); + websocket2.cancel(); + + logger.info('ending'); + }); test('makes a connection over IPv6', async () => { webSocketServer = await WebSocketServer.createWebSocketServer({ connectionCallback: (streamPair) => { @@ -124,7 +154,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host: '::1', logger: logger.getChild('server'), @@ -133,7 +162,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host: '::1', port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -159,7 +188,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -168,7 +196,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -190,7 +218,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -199,7 +226,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); @@ -236,7 +263,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -263,7 +289,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -289,7 +314,6 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -307,7 +331,7 @@ describe('WebSocket', () => { env: { PK_TEST_HOST: host, PK_TEST_PORT: `${webSocketServer.getPort()}`, - PK_TEST_NODE_ID: nodesUtils.encodeNodeId(keyRing.getNodeId()), + PK_TEST_NODE_ID: nodesUtils.encodeNodeId(nodeId), }, }, logger, @@ -369,7 +393,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: await startedProm.p, - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -412,7 +436,6 @@ describe('WebSocket', () => { } })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -421,7 +444,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -453,7 +476,6 @@ describe('WebSocket', () => { await writer.close(); })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -462,7 +484,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -492,7 +514,6 @@ describe('WebSocket', () => { await writer.close(); })().then(serverStreamProm.resolveP, serverStreamProm.rejectP); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -501,7 +522,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -521,7 +542,6 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -530,7 +550,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -562,7 +582,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -592,7 +611,6 @@ describe('WebSocket', () => { logger.info('inside callback'); // Hang connection }, - basePath: dataDir, tlsConfig, host, pingTimeoutTimeTime: 100, @@ -602,7 +620,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); await webSocketClient.startConnection(); @@ -619,7 +637,6 @@ describe('WebSocket', () => { logger.info('inside callback'); streamPairProm.resolveP(streamPair); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -628,7 +645,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const websocket = await webSocketClient.startConnection(); @@ -663,7 +680,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -686,7 +702,7 @@ describe('WebSocket', () => { }); test('authenticates with multiple certs in chain', async () => { const keyPairs: Array = [ - keyRing.keyPair, + keyPair, keysUtils.generateKeyPair(), keysUtils.generateKeyPair(), keysUtils.generateKeyPair(), @@ -702,7 +718,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -732,7 +747,6 @@ describe('WebSocket', () => { .catch(() => {}) .finally(() => logger.info('STREAM HANDLING ENDED')); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -741,7 +755,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId(), alternativeNodeId], + expectedNodeIds: [nodeId, alternativeNodeId], logger: logger.getChild('clientClient'), }); await expect(webSocketClient.startConnection()).toResolve(); @@ -754,7 +768,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: 12345, - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], connectionTimeoutTime: 0, logger: logger.getChild('clientClient'), }); @@ -772,7 +786,6 @@ describe('WebSocket', () => { logger.info('inside callback'); // Hang connection }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -781,7 +794,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], pingTimeoutTimeTime: 100, logger: logger.getChild('clientClient'), }); @@ -809,7 +822,6 @@ describe('WebSocket', () => { })().catch(() => {}), ]); }, - basePath: dataDir, tlsConfig, host, logger: logger.getChild('server'), @@ -818,7 +830,7 @@ describe('WebSocket', () => { webSocketClient = await WebSocketClient.createWebSocketClient({ host, port: webSocketServer.getPort(), - expectedNodeIds: [keyRing.getNodeId()], + expectedNodeIds: [nodeId], logger: logger.getChild('clientClient'), }); const abortController = new AbortController();