Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Imperative API #2378

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/channels/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod

Check notice on line 1 in libs/langgraph/langgraph/channels/base.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 63.3 ms +- 1.7 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 53.2 ms +- 0.7 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 93.4 ms +- 7.7 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 96.0 ms +- 1.2 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 637 ms +- 14 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 528 ms +- 14 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 976 ms +- 49 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 957 ms +- 17 ms ......................................... react_agent_10x: Mean +- std dev: 31.0 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.7 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.0 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.7 ms +- 0.4 ms ......................................... react_agent_100x: Mean +- std dev: 347 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 273 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 933 ms +- 10 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 828 ms +- 8 ms ......................................... wide_state_25x300: Mean +- std dev: 24.5 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.9 ms +- 0.1 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 277 ms +- 2 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 264 ms +- 5 ms ......................................... wide_state_15x600: Mean +- std dev: 28.8 ms +- 0.5 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 18.4 ms +- 0.2 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 478 ms +- 7 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 461 ms +- 4 ms ......................................... wide_state_9x1200: Mean +- std dev: 28.9 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 18.4 ms +- 0.1 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 311 ms +- 3 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 296 ms +- 3 ms

Check notice on line 1 in libs/langgraph/langgraph/channels/base.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+-----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +=========================================+=========+=======================+ | wide_state_9x1200_checkpoint | 318 ms | 311 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint_sync | 846 ms | 828 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint_sync | 302 ms | 296 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint | 282 ms | 277 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint_sync | 470 ms | 461 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_checkpoint_sync | 269 ms | 264 ms: 1.02x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_100x_checkpoint | 947 ms | 933 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_checkpoint | 485 ms | 478 ms: 1.01x faster | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.5 ms | 36.7 ms: 1.00x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 46.8 ms | 47.0 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | react_agent_100x | 344 ms | 347 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 95.0 ms | 96.0 ms: 1.01x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300 | 24.1 ms | 24.5 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 18.0 ms | 18.4 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600 | 28.1 ms | 28.8 ms: 1.02x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_9x1200 | 28.2 ms | 28.9 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 18.0 ms | 18.4 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 15.5 ms | 15.9 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint_sync | 932 ms | 957 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 51.6 ms | 53.2 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x | 61.4 ms | 63.3 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_checkpoint | 944 ms | 976 ms: 1.03x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x | 609 ms | 637 ms: 1.05x slower | +-----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 502 ms | 528 ms: 1.05x slower | +-----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.01x slower | +---------------------------------------
from typing import Any, Generic, Optional, Sequence, Type, TypeVar
from typing import Any, Generic, Optional, Sequence, TypeVar

from typing_extensions import Self

Expand All @@ -13,7 +13,7 @@
class BaseChannel(Generic[Value, Update, C], ABC):
__slots__ = ("key", "typ")

def __init__(self, typ: Type[Any], key: str = "") -> None:
def __init__(self, typ: Any, key: str = "") -> None:
self.typ = typ
self.key = key

Expand Down
4 changes: 4 additions & 0 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,16 @@
# marker to signal node was scheduled (in distributed mode)
TASKS = sys.intern("__pregel_tasks")
# for Send objects returned by nodes/edges, corresponds to PUSH below
RETURN = sys.intern("__return__")
# for writes of a task where we simply record the return value

# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = sys.intern("__pregel_send")
# holds the `write` function that accepts writes to state/edges/reserved keys
CONFIG_KEY_READ = sys.intern("__pregel_read")
# holds the `read` function that returns a copy of the current state
CONFIG_KEY_CALL = sys.intern("__pregel_call")
# holds the `call` function that accepts a node/func, args and returns a future
CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer")
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
CONFIG_KEY_STREAM = sys.intern("__pregel_stream")
Expand Down
80 changes: 80 additions & 0 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import asyncio
import concurrent
import concurrent.futures
import types
from functools import partial, update_wrapper
from typing import (
Any,
Awaitable,
Callable,
Optional,
ParamSpec,
TypeVar,
Union,
overload,
)

from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END, START, TAG_HIDDEN
from langgraph.pregel import Pregel
from langgraph.pregel.call import get_runnable_for_func
from langgraph.pregel.read import PregelNode
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import RetryPolicy, call

P = ParamSpec("P")
T = TypeVar("T")


@overload
def task(
*, retry: Optional[RetryPolicy] = None
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]]: ...


@overload
def task( # type: ignore[overload-cannot-match]
*, retry: Optional[RetryPolicy] = None
) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ...


def task(
*, retry: Optional[RetryPolicy] = None
) -> Union[
Callable[[Callable[P, Awaitable[T]]], Callable[P, asyncio.Future[T]]],
Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]],
]:
def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]:
return update_wrapper(partial(call, func, retry=retry), func)

