Skip to content

Commit

Permalink
Merge pull request #664 from MatrixAI/feature-paramaterized-rpc-middl…
Browse files Browse the repository at this point in the history
…eware

PolykeyAgent and PolykeyClient now accept the `rpcMiddlewareFactory` for Custom Middleware + `versionMetadata` option for defining version information for `agentStatus` RPC call
  • Loading branch information
amydevs authored Jan 10, 2024
2 parents 8d697de + f28988b commit 78651fe
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 100 deletions.
31 changes: 29 additions & 2 deletions src/PolykeyAgent.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import type { DeepPartial, FileSystem, ObjectEmpty } from './types';
import type {
JSONObject,
JSONRPCRequest,
JSONRPCResponse,
MiddlewareFactory,
} from '@matrixai/rpc';
import type { DeepPartial, FileSystem, ObjectEmpty, POJO } from './types';
import type { PolykeyWorkerManagerInterface } from './workers/types';
import type { TLSConfig } from './network/types';
import type { NodeAddress, NodeId, SeedNodes } from './nodes/types';
import type { Key, PasswordOpsLimit, PasswordMemLimit } from './keys/types';
import type {
ClientRPCRequestParams,
ClientRPCResponseResult,
} from './client/types';
import path from 'path';
import process from 'process';
import Logger from '@matrixai/logger';
Expand Down Expand Up @@ -42,7 +52,6 @@ import * as workersUtils from './workers/utils';
import * as clientMiddleware from './client/middleware';
import clientServerManifest from './client/handlers';
import agentServerManifest from './nodes/agent/handlers';

/**
* Optional configuration for `PolykeyAgent`.
*/
Expand Down Expand Up @@ -73,6 +82,12 @@ type PolykeyAgentOptions = {
keepAliveIntervalTime: number;
rpcCallTimeoutTime: number;
rpcParserBufferSize: number;
rpcMiddlewareFactory?: MiddlewareFactory<
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCResponse<ClientRPCResponseResult>,
JSONRPCResponse<ClientRPCResponseResult>
>;
};
nodes: {
connectionIdleTimeoutTimeMin: number;
Expand All @@ -89,6 +104,7 @@ type PolykeyAgentOptions = {
groups: Array<string>;
port: number;
};
versionMetadata: POJO;
};

