diff --git a/.changeset/fresh-candles-leave.md b/.changeset/fresh-candles-leave.md new file mode 100644 index 0000000000..67c5f741d6 --- /dev/null +++ b/.changeset/fresh-candles-leave.md @@ -0,0 +1,5 @@ +--- +'@aws-amplify/ai-constructs': minor +--- + +Propagate errors to AppSync diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts index 510925492a..c215efc0b4 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts @@ -21,6 +21,7 @@ import { import { ConversationTurnEventToolsProvider } from './event-tools-provider'; import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever'; import * as bedrock from '@aws-sdk/client-bedrock-runtime'; +import { ValidationError } from './errors'; /** * This class is responsible for interacting with Bedrock Converse API @@ -87,7 +88,7 @@ export class BedrockConverseAdapter { this.clientToolByName.set(t.name, t); }); if (duplicateTools.size > 0) { - throw new Error( + throw new ValidationError( `Tools must have unique names. Duplicate tools: ${[ ...duplicateTools, ].join(', ')}.` diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts index be96f22fa3..8c42431b63 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts @@ -5,6 +5,7 @@ import { ConversationTurnEvent, StreamingResponseChunk } from './types'; import { BedrockConverseAdapter } from './bedrock_converse_adapter'; import { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { ConversationTurnResponseSender } from './conversation_turn_response_sender'; +import { Lazy } from './lazy'; void describe('Conversation turn executor', () => { const event: ConversationTurnEvent = { @@ -62,8 +63,8 @@ void describe('Conversation turn executor', () => { await new ConversationTurnExecutor( event, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(); @@ -156,8 +157,8 @@ void describe('Conversation turn executor', () => { await new ConversationTurnExecutor( streamingEvent, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(); @@ -214,13 +215,21 @@ void describe('Conversation turn executor', () => { () => Promise.resolve() ); + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + const consoleErrorMock = mock.fn(); const consoleLogMock = mock.fn(); const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); const consoleMock = { error: consoleErrorMock, log: consoleLogMock, debug: consoleDebugMock, + warn: consoleWarnMock, } as unknown as Console; await assert.rejects( @@ -228,8 +237,8 @@ void describe('Conversation turn executor', () => { new ConversationTurnExecutor( event, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(), (error: Error) => { @@ -263,6 +272,16 @@ void describe('Conversation turn executor', () => { consoleErrorMock.mock.calls[0].arguments[1], bedrockError ); + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'Bedrock failed', + }, + ] + ); }); void it('logs and propagates error if response sender throws', async () => { @@ -290,13 +309,21 @@ void describe('Conversation turn executor', () => { () => Promise.resolve() ); + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + const consoleErrorMock = mock.fn(); const consoleLogMock = mock.fn(); const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); const consoleMock = { error: consoleErrorMock, log: consoleLogMock, debug: consoleDebugMock, + warn: consoleWarnMock, } as unknown as Console; await assert.rejects( @@ -304,8 +331,8 @@ void describe('Conversation turn executor', () => { new ConversationTurnExecutor( event, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(), (error: Error) => { @@ -339,5 +366,180 @@ void describe('Conversation turn executor', () => { consoleErrorMock.mock.calls[0].arguments[1], responseSenderError ); + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'Failed to send response', + }, + ] + ); + }); + + void it('throws original exception if error sender fails', async () => { + const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); + const originalError = new Error('original error'); + mock.method(bedrockConverseAdapter, 'askBedrock', () => + Promise.reject(originalError) + ); + const responseSender = new ConversationTurnResponseSender(event); + mock.method(responseSender, 'sendResponse', () => Promise.resolve()); + + mock.method(responseSender, 'sendResponseChunk', () => Promise.resolve()); + + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.reject(new Error('sender error')) + ); + + const consoleErrorMock = mock.fn(); + const consoleLogMock = mock.fn(); + const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); + const consoleMock = { + error: consoleErrorMock, + log: consoleLogMock, + debug: consoleDebugMock, + warn: consoleWarnMock, + } as unknown as Console; + + await assert.rejects( + () => + new ConversationTurnExecutor( + event, + [], + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), + consoleMock + ).execute(), + (error: Error) => { + assert.strictEqual(error, originalError); + return true; + } + ); + + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'original error', + }, + ] + ); + }); + + void it('serializes unknown errors', async () => { + const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); + const unknownError = { some: 'shape' }; + mock.method(bedrockConverseAdapter, 'askBedrock', () => + Promise.reject(unknownError) + ); + const responseSender = new ConversationTurnResponseSender(event); + mock.method(responseSender, 'sendResponse', () => Promise.resolve()); + + mock.method(responseSender, 'sendResponseChunk', () => Promise.resolve()); + + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + + const consoleErrorMock = mock.fn(); + const consoleLogMock = mock.fn(); + const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); + const consoleMock = { + error: consoleErrorMock, + log: consoleLogMock, + debug: consoleDebugMock, + warn: consoleWarnMock, + } as unknown as Console; + + await assert.rejects( + () => + new ConversationTurnExecutor( + event, + [], + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), + consoleMock + ).execute(), + (error: Error) => { + assert.strictEqual(error, unknownError); + return true; + } + ); + + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'UnknownError', + message: '{"some":"shape"}', + }, + ] + ); + }); + + void it('reports initialization errors', async () => { + const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); + mock.method(bedrockConverseAdapter, 'askBedrock', () => Promise.resolve()); + const responseSender = new ConversationTurnResponseSender(event); + mock.method(responseSender, 'sendResponse', () => Promise.resolve()); + + mock.method(responseSender, 'sendResponseChunk', () => Promise.resolve()); + + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + + const consoleErrorMock = mock.fn(); + const consoleLogMock = mock.fn(); + const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); + const consoleMock = { + error: consoleErrorMock, + log: consoleLogMock, + debug: consoleDebugMock, + warn: consoleWarnMock, + } as unknown as Console; + + const initializationError = new Error('initialization error'); + await assert.rejects( + () => + new ConversationTurnExecutor( + event, + [], + new Lazy(() => responseSender), + new Lazy(() => { + throw initializationError; + }), + consoleMock + ).execute(), + (error: Error) => { + assert.strictEqual(error, initializationError); + return true; + } + ); + + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'initialization error', + }, + ] + ); }); }); diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts index 04a7384207..9c5389f610 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts @@ -1,6 +1,7 @@ import { ConversationTurnResponseSender } from './conversation_turn_response_sender.js'; import { ConversationTurnEvent, ExecutableTool, JSONSchema } from './types.js'; import { BedrockConverseAdapter } from './bedrock_converse_adapter.js'; +import { Lazy } from './lazy'; /** * This class is responsible for orchestrating conversation turn execution. @@ -16,11 +17,13 @@ export class ConversationTurnExecutor { constructor( private readonly event: ConversationTurnEvent, additionalTools: Array, - private readonly bedrockConverseAdapter = new BedrockConverseAdapter( - event, - additionalTools + // We're deferring dependency initialization here so that we can capture all validation errors. + private readonly responseSender = new Lazy( + () => new ConversationTurnResponseSender(event) + ), + private readonly bedrockConverseAdapter = new Lazy( + () => new BedrockConverseAdapter(event, additionalTools) ), - private readonly responseSender = new ConversationTurnResponseSender(event), private readonly logger = console ) {} @@ -32,14 +35,14 @@ export class ConversationTurnExecutor { this.logger.debug('Event received:', this.event); if (this.event.streamResponse) { - const chunks = this.bedrockConverseAdapter.askBedrockStreaming(); + const chunks = this.bedrockConverseAdapter.value.askBedrockStreaming(); for await (const chunk of chunks) { - await this.responseSender.sendResponseChunk(chunk); + await this.responseSender.value.sendResponseChunk(chunk); } } else { const assistantResponse = - await this.bedrockConverseAdapter.askBedrock(); - await this.responseSender.sendResponse(assistantResponse); + await this.bedrockConverseAdapter.value.askBedrock(); + await this.responseSender.value.sendResponse(assistantResponse); } this.logger.log( @@ -50,10 +53,28 @@ export class ConversationTurnExecutor { `Failed to handle conversation turn event, currentMessageId=${this.event.currentMessageId}, conversationId=${this.event.conversationId}`, e ); + await this.tryForwardError(e); // Propagate error to mark lambda execution as failed in metrics. throw e; } }; + + private tryForwardError = async (e: unknown) => { + try { + let errorType = 'UnknownError'; + let message: string; + if (e instanceof Error) { + errorType = e.name; + message = e.message; + } else { + message = JSON.stringify(e); + } + await this.responseSender.value.sendErrors([{ errorType, message }]); + } catch (e) { + // Best effort, only log the fact that we tried to send error back to AppSync. + this.logger.warn('Failed to send error mutation', e); + } + }; } /** diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts index b38ccbdff9..0e7e1fe71a 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts @@ -5,7 +5,11 @@ import { MutationResponseInput, MutationStreamingResponseInput, } from './conversation_turn_response_sender'; -import { ConversationTurnEvent, StreamingResponseChunk } from './types'; +import { + ConversationTurnError, + ConversationTurnEvent, + StreamingResponseChunk, +} from './types'; import { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { GraphqlRequest, @@ -236,4 +240,60 @@ void describe('Conversation turn response sender', () => { }, }); }); + + void it('sends errors response back to appsync', async () => { + const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', ''); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + () => + // Mock successful Appsync response + Promise.resolve() + ); + const sender = new ConversationTurnResponseSender( + event, + graphqlRequestExecutor + ); + const errors: Array = [ + { + errorType: 'errorType1', + message: 'errorMessage1', + }, + { + errorType: 'errorType2', + message: 'errorMessage2', + }, + ]; + await sender.sendErrors(errors); + + assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); + const request = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + assert.deepStrictEqual(request, { + query: + '\n' + + ' mutation PublishModelResponse($input: testResponseMutationInputTypeName!) {\n' + + ' testResponseMutationName(input: $input) {\n' + + ' testSelectionSet\n' + + ' }\n' + + ' }\n' + + ' ', + variables: { + input: { + conversationId: event.conversationId, + errors: [ + { + errorType: 'errorType1', + message: 'errorMessage1', + }, + { + errorType: 'errorType2', + message: 'errorMessage2', + }, + ], + associatedUserMessageId: event.currentMessageId, + }, + }, + }); + }); }); diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts index 7590f5edcf..723b787959 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts @@ -1,4 +1,8 @@ -import { ConversationTurnEvent, StreamingResponseChunk } from './types.js'; +import { + ConversationTurnError, + ConversationTurnEvent, + StreamingResponseChunk, +} from './types.js'; import type { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { GraphqlRequestExecutor } from './graphql_request_executor'; @@ -14,6 +18,14 @@ export type MutationStreamingResponseInput = { input: StreamingResponseChunk; }; +export type MutationErrorsResponseInput = { + input: { + conversationId: string; + errors: ConversationTurnError[]; + associatedUserMessageId: string; + }; +}; + /** * This class is responsible for sending a response produced by Bedrock back to AppSync * in a form of mutation. @@ -50,6 +62,36 @@ export class ConversationTurnResponseSender { >(responseMutationRequest); }; + sendErrors = async (errors: ConversationTurnError[]) => { + const responseMutationRequest = this.createMutationErrorsRequest(errors); + this.logger.debug( + 'Sending errors response mutation:', + responseMutationRequest + ); + await this.graphqlRequestExecutor.executeGraphql< + MutationErrorsResponseInput, + void + >(responseMutationRequest); + }; + + private createMutationErrorsRequest = (errors: ConversationTurnError[]) => { + const query = ` + mutation PublishModelResponse($input: ${this.event.responseMutation.inputTypeName}!) { + ${this.event.responseMutation.name}(input: $input) { + ${this.event.responseMutation.selectionSet} + } + } + `; + const variables: MutationErrorsResponseInput = { + input: { + conversationId: this.event.conversationId, + errors, + associatedUserMessageId: this.event.currentMessageId, + }, + }; + return { query, variables }; + }; + private createMutationRequest = (content: ContentBlock[]) => { const query = ` mutation PublishModelResponse($input: ${this.event.responseMutation.inputTypeName}!) { diff --git a/packages/ai-constructs/src/conversation/runtime/errors.ts b/packages/ai-constructs/src/conversation/runtime/errors.ts new file mode 100644 index 0000000000..1d3063dd49 --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/errors.ts @@ -0,0 +1,12 @@ +/** + * Represents validation errors. + */ +export class ValidationError extends Error { + /** + * Creates validation error instance. + */ + constructor(message: string) { + super(message); + this.name = 'ValidationError'; + } +} diff --git a/packages/ai-constructs/src/conversation/runtime/lazy.ts b/packages/ai-constructs/src/conversation/runtime/lazy.ts new file mode 100644 index 0000000000..7f5b2032ca --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/lazy.ts @@ -0,0 +1,17 @@ +/** + * A class that initializes lazily upon usage. + */ +export class Lazy { + #value?: T; + + /** + * Creates lazy instance. + */ + constructor(private readonly valueFactory: () => T) {} + /** + * Gets a value. Value is create at first access. + */ + public get value(): T { + return (this.#value ??= this.valueFactory()); + } +} diff --git a/packages/ai-constructs/src/conversation/runtime/types.ts b/packages/ai-constructs/src/conversation/runtime/types.ts index 330a4ae4c1..3d95030cab 100644 --- a/packages/ai-constructs/src/conversation/runtime/types.ts +++ b/packages/ai-constructs/src/conversation/runtime/types.ts @@ -95,6 +95,11 @@ export type ExecutableTool< execute: (input: TToolInput) => Promise; }; +export type ConversationTurnError = { + errorType: string; + message: string; +}; + export type StreamingResponseChunk = { // always required conversationId: string; diff --git a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts index e309ffc71f..53a9ea44a9 100644 --- a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts +++ b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts @@ -48,6 +48,7 @@ if (process.versions.node) { type ConversationTurnAppSyncResponse = { associatedUserMessageId: string; content: string; + errors?: Array; }; type ConversationMessage = { @@ -78,6 +79,11 @@ type CreateConversationMessageChatInput = ConversationMessage & { associatedUserMessageId?: string; }; +type ConversationTurnError = { + errorType: string; + message: string; +}; + type ConversationTurnAppSyncResponseChunk = { conversationId: string; associatedUserMessageId: string; @@ -87,6 +93,7 @@ type ConversationTurnAppSyncResponseChunk = { contentBlockDoneAtIndex?: number; contentBlockToolUse?: string; stopReason?: string; + errors?: Array; }; /** @@ -330,6 +337,26 @@ class ConversationHandlerTestProject extends TestProjectBase { true ) ); + + await this.executeWithRetry(() => + this.assertDefaultConversationHandlerCanPropagateError( + backendId, + authenticatedUserCredentials.accessToken, + dataUrl, + apolloClient, + true + ) + ); + + await this.executeWithRetry(() => + this.assertDefaultConversationHandlerCanPropagateError( + backendId, + authenticatedUserCredentials.accessToken, + dataUrl, + apolloClient, + false + ) + ); } private assertDefaultConversationHandlerCanExecuteTurn = async ( @@ -378,12 +405,12 @@ class ConversationHandlerTestProject extends TestProjectBase { } await this.insertMessage(apolloClient, message); - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient ); - assert.match(responseContent, /3\.14/); + assert.match(response.content, /3\.14/); }; private assertDefaultConversationHandlerCanExecuteTurnWithImage = async ( @@ -443,14 +470,14 @@ class ConversationHandlerTestProject extends TestProjectBase { ...this.getCommonEventProperties(streamResponse), }; await this.insertMessage(apolloClient, message); - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient ); // The image contains a logo of AWS. Responses may vary, but they should always contain statements below. - assert.match(responseContent, /logo/); - assert.match(responseContent, /(aws)|(AWS)|(Amazon Web Services)/); + assert.match(response.content, /logo/); + assert.match(response.content, /(aws)|(AWS)|(Amazon Web Services)/); }; private assertDefaultConversationHandlerCanExecuteTurnWithDataTool = async ( @@ -517,14 +544,14 @@ class ConversationHandlerTestProject extends TestProjectBase { }, ...this.getCommonEventProperties(streamResponse), }; - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient ); // Assert that tool was used. I.e. that LLM used value returned by the tool. assert.match( - responseContent, + response.content, new RegExp(expectedTemperatureInDataToolScenario.toString()) ); }; @@ -586,7 +613,7 @@ class ConversationHandlerTestProject extends TestProjectBase { }, ...this.getCommonEventProperties(streamResponse), }; - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient @@ -594,17 +621,20 @@ class ConversationHandlerTestProject extends TestProjectBase { // Assert that tool use content blocks are emitted in case LLM selects client tool. // The content blocks are string serialized, but not as a proper JSON, // hence string matching is employed below to detect some signals that tool use blocks kinds were emitted. - assert.match(responseContent, /toolUse/); - assert.match(responseContent, /toolUseId/); + assert.match(response.content, /toolUse/); + assert.match(response.content, /toolUseId/); // Assert that LLM attempts to pass parameter when asking for tool use. - assert.match(responseContent, /"city":"Seattle"/); + assert.match(response.content, /"city":"Seattle"/); }; private executeConversationTurn = async ( event: ConversationTurnEvent, functionName: string, apolloClient: ApolloClient - ): Promise => { + ): Promise<{ + content: string; + errors?: Array; + }> => { console.log( `Sending event conversationId=${event.conversationId} currentMessageId=${event.currentMessageId}` ); @@ -649,6 +679,10 @@ class ConversationHandlerTestProject extends TestProjectBase { contentBlockToolUse conversationId createdAt + errors { + errorType + message + } id owner stopReason @@ -676,6 +710,13 @@ class ConversationHandlerTestProject extends TestProjectBase { assert.ok(chunks); + if (chunks.length === 1 && chunks[0].errors) { + return { + content: '', + errors: chunks[0].errors, + }; + } + chunks.sort((a, b) => { // This is very simplified sort by message,block and delta indexes; let aValue = 1000 * 1000 * a.contentBlockIndex; @@ -699,7 +740,7 @@ class ConversationHandlerTestProject extends TestProjectBase { return accumulated; }, ''); - return content; + return { content }; } const queryResult = await apolloClient.query<{ listConversationMessageAssistantResponses: { @@ -721,6 +762,10 @@ class ConversationHandlerTestProject extends TestProjectBase { updatedAt createdAt content + errors { + errorType + message + } associatedUserMessageId } nextToken @@ -739,8 +784,16 @@ class ConversationHandlerTestProject extends TestProjectBase { ); const response = queryResult.data.listConversationMessageAssistantResponses.items[0]; + + if (response.errors) { + return { + content: '', + errors: response.errors, + }; + } + assert.ok(response.content); - return response.content; + return { content: response.content }; }; private assertCustomConversationHandlerCanExecuteTurn = async ( @@ -780,26 +833,81 @@ class ConversationHandlerTestProject extends TestProjectBase { }, ...this.getCommonEventProperties(streamResponse), }; - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, customConversationHandlerFunction, apolloClient ); // Assert that tool was used. I.e. LLM used value provided by the tool. assert.match( - responseContent, + response.content, new RegExp( expectedTemperaturesInProgrammaticToolScenario.Seattle.toString() ) ); assert.match( - responseContent, + response.content, new RegExp( expectedTemperaturesInProgrammaticToolScenario.Boston.toString() ) ); }; + private assertDefaultConversationHandlerCanPropagateError = async ( + backendId: BackendIdentifier, + accessToken: string, + graphqlApiEndpoint: string, + apolloClient: ApolloClient, + streamResponse: boolean + ): Promise => { + const defaultConversationHandlerFunction = ( + await this.resourceFinder.findByBackendIdentifier( + backendId, + 'AWS::Lambda::Function', + (name) => name.includes('default') + ) + )[0]; + + const message: CreateConversationMessageChatInput = { + id: randomUUID().toString(), + conversationId: randomUUID().toString(), + role: 'user', + content: [ + { + text: 'What is the value of PI?', + }, + ], + }; + + // send event + const event: ConversationTurnEvent = { + conversationId: message.conversationId, + currentMessageId: message.id, + graphqlApiEndpoint: graphqlApiEndpoint, + request: { + headers: { authorization: accessToken }, + }, + ...this.getCommonEventProperties(streamResponse), + }; + + // Inject failure + event.modelConfiguration.modelId = 'invalidId'; + await this.insertMessage(apolloClient, message); + + const response = await this.executeConversationTurn( + event, + defaultConversationHandlerFunction, + apolloClient + ); + assert.ok(response.errors); + assert.ok(response.errors[0]); + assert.strictEqual(response.errors[0].errorType, 'ValidationException'); + assert.match( + response.errors[0].message, + /provided model identifier is invalid/ + ); + }; + private insertMessage = async ( apolloClient: ApolloClient, message: CreateConversationMessageChatInput diff --git a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts index 0b8f9dc75f..07c19400c2 100644 --- a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts +++ b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts @@ -86,11 +86,17 @@ const schema = a.schema({ tools: a.ref('MockTool').array(), }), + MockConversationTurnError: a.customType({ + errorType: a.string(), + message: a.string(), + }), + ConversationMessageAssistantResponse: a .model({ conversationId: a.id(), associatedUserMessageId: a.id(), content: a.string(), + errors: a.ref('MockConversationTurnError').array(), }) .authorization((allow) => [allow.authenticated(), allow.owner()]), @@ -99,7 +105,7 @@ const schema = a.schema({ // always conversationId: a.id().required(), associatedUserMessageId: a.id().required(), - contentBlockIndex: a.integer().required(), + contentBlockIndex: a.integer(), accumulatedTurnContent: a.ref('MockContentBlock').array(), // these describe chunks or end of block @@ -110,6 +116,9 @@ const schema = a.schema({ // when message is complete stopReason: a.string(), + + // error + errors: a.ref('MockConversationTurnError').array(), }) .secondaryIndexes((index) => [ index('conversationId').sortKeys(['associatedUserMessageId']),