return _task


def entrypoint(
*,
checkpointer: Optional[BaseCheckpointSaver] = None,
store: Optional[BaseStore] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType) -> Pregel:
return Pregel(
nodes={
func.__name__: PregelNode(
bound=get_runnable_for_func(func),
triggers=[START],
channels=[START],
writers=[ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN])],
)
},
channels={START: EphemeralValue(Any), END: LastValue(Any, END)},
input_channels=START,
output_channels=END,
stream_channels=END,
stream_mode="updates",
checkpointer=checkpointer,
store=store,
)

return _imp
146 changes: 126 additions & 20 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,26 @@
PUSH,
RESERVED,
RESUME,
RETURN,
TAG_HIDDEN,
TASKS,
Send,
)
from langgraph.errors import EmptyChannelError, InvalidUpdateError
from langgraph.managed.base import ManagedValueMapping
from langgraph.pregel.call import get_runnable_for_func
from langgraph.pregel.io import read_channel, read_channels
from langgraph.pregel.log import logger
from langgraph.pregel.manager import ChannelsManager
from langgraph.pregel.read import PregelNode
from langgraph.store.base import BaseStore
from langgraph.types import All, LoopProtocol, PregelExecutableTask, PregelTask
from langgraph.types import (
All,
LoopProtocol,
PregelExecutableTask,
PregelTask,
RetryPolicy,
)
from langgraph.utils.config import merge_configs, patch_config

GetNextVersion = Callable[[Optional[V], BaseChannel], V]
Expand Down Expand Up @@ -95,6 +103,21 @@ class PregelTaskWrites(NamedTuple):
triggers: Sequence[str]


class Call:
__slots__ = ("func", "input", "retry")

func: Callable
input: Any
retry: Optional[RetryPolicy]

def __init__(
self, func: Callable, input: Any, *, retry: Optional[RetryPolicy]
) -> None:
self.func = func
self.input = input
self.retry = retry


