Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Nov 14, 2024
1 parent c473a76 commit bd255c2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
85 changes: 85 additions & 0 deletions libs/langgraph/langgraph/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
import concurrent
import concurrent.futures
import types
from functools import partial, update_wrapper
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Optional,
ParamSpec,
TypeVar,
Union,
overload,
)

from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.channels.last_value import LastValue
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.constants import END, START, TAG_HIDDEN
from langgraph.pregel import Pregel
from langgraph.pregel.call import get_runnable_for_func
from langgraph.pregel.read import PregelNode
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.types import RetryPolicy, acall, call

P = ParamSpec("P")
T = TypeVar("T")


@overload
def task(
*, retry: Optional[RetryPolicy] = None
) -> Callable[
[Callable[P, Coroutine[None, None, T]]], Callable[P, asyncio.Future[T]]
]: ...


@overload
def task(
*, retry: Optional[RetryPolicy] = None
) -> Callable[[Callable[P, T]], Callable[P, concurrent.futures.Future[T]]]: ...


def task(
*, retry: Optional[RetryPolicy] = None
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]],
Callable[P, Union[concurrent.futures.Future[T], asyncio.Future[T]]],
]:
def _task(func: Callable[P, T]) -> Callable[P, concurrent.futures.Future[T]]:
if asyncio.iscoroutinefunction(func):
return update_wrapper(partial(acall, func), func)
else:
return update_wrapper(partial(call, func), func)

return _task


def imp(
*,
checkpointer: Optional[BaseCheckpointSaver] = None,
store: Optional[BaseStore] = None,
) -> Callable[[types.FunctionType], Pregel]:
def _imp(func: types.FunctionType):
return Pregel(
nodes={
func.__name__: PregelNode(
bound=get_runnable_for_func(func),
triggers=[START],
channels=[START],
writers=[ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN])],
)
},
channels={START: EphemeralValue(Any, START), END: LastValue(Any, END)},
input_channels=START,
output_channels=END,
stream_mode="updates",
checkpointer=checkpointer,
store=store,
)

return _imp
24 changes: 24 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
START,
)
from langgraph.errors import InvalidUpdateError, MultipleSubgraphsError, NodeInterrupt
from langgraph.func import imp, task
from langgraph.graph import END, Graph, GraphCommand, StateGraph
from langgraph.graph.message import MessageGraph, MessagesState, add_messages
from langgraph.managed.shared_value import SharedValue
Expand Down Expand Up @@ -1966,6 +1967,29 @@ def route_to_three(state) -> Literal["3"]:
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_imp_task(request: pytest.FixtureRequest, checkpointer_name: str) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

@task()
def mapper(input: str) -> str:
print(f"mapper {input}")
return input * 2

@imp(checkpointer=checkpointer)
def graph(input: list[str]) -> list[str]:
futures = [mapper(i) for i in input]
mapped = [f.result() for f in futures]
# answer = interrupt("question")
# TODO raises NodeInterrupt if no answer provided yet
# returns answer (saved in writes?) if provided
# what is the API for passing the answer?
return mapped

thread1 = {"configurable": {"thread_id": "1"}}
assert graph.invoke(["0", "1"], thread1) == ["00", "11"]


@pytest.mark.repeat(20)
@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_send_dedupe_on_resume(
Expand Down

0 comments on commit bd255c2

Please sign in to comment.