From b0876c8c21b115ca06be9c8518a6ad07fe72d750 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 24 Mar 2023 12:09:39 +1100 Subject: [PATCH] fix: `RPCServer` handling output stream cancellation [ci skip] --- src/rpc/RPCServer.ts | 49 +++++++++--------------- src/rpc/errors.ts | 6 +-- tests/rpc/RPCServer.test.ts | 74 ++++++++++++++++++++++++++++++++++--- 3 files changed, 89 insertions(+), 40 deletions(-) diff --git a/src/rpc/RPCServer.ts b/src/rpc/RPCServer.ts index 390609145..fe357dec1 100644 --- a/src/rpc/RPCServer.ts +++ b/src/rpc/RPCServer.ts @@ -228,12 +228,7 @@ class RPCServer extends EventTarget { result: response, id: null, }; - try { - yield responseMessage; - } catch(e) { - // This catches any exceptions thrown into the reverse stream - await handlerG.throw(e); - } + yield responseMessage; } }; const outputGenerator = outputGen(); @@ -258,35 +253,27 @@ class RPCServer extends EventTarget { id: null, }; controller.enqueue(rpcErrorMessage); - await forwardStream.cancel( - new rpcErrors.ErrorRPCHandlerFailed('Error clean up'), - ); + // Clean up the input stream here, ignore error if already ended + await forwardStream + .cancel(new rpcErrors.ErrorRPCHandlerFailed('Error clean up')) + .catch(() => {}); controller.close(); } }, cancel: async (reason) => { - try { - // Throw the reason into the reverse stream - await outputGenerator.throw(reason); - } catch (e) { - // If the e is the same as the reason - // then the handler did not care about the reason - // and we just discard it - if (e !== reason) { - this.dispatchEvent( - new rpcEvents.RPCErrorEvent({ - detail: new rpcErrors.ErrorRPCSendErrorFailed( - 'Stream has been cancelled', - { - cause: e, - } - ), - }), - ); - } - } - // await outputGenerator.nexj - // handlerAbortController.abort(reason); + this.dispatchEvent( + new rpcEvents.RPCErrorEvent({ + detail: new rpcErrors.ErrorRPCOutputStreamError( + 'Stream has been cancelled', + { + cause: reason, + }, + ), + }), + ); + // If the output stream path fails then we need to end the generator + // early. + await outputGenerator.return(undefined); }, }); void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); diff --git a/src/rpc/errors.ts b/src/rpc/errors.ts index ffb1b47e1..f62b03f3f 100644 --- a/src/rpc/errors.ts +++ b/src/rpc/errors.ts @@ -38,8 +38,8 @@ class ErrorRPCMissingResponse extends ErrorRPC { exitCode = sysexits.UNAVAILABLE; } -class ErrorRPCSendErrorFailed extends ErrorRPC { - static description = 'Failed to send error message'; +class ErrorRPCOutputStreamError extends ErrorRPC { + static description = 'Output stream failed, unable to send data'; exitCode = sysexits.UNAVAILABLE; } @@ -102,6 +102,6 @@ export { ErrorRPCHandlerFailed, ErrorRPCMessageLength, ErrorRPCMissingResponse, - ErrorRPCSendErrorFailed, + ErrorRPCOutputStreamError, ErrorPolykeyRemote, }; diff --git a/tests/rpc/RPCServer.test.ts b/tests/rpc/RPCServer.test.ts index cec0b5bd7..d1560c62c 100644 --- a/tests/rpc/RPCServer.test.ts +++ b/tests/rpc/RPCServer.test.ts @@ -24,6 +24,7 @@ import { UnaryHandler, } from '@/rpc/handlers'; import * as middlewareUtils from '@/rpc/utils/middleware'; +import { promise } from '@/utils'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -454,13 +455,70 @@ describe(`${RPCServer.name}`, () => { }, ); testProp( - 'should emit stream error', + 'should emit stream error if input stream fails', [specificMessageArb], async (messages) => { + const handlerEndedProm = promise(); + class TestMethod extends DuplexHandler { + public async *handle(input): AsyncIterable { + try { + for await (const _ of input) { + // Consume but don't yield anything + } + } finally { + handlerEndedProm.resolveP(); + } + } + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + }); + let resolve; + rpcServer.addEventListener('error', (thing: RPCErrorEvent) => { + resolve(thing); + }); + const passThroughStreamIn = new TransformStream(); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: passThroughStreamIn.readable, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + const writer = passThroughStreamIn.writable.getWriter(); + // Write messages + for (const message of messages) { + await writer.write(Buffer.from(JSON.stringify(message))); + } + // Abort stream + const writerReason = Symbol('writerAbort'); + await writer.abort(writerReason); + // We should get an error RPC message + await expect(outputResult).toResolve(); + const errorMessage = JSON.parse((await outputResult)[0].toString()); + // Parse without error + rpcUtils.parseJSONRPCResponseError(errorMessage); + // Check that the handler was cleaned up. + await expect(handlerEndedProm.p).toResolve(); + await rpcServer.destroy(); + }, + { numRuns: 1 }, + ); + testProp.only( + 'should emit stream error if output stream fails', + [specificMessageArb], + async (messages) => { + const handlerEndedProm = promise(); class TestMethod extends DuplexHandler { public async *handle(input): AsyncIterable { // Echo input - yield* input; + try { + yield* input; + } finally { + handlerEndedProm.resolveP(); + } } } const rpcServer = await RPCServer.createRPCServer({ @@ -494,14 +552,18 @@ describe(`${RPCServer.name}`, () => { await reader.read(); } // Abort stream - const writerReason = Symbol('writerAbort'); + // const writerReason = Symbol('writerAbort'); const readerReason = Symbol('readerAbort'); - await writer.abort(writerReason); + // Await writer.abort(writerReason); await reader.cancel(readerReason); // We should get an error event const event = await errorProm; - expect(event.detail.cause).toContain(writerReason); - expect(event.detail.cause).toContain(readerReason); + await writer.close(); + // Expect(event.detail.cause).toContain(writerReason); + expect(event.detail).toBeInstanceOf(rpcErrors.ErrorRPCOutputStreamError); + expect(event.detail.cause).toBe(readerReason); + // Check that the handler was cleaned up. + await expect(handlerEndedProm.p).toResolve(); await rpcServer.destroy(); }, { numRuns: 1 },