Skip to content

Commit

Permalink
feat(playground): wire up tool calling ui (#5029)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 16, 2024
1 parent 5485c03 commit 75b7000
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 27 deletions.
4 changes: 2 additions & 2 deletions app/src/components/generative/ToolChoiceSelector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type ToolChoicePickerProps = {
/**
* The current choice including the default {@link ToolChoice} and any user defined tools
*/
choice: ToolChoice;
choice: ToolChoice | undefined;
/**
* Callback for when the tool choice changes
*/
Expand All @@ -52,7 +52,7 @@ export function ToolChoicePicker({
toolNames,
}: ToolChoicePickerProps) {
const currentKey =
typeof choice === "string"
choice == null || typeof choice === "string"
? choice
: addToolNamePrefix(choice.function.name);
return (
Expand Down
4 changes: 3 additions & 1 deletion app/src/pages/playground/PlaygroundChatTemplate.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ function SortableMessageItem({
}
extra={
<Flex direction="row" gap="size-100">
<CopyToClipboardButton text={message.content} />
{message.content != null && (
<CopyToClipboardButton text={message.content} />
)}
<Button
aria-label="Delete message"
icon={<Icon svg={<Icons.TrashOutline />} />}
Expand Down
77 changes: 70 additions & 7 deletions app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import React, { useMemo, useState } from "react";
import { useSubscription } from "react-relay";
import { graphql, GraphQLSubscriptionConfig } from "relay-runtime";
import { css } from "@emotion/react";

import { Card, Flex, Icon, Icons } from "@arizeai/components";

import { useCredentialsContext } from "@phoenix/contexts/CredentialsContext";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";
import { useChatMessageStyles } from "@phoenix/hooks/useChatMessageStyles";
import type { ToolCall } from "@phoenix/store";
import { ChatMessage, generateMessageId } from "@phoenix/store";
import { assertUnreachable } from "@phoenix/typeUtils";

Expand All @@ -24,11 +26,37 @@ import { PlaygroundInstanceProps } from "./types";
interface PlaygroundOutputProps extends PlaygroundInstanceProps {}

function PlaygroundOutputMessage({ message }: { message: ChatMessage }) {
const styles = useChatMessageStyles(message.role);
const { role, content, toolCalls } = message;
const styles = useChatMessageStyles(role);

return (
<Card title={message.role} {...styles} variant="compact">
{message.content}
<Card title={role} {...styles} variant="compact">
{content != null && (
<Flex direction="column" alignItems="start">
{content}
</Flex>
)}
{toolCalls && toolCalls.length > 0
? toolCalls.map((toolCall) => {
return (
<pre
key={toolCall.id}
css={css`
text-wrap: wrap;
margin: var(--ac-global-dimension-static-size-100) 0;
`}
>
{toolCall.function.name}(
{JSON.stringify(
JSON.parse(toolCall.function.arguments),
null,
2
)}
)
</pre>
);
})
: null}
</Card>
);
}
Expand Down Expand Up @@ -106,13 +134,15 @@ function useChatCompletionSubscription({
$messages: [ChatCompletionMessageInput!]!
$model: GenerativeModelInput!
$invocationParameters: InvocationParameters!
$tools: [JSON!]
$apiKey: String
) {
chatCompletion(
input: {
messages: $messages
model: $model
invocationParameters: $invocationParameters
tools: $tools
apiKey: $apiKey
}
) {
Expand Down Expand Up @@ -196,7 +226,8 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
throw new Error("We only support chat templates for now");
}

const [output, setOutput] = useState<string>("");
const [output, setOutput] = useState<string | undefined>(undefined);
const [toolCalls, setToolCalls] = useState<ToolCall[]>([]);

useChatCompletionSubscription({
params: {
Expand All @@ -206,23 +237,54 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
name: instance.model.modelName || "",
},
invocationParameters: {
temperature: 0.1, // TODO: add invocation parameters
toolChoice: instance.toolChoice,
},
tools: instance.tools.map((tool) => tool.definition),
apiKey: credentials[instance.model.provider],
},
runId: instance.activeRunId,
onNext: (response) => {
const chatCompletion = response.chatCompletion;
if (chatCompletion.__typename === "TextChunk") {
setOutput((acc) => acc + chatCompletion.content);
setOutput((acc) => (acc || "") + chatCompletion.content);
} else if (chatCompletion.__typename === "ToolCallChunk") {
setToolCalls((toolCalls) => {
let toolCallExists = false;
const updated = toolCalls.map((toolCall) => {
if (toolCall.id === chatCompletion.id) {
toolCallExists = true;
return {
...toolCall,
function: {
...toolCall.function,
arguments:
toolCall.function.arguments +
chatCompletion.function.arguments,
},
};
} else {
return toolCall;
}
});
if (!toolCallExists) {
updated.push({
id: chatCompletion.id,
function: {
name: chatCompletion.function.name,
arguments: chatCompletion.function.arguments,
},
});
}
return updated;
});
}
},
onCompleted: () => {
markPlaygroundInstanceComplete(props.playgroundInstanceId);
},
});

if (!output) {
if (!output && (toolCalls.length === 0 || instance.isRunning)) {
return (
<Flex direction="row" gap="size-100" alignItems="center">
<Icon svg={<Icons.LoadingOutline />} />
Expand All @@ -236,6 +298,7 @@ function PlaygroundOutputText(props: PlaygroundInstanceProps) {
id: generateMessageId(),
content: output,
role: "ai",
toolCalls: toolCalls,
}}
/>
);
Expand Down
1 change: 1 addition & 0 deletions app/src/pages/playground/PlaygroundTool.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ export function PlaygroundTool({
instanceId: playgroundInstanceId,
patch: {
tools: instanceTools.filter((t) => t.id !== tool.id),
toolChoice: undefined,
},
});
}}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion app/src/pages/playground/__tests__/playgroundUtils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const expectedPlaygroundInstanceWithIO: PlaygroundInstance = {
},
input: { variableKeys: [], variablesValueCache: {} },
tools: [],
toolChoice: "auto",
toolChoice: undefined,
template: {
__type: "chat",
// These id's are not 0, 1, 2, because we create a playground instance (including messages) at the top of the transformSpanAttributesToPlaygroundInstance function
Expand Down
9 changes: 5 additions & 4 deletions app/src/store/playground/playgroundStore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export function createPlaygroundInstance(): PlaygroundInstance {
template: generateChatCompletionTemplate(),
model: { provider: DEFAULT_MODEL_PROVIDER, modelName: "gpt-4o" },
tools: [],
toolChoice: "auto",
toolChoice: undefined,
// TODO(apowell) - use datasetId if in dataset mode
input: { variablesValueCache: {}, variableKeys: [] },
output: undefined,
Expand Down Expand Up @@ -310,9 +310,10 @@ export const createPlaygroundStore = (
// for each chat message in the instance
instance.template.messages.forEach((message) => {
// extract variables from the message content
const extractedVariables = utils.extractVariables(
message.content
);
const extractedVariables =
message.content == null
? []
: utils.extractVariables(message.content);
extractedVariables.forEach((variable) => {
variables.add(variable);
});
Expand Down
43 changes: 40 additions & 3 deletions app/src/store/playground/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,51 @@ export type GenAIOperationType = "chat" | "text_completion";
*/
export type PlaygroundInputMode = "manual" | "dataset";

/**
* A tool call that invokes a function with JSON arguments
* @example
* ```typescript
* {
* id: "1",
* function: {
* name: "getCurrentWeather",
* arguments: "{ \"city\": \"San Francisco\" }"
* }
* }
* ```
*/
export type ToolCall = {
id: string;
function: {
name: string;
arguments: string;
};
};

/**
* A chat message with a role and content
* @example { role: "user", content: "What is the meaning of life?" }
* @example { role: "user", content: "What is the weather in San Francisco?" }
* @example
* ```typescript
* {
* "role": "assistant",
* "toolCalls": [
* {
* "id": "1",
* "function": {
* "name": "getCurrentWeather",
* "arguments": "{ \"city\": \"San Francisco\" }"
* }
* }
* ]
* }
* ```
*/
export type ChatMessage = {
id: number;
role: ChatMessageRole;
content: string;
content?: string;
toolCalls?: ToolCall[];
};

/**
Expand Down Expand Up @@ -82,7 +119,7 @@ export interface PlaygroundInstance {
id: number;
template: PlaygroundTemplate;
tools: Tool[];
toolChoice: ToolChoice;
toolChoice: ToolChoice | undefined;
input: PlaygroundInput;
model: ModelConfig;
output: ChatMessage[] | undefined | string;
Expand Down
4 changes: 3 additions & 1 deletion src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from collections import defaultdict
from dataclasses import fields
from datetime import datetime
from itertools import chain
from typing import (
Expand Down Expand Up @@ -280,8 +281,9 @@ def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:


def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
yield INPUT_MIME_TYPE, JSON
yield INPUT_VALUE, safe_json_dumps(jsonify(input))
yield INPUT_VALUE, safe_json_dumps({k: v for k, v in jsonify(input).items() if k != api_key})


def _output_value_and_mime_type(output: Any) -> Iterator[Tuple[str, Any]]:
Expand Down

0 comments on commit 75b7000

Please sign in to comment.