Skip to content

Commit

Permalink
Merge pull request #74 from MatrixAI/feature-message-skip
Browse files Browse the repository at this point in the history
Fixing random CI failures due to message skips
  • Loading branch information
aryanjassal authored Dec 6, 2024
2 parents 7a50a39 + 769d66f commit 22c79e1
Show file tree
Hide file tree
Showing 6 changed files with 1,641 additions and 1,910 deletions.
73 changes: 42 additions & 31 deletions src/RPCServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,8 @@ class RPCServer {
yield await handler(inputVal, cancel, meta, ctx);
break;
}
for await (const _ of input) {
// Noop so that stream can close after flushing
}
// Noop so that stream can close after flushing
for await (const _ of input);
};
this.registerDuplexStreamHandler(method, wrapperDuplex, timeout);
}
Expand Down Expand Up @@ -498,51 +497,47 @@ class RPCServer {

const prom = (async () => {
const id = await this.idGen();
const headTransformStream = middleware.binaryToJsonMessageStream(
const transformStream = middleware.binaryToJsonHeaderMessageStream(
utils.parseJSONRPCRequest,
);
// Transparent transform used as a point to cancel the input stream from
const passthroughTransform = new TransformStream<
Uint8Array,
Uint8Array
>();
const inputStream = passthroughTransform.readable;
const inputStreamEndProm = rpcStream.readable
.pipeTo(passthroughTransform.writable)
// Ignore any errors here, we only care that it ended
.catch(() => {});
void inputStream
// Allow us to re-use the readable after reading the first message
.pipeTo(headTransformStream.writable, {
preventClose: true,
preventCancel: true,
})
// Ignore any errors here, we only care that it ended
.catch(() => {});
const cleanUp = async (reason: any) => {
await inputStream.cancel(reason);
// Release resources
await transformStream.readable.cancel(reason);
await transformStream.writable.abort(reason);
await passthroughTransform.readable.cancel(reason);
await rpcStream.writable.abort(reason);
await inputStreamEndProm;
// Stop the timer
timer.cancel(cleanupReason);
await timer.catch(() => {});
};
// Read a single empty value to consume the first message
const reader = headTransformStream.readable.getReader();
passthroughTransform.readable
.pipeTo(transformStream.writable)
.catch(() => {});
const reader = transformStream.readable.getReader();
// Allows timing out when waiting for the first message
let headerMessage:
| ReadableStreamDefaultReadResult<JSONRPCRequest>
| undefined
| void;
| ReadableStreamDefaultReadResult<JSONRPCRequest | Uint8Array>
| undefined;
try {
headerMessage = await Promise.race([
reader.read(),
timer.then(
() => undefined,
() => {},
() => undefined,
),
]);
} catch (e) {
const newErr = new errors.ErrorRPCHandlerFailed(
const err = new errors.ErrorRPCHandlerFailed(
'Stream failed waiting for header',
{ cause: e },
);
Expand All @@ -553,49 +548,61 @@ class RPCServer {
new events.RPCErrorEvent({
detail: new errors.ErrorRPCOutputStreamError(
'Stream failed waiting for header',
{ cause: newErr },
{ cause: err },
),
}),
);
return;
}
// Downgrade back to the raw stream
await reader.cancel();
reader.releaseLock();
// There are 2 conditions where we just end here
// 1. The timeout timer resolves before the first message
// 2. the stream ends before the first message
if (headerMessage == null) {
const newErr = new errors.ErrorRPCTimedOut(
const err = new errors.ErrorRPCTimedOut(
'Timed out waiting for header',
{ cause: new errors.ErrorRPCStreamEnded() },
);
await cleanUp(newErr);
await cleanUp(err);
this.dispatchEvent(
new events.RPCErrorEvent({
detail: new errors.ErrorRPCTimedOut(
'Timed out waiting for header',
{
cause: newErr,
cause: err,
},
),
}),
);
return;
}
if (headerMessage.done) {
const newErr = new errors.ErrorMissingHeader('Missing header');
await cleanUp(newErr);
const err = new errors.ErrorMissingHeader('Missing header');
await cleanUp(err);
this.dispatchEvent(
new events.RPCErrorEvent({
detail: new errors.ErrorRPCOutputStreamError(),
}),
);
return;
}
if (headerMessage.value instanceof Uint8Array) {
const err = new errors.ErrorRPCParse('Invalid message type');
await cleanUp(err);
this.dispatchEvent(
new events.RPCErrorEvent({
detail: new errors.ErrorRPCParse(),
}),
);
return;
}
const method = headerMessage.value.method;
const handler = this.handlerMap.get(method);
if (handler == null) {
await cleanUp(new errors.ErrorRPCHandlerFailed('Missing handler'));
await cleanUp(
new errors.ErrorRPCHandlerFailed(`Missing handler for ${method}`),
);
return;
}
if (abortController.signal.aborted) {
Expand All @@ -617,13 +624,17 @@ class RPCServer {
timer.refresh();
}
}

this.logger.info(`Handling stream with method (${method})`);
let handlerResult: [JSONObject | undefined, ReadableStream<Uint8Array>];
const headerWriter = rpcStream.writable.getWriter();
try {
// The as keyword is used here as the middleware will only return the
// first message as a JSONMessage, and others as raw Uint8Arrays.
handlerResult = await handler(
[headerMessage.value, inputStream],
[
headerMessage.value,
transformStream.readable as ReadableStream<Uint8Array>,
],
rpcStream.cancel,
rpcStream.meta,
{ signal: abortController.signal, timer },
Expand Down
76 changes: 76 additions & 0 deletions src/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,81 @@ function binaryToJsonMessageStream<T extends JSONRPCMessage>(
});
}

/**
* This function is a factory to create a TransformStream that will
* transform a `Uint8Array` stream to a stream containing the JSON header
* message and the rest of the data in `Uint8Array` format.
* The header message will be validated with the provided messageParser, this
* also infers the type of the stream output.
* @param messageParser - Validates the JSONRPC messages, so you can select for a
* specific type of message
* @param bufferByteLimit - sets the number of bytes buffered before throwing an
* error. This is used to avoid infinitely buffering the input.
*/
function binaryToJsonHeaderMessageStream<T extends JSONRPCMessage>(
messageParser: (message: unknown) => T,
bufferByteLimit: number = 1024 * 1024,
): TransformStream<Uint8Array, T | Uint8Array> {
const parser = new JSONParser({
separator: '',
paths: ['$'],
});
let bytesWritten: number = 0;
let accumulator = Buffer.alloc(0);
let rawStream = false;
let parserEnded = false;

const cleanUp = async () => {
// Avoid potential race conditions by allowing parser to end first
const waitP = utils.promise();
parser.onEnd = () => waitP.resolveP();
parser.end();
await waitP.p;
};

return new TransformStream<Uint8Array, T | Uint8Array>({
flush: async () => {
if (!parserEnded) await cleanUp();
},
start: (controller) => {
parser.onValue = async (value) => {
// Enqueue the regular JSON message
const jsonMessage = messageParser(value.value);
controller.enqueue(jsonMessage);
// Remove the header message from the accumulated data
const headerLength = Buffer.from(
JSON.stringify(jsonMessage),
).byteLength;
accumulator = accumulator.subarray(headerLength);
if (accumulator.length > 0) controller.enqueue(accumulator);
// Set system state
bytesWritten = 0;
rawStream = true;
await cleanUp();
parserEnded = true;
};
},
transform: (chunk, controller) => {
try {
bytesWritten += chunk.byteLength;
if (rawStream) {
// Send raw binary data directly
controller.enqueue(chunk);
} else {
// Prepare the data to be parsed to JSON
accumulator = Buffer.concat([accumulator, chunk]);
parser.write(chunk);
}
} catch (e) {
throw new rpcErrors.ErrorRPCParse(undefined, { cause: e });
}
if (bytesWritten > bufferByteLimit) {
throw new rpcErrors.ErrorRPCMessageLength();
}
},
});
}

/**
* This function is a factory for a TransformStream that will transform
* JsonRPCMessages into the `Uint8Array` form. This is used for the stream
Expand Down Expand Up @@ -270,6 +345,7 @@ const defaultClientMiddlewareWrapper = (

export {
binaryToJsonMessageStream,
binaryToJsonHeaderMessageStream,
jsonMessageToBinaryStream,
timeoutMiddlewareClient,
timeoutMiddlewareServer,
Expand Down
Loading

0 comments on commit 22c79e1

Please sign in to comment.