Skip to content

Commit

Permalink
Merge pull request #2070 from langchain-ai/dqbd/debug-self-referencin…
Browse files Browse the repository at this point in the history
…g-checkpoint

fix(debug): self-referencing checkpoints when resuming streaming mid-thread
  • Loading branch information
dqbd authored Oct 10, 2024
2 parents aa83f4a + be47752 commit 28b5105
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 10 deletions.
10 changes: 3 additions & 7 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,7 @@ def __init__(
if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS)
else ()
)
self.prev_checkpoint_config = (
self.checkpoint_config
if self.checkpoint_config
and CONF in self.checkpoint_config
and CONFIG_KEY_CHECKPOINT_ID in self.checkpoint_config[CONF]
else None
)
self.prev_checkpoint_config = None

def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None:
"""Put writes for a task, to be read by the next tick."""
Expand Down Expand Up @@ -740,6 +734,7 @@ def __enter__(self) -> Self:
**saved.config.get(CONF, {}),
},
}
self.prev_checkpoint_config = saved.parent_config
self.checkpoint = saved.checkpoint
self.checkpoint_metadata = saved.metadata
self.checkpoint_pending_writes = (
Expand Down Expand Up @@ -867,6 +862,7 @@ async def __aenter__(self) -> Self:
**saved.config.get(CONF, {}),
},
}
self.prev_checkpoint_config = saved.parent_config
self.checkpoint = saved.checkpoint
self.checkpoint_metadata = saved.metadata
self.checkpoint_pending_writes = (
Expand Down
68 changes: 65 additions & 3 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11584,6 +11584,66 @@ def baz(state: State):
assert graph.invoke({"foo": "hello"}) == {"foo": "hello", "bar": "hello!"}


def test_debug_retry():
class State(TypedDict):
messages: Annotated[list[str], operator.add]

def node(name):
def _node(state: State):
return {"messages": [f"entered {name} node"]}

return _node

builder = StateGraph(State)
builder.add_node("one", node("one"))
builder.add_node("two", node("two"))
builder.add_edge(START, "one")
builder.add_edge("one", "two")
builder.add_edge("two", END)

saver = MemorySaver()

graph = builder.compile(checkpointer=saver)

config = {"configurable": {"thread_id": "1"}}
graph.invoke({"messages": []}, config=config)

# re-run step: 1
target_config = next(
c.parent_config for c in saver.list(config) if c.metadata["step"] == 1
)
update_config = graph.update_state(target_config, values=None)

events = [*graph.stream(None, config=update_config, stream_mode="debug")]

checkpoint_events = list(
reversed([e["payload"] for e in events if e["type"] == "checkpoint"])
)

checkpoint_history = {
c.config["configurable"]["checkpoint_id"]: c
for c in graph.get_state_history(config)
}

def lax_normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]

for stream in checkpoint_events:
stream_conf = lax_normalize_config(stream["config"])
stream_parent_conf = lax_normalize_config(stream["parent_config"])
assert stream_conf != stream_parent_conf

# ensure the streamed checkpoint == checkpoint from checkpointer.list()
history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]]
history_conf = lax_normalize_config(history.config)
assert stream_conf == history_conf

history_parent_conf = lax_normalize_config(history.parent_config)
assert stream_parent_conf == history_parent_conf


def test_debug_subgraphs():
class State(TypedDict):
messages: Annotated[list[str], operator.add]
Expand Down Expand Up @@ -11627,16 +11687,18 @@ def _node(state: State):

assert len(checkpoint_events) == len(checkpoint_history)

def normalize_config(config: Optional[dict]) -> Optional[dict]:
def lax_normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]

for stream, history in zip(checkpoint_events, checkpoint_history):
assert stream["values"] == history.values
assert stream["next"] == list(history.next)
assert normalize_config(stream["config"]) == normalize_config(history.config)
assert normalize_config(stream["parent_config"]) == normalize_config(
assert lax_normalize_config(stream["config"]) == lax_normalize_config(
history.config
)
assert lax_normalize_config(stream["parent_config"]) == lax_normalize_config(
history.parent_config
)

Expand Down
65 changes: 65 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9820,6 +9820,71 @@ async def node(input: State, config: RunnableConfig, store: BaseStore):
) # still overwriting the same one


async def test_debug_retry():
class State(TypedDict):
messages: Annotated[list[str], operator.add]

def node(name):
async def _node(state: State):
return {"messages": [f"entered {name} node"]}

return _node

builder = StateGraph(State)
builder.add_node("one", node("one"))
builder.add_node("two", node("two"))
builder.add_edge(START, "one")
builder.add_edge("one", "two")
builder.add_edge("two", END)

saver = MemorySaver()

graph = builder.compile(checkpointer=saver)

config = {"configurable": {"thread_id": "1"}}
await graph.ainvoke({"messages": []}, config=config)

# re-run step: 1
async for c in saver.alist(config):
if c.metadata["step"] == 1:
target_config = c.parent_config
break
assert target_config is not None

update_config = await graph.aupdate_state(target_config, values=None)

events = [
c async for c in graph.astream(None, config=update_config, stream_mode="debug")
]

checkpoint_events = list(
reversed([e["payload"] for e in events if e["type"] == "checkpoint"])
)

checkpoint_history = {
c.config["configurable"]["checkpoint_id"]: c
async for c in graph.aget_state_history(config)
}

def lax_normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]

for stream in checkpoint_events:
stream_conf = lax_normalize_config(stream["config"])
stream_parent_conf = lax_normalize_config(stream["parent_config"])
assert stream_conf != stream_parent_conf

# ensure the streamed checkpoint == checkpoint from checkpointer.list()
history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]]
history_conf = lax_normalize_config(history.config)
assert stream_conf == history_conf

history_parent_conf = lax_normalize_config(history.parent_config)
assert stream_parent_conf == history_parent_conf


async def test_debug_subgraphs():
class State(TypedDict):
messages: Annotated[list[str], operator.add]
Expand Down

0 comments on commit 28b5105

Please sign in to comment.