def should_interrupt(
checkpoint: Checkpoint,
interrupt_nodes: Union[All, Sequence[str]],
Expand Down Expand Up @@ -177,7 +200,7 @@ def local_write(
"""Function injected under CONFIG_KEY_SEND in task config, to write to channels.
Validates writes and forwards them to `commit` function."""
for chan, value in writes:
if chan in (PUSH, TASKS):
if chan in (PUSH, TASKS) and value is not None:
if not isinstance(value, Send):
raise InvalidUpdateError(f"Expected Send, got {value}")
if value.node not in process_keys:
Expand Down Expand Up @@ -245,7 +268,7 @@ def apply_writes(
pending_writes_by_managed: dict[str, list[Any]] = defaultdict(list)
for task in tasks:
for chan, val in task.writes:
if chan in (NO_WRITES, PUSH, RESUME, INTERRUPT):
if chan in (NO_WRITES, PUSH, RESUME, INTERRUPT, RETURN):
pass
elif chan == TASKS: # TODO: remove branch in 1.0
checkpoint["pending_sends"].append(val)
Expand Down Expand Up @@ -436,7 +459,7 @@ def prepare_next_tasks(


def prepare_single_task(
task_path: tuple[Union[str, int, tuple], ...],
task_path: tuple[Any, ...],
task_id_checksum: Optional[str],
*,
checkpoint: Checkpoint,
Expand All @@ -457,7 +480,87 @@ def prepare_single_task(
configurable = config.get(CONF, {})
parent_ns = configurable.get(CONFIG_KEY_CHECKPOINT_NS, "")

if task_path[0] == PUSH:
if task_path[0] == PUSH and isinstance(task_path[-1], Call):
# (PUSH, parent task path, idx of PUSH write, id of parent task, Call)
task_path_t = cast(tuple[str, tuple, int, str, Call], task_path)
call = task_path_t[-1]
proc_ = get_runnable_for_func(call.func)
name = proc_.name
if name is None:
raise ValueError("`call` functions must have a `__name__` attribute")
# create task id
triggers = [PUSH]
checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name
task_id = _uuid5_str(
checkpoint_id,
checkpoint_ns,
str(step),
name,
PUSH,
_tuple_str(task_path[1]),
str(task_path[2]),
)
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
metadata = {
"langgraph_step": step,
"langgraph_node": name,
"langgraph_triggers": triggers,
"langgraph_path": task_path[:3],
"langgraph_checkpoint_ns": task_checkpoint_ns,
}
if task_id_checksum is not None:
assert task_id == task_id_checksum, f"{task_id} != {task_id_checksum}"
if for_execution:
writes: deque[tuple[str, Any]] = deque()
return PregelExecutableTask(
name,
call.input,
proc_,
writes,
patch_config(
merge_configs(config, {"metadata": metadata}),
run_name=name,
callbacks=(
manager.get_child(f"graph:step:{step}") if manager else None
),
configurable={
CONFIG_KEY_TASK_ID: task_id,
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
writes.extend,
processes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
step,
checkpoint,
channels,
managed,
PregelTaskWrites(task_path[:3], name, writes, triggers),
config,
),
CONFIG_KEY_STORE: (store or configurable.get(CONFIG_KEY_STORE)),
CONFIG_KEY_CHECKPOINTER: (
checkpointer or configurable.get(CONFIG_KEY_CHECKPOINTER)
),
CONFIG_KEY_CHECKPOINT_MAP: {
**configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}),
parent_ns: checkpoint["id"],
},
CONFIG_KEY_CHECKPOINT_ID: None,
CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns,
},
),
triggers,
call.retry,
None,
task_id,
task_path[:3],
)
else:
return PregelTask(task_id, name, task_path[:3])
elif task_path[0] == PUSH:
if len(task_path) == 2: # TODO: remove branch in 1.0
# legacy SEND tasks, executed in superstep n+1
# (PUSH, idx of pending send)
Expand Down Expand Up @@ -488,17 +591,19 @@ def prepare_single_task(
PUSH,
str(idx),
)
elif len(task_path) == 4:
elif len(task_path) >= 4:
# new PUSH tasks, executed in superstep n
# (PUSH, parent task path, idx of PUSH write, id of parent task)
task_path_t = cast(tuple[str, tuple, int, str], task_path)
writes_for_path = [w for w in pending_writes if w[0] == task_path_t[3]]
if task_path_t[2] >= len(writes_for_path):
task_path_tt = cast(tuple[str, tuple, int, str], task_path)
writes_for_path = [w for w in pending_writes if w[0] == task_path_tt[3]]
if task_path_tt[2] >= len(writes_for_path):
logger.warning(
f"Ignoring invalid write index {task_path[2]} in pending writes"
)
return
packet = writes_for_path[task_path_t[2]][2]
packet = writes_for_path[task_path_tt[2]][2]
if packet is None:
return
if not isinstance(packet, Send):
logger.warning(
f"Ignoring invalid packet type {type(packet)} in pending writes"
Expand Down Expand Up @@ -531,7 +636,7 @@ def prepare_single_task(
"langgraph_step": step,
"langgraph_node": packet.node,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
"langgraph_path": task_path[:3],
"langgraph_checkpoint_ns": task_checkpoint_ns,
}
if task_id_checksum is not None:
Expand All @@ -541,7 +646,7 @@ def prepare_single_task(
if node := proc.node:
if proc.metadata:
metadata.update(proc.metadata)
writes: deque[tuple[str, Any]] = deque()
writes = deque()
return PregelExecutableTask(
packet.node,
packet.arg,
Expand Down Expand Up @@ -570,7 +675,7 @@ def prepare_single_task(
channels,
managed,
PregelTaskWrites(
task_path, packet.node, writes, triggers
task_path[:3], packet.node, writes, triggers
),
config,
),
Expand Down Expand Up @@ -601,12 +706,11 @@ def prepare_single_task(
proc.retry_policy,
None,
task_id,
task_path,
task_path[:3],
writers=proc.flat_writers,
)

else:
return PregelTask(task_id, packet.node, task_path)
return PregelTask(task_id, packet.node, task_path[:3])
elif task_path[0] == PULL:
# (PULL, node name)
name = cast(str, task_path[1])
Expand Down Expand Up @@ -650,7 +754,7 @@ def prepare_single_task(
"langgraph_step": step,
"langgraph_node": name,
"langgraph_triggers": triggers,
"langgraph_path": task_path,
"langgraph_path": task_path[:3],
"langgraph_checkpoint_ns": task_checkpoint_ns,
}
if task_id_checksum is not None:
Expand Down Expand Up @@ -689,7 +793,9 @@ def prepare_single_task(
checkpoint,
channels,
managed,
PregelTaskWrites(task_path, name, writes, triggers),
PregelTaskWrites(
task_path[:3], name, writes, triggers
),
config,
),
CONFIG_KEY_STORE: (
Expand Down Expand Up @@ -720,11 +826,11 @@ def prepare_single_task(
proc.retry_policy,
None,
task_id,
task_path,
task_path[:3],
writers=proc.flat_writers,
)
else:
return PregelTask(task_id, name, task_path)
return PregelTask(task_id, name, task_path[:3])


def _proc_input(
Expand Down
Loading
Loading