From ac8b51f1f2513daed4e01bf8619af230a02b95d0 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Thu, 10 Oct 2024 02:21:17 +0200 Subject: [PATCH 1/4] fix(debug): add failing test for self-referencing --- libs/langgraph/tests/test_pregel.py | 68 +++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index e75df2245..7ee8dc926 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -11584,6 +11584,66 @@ def baz(state: State): assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"} +def test_debug_retry(): + class State(TypedDict): + messages: Annotated[list[str], operator.add] + + def node(name): + def _node(state: State): + return {"messages": [f"entered {name} node"]} + + return _node + + builder = StateGraph(State) + builder.add_node("one", node("one")) + builder.add_node("two", node("two")) + builder.add_edge(START, "one") + builder.add_edge("one", "two") + builder.add_edge("two", END) + + saver = MemorySaver() + + graph = builder.compile(checkpointer=saver) + + config = {"configurable": {"thread_id": "1"}} + graph.invoke({"messages": []}, config=config) + + # re-run step: 1 + target_config = next( + c.parent_config for c in saver.list(config) if c.metadata["step"] == 1 + ) + update_config = graph.update_state(target_config, values=None) + + events = [*graph.stream(None, config=update_config, stream_mode="debug")] + + checkpoint_events = list( + reversed([e["payload"] for e in events if e["type"] == "checkpoint"]) + ) + + checkpoint_history = { + c.config["configurable"]["checkpoint_id"]: c + for c in graph.get_state_history(config) + } + + def lax_normalize_config(config: Optional[dict]) -> Optional[dict]: + if config is None: + return None + return config["configurable"] + + for stream in checkpoint_events: + stream_conf = lax_normalize_config(stream["config"]) + stream_parent_conf = lax_normalize_config(stream["parent_config"]) + assert stream_conf != stream_parent_conf + + # ensure the streamed checkpoint == checkpoint from checkpointer.list() + history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]] + history_conf = lax_normalize_config(history.config) + assert stream_conf == history_conf + + history_parent_conf = lax_normalize_config(history.parent_config) + assert stream_parent_conf == history_parent_conf + + def test_debug_subgraphs(): class State(TypedDict): messages: Annotated[list[str], operator.add] @@ -11627,7 +11687,7 @@ def _node(state: State): assert len(checkpoint_events) == len(checkpoint_history) - def normalize_config(config: Optional[dict]) -> Optional[dict]: + def lax_normalize_config(config: Optional[dict]) -> Optional[dict]: if config is None: return None return config["configurable"] @@ -11635,8 +11695,10 @@ def normalize_config(config: Optional[dict]) -> Optional[dict]: for stream, history in zip(checkpoint_events, checkpoint_history): assert stream["values"] == history.values assert stream["next"] == list(history.next) - assert normalize_config(stream["config"]) == normalize_config(history.config) - assert normalize_config(stream["parent_config"]) == normalize_config( + assert lax_normalize_config(stream["config"]) == lax_normalize_config( + history.config + ) + assert lax_normalize_config(stream["parent_config"]) == lax_normalize_config( history.parent_config ) From fb8c3869588c8dae270211427feee5d62f1dd47b Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Thu, 10 Oct 2024 02:21:32 +0200 Subject: [PATCH 2/4] fix(debug): address self-referencing --- libs/langgraph/langgraph/pregel/loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 362756f32..83ce36379 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -740,6 +740,7 @@ def __enter__(self) -> Self: **saved.config.get(CONF, {}), }, } + self.prev_checkpoint_config = saved.parent_config self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( @@ -867,6 +868,7 @@ async def __aenter__(self) -> Self: **saved.config.get(CONF, {}), }, } + self.prev_checkpoint_config = saved.parent_config self.checkpoint = saved.checkpoint self.checkpoint_metadata = saved.metadata self.checkpoint_pending_writes = ( From 4d69331a52b63b329bb3705c5c677b5f0335d4ac Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Thu, 10 Oct 2024 02:29:28 +0200 Subject: [PATCH 3/4] Add async tests --- libs/langgraph/tests/test_pregel_async.py | 65 +++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 3be52643b..5a25a3f01 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -9820,6 +9820,71 @@ async def node(input: State, config: RunnableConfig, store: BaseStore): ) # still overwriting the same one +async def test_debug_retry(): + class State(TypedDict): + messages: Annotated[list[str], operator.add] + + def node(name): + async def _node(state: State): + return {"messages": [f"entered {name} node"]} + + return _node + + builder = StateGraph(State) + builder.add_node("one", node("one")) + builder.add_node("two", node("two")) + builder.add_edge(START, "one") + builder.add_edge("one", "two") + builder.add_edge("two", END) + + saver = MemorySaver() + + graph = builder.compile(checkpointer=saver) + + config = {"configurable": {"thread_id": "1"}} + await graph.ainvoke({"messages": []}, config=config) + + # re-run step: 1 + async for c in saver.alist(config): + if c.metadata["step"] == 1: + target_config = c.parent_config + break + assert target_config is not None + + update_config = await graph.aupdate_state(target_config, values=None) + + events = [ + c async for c in graph.astream(None, config=update_config, stream_mode="debug") + ] + + checkpoint_events = list( + reversed([e["payload"] for e in events if e["type"] == "checkpoint"]) + ) + + checkpoint_history = { + c.config["configurable"]["checkpoint_id"]: c + async for c in graph.aget_state_history(config) + } + + def lax_normalize_config(config: Optional[dict]) -> Optional[dict]: + if config is None: + return None + return config["configurable"] + + for stream in checkpoint_events: + stream_conf = lax_normalize_config(stream["config"]) + stream_parent_conf = lax_normalize_config(stream["parent_config"]) + assert stream_conf != stream_parent_conf + + # ensure the streamed checkpoint == checkpoint from checkpointer.list() + history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]] + history_conf = lax_normalize_config(history.config) + assert stream_conf == history_conf + + history_parent_conf = lax_normalize_config(history.parent_config) + assert stream_parent_conf == history_parent_conf + + async def test_debug_subgraphs(): class State(TypedDict): messages: Annotated[list[str], operator.add] From be47752f0ece96eba6e46b129ab42773675719c5 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Thu, 10 Oct 2024 11:15:54 +0200 Subject: [PATCH 4/4] Initialise to None --- libs/langgraph/langgraph/pregel/loop.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 83ce36379..2ac371f2a 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -250,13 +250,7 @@ def __init__( if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) else () ) - self.prev_checkpoint_config = ( - self.checkpoint_config - if self.checkpoint_config - and CONF in self.checkpoint_config - and CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF] - else None - ) + self.prev_checkpoint_config = None def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None: """Put writes for a task, to be read by the next tick."""