diff --git a/app/(playground)/p/[agentId]/actions.ts b/app/(playground)/p/[agentId]/actions.ts index 10594435..458275fa 100644 --- a/app/(playground)/p/[agentId]/actions.ts +++ b/app/(playground)/p/[agentId]/actions.ts @@ -1,370 +1,24 @@ "use server"; -import { db } from "@/drizzle"; import { ExternalServiceName, VercelBlobOperation, createLogger, waitForTelemetryExport, withCountMeasurement, - withTokenMeasurement, } from "@/lib/opentelemetry"; -import { anthropic } from "@ai-sdk/anthropic"; -import { google } from "@ai-sdk/google"; -import { openai } from "@ai-sdk/openai"; -import { toJsonSchema } from "@valibot/to-json-schema"; import { type ListBlobResult, del, list, put } from "@vercel/blob"; -import { type LanguageModelV1, jsonSchema, streamObject } from "ai"; -import { createStreamableValue } from "ai/rsc"; -import { MockLanguageModelV1, simulateReadableStream } from "ai/test"; -import HandleBars from "handlebars"; -import Langfuse from "langfuse"; import { UnstructuredClient } from "unstructured-client"; import { Strategy } from "unstructured-client/sdk/models/shared"; -import * as v from "valibot"; -import { vercelBlobFileFolder, vercelBlobGraphFolder } from "./constants"; +import { vercelBlobFileFolder } from "./constants"; -import { textGenerationPrompt } from "./lib/prompts"; import { buildFileFolderPath, buildGraphPath, elementsToMarkdown, - langfuseModel, pathJoin, - toErrorWithMessage, } from "./lib/utils"; -import type { - AgentId, - ArtifactId, - FileData, - FileId, - Graph, - GraphId, - NodeHandle, - NodeHandleId, - NodeId, - TextArtifactObject, - TextGenerateActionContent, -} from "./types"; - -function resolveLanguageModel( - llm: TextGenerateActionContent["llm"], -): LanguageModelV1 { - const [provider, model] = llm.split(":"); - if (provider === "openai") { - return openai(model); - } - if (provider === "anthropic") { - return anthropic(model); - } - if (provider === "google") { - return google(model); - } - if (provider === "dev") { - return new MockLanguageModelV1({ - defaultObjectGenerationMode: "json", - doStream: async () => ({ - stream: simulateReadableStream({ - chunks: [{ type: "error", error: "a" }], - }), - rawCall: { rawPrompt: null, rawSettings: {} }, - }), - }); - } - throw new Error("Unsupported model provider"); -} - -const artifactSchema = v.object({ - plan: v.pipe( - v.string(), - v.description( - "How you think about the content of the artefact (purpose, structure, essentials) and how you intend to output it", - ), - ), - title: v.pipe(v.string(), v.description("The title of the artefact")), - content: v.pipe( - v.string(), - v.description("The content of the artefact formatted markdown."), - ), - description: v.pipe( - v.string(), - v.description( - "Explanation of the Artifact and what the intention was in creating this Artifact. Add any suggestions for making it even better.", - ), - ), -}); - -interface ActionSourceBase { - type: string; - nodeId: NodeId; -} - -interface TextSource extends ActionSourceBase { - type: "text"; - content: string; -} -interface FileSource extends ActionSourceBase { - type: "file"; - title: string; - content: string; -} -interface TextGenerationSource extends ActionSourceBase { - type: "textGeneration"; - title: string; - content: string; -} - -type ActionSource = TextSource | TextGenerationSource | FileSource; - -export async function action( - artifactId: ArtifactId, - agentId: AgentId, - nodeId: NodeId, -) { - const startTime = Date.now(); - const lf = new Langfuse(); - const trace = lf.trace({ - sessionId: artifactId, - }); - - const agent = await db.query.agents.findFirst({ - where: (agents, { eq }) => eq(agents.id, agentId), - }); - if (agent === undefined || agent.graphUrl === null) { - throw new Error(`Agent with id ${agentId} not found`); - } - - const graph = await fetch(agent.graphUrl).then( - (res) => res.json() as unknown as Graph, - ); - const node = graph.nodes.find((node) => node.id === nodeId); - if (node === undefined) { - throw new Error("Node not found"); - } - /** - * This function is a helper that retrieves a node from the graph - * based on its NodeHandleId. It looks for a connection in the - * graph that matches the provided handleId and returns the - * corresponding node if found, or null if no such node exists. - */ - function findNode(handleId: NodeHandleId) { - const connection = graph.connections.find( - (connection) => connection.targetNodeHandleId === handleId, - ); - const node = graph.nodes.find( - (node) => node.id === connection?.sourceNodeId, - ); - if (node === undefined) { - return null; - } - return node; - } - - /** - * The resolveSources function maps over an array of NodeHandles, - * finds the corresponding nodes in the graph, and returns an - * array of ActionSources. It handles both text and text generation - * sources and filters out any null results. If a text node is - * found, it extracts the text content; if a textGeneration node - * is found, it retrieves the corresponding generatedArtifact. - */ - async function resolveSources(sources: NodeHandle[]) { - return Promise.all( - sources.map(async (source) => { - const node = findNode(source.id); - switch (node?.content.type) { - case "text": - return { - type: "text", - content: node.content.text, - nodeId: node.id, - } satisfies ActionSource; - case "file": { - if (node.content.data == null) { - throw new Error("File not found"); - } - if (node.content.data.status === "uploading") { - /** @todo Let user know file is uploading*/ - throw new Error("File is uploading"); - } - if (node.content.data.status === "processing") { - /** @todo Let user know file is processing*/ - throw new Error("File is processing"); - } - if (node.content.data.status === "failed") { - return null; - } - const text = await fetch(node.content.data.textDataUrl).then( - (res) => res.text(), - ); - return { - type: "file", - title: node.content.data.name, - content: text, - nodeId: node.id, - } satisfies ActionSource; - } - - case "files": { - return await Promise.all( - node.content.data.map(async (file) => { - if (file == null) { - throw new Error("File not found"); - } - if (file.status === "uploading") { - /** @todo Let user know file is uploading*/ - throw new Error("File is uploading"); - } - if (file.status === "processing") { - /** @todo Let user know file is processing*/ - throw new Error("File is processing"); - } - if (file.status === "failed") { - return null; - } - const text = await fetch(file.textDataUrl).then((res) => - res.text(), - ); - return { - type: "file", - title: file.name, - content: text, - nodeId: node.id, - } satisfies ActionSource; - }), - ); - } - case "textGeneration": { - const generatedArtifact = graph.artifacts.find( - (artifact) => artifact.creatorNodeId === node.id, - ); - if ( - generatedArtifact === undefined || - generatedArtifact.type !== "generatedArtifact" - ) { - return null; - } - return { - type: "textGeneration", - title: generatedArtifact.object.title, - content: generatedArtifact.object.content, - nodeId: node.id, - } satisfies ActionSource; - } - default: - return null; - } - }), - ).then((sources) => sources.filter((source) => source !== null).flat()); - } - - /** - * The resolveRequirement function retrieves the content of a - * specified requirement node, if it exists. It looks for - * the node in the graph based on the given NodeHandle. - * If the node is of type "text", it returns the text - * content; if it is of type "textGeneration", it looks - * for the corresponding generated artifact and returns - * its content. If the node is not found or does not match - * the expected types, it returns null. - */ - function resolveRequirement(requirement?: NodeHandle) { - if (requirement === undefined) { - return null; - } - const node = findNode(requirement.id); - switch (node?.content.type) { - case "text": - return node.content.text; - case "textGeneration": { - const generatedArtifact = graph.artifacts.find( - (artifact) => artifact.creatorNodeId === node.id, - ); - if ( - generatedArtifact === undefined || - generatedArtifact.type === "generatedArtifact" - ) { - return null; - } - return generatedArtifact.object.content; - } - default: - return null; - } - } - - // The main switch statement handles the different types of nodes - switch (node.content.type) { - case "textGeneration": { - const actionSources = await resolveSources(node.content.sources); - const requirement = resolveRequirement(node.content.requirement); - const model = resolveLanguageModel(node.content.llm); - const promptTemplate = HandleBars.compile( - node.content.system ?? textGenerationPrompt, - ); - const prompt = promptTemplate({ - instruction: node.content.instruction, - requirement, - sources: actionSources, - }); - const topP = node.content.topP; - const temperature = node.content.temperature; - const stream = createStreamableValue(); - - const generationTracer = trace.generation({ - name: "generate-text", - input: prompt, - model: langfuseModel(node.content.llm), - modelParameters: { - topP: node.content.topP, - temperature: node.content.temperature, - }, - }); - (async () => { - const { partialObjectStream, object, usage } = streamObject({ - model, - prompt, - schema: jsonSchema>( - toJsonSchema(artifactSchema), - ), - topP, - temperature, - }); - - for await (const partialObject of partialObjectStream) { - stream.update({ - type: "text", - title: partialObject.title ?? "", - content: partialObject.content ?? "", - messages: { - plan: partialObject.plan ?? "", - description: partialObject.description ?? "", - }, - }); - } - const result = await object; - - await withTokenMeasurement( - createLogger(node.content.type), - async () => { - generationTracer.end({ output: result }); - await lf.shutdownAsync(); - waitForTelemetryExport(); - return { usage: await usage }; - }, - model, - startTime, - ); - stream.done(); - })().catch((error) => { - stream.error(error); - }); - return stream.value; - } - default: - throw new Error("Invalid node type"); - } -} +import type { FileData, FileId, Graph } from "./types"; export async function parse(id: FileId, name: string, blobUrl: string) { const startTime = Date.now(); diff --git a/app/(playground)/p/[agentId]/components/properties-panel.tsx b/app/(playground)/p/[agentId]/components/properties-panel.tsx index 677b2173..0121bc93 100644 --- a/app/(playground)/p/[agentId]/components/properties-panel.tsx +++ b/app/(playground)/p/[agentId]/components/properties-panel.tsx @@ -29,7 +29,7 @@ import { useMemo, useState, } from "react"; -import { action, parse, remove } from "../actions"; +import { parse, remove } from "../actions"; import { vercelBlobFileFolder } from "../constants"; import { useDeveloperMode } from "../contexts/developer-mode"; import { useExecution } from "../contexts/execution"; @@ -288,7 +288,7 @@ export function PropertiesPanel() { const { graph, dispatch, flush } = useGraph(); const selectedNode = useSelectedNode(); const { open, setOpen, tab, setTab } = usePropertiesPanel(); - const { execute } = useExecution(); + const { executeNode } = useExecution(); return (
execute(selectedNode.id)} + onClick={() => executeNode(selectedNode.id)} > Generate @@ -630,7 +630,7 @@ export function PropertiesPanel() { setTab("Prompt"); }} onEditPrompt={() => setTab("Prompt")} - onGenerateText={() => execute(selectedNode.id)} + onGenerateText={() => executeNode(selectedNode.id)} /> )} diff --git a/app/(playground)/p/[agentId]/contexts/execution.tsx b/app/(playground)/p/[agentId]/contexts/execution.tsx index 26a51825..34f04a08 100644 --- a/app/(playground)/p/[agentId]/contexts/execution.tsx +++ b/app/(playground)/p/[agentId]/contexts/execution.tsx @@ -8,6 +8,7 @@ import { useContext, useState, } from "react"; +import { deriveFlows } from "../lib/graph"; import { createArtifactId, createExecutionId, @@ -37,6 +38,7 @@ import type { SkippedJobExecution, StepExecution, StepId, + TextArtifact, TextArtifactObject, } from "../types"; import { useGraph } from "./graph"; @@ -137,15 +139,28 @@ const processStreamContent = async ( return textArtifactObject; }; -const executeStep = async ( - stepExecution: StepExecution, +const executeStep = async ({ + stepExecution, + executeStepAction, + updateExecution, + updateArtifact, + onStepFinish, + onStepFail, +}: { + stepExecution: StepExecution; executeStepAction: ( stepId: StepId, - ) => Promise>, + ) => Promise>; updateExecution: ( updater: (prev: Execution | null) => Execution | null, - ) => void, -): Promise => { + ) => void; + updateArtifact: (artifactId: ArtifactId, content: TextArtifactObject) => void; + onStepFinish?: ( + stepExecution: CompletedStepExecution, + artifact: TextArtifact, + ) => void; + onStepFail?: (stepExecution: FailedStepExecution) => void; +}): Promise => { if (stepExecution.status === "completed") { return stepExecution; } @@ -185,29 +200,26 @@ const executeStep = async ( try { // Execute step and process stream const stream = await executeStepAction(stepExecution.stepId); - const finalArtifact = await processStreamContent(stream, (content) => { - updateExecution((prev) => { - if (!prev || prev.status !== "running") return null; - return { - ...prev, - artifacts: prev.artifacts.map((artifact) => - artifact.id === artifactId - ? { ...artifact, object: content } - : artifact, - ), - }; - }); - }); + const finalArtifact = await processStreamContent(stream, (content) => + updateArtifact(artifactId, content), + ); // Complete step execution const stepDurationMs = Date.now() - stepRunStartedAt; - const successStepExecution: CompletedStepExecution = { + const completedStepExecution: CompletedStepExecution = { ...stepExecution, status: "completed", runStartedAt: stepRunStartedAt, durationMs: stepDurationMs, }; + const generatedArtifact = { + id: artifactId, + type: "generatedArtifact", + creatorNodeId: stepExecution.nodeId, + createdAt: Date.now(), + object: finalArtifact, + } satisfies TextArtifact; updateExecution((prev) => { if (!prev || prev.status !== "running") return null; @@ -216,24 +228,16 @@ const executeStep = async ( jobExecutions: prev.jobExecutions.map((job) => ({ ...job, stepExecutions: job.stepExecutions.map((step) => - step.id === stepExecution.id ? successStepExecution : step, + step.id === stepExecution.id ? completedStepExecution : step, ), })), artifacts: prev.artifacts.map((artifact) => - artifact.id === artifactId - ? { - id: artifactId, - type: "generatedArtifact", - creatorNodeId: stepExecution.nodeId, - createdAt: Date.now(), - object: finalArtifact, - } - : artifact, + artifact.id === artifactId ? generatedArtifact : artifact, ), }; }); - - return successStepExecution; + onStepFinish?.(completedStepExecution, generatedArtifact); + return completedStepExecution; } catch (unknownError) { const error = toErrorWithMessage(unknownError).message; const stepDurationMs = Date.now() - stepRunStartedAt; @@ -260,19 +264,33 @@ const executeStep = async ( ), }; }); + onStepFail?.(failedStepExecution); return failedStepExecution; } }; -const executeJob = async ( - jobExecution: JobExecution, +const executeJob = async ({ + jobExecution, + executeStepAction, + updateArtifact, + updateExecution, + onStepFinish, + onStepFail, +}: { + jobExecution: JobExecution; executeStepAction: ( stepId: StepId, - ) => Promise>, + ) => Promise>; updateExecution: ( updater: (prev: Execution | null) => Execution | null, - ) => void, -): Promise => { + ) => void; + updateArtifact: (artifactId: ArtifactId, content: TextArtifactObject) => void; + onStepFinish?: ( + stepExecution: CompletedStepExecution, + artifact: TextArtifact, + ) => void; + onStepFail?: (stepExecution: FailedStepExecution) => void; +}): Promise => { const jobRunStartedAt = Date.now(); // Start job execution @@ -290,8 +308,15 @@ const executeJob = async ( // Execute all steps in parallel const stepExecutions = await Promise.all( - jobExecution.stepExecutions.map((step) => - executeStep(step, executeStepAction, updateExecution), + jobExecution.stepExecutions.map((stepExecution) => + executeStep({ + stepExecution, + executeStepAction, + updateExecution, + updateArtifact, + onStepFinish, + onStepFail, + }), ), ); @@ -345,12 +370,17 @@ const executeJob = async ( interface ExecutionContextType { execution: Execution | null; - execute: (nodeId: NodeId) => Promise; + executeNode: (nodeId: NodeId) => Promise; executeFlow: (flowId: FlowId) => Promise; retryFlowExecution: ( executionId: ExecutionId, forceRetryStepId?: StepId, ) => Promise; + recordAgentUsageAction: ( + startedAt: number, + endedAt: number, + totalDurationMs: number, + ) => Promise; } const ExecutionContext = createContext( @@ -372,23 +402,29 @@ type RetryStepAction = ( ) => Promise>; interface ExecutionProviderProps { children: ReactNode; - executeAction: ( - artifactId: ArtifactId, - nodeId: NodeId, - ) => Promise>; executeStepAction: ExecuteStepAction; putExecutionAction: ( executionSnapshot: ExecutionSnapshot, ) => Promise<{ blobUrl: string }>; retryStepAction: RetryStepAction; + executeNodeAction: ( + executionId: ExecutionId, + nodeId: NodeId, + ) => Promise>; + recordAgentUsageAction: ( + startedAt: number, + endedAt: number, + totalDurationMs: number, + ) => Promise; } export function ExecutionProvider({ children, - executeAction, executeStepAction, putExecutionAction, retryStepAction, + executeNodeAction, + recordAgentUsageAction, }: ExecutionProviderProps) { const { dispatch, flush, graph } = useGraph(); const { setTab } = usePropertiesPanel(); @@ -396,97 +432,7 @@ export function ExecutionProvider({ const { setPlaygroundMode } = usePlaygroundMode(); const [execution, setExecution] = useState(null); - const execute = useCallback( - async (nodeId: NodeId) => { - const artifactId = createArtifactId(); - dispatch({ - type: "upsertArtifact", - input: { - nodeId, - artifact: { - id: artifactId, - type: "streamArtifact", - creatorNodeId: nodeId, - object: { - type: "text", - title: "", - content: "", - messages: { - plan: "", - description: "", - }, - }, - }, - }, - }); - setTab("Result"); - await flush(); - try { - const stream = await executeAction(artifactId, nodeId); - - let textArtifactObject: TextArtifactObject = { - type: "text", - title: "", - content: "", - messages: { - plan: "", - description: "", - }, - }; - for await (const streamContent of readStreamableValue(stream)) { - if (streamContent === undefined) { - continue; - } - dispatch({ - type: "upsertArtifact", - input: { - nodeId, - artifact: { - id: artifactId, - type: "streamArtifact", - creatorNodeId: nodeId, - object: streamContent, - }, - }, - }); - textArtifactObject = { - ...textArtifactObject, - ...streamContent, - }; - } - dispatch({ - type: "upsertArtifact", - input: { - nodeId, - artifact: { - id: artifactId, - type: "generatedArtifact", - creatorNodeId: nodeId, - createdAt: Date.now(), - object: textArtifactObject, - }, - }, - }); - } catch (error) { - addToast({ - type: "error", - title: "Execution failed", - message: toErrorWithMessage(error).message, - }); - dispatch({ - type: "upsertArtifact", - input: { - nodeId, - artifact: null, - }, - }); - } - }, - [executeAction, dispatch, flush, setTab, addToast], - ); - interface ExecuteFlowParams { - flowId: FlowId; initialExecution: Execution; flow: Flow; nodes: Node[]; @@ -494,15 +440,24 @@ export function ExecutionProvider({ executeStepCallback: ( stepId: StepId, ) => Promise>; + updateArtifactCallback: ( + artifactId: ArtifactId, + content: TextArtifactObject, + ) => void; + onStepFinish?: ( + stepExecution: CompletedStepExecution, + artifact: TextArtifact, + ) => void; } const performFlowExecution = useCallback( async ({ - flowId, initialExecution, flow, nodes, connections, executeStepCallback, + updateArtifactCallback, + onStepFinish, }: ExecuteFlowParams) => { let currentExecution = initialExecution; let totalFlowDurationMs = 0; @@ -534,17 +489,25 @@ export function ExecutionProvider({ continue; } - const executedJob = await executeJob( + const executedJob = await executeJob({ jobExecution, - executeStepCallback, - (updater) => { + executeStepAction: executeStepCallback, + updateExecution: (updater) => { const updated = updater(currentExecution); if (updated) { currentExecution = updated; setExecution(updated); } }, - ); + updateArtifact: updateArtifactCallback, + onStepFinish, + onStepFail: (failedStep) => { + addToast({ + type: "error", + message: failedStep.error, + }); + }, + }); totalFlowDurationMs += executedJob.durationMs; if (executedJob.status === "failed") { @@ -578,20 +541,27 @@ export function ExecutionProvider({ }); const { blobUrl } = await putExecutionAction(executionSnapshot); + const runEndedAt = Date.now(); dispatch({ type: "addExecutionIndex", input: { executionIndex: { executionId: currentExecution.id, blobUrl, - completedAt: Date.now(), + completedAt: runEndedAt, }, }, }); + await recordAgentUsageAction( + currentExecution.runStartedAt, + runEndedAt, + currentExecution.durationMs, + ); + return currentExecution; }, - [dispatch, putExecutionAction], + [dispatch, putExecutionAction, addToast, recordAgentUsageAction], ); const executeFlow = useCallback( @@ -617,7 +587,6 @@ export function ExecutionProvider({ setExecution(initialExecution); const finalExecution = await performFlowExecution({ - flowId, initialExecution, flow, nodes: graph.nodes, @@ -629,6 +598,19 @@ export function ExecutionProvider({ stepId, initialExecution.artifacts, ), + updateArtifactCallback: (artifactId, content) => { + setExecution((prev) => { + if (!prev || prev.status !== "running") return null; + return { + ...prev, + artifacts: prev.artifacts.map((artifact) => + artifact.id === artifactId + ? { ...artifact, object: content } + : artifact, + ), + }; + }); + }, }); setExecution(finalExecution); }, @@ -664,7 +646,6 @@ export function ExecutionProvider({ runStartedAt: flowRunStartedAt, }; const finalExecution = await performFlowExecution({ - flowId: retryExecutionSnapshot.flow.id, initialExecution, flow: retryExecutionSnapshot.flow, nodes: retryExecutionSnapshot.nodes, @@ -676,15 +657,105 @@ export function ExecutionProvider({ stepId, initialExecution.artifacts, ), + updateArtifactCallback: (artifactId, content) => { + setExecution((prev) => { + if (!prev || prev.status !== "running") return null; + return { + ...prev, + artifacts: prev.artifacts.map((artifact) => + artifact.id === artifactId + ? { ...artifact, object: content } + : artifact, + ), + }; + }); + }, }); setExecution(finalExecution); }, [graph.executionIndexes, flush, retryStepAction, performFlowExecution], ); + const executeNode = useCallback( + async (nodeId: NodeId) => { + setTab("Result"); + const executionId = createExecutionId(); + const flowRunStartedAt = Date.now(); + const node = graph.nodes.find((node) => node.id === nodeId); + if (node === undefined) { + throw new Error("Node not found"); + } + + const tmpFlows = deriveFlows({ + nodes: [node], + connections: [], + }); + if (tmpFlows.length !== 1) { + throw new Error("Unexpected number of flows"); + } + const tmpFlow = tmpFlows[0]; + + // Initialize flow execution + const initialExecution: Execution = { + id: executionId, + status: "running", + jobExecutions: createInitialJobExecutions(tmpFlow), + artifacts: [], + runStartedAt: flowRunStartedAt, + }; + setExecution(initialExecution); + const finalExecution = await performFlowExecution({ + initialExecution, + flow: tmpFlow, + nodes: graph.nodes, + connections: graph.connections, + executeStepCallback: (stepId) => + executeNodeAction(executionId, node.id), + updateArtifactCallback: (artifactId, content) => { + dispatch({ + type: "upsertArtifact", + input: { + nodeId, + artifact: { + id: artifactId, + type: "streamArtifact", + creatorNodeId: nodeId, + object: content, + }, + }, + }); + }, + onStepFinish: (execution, artifact) => { + dispatch({ + type: "upsertArtifact", + input: { + nodeId, + artifact, + }, + }); + }, + }); + setExecution(finalExecution); + }, + [ + setTab, + executeNodeAction, + graph.connections, + graph.nodes, + performFlowExecution, + dispatch, + ], + ); + return ( {children} diff --git a/app/(playground)/p/[agentId]/lib/execution.ts b/app/(playground)/p/[agentId]/lib/execution.ts index d1e2cac5..2d972c5c 100644 --- a/app/(playground)/p/[agentId]/lib/execution.ts +++ b/app/(playground)/p/[agentId]/lib/execution.ts @@ -20,6 +20,7 @@ import * as v from "valibot"; import type { AgentId, Artifact, + Connection, ExecutionId, ExecutionSnapshot, FlowId, @@ -28,6 +29,7 @@ import type { NodeHandle, NodeHandleId, NodeId, + Step, StepId, TextArtifactObject, TextGenerateActionContent, @@ -62,6 +64,35 @@ function resolveLanguageModel( throw new Error("Unsupported model provider"); } +function nodeResolver(nodeHandleId: NodeHandleId, context: ExecutionContext) { + const connection = context.connections.find( + (connection) => connection.targetNodeHandleId === nodeHandleId, + ); + const node = context.nodes.find( + (node) => node.id === connection?.sourceNodeId, + ); + if (node === undefined) { + return null; + } + return node; +} + +function artifactResolver( + artifactCreatorNodeId: NodeId, + context: ExecutionContext, +) { + const generatedArtifact = context.artifacts.find( + (artifact) => artifact.creatorNodeId === artifactCreatorNodeId, + ); + if ( + generatedArtifact === undefined || + generatedArtifact.type !== "generatedArtifact" + ) { + return null; + } + return generatedArtifact; +} + const artifactSchema = v.object({ plan: v.pipe( v.string(), @@ -110,10 +141,13 @@ interface TextGenerationSource extends ExecutionSourceBase { } type ExecutionSource = TextSource | TextGenerationSource | FileSource; -async function resolveSources(sources: NodeHandle[], resolver: SourceResolver) { +async function resolveSources( + sources: NodeHandle[], + context: ExecutionContext, +) { return Promise.all( sources.map(async (source) => { - const node = resolver.nodeResolver(source.id); + const node = nodeResolver(source.id, context); switch (node?.content.type) { case "text": return { @@ -177,7 +211,7 @@ async function resolveSources(sources: NodeHandle[], resolver: SourceResolver) { ); } case "textGeneration": { - const generatedArtifact = resolver.artifactResolver(node.id); + const generatedArtifact = artifactResolver(node.id, context); if ( generatedArtifact === null || generatedArtifact.type !== "generatedArtifact" @@ -198,24 +232,19 @@ async function resolveSources(sources: NodeHandle[], resolver: SourceResolver) { ).then((sources) => sources.filter((source) => source !== null).flat()); } -interface RequirementResolver { - nodeResolver: NodeResolver; - artifactResolver: ArtifactResolver; -} - function resolveRequirement( requirement: NodeHandle | null, - resolver: RequirementResolver, + context: ExecutionContext, ) { if (requirement === null) { return null; } - const node = resolver.nodeResolver(requirement.id); + const node = nodeResolver(requirement.id, context); switch (node?.content.type) { case "text": return node.content.text; case "textGeneration": { - const generatedArtifact = resolver.artifactResolver(node.id); + const generatedArtifact = artifactResolver(node.id, context); if ( generatedArtifact === null || generatedArtifact.type === "generatedArtifact" @@ -231,11 +260,10 @@ function resolveRequirement( interface ExecutionContext { executionId: ExecutionId; - stepId: StepId; + node: Node; artifacts: Artifact[]; nodes: Node[]; - connections: ExecutionSnapshot["connections"]; - flow: ExecutionSnapshot["flow"]; + connections: Connection[]; } async function performFlowExecution(context: ExecutionContext) { @@ -244,56 +272,15 @@ async function performFlowExecution(context: ExecutionContext) { const trace = lf.trace({ sessionId: context.executionId, }); - - const step = context.flow.jobs - .flatMap((job) => job.steps) - .find((step) => step.id === context.stepId); - - if (step === undefined) { - throw new Error(`Step with id ${context.stepId} not found`); - } - - const node = context.nodes.find((node) => node.id === step.nodeId); - if (node === undefined) { - throw new Error("Node not found"); - } - - function nodeResolver(nodeHandleId: NodeHandleId) { - const connection = context.connections.find( - (connection) => connection.targetNodeHandleId === nodeHandleId, - ); - const node = context.nodes.find( - (node) => node.id === connection?.sourceNodeId, - ); - if (node === undefined) { - return null; - } - return node; - } - - function artifactResolver(artifactCreatorNodeId: NodeId) { - const generatedArtifact = context.artifacts.find( - (artifact) => artifact.creatorNodeId === artifactCreatorNodeId, - ); - if ( - generatedArtifact === undefined || - generatedArtifact.type !== "generatedArtifact" - ) { - return null; - } - return generatedArtifact; - } + const node = context.node; switch (node.content.type) { case "textGeneration": { - const actionSources = await resolveSources(node.content.sources, { - nodeResolver, - artifactResolver, - }); - const requirement = resolveRequirement(node.content.requirement ?? null, { - nodeResolver, - artifactResolver, - }); + const actionSources = await resolveSources(node.content.sources, context); + const requirement = resolveRequirement( + node.content.requirement ?? null, + context, + ); const model = resolveLanguageModel(node.content.llm); const promptTemplate = HandleBars.compile( node.content.system ?? textGenerationPrompt, @@ -397,13 +384,24 @@ export async function executeStep( throw new Error(`Flow with id ${flowId} not found`); } + const step = flow.jobs + .flatMap((job) => job.steps) + .find((step) => step.id === stepId); + + if (step === undefined) { + throw new Error(`Step with id ${stepId} not found`); + } + const node = graph.nodes.find((node) => node.id === step.nodeId); + if (node === undefined) { + throw new Error("Node not found"); + } + const context: ExecutionContext = { executionId, - stepId, + node, artifacts, nodes: graph.nodes, connections: graph.connections, - flow, }; return performFlowExecution(context); @@ -419,13 +417,58 @@ export async function retryStep( (res) => res.json() as unknown as ExecutionSnapshot, ); + const step = executionSnapshot.flow.jobs + .flatMap((job) => job.steps) + .find((step) => step.id === stepId); + + if (step === undefined) { + throw new Error(`Step with id ${stepId} not found`); + } + + const node = executionSnapshot.nodes.find((node) => node.id === step.nodeId); + if (node === undefined) { + throw new Error("Node not found"); + } + const context: ExecutionContext = { executionId, - stepId, + node, artifacts, nodes: executionSnapshot.nodes, connections: executionSnapshot.connections, - flow: executionSnapshot.flow, + }; + + return performFlowExecution(context); +} + +export async function executeNode( + agentId: AgentId, + executionId: ExecutionId, + nodeId: NodeId, +) { + const agent = await db.query.agents.findFirst({ + where: (agents, { eq }) => eq(agents.id, agentId), + }); + + if (agent === undefined || agent.graphUrl === null) { + throw new Error(`Agent with id ${agentId} not found`); + } + + const graph = await fetch(agent.graphUrl).then( + (res) => res.json() as unknown as Graph, + ); + + const node = graph.nodes.find((node) => node.id === nodeId); + if (node === undefined) { + throw new Error("Node not found"); + } + + const context: ExecutionContext = { + executionId, + node, + artifacts: graph.artifacts, + nodes: graph.nodes, + connections: graph.connections, }; return performFlowExecution(context); diff --git a/app/(playground)/p/[agentId]/lib/graph.test.ts b/app/(playground)/p/[agentId]/lib/graph.test.ts index ca097819..4f4919ae 100644 --- a/app/(playground)/p/[agentId]/lib/graph.test.ts +++ b/app/(playground)/p/[agentId]/lib/graph.test.ts @@ -150,9 +150,32 @@ describe("deriveFlows", () => { expect(flows[1].nodes.length).toBe(3); }); test("ignore ghost connectors", () => { - console.log(flows[1].jobs[2].steps); expect(flows[1].jobs[2].steps.length).toBe(1); }); + test("one node graph", () => { + const testFlows = deriveFlows({ + nodes: [ + { + id: "nd_onenode", + name: "Summary", + position: { x: 420, y: 180 }, + selected: false, + type: "action", + content: { + type: "textGeneration", + llm: "anthropic:claude-3-5-sonnet-latest", + temperature: 0.7, + topP: 1, + instruction: "Please let me know key takeaway about ", + sources: [], + }, + }, + ], + connections: [], + }); + expect(testFlows.length).toBe(1); + expect(testFlows[0].jobs[0].steps[0].nodeId).toBe("nd_onenode"); + }); }); describe("isLatestVersion", () => { diff --git a/app/(playground)/p/[agentId]/lib/graph.ts b/app/(playground)/p/[agentId]/lib/graph.ts index edc9b966..2ab40e96 100644 --- a/app/(playground)/p/[agentId]/lib/graph.ts +++ b/app/(playground)/p/[agentId]/lib/graph.ts @@ -12,7 +12,9 @@ import type { } from "../types"; import { createFlowId, createJobId, createStepId } from "./utils"; -export function deriveFlows(graph: Graph): Flow[] { +export function deriveFlows( + graph: Pick, +): Flow[] { const processedNodes = new Set(); const flows: Flow[] = []; const connectionMap = new Map>(); diff --git a/app/(playground)/p/[agentId]/page.tsx b/app/(playground)/p/[agentId]/page.tsx index bc02892a..b2e49662 100644 --- a/app/(playground)/p/[agentId]/page.tsx +++ b/app/(playground)/p/[agentId]/page.tsx @@ -9,11 +9,12 @@ import { withCountMeasurement, } from "@/lib/opentelemetry"; import { getUser } from "@/lib/supabase"; +import { recordAgentUsage } from "@/services/agents/activities"; import { del, list, put } from "@vercel/blob"; import { ReactFlowProvider } from "@xyflow/react"; import { eq } from "drizzle-orm"; import { notFound } from "next/navigation"; -import { action, putGraph } from "./actions"; +import { putGraph } from "./actions"; import { Playground } from "./components/playground"; import { AgentNameProvider } from "./contexts/agent-name"; import { DeveloperModeProvider } from "./contexts/developer-mode"; @@ -24,14 +25,12 @@ import { PlaygroundModeProvider } from "./contexts/playground-mode"; import { PropertiesPanelProvider } from "./contexts/properties-panel"; import { ToastProvider } from "./contexts/toast"; import { ToolbarContextProvider } from "./contexts/toolbar"; -import { executeStep, retryStep } from "./lib/execution"; +import { executeNode, executeStep, retryStep } from "./lib/execution"; import { isLatestVersion, migrateGraph } from "./lib/graph"; import { buildGraphExecutionPath, buildGraphFolderPath } from "./lib/utils"; import type { AgentId, Artifact, - ArtifactId, - Execution, ExecutionId, ExecutionSnapshot, FlowId, @@ -141,11 +140,6 @@ export default async function Page({ return agentName; } - async function execute(artifactId: ArtifactId, nodeId: NodeId) { - "use server"; - return await action(artifactId, agentId, nodeId); - } - async function executeStepAction( flowId: FlowId, executionId: ExecutionId, @@ -197,6 +191,20 @@ export default async function Page({ ); } + async function executeNodeAction(executionId: ExecutionId, nodeId: NodeId) { + "use server"; + return await executeNode(agentId, executionId, nodeId); + } + + async function recordAgentUsageAction( + startedAt: number, + endedAt: number, + totalDurationMs: number, + ) { + "use server"; + return await recordAgentUsage(agentId, startedAt, endedAt, totalDurationMs); + } + return ( diff --git a/app/(playground)/p/[agentId]/types.ts b/app/(playground)/p/[agentId]/types.ts index 67f8db59..52606eb3 100644 --- a/app/(playground)/p/[agentId]/types.ts +++ b/app/(playground)/p/[agentId]/types.ts @@ -162,7 +162,7 @@ export interface TextArtifactObject extends ArtifactObjectBase { completionTokens: number; }; } -interface TextArtifact extends GeneratedArtifact { +export interface TextArtifact extends GeneratedArtifact { object: TextArtifactObject; } interface TextStreamArtifact extends StreamAtrifact { diff --git a/services/agents/activities/types.ts b/services/agents/activities/deprecated-agent-activity.ts similarity index 97% rename from services/agents/activities/types.ts rename to services/agents/activities/deprecated-agent-activity.ts index fb05612b..92275e43 100644 --- a/services/agents/activities/types.ts +++ b/services/agents/activities/deprecated-agent-activity.ts @@ -1,5 +1,8 @@ import type { AgentId } from "../types"; +/** + * @deprecated + */ export class AgentActivity { private actions: AgentActivityAction[] = []; public agentId: AgentId; diff --git a/services/agents/activities/index.ts b/services/agents/activities/index.ts index 8e538a9c..84a48a3f 100644 --- a/services/agents/activities/index.ts +++ b/services/agents/activities/index.ts @@ -1,5 +1,7 @@ export { calculateAgentTimeUsageMs } from "./agent-time-usage"; export { AGENT_TIME_CHARGE_LIMIT_MINUTES } from "./constants"; +export { AgentActivity } from "./deprecated-agent-activity"; export { hasEnoughAgentTimeCharge } from "./has-enough-agent-time-charge"; -export { AgentActivity } from "./types"; +export { recordAgentUsage } from "./record-agent-usage"; +export { saveAgentActivity } from "./save-agent-activity"; export { getMonthlyBillingCycle } from "./utils"; diff --git a/services/agents/activities/record-agent-usage.ts b/services/agents/activities/record-agent-usage.ts new file mode 100644 index 00000000..1fcc4246 --- /dev/null +++ b/services/agents/activities/record-agent-usage.ts @@ -0,0 +1,21 @@ +import { toUTCDate } from "@/lib/date"; +import { reportAgentTimeUsage } from "@/services/usage-based-billing/report-agent-time-usage"; +import type { AgentId } from "../types"; +import { saveAgentActivity } from "./save-agent-activity"; + +export async function recordAgentUsage( + agentId: AgentId, + startedAt: number, + endedAt: number, + totalDurationMs: number, +) { + const startedAtDateUTC = toUTCDate(new Date(startedAt)); + const endedAtDateUTC = toUTCDate(new Date(endedAt)); + await saveAgentActivity( + agentId, + startedAtDateUTC, + endedAtDateUTC, + totalDurationMs, + ); + await reportAgentTimeUsage(endedAtDateUTC); +} diff --git a/services/agents/activities/save-agent-activity.ts b/services/agents/activities/save-agent-activity.ts new file mode 100644 index 00000000..9c02a01d --- /dev/null +++ b/services/agents/activities/save-agent-activity.ts @@ -0,0 +1,26 @@ +import { agentActivities, agents, db } from "@/drizzle"; +import type { AgentId } from "@/services/agents"; +import { eq } from "drizzle-orm"; + +export async function saveAgentActivity( + agentId: AgentId, + startedAt: Date, + endedAt: Date, + totalDurationMs: number, +) { + const records = await db + .select({ agentDbId: agents.dbId }) + .from(agents) + .where(eq(agents.id, agentId)); + if (records.length === 0) { + throw new Error(`Agent with id ${agentId} not found`); + } + const agentDbId = records[0].agentDbId; + + await db.insert(agentActivities).values({ + agentDbId, + startedAt: startedAt, + endedAt: endedAt, + totalDurationMs: totalDurationMs.toString(), + }); +} diff --git a/services/external/stripe/config.ts b/services/external/stripe/config.ts index b58ffc4a..dfaa4bff 100644 --- a/services/external/stripe/config.ts +++ b/services/external/stripe/config.ts @@ -1,6 +1,21 @@ import { Stripe } from "stripe"; -export const stripe = new Stripe(process.env.STRIPE_SECRET_KEY as string, { - // https://github.com/stripe/stripe-node#configuration - apiVersion: "2024-11-20.acacia", -}); +let stripeInstance: Stripe | null = null; + +const handler: ProxyHandler = { + get: (_target, prop: keyof Stripe | symbol) => { + if (!stripeInstance) { + const key = process.env.STRIPE_SECRET_KEY; + if (!key) { + throw new Error("STRIPE_SECRET_KEY is not configured"); + } + stripeInstance = new Stripe(key, { + // https://github.com/stripe/stripe-node#configuration + apiVersion: "2024-11-20.acacia", + }); + } + return stripeInstance[prop as keyof Stripe]; + }, +}; + +export const stripe: Stripe = new Proxy(new Stripe("dummy"), handler); diff --git a/services/usage-based-billing/report-agent-time-usage.ts b/services/usage-based-billing/report-agent-time-usage.ts new file mode 100644 index 00000000..b9932f26 --- /dev/null +++ b/services/usage-based-billing/report-agent-time-usage.ts @@ -0,0 +1,22 @@ +import { db } from "@/drizzle"; +import { stripe } from "@/services/external/stripe"; +import { fetchCurrentTeam } from "@/services/teams"; +import { processUnreportedActivities } from "@/services/usage-based-billing"; +import { AgentTimeUsageDAO } from "@/services/usage-based-billing/agent-time-usage-dao"; + +export async function reportAgentTimeUsage(targetDate: Date) { + const currentTeam = await fetchCurrentTeam(); + if (currentTeam.activeSubscriptionId == null) { + return; + } + return processUnreportedActivities( + { + teamDbId: currentTeam.dbId, + targetDate: targetDate, + }, + { + dao: new AgentTimeUsageDAO(db), + stripe: stripe, + }, + ); +}