From 4d69331a52b63b329bb3705c5c677b5f0335d4ac Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Thu, 10 Oct 2024 02:29:28 +0200 Subject: [PATCH] Add async tests --- libs/langgraph/tests/test_pregel_async.py | 65 +++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 3be52643b..5a25a3f01 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -9820,6 +9820,71 @@ async def node(input: State, config: RunnableConfig, store: BaseStore): ) # still overwriting the same one +async def test_debug_retry(): + class State(TypedDict): + messages: Annotated[list[str], operator.add] + + def node(name): + async def _node(state: State): + return {"messages": [f"entered {name} node"]} + + return _node + + builder = StateGraph(State) + builder.add_node("one", node("one")) + builder.add_node("two", node("two")) + builder.add_edge(START, "one") + builder.add_edge("one", "two") + builder.add_edge("two", END) + + saver = MemorySaver() + + graph = builder.compile(checkpointer=saver) + + config = {"configurable": {"thread_id": "1"}} + await graph.ainvoke({"messages": []}, config=config) + + # re-run step: 1 + async for c in saver.alist(config): + if c.metadata["step"] == 1: + target_config = c.parent_config + break + assert target_config is not None + + update_config = await graph.aupdate_state(target_config, values=None) + + events = [ + c async for c in graph.astream(None, config=update_config, stream_mode="debug") + ] + + checkpoint_events = list( + reversed([e["payload"] for e in events if e["type"] == "checkpoint"]) + ) + + checkpoint_history = { + c.config["configurable"]["checkpoint_id"]: c + async for c in graph.aget_state_history(config) + } + + def lax_normalize_config(config: Optional[dict]) -> Optional[dict]: + if config is None: + return None + return config["configurable"] + + for stream in checkpoint_events: + stream_conf = lax_normalize_config(stream["config"]) + stream_parent_conf = lax_normalize_config(stream["parent_config"]) + assert stream_conf != stream_parent_conf + + # ensure the streamed checkpoint == checkpoint from checkpointer.list() + history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]] + history_conf = lax_normalize_config(history.config) + assert stream_conf == history_conf + + history_parent_conf = lax_normalize_config(history.parent_config) + assert stream_parent_conf == history_parent_conf + + async def test_debug_subgraphs(): class State(TypedDict): messages: Annotated[list[str], operator.add]