Skip to content

Commit

Permalink
feat(middleware): Enhance middleware with timeout and encapsulation f…
Browse files Browse the repository at this point in the history
…eatures

- Import ClientRPCResponseResult and ClientRPCRequestParams from PK.
- Implement timeoutMiddlewareServer and timeoutMiddlewareClient.
- Integrate timeoutMiddleware into defaultMiddleware.
- Fix Jest test issues.
- Rename to RPCResponseResult and RPCRequestParams for clarity.
- Perform lint fixes and Jest tests.
  • Loading branch information
addievo authored and amydevs committed Oct 26, 2023
1 parent 8434c70 commit 826823b
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 95 deletions.
110 changes: 104 additions & 6 deletions src/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ import type {
JSONRPCResponse,
JSONRPCResponseResult,
MiddlewareFactory,
JSONValue,
RPCRequestParams,
RPCResponseResult,
} from './types';
import type { ContextTimed } from '@matrixai/contexts';
import { TransformStream } from 'stream/web';
import { JSONParser } from '@streamparser/json';
import * as utils from './utils';
Expand Down Expand Up @@ -75,6 +79,96 @@ function jsonMessageToBinaryStream(): TransformStream<
});
}

function timeoutMiddlewareServer(
ctx: ContextTimed,
_cancel: (reason?: any) => void,
_meta: Record<string, JSONValue> | undefined,
) {
const currentTimeout = ctx.timer.delay;
// Flags for tracking if the first message has been processed
let forwardFirst = true;
let reverseFirst = true;
return {
forward: new TransformStream<
JSONRPCRequest<RPCRequestParams>,
JSONRPCRequest<RPCRequestParams>
>({
transform: (chunk, controller) => {
controller.enqueue(chunk);
if (forwardFirst) {
forwardFirst = false;
const clientTimeout = chunk.metadata?.timeout;

if (clientTimeout == null) return;
if (clientTimeout < currentTimeout) ctx.timer.reset(clientTimeout);
}
},
}),
reverse: new TransformStream<
JSONRPCResponse<RPCResponseResult>,
JSONRPCResponse<RPCResponseResult>
>({
transform: (chunk, controller) => {
if (reverseFirst) {
reverseFirst = false;
if ('result' in chunk) {
if (chunk.metadata == null) chunk.metadata = {};
chunk.metadata.timeout = currentTimeout;
}
}
controller.enqueue(chunk);
},
}),
};
}

/**
* This adds its own timeout to the forward metadata and updates it's timeout
* based on the reverse metadata.
* @param ctx
* @param _cancel
* @param _meta
*/
function timeoutMiddlewareClient(
ctx: ContextTimed,
_cancel: (reason?: any) => void,
_meta: Record<string, JSONValue> | undefined,
) {
const currentTimeout = ctx.timer.delay;
// Flags for tracking if the first message has been processed
let forwardFirst = true;
let reverseFirst = true;
return {
forward: new TransformStream<JSONRPCRequest, JSONRPCRequest>({
transform: (chunk, controller) => {
if (forwardFirst) {
forwardFirst = false;
if (chunk == null) chunk = { jsonrpc: '2.0', method: '' };
if (chunk.metadata == null) chunk.metadata = {};
(chunk.metadata as any).timeout = currentTimeout;
}
controller.enqueue(chunk);
},
}),
reverse: new TransformStream<
JSONRPCResponse<RPCResponseResult>,
JSONRPCResponse<RPCResponseResult>
>({
transform: (chunk, controller) => {
controller.enqueue(chunk);
if (reverseFirst) {
reverseFirst = false;
if ('result' in chunk) {
const clientTimeout = chunk.result?.metadata?.timeout;
if (clientTimeout == null) return;
if (clientTimeout < currentTimeout) ctx.timer.reset(clientTimeout);
}
}
},
}),
};
}

