Skip to content

Commit

Permalink
wip: almost working tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aryanjassal committed Dec 5, 2024
1 parent 4419604 commit c892786
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 42 deletions.
86 changes: 46 additions & 40 deletions src/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -497,31 +497,36 @@ class RPCServer {

const prom = (async () => {
const id = await this.idGen();
const headTransformStream = middleware.binaryToJsonMessageStream(
const transformStream = middleware.binaryToJsonHeaderMessageStream(
utils.parseJSONRPCRequest,
);
// // Transparent transform used as a point to cancel the input stream from
// Transparent transform used as a point to cancel the input stream from
const passthroughTransform = new TransformStream<
Uint8Array,
Uint8Array
>();
const inputStream = passthroughTransform.readable;
const inputStreamEndProm = rpcStream.readable
.pipeTo(passthroughTransform.writable)
// Ignore any errors here, we only care that it ended
.catch(() => {});
const cleanUp = async (reason: any) => {
await inputStream.cancel(reason);
const cleanup = async (reason: any) => {
// Release resources
await transformStream.readable.cancel(reason);
await transformStream.writable.abort(reason);
await passthroughTransform.readable.cancel(reason);
await rpcStream.writable.abort(reason);
await inputStreamEndProm;
// Stop the timer
timer.cancel(cleanupReason);
await timer.catch(() => {});
};
const reader = inputStream.getReader();
console.log('about to read header message');
passthroughTransform.readable
.pipeTo(transformStream.writable)
.catch(() => {});
const reader = transformStream.readable.getReader();
// Allows timing out when waiting for the first message
let headerMessage:
| ReadableStreamDefaultReadResult<Uint8Array>
| ReadableStreamDefaultReadResult<JSONRPCRequest | Uint8Array>
| undefined
| void;
try {
Expand All @@ -533,7 +538,7 @@ class RPCServer {
),
]);
} catch (e) {
const newErr = new errors.ErrorRPCHandlerFailed(
const err = new errors.ErrorRPCHandlerFailed(
'Stream failed waiting for header',
{ cause: e },
);
Expand All @@ -544,70 +549,63 @@ class RPCServer {
new events.RPCErrorEvent({
detail: new errors.ErrorRPCOutputStreamError(
'Stream failed waiting for header',
{ cause: newErr },
{ cause: err },
),
}),
);
return;
}
// Downgrade back to the raw stream
reader.releaseLock();
console.log('read header message');
// There are 2 conditions where we just end here
// 1. The timeout timer resolves before the first message
// 2. the stream ends before the first message
if (headerMessage == null) {
const newErr = new errors.ErrorRPCTimedOut(
const err = new errors.ErrorRPCTimedOut(
'Timed out waiting for header',
{ cause: new errors.ErrorRPCStreamEnded() },
);
await cleanUp(newErr);
await cleanup(err);
this.dispatchEvent(
new events.RPCErrorEvent({
detail: new errors.ErrorRPCTimedOut(
'Timed out waiting for header',
{
cause: newErr,
cause: err,
},
),
}),
);
return;
}
if (headerMessage.done) {
const newErr = new errors.ErrorMissingHeader('Missing header');
await cleanUp(newErr);
const err = new errors.ErrorMissingHeader('Missing header');
await cleanup(err);
this.dispatchEvent(
new events.RPCErrorEvent({
detail: new errors.ErrorRPCOutputStreamError(),
}),
);
return;
}
console.log('resulting header message', headerMessage);
const headerStream = new ReadableStream<Uint8Array>({
start: async (controller) => {
controller.enqueue(headerMessage!.value);
controller.close();
},
});
console.log('piping readable stream to head transform writable');
await headerStream.pipeTo(headTransformStream.writable);
console.log('piping finished');
// Read the transformed header message
const transformedReader = headTransformStream.readable.getReader();
const transformedHeaderMessage = (await transformedReader.read()).value;
if (transformedHeaderMessage == null) utils.never();
console.log('got header properly', transformedHeaderMessage);
// Use the parsed header message
const method = transformedHeaderMessage.method;
if (headerMessage.value instanceof Uint8Array) {
const err = new errors.ErrorRPCParse('Invalid message type');
await cleanup(err);
this.dispatchEvent(
new events.RPCErrorEvent({
detail: new errors.ErrorRPCParse(),
}),
);
return;
}
const method = headerMessage.value.method;
const handler = this.handlerMap.get(method);
if (handler == null) {
await cleanUp(new errors.ErrorRPCHandlerFailed('Missing handler'));
await cleanup(new errors.ErrorRPCHandlerFailed('Missing handler'));
return;
}
if (abortController.signal.aborted) {
await cleanUp(
await cleanup(
new errors.ErrorHandlerAborted('Aborted', {
cause: new errors.ErrorHandlerAborted(),
}),
Expand All @@ -625,12 +623,24 @@ class RPCServer {
timer.refresh();
}
}
// Set up a wrapper ReadableStream of the correct type
const binaryReadableStream = new ReadableStream<Uint8Array>({
pull: async (controller) => {
for await (const chunk of transformStream.readable) {
// The transformStream is guaranteed to return binary data after
// sending the header message;
if (!(chunk instanceof Uint8Array)) utils.never();
controller.enqueue(chunk);
}
controller.close();
}
})
this.logger.info(`Handling stream with method (${method})`);
let handlerResult: [JSONObject | undefined, ReadableStream<Uint8Array>];
const headerWriter = rpcStream.writable.getWriter();
try {
handlerResult = await handler(
[transformedHeaderMessage, inputStream],
[headerMessage.value, binaryReadableStream],
rpcStream.cancel,
rpcStream.meta,
{ signal: abortController.signal, timer },
Expand Down Expand Up @@ -681,10 +691,6 @@ class RPCServer {
const outputStreamEndProm = outputStream
.pipeTo(rpcStream.writable)
.catch(() => {}); // Ignore any errors, we only care that it finished
// let inputStreamEndProm = (async () => {
// await rpcStream.readable.cancel(cleanupReason);
// await rpcStream.writable.abort(cleanupReason);
// })();
await Promise.allSettled([inputStreamEndProm, outputStreamEndProm]);
this.logger.info(`Handled stream with method (${method})`);
// Cleaning up abort and timer
Expand Down
67 changes: 67 additions & 0 deletions src/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,72 @@ function binaryToJsonMessageStream<T extends JSONRPCMessage>(
});
}

/**
* This function is a factory to create a TransformStream that will
* transform a `Uint8Array` stream to a stream containing the JSON header
* message and the rest of the data in `Uint8Array` format.
* The header message will be validated with the provided messageParser, this
* also infers the type of the stream output.
* @param messageParser - Validates the JSONRPC messages, so you can select for a
* specific type of message
* @param bufferByteLimit - sets the number of bytes buffered before throwing an
* error. This is used to avoid infinitely buffering the input.
*/
function binaryToJsonHeaderMessageStream<T extends JSONRPCMessage>(
messageParser: (message: unknown) => T,
bufferByteLimit: number = 1024 * 1024,
): TransformStream<Uint8Array, T | Uint8Array> {
const parser = new JSONParser({
separator: '',
paths: ['$'],
});
let bytesWritten: number = 0;
let accumulator = new Uint8Array([]);
let rawStream = false;

return new TransformStream<Uint8Array, T | Uint8Array>({
flush: async () => {
// Avoid potential race conditions by allowing parser to end first
const waitP = utils.promise();
parser.onEnd = () => waitP.resolveP();
parser.end();
await waitP.p;
},
start: (controller) => {
parser.onValue = (value) => {
// Enqueue the regular JSON message
const jsonMessage = messageParser(value.value);
controller.enqueue(jsonMessage);
// Remove the header message from the accumulated data
const headerLength = Buffer.from(JSON.stringify(jsonMessage)).length;
accumulator = accumulator.slice(headerLength);
if (accumulator.length > 0) controller.enqueue(accumulator);
// Set system state
bytesWritten = 0;
rawStream = true;
};
},
transform: (chunk, controller) => {
try {
bytesWritten += chunk.byteLength;
if (rawStream) {
// Send raw binary data directly
controller.enqueue(chunk);
} else {
// Prepare the data to be parsed to JSON
accumulator = new Uint8Array(Buffer.concat([accumulator, chunk]));
parser.write(chunk);
}
} catch (e) {
throw new rpcErrors.ErrorRPCParse(undefined, { cause: e });
}
if (bytesWritten > bufferByteLimit) {
throw new rpcErrors.ErrorRPCMessageLength();
}
},
});
}

/**
* This function is a factory for a TransformStream that will transform
* JsonRPCMessages into the `Uint8Array` form. This is used for the stream
Expand Down Expand Up @@ -270,6 +336,7 @@ const defaultClientMiddlewareWrapper = (

export {
binaryToJsonMessageStream,
binaryToJsonHeaderMessageStream,
jsonMessageToBinaryStream,
timeoutMiddlewareClient,
timeoutMiddlewareServer,
Expand Down
2 changes: 1 addition & 1 deletion tests/RPC.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ describe('RPC', () => {
return {
forward: new TransformStream({
start: (controller) => {
// Controller.terminate();
// controller.terminate();
controller.error(Error('SOME ERROR'));
},
}),
Expand Down
2 changes: 1 addition & 1 deletion tests/RPCServer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ describe(`${RPCServer.name}`, () => {
messages: specificMessageArb,
},
// { numRuns: 1 },
{ seed: 1292472631, path: "0", endOnFailure: true }
{ seed: 1292472631, path: '0', endOnFailure: true },
)('reverse middlewares', async ({ messages }) => {
const stream = rpcTestUtils.messagesToReadableStream(messages);
class TestMethod extends DuplexHandler {
Expand Down

0 comments on commit c892786

Please sign in to comment.