Skip to content

Commit

Permalink
Merge pull request #36 from route06inc/infer-flow
Browse files Browse the repository at this point in the history
Indicate the Final Node
  • Loading branch information
shige authored Oct 23, 2024
2 parents 8d02886 + f8a324b commit eb35581
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ export function GiselleNode(props: GiselleNodeProps) {
)}
/>
{props.object === "node" && (
<div className="absolute text-black-30 font-rosart text-[12px] -translate-y-full left-[8px] -top-[2px]">
<div className="absolute text-black-30 font-rosart text-[12px] -translate-y-full left-[8px] -top-[2px] flex items-center gap-[12px]">
{props.isFinal && <span>Final</span>}
{props.name}
</div>
)}
Expand Down Expand Up @@ -214,6 +215,7 @@ export function GiselleNode(props: GiselleNodeProps) {
<div>outgoing: {props.outgoingConnections?.length ?? 0}</div>
<div>property: {JSON.stringify(props.properties, null, 2)}</div>
<div>ui: {JSON.stringify(props.ui, null, 2)}</div>
<div>isFinal: {JSON.stringify(props.isFinal)}</div>
</div>
</div>
)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export type GiselleNode = {
resultPortLabel: string;
properties: Record<string, unknown>;
output: unknown;
isFinal: boolean;
};

export type GiselleNodeArtifactElement = {
Expand Down
108 changes: 98 additions & 10 deletions app/(playground)/p/[agentId]/beta-proto/graph/actions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import {
parseFile,
uploadFile,
} from "./server-actions";
import { type V2NodeAction, updateNode } from "./v2/node";

export type AddNodeAction = {
type: "addNode";
Expand All @@ -68,6 +69,7 @@ type AddNodeArgs = {
node: GiselleNodeBlueprint;
position: XYPosition;
name: string;
isFinal?: boolean;
properties?: Record<string, unknown>;
};

Expand All @@ -92,6 +94,7 @@ export const addNode = (args: AddNodeArgs): AddNodeAction => {
ui: { position: args.position },
properties: args.properties ?? {},
state: giselleNodeState.idle,
isFinal: args.isFinal ?? false,
output: "",
},
},
Expand Down Expand Up @@ -168,6 +171,8 @@ export const addNodesAndConnect = (
args: AddNodesAndConnectArgs,
): ThunkAction => {
return (dispatch, getState) => {
const state = getState();
const hasFinalNode = state.graph.nodes.some((node) => node.isFinal);
const currentNodes = getState().graph.nodes;
const addSourceNode = addNode({
...args.sourceNode,
Expand All @@ -176,6 +181,7 @@ export const addNodesAndConnect = (
dispatch(addSourceNode);
const addTargetNode = addNode({
...args.targetNode,
isFinal: !hasFinalNode,
name: `Untitled node - ${currentNodes.length + 2}`,
});
dispatch(addTargetNode);
Expand Down Expand Up @@ -810,18 +816,20 @@ export function addSourceToPromptNode(
): ThunkAction {
return async (dispatch, getState) => {
const state = getState();
const updateNode = state.graph.nodes.find(
const targetPromptNode = state.graph.nodes.find(
(node) => node.id === args.promptNode.id,
);
if (updateNode === undefined) {
if (targetPromptNode === undefined) {
return;
}
if (updateNode.archetype !== giselleNodeArchetypes.prompt) {
if (targetPromptNode.archetype !== giselleNodeArchetypes.prompt) {
return;
}
const currentSources = updateNode.properties.sources ?? [];
const currentSources = targetPromptNode.properties.sources ?? [];
if (!Array.isArray(currentSources)) {
throw new Error(`${updateNode.id}'s sources property is not an array`);
throw new Error(
`${targetPromptNode.id}'s sources property is not an array`,
);
}
dispatch(
updateNodeProperty({
Expand Down Expand Up @@ -885,6 +893,28 @@ export function addSourceToPromptNode(
},
}),
);

const artifactGeneratorNode = state.graph.nodes.find(
(node) => node.id === artifact?.generatorNode.id,
);
if (artifactGeneratorNode?.isFinal) {
dispatch(
updateNode({
input: {
id: artifact.generatorNode.id,
isFinal: false,
},
}),
);
dispatch(
updateNode({
input: {
id: outgoingNode.id,
isFinal: true,
},
}),
);
}
}
} else if (args.source.object === "file") {
if (args.source.status === fileStatuses.uploading) {
Expand Down Expand Up @@ -986,6 +1016,27 @@ export function addSourceToPromptNode(
},
}),
);
const webSearchGeneratorNode = state.graph.nodes.find(
(node) => node.id === webSearch.generatorNode.id,
);
if (webSearchGeneratorNode?.isFinal) {
dispatch(
updateNode({
input: {
id: webSearchGeneratorNode.id,
isFinal: false,
},
}),
);
dispatch(
updateNode({
input: {
id: outgoingNode.id,
isFinal: true,
},
}),
);
}
}
}
};
Expand Down Expand Up @@ -1077,12 +1128,30 @@ export function removeSourceFromPromptNode(
},
}),
);
if (outgoingNode?.isFinal) {
dispatch(
updateNode({
input: {
id: outgoingNode.id,
isFinal: false,
},
}),
);
dispatch(
updateNode({
input: {
id: artifact.generatorNode.id,
isFinal: true,
},
}),
);
}
}
} else if (args.source.object === "webSearch") {
const artifact = state.graph.artifacts.find(
(artifact) => artifact.id === args.source.id,
const webSearch = state.graph.webSearches.find(
(webSearch) => webSearch.id === args.source.id,
);
if (artifact === undefined) {
if (webSearch === undefined) {
return;
}
const outgoingConnectors = state.graph.connectors.filter(
Expand All @@ -1099,7 +1168,7 @@ export function removeSourceFromPromptNode(
state.graph.connectors.find(
(connector) =>
connector.target === outgoingNode.id &&
connector.source === artifact.generatorNode.id,
connector.source === webSearch.generatorNode.id,
);
if (artifactCreatorNodeToOutgoingNodeConnector === undefined) {
continue;
Expand All @@ -1121,6 +1190,24 @@ export function removeSourceFromPromptNode(
},
}),
);
if (outgoingNode?.isFinal) {
dispatch(
updateNode({
input: {
id: outgoingNode.id,
isFinal: false,
},
}),
);
dispatch(
updateNode({
input: {
id: webSearch.generatorNode.id,
isFinal: true,
},
}),
);
}
}
}
};
Expand Down Expand Up @@ -1253,4 +1340,5 @@ export type GraphAction =
| RemoveArtifactAction
| AddParameterToNodeAction
| RemoveParameterFromNodeAction
| UpsertWebSearchAction;
| UpsertWebSearchAction
| V2NodeAction;
10 changes: 10 additions & 0 deletions app/(playground)/p/[agentId]/beta-proto/graph/reducer.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import type { GraphAction } from "./actions";
import type { GraphState } from "./types";
import { isV2NodeAction, v2NodeReducer } from "./v2/node";