/**
* This function is a factory for creating a pass-through streamPair. It is used
* as the default middleware for the middleware wrappers.
Expand Down Expand Up @@ -116,12 +210,14 @@ function defaultServerMiddlewareWrapper(
>();

const middleMiddleware = middlewareFactory(ctx, cancel, meta);
const timeoutMiddleware = timeoutMiddlewareServer(ctx, cancel, meta);

const forwardReadable = inputTransformStream.readable.pipeThrough(
middleMiddleware.forward,
); // Usual middleware here
const forwardReadable = inputTransformStream.readable
.pipeThrough(timeoutMiddleware.forward) // Timeout middleware here
.pipeThrough(middleMiddleware.forward); // Usual middleware here
const reverseReadable = outputTransformStream.readable
.pipeThrough(middleMiddleware.reverse) // Usual middleware here
.pipeThrough(timeoutMiddleware.reverse) // Timeout middleware here
.pipeThrough(jsonMessageToBinaryStream());

return {
Expand Down Expand Up @@ -172,13 +268,15 @@ const defaultClientMiddlewareWrapper = (
JSONRPCRequest
>();

const timeoutMiddleware = timeoutMiddlewareClient(ctx, cancel, meta);
const middleMiddleware = middleware(ctx, cancel, meta);
const forwardReadable = inputTransformStream.readable
.pipeThrough(timeoutMiddleware.forward)
.pipeThrough(middleMiddleware.forward) // Usual middleware here
.pipeThrough(jsonMessageToBinaryStream());
const reverseReadable = outputTransformStream.readable.pipeThrough(
middleMiddleware.reverse,
); // Usual middleware here
const reverseReadable = outputTransformStream.readable
.pipeThrough(middleMiddleware.reverse)
.pipeThrough(timeoutMiddleware.reverse); // Usual middleware here

return {
forward: {
Expand Down
30 changes: 27 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type JSONRPCRequestMessage<T extends JSONValue = JSONValue> = {
* SHOULD NOT contain fractional parts [2]
*/
id: string | number | null;
};
} & RPCResponseResult;

/**
* This is the JSON RPC notification object. this is used for a request that
Expand All @@ -60,7 +60,7 @@ type JSONRPCRequestNotification<T extends JSONValue = JSONValue> = {
* This member MAY be omitted.
*/
params?: T;
};
} & RPCResponseResult;

/**
* This is the JSON RPC response result object. It contains the response data for a
Expand All @@ -84,7 +84,7 @@ type JSONRPCResponseResult<T extends JSONValue = JSONValue> = {
* it MUST be Null.
*/
id: string | number | null;
};
} & RPCResponseResult;

/**
* This is the JSON RPC response Error object. It contains any errors that have
Expand All @@ -110,6 +110,28 @@ type JSONRPCResponseError = {
id: string | number | null;
};

type ObjectEmpty = NonNullable<unknown>;

// Prevent overwriting the metadata type with `Omit<>`
type RPCRequestParams<T extends Record<string, JSONValue> = ObjectEmpty> = {
metadata?: {
[Key: string]: JSONValue;
} & Partial<{
authorization: string;
timeout: number | null;
}>;
} & Omit<T, 'metadata'>;

// Prevent overwriting the metadata type with `Omit<>`
type RPCResponseResult<T extends Record<string, JSONValue> = ObjectEmpty> = {
metadata?: {
[Key: string]: JSONValue;
} & Partial<{
authorization: string;
timeout: number | null;
}>;
} & Omit<T, 'metadata'>;

/**
* This is a JSON RPC error object, it encodes the error data for the JSONRPCResponseError object.
*/
Expand Down Expand Up @@ -357,6 +379,8 @@ export type {
JSONRPCRequestNotification,
JSONRPCResponseResult,
JSONRPCResponseError,
RPCRequestParams,
RPCResponseResult,
JSONRPCError,
JSONRPCRequest,
JSONRPCResponse,
Expand Down
12 changes: 12 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import type {
JSONRPCResponseResult,
JSONValue,
PromiseDeconstructed,
RPCRequestParams,
RPCResponseResult,
ToError,
} from './types';
import { TransformStream } from 'stream/web';
Expand Down Expand Up @@ -262,6 +264,14 @@ function fromError(error: any): JSONValue {
return error;
}

function isRPCRequestParams(data: JSONValue): data is RPCRequestParams {
return typeof data === 'object' && !Array.isArray(data);
}

function isRPCResponseResult(data: JSONValue): data is RPCResponseResult {
return typeof data === 'object' && !Array.isArray(data);
}

/**
* Error constructors for non-Polykey rpcErrors
* Allows these rpcErrors to be reconstructed from RPC metadata
Expand Down Expand Up @@ -545,6 +555,8 @@ export {
parseJSONRPCResponseError,
parseJSONRPCResponse,
parseJSONRPCMessage,
isRPCRequestParams,
isRPCResponseResult,
filterSensitive,
fromError,
toError,
Expand Down
Loading

0 comments on commit 826823b

Please sign in to comment.