From 1bbc81a6ce360afd5f7374e9df9bced7aac632dd Mon Sep 17 00:00:00 2001 From: Kamil Sobol Date: Tue, 29 Oct 2024 12:52:28 -0700 Subject: [PATCH 1/5] Propagate errors to AppSync --- .changeset/fresh-candles-leave.md | 5 + .../runtime/bedrock_converse_adapter.test.ts | 20 +- .../runtime/bedrock_converse_adapter.ts | 447 ++++++++++-------- .../conversation_turn_executor.test.ts | 21 +- .../conversation_turn_response_sender.test.ts | 22 +- .../conversation_turn_response_sender.ts | 59 ++- .../src/conversation/runtime/types.ts | 29 ++ .../conversation_handler_project.ts | 142 +++++- .../amplify/data/resource.ts | 9 + 9 files changed, 492 insertions(+), 262 deletions(-) create mode 100644 .changeset/fresh-candles-leave.md diff --git a/.changeset/fresh-candles-leave.md b/.changeset/fresh-candles-leave.md new file mode 100644 index 0000000000..67c5f741d6 --- /dev/null +++ b/.changeset/fresh-candles-leave.md @@ -0,0 +1,5 @@ +--- +'@aws-amplify/ai-constructs': minor +--- + +Propagate errors to AppSync diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts index ae424828c7..f621ebe3b5 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts @@ -200,7 +200,7 @@ void describe('Bedrock converse adapter', () => { ]); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, content); + assert.deepStrictEqual(responseContent, { content }); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1); @@ -339,7 +339,7 @@ void describe('Bedrock converse adapter', () => { assert.strictEqual(responseText, 'finalResponse'); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, content); + assert.deepStrictEqual(responseContent, { content }); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 3); @@ -510,7 +510,7 @@ void describe('Bedrock converse adapter', () => { assert.strictEqual(responseText, 'finalResponse'); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, content); + assert.deepStrictEqual(responseContent, { content }); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 2); @@ -598,7 +598,7 @@ void describe('Bedrock converse adapter', () => { assert.strictEqual(responseText, 'finalResponse'); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, content); + assert.deepStrictEqual(responseContent, { content }); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 2); @@ -711,11 +711,13 @@ void describe('Bedrock converse adapter', () => { ]); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, [ - { - toolUse: clientToolUse, - }, - ]); + assert.deepStrictEqual(responseContent, { + content: [ + { + toolUse: clientToolUse, + }, + ], + }); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1); diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts index 510925492a..f1bd37cead 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts @@ -13,7 +13,9 @@ import { ToolInputSchema, } from '@aws-sdk/client-bedrock-runtime'; import { + ConversationTurnError, ConversationTurnEvent, + ConversationTurnResponse, ExecutableTool, StreamingResponseChunk, ToolDefinition, @@ -95,80 +97,82 @@ export class BedrockConverseAdapter { } } - askBedrock = async (): Promise => { - const { modelId, systemPrompt, inferenceConfiguration } = - this.event.modelConfiguration; + askBedrock = async (): Promise => { + try { + const { modelId, systemPrompt, inferenceConfiguration } = + this.event.modelConfiguration; - const messages: Array = - await this.getEventMessagesAsBedrockMessages(); + const messages: Array = + await this.getEventMessagesAsBedrockMessages(); - let bedrockResponse: ConverseCommandOutput; - do { - const toolConfig = this.createToolConfiguration(); - const converseCommandInput: ConverseCommandInput = { - modelId, - messages: [...messages], - system: [{ text: systemPrompt }], - inferenceConfig: inferenceConfiguration, - toolConfig, - }; - this.logger.info('Sending Bedrock Converse request'); - this.logger.debug('Bedrock Converse request:', converseCommandInput); - bedrockResponse = await this.bedrockClient.send( - new ConverseCommand(converseCommandInput) - ); - this.logger.info( - `Received Bedrock Converse response, requestId=${bedrockResponse.$metadata.requestId}`, - bedrockResponse.usage - ); - this.logger.debug('Bedrock Converse response:', bedrockResponse); - if (bedrockResponse.output?.message) { - messages.push(bedrockResponse.output?.message); - } - if (bedrockResponse.stopReason === 'tool_use') { - const responseContentBlocks = - bedrockResponse.output?.message?.content ?? []; - const toolUseBlocks = responseContentBlocks.filter( - (block) => 'toolUse' in block - ) as Array; - const clientToolUseBlocks = responseContentBlocks.filter( - (block) => - block.toolUse?.name && - this.clientToolByName.has(block.toolUse?.name) + let bedrockResponse: ConverseCommandOutput; + do { + const toolConfig = this.createToolConfiguration(); + const converseCommandInput: ConverseCommandInput = { + modelId, + messages: [...messages], + system: [{ text: systemPrompt }], + inferenceConfig: inferenceConfiguration, + toolConfig, + }; + this.logger.info('Sending Bedrock Converse request'); + this.logger.debug('Bedrock Converse request:', converseCommandInput); + bedrockResponse = await this.bedrockClient.send( + new ConverseCommand(converseCommandInput) ); - if (clientToolUseBlocks.length > 0) { - // For now if any of client tools is used we ignore executable tools - // and propagate result back to client. - return clientToolUseBlocks; + this.logger.info( + `Received Bedrock Converse response, requestId=${bedrockResponse.$metadata.requestId}`, + bedrockResponse.usage + ); + this.logger.debug('Bedrock Converse response:', bedrockResponse); + if (bedrockResponse.output?.message) { + messages.push(bedrockResponse.output?.message); } - const toolResponseContentBlocks: Array = []; - for (const responseContentBlock of toolUseBlocks) { - const toolUseBlock = - responseContentBlock as ContentBlock.ToolUseMember; - const toolResultContentBlock = await this.executeTool(toolUseBlock); - toolResponseContentBlocks.push(toolResultContentBlock); + if (bedrockResponse.stopReason === 'tool_use') { + const responseContentBlocks = + bedrockResponse.output?.message?.content ?? []; + const toolUseBlocks = responseContentBlocks.filter( + (block) => 'toolUse' in block + ) as Array; + const clientToolUseBlocks = responseContentBlocks.filter( + (block) => + block.toolUse?.name && + this.clientToolByName.has(block.toolUse?.name) + ); + if (clientToolUseBlocks.length > 0) { + // For now if any of client tools is used we ignore executable tools + // and propagate result back to client. + return { content: clientToolUseBlocks }; + } + const toolResponseContentBlocks: Array = []; + for (const responseContentBlock of toolUseBlocks) { + const toolUseBlock = + responseContentBlock as ContentBlock.ToolUseMember; + const toolResultContentBlock = await this.executeTool(toolUseBlock); + toolResponseContentBlocks.push(toolResultContentBlock); + } + messages.push({ + role: 'user', + content: toolResponseContentBlocks, + }); } - messages.push({ - role: 'user', - content: toolResponseContentBlocks, - }); - } - } while (bedrockResponse.stopReason === 'tool_use'); + } while (bedrockResponse.stopReason === 'tool_use'); - return bedrockResponse.output?.message?.content ?? []; + return { content: bedrockResponse.output?.message?.content ?? [] }; + } catch (error) { + console.error('Conversation with Bedrock failed', error); + const conversationTurnError = + this.convertErrorToConversationTurnError(error); + return { + errors: [conversationTurnError], + }; + } }; /** * Asks Bedrock for response using streaming version of Converse API. */ async *askBedrockStreaming(): AsyncGenerator { - const { modelId, systemPrompt, inferenceConfiguration } = - this.event.modelConfiguration; - - const messages: Array = - await this.getEventMessagesAsBedrockMessages(); - - let bedrockResponse: ConverseStreamCommandOutput; // keep our own indexing for blocks instead of using Bedrock's indexes // since we stream subset of these upstream. let blockIndex = 0; @@ -177,160 +181,182 @@ export class BedrockConverseAdapter { // Accumulates client facing content per turn. // So that upstream can persist full message at the end of the streaming. const accumulatedTurnContent: Array = []; - do { - const toolConfig = this.createToolConfiguration(); - const converseCommandInput: ConverseStreamCommandInput = { - modelId, - messages: [...messages], - system: [{ text: systemPrompt }], - inferenceConfig: inferenceConfiguration, - toolConfig, - }; - this.logger.info('Sending Bedrock Converse Stream request'); - this.logger.debug( - 'Bedrock Converse Stream request:', - converseCommandInput - ); - bedrockResponse = await this.bedrockClient.send( - new ConverseStreamCommand(converseCommandInput) - ); - this.logger.info( - `Received Bedrock Converse Stream response, requestId=${bedrockResponse.$metadata.requestId}` - ); - if (!bedrockResponse.stream) { - throw new Error('Bedrock response is missing stream'); - } - let toolUseBlock: ContentBlock.ToolUseMember | undefined; - let clientToolsRequested = false; - let text: string = ''; - let toolUseInput: string = ''; - let blockDeltaIndex = 0; - let lastBlockDeltaIndex = 0; - // Accumulate current message for the tool use loop purpose. - const accumulatedAssistantMessage: Message = { - role: undefined, - content: [], - }; - for await (const chunk of bedrockResponse.stream) { - this.logger.debug('Bedrock Converse Stream response chunk:', chunk); - if (chunk.messageStart) { - accumulatedAssistantMessage.role = chunk.messageStart.role; - } else if (chunk.contentBlockStart) { - blockDeltaIndex = 0; - lastBlockDeltaIndex = 0; - if (chunk.contentBlockStart.start?.toolUse) { - toolUseBlock = { - toolUse: { - ...chunk.contentBlockStart.start?.toolUse, - input: undefined, - }, - }; - } - } else if (chunk.contentBlockDelta) { - if (chunk.contentBlockDelta.delta?.toolUse) { - if (!chunk.contentBlockDelta.delta.toolUse.input) { - toolUseInput = ''; + try { + const { modelId, systemPrompt, inferenceConfiguration } = + this.event.modelConfiguration; + + const messages: Array = + await this.getEventMessagesAsBedrockMessages(); + + let bedrockResponse: ConverseStreamCommandOutput; + do { + const toolConfig = this.createToolConfiguration(); + const converseCommandInput: ConverseStreamCommandInput = { + modelId, + messages: [...messages], + system: [{ text: systemPrompt }], + inferenceConfig: inferenceConfiguration, + toolConfig, + }; + this.logger.info('Sending Bedrock Converse Stream request'); + this.logger.debug( + 'Bedrock Converse Stream request:', + converseCommandInput + ); + bedrockResponse = await this.bedrockClient.send( + new ConverseStreamCommand(converseCommandInput) + ); + this.logger.info( + `Received Bedrock Converse Stream response, requestId=${bedrockResponse.$metadata.requestId}` + ); + if (!bedrockResponse.stream) { + throw new Error('Bedrock response is missing stream'); + } + let toolUseBlock: ContentBlock.ToolUseMember | undefined; + let clientToolsRequested = false; + let text: string = ''; + let toolUseInput: string = ''; + let blockDeltaIndex = 0; + let lastBlockDeltaIndex = 0; + // Accumulate current message for the tool use loop purpose. + const accumulatedAssistantMessage: Message = { + role: undefined, + content: [], + }; + + for await (const chunk of bedrockResponse.stream) { + this.logger.debug('Bedrock Converse Stream response chunk:', chunk); + if (chunk.messageStart) { + accumulatedAssistantMessage.role = chunk.messageStart.role; + } else if (chunk.contentBlockStart) { + blockDeltaIndex = 0; + lastBlockDeltaIndex = 0; + if (chunk.contentBlockStart.start?.toolUse) { + toolUseBlock = { + toolUse: { + ...chunk.contentBlockStart.start?.toolUse, + input: undefined, + }, + }; } - toolUseInput += chunk.contentBlockDelta.delta.toolUse.input; - } else if (chunk.contentBlockDelta.delta?.text) { - text += chunk.contentBlockDelta.delta.text; - yield { - accumulatedTurnContent: [...accumulatedTurnContent, { text }], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockText: chunk.contentBlockDelta.delta.text, - contentBlockIndex: blockIndex, - contentBlockDeltaIndex: blockDeltaIndex, - }; - lastBlockDeltaIndex = blockDeltaIndex; - blockDeltaIndex++; - } - } else if (chunk.contentBlockStop) { - if (toolUseBlock) { - toolUseBlock.toolUse.input = JSON.parse(toolUseInput); - accumulatedAssistantMessage.content?.push(toolUseBlock); - if ( - toolUseBlock.toolUse.name && - this.clientToolByName.has(toolUseBlock.toolUse.name) - ) { - clientToolsRequested = true; - accumulatedTurnContent.push(toolUseBlock); + } else if (chunk.contentBlockDelta) { + if (chunk.contentBlockDelta.delta?.toolUse) { + if (!chunk.contentBlockDelta.delta.toolUse.input) { + toolUseInput = ''; + } + toolUseInput += chunk.contentBlockDelta.delta.toolUse.input; + } else if (chunk.contentBlockDelta.delta?.text) { + text += chunk.contentBlockDelta.delta.text; + yield { + accumulatedTurnContent: [...accumulatedTurnContent, { text }], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockText: chunk.contentBlockDelta.delta.text, + contentBlockIndex: blockIndex, + contentBlockDeltaIndex: blockDeltaIndex, + }; + lastBlockDeltaIndex = blockDeltaIndex; + blockDeltaIndex++; + } + } else if (chunk.contentBlockStop) { + if (toolUseBlock) { + toolUseBlock.toolUse.input = JSON.parse(toolUseInput); + accumulatedAssistantMessage.content?.push(toolUseBlock); + if ( + toolUseBlock.toolUse.name && + this.clientToolByName.has(toolUseBlock.toolUse.name) + ) { + clientToolsRequested = true; + accumulatedTurnContent.push(toolUseBlock); + yield { + accumulatedTurnContent: [...accumulatedTurnContent], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockIndex: blockIndex, + contentBlockToolUse: JSON.stringify(toolUseBlock), + }; + lastBlockIndex = blockIndex; + blockIndex++; + } + toolUseBlock = undefined; + toolUseInput = ''; + } else { + accumulatedAssistantMessage.content?.push({ + text, + }); + accumulatedTurnContent.push({ text }); yield { accumulatedTurnContent: [...accumulatedTurnContent], conversationId: this.event.conversationId, associatedUserMessageId: this.event.currentMessageId, contentBlockIndex: blockIndex, - contentBlockToolUse: JSON.stringify(toolUseBlock), + contentBlockDoneAtIndex: lastBlockDeltaIndex, }; + text = ''; lastBlockIndex = blockIndex; blockIndex++; } - toolUseBlock = undefined; - toolUseInput = ''; - } else { - accumulatedAssistantMessage.content?.push({ - text, - }); - accumulatedTurnContent.push({ text }); - yield { - accumulatedTurnContent: [...accumulatedTurnContent], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockIndex: blockIndex, - contentBlockDoneAtIndex: lastBlockDeltaIndex, - }; - text = ''; - lastBlockIndex = blockIndex; - blockIndex++; + } else if (chunk.messageStop) { + stopReason = chunk.messageStop.stopReason ?? ''; } - } else if (chunk.messageStop) { - stopReason = chunk.messageStop.stopReason ?? ''; } - } - this.logger.debug( - 'Accumulated Bedrock Converse Stream response:', - accumulatedAssistantMessage - ); - if (clientToolsRequested) { - // For now if any of client tools is used we ignore executable tools - // and propagate result back to client. - yield { - accumulatedTurnContent: [...accumulatedTurnContent], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockIndex: lastBlockIndex, - stopReason: stopReason, - }; - return; - } - messages.push(accumulatedAssistantMessage); - if (stopReason === 'tool_use') { - const responseContentBlocks = accumulatedAssistantMessage.content ?? []; - const toolUseBlocks = responseContentBlocks.filter( - (block) => 'toolUse' in block - ) as Array; - const toolResponseContentBlocks: Array = []; - for (const responseContentBlock of toolUseBlocks) { - const toolUseBlock = - responseContentBlock as ContentBlock.ToolUseMember; - const toolResultContentBlock = await this.executeTool(toolUseBlock); - toolResponseContentBlocks.push(toolResultContentBlock); + this.logger.debug( + 'Accumulated Bedrock Converse Stream response:', + accumulatedAssistantMessage + ); + if (clientToolsRequested) { + // For now if any of client tools is used we ignore executable tools + // and propagate result back to client. + yield { + accumulatedTurnContent: [...accumulatedTurnContent], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockIndex: lastBlockIndex, + stopReason: stopReason, + }; + return; } - messages.push({ - role: 'user', - content: toolResponseContentBlocks, - }); - } - } while (stopReason === 'tool_use'); + messages.push(accumulatedAssistantMessage); + if (stopReason === 'tool_use') { + const responseContentBlocks = + accumulatedAssistantMessage.content ?? []; + const toolUseBlocks = responseContentBlocks.filter( + (block) => 'toolUse' in block + ) as Array; + const toolResponseContentBlocks: Array = []; + for (const responseContentBlock of toolUseBlocks) { + const toolUseBlock = + responseContentBlock as ContentBlock.ToolUseMember; + const toolResultContentBlock = await this.executeTool(toolUseBlock); + toolResponseContentBlocks.push(toolResultContentBlock); + } + messages.push({ + role: 'user', + content: toolResponseContentBlocks, + }); + } + } while (stopReason === 'tool_use'); - yield { - accumulatedTurnContent: [...accumulatedTurnContent], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockIndex: lastBlockIndex, - stopReason: stopReason, - }; + yield { + accumulatedTurnContent: [...accumulatedTurnContent], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockIndex: lastBlockIndex, + stopReason: stopReason, + }; + } catch (error) { + console.error('Streaming conversation with Bedrock failed', error); + const conversationTurnError = + this.convertErrorToConversationTurnError(error); + yield { + accumulatedTurnContent: [...accumulatedTurnContent], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockIndex: blockIndex, + errors: [conversationTurnError], + }; + } } /** @@ -433,4 +459,23 @@ export class BedrockConverseAdapter { }; } }; + + private convertErrorToConversationTurnError = ( + error: unknown + ): ConversationTurnError => { + let errorType = 'UnknownError'; + let message: string; + if (error instanceof Error) { + message = error.message; + if (error.name) { + errorType = error.name; + } + } else { + message = JSON.stringify(error); + } + return { + errorType, + message, + }; + }; } diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts index be96f22fa3..6d71fbf5ff 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts @@ -1,9 +1,12 @@ import { describe, it, mock } from 'node:test'; import assert from 'node:assert'; import { ConversationTurnExecutor } from './conversation_turn_executor'; -import { ConversationTurnEvent, StreamingResponseChunk } from './types'; +import { + ConversationTurnEvent, + ConversationTurnResponse, + StreamingResponseChunk, +} from './types'; import { BedrockConverseAdapter } from './bedrock_converse_adapter'; -import { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { ConversationTurnResponseSender } from './conversation_turn_response_sender'; void describe('Conversation turn executor', () => { @@ -28,10 +31,9 @@ void describe('Conversation turn executor', () => { void it('executes turn successfully', async () => { const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); - const bedrockResponse: Array = [ - { text: 'block1' }, - { text: 'block2' }, - ]; + const bedrockResponse: ConversationTurnResponse = { + content: [{ text: 'block1' }, { text: 'block2' }], + }; const bedrockConverseAdapterAskBedrockMock = mock.method( bedrockConverseAdapter, 'askBedrock', @@ -267,10 +269,9 @@ void describe('Conversation turn executor', () => { void it('logs and propagates error if response sender throws', async () => { const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); - const bedrockResponse: Array = [ - { text: 'block1' }, - { text: 'block2' }, - ]; + const bedrockResponse: ConversationTurnResponse = { + content: [{ text: 'block1' }, { text: 'block2' }], + }; const bedrockConverseAdapterAskBedrockMock = mock.method( bedrockConverseAdapter, 'askBedrock', diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts index b38ccbdff9..fad4130275 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts @@ -5,7 +5,11 @@ import { MutationResponseInput, MutationStreamingResponseInput, } from './conversation_turn_response_sender'; -import { ConversationTurnEvent, StreamingResponseChunk } from './types'; +import { + ConversationTurnEvent, + ConversationTurnResponse, + StreamingResponseChunk, +} from './types'; import { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { GraphqlRequest, @@ -45,12 +49,14 @@ void describe('Conversation turn response sender', () => { event, graphqlRequestExecutor ); - const response: Array = [ - { - text: 'block1', - }, - { text: 'block2' }, - ]; + const response: ConversationTurnResponse = { + content: [ + { + text: 'block1', + }, + { text: 'block2' }, + ], + }; await sender.sendResponse(response); assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); @@ -102,7 +108,7 @@ void describe('Conversation turn response sender', () => { }, }, }; - const response: Array = [toolUseBlock]; + const response: ConversationTurnResponse = { content: [toolUseBlock] }; await sender.sendResponse(response); assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts index 7590f5edcf..7ac14022c6 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts @@ -1,13 +1,26 @@ -import { ConversationTurnEvent, StreamingResponseChunk } from './types.js'; +import { + ConversationTurnError, + ConversationTurnEvent, + ConversationTurnResponse, + StreamingResponseChunk, +} from './types.js'; import type { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { GraphqlRequestExecutor } from './graphql_request_executor'; export type MutationResponseInput = { - input: { - conversationId: string; - content: ContentBlock[]; - associatedUserMessageId: string; - }; + input: + | { + associatedUserMessageId: string; + conversationId: string; + content: ContentBlock[]; + errors?: never; + } + | { + associatedUserMessageId: string; + conversationId: string; + content?: never; + errors: ConversationTurnError[]; + }; }; export type MutationStreamingResponseInput = { @@ -32,8 +45,8 @@ export class ConversationTurnResponseSender { private readonly logger = console ) {} - sendResponse = async (message: ContentBlock[]) => { - const responseMutationRequest = this.createMutationRequest(message); + sendResponse = async (response: ConversationTurnResponse) => { + const responseMutationRequest = this.createMutationRequest(response); this.logger.debug('Sending response mutation:', responseMutationRequest); await this.graphqlRequestExecutor.executeGraphql< MutationResponseInput, @@ -50,7 +63,7 @@ export class ConversationTurnResponseSender { >(responseMutationRequest); }; - private createMutationRequest = (content: ContentBlock[]) => { + private createMutationRequest = (response: ConversationTurnResponse) => { const query = ` mutation PublishModelResponse($input: ${this.event.responseMutation.inputTypeName}!) { ${this.event.responseMutation.name}(input: $input) { @@ -58,14 +71,26 @@ export class ConversationTurnResponseSender { } } `; - content = this.serializeContent(content); - const variables: MutationResponseInput = { - input: { - conversationId: this.event.conversationId, - content, - associatedUserMessageId: this.event.currentMessageId, - }, - }; + let variables: MutationResponseInput; + if (typeof response.content !== 'undefined') { + variables = { + input: { + conversationId: this.event.conversationId, + content: this.serializeContent(response.content), + associatedUserMessageId: this.event.currentMessageId, + }, + }; + } else if (typeof response.errors !== 'undefined') { + variables = { + input: { + conversationId: this.event.conversationId, + errors: response.errors, + associatedUserMessageId: this.event.currentMessageId, + }, + }; + } else { + throw new Error('Response contains neither content nor error'); + } return { query, variables }; }; diff --git a/packages/ai-constructs/src/conversation/runtime/types.ts b/packages/ai-constructs/src/conversation/runtime/types.ts index 330a4ae4c1..cdf925eb81 100644 --- a/packages/ai-constructs/src/conversation/runtime/types.ts +++ b/packages/ai-constructs/src/conversation/runtime/types.ts @@ -1,5 +1,6 @@ import * as bedrock from '@aws-sdk/client-bedrock-runtime'; import * as jsonSchemaToTypeScript from 'json-schema-to-ts'; +import type { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; /* Notice: This file contains types that are exposed publicly. @@ -95,6 +96,21 @@ export type ExecutableTool< execute: (input: TToolInput) => Promise; }; +export type ConversationTurnError = { + errorType: string; + message: string; +}; + +export type ConversationTurnResponse = + | { + content: ContentBlock[]; + errors?: never; + } + | { + content?: never; + errors: ConversationTurnError[]; + }; + export type StreamingResponseChunk = { // always required conversationId: string; @@ -108,6 +124,7 @@ export type StreamingResponseChunk = { contentBlockDeltaIndex: number; contentBlockDoneAtIndex?: never; contentBlockToolUse?: never; + errors?: never; stopReason?: never; } | { @@ -116,6 +133,7 @@ export type StreamingResponseChunk = { contentBlockText?: never; contentBlockDeltaIndex?: never; contentBlockToolUse?: never; + errors?: never; stopReason?: never; } | { @@ -124,6 +142,7 @@ export type StreamingResponseChunk = { contentBlockDoneAtIndex?: never; contentBlockText?: never; contentBlockDeltaIndex?: never; + errors?: never; stopReason?: never; } | { @@ -133,5 +152,15 @@ export type StreamingResponseChunk = { contentBlockText?: never; contentBlockDeltaIndex?: never; contentBlockToolUse?: never; + errors?: never; + } + | { + // error + errors: ConversationTurnError[]; + stopReason?: never; + contentBlockDoneAtIndex?: never; + contentBlockText?: never; + contentBlockDeltaIndex?: never; + contentBlockToolUse?: never; } ); diff --git a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts index e309ffc71f..53a9ea44a9 100644 --- a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts +++ b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts @@ -48,6 +48,7 @@ if (process.versions.node) { type ConversationTurnAppSyncResponse = { associatedUserMessageId: string; content: string; + errors?: Array; }; type ConversationMessage = { @@ -78,6 +79,11 @@ type CreateConversationMessageChatInput = ConversationMessage & { associatedUserMessageId?: string; }; +type ConversationTurnError = { + errorType: string; + message: string; +}; + type ConversationTurnAppSyncResponseChunk = { conversationId: string; associatedUserMessageId: string; @@ -87,6 +93,7 @@ type ConversationTurnAppSyncResponseChunk = { contentBlockDoneAtIndex?: number; contentBlockToolUse?: string; stopReason?: string; + errors?: Array; }; /** @@ -330,6 +337,26 @@ class ConversationHandlerTestProject extends TestProjectBase { true ) ); + + await this.executeWithRetry(() => + this.assertDefaultConversationHandlerCanPropagateError( + backendId, + authenticatedUserCredentials.accessToken, + dataUrl, + apolloClient, + true + ) + ); + + await this.executeWithRetry(() => + this.assertDefaultConversationHandlerCanPropagateError( + backendId, + authenticatedUserCredentials.accessToken, + dataUrl, + apolloClient, + false + ) + ); } private assertDefaultConversationHandlerCanExecuteTurn = async ( @@ -378,12 +405,12 @@ class ConversationHandlerTestProject extends TestProjectBase { } await this.insertMessage(apolloClient, message); - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient ); - assert.match(responseContent, /3\.14/); + assert.match(response.content, /3\.14/); }; private assertDefaultConversationHandlerCanExecuteTurnWithImage = async ( @@ -443,14 +470,14 @@ class ConversationHandlerTestProject extends TestProjectBase { ...this.getCommonEventProperties(streamResponse), }; await this.insertMessage(apolloClient, message); - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient ); // The image contains a logo of AWS. Responses may vary, but they should always contain statements below. - assert.match(responseContent, /logo/); - assert.match(responseContent, /(aws)|(AWS)|(Amazon Web Services)/); + assert.match(response.content, /logo/); + assert.match(response.content, /(aws)|(AWS)|(Amazon Web Services)/); }; private assertDefaultConversationHandlerCanExecuteTurnWithDataTool = async ( @@ -517,14 +544,14 @@ class ConversationHandlerTestProject extends TestProjectBase { }, ...this.getCommonEventProperties(streamResponse), }; - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient ); // Assert that tool was used. I.e. that LLM used value returned by the tool. assert.match( - responseContent, + response.content, new RegExp(expectedTemperatureInDataToolScenario.toString()) ); }; @@ -586,7 +613,7 @@ class ConversationHandlerTestProject extends TestProjectBase { }, ...this.getCommonEventProperties(streamResponse), }; - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, defaultConversationHandlerFunction, apolloClient @@ -594,17 +621,20 @@ class ConversationHandlerTestProject extends TestProjectBase { // Assert that tool use content blocks are emitted in case LLM selects client tool. // The content blocks are string serialized, but not as a proper JSON, // hence string matching is employed below to detect some signals that tool use blocks kinds were emitted. - assert.match(responseContent, /toolUse/); - assert.match(responseContent, /toolUseId/); + assert.match(response.content, /toolUse/); + assert.match(response.content, /toolUseId/); // Assert that LLM attempts to pass parameter when asking for tool use. - assert.match(responseContent, /"city":"Seattle"/); + assert.match(response.content, /"city":"Seattle"/); }; private executeConversationTurn = async ( event: ConversationTurnEvent, functionName: string, apolloClient: ApolloClient - ): Promise => { + ): Promise<{ + content: string; + errors?: Array; + }> => { console.log( `Sending event conversationId=${event.conversationId} currentMessageId=${event.currentMessageId}` ); @@ -649,6 +679,10 @@ class ConversationHandlerTestProject extends TestProjectBase { contentBlockToolUse conversationId createdAt + errors { + errorType + message + } id owner stopReason @@ -676,6 +710,13 @@ class ConversationHandlerTestProject extends TestProjectBase { assert.ok(chunks); + if (chunks.length === 1 && chunks[0].errors) { + return { + content: '', + errors: chunks[0].errors, + }; + } + chunks.sort((a, b) => { // This is very simplified sort by message,block and delta indexes; let aValue = 1000 * 1000 * a.contentBlockIndex; @@ -699,7 +740,7 @@ class ConversationHandlerTestProject extends TestProjectBase { return accumulated; }, ''); - return content; + return { content }; } const queryResult = await apolloClient.query<{ listConversationMessageAssistantResponses: { @@ -721,6 +762,10 @@ class ConversationHandlerTestProject extends TestProjectBase { updatedAt createdAt content + errors { + errorType + message + } associatedUserMessageId } nextToken @@ -739,8 +784,16 @@ class ConversationHandlerTestProject extends TestProjectBase { ); const response = queryResult.data.listConversationMessageAssistantResponses.items[0]; + + if (response.errors) { + return { + content: '', + errors: response.errors, + }; + } + assert.ok(response.content); - return response.content; + return { content: response.content }; }; private assertCustomConversationHandlerCanExecuteTurn = async ( @@ -780,26 +833,81 @@ class ConversationHandlerTestProject extends TestProjectBase { }, ...this.getCommonEventProperties(streamResponse), }; - const responseContent = await this.executeConversationTurn( + const response = await this.executeConversationTurn( event, customConversationHandlerFunction, apolloClient ); // Assert that tool was used. I.e. LLM used value provided by the tool. assert.match( - responseContent, + response.content, new RegExp( expectedTemperaturesInProgrammaticToolScenario.Seattle.toString() ) ); assert.match( - responseContent, + response.content, new RegExp( expectedTemperaturesInProgrammaticToolScenario.Boston.toString() ) ); }; + private assertDefaultConversationHandlerCanPropagateError = async ( + backendId: BackendIdentifier, + accessToken: string, + graphqlApiEndpoint: string, + apolloClient: ApolloClient, + streamResponse: boolean + ): Promise => { + const defaultConversationHandlerFunction = ( + await this.resourceFinder.findByBackendIdentifier( + backendId, + 'AWS::Lambda::Function', + (name) => name.includes('default') + ) + )[0]; + + const message: CreateConversationMessageChatInput = { + id: randomUUID().toString(), + conversationId: randomUUID().toString(), + role: 'user', + content: [ + { + text: 'What is the value of PI?', + }, + ], + }; + + // send event + const event: ConversationTurnEvent = { + conversationId: message.conversationId, + currentMessageId: message.id, + graphqlApiEndpoint: graphqlApiEndpoint, + request: { + headers: { authorization: accessToken }, + }, + ...this.getCommonEventProperties(streamResponse), + }; + + // Inject failure + event.modelConfiguration.modelId = 'invalidId'; + await this.insertMessage(apolloClient, message); + + const response = await this.executeConversationTurn( + event, + defaultConversationHandlerFunction, + apolloClient + ); + assert.ok(response.errors); + assert.ok(response.errors[0]); + assert.strictEqual(response.errors[0].errorType, 'ValidationException'); + assert.match( + response.errors[0].message, + /provided model identifier is invalid/ + ); + }; + private insertMessage = async ( apolloClient: ApolloClient, message: CreateConversationMessageChatInput diff --git a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts index 0b8f9dc75f..7e924e5aff 100644 --- a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts +++ b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts @@ -86,11 +86,17 @@ const schema = a.schema({ tools: a.ref('MockTool').array(), }), + MockConversationTurnError: a.customType({ + errorType: a.string(), + message: a.string(), + }), + ConversationMessageAssistantResponse: a .model({ conversationId: a.id(), associatedUserMessageId: a.id(), content: a.string(), + errors: a.ref('MockConversationTurnError').array(), }) .authorization((allow) => [allow.authenticated(), allow.owner()]), @@ -110,6 +116,9 @@ const schema = a.schema({ // when message is complete stopReason: a.string(), + + // error + errors: a.ref('MockConversationTurnError').array(), }) .secondaryIndexes((index) => [ index('conversationId').sortKeys(['associatedUserMessageId']), From 98df270688f2a7fc5e21176c79e383d9341ca70b Mon Sep 17 00:00:00 2001 From: Kamil Sobol Date: Wed, 30 Oct 2024 14:32:26 -0700 Subject: [PATCH 2/5] this works --- .../runtime/bedrock_converse_adapter.test.ts | 20 +- .../runtime/bedrock_converse_adapter.ts | 447 ++++++++---------- .../conversation_turn_executor.test.ts | 56 ++- .../runtime/conversation_turn_executor.ts | 37 +- .../conversation_turn_response_sender.test.ts | 74 ++- .../conversation_turn_response_sender.ts | 91 ++-- .../src/conversation/runtime/lazy.ts | 17 + .../src/conversation/runtime/types.ts | 24 - .../amplify/data/resource.ts | 2 +- 9 files changed, 412 insertions(+), 356 deletions(-) create mode 100644 packages/ai-constructs/src/conversation/runtime/lazy.ts diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts index f621ebe3b5..ae424828c7 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.test.ts @@ -200,7 +200,7 @@ void describe('Bedrock converse adapter', () => { ]); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, { content }); + assert.deepStrictEqual(responseContent, content); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1); @@ -339,7 +339,7 @@ void describe('Bedrock converse adapter', () => { assert.strictEqual(responseText, 'finalResponse'); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, { content }); + assert.deepStrictEqual(responseContent, content); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 3); @@ -510,7 +510,7 @@ void describe('Bedrock converse adapter', () => { assert.strictEqual(responseText, 'finalResponse'); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, { content }); + assert.deepStrictEqual(responseContent, content); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 2); @@ -598,7 +598,7 @@ void describe('Bedrock converse adapter', () => { assert.strictEqual(responseText, 'finalResponse'); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, { content }); + assert.deepStrictEqual(responseContent, content); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 2); @@ -711,13 +711,11 @@ void describe('Bedrock converse adapter', () => { ]); } else { const responseContent = await adapter.askBedrock(); - assert.deepStrictEqual(responseContent, { - content: [ - { - toolUse: clientToolUse, - }, - ], - }); + assert.deepStrictEqual(responseContent, [ + { + toolUse: clientToolUse, + }, + ]); } assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1); diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts index f1bd37cead..510925492a 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts @@ -13,9 +13,7 @@ import { ToolInputSchema, } from '@aws-sdk/client-bedrock-runtime'; import { - ConversationTurnError, ConversationTurnEvent, - ConversationTurnResponse, ExecutableTool, StreamingResponseChunk, ToolDefinition, @@ -97,82 +95,80 @@ export class BedrockConverseAdapter { } } - askBedrock = async (): Promise => { - try { - const { modelId, systemPrompt, inferenceConfiguration } = - this.event.modelConfiguration; + askBedrock = async (): Promise => { + const { modelId, systemPrompt, inferenceConfiguration } = + this.event.modelConfiguration; - const messages: Array = - await this.getEventMessagesAsBedrockMessages(); + const messages: Array = + await this.getEventMessagesAsBedrockMessages(); - let bedrockResponse: ConverseCommandOutput; - do { - const toolConfig = this.createToolConfiguration(); - const converseCommandInput: ConverseCommandInput = { - modelId, - messages: [...messages], - system: [{ text: systemPrompt }], - inferenceConfig: inferenceConfiguration, - toolConfig, - }; - this.logger.info('Sending Bedrock Converse request'); - this.logger.debug('Bedrock Converse request:', converseCommandInput); - bedrockResponse = await this.bedrockClient.send( - new ConverseCommand(converseCommandInput) - ); - this.logger.info( - `Received Bedrock Converse response, requestId=${bedrockResponse.$metadata.requestId}`, - bedrockResponse.usage + let bedrockResponse: ConverseCommandOutput; + do { + const toolConfig = this.createToolConfiguration(); + const converseCommandInput: ConverseCommandInput = { + modelId, + messages: [...messages], + system: [{ text: systemPrompt }], + inferenceConfig: inferenceConfiguration, + toolConfig, + }; + this.logger.info('Sending Bedrock Converse request'); + this.logger.debug('Bedrock Converse request:', converseCommandInput); + bedrockResponse = await this.bedrockClient.send( + new ConverseCommand(converseCommandInput) + ); + this.logger.info( + `Received Bedrock Converse response, requestId=${bedrockResponse.$metadata.requestId}`, + bedrockResponse.usage + ); + this.logger.debug('Bedrock Converse response:', bedrockResponse); + if (bedrockResponse.output?.message) { + messages.push(bedrockResponse.output?.message); + } + if (bedrockResponse.stopReason === 'tool_use') { + const responseContentBlocks = + bedrockResponse.output?.message?.content ?? []; + const toolUseBlocks = responseContentBlocks.filter( + (block) => 'toolUse' in block + ) as Array; + const clientToolUseBlocks = responseContentBlocks.filter( + (block) => + block.toolUse?.name && + this.clientToolByName.has(block.toolUse?.name) ); - this.logger.debug('Bedrock Converse response:', bedrockResponse); - if (bedrockResponse.output?.message) { - messages.push(bedrockResponse.output?.message); + if (clientToolUseBlocks.length > 0) { + // For now if any of client tools is used we ignore executable tools + // and propagate result back to client. + return clientToolUseBlocks; } - if (bedrockResponse.stopReason === 'tool_use') { - const responseContentBlocks = - bedrockResponse.output?.message?.content ?? []; - const toolUseBlocks = responseContentBlocks.filter( - (block) => 'toolUse' in block - ) as Array; - const clientToolUseBlocks = responseContentBlocks.filter( - (block) => - block.toolUse?.name && - this.clientToolByName.has(block.toolUse?.name) - ); - if (clientToolUseBlocks.length > 0) { - // For now if any of client tools is used we ignore executable tools - // and propagate result back to client. - return { content: clientToolUseBlocks }; - } - const toolResponseContentBlocks: Array = []; - for (const responseContentBlock of toolUseBlocks) { - const toolUseBlock = - responseContentBlock as ContentBlock.ToolUseMember; - const toolResultContentBlock = await this.executeTool(toolUseBlock); - toolResponseContentBlocks.push(toolResultContentBlock); - } - messages.push({ - role: 'user', - content: toolResponseContentBlocks, - }); + const toolResponseContentBlocks: Array = []; + for (const responseContentBlock of toolUseBlocks) { + const toolUseBlock = + responseContentBlock as ContentBlock.ToolUseMember; + const toolResultContentBlock = await this.executeTool(toolUseBlock); + toolResponseContentBlocks.push(toolResultContentBlock); } - } while (bedrockResponse.stopReason === 'tool_use'); + messages.push({ + role: 'user', + content: toolResponseContentBlocks, + }); + } + } while (bedrockResponse.stopReason === 'tool_use'); - return { content: bedrockResponse.output?.message?.content ?? [] }; - } catch (error) { - console.error('Conversation with Bedrock failed', error); - const conversationTurnError = - this.convertErrorToConversationTurnError(error); - return { - errors: [conversationTurnError], - }; - } + return bedrockResponse.output?.message?.content ?? []; }; /** * Asks Bedrock for response using streaming version of Converse API. */ async *askBedrockStreaming(): AsyncGenerator { + const { modelId, systemPrompt, inferenceConfiguration } = + this.event.modelConfiguration; + + const messages: Array = + await this.getEventMessagesAsBedrockMessages(); + + let bedrockResponse: ConverseStreamCommandOutput; // keep our own indexing for blocks instead of using Bedrock's indexes // since we stream subset of these upstream. let blockIndex = 0; @@ -181,182 +177,160 @@ export class BedrockConverseAdapter { // Accumulates client facing content per turn. // So that upstream can persist full message at the end of the streaming. const accumulatedTurnContent: Array = []; + do { + const toolConfig = this.createToolConfiguration(); + const converseCommandInput: ConverseStreamCommandInput = { + modelId, + messages: [...messages], + system: [{ text: systemPrompt }], + inferenceConfig: inferenceConfiguration, + toolConfig, + }; + this.logger.info('Sending Bedrock Converse Stream request'); + this.logger.debug( + 'Bedrock Converse Stream request:', + converseCommandInput + ); + bedrockResponse = await this.bedrockClient.send( + new ConverseStreamCommand(converseCommandInput) + ); + this.logger.info( + `Received Bedrock Converse Stream response, requestId=${bedrockResponse.$metadata.requestId}` + ); + if (!bedrockResponse.stream) { + throw new Error('Bedrock response is missing stream'); + } + let toolUseBlock: ContentBlock.ToolUseMember | undefined; + let clientToolsRequested = false; + let text: string = ''; + let toolUseInput: string = ''; + let blockDeltaIndex = 0; + let lastBlockDeltaIndex = 0; + // Accumulate current message for the tool use loop purpose. + const accumulatedAssistantMessage: Message = { + role: undefined, + content: [], + }; - try { - const { modelId, systemPrompt, inferenceConfiguration } = - this.event.modelConfiguration; - - const messages: Array = - await this.getEventMessagesAsBedrockMessages(); - - let bedrockResponse: ConverseStreamCommandOutput; - do { - const toolConfig = this.createToolConfiguration(); - const converseCommandInput: ConverseStreamCommandInput = { - modelId, - messages: [...messages], - system: [{ text: systemPrompt }], - inferenceConfig: inferenceConfiguration, - toolConfig, - }; - this.logger.info('Sending Bedrock Converse Stream request'); - this.logger.debug( - 'Bedrock Converse Stream request:', - converseCommandInput - ); - bedrockResponse = await this.bedrockClient.send( - new ConverseStreamCommand(converseCommandInput) - ); - this.logger.info( - `Received Bedrock Converse Stream response, requestId=${bedrockResponse.$metadata.requestId}` - ); - if (!bedrockResponse.stream) { - throw new Error('Bedrock response is missing stream'); - } - let toolUseBlock: ContentBlock.ToolUseMember | undefined; - let clientToolsRequested = false; - let text: string = ''; - let toolUseInput: string = ''; - let blockDeltaIndex = 0; - let lastBlockDeltaIndex = 0; - // Accumulate current message for the tool use loop purpose. - const accumulatedAssistantMessage: Message = { - role: undefined, - content: [], - }; - - for await (const chunk of bedrockResponse.stream) { - this.logger.debug('Bedrock Converse Stream response chunk:', chunk); - if (chunk.messageStart) { - accumulatedAssistantMessage.role = chunk.messageStart.role; - } else if (chunk.contentBlockStart) { - blockDeltaIndex = 0; - lastBlockDeltaIndex = 0; - if (chunk.contentBlockStart.start?.toolUse) { - toolUseBlock = { - toolUse: { - ...chunk.contentBlockStart.start?.toolUse, - input: undefined, - }, - }; - } - } else if (chunk.contentBlockDelta) { - if (chunk.contentBlockDelta.delta?.toolUse) { - if (!chunk.contentBlockDelta.delta.toolUse.input) { - toolUseInput = ''; - } - toolUseInput += chunk.contentBlockDelta.delta.toolUse.input; - } else if (chunk.contentBlockDelta.delta?.text) { - text += chunk.contentBlockDelta.delta.text; - yield { - accumulatedTurnContent: [...accumulatedTurnContent, { text }], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockText: chunk.contentBlockDelta.delta.text, - contentBlockIndex: blockIndex, - contentBlockDeltaIndex: blockDeltaIndex, - }; - lastBlockDeltaIndex = blockDeltaIndex; - blockDeltaIndex++; - } - } else if (chunk.contentBlockStop) { - if (toolUseBlock) { - toolUseBlock.toolUse.input = JSON.parse(toolUseInput); - accumulatedAssistantMessage.content?.push(toolUseBlock); - if ( - toolUseBlock.toolUse.name && - this.clientToolByName.has(toolUseBlock.toolUse.name) - ) { - clientToolsRequested = true; - accumulatedTurnContent.push(toolUseBlock); - yield { - accumulatedTurnContent: [...accumulatedTurnContent], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockIndex: blockIndex, - contentBlockToolUse: JSON.stringify(toolUseBlock), - }; - lastBlockIndex = blockIndex; - blockIndex++; - } - toolUseBlock = undefined; + for await (const chunk of bedrockResponse.stream) { + this.logger.debug('Bedrock Converse Stream response chunk:', chunk); + if (chunk.messageStart) { + accumulatedAssistantMessage.role = chunk.messageStart.role; + } else if (chunk.contentBlockStart) { + blockDeltaIndex = 0; + lastBlockDeltaIndex = 0; + if (chunk.contentBlockStart.start?.toolUse) { + toolUseBlock = { + toolUse: { + ...chunk.contentBlockStart.start?.toolUse, + input: undefined, + }, + }; + } + } else if (chunk.contentBlockDelta) { + if (chunk.contentBlockDelta.delta?.toolUse) { + if (!chunk.contentBlockDelta.delta.toolUse.input) { toolUseInput = ''; - } else { - accumulatedAssistantMessage.content?.push({ - text, - }); - accumulatedTurnContent.push({ text }); + } + toolUseInput += chunk.contentBlockDelta.delta.toolUse.input; + } else if (chunk.contentBlockDelta.delta?.text) { + text += chunk.contentBlockDelta.delta.text; + yield { + accumulatedTurnContent: [...accumulatedTurnContent, { text }], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockText: chunk.contentBlockDelta.delta.text, + contentBlockIndex: blockIndex, + contentBlockDeltaIndex: blockDeltaIndex, + }; + lastBlockDeltaIndex = blockDeltaIndex; + blockDeltaIndex++; + } + } else if (chunk.contentBlockStop) { + if (toolUseBlock) { + toolUseBlock.toolUse.input = JSON.parse(toolUseInput); + accumulatedAssistantMessage.content?.push(toolUseBlock); + if ( + toolUseBlock.toolUse.name && + this.clientToolByName.has(toolUseBlock.toolUse.name) + ) { + clientToolsRequested = true; + accumulatedTurnContent.push(toolUseBlock); yield { accumulatedTurnContent: [...accumulatedTurnContent], conversationId: this.event.conversationId, associatedUserMessageId: this.event.currentMessageId, contentBlockIndex: blockIndex, - contentBlockDoneAtIndex: lastBlockDeltaIndex, + contentBlockToolUse: JSON.stringify(toolUseBlock), }; - text = ''; lastBlockIndex = blockIndex; blockIndex++; } - } else if (chunk.messageStop) { - stopReason = chunk.messageStop.stopReason ?? ''; + toolUseBlock = undefined; + toolUseInput = ''; + } else { + accumulatedAssistantMessage.content?.push({ + text, + }); + accumulatedTurnContent.push({ text }); + yield { + accumulatedTurnContent: [...accumulatedTurnContent], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockIndex: blockIndex, + contentBlockDoneAtIndex: lastBlockDeltaIndex, + }; + text = ''; + lastBlockIndex = blockIndex; + blockIndex++; } + } else if (chunk.messageStop) { + stopReason = chunk.messageStop.stopReason ?? ''; } - this.logger.debug( - 'Accumulated Bedrock Converse Stream response:', - accumulatedAssistantMessage - ); - if (clientToolsRequested) { - // For now if any of client tools is used we ignore executable tools - // and propagate result back to client. - yield { - accumulatedTurnContent: [...accumulatedTurnContent], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockIndex: lastBlockIndex, - stopReason: stopReason, - }; - return; - } - messages.push(accumulatedAssistantMessage); - if (stopReason === 'tool_use') { - const responseContentBlocks = - accumulatedAssistantMessage.content ?? []; - const toolUseBlocks = responseContentBlocks.filter( - (block) => 'toolUse' in block - ) as Array; - const toolResponseContentBlocks: Array = []; - for (const responseContentBlock of toolUseBlocks) { - const toolUseBlock = - responseContentBlock as ContentBlock.ToolUseMember; - const toolResultContentBlock = await this.executeTool(toolUseBlock); - toolResponseContentBlocks.push(toolResultContentBlock); - } - messages.push({ - role: 'user', - content: toolResponseContentBlocks, - }); + } + this.logger.debug( + 'Accumulated Bedrock Converse Stream response:', + accumulatedAssistantMessage + ); + if (clientToolsRequested) { + // For now if any of client tools is used we ignore executable tools + // and propagate result back to client. + yield { + accumulatedTurnContent: [...accumulatedTurnContent], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockIndex: lastBlockIndex, + stopReason: stopReason, + }; + return; + } + messages.push(accumulatedAssistantMessage); + if (stopReason === 'tool_use') { + const responseContentBlocks = accumulatedAssistantMessage.content ?? []; + const toolUseBlocks = responseContentBlocks.filter( + (block) => 'toolUse' in block + ) as Array; + const toolResponseContentBlocks: Array = []; + for (const responseContentBlock of toolUseBlocks) { + const toolUseBlock = + responseContentBlock as ContentBlock.ToolUseMember; + const toolResultContentBlock = await this.executeTool(toolUseBlock); + toolResponseContentBlocks.push(toolResultContentBlock); } - } while (stopReason === 'tool_use'); + messages.push({ + role: 'user', + content: toolResponseContentBlocks, + }); + } + } while (stopReason === 'tool_use'); - yield { - accumulatedTurnContent: [...accumulatedTurnContent], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockIndex: lastBlockIndex, - stopReason: stopReason, - }; - } catch (error) { - console.error('Streaming conversation with Bedrock failed', error); - const conversationTurnError = - this.convertErrorToConversationTurnError(error); - yield { - accumulatedTurnContent: [...accumulatedTurnContent], - conversationId: this.event.conversationId, - associatedUserMessageId: this.event.currentMessageId, - contentBlockIndex: blockIndex, - errors: [conversationTurnError], - }; - } + yield { + accumulatedTurnContent: [...accumulatedTurnContent], + conversationId: this.event.conversationId, + associatedUserMessageId: this.event.currentMessageId, + contentBlockIndex: lastBlockIndex, + stopReason: stopReason, + }; } /** @@ -459,23 +433,4 @@ export class BedrockConverseAdapter { }; } }; - - private convertErrorToConversationTurnError = ( - error: unknown - ): ConversationTurnError => { - let errorType = 'UnknownError'; - let message: string; - if (error instanceof Error) { - message = error.message; - if (error.name) { - errorType = error.name; - } - } else { - message = JSON.stringify(error); - } - return { - errorType, - message, - }; - }; } diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts index 6d71fbf5ff..49ee1d342f 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts @@ -1,13 +1,11 @@ import { describe, it, mock } from 'node:test'; import assert from 'node:assert'; import { ConversationTurnExecutor } from './conversation_turn_executor'; -import { - ConversationTurnEvent, - ConversationTurnResponse, - StreamingResponseChunk, -} from './types'; +import { ConversationTurnEvent, StreamingResponseChunk } from './types'; import { BedrockConverseAdapter } from './bedrock_converse_adapter'; +import { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { ConversationTurnResponseSender } from './conversation_turn_response_sender'; +import { Lazy } from './lazy'; void describe('Conversation turn executor', () => { const event: ConversationTurnEvent = { @@ -31,9 +29,10 @@ void describe('Conversation turn executor', () => { void it('executes turn successfully', async () => { const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); - const bedrockResponse: ConversationTurnResponse = { - content: [{ text: 'block1' }, { text: 'block2' }], - }; + const bedrockResponse: Array = [ + { text: 'block1' }, + { text: 'block2' }, + ]; const bedrockConverseAdapterAskBedrockMock = mock.method( bedrockConverseAdapter, 'askBedrock', @@ -64,8 +63,8 @@ void describe('Conversation turn executor', () => { await new ConversationTurnExecutor( event, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(); @@ -158,8 +157,8 @@ void describe('Conversation turn executor', () => { await new ConversationTurnExecutor( streamingEvent, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(); @@ -219,10 +218,12 @@ void describe('Conversation turn executor', () => { const consoleErrorMock = mock.fn(); const consoleLogMock = mock.fn(); const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); const consoleMock = { error: consoleErrorMock, log: consoleLogMock, debug: consoleDebugMock, + warn: consoleWarnMock, } as unknown as Console; await assert.rejects( @@ -230,8 +231,8 @@ void describe('Conversation turn executor', () => { new ConversationTurnExecutor( event, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(), (error: Error) => { @@ -269,9 +270,10 @@ void describe('Conversation turn executor', () => { void it('logs and propagates error if response sender throws', async () => { const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); - const bedrockResponse: ConversationTurnResponse = { - content: [{ text: 'block1' }, { text: 'block2' }], - }; + const bedrockResponse: Array = [ + { text: 'block1' }, + { text: 'block2' }, + ]; const bedrockConverseAdapterAskBedrockMock = mock.method( bedrockConverseAdapter, 'askBedrock', @@ -291,6 +293,12 @@ void describe('Conversation turn executor', () => { () => Promise.resolve() ); + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + const consoleErrorMock = mock.fn(); const consoleLogMock = mock.fn(); const consoleDebugMock = mock.fn(); @@ -305,8 +313,8 @@ void describe('Conversation turn executor', () => { new ConversationTurnExecutor( event, [], - bedrockConverseAdapter, - responseSender, + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), consoleMock ).execute(), (error: Error) => { @@ -340,5 +348,15 @@ void describe('Conversation turn executor', () => { consoleErrorMock.mock.calls[0].arguments[1], responseSenderError ); + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'Failed to send response', + }, + ] + ); }); }); diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts index 04a7384207..9c5389f610 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.ts @@ -1,6 +1,7 @@ import { ConversationTurnResponseSender } from './conversation_turn_response_sender.js'; import { ConversationTurnEvent, ExecutableTool, JSONSchema } from './types.js'; import { BedrockConverseAdapter } from './bedrock_converse_adapter.js'; +import { Lazy } from './lazy'; /** * This class is responsible for orchestrating conversation turn execution. @@ -16,11 +17,13 @@ export class ConversationTurnExecutor { constructor( private readonly event: ConversationTurnEvent, additionalTools: Array, - private readonly bedrockConverseAdapter = new BedrockConverseAdapter( - event, - additionalTools + // We're deferring dependency initialization here so that we can capture all validation errors. + private readonly responseSender = new Lazy( + () => new ConversationTurnResponseSender(event) + ), + private readonly bedrockConverseAdapter = new Lazy( + () => new BedrockConverseAdapter(event, additionalTools) ), - private readonly responseSender = new ConversationTurnResponseSender(event), private readonly logger = console ) {} @@ -32,14 +35,14 @@ export class ConversationTurnExecutor { this.logger.debug('Event received:', this.event); if (this.event.streamResponse) { - const chunks = this.bedrockConverseAdapter.askBedrockStreaming(); + const chunks = this.bedrockConverseAdapter.value.askBedrockStreaming(); for await (const chunk of chunks) { - await this.responseSender.sendResponseChunk(chunk); + await this.responseSender.value.sendResponseChunk(chunk); } } else { const assistantResponse = - await this.bedrockConverseAdapter.askBedrock(); - await this.responseSender.sendResponse(assistantResponse); + await this.bedrockConverseAdapter.value.askBedrock(); + await this.responseSender.value.sendResponse(assistantResponse); } this.logger.log( @@ -50,10 +53,28 @@ export class ConversationTurnExecutor { `Failed to handle conversation turn event, currentMessageId=${this.event.currentMessageId}, conversationId=${this.event.conversationId}`, e ); + await this.tryForwardError(e); // Propagate error to mark lambda execution as failed in metrics. throw e; } }; + + private tryForwardError = async (e: unknown) => { + try { + let errorType = 'UnknownError'; + let message: string; + if (e instanceof Error) { + errorType = e.name; + message = e.message; + } else { + message = JSON.stringify(e); + } + await this.responseSender.value.sendErrors([{ errorType, message }]); + } catch (e) { + // Best effort, only log the fact that we tried to send error back to AppSync. + this.logger.warn('Failed to send error mutation', e); + } + }; } /** diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts index fad4130275..0e7e1fe71a 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.test.ts @@ -6,8 +6,8 @@ import { MutationStreamingResponseInput, } from './conversation_turn_response_sender'; import { + ConversationTurnError, ConversationTurnEvent, - ConversationTurnResponse, StreamingResponseChunk, } from './types'; import { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; @@ -49,14 +49,12 @@ void describe('Conversation turn response sender', () => { event, graphqlRequestExecutor ); - const response: ConversationTurnResponse = { - content: [ - { - text: 'block1', - }, - { text: 'block2' }, - ], - }; + const response: Array = [ + { + text: 'block1', + }, + { text: 'block2' }, + ]; await sender.sendResponse(response); assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); @@ -108,7 +106,7 @@ void describe('Conversation turn response sender', () => { }, }, }; - const response: ConversationTurnResponse = { content: [toolUseBlock] }; + const response: Array = [toolUseBlock]; await sender.sendResponse(response); assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); @@ -242,4 +240,60 @@ void describe('Conversation turn response sender', () => { }, }); }); + + void it('sends errors response back to appsync', async () => { + const graphqlRequestExecutor = new GraphqlRequestExecutor('', '', ''); + const executeGraphqlMock = mock.method( + graphqlRequestExecutor, + 'executeGraphql', + () => + // Mock successful Appsync response + Promise.resolve() + ); + const sender = new ConversationTurnResponseSender( + event, + graphqlRequestExecutor + ); + const errors: Array = [ + { + errorType: 'errorType1', + message: 'errorMessage1', + }, + { + errorType: 'errorType2', + message: 'errorMessage2', + }, + ]; + await sender.sendErrors(errors); + + assert.strictEqual(executeGraphqlMock.mock.calls.length, 1); + const request = executeGraphqlMock.mock.calls[0] + .arguments[0] as GraphqlRequest; + assert.deepStrictEqual(request, { + query: + '\n' + + ' mutation PublishModelResponse($input: testResponseMutationInputTypeName!) {\n' + + ' testResponseMutationName(input: $input) {\n' + + ' testSelectionSet\n' + + ' }\n' + + ' }\n' + + ' ', + variables: { + input: { + conversationId: event.conversationId, + errors: [ + { + errorType: 'errorType1', + message: 'errorMessage1', + }, + { + errorType: 'errorType2', + message: 'errorMessage2', + }, + ], + associatedUserMessageId: event.currentMessageId, + }, + }, + }); + }); }); diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts index 7ac14022c6..723b787959 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_response_sender.ts @@ -1,32 +1,31 @@ import { ConversationTurnError, ConversationTurnEvent, - ConversationTurnResponse, StreamingResponseChunk, } from './types.js'; import type { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; import { GraphqlRequestExecutor } from './graphql_request_executor'; export type MutationResponseInput = { - input: - | { - associatedUserMessageId: string; - conversationId: string; - content: ContentBlock[]; - errors?: never; - } - | { - associatedUserMessageId: string; - conversationId: string; - content?: never; - errors: ConversationTurnError[]; - }; + input: { + conversationId: string; + content: ContentBlock[]; + associatedUserMessageId: string; + }; }; export type MutationStreamingResponseInput = { input: StreamingResponseChunk; }; +export type MutationErrorsResponseInput = { + input: { + conversationId: string; + errors: ConversationTurnError[]; + associatedUserMessageId: string; + }; +}; + /** * This class is responsible for sending a response produced by Bedrock back to AppSync * in a form of mutation. @@ -45,8 +44,8 @@ export class ConversationTurnResponseSender { private readonly logger = console ) {} - sendResponse = async (response: ConversationTurnResponse) => { - const responseMutationRequest = this.createMutationRequest(response); + sendResponse = async (message: ContentBlock[]) => { + const responseMutationRequest = this.createMutationRequest(message); this.logger.debug('Sending response mutation:', responseMutationRequest); await this.graphqlRequestExecutor.executeGraphql< MutationResponseInput, @@ -63,7 +62,37 @@ export class ConversationTurnResponseSender { >(responseMutationRequest); }; - private createMutationRequest = (response: ConversationTurnResponse) => { + sendErrors = async (errors: ConversationTurnError[]) => { + const responseMutationRequest = this.createMutationErrorsRequest(errors); + this.logger.debug( + 'Sending errors response mutation:', + responseMutationRequest + ); + await this.graphqlRequestExecutor.executeGraphql< + MutationErrorsResponseInput, + void + >(responseMutationRequest); + }; + + private createMutationErrorsRequest = (errors: ConversationTurnError[]) => { + const query = ` + mutation PublishModelResponse($input: ${this.event.responseMutation.inputTypeName}!) { + ${this.event.responseMutation.name}(input: $input) { + ${this.event.responseMutation.selectionSet} + } + } + `; + const variables: MutationErrorsResponseInput = { + input: { + conversationId: this.event.conversationId, + errors, + associatedUserMessageId: this.event.currentMessageId, + }, + }; + return { query, variables }; + }; + + private createMutationRequest = (content: ContentBlock[]) => { const query = ` mutation PublishModelResponse($input: ${this.event.responseMutation.inputTypeName}!) { ${this.event.responseMutation.name}(input: $input) { @@ -71,26 +100,14 @@ export class ConversationTurnResponseSender { } } `; - let variables: MutationResponseInput; - if (typeof response.content !== 'undefined') { - variables = { - input: { - conversationId: this.event.conversationId, - content: this.serializeContent(response.content), - associatedUserMessageId: this.event.currentMessageId, - }, - }; - } else if (typeof response.errors !== 'undefined') { - variables = { - input: { - conversationId: this.event.conversationId, - errors: response.errors, - associatedUserMessageId: this.event.currentMessageId, - }, - }; - } else { - throw new Error('Response contains neither content nor error'); - } + content = this.serializeContent(content); + const variables: MutationResponseInput = { + input: { + conversationId: this.event.conversationId, + content, + associatedUserMessageId: this.event.currentMessageId, + }, + }; return { query, variables }; }; diff --git a/packages/ai-constructs/src/conversation/runtime/lazy.ts b/packages/ai-constructs/src/conversation/runtime/lazy.ts new file mode 100644 index 0000000000..7f5b2032ca --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/lazy.ts @@ -0,0 +1,17 @@ +/** + * A class that initializes lazily upon usage. + */ +export class Lazy { + #value?: T; + + /** + * Creates lazy instance. + */ + constructor(private readonly valueFactory: () => T) {} + /** + * Gets a value. Value is create at first access. + */ + public get value(): T { + return (this.#value ??= this.valueFactory()); + } +} diff --git a/packages/ai-constructs/src/conversation/runtime/types.ts b/packages/ai-constructs/src/conversation/runtime/types.ts index cdf925eb81..3d95030cab 100644 --- a/packages/ai-constructs/src/conversation/runtime/types.ts +++ b/packages/ai-constructs/src/conversation/runtime/types.ts @@ -1,6 +1,5 @@ import * as bedrock from '@aws-sdk/client-bedrock-runtime'; import * as jsonSchemaToTypeScript from 'json-schema-to-ts'; -import type { ContentBlock } from '@aws-sdk/client-bedrock-runtime'; /* Notice: This file contains types that are exposed publicly. @@ -101,16 +100,6 @@ export type ConversationTurnError = { message: string; }; -export type ConversationTurnResponse = - | { - content: ContentBlock[]; - errors?: never; - } - | { - content?: never; - errors: ConversationTurnError[]; - }; - export type StreamingResponseChunk = { // always required conversationId: string; @@ -124,7 +113,6 @@ export type StreamingResponseChunk = { contentBlockDeltaIndex: number; contentBlockDoneAtIndex?: never; contentBlockToolUse?: never; - errors?: never; stopReason?: never; } | { @@ -133,7 +121,6 @@ export type StreamingResponseChunk = { contentBlockText?: never; contentBlockDeltaIndex?: never; contentBlockToolUse?: never; - errors?: never; stopReason?: never; } | { @@ -142,7 +129,6 @@ export type StreamingResponseChunk = { contentBlockDoneAtIndex?: never; contentBlockText?: never; contentBlockDeltaIndex?: never; - errors?: never; stopReason?: never; } | { @@ -152,15 +138,5 @@ export type StreamingResponseChunk = { contentBlockText?: never; contentBlockDeltaIndex?: never; contentBlockToolUse?: never; - errors?: never; - } - | { - // error - errors: ConversationTurnError[]; - stopReason?: never; - contentBlockDoneAtIndex?: never; - contentBlockText?: never; - contentBlockDeltaIndex?: never; - contentBlockToolUse?: never; } ); diff --git a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts index 7e924e5aff..07c19400c2 100644 --- a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts +++ b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts @@ -105,7 +105,7 @@ const schema = a.schema({ // always conversationId: a.id().required(), associatedUserMessageId: a.id().required(), - contentBlockIndex: a.integer().required(), + contentBlockIndex: a.integer(), accumulatedTurnContent: a.ref('MockContentBlock').array(), // these describe chunks or end of block From fdad5b2469f80378b3e7bfcb35211817f8a607be Mon Sep 17 00:00:00 2001 From: Kamil Sobol Date: Wed, 30 Oct 2024 14:44:20 -0700 Subject: [PATCH 3/5] more tests --- .../conversation_turn_executor.test.ts | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts index 49ee1d342f..4a3b82b44c 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts @@ -215,6 +215,12 @@ void describe('Conversation turn executor', () => { () => Promise.resolve() ); + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + const consoleErrorMock = mock.fn(); const consoleLogMock = mock.fn(); const consoleDebugMock = mock.fn(); @@ -266,6 +272,16 @@ void describe('Conversation turn executor', () => { consoleErrorMock.mock.calls[0].arguments[1], bedrockError ); + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'Bedrock failed', + }, + ] + ); }); void it('logs and propagates error if response sender throws', async () => { @@ -302,10 +318,12 @@ void describe('Conversation turn executor', () => { const consoleErrorMock = mock.fn(); const consoleLogMock = mock.fn(); const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); const consoleMock = { error: consoleErrorMock, log: consoleLogMock, debug: consoleDebugMock, + warn: consoleWarnMock, } as unknown as Console; await assert.rejects( @@ -359,4 +377,114 @@ void describe('Conversation turn executor', () => { ] ); }); + + void it('throws original exception if error sender fails', async () => { + const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); + const originalError = new Error('original error'); + mock.method(bedrockConverseAdapter, 'askBedrock', () => + Promise.reject(originalError) + ); + const responseSender = new ConversationTurnResponseSender(event); + mock.method(responseSender, 'sendResponse', () => Promise.resolve()); + + mock.method(responseSender, 'sendResponseChunk', () => Promise.resolve()); + + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.reject(new Error('sender error')) + ); + + const consoleErrorMock = mock.fn(); + const consoleLogMock = mock.fn(); + const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); + const consoleMock = { + error: consoleErrorMock, + log: consoleLogMock, + debug: consoleDebugMock, + warn: consoleWarnMock, + } as unknown as Console; + + await assert.rejects( + () => + new ConversationTurnExecutor( + event, + [], + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), + consoleMock + ).execute(), + (error: Error) => { + assert.strictEqual(error, originalError); + return true; + } + ); + + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'original error', + }, + ] + ); + }); + + void it('serializes unknown errors', async () => { + const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); + const unknownError = { some: 'shape' }; + mock.method(bedrockConverseAdapter, 'askBedrock', () => + Promise.reject(unknownError) + ); + const responseSender = new ConversationTurnResponseSender(event); + mock.method(responseSender, 'sendResponse', () => Promise.resolve()); + + mock.method(responseSender, 'sendResponseChunk', () => Promise.resolve()); + + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + + const consoleErrorMock = mock.fn(); + const consoleLogMock = mock.fn(); + const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); + const consoleMock = { + error: consoleErrorMock, + log: consoleLogMock, + debug: consoleDebugMock, + warn: consoleWarnMock, + } as unknown as Console; + + await assert.rejects( + () => + new ConversationTurnExecutor( + event, + [], + new Lazy(() => responseSender), + new Lazy(() => bedrockConverseAdapter), + consoleMock + ).execute(), + (error: Error) => { + assert.strictEqual(error, unknownError); + return true; + } + ); + + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'UnknownError', + message: '{"some":"shape"}', + }, + ] + ); + }); }); From 1b20e0dc4b21f3f8859fcc98f432c4c76ae3a82f Mon Sep 17 00:00:00 2001 From: Kamil Sobol Date: Wed, 30 Oct 2024 14:50:01 -0700 Subject: [PATCH 4/5] validation error --- .../conversation/runtime/bedrock_converse_adapter.ts | 3 ++- .../ai-constructs/src/conversation/runtime/errors.ts | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 packages/ai-constructs/src/conversation/runtime/errors.ts diff --git a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts index 510925492a..c215efc0b4 100644 --- a/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts +++ b/packages/ai-constructs/src/conversation/runtime/bedrock_converse_adapter.ts @@ -21,6 +21,7 @@ import { import { ConversationTurnEventToolsProvider } from './event-tools-provider'; import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever'; import * as bedrock from '@aws-sdk/client-bedrock-runtime'; +import { ValidationError } from './errors'; /** * This class is responsible for interacting with Bedrock Converse API @@ -87,7 +88,7 @@ export class BedrockConverseAdapter { this.clientToolByName.set(t.name, t); }); if (duplicateTools.size > 0) { - throw new Error( + throw new ValidationError( `Tools must have unique names. Duplicate tools: ${[ ...duplicateTools, ].join(', ')}.` diff --git a/packages/ai-constructs/src/conversation/runtime/errors.ts b/packages/ai-constructs/src/conversation/runtime/errors.ts new file mode 100644 index 0000000000..1d3063dd49 --- /dev/null +++ b/packages/ai-constructs/src/conversation/runtime/errors.ts @@ -0,0 +1,12 @@ +/** + * Represents validation errors. + */ +export class ValidationError extends Error { + /** + * Creates validation error instance. + */ + constructor(message: string) { + super(message); + this.name = 'ValidationError'; + } +} From b035b67be167f9588599cdb56e70ad39d6bde77a Mon Sep 17 00:00:00 2001 From: Kamil Sobol Date: Wed, 30 Oct 2024 15:41:45 -0700 Subject: [PATCH 5/5] more tests --- .../conversation_turn_executor.test.ts | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts index 4a3b82b44c..8c42431b63 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_turn_executor.test.ts @@ -487,4 +487,59 @@ void describe('Conversation turn executor', () => { ] ); }); + + void it('reports initialization errors', async () => { + const bedrockConverseAdapter = new BedrockConverseAdapter(event, []); + mock.method(bedrockConverseAdapter, 'askBedrock', () => Promise.resolve()); + const responseSender = new ConversationTurnResponseSender(event); + mock.method(responseSender, 'sendResponse', () => Promise.resolve()); + + mock.method(responseSender, 'sendResponseChunk', () => Promise.resolve()); + + const responseSenderSendErrorsMock = mock.method( + responseSender, + 'sendErrors', + () => Promise.resolve() + ); + + const consoleErrorMock = mock.fn(); + const consoleLogMock = mock.fn(); + const consoleDebugMock = mock.fn(); + const consoleWarnMock = mock.fn(); + const consoleMock = { + error: consoleErrorMock, + log: consoleLogMock, + debug: consoleDebugMock, + warn: consoleWarnMock, + } as unknown as Console; + + const initializationError = new Error('initialization error'); + await assert.rejects( + () => + new ConversationTurnExecutor( + event, + [], + new Lazy(() => responseSender), + new Lazy(() => { + throw initializationError; + }), + consoleMock + ).execute(), + (error: Error) => { + assert.strictEqual(error, initializationError); + return true; + } + ); + + assert.strictEqual(responseSenderSendErrorsMock.mock.calls.length, 1); + assert.deepStrictEqual( + responseSenderSendErrorsMock.mock.calls[0].arguments[0], + [ + { + errorType: 'Error', + message: 'initialization error', + }, + ] + ); + }); });