Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 11, 2024
1 parent 89a0859 commit 3ad966e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 44 deletions.
1 change: 0 additions & 1 deletion libs/langgraph/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,6 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
ALL_CHECKPOINTERS_ASYNC = [
"memory",
"sqlite_aio",
"duckdb_aio",
"postgres_aio",
"postgres_aio_pipe",
"postgres_aio_pool",
Expand Down
35 changes: 17 additions & 18 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,16 +1875,15 @@ def __call__(self, state):
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Control):
state.state = update
return state
if isinstance(state, GraphCommand):
return state.copy(update=update)
else:
return update

def send_for_fun(state):
return [
Send("2", Control(send=Send("2", 3))),
Send("2", Control(send=Send("flaky", 4))),
Send("2", GraphCommand(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("flaky", 4))),
"3.1",
]

Expand All @@ -1906,8 +1905,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(["0"], thread1, debug=1) == [
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
]
assert builder.nodes["2"].runnable.func.ticks == 3
Expand All @@ -1922,8 +1921,8 @@ def route_to_three(state) -> Literal["3"]:
assert graph.invoke(None, thread1, debug=1) == [
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand All @@ -1945,8 +1944,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand Down Expand Up @@ -1981,8 +1980,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
],
Expand All @@ -1999,8 +1998,8 @@ def route_to_three(state) -> Literal["3"]:
"writes": {
"1": ["1"],
"2": [
["2|Control(send=Send(node='2', arg=3))"],
["2|Control(send=Send(node='flaky', arg=4))"],
["2|Command(send=Send(node='2', arg=3))"],
["2|Command(send=Send(node='flaky', arg=4))"],
["2|3"],
],
"flaky": ["flaky|4"],
Expand Down Expand Up @@ -2085,7 +2084,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Control(send=Send(node='2', arg=3))"],
result=["2|Command(send=Send(node='2', arg=3))"],
),
PregelTask(
id=AnyStr(),
Expand All @@ -2099,7 +2098,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Control(send=Send(node='flaky', arg=4))"],
result=["2|Command(send=Send(node='flaky', arg=4))"],
),
PregelTask(
id=AnyStr(),
Expand Down Expand Up @@ -2904,7 +2903,7 @@ def foo(call: ToolCall):

# interrupt-update-resume flow, creating new Send in update call

# TODO add here test with invoke(Control())
# TODO add here test with invoke(Command())


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
Expand Down
69 changes: 44 additions & 25 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2035,7 +2035,9 @@ async def route_to_three(state) -> Literal["3"]:
]


async def test_send_sequences() -> None:
@pytest.mark.repeat(10)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_send_sequences(checkpointer_name: str) -> None:
class Node:
def __init__(self, name: str):
self.name = name
Expand Down Expand Up @@ -2074,22 +2076,40 @@ async def route_to_three(state) -> Literal["3"]:
assert await graph.ainvoke(["0"]) == [
"0",
"1",
"3.1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='2', arg=4))",
"3",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
"3.1",
]

async with awith_checkpointer(checkpointer_name) as checkpointer:
graph = builder.compile(checkpointer=checkpointer, interrupt_before=["3.1"])
thread1 = {"configurable": {"thread_id": "1"}}
assert await graph.ainvoke(["0"], thread1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|3",
"2|4",
]
assert await graph.ainvoke(None, thread1) == [
"0",
"1",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='2', arg=4))",
"2|3",
"2|4",
"3",
"3.1",
]


@pytest.mark.repeat(20)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_send_dedupe_on_resume(checkpointer_name: str) -> None:
if checkpointer_name == "duckdb_aio":
pytest.skip("DuckDB isn't returning the right history")

class InterruptOnce:
ticks: int = 0

Expand All @@ -2112,16 +2132,15 @@ def __call__(self, state):
if isinstance(state, list)
else ["|".join((self.name, str(state)))]
)
if isinstance(state, Control):
state.state = update
return state
if isinstance(state, GraphCommand):
return state.copy(update=update)
else:
return update

def send_for_fun(state):
return [
Send("2", Control(send=Send("2", 3))),
Send("2", Control(send=Send("flaky", 4))),
Send("2", GraphCommand(send=Send("2", 3))),
Send("2", GraphCommand(send=Send("flaky", 4))),
"3.1",
]

Expand All @@ -2144,8 +2163,8 @@ def route_to_three(state) -> Literal["3"]:
assert await graph.ainvoke(["0"], thread1, debug=1) == [
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
]
assert builder.nodes["2"].runnable.func.ticks == 3
Expand All @@ -2154,8 +2173,8 @@ def route_to_three(state) -> Literal["3"]:
assert await graph.ainvoke(None, thread1, debug=1) == [
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand All @@ -2172,8 +2191,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
"3",
Expand Down Expand Up @@ -2208,8 +2227,8 @@ def route_to_three(state) -> Literal["3"]:
values=[
"0",
"1",
"2|Control(send=Send(node='2', arg=3))",
"2|Control(send=Send(node='flaky', arg=4))",
"2|Command(send=Send(node='2', arg=3))",
"2|Command(send=Send(node='flaky', arg=4))",
"2|3",
"flaky|4",
],
Expand All @@ -2226,8 +2245,8 @@ def route_to_three(state) -> Literal["3"]:
"writes": {
"1": ["1"],
"2": [
["2|Control(send=Send(node='2', arg=3))"],
["2|Control(send=Send(node='flaky', arg=4))"],
["2|Command(send=Send(node='2', arg=3))"],
["2|Command(send=Send(node='flaky', arg=4))"],
["2|3"],
],
"flaky": ["flaky|4"],
Expand Down Expand Up @@ -2312,7 +2331,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Control(send=Send(node='2', arg=3))"],
result=["2|Command(send=Send(node='2', arg=3))"],
),
PregelTask(
id=AnyStr(),
Expand All @@ -2326,7 +2345,7 @@ def route_to_three(state) -> Literal["3"]:
error=None,
interrupts=(),
state=None,
result=["2|Control(send=Send(node='flaky', arg=4))"],
result=["2|Command(send=Send(node='flaky', arg=4))"],
),
PregelTask(
id=AnyStr(),
Expand Down

0 comments on commit 3ad966e

Please sign in to comment.