From ecf309cc2884486401f27f167b12da3f2f9ed73b Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 24 Feb 2023 18:26:21 +1100 Subject: [PATCH] wip: created agent handlers and tests [ci skip] --- src/RPC/RPCClient.ts | 4 +- src/RPC/errors.ts | 6 + src/RPC/middleware.ts | 8 ++ src/clientRPC/ClientServer.ts | 5 +- src/clientRPC/handlers/agentLockAll.ts | 30 ++++ src/clientRPC/handlers/agentStop.ts | 32 +++++ src/clientRPC/handlers/index.ts | 37 +++++ tests/clientRPC/handlers/agentLockAll.test.ts | 113 +++++++++++++++ tests/clientRPC/handlers/agentStop.test.ts | 134 ++++++++++++++++++ 9 files changed, 365 insertions(+), 4 deletions(-) create mode 100644 src/clientRPC/handlers/agentLockAll.ts create mode 100644 src/clientRPC/handlers/agentStop.ts create mode 100644 src/clientRPC/handlers/index.ts create mode 100644 tests/clientRPC/handlers/agentLockAll.test.ts create mode 100644 tests/clientRPC/handlers/agentStop.test.ts diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 91f2c729db..6d8e21bdd7 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -134,7 +134,7 @@ class RPCClient { await writer.write(parameters); const output = await reader.read(); if (output.done) { - throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); + throw new rpcErrors.ErrorRpcMissingResponse(); } await reader.cancel(); await writer.close(); @@ -165,7 +165,7 @@ class RPCClient { const reader = callerInterface.readable.getReader(); const output = reader.read().then(({ value, done }) => { if (done) { - throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); + throw new rpcErrors.ErrorRpcMissingResponse(); } return value; }); diff --git a/src/RPC/errors.ts b/src/RPC/errors.ts index c434efdb89..b8a1c2dc87 100644 --- a/src/RPC/errors.ts +++ b/src/RPC/errors.ts @@ -30,6 +30,11 @@ class ErrorRpcMessageLength extends ErrorRpc { exitCode = sysexits.DATAERR; } +class ErrorRpcMissingResponse extends ErrorRpc { + static description = 'Stream ended before response'; + exitCode = sysexits.UNAVAILABLE; +} + class ErrorRpcRemoteError extends ErrorRpc { static description = 'RPC Message exceeds maximum size'; exitCode = sysexits.UNAVAILABLE; @@ -51,6 +56,7 @@ export { ErrorRpcParse, ErrorRpcHandlerFailed, ErrorRpcMessageLength, + ErrorRpcMissingResponse, ErrorRpcRemoteError, ErrorRpcNoMessageError, ErrorRpcPlaceholderConnectionError, diff --git a/src/RPC/middleware.ts b/src/RPC/middleware.ts index 0f3150d832..b185e47ca8 100644 --- a/src/RPC/middleware.ts +++ b/src/RPC/middleware.ts @@ -8,6 +8,7 @@ import type { import { TransformStream } from 'stream/web'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; +import { promise } from '../utils'; const jsonStreamParsers = require('@streamparser/json'); function binaryToJsonMessageStream( @@ -22,6 +23,13 @@ function binaryToJsonMessageStream( let bytesWritten: number = 0; return new TransformStream({ + flush: async () => { + // Avoid potential race conditions by allowing parser to end first + const waitP = promise(); + parser.onEnd = () => waitP.resolveP(); + parser.end(); + await waitP.p; + }, start: (controller) => { if (firstMessage != null) controller.enqueue(firstMessage); parser.onValue = (value) => { diff --git a/src/clientRPC/ClientServer.ts b/src/clientRPC/ClientServer.ts index ee2c595e77..71cfda827e 100644 --- a/src/clientRPC/ClientServer.ts +++ b/src/clientRPC/ClientServer.ts @@ -288,7 +288,8 @@ class ClientServer { start: (controller) => { readableController = controller; context.message = (ws, message, _) => { - readableLogger.debug(`Received ${message.toString()}`); + const messageBuffer = Buffer.from(message); + readableLogger.debug(`Received ${messageBuffer.toString()}`); if (message.byteLength === 0) { readableLogger.debug('Null message received'); if (!readableClosed) { @@ -302,7 +303,7 @@ class ClientServer { } return; } - controller.enqueue(Buffer.from(message)); + controller.enqueue(messageBuffer); if (controller.desiredSize != null && controller.desiredSize < 0) { readableLogger.error('Read stream buffer full'); if (!wsClosed) ws.end(4001, 'Read stream buffer full'); diff --git a/src/clientRPC/handlers/agentLockAll.ts b/src/clientRPC/handlers/agentLockAll.ts new file mode 100644 index 0000000000..a79c5897e0 --- /dev/null +++ b/src/clientRPC/handlers/agentLockAll.ts @@ -0,0 +1,30 @@ +import type Logger from '@matrixai/logger'; +import type { RPCRequestParams, RPCResponseResult } from '../types'; +import type { DB } from '@matrixai/db'; +import type SessionManager from '../../sessions/SessionManager'; +import { UnaryHandler } from '../../RPC/handlers'; +import { UnaryCaller } from '../../RPC/callers'; + +const agentLockAllCaller = new UnaryCaller< + RPCRequestParams, + RPCResponseResult +>(); + +class AgentLockAllHandler extends UnaryHandler< + { + sessionManager: SessionManager; + db: DB; + logger: Logger; + }, + RPCRequestParams, + RPCResponseResult +> { + public async handle(): Promise { + await this.container.db.withTransactionF((tran) => + this.container.sessionManager.resetKey(tran), + ); + return {}; + } +} + +export { agentLockAllCaller, AgentLockAllHandler }; diff --git a/src/clientRPC/handlers/agentStop.ts b/src/clientRPC/handlers/agentStop.ts new file mode 100644 index 0000000000..368a2e7d3d --- /dev/null +++ b/src/clientRPC/handlers/agentStop.ts @@ -0,0 +1,32 @@ +import type Logger from '@matrixai/logger'; +import type { RPCRequestParams, RPCResponseResult } from '../types'; +import type PolykeyAgent from '../../PolykeyAgent'; +import { running, status } from '@matrixai/async-init'; +import { UnaryHandler } from '../../RPC/handlers'; +import { UnaryCaller } from '../../RPC/callers'; + +const agentStopCaller = new UnaryCaller(); + +class AgentStopHandler extends UnaryHandler< + { + pkAgent: PolykeyAgent; + logger: Logger; + }, + RPCRequestParams, + RPCResponseResult +> { + public async handle(): Promise { + const pkAgent = this.container.pkAgent; + // If not running or in stopping status, then respond successfully + if (!pkAgent[running] || pkAgent[status] === 'stopping') { + return {}; + } + // Stop PK agent in the background, allow the RPC time to respond + setTimeout(async () => { + await pkAgent.stop(); + }, 500); + return {}; + } +} + +export { agentStopCaller, AgentStopHandler }; diff --git a/src/clientRPC/handlers/index.ts b/src/clientRPC/handlers/index.ts new file mode 100644 index 0000000000..148f8f734d --- /dev/null +++ b/src/clientRPC/handlers/index.ts @@ -0,0 +1,37 @@ +import type Logger from '@matrixai/logger'; +import type SessionManager from '../../sessions/SessionManager'; +import type KeyRing from '../../keys/KeyRing'; +import type CertManager from '../../keys/CertManager'; +import type PolykeyAgent from '../../PolykeyAgent'; +import type { DB } from '@matrixai/db'; +import { agentStatusCaller, AgentStatusHandler } from './agentStatus'; +import { agentStopCaller, AgentStopHandler } from './agentStop'; +import { agentUnlockCaller, AgentUnlockHandler } from './agentUnlock'; +import { agentLockAllCaller, AgentLockAllHandler } from './agentLockAll'; + +const serverManifest = (container: { + pkAgent: PolykeyAgent; + keyRing: KeyRing; + certManager: CertManager; + db: DB; + sessionManager: SessionManager; + logger: Logger; +}) => { + // No type used here, it will override type inference + return { + agentLockAll: new AgentLockAllHandler(container), + agentStatus: new AgentStatusHandler(container), + agentStop: new AgentStopHandler(container), + agentUnlock: new AgentUnlockHandler(container), + }; +}; + +// No type used here, it will override type inference +const clientManifest = { + agentLockAll: agentLockAllCaller, + agentStatus: agentStatusCaller, + agentStop: agentStopCaller, + agentUnlock: agentUnlockCaller, +}; + +export { serverManifest, clientManifest }; diff --git a/tests/clientRPC/handlers/agentLockAll.test.ts b/tests/clientRPC/handlers/agentLockAll.test.ts new file mode 100644 index 0000000000..44093d2898 --- /dev/null +++ b/tests/clientRPC/handlers/agentLockAll.test.ts @@ -0,0 +1,113 @@ +import type { TLSConfig } from '@/network/types'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; +import { DB } from '@matrixai/db'; +import KeyRing from '@/keys/KeyRing'; +import * as keysUtils from '@/keys/utils'; +import RPCServer from '@/RPC/RPCServer'; +import TaskManager from '@/tasks/TaskManager'; +import { + agentLockAllCaller, + AgentLockAllHandler, +} from '@/clientRPC/handlers/agentLockAll'; +import RPCClient from '@/RPC/RPCClient'; +import { SessionManager } from '@/sessions'; +import ClientServer from '@/clientRPC/ClientServer'; +import ClientClient from '@/clientRPC/ClientClient'; +import * as testsUtils from '../../utils'; + +describe('agentLockAll', () => { + const logger = new Logger('agentUnlock test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + const password = 'helloWorld'; + const host = '127.0.0.1'; + let dataDir: string; + let db: DB; + let keyRing: KeyRing; + let taskManager: TaskManager; + let sessionManager: SessionManager; + let clientClient: ClientClient; + let clientServer: ClientServer; + let tlsConfig: TLSConfig; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + const dbPath = path.join(dataDir, 'db'); + db = await DB.createDB({ + dbPath, + logger, + }); + keyRing = await KeyRing.createKeyRing({ + password, + keysPath, + logger, + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + }); + taskManager = await TaskManager.createTaskManager({ db, logger }); + sessionManager = await SessionManager.createSessionManager({ + db, + keyRing, + logger, + }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + }); + afterEach(async () => { + await clientServer.stop(true); + await clientClient.destroy(true); + await taskManager.stop(); + await keyRing.stop(); + await db.stop(); + await fs.promises.rm(dataDir, { + force: true, + recursive: true, + }); + }); + test('Locks all current sessions', async () => { + // Setup + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + agentLockAll: new AgentLockAllHandler({ + db, + sessionManager, + logger, + }), + }, + logger, + }); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair, connectionInfo) => + rpcServer.handleStream(streamPair, connectionInfo), + host, + tlsConfig, + logger: logger.getChild('server'), + }); + clientClient = await ClientClient.createClientClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + logger: logger.getChild('client'), + port: clientServer.port, + }); + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + agentLockAll: agentLockAllCaller, + }, + streamPairCreateCallback: async () => clientClient.startConnection(), + logger: logger.getChild('clientRPC'), + }); + + // Doing the test + const token = await sessionManager.createToken(); + await rpcClient.methods.agentLockAll({}); + expect(await sessionManager.verifyToken(token)).toBeFalsy(); + }); +}); diff --git a/tests/clientRPC/handlers/agentStop.test.ts b/tests/clientRPC/handlers/agentStop.test.ts new file mode 100644 index 0000000000..cd248815b8 --- /dev/null +++ b/tests/clientRPC/handlers/agentStop.test.ts @@ -0,0 +1,134 @@ +import type { TLSConfig } from '@/network/types'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; +import { DB } from '@matrixai/db'; +import { running } from '@matrixai/async-init'; +import KeyRing from '@/keys/KeyRing'; +import * as keysUtils from '@/keys/utils'; +import RPCServer from '@/RPC/RPCServer'; +import TaskManager from '@/tasks/TaskManager'; +import { + agentStopCaller, + AgentStopHandler, +} from '@/clientRPC/handlers/agentStop'; +import RPCClient from '@/RPC/RPCClient'; +import ClientServer from '@/clientRPC/ClientServer'; +import ClientClient from '@/clientRPC/ClientClient'; +import config from '@/config'; +import PolykeyAgent from '@/PolykeyAgent'; +import * as testsUtils from '../../utils'; +import Status from '../../../src/status/Status'; + +describe('agentStop', () => { + const logger = new Logger('agentUnlock test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + const password = 'helloWorld'; + const host = '127.0.0.1'; + let dataDir: string; + let nodePath: string; + let db: DB; + let keyRing: KeyRing; + let taskManager: TaskManager; + let clientClient: ClientClient; + let clientServer: ClientServer; + let tlsConfig: TLSConfig; + let pkAgent: PolykeyAgent; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + nodePath = path.join(dataDir, 'polykey'); + const keysPath = path.join(dataDir, 'keys'); + const dbPath = path.join(dataDir, 'db'); + db = await DB.createDB({ + dbPath, + logger, + }); + keyRing = await KeyRing.createKeyRing({ + password, + keysPath, + logger, + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + }); + taskManager = await TaskManager.createTaskManager({ db, logger }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + pkAgent = await PolykeyAgent.createPolykeyAgent({ + password, + nodePath, + logger, + keyRingConfig: { + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + }, + }); + }); + afterEach(async () => { + await clientServer.stop(true); + await clientClient.destroy(true); + await taskManager.stop(); + await keyRing.stop(); + await db.stop(); + await fs.promises.rm(dataDir, { + force: true, + recursive: true, + }); + }); + test('Stops the agent', async () => { + // Setup + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + agentStop: new AgentStopHandler({ + pkAgent, + logger, + }), + }, + logger, + }); + clientServer = await ClientServer.createClientServer({ + connectionCallback: (streamPair, connectionInfo) => + rpcServer.handleStream(streamPair, connectionInfo), + host, + tlsConfig, + logger: logger.getChild('server'), + }); + clientClient = await ClientClient.createClientClient({ + expectedNodeIds: [keyRing.getNodeId()], + host, + logger: logger.getChild('client'), + port: clientServer.port, + }); + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + agentStop: agentStopCaller, + }, + streamPairCreateCallback: async () => clientClient.startConnection(), + logger: logger.getChild('clientRPC'), + }); + + // Doing the test + const statusPath = path.join(nodePath, config.defaults.statusBase); + const statusLockPath = path.join(nodePath, config.defaults.statusLockBase); + const status = new Status({ + statusPath, + statusLockPath, + fs, + logger, + }); + await rpcClient.methods.agentStop({}); + // It may already be stopping + expect(await status.readStatus()).toMatchObject({ + status: expect.stringMatching(/LIVE|STOPPING|DEAD/), + }); + await status.waitFor('DEAD'); + expect(pkAgent[running]).toBe(false); + }); +});