Skip to content

Commit

Permalink
Add async tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dqbd committed Oct 10, 2024
1 parent fb8c386 commit 4d69331
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9820,6 +9820,71 @@ async def node(input: State, config: RunnableConfig, store: BaseStore):
) # still overwriting the same one


async def test_debug_retry():
class State(TypedDict):
messages: Annotated[list[str], operator.add]

def node(name):
async def _node(state: State):
return {"messages": [f"entered {name} node"]}

return _node

builder = StateGraph(State)
builder.add_node("one", node("one"))
builder.add_node("two", node("two"))
builder.add_edge(START, "one")
builder.add_edge("one", "two")
builder.add_edge("two", END)

saver = MemorySaver()

graph = builder.compile(checkpointer=saver)

config = {"configurable": {"thread_id": "1"}}
await graph.ainvoke({"messages": []}, config=config)

# re-run step: 1
async for c in saver.alist(config):
if c.metadata["step"] == 1:
target_config = c.parent_config
break
assert target_config is not None

update_config = await graph.aupdate_state(target_config, values=None)

events = [
c async for c in graph.astream(None, config=update_config, stream_mode="debug")
]

checkpoint_events = list(
reversed([e["payload"] for e in events if e["type"] == "checkpoint"])
)

checkpoint_history = {
c.config["configurable"]["checkpoint_id"]: c
async for c in graph.aget_state_history(config)
}

def lax_normalize_config(config: Optional[dict]) -> Optional[dict]:
if config is None:
return None
return config["configurable"]

for stream in checkpoint_events:
stream_conf = lax_normalize_config(stream["config"])
stream_parent_conf = lax_normalize_config(stream["parent_config"])
assert stream_conf != stream_parent_conf

# ensure the streamed checkpoint == checkpoint from checkpointer.list()
history = checkpoint_history[stream["config"]["configurable"]["checkpoint_id"]]
history_conf = lax_normalize_config(history.config)
assert stream_conf == history_conf

history_parent_conf = lax_normalize_config(history.parent_config)
assert stream_parent_conf == history_parent_conf


async def test_debug_subgraphs():
class State(TypedDict):
messages: Annotated[list[str], operator.add]
Expand Down

0 comments on commit 4d69331

Please sign in to comment.