diff --git a/src/WebSocketStream.ts b/src/WebSocketStream.ts index 7a5672c3..8d10c3be 100644 --- a/src/WebSocketStream.ts +++ b/src/WebSocketStream.ts @@ -3,8 +3,8 @@ import type { StreamCodeToReason, StreamReasonToCode, } from './types.js'; -import type WebSocketConnection from './WebSocketConnection.js'; import type { StreamId, StreamMessage, VarInt } from './message/index.js'; +import type { Evented } from '@matrixai/events'; import type { ReadableWritablePair, WritableStreamDefaultController, @@ -102,7 +102,7 @@ class WebSocketStream implements ReadableWritablePair { public readonly writable: WritableStream; protected logger: Logger; - protected connection: WebSocketConnection; + protected connection: Evented & { meta: () => ConnectionMetadata }; protected reasonToCode: StreamReasonToCode; protected codeToReason: StreamCodeToReason; protected readableController: ReadableStreamDefaultController; @@ -199,7 +199,7 @@ class WebSocketStream implements ReadableWritablePair { }: { initiated: 'local' | 'peer'; streamId: StreamId; - connection: WebSocketConnection; + connection: Evented & { meta: () => ConnectionMetadata }; bufferSize: number; reasonToCode?: StreamReasonToCode; codeToReason?: StreamCodeToReason; diff --git a/tests/WebSocketStream.test.ts b/tests/WebSocketStream.test.ts index 120ade7e..0f5bddc6 100644 --- a/tests/WebSocketStream.test.ts +++ b/tests/WebSocketStream.test.ts @@ -1,9 +1,9 @@ import type { StreamId } from '#message/index.js'; +import type WebSocketConnection from '#WebSocketConnection.js'; import Logger, { formatting, LogLevel, StreamHandler } from '@matrixai/logger'; import { fc, test } from '@fast-check/jest'; import * as messageTestUtils from './message/utils.js'; import WebSocketStream from '#WebSocketStream.js'; -import WebSocketConnection from '#WebSocketConnection.js'; import * as events from '#events.js'; import * as utils from '#utils.js'; import * as messageUtils from '#message/utils.js'; @@ -28,121 +28,118 @@ const logger2 = new Logger('stream 2', LogLevel.WARN, [ let streamIdCounter = 0n; -jest.mock('#WebSocketConnection.js', () => { - return jest.fn().mockImplementation((streamOptions: StreamOptions = {}) => { - const instance = new EventTarget() as EventTarget & { - peerConnection: WebSocketConnection | undefined; - connectTo: (connection: WebSocketConnection) => void; - send: (data: Uint8Array) => Promise; - newStream: () => Promise; - streamMap: Map; - }; - instance.peerConnection = undefined; - instance.connectTo = (peerConnection: any) => { - instance.peerConnection = peerConnection; - peerConnection.peerConnection = instance; - }; - instance.streamMap = new Map(); - instance.newStream = async () => { - const stream = new WebSocketStream({ - initiated: 'local', - streamId: streamIdCounter as StreamId, +function createMockedWebSocketConnection(streamOptions: StreamOptions = {}) { + const instance = new EventTarget() as EventTarget & { + peerConnection: WebSocketConnection | undefined; + connectTo: (connection: WebSocketConnection) => void; + send: (data: Uint8Array) => Promise; + newStream: () => Promise; + streamMap: Map; + }; + instance.peerConnection = undefined; + instance.connectTo = (peerConnection: any) => { + instance.peerConnection = peerConnection; + peerConnection.peerConnection = instance; + }; + instance.streamMap = new Map(); + instance.newStream = async () => { + const stream = new WebSocketStream({ + initiated: 'local', + streamId: streamIdCounter as StreamId, + bufferSize: STREAM_BUFFER_SIZE, + connection: instance as any, + logger: logger1, + ...streamOptions, + }); + stream.addEventListener( + events.EventWebSocketStreamSend.name, + async (evt: any) => { + await instance.send(evt.msg); + }, + ); + stream.addEventListener( + events.EventWebSocketStreamStopped.name, + () => { + instance.streamMap.delete(stream.streamId); + }, + { once: true }, + ); + instance.streamMap.set(stream.streamId, stream); + await stream.start(); + streamIdCounter++; + return stream; + }; + instance.send = async (array: Uint8Array | Array) => { + let data: Uint8Array; + if (ArrayBuffer.isView(array)) { + data = array; + } else { + data = messageUtils.concatUInt8Array(...array); + } + const { data: streamId, remainder } = messageUtils.parseStreamId(data); + // @ts-ignore: protected property + let stream = instance.peerConnection!.streamMap.get(streamId); + if (stream == null) { + if ( + !(remainder.at(0) === 0 && remainder.at(1) === StreamMessageType.Ack) + ) { + return; + } + stream = new WebSocketStream({ + initiated: 'peer', + streamId, bufferSize: STREAM_BUFFER_SIZE, - connection: instance as any, - logger: logger1, + connection: instance.peerConnection!, + logger: logger2, ...streamOptions, }); stream.addEventListener( events.EventWebSocketStreamSend.name, async (evt: any) => { - await instance.send(evt.msg); + // @ts-ignore: protected property + await instance.peerConnection!.send(evt.msg); }, ); stream.addEventListener( events.EventWebSocketStreamStopped.name, () => { - instance.streamMap.delete(stream.streamId); + // @ts-ignore: protected property + instance.peerConnection!.streamMap.delete(streamId); }, { once: true }, ); - instance.streamMap.set(stream.streamId, stream); - await stream.start(); - streamIdCounter++; - return stream; - }; - instance.send = async (array: Uint8Array | Array) => { - let data: Uint8Array; - if (ArrayBuffer.isView(array)) { - data = array; - } else { - data = messageUtils.concatUInt8Array(...array); - } - const { data: streamId, remainder } = messageUtils.parseStreamId(data); // @ts-ignore: protected property - let stream = instance.peerConnection!.streamMap.get(streamId); - if (stream == null) { - if ( - !(remainder.at(0) === 0 && remainder.at(1) === StreamMessageType.Ack) - ) { - return; - } - stream = new WebSocketStream({ - initiated: 'peer', - streamId, - bufferSize: STREAM_BUFFER_SIZE, - connection: instance.peerConnection!, - logger: logger2, - ...streamOptions, - }); - stream.addEventListener( - events.EventWebSocketStreamSend.name, - async (evt: any) => { - // @ts-ignore: protected property - await instance.peerConnection!.send(evt.msg); - }, - ); - stream.addEventListener( - events.EventWebSocketStreamStopped.name, - () => { - // @ts-ignore: protected property - instance.peerConnection!.streamMap.delete(streamId); - }, - { once: true }, - ); - // @ts-ignore: protected property - instance.peerConnection!.streamMap.set(stream.streamId, stream); - await stream.start(); - instance.peerConnection!.dispatchEvent( - new events.EventWebSocketConnectionStream({ - detail: stream, - }), - ); - } - await stream.streamRecv(remainder); - }; - return instance; - }); -}); - -const connectionMock = jest.mocked(WebSocketConnection, { shallow: true }); + instance.peerConnection!.streamMap.set(stream.streamId, stream); + await stream.start(); + instance.peerConnection!.dispatchEvent( + new events.EventWebSocketConnectionStream({ + detail: stream, + }), + ); + } + await stream.streamRecv(remainder); + }; + return instance; +} describe(WebSocketStream.name, () => { - beforeEach(async () => { - connectionMock.mockClear(); - }); - async function createConnectionPair( streamOptions: StreamOptions = {}, - ): Promise<[WebSocketConnection, WebSocketConnection]> { - const connection1 = new (WebSocketConnection as any)(streamOptions); - const connection2 = new (WebSocketConnection as any)(streamOptions); + ): Promise< + [ + ReturnType, + ReturnType, + ] + > { + const connection1 = createMockedWebSocketConnection(streamOptions); + const connection2 = createMockedWebSocketConnection(streamOptions); (connection1 as any).connectTo(connection2); return [connection1, connection2]; } async function createStreamPairFrom( - connection1: WebSocketConnection, - connection2: WebSocketConnection, + connection1: ReturnType, + connection2: ReturnType, ): Promise<[WebSocketStream, WebSocketStream]> { const stream1 = await connection1.newStream(); const createStream2Prom = utils.promise();