Skip to content

Commit

Permalink
feat: Add cancel button to playground runs (#5566)
Browse files Browse the repository at this point in the history
* feat: Add cancel button to playground runs

Can cancel runs with and without datasets. Cancellation is purely client-side.
The backend should abort its tasks when connection with the client is lost.

* Abort network requests when gql query/mutation is disposed

* Reset loading state when canceling run

* Rethrow AbortError inside of authFetch
  • Loading branch information
cephalization authored Dec 2, 2024
1 parent 3330bff commit 6bafa46
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 40 deletions.
111 changes: 78 additions & 33 deletions app/src/RelayEnvironment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,50 +15,95 @@ import {
import { authFetch } from "@phoenix/authFetch";
import { BASE_URL, WS_BASE_URL } from "@phoenix/config";

import { isObject } from "./typeUtils";

const graphQLPath = BASE_URL + "/graphql";

const graphQLFetch = window.Config.authenticationEnabled ? authFetch : fetch;

/**
* Relay requires developers to configure a "fetch" function that tells Relay how to load
* the results of GraphQL queries from your server (or other data source). See more at
* https://relay.dev/docs/en/quick-start-guide#relay-environment.
* Create an observable that fetches JSON from the given input and returns an error if
* the data has errors.
*
* The observable aborts in-flight network requests when the unsubscribe function is
* called.
*
* @param input - The input to fetch from.
* @param init - The request init options.
* @param hasErrors - A function that returns an error if the data has errors.
* @returns An observable that emits the data or an error.
*/
const fetchRelay: FetchFunction = async (params, variables, _cacheConfig) => {
const response = await graphQLFetch(graphQLPath, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
query: params.text,
variables,
}),
});

// Get the response as JSON
const json = await response.json();
function fetchJsonObservable<T>(
input: RequestInfo | URL,
init?: RequestInit,
hasErrors?: (data: unknown) => Error | undefined
): Observable<T> {
return Observable.create((sink) => {
const controller = new AbortController();

// GraphQL returns exceptions (for example, a missing required variable) in the "errors"
// property of the response. If any exceptions occurred when processing the request,
// throw an error to indicate to the developer what went wrong.
if (Array.isArray(json.errors)) {
throw new Error(
`Error fetching GraphQL query '${
params.name
}' with variables '${JSON.stringify(variables)}': ${JSON.stringify(
json.errors
)}`
);
}
graphQLFetch(input, { ...init, signal: controller.signal })
.then((response) => response.json())
.then((data) => {
const error = hasErrors?.(data);
if (error) {
throw error;
}
sink.next(data as T);
sink.complete();
})
.catch((error) => {
if (error.name === "AbortError") {
// this is triggered when the controller is aborted
sink.complete();
} else {
// this is triggered when graphQLFetch throws an error or the response
// data has errors
sink.error(error);
}
});

// Otherwise, return the full payload.
return json;
};
return () => {
// abort the fetch request when the observable is unsubscribed
controller.abort();
};
});
}

/**
* Check whether or not we are running
* Relay requires developers to configure a "fetch" function that tells Relay how to load
* the results of GraphQL queries from your server (or other data source). See more at
* https://relay.dev/docs/en/quick-start-guide#relay-environment.
*/
const fetchRelay: FetchFunction = (params, variables, _cacheConfig) =>
fetchJsonObservable(
graphQLPath,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
query: params.text,
variables,
}),
},
// GraphQL returns exceptions (for example, a missing required variable) in the "errors"
// property of the response. If any exceptions occurred when processing the request,
// throw an error to indicate to the developer what went wrong.
(data) => {
if (!isObject(data) || !("errors" in data)) {
return;
}
if (Array.isArray(data.errors)) {
return new Error(
`Error fetching GraphQL query '${params.name}' with variables '${JSON.stringify(
variables
)}': ${JSON.stringify(data.errors)}`
);
}
}
);

const wsClient = createClient({
url: `${WS_BASE_URL}/graphql`,
});
Expand Down
4 changes: 4 additions & 0 deletions app/src/authFetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ export async function authFetch(
// Retry the original request
return fetch(input, init);
}
if (error instanceof Error && error.name === "AbortError") {
// This is triggered when the controller is aborted
throw error;
}
}
throw new Error("An unexpected error occurred while fetching data");
}
Expand Down
9 changes: 8 additions & 1 deletion app/src/pages/playground/PlaygroundDatasetExamplesTable.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ export function PlaygroundDatasetExamplesTable({
}
};
} else {
const disposables: Disposable[] = [];
for (const instance of instances) {
const { activeRunId } = instance;
if (activeRunId === null) {
Expand All @@ -591,7 +592,7 @@ export function PlaygroundDatasetExamplesTable({
datasetId,
}),
};
generateChatCompletion({
const disposable = generateChatCompletion({
variables,
onCompleted: onCompleted(instance.id),
onError(error) {
Expand All @@ -612,7 +613,13 @@ export function PlaygroundDatasetExamplesTable({
}
},
});
disposables.push(disposable);
}
return () => {
for (const disposable of disposables) {
disposable.dispose();
}
};
}
}, [
credentials,
Expand Down
6 changes: 5 additions & 1 deletion app/src/pages/playground/PlaygroundOutput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ export function PlaygroundOutput(props: PlaygroundOutputProps) {

useEffect(() => {
if (!hasRunId) {
setLoading(false);
return;
}
setLoading(true);
Expand Down Expand Up @@ -378,7 +379,8 @@ export function PlaygroundOutput(props: PlaygroundOutputProps) {
const subscription = requestSubscription(environment, config);
return subscription.dispose;
}
generateChatCompletion({

const disposable = generateChatCompletion({
variables: {
input,
},
Expand All @@ -400,6 +402,8 @@ export function PlaygroundOutput(props: PlaygroundOutputProps) {
}
},
});

return disposable.dispose;
}, [
cleanup,
credentials,
Expand Down
34 changes: 29 additions & 5 deletions app/src/pages/playground/PlaygroundRunButton.tsx
Original file line number Diff line number Diff line change
@@ -1,28 +1,52 @@
import React from "react";
import { css } from "@emotion/react";

import { Button, Icon, Icons } from "@arizeai/components";

import { Loading } from "@phoenix/components";
import { usePlaygroundContext } from "@phoenix/contexts/PlaygroundContext";

export function PlaygroundRunButton() {
const runPlaygroundInstances = usePlaygroundContext(
(state) => state.runPlaygroundInstances
);
const cancelPlaygroundInstances = usePlaygroundContext(
(state) => state.cancelPlaygroundInstances
);
const isRunning = usePlaygroundContext((state) =>
state.instances.some((instance) => instance.activeRunId != null)
);
return (
<Button
variant="primary"
disabled={isRunning}
icon={<Icon svg={<Icons.PlayCircleOutline />} />}
loading={isRunning}
icon={
!isRunning ? (
<Icon svg={<Icons.PlayCircleOutline />} />
) : (
<div
css={css`
margin-right: var(--ac-global-dimension-static-size-50);
& > * {
height: 1em;
width: 1em;
font-size: 1.3rem;
}
`}
>
<Loading size="S" />
</div>
)
}
size="compact"
onClick={() => {
runPlaygroundInstances();
if (isRunning) {
cancelPlaygroundInstances();
} else {
runPlaygroundInstances();
}
}}
>
{isRunning ? "Running..." : "Run"}
{isRunning ? "Cancel" : "Run"}
</Button>
);
}
9 changes: 9 additions & 0 deletions app/src/store/playground/playgroundStore.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,15 @@ export const createPlaygroundStore = (initialProps: InitialPlaygroundState) => {
})),
});
},
cancelPlaygroundInstances: () => {
set({
instances: get().instances.map((instance) => ({
...instance,
activeRunId: null,
spanId: null,
})),
});
},
markPlaygroundInstanceComplete: (instanceId: number) => {
const instances = get().instances;
set({
Expand Down
4 changes: 4 additions & 0 deletions app/src/store/playground/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ export interface PlaygroundState extends PlaygroundProps {
* Run all the active playground Instances
*/
runPlaygroundInstances: () => void;
/**
* Cancel all the active playground Instances
*/
cancelPlaygroundInstances: () => void;
/**
* Mark a given playground instance as completed
*/
Expand Down
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ pass_env=
PHOENIX_HOST_ROOT_PATH
PHOENIX_SQL_DATABASE_URL
PHOENIX_SQL_DATABASE_SCHEMA
PHOENIX_ENABLE_AUTH
PHOENIX_SECRET
commands_pre =
uv tool install arize-phoenix@. \
--reinstall-package arize-phoenix \
Expand Down

0 comments on commit 6bafa46

Please sign in to comment.