From 01bb6cbd4dcaa48474664d163dfc21c1c196a730 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Sep 2024 17:27:37 -0700 Subject: [PATCH 1/3] Add stream_mode=messages - yields messages token-by-token from LLMs, and any messages returned from a node --- libs/langgraph/langgraph/pregel/__init__.py | 9 +- libs/langgraph/langgraph/pregel/algo.py | 31 +-- libs/langgraph/langgraph/pregel/loop.py | 5 +- libs/langgraph/langgraph/pregel/messages.py | 158 +++++++++++++ libs/langgraph/langgraph/pregel/types.py | 3 +- libs/langgraph/tests/fake_chat.py | 83 +++++++ libs/langgraph/tests/test_pregel.py | 237 +++++++++++++++++++- 7 files changed, 495 insertions(+), 31 deletions(-) create mode 100644 libs/langgraph/langgraph/pregel/messages.py create mode 100644 libs/langgraph/tests/fake_chat.py diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 1cf5ec982..aee89123f 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -78,6 +78,7 @@ from langgraph.pregel.io import read_channels from langgraph.pregel.loop import AsyncPregelLoop, StreamProtocol, SyncPregelLoop from langgraph.pregel.manager import AsyncChannelsManager, ChannelsManager +from langgraph.pregel.messages import StreamMessagesHandler from langgraph.pregel.read import PregelNode from langgraph.pregel.retry import RetryPolicy from langgraph.pregel.runner import PregelRunner @@ -1213,7 +1214,11 @@ def output() -> Iterator: interrupt_after=interrupt_after, debug=debug, ) - + # set up messages stream mode + if "messages" in stream_modes: + run_manager.inheritable_handlers.append( + StreamMessagesHandler(stream.put) + ) with SyncPregelLoop( input, stream=StreamProtocol(stream.put, stream_modes), @@ -1234,6 +1239,8 @@ def output() -> Iterator: # enable subgraph streaming if subgraphs: loop.config["configurable"][CONFIG_KEY_STREAM] = loop.stream + # enable concurrent streaming + if subgraphs or "messages" in stream_modes: # we are careful to have a single waiter live at any one time # because on exit we increment semaphore count by exactly 1 waiter: Optional[concurrent.futures.Future] = None diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index f23424ca0..c2c5866e4 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -350,12 +350,6 @@ def prepare_single_task( return # create task id triggers = [PUSH] - metadata = { - "langgraph_step": step, - "langgraph_node": packet.node, - "langgraph_triggers": triggers, - "langgraph_path": task_path, - } checkpoint_ns = ( f"{parent_ns}{NS_SEP}{packet.node}" if parent_ns else packet.node ) @@ -367,6 +361,14 @@ def prepare_single_task( PUSH, str(idx), ) + task_checkpoint_ns = f"{checkpoint_ns}:{task_id}" + metadata = { + "langgraph_step": step, + "langgraph_node": packet.node, + "langgraph_triggers": triggers, + "langgraph_path": task_path, + "langgraph_checkpoint_ns": task_checkpoint_ns, + } if task_id_checksum is not None: assert task_id == task_id_checksum if for_execution: @@ -376,7 +378,6 @@ def prepare_single_task( if proc.metadata: metadata.update(proc.metadata) writes = deque() - task_checkpoint_ns = f"{checkpoint_ns}:{task_id}" return PregelExecutableTask( packet.node, packet.arg, @@ -461,12 +462,6 @@ def prepare_single_task( return # create task id - metadata = { - "langgraph_step": step, - "langgraph_node": name, - "langgraph_triggers": triggers, - "langgraph_path": task_path, - } checkpoint_ns = f"{parent_ns}{NS_SEP}{name}" if parent_ns else name task_id = _uuid5_str( checkpoint_id, @@ -476,15 +471,21 @@ def prepare_single_task( PULL, *triggers, ) + task_checkpoint_ns = f"{checkpoint_ns}:{task_id}" + metadata = { + "langgraph_step": step, + "langgraph_node": name, + "langgraph_triggers": triggers, + "langgraph_path": task_path, + "langgraph_checkpoint_ns": task_checkpoint_ns, + } if task_id_checksum is not None: assert task_id == task_id_checksum - if for_execution: if node := proc.node: if proc.metadata: metadata.update(proc.metadata) writes = deque() - task_checkpoint_ns = f"{checkpoint_ns}:{task_id}" return PregelExecutableTask( name, val, diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index f1db60e0e..7560905ed 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -8,7 +8,6 @@ AsyncContextManager, Callable, ContextManager, - Iterable, Iterator, List, Literal, @@ -112,11 +111,11 @@ class StreamProtocol: modes: Sequence[Literal["values", "updates", "debug"]] - __call__: Callable[[Iterable[Tuple[str, str, Any]]], None] + __call__: Callable[[Tuple[str, str, Any]], None] def __init__( self, - __call__: Callable[[Iterable[Tuple[str, str, Any]]], None], + __call__: Callable[[Tuple[str, str, Any]], None], modes: Sequence[Literal["values", "updates", "debug"]], ) -> None: self.__call__ = __call__ diff --git a/libs/langgraph/langgraph/pregel/messages.py b/libs/langgraph/langgraph/pregel/messages.py new file mode 100644 index 000000000..2960294e3 --- /dev/null +++ b/libs/langgraph/langgraph/pregel/messages.py @@ -0,0 +1,158 @@ +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Tuple, +) +from uuid import UUID, uuid4 + +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatGenerationChunk, LLMResult +from langchain_core.tracers._streaming import T, _StreamingCallbackHandler + +from langgraph.constants import NS_SEP + + +class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler): + def __init__(self, stream: Callable[[Tuple[str, str, Any]], None]): + self.stream = stream + self.metadata: dict[str, tuple[str, dict[str, Any]]] = {} + self.seen = set() + + def _emit(self, meta: Tuple[str, dict[str, Any]], message: BaseMessage): + ident = id(message) + if ident in self.seen: + return + else: + if message.id is None: + message.id = str(uuid4()) + self.seen.add(ident) + self.stream((meta[0], "messages", (message, meta[1]))) + + def tap_output_aiter( + self, run_id: UUID, output: AsyncIterator[T] + ) -> AsyncIterator[T]: + return output + + def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]: + return output + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + if metadata: + self.metadata[run_id] = ( + tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)), + metadata, + ) + + def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[ChatGenerationChunk] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if not isinstance(chunk, ChatGenerationChunk): + return + if meta := self.metadata.get(run_id): + self._emit(meta, chunk.message) + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self.metadata.pop(run_id, None) + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self.metadata.pop(run_id, None) + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + if metadata: + self.metadata[run_id] = ( + tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)), + metadata, + ) + + def on_chain_end( + self, + response: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if meta := self.metadata.pop(run_id, None): + if isinstance(response, BaseMessage): + self._emit(meta, response) + elif isinstance(response, Sequence): + for value in response: + if isinstance(value, BaseMessage): + self._emit(meta, value) + elif isinstance(response, dict): + for value in response.values(): + if isinstance(value, BaseMessage): + self._emit(meta, value) + elif isinstance(value, Sequence): + for item in value: + if isinstance(item, BaseMessage): + self._emit(meta, item) + elif hasattr(response, "__dir__") and callable(response.__dir__): + for key in dir(response): + try: + value = getattr(response, key) + if isinstance(value, BaseMessage): + self._emit(meta, value) + elif isinstance(value, Sequence): + for item in value: + if isinstance(item, BaseMessage): + self._emit(meta, item) + except AttributeError: + pass + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + self.metadata.pop(run_id, None) diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index 589ed43ab..695647533 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -107,11 +107,12 @@ class StateSnapshot(NamedTuple): All = Literal["*"] -StreamMode = Literal["values", "updates", "debug"] +StreamMode = Literal["values", "updates", "debug", "messages"] """How the stream method should emit outputs. - 'values': Emit all values of the state for each step. - 'updates': Emit only the node name(s) and updates that were returned by the node(s) **after** each step. - 'debug': Emit debug events for each step. +- 'messages': Emit LLM messages token-by-token. """ diff --git a/libs/langgraph/tests/fake_chat.py b/libs/langgraph/tests/fake_chat.py new file mode 100644 index 000000000..0d6caf2fd --- /dev/null +++ b/libs/langgraph/tests/fake_chat.py @@ -0,0 +1,83 @@ +import re +from typing import Any, Iterator, List, Optional, cast + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult + + +class FakeChatModel(GenericFakeChatModel): + messages: list[BaseMessage] + + i: int = 0 + + def bind_tools(self, functions: list): + return self + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + if self.i >= len(self.messages): + self.i = 0 + message = self.messages[self.i] + self.i += 1 + if isinstance(message, str): + message_ = AIMessage(content=message) + else: + message_ = message + generation = ChatGeneration(message=message_) + return ChatResult(generations=[generation]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the output of the model.""" + chat_result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + if not isinstance(chat_result, ChatResult): + raise ValueError( + f"Expected generate to return a ChatResult, " + f"but got {type(chat_result)} instead." + ) + + message = chat_result.generations[0].message + + if not isinstance(message, AIMessage): + raise ValueError( + f"Expected invoke to return an AIMessage, " + f"but got {type(message)} instead." + ) + + content = message.content + + if content: + # Use a regular expression to split on whitespace with a capture group + # so that we can preserve the whitespace in the output. + assert isinstance(content, str) + content_chunks = cast(list[str], re.split(r"(\s)", content)) + + for token in content_chunks: + chunk = ChatGenerationChunk( + message=AIMessageChunk(content=token, id=message.id) + ) + if run_manager: + run_manager.on_llm_new_token(token, chunk=chunk) + yield chunk + else: + args = message.__dict__ + args.pop("type") + chunk = ChatGenerationChunk(message=AIMessageChunk(**args)) + if run_manager: + run_manager.on_llm_new_token("", chunk=chunk) + yield chunk diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 9096803a5..e5f4c9b8f 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -74,9 +74,15 @@ from langgraph.store.memory import MemoryStore from tests.any_str import AnyDict, AnyStr, AnyVersion, UnsortedSequence from tests.conftest import ALL_CHECKPOINTERS_SYNC, SHOULD_CHECK_SNAPSHOTS +from tests.fake_chat import FakeChatModel from tests.fake_tracer import FakeTracer from tests.memory_assert import MemorySaverAssertCheckpointMetadata -from tests.messages import _AnyIdAIMessage, _AnyIdHumanMessage, _AnyIdToolMessage +from tests.messages import ( + _AnyIdAIMessage, + _AnyIdAIMessageChunk, + _AnyIdHumanMessage, + _AnyIdToolMessage, +) # define these objects to avoid importing langchain_core.agents @@ -3926,16 +3932,9 @@ def should_start(data: AgentState) -> str: def test_prebuilt_tool_chat(snapshot: SnapshotAssertion) -> None: - from langchain_core.language_models.fake_chat_models import ( - FakeMessagesListChatModel, - ) - from langchain_core.messages import AIMessage, HumanMessage + from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage from langchain_core.tools import tool - class FakeFuntionChatModel(FakeMessagesListChatModel): - def bind_tools(self, functions: list): - return self - @tool() def search_api(query: str) -> str: """Searches the API for the query.""" @@ -3943,8 +3942,8 @@ def search_api(query: str) -> str: tools = [search_api] - model = FakeFuntionChatModel( - responses=[ + model = FakeChatModel( + messages=[ AIMessage( content="", tool_calls=[ @@ -4032,6 +4031,222 @@ def search_api(query: str) -> str: ] } + assert [ + c + for c in app.stream( + {"messages": [HumanMessage(content="what is weather in sf")]}, + stream_mode="messages", + ) + ] == [ + ( + _AnyIdHumanMessage( + content="what is weather in sf", + ), + { + "langgraph_step": 0, + "langgraph_node": "__start__", + "langgraph_triggers": ["__start__"], + "langgraph_path": ("__pregel_pull", "__start__"), + "langgraph_checkpoint_ns": AnyStr("__start__:"), + }, + ), + ( + _AnyIdAIMessageChunk( + content="", + tool_calls=[ + { + "name": "search_api", + "args": {"query": "query"}, + "id": "tool_call123", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "search_api", + "args": '{"query": "query"}', + "id": "tool_call123", + "index": None, + "type": "tool_call_chunk", + } + ], + ), + { + "langgraph_step": 1, + "langgraph_node": "agent", + "langgraph_triggers": ["start:agent"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + "ls_provider": "fakechatmodel", + "ls_model_type": "chat", + }, + ), + ( + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "search_api", + "args": {"query": "query"}, + "id": "tool_call123", + "type": "tool_call", + } + ], + ), + { + "langgraph_step": 1, + "langgraph_node": "agent", + "langgraph_triggers": ["start:agent"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + }, + ), + ( + _AnyIdToolMessage( + content="result for query", + name="search_api", + tool_call_id="tool_call123", + ), + { + "langgraph_step": 2, + "langgraph_node": "tools", + "langgraph_triggers": ["branch:agent:should_continue:tools"], + "langgraph_path": ("__pregel_pull", "tools"), + "langgraph_checkpoint_ns": AnyStr("tools:"), + }, + ), + ( + _AnyIdAIMessageChunk( + content="", + tool_calls=[ + { + "name": "search_api", + "args": {"query": "another"}, + "id": "tool_call234", + "type": "tool_call", + }, + { + "name": "search_api", + "args": {"query": "a third one"}, + "id": "tool_call567", + "type": "tool_call", + }, + ], + tool_call_chunks=[ + { + "name": "search_api", + "args": '{"query": "another"}', + "id": "tool_call234", + "index": None, + "type": "tool_call_chunk", + }, + { + "name": "search_api", + "args": '{"query": "a third one"}', + "id": "tool_call567", + "index": None, + "type": "tool_call_chunk", + }, + ], + ), + { + "langgraph_step": 3, + "langgraph_node": "agent", + "langgraph_triggers": ["tools"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + "ls_provider": "fakechatmodel", + "ls_model_type": "chat", + }, + ), + ( + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "name": "search_api", + "args": {"query": "another"}, + "id": "tool_call234", + "type": "tool_call", + }, + { + "name": "search_api", + "args": {"query": "a third one"}, + "id": "tool_call567", + "type": "tool_call", + }, + ], + ), + { + "langgraph_step": 3, + "langgraph_node": "agent", + "langgraph_triggers": ["tools"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + }, + ), + ( + _AnyIdToolMessage( + content="result for another", + name="search_api", + tool_call_id="tool_call234", + ), + { + "langgraph_step": 4, + "langgraph_node": "tools", + "langgraph_triggers": ["branch:agent:should_continue:tools"], + "langgraph_path": ("__pregel_pull", "tools"), + "langgraph_checkpoint_ns": AnyStr("tools:"), + }, + ), + ( + _AnyIdToolMessage( + content="result for a third one", + name="search_api", + tool_call_id="tool_call567", + ), + { + "langgraph_step": 4, + "langgraph_node": "tools", + "langgraph_triggers": ["branch:agent:should_continue:tools"], + "langgraph_path": ("__pregel_pull", "tools"), + "langgraph_checkpoint_ns": AnyStr("tools:"), + }, + ), + ( + _AnyIdAIMessageChunk( + content="answer", + ), + { + "langgraph_step": 5, + "langgraph_node": "agent", + "langgraph_triggers": ["tools"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + "ls_provider": "fakechatmodel", + "ls_model_type": "chat", + }, + ), + ( + _AnyIdAIMessage( + content="answer", + ), + { + "langgraph_step": 5, + "langgraph_node": "agent", + "langgraph_triggers": ["tools"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + }, + ), + ] + assert app.invoke( {"messages": [HumanMessage(content="what is weather in sf")]}, {"recursion_limit": 2}, From c0d8493db5b5d248b8d1ee735fadf7594d5d6e8e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 17 Sep 2024 17:32:14 -0700 Subject: [PATCH 2/3] Only look at chain runs which are nodes --- libs/langgraph/langgraph/pregel/messages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langgraph/langgraph/pregel/messages.py b/libs/langgraph/langgraph/pregel/messages.py index 2960294e3..5ee839179 100644 --- a/libs/langgraph/langgraph/pregel/messages.py +++ b/libs/langgraph/langgraph/pregel/messages.py @@ -105,7 +105,7 @@ def on_chain_start( metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: - if metadata: + if metadata and kwargs.get("name") == metadata.get("langgraph_node"): self.metadata[run_id] = ( tuple(metadata["langgraph_checkpoint_ns"].split(NS_SEP)), metadata, From 9235178c880a6bbccac745b590e68a6f9a36ed82 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 18 Sep 2024 13:56:04 -0700 Subject: [PATCH 3/3] Finish --- libs/langgraph/langgraph/pregel/messages.py | 25 ++- libs/langgraph/tests/fake_chat.py | 5 +- libs/langgraph/tests/test_pregel.py | 206 +++++++---------- libs/langgraph/tests/test_pregel_async.py | 233 ++++++++++++++++---- 4 files changed, 288 insertions(+), 181 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/messages.py b/libs/langgraph/langgraph/pregel/messages.py index 5ee839179..3ade1dd5f 100644 --- a/libs/langgraph/langgraph/pregel/messages.py +++ b/libs/langgraph/langgraph/pregel/messages.py @@ -25,14 +25,23 @@ def __init__(self, stream: Callable[[Tuple[str, str, Any]], None]): self.metadata: dict[str, tuple[str, dict[str, Any]]] = {} self.seen = set() - def _emit(self, meta: Tuple[str, dict[str, Any]], message: BaseMessage): + def _emit( + self, + meta: Tuple[str, dict[str, Any]], + message: BaseMessage, + *, + dedupe: bool = False, + ): ident = id(message) - if ident in self.seen: + if dedupe and message.id in self.seen: + return + elif ident in self.seen: return else: if message.id is None: message.id = str(uuid4()) self.seen.add(ident) + self.seen.add(message.id) self.stream((meta[0], "messages", (message, meta[1]))) def tap_output_aiter( @@ -121,29 +130,29 @@ def on_chain_end( ) -> Any: if meta := self.metadata.pop(run_id, None): if isinstance(response, BaseMessage): - self._emit(meta, response) + self._emit(meta, response, dedupe=True) elif isinstance(response, Sequence): for value in response: if isinstance(value, BaseMessage): - self._emit(meta, value) + self._emit(meta, value, dedupe=True) elif isinstance(response, dict): for value in response.values(): if isinstance(value, BaseMessage): - self._emit(meta, value) + self._emit(meta, value, dedupe=True) elif isinstance(value, Sequence): for item in value: if isinstance(item, BaseMessage): - self._emit(meta, item) + self._emit(meta, item, dedupe=True) elif hasattr(response, "__dir__") and callable(response.__dir__): for key in dir(response): try: value = getattr(response, key) if isinstance(value, BaseMessage): - self._emit(meta, value) + self._emit(meta, value, dedupe=True) elif isinstance(value, Sequence): for item in value: if isinstance(item, BaseMessage): - self._emit(meta, item) + self._emit(meta, item, dedupe=True) except AttributeError: pass diff --git a/libs/langgraph/tests/fake_chat.py b/libs/langgraph/tests/fake_chat.py index 0d6caf2fd..c2a6b9b9e 100644 --- a/libs/langgraph/tests/fake_chat.py +++ b/libs/langgraph/tests/fake_chat.py @@ -30,7 +30,10 @@ def _generate( if isinstance(message, str): message_ = AIMessage(content=message) else: - message_ = message + if hasattr(message, "model_copy"): + message_ = message.model_copy() + else: + message_ = message.copy() generation = ChatGeneration(message=message_) return ChatResult(generations=[generation]) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index e5f4c9b8f..f42370fed 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -3932,7 +3932,7 @@ def should_start(data: AgentState) -> str: def test_prebuilt_tool_chat(snapshot: SnapshotAssertion) -> None: - from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage + from langchain_core.messages import AIMessage, HumanMessage from langchain_core.tools import tool @tool() @@ -4082,27 +4082,6 @@ def search_api(query: str) -> str: "ls_model_type": "chat", }, ), - ( - _AnyIdAIMessage( - content="", - tool_calls=[ - { - "name": "search_api", - "args": {"query": "query"}, - "id": "tool_call123", - "type": "tool_call", - } - ], - ), - { - "langgraph_step": 1, - "langgraph_node": "agent", - "langgraph_triggers": ["start:agent"], - "langgraph_path": ("__pregel_pull", "agent"), - "langgraph_checkpoint_ns": AnyStr("agent:"), - "checkpoint_ns": AnyStr("agent:"), - }, - ), ( _AnyIdToolMessage( content="result for query", @@ -4162,33 +4141,6 @@ def search_api(query: str) -> str: "ls_model_type": "chat", }, ), - ( - _AnyIdAIMessage( - content="", - tool_calls=[ - { - "name": "search_api", - "args": {"query": "another"}, - "id": "tool_call234", - "type": "tool_call", - }, - { - "name": "search_api", - "args": {"query": "a third one"}, - "id": "tool_call567", - "type": "tool_call", - }, - ], - ), - { - "langgraph_step": 3, - "langgraph_node": "agent", - "langgraph_triggers": ["tools"], - "langgraph_path": ("__pregel_pull", "agent"), - "langgraph_checkpoint_ns": AnyStr("agent:"), - "checkpoint_ns": AnyStr("agent:"), - }, - ), ( _AnyIdToolMessage( content="result for another", @@ -4232,19 +4184,6 @@ def search_api(query: str) -> str: "ls_model_type": "chat", }, ), - ( - _AnyIdAIMessage( - content="answer", - ), - { - "langgraph_step": 5, - "langgraph_node": "agent", - "langgraph_triggers": ["tools"], - "langgraph_path": ("__pregel_pull", "agent"), - "langgraph_checkpoint_ns": AnyStr("agent:"), - "checkpoint_ns": AnyStr("agent:"), - }, - ), ] assert app.invoke( @@ -4260,76 +4199,79 @@ def search_api(query: str) -> str: model.i = 0 # reset the model - assert app.invoke( - {"messages": [HumanMessage(content="what is weather in sf")]}, - stream_mode="updates", - ) == [ - { - "agent": { - "messages": [ - _AnyIdAIMessage( - content="", - tool_calls=[ - { - "id": "tool_call123", - "name": "search_api", - "args": {"query": "query"}, - }, - ], - ) - ] - } - }, - { - "tools": { - "messages": [ - _AnyIdToolMessage( - content="result for query", - name="search_api", - tool_call_id="tool_call123", - ) - ] - } - }, - { - "agent": { - "messages": [ - _AnyIdAIMessage( - content="", - tool_calls=[ - { - "id": "tool_call234", - "name": "search_api", - "args": {"query": "another"}, - }, - { - "id": "tool_call567", - "name": "search_api", - "args": {"query": "a third one"}, - }, - ], - ) - ] - } - }, - { - "tools": { - "messages": [ - _AnyIdToolMessage( - content="result for another", - name="search_api", - tool_call_id="tool_call234", - ), - _AnyIdToolMessage( - content="result for a third one", - name="search_api", - tool_call_id="tool_call567", - ), - ] - } - }, - {"agent": {"messages": [_AnyIdAIMessage(content="answer")]}}, - ] + assert ( + app.invoke( + {"messages": [HumanMessage(content="what is weather in sf")]}, + stream_mode="updates", + )[0]["agent"]["messages"] + == [ + { + "agent": { + "messages": [ + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "id": "tool_call123", + "name": "search_api", + "args": {"query": "query"}, + }, + ], + ) + ] + } + }, + { + "tools": { + "messages": [ + _AnyIdToolMessage( + content="result for query", + name="search_api", + tool_call_id="tool_call123", + ) + ] + } + }, + { + "agent": { + "messages": [ + _AnyIdAIMessage( + content="", + tool_calls=[ + { + "id": "tool_call234", + "name": "search_api", + "args": {"query": "another"}, + }, + { + "id": "tool_call567", + "name": "search_api", + "args": {"query": "a third one"}, + }, + ], + ) + ] + } + }, + { + "tools": { + "messages": [ + _AnyIdToolMessage( + content="result for another", + name="search_api", + tool_call_id="tool_call234", + ), + _AnyIdToolMessage( + content="result for a third one", + name="search_api", + tool_call_id="tool_call567", + ), + ] + } + }, + {"agent": {"messages": [_AnyIdAIMessage(content="answer")]}}, + ][0]["agent"]["messages"] + ) assert [ *app.stream({"messages": [HumanMessage(content="what is weather in sf")]}) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 2e031e2f5..7b32be4a5 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -77,12 +77,18 @@ SHOULD_CHECK_SNAPSHOTS, awith_checkpointer, ) +from tests.fake_chat import FakeChatModel from tests.fake_tracer import FakeTracer from tests.memory_assert import ( MemorySaverAssertCheckpointMetadata, MemorySaverNoPending, ) -from tests.messages import _AnyIdAIMessage, _AnyIdHumanMessage, _AnyIdToolMessage +from tests.messages import ( + _AnyIdAIMessage, + _AnyIdAIMessageChunk, + _AnyIdHumanMessage, + _AnyIdToolMessage, +) pytestmark = pytest.mark.anyio @@ -3842,15 +3848,39 @@ def should_start(data: AgentState) -> str: async def test_prebuilt_tool_chat() -> None: - from langchain_core.language_models.fake_chat_models import ( - FakeMessagesListChatModel, - ) from langchain_core.messages import AIMessage, HumanMessage from langchain_core.tools import tool - class FakeFuntionChatModel(FakeMessagesListChatModel): - def bind_tools(self, functions: list): - return self + model = FakeChatModel( + messages=[ + AIMessage( + content="", + tool_calls=[ + { + "id": "tool_call123", + "name": "search_api", + "args": {"query": "query"}, + }, + ], + ), + AIMessage( + content="", + tool_calls=[ + { + "id": "tool_call234", + "name": "search_api", + "args": {"query": "another"}, + }, + { + "id": "tool_call567", + "name": "search_api", + "args": {"query": "a third one"}, + }, + ], + ), + AIMessage(content="answer"), + ] + ) @tool() def search_api(query: str) -> str: @@ -3859,39 +3889,7 @@ def search_api(query: str) -> str: tools = [search_api] - app = create_tool_calling_executor( - FakeFuntionChatModel( - responses=[ - AIMessage( - content="", - tool_calls=[ - { - "id": "tool_call123", - "name": "search_api", - "args": {"query": "query"}, - }, - ], - ), - AIMessage( - content="", - tool_calls=[ - { - "id": "tool_call234", - "name": "search_api", - "args": {"query": "another"}, - }, - { - "id": "tool_call567", - "name": "search_api", - "args": {"query": "a third one"}, - }, - ], - ), - AIMessage(content="answer"), - ] - ), - tools, - ) + app = create_tool_calling_executor(model, tools) assert await app.ainvoke( {"messages": [HumanMessage(content="what is weather in sf")]} @@ -3943,6 +3941,161 @@ def search_api(query: str) -> str: ] } + assert [ + c + for c in app.stream( + {"messages": [HumanMessage(content="what is weather in sf")]}, + stream_mode="messages", + ) + ] == [ + ( + _AnyIdHumanMessage( + content="what is weather in sf", + ), + { + "langgraph_step": 0, + "langgraph_node": "__start__", + "langgraph_triggers": ["__start__"], + "langgraph_path": ("__pregel_pull", "__start__"), + "langgraph_checkpoint_ns": AnyStr("__start__:"), + }, + ), + ( + _AnyIdAIMessageChunk( + content="", + tool_calls=[ + { + "name": "search_api", + "args": {"query": "query"}, + "id": "tool_call123", + "type": "tool_call", + } + ], + tool_call_chunks=[ + { + "name": "search_api", + "args": '{"query": "query"}', + "id": "tool_call123", + "index": None, + "type": "tool_call_chunk", + } + ], + ), + { + "langgraph_step": 1, + "langgraph_node": "agent", + "langgraph_triggers": ["start:agent"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + "ls_provider": "fakechatmodel", + "ls_model_type": "chat", + }, + ), + ( + _AnyIdToolMessage( + content="result for query", + name="search_api", + tool_call_id="tool_call123", + ), + { + "langgraph_step": 2, + "langgraph_node": "tools", + "langgraph_triggers": ["branch:agent:should_continue:tools"], + "langgraph_path": ("__pregel_pull", "tools"), + "langgraph_checkpoint_ns": AnyStr("tools:"), + }, + ), + ( + _AnyIdAIMessageChunk( + content="", + tool_calls=[ + { + "name": "search_api", + "args": {"query": "another"}, + "id": "tool_call234", + "type": "tool_call", + }, + { + "name": "search_api", + "args": {"query": "a third one"}, + "id": "tool_call567", + "type": "tool_call", + }, + ], + tool_call_chunks=[ + { + "name": "search_api", + "args": '{"query": "another"}', + "id": "tool_call234", + "index": None, + "type": "tool_call_chunk", + }, + { + "name": "search_api", + "args": '{"query": "a third one"}', + "id": "tool_call567", + "index": None, + "type": "tool_call_chunk", + }, + ], + ), + { + "langgraph_step": 3, + "langgraph_node": "agent", + "langgraph_triggers": ["tools"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + "ls_provider": "fakechatmodel", + "ls_model_type": "chat", + }, + ), + ( + _AnyIdToolMessage( + content="result for another", + name="search_api", + tool_call_id="tool_call234", + ), + { + "langgraph_step": 4, + "langgraph_node": "tools", + "langgraph_triggers": ["branch:agent:should_continue:tools"], + "langgraph_path": ("__pregel_pull", "tools"), + "langgraph_checkpoint_ns": AnyStr("tools:"), + }, + ), + ( + _AnyIdToolMessage( + content="result for a third one", + name="search_api", + tool_call_id="tool_call567", + ), + { + "langgraph_step": 4, + "langgraph_node": "tools", + "langgraph_triggers": ["branch:agent:should_continue:tools"], + "langgraph_path": ("__pregel_pull", "tools"), + "langgraph_checkpoint_ns": AnyStr("tools:"), + }, + ), + ( + _AnyIdAIMessageChunk( + content="answer", + ), + { + "langgraph_step": 5, + "langgraph_node": "agent", + "langgraph_triggers": ["tools"], + "langgraph_path": ("__pregel_pull", "agent"), + "langgraph_checkpoint_ns": AnyStr("agent:"), + "checkpoint_ns": AnyStr("agent:"), + "ls_provider": "fakechatmodel", + "ls_model_type": "chat", + }, + ), + ] + assert [ c async for c in app.astream(