interface PolykeyAgent extends CreateDestroyStartStop {}
Expand Down Expand Up @@ -181,6 +197,7 @@ class PolykeyAgent {
groups: config.defaultsSystem.mdnsGroups,
port: config.defaultsSystem.mdnsPort,
},
versionMetadata: {},
});
// This can only happen if the caller didn't specify the node path and the
// automatic detection failed
Expand Down Expand Up @@ -417,6 +434,7 @@ class PolykeyAgent {
middlewareFactory: clientMiddleware.middlewareServer(
sessionManager,
keyRing,
optionsDefaulted.client.rpcMiddlewareFactory,
),
keepAliveTimeoutTime: optionsDefaulted.client.keepAliveTimeoutTime,
keepAliveIntervalTime: optionsDefaulted.client.keepAliveIntervalTime,
Expand Down Expand Up @@ -472,6 +490,7 @@ class PolykeyAgent {
clientService,
fs,
logger,
versionMetadata: optionsDefaulted.versionMetadata,
});
await pkAgent.start({
password,
Expand Down Expand Up @@ -514,6 +533,7 @@ class PolykeyAgent {
public readonly clientService: ClientService;
protected workerManager: PolykeyWorkerManagerInterface | undefined;
protected _startTime: number = 0;
protected _versionMetadata: JSONObject;

protected handleEventCertManagerCertChange = async (
evt: keysEvents.EventCertManagerCertChange,
Expand Down Expand Up @@ -558,6 +578,7 @@ class PolykeyAgent {
clientService,
fs,
logger,
versionMetadata,
}: {
nodePath: string;
audit: Audit;
Expand All @@ -581,6 +602,7 @@ class PolykeyAgent {
clientService: ClientService;
fs: FileSystem;
logger: Logger;
versionMetadata: POJO;
}) {
this.logger = logger;
this.nodePath = nodePath;
Expand All @@ -604,6 +626,7 @@ class PolykeyAgent {
this.sessionManager = sessionManager;
this.clientService = clientService;
this.fs = fs;
this._versionMetadata = versionMetadata;
}

@ready(new errors.ErrorPolykeyAgentNotRunning())
Expand All @@ -626,6 +649,10 @@ class PolykeyAgent {
return this.nodeConnectionManager.port;
}

get versionMetadata() {
return this._versionMetadata;
}

/**
* Returns the time the `PolykeyAgent` was started at in milliseconds since Unix epoch
*/
Expand Down
22 changes: 20 additions & 2 deletions src/PolykeyClient.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import type {
JSONRPCRequest,
JSONRPCResponse,
MiddlewareFactory,
} from '@matrixai/rpc';
import type { PromiseCancellable } from '@matrixai/async-cancellable';
import type { ContextTimed, ContextTimedInput } from '@matrixai/contexts';
import type { DeepPartial, FileSystem } from './types';
import type { OverrideRPClientType } from './client/types';
import type {
ClientRPCRequestParams,
ClientRPCResponseResult,
OverrideRPClientType,
} from './client/types';
import type { NodeId } from './ids/types';
import path from 'path';
import Logger from '@matrixai/logger';
Expand Down Expand Up @@ -33,6 +42,12 @@ type PolykeyClientOptions = {
keepAliveIntervalTime: number;
rpcCallTimeoutTime: number;
rpcParserBufferSize: number;
rpcMiddlewareFactory?: MiddlewareFactory<
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCResponse<ClientRPCResponseResult>,
JSONRPCResponse<ClientRPCResponseResult>
>;
};

interface PolykeyClient extends CreateDestroyStartStop {}
Expand Down Expand Up @@ -304,7 +319,10 @@ class PolykeyClient {
manifest: clientClientManifest,
streamFactory: () => webSocketClient.connection.newStream(),
middlewareFactory: rpcMiddleware.defaultClientMiddlewareWrapper(
clientMiddleware.middlewareClient(this.session),
clientMiddleware.middlewareClient(
this.session,
optionsDefaulted.rpcMiddlewareFactory,
),
optionsDefaulted.rpcParserBufferSize,
),
toError: networkUtils.toError,
Expand Down
1 change: 1 addition & 0 deletions src/client/handlers/AgentStatus.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class AgentStatus extends UnaryHandler<
sourceVersion: config.sourceVersion,
stateVersion: config.stateVersion,
networkVersion: config.networkVersion,
versionMetadata: polykeyAgent.versionMetadata,
};
};
}
Expand Down
52 changes: 44 additions & 8 deletions src/client/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import * as authenticationMiddlewareUtils from './authenticationMiddleware';
function middlewareServer(
sessionManager: SessionManager,
keyRing: KeyRing,
customMiddlewareFactory?: MiddlewareFactory<
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCResponse<ClientRPCResponseResult>,
JSONRPCResponse<ClientRPCResponseResult>
>,
): MiddlewareFactory<
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCRequest<ClientRPCRequestParams>,
Expand All @@ -25,22 +31,40 @@ function middlewareServer(
);
return (ctx, cancel, meta) => {
const authMiddleware = authMiddlewareFactory(ctx, cancel, meta);
// Order is auth -> timeout
const customMiddleware = customMiddlewareFactory?.(ctx, cancel, meta);
// Order is auth -> custom
return {
forward: {
writable: authMiddleware.forward.writable,
readable: authMiddleware.forward.readable,
readable:
customMiddleware == null
? authMiddleware.forward.readable
: authMiddleware.forward.readable.pipeThrough(
customMiddleware.forward,
),
},
reverse: {
writable: authMiddleware.reverse.writable,
readable: authMiddleware.reverse.readable,
writable:
customMiddleware?.reverse.writable ?? authMiddleware.reverse.writable,
readable:
customMiddleware == null
? authMiddleware.reverse.readable
: customMiddleware.reverse.readable.pipeThrough(
authMiddleware.reverse,
),
},
};
};
}

function middlewareClient(
session: Session,
customMiddlewareFactory?: MiddlewareFactory<
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCResponse<ClientRPCResponseResult>,
JSONRPCResponse<ClientRPCResponseResult>
>,
): MiddlewareFactory<
JSONRPCRequest<ClientRPCRequestParams>,
JSONRPCRequest<ClientRPCRequestParams>,
Expand All @@ -51,15 +75,27 @@ function middlewareClient(
authenticationMiddlewareUtils.authenticationMiddlewareClient(session);
return (ctx, cancel, meta) => {
const authMiddleware = authMiddlewareFactory(ctx, cancel, meta);
// Order is timeout -> auth
const customMiddleware = customMiddlewareFactory?.(ctx, cancel, meta);
// Order is custom -> auth
return {
forward: {
writable: authMiddleware.forward.writable,
readable: authMiddleware.forward.readable,
writable:
customMiddleware?.forward.writable ?? authMiddleware.forward.writable,
readable:
customMiddleware == null
? authMiddleware.forward.readable
: customMiddleware.forward.readable.pipeThrough(
authMiddleware.forward,
),
},
reverse: {
writable: authMiddleware.reverse.writable,
readable: authMiddleware.reverse.readable,
readable:
customMiddleware == null
? authMiddleware.reverse.readable
: authMiddleware.reverse.readable.pipeThrough(
customMiddleware.reverse,
),
},
};
};
Expand Down
1 change: 1 addition & 0 deletions src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type StatusResultMessage = {
sourceVersion: string;
stateVersion: number;
networkVersion: number;
versionMetadata: JSONObject;
};

// Identity messages
Expand Down
6 changes: 6 additions & 0 deletions tests/client/handlers/agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ describe('agentStatus', () => {
passwordMemLimit: keysUtils.passwordMemLimits.min,
strictMemoryLock: false,
},
versionMetadata: {
cliAgentCommitHash: 'test',
},
},
logger,
});
Expand Down Expand Up @@ -213,6 +216,9 @@ describe('agentStatus', () => {
sourceVersion: config.sourceVersion,
stateVersion: config.stateVersion,
networkVersion: config.networkVersion,
versionMetadata: {
cliAgentCommitHash: 'test',
},
});
});
});
Expand Down
Loading

0 comments on commit 78651fe

Please sign in to comment.