export const graphReducer = (
state: GraphState,
action: GraphAction,
): GraphState => {
if (isV2NodeAction(action)) {
return {
...state,
graph: {
...state.graph,
nodes: v2NodeReducer(state.graph.nodes, action),
},
};
}
switch (action.type) {
case "addNode":
return {
Expand Down
49 changes: 49 additions & 0 deletions app/(playground)/p/[agentId]/beta-proto/graph/v2/node.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import type { GiselleNode, GiselleNodeId } from "../../giselle-node/types";

const v2NodeActionTypes = {
updateNode: "v2.updateNode",
} as const;
type V2NodeActionType =
(typeof v2NodeActionTypes)[keyof typeof v2NodeActionTypes];
export function isV2NodeAction(action: unknown): action is V2NodeAction {
return Object.values(v2NodeActionTypes).includes(
(action as V2NodeAction).type,
);
}
interface UpdateNodeAction {
type: Extract<V2NodeActionType, "v2.updateNode">;
input: UpdateNodeInput;
}
interface UpdateNodeInput {
id: GiselleNodeId;
isFinal?: boolean;
}
export function updateNode({
input,
}: { input: UpdateNodeInput }): UpdateNodeAction {
return {
type: v2NodeActionTypes.updateNode,
input,
};
}

export type V2NodeAction = UpdateNodeAction;

export function v2NodeReducer(
nodes: GiselleNode[],
action: V2NodeAction,
): GiselleNode[] {
switch (action.type) {
case v2NodeActionTypes.updateNode:
return nodes.map((node) => {
if (node.id === action.input.id) {
return {
...node,
isFinal: action.input.isFinal ?? false,
};
}
return node;
});
}
return nodes;
}
56 changes: 56 additions & 0 deletions scripts/20241022-db-patch-node-is-final.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { giselleNodeCategories } from "@/app/(playground)/p/[agentId]/beta-proto/giselle-node/types";
import { agents, db } from "@/drizzle";
import { eq } from "drizzle-orm";

function chunkArray<T>(array: T[], chunkSize: number): T[][] {
const chunks: T[][] = [];
for (let i = 0; i < array.length; i += chunkSize) {
chunks.push(array.slice(i, i + chunkSize));
}
return chunks;
}

console.log("Add 'isFinal' to the all of the nodes...");
const listOfAgents = await db.select().from(agents);
console.log(`Updating ${listOfAgents.length} agents...`);

const agentChunks = chunkArray(listOfAgents, 10);

for (let i = 0; i < agentChunks.length; i++) {
const chunk = agentChunks[i];
console.log(
` ├ Processing chunk ${i + 1}/${agentChunks.length} (${chunk.length} agents)`,
);

await Promise.all(
chunk.map(async (agent) => {
const instructionNodes = agent.graphv2.nodes.filter(
(node) => node.category === giselleNodeCategories.instruction,
);
const actionNodes = agent.graphv2.nodes.filter(
(node) => node.category === giselleNodeCategories.action,
);
await db
.update(agents)
.set({
graphv2: {
...agent.graphv2,
nodes: [
...instructionNodes.map((node) => ({
...node,
isFinal: false,
})),
...actionNodes.map((node, index) => ({
...node,
isFinal: index === actionNodes.length - 1,
})),
],
},
})
.where(eq(agents.id, agent.id));
}),
);

console.log(` │ └ Completed chunk ${i + 1}/${i + 1}`);
}
console.log("All agents have been updated!");

0 comments on commit eb35581

Please sign in to comment.