Skip to content

Commit

Permalink
fix(api-graphql): events url pattern; non-retryable error handling (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
iartemiev authored Oct 29, 2024
1 parent 891dae5 commit e0fdeb7
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 19 deletions.
177 changes: 177 additions & 0 deletions packages/api-graphql/__tests__/AWSAppSyncEventProvider.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import { Observable, Observer } from 'rxjs';
import { Reachability } from '@aws-amplify/core/internals/utils';
import { ConsoleLogger } from '@aws-amplify/core';
import { MESSAGE_TYPES } from '../src/Providers/constants';
import * as constants from '../src/Providers/constants';

import { delay, FakeWebSocketInterface } from './helpers';
import { ConnectionState as CS } from '../src/types/PubSub';

import { AWSAppSyncEventProvider } from '../src/Providers/AWSAppSyncEventsProvider';

describe('AppSyncEventProvider', () => {
describe('subscribe()', () => {
describe('returned observer', () => {
describe('connection logic with mocked websocket', () => {
let fakeWebSocketInterface: FakeWebSocketInterface;
const loggerSpy: jest.SpyInstance = jest.spyOn(
ConsoleLogger.prototype,
'_log',
);

let provider: AWSAppSyncEventProvider;
let reachabilityObserver: Observer<{ online: boolean }>;

beforeEach(async () => {
// Set the network to "online" for these tests
jest
.spyOn(Reachability.prototype, 'networkMonitor')
.mockImplementationOnce(() => {
return new Observable(observer => {
reachabilityObserver = observer;
});
})
// Twice because we subscribe to get the initial state then again to monitor reachability
.mockImplementationOnce(() => {
return new Observable(observer => {
reachabilityObserver = observer;
});
});

fakeWebSocketInterface = new FakeWebSocketInterface();
provider = new AWSAppSyncEventProvider();

// Saving this spy and resetting it by hand causes badness
// Saving it causes new websockets to be reachable across past tests that have not fully closed
// Resetting it proactively causes those same past tests to be dealing with null while they reach a settled state
jest
.spyOn(provider as any, '_getNewWebSocket')
.mockImplementation(() => {
fakeWebSocketInterface.newWebSocket();
return fakeWebSocketInterface.webSocket as WebSocket;
});

// Reduce retry delay for tests to 100ms
Object.defineProperty(constants, 'MAX_DELAY_MS', {
value: 100,
});
// Reduce retry delay for tests to 100ms
Object.defineProperty(constants, 'RECONNECT_DELAY', {
value: 100,
});
});

afterEach(async () => {
provider?.close();
await fakeWebSocketInterface?.closeInterface();
fakeWebSocketInterface?.teardown();
loggerSpy.mockClear();
});

test('subscription observer error is triggered when a connection is formed and a non-retriable connection_error data message is received', async () => {
expect.assertions(3);

const socketCloseSpy = jest.spyOn(
fakeWebSocketInterface.webSocket,
'close',
);
fakeWebSocketInterface.webSocket.readyState = WebSocket.OPEN;

const observer = provider.subscribe({
appSyncGraphqlEndpoint: 'ws://localhost:8080',
});

observer.subscribe({
error: e => {
expect(e.errors[0].message).toEqual(
'Connection failed: UnauthorizedException',
);
},
});

await fakeWebSocketInterface?.readyForUse;
await fakeWebSocketInterface?.triggerOpen();

// Resolve the message delivery actions
await Promise.resolve(
fakeWebSocketInterface?.sendDataMessage({
type: MESSAGE_TYPES.GQL_CONNECTION_ERROR,
errors: [
{
errorType: 'UnauthorizedException', // - non-retriable
errorCode: 401,
},
],
}),
);

// Watching for raised exception to be caught and logged
expect(loggerSpy).toHaveBeenCalledWith(
'DEBUG',
expect.stringContaining('error on bound '),
expect.objectContaining({
message: expect.stringMatching('UnauthorizedException'),
}),
);

await delay(1);

expect(socketCloseSpy).toHaveBeenCalledWith(3001);
});

test('subscription observer error is not triggered when a connection is formed and a retriable connection_error data message is received', async () => {
expect.assertions(2);

const observer = provider.subscribe({
appSyncGraphqlEndpoint: 'ws://localhost:8080',
});

observer.subscribe({
error: x => {},
});

const openSocketAttempt = async () => {
await fakeWebSocketInterface?.readyForUse;
await fakeWebSocketInterface?.triggerOpen();

// Resolve the message delivery actions
await Promise.resolve(
fakeWebSocketInterface?.sendDataMessage({
type: MESSAGE_TYPES.GQL_CONNECTION_ERROR,
errors: [
{
errorType: 'Retriable Test',
errorCode: 408, // Request timed out - retriable
},
],
}),
);
await fakeWebSocketInterface?.resetWebsocket();
};

// Go through two connection attempts to excercise backoff and retriable raise
await openSocketAttempt();
await openSocketAttempt();

// Watching for raised exception to be caught and logged
expect(loggerSpy).toHaveBeenCalledWith(
'DEBUG',
expect.stringContaining('error on bound '),
expect.objectContaining({
message: expect.stringMatching('Retriable Test'),
}),
);

await fakeWebSocketInterface?.waitUntilConnectionStateIn([
CS.ConnectionDisrupted,
]);

expect(loggerSpy).toHaveBeenCalledWith(
'DEBUG',
'Connection failed: Retriable Test',
);
});
});
});
});
});
14 changes: 14 additions & 0 deletions packages/api-graphql/__tests__/appsyncUrl.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { getRealtimeEndpointUrl } from '../src/Providers/AWSWebSocketProvider/appsyncUrl';

describe('getRealtimeEndpointUrl', () => {
test('events', () => {
const httpUrl =
'https://abcdefghijklmnopqrstuvwxyz.appsync-api.us-east-1.amazonaws.com/event';

const res = getRealtimeEndpointUrl(httpUrl).toString();

expect(res).toEqual(
'wss://abcdefghijklmnopqrstuvwxyz.appsync-realtime-api.us-east-1.amazonaws.com/event/realtime',
);
});
});
3 changes: 2 additions & 1 deletion packages/api-graphql/__tests__/events.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { AppSyncEventProvider } from '../src/Providers/AWSAppSyncEventsProvider'

import { events } from '../src/';
import { appsyncRequest } from '../src/internals/events/appsyncRequest';

import { GraphQLAuthMode } from '@aws-amplify/core/internals/utils';

const abortController = new AbortController();
Expand Down Expand Up @@ -38,7 +39,7 @@ jest.mock('../src/internals/events/appsyncRequest', () => {
* so we're just sanity checking that the expected auth mode is passed to the provider in this test file.
*/

describe('Events', () => {
describe('Events client', () => {
afterAll(() => {
jest.resetAllMocks();
jest.clearAllMocks();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
} from '@aws-amplify/core/internals/utils';
import { CustomHeaders } from '@aws-amplify/data-schema/runtime';

import { MESSAGE_TYPES } from '../constants';
import { DEFAULT_KEEP_ALIVE_TIMEOUT, MESSAGE_TYPES } from '../constants';
import { AWSWebSocketProvider } from '../AWSWebSocketProvider';
import { awsRealTimeHeaderBasedAuth } from '../AWSWebSocketProvider/authHeaders';

Expand Down Expand Up @@ -44,7 +44,7 @@ interface DataResponse {
const PROVIDER_NAME = 'AWSAppSyncEventsProvider';
const WS_PROTOCOL_NAME = 'aws-appsync-event-ws';

class AWSAppSyncEventProvider extends AWSWebSocketProvider {
export class AWSAppSyncEventProvider extends AWSWebSocketProvider {
constructor() {
super({ providerName: PROVIDER_NAME, wsProtocolName: WS_PROTOCOL_NAME });
}
Expand Down Expand Up @@ -187,6 +187,21 @@ class AWSAppSyncEventProvider extends AWSWebSocketProvider {
type: MESSAGE_TYPES.EVENT_STOP,
};
}

protected _extractConnectionTimeout(data: Record<string, any>): number {
const { connectionTimeoutMs = DEFAULT_KEEP_ALIVE_TIMEOUT } = data;

return connectionTimeoutMs;
}

protected _extractErrorCodeAndType(data: Record<string, any>): {
errorCode: number;
errorType: string;
} {
const { errors: [{ errorType = '', errorCode = 0 } = {}] = [] } = data;

return { errorCode, errorType };
}
}

export const AppSyncEventProvider = new AWSAppSyncEventProvider();
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
} from '@aws-amplify/core/internals/utils';
import { CustomHeaders } from '@aws-amplify/data-schema/runtime';

import { MESSAGE_TYPES } from '../constants';
import { DEFAULT_KEEP_ALIVE_TIMEOUT, MESSAGE_TYPES } from '../constants';
import { AWSWebSocketProvider } from '../AWSWebSocketProvider';
import { awsRealTimeHeaderBasedAuth } from '../AWSWebSocketProvider/authHeaders';

Expand Down Expand Up @@ -158,4 +158,23 @@ export class AWSAppSyncRealTimeProvider extends AWSWebSocketProvider {
type: MESSAGE_TYPES.GQL_STOP,
};
}

protected _extractConnectionTimeout(data: Record<string, any>): number {
const {
payload: { connectionTimeoutMs = DEFAULT_KEEP_ALIVE_TIMEOUT } = {},
} = data;

return connectionTimeoutMs;
}

protected _extractErrorCodeAndType(data: any): {
errorCode: number;
errorType: string;
} {
const {
payload: { errors: [{ errorType = '', errorCode = 0 } = {}] = [] } = {},
} = data;

return { errorCode, errorType };
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const protocol = 'wss://';
const standardDomainPattern =
/^https:\/\/\w{26}\.appsync-api\.\w{2}(?:(?:-\w{2,})+)-\d\.amazonaws.com(?:\.cn)?\/graphql$/i;
const eventDomainPattern =
/^https:\/\/\w{26}\.ddpg-api\.\w{2}(?:(?:-\w{2,})+)-\d\.amazonaws.com(?:\.cn)?\/event$/i;
/^https:\/\/\w{26}\.\w+-api\.\w{2}(?:(?:-\w{2,})+)-\d\.amazonaws.com(?:\.cn)?\/event$/i;
const customDomainPath = '/realtime';

export const isCustomDomain = (url: string): boolean => {
Expand All @@ -31,7 +31,8 @@ export const getRealtimeEndpointUrl = (
if (isEventDomain(realtimeEndpoint)) {
realtimeEndpoint = realtimeEndpoint
.concat(customDomainPath)
.replace('ddpg-api', 'grt-gamma');
.replace('ddpg-api', 'grt-gamma')
.replace('appsync-api', 'appsync-realtime-api');
} else if (isCustomDomain(realtimeEndpoint)) {
realtimeEndpoint = realtimeEndpoint.concat(customDomainPath);
} else {
Expand Down
35 changes: 23 additions & 12 deletions packages/api-graphql/src/Providers/AWSWebSocketProvider/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
MAX_DELAY_MS,
MESSAGE_TYPES,
NON_RETRYABLE_CODES,
NON_RETRYABLE_ERROR_TYPES,
SOCKET_STATUS,
START_ACK_TIMEOUT,
SUBSCRIPTION_STATUS,
Expand Down Expand Up @@ -546,6 +547,15 @@ export abstract class AWSWebSocketProvider {
{ id: string; payload: string | Record<string, unknown>; type: string },
];

protected abstract _extractConnectionTimeout(
data: Record<string, any>,
): number;

protected abstract _extractErrorCodeAndType(data: Record<string, any>): {
errorCode: number;
errorType: string;
};

private _handleIncomingSubscriptionMessage(message: MessageEvent) {
if (typeof message.data !== 'string') {
return;
Expand Down Expand Up @@ -629,14 +639,14 @@ export abstract class AWSWebSocketProvider {
});

this.logger.debug(
`${CONTROL_MSG.CONNECTION_FAILED}: ${JSON.stringify(payload)}`,
`${CONTROL_MSG.CONNECTION_FAILED}: ${JSON.stringify(payload ?? data)}`,
);

observer.error({
errors: [
{
...new GraphQLError(
`${CONTROL_MSG.CONNECTION_FAILED}: ${JSON.stringify(payload)}`,
`${CONTROL_MSG.CONNECTION_FAILED}: ${JSON.stringify(payload ?? data)}`,
),
},
],
Expand Down Expand Up @@ -830,10 +840,10 @@ export abstract class AWSWebSocketProvider {
);

const data = JSON.parse(message.data) as ParsedMessagePayload;
const {
type,
payload: { connectionTimeoutMs = DEFAULT_KEEP_ALIVE_TIMEOUT } = {},
} = data;

const { type } = data;

const connectionTimeoutMs = this._extractConnectionTimeout(data);

if (type === MESSAGE_TYPES.GQL_CONNECTION_ACK) {
ackOk = true;
Expand All @@ -844,11 +854,7 @@ export abstract class AWSWebSocketProvider {
}

if (type === MESSAGE_TYPES.GQL_CONNECTION_ERROR) {
const {
payload: {
errors: [{ errorType = '', errorCode = 0 } = {}] = [],
} = {},
} = data;
const { errorType, errorCode } = this._extractErrorCodeAndType(data);

// TODO(Eslint): refactor to reject an Error object instead of a plain object
// eslint-disable-next-line prefer-promise-reject-errors
Expand Down Expand Up @@ -920,7 +926,12 @@ export abstract class AWSWebSocketProvider {
errorCode: number;
};

if (NON_RETRYABLE_CODES.includes(errorCode)) {
if (
NON_RETRYABLE_CODES.includes(errorCode) ||
// Event API does not currently return `errorCode`. This may change in the future.
// For now fall back to also checking known non-retryable error types
NON_RETRYABLE_ERROR_TYPES.includes(errorType)
) {
throw new NonRetryableError(errorType);
} else if (errorType) {
throw new Error(errorType);
Expand Down
4 changes: 4 additions & 0 deletions packages/api-graphql/src/Providers/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ export { AMPLIFY_SYMBOL } from '@aws-amplify/core/internals/utils';
export const MAX_DELAY_MS = 5000;

export const NON_RETRYABLE_CODES = [400, 401, 403];
export const NON_RETRYABLE_ERROR_TYPES = [
'BadRequestException',
'UnauthorizedException',
];

export const CONNECTION_STATE_CHANGE = 'ConnectionStateChange';

Expand Down
Loading

0 comments on commit e0fdeb7

Please sign in to comment.