From 630fa657b113df6ed9458046723296df6a7954c6 Mon Sep 17 00:00:00 2001 From: Nick Bobrowski <39348559+bonk1t@users.noreply.github.com> Date: Fri, 20 Dec 2024 02:29:28 +0000 Subject: [PATCH] Resolve Langfuse issues --- agency_swarm/threads/thread.py | 25 ++++++-- agency_swarm/util/tracking/__init__.py | 4 +- agency_swarm/util/tracking/langchain_types.py | 63 ++++++++++++------- tests/demos/demo_observability.py | 7 ++- 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/agency_swarm/threads/thread.py b/agency_swarm/threads/thread.py index 0ca27891..5e4b3d5c 100644 --- a/agency_swarm/threads/thread.py +++ b/agency_swarm/threads/thread.py @@ -182,7 +182,10 @@ def get_completion( # Chain start if self.callback_handler: self.callback_handler.on_chain_start( - serialized={"name": f"Thread.get_completion -> {recipient_agent.name}"}, + serialized={ + "name": f"Thread.get_completion -> {recipient_agent.name}", + "id": [self._run.id], + }, inputs={"message": message}, run_id=chain_run_id, parent_run_id=parent_run_id, @@ -195,12 +198,21 @@ def get_completion( # chat model start callback if self.callback_handler: - chat_messages = ( - [[HumanMessage(content=message)]] if isinstance(message, str) else [] - ) - if chat_messages: + chat_messages = [] + if isinstance(message, str): + chat_messages = [[HumanMessage(content=message)]] + + kwargs = { + "invocation_params": { + "_type": "openai", + "model": self._run.model, + "temperature": self._run.temperature, + }, + "name": recipient_agent.name if recipient_agent else "Unknown", + } + self.callback_handler.on_chat_model_start( - serialized={"name": self._run.model}, + serialized={"name": kwargs["name"], "id": [self._run.id]}, messages=chat_messages, run_id=self._run.id, parent_run_id=chain_run_id, @@ -208,6 +220,7 @@ def get_completion( "agent_name": self.agent.name, "recipient_agent_name": recipient_agent.name, }, + **kwargs, ) try: diff --git a/agency_swarm/util/tracking/__init__.py b/agency_swarm/util/tracking/__init__.py index 0a995dbb..f61f570d 100644 --- a/agency_swarm/util/tracking/__init__.py +++ b/agency_swarm/util/tracking/__init__.py @@ -1,6 +1,8 @@ import threading from typing import Callable, Literal +from .langchain_types import use_langchain_types + _callback_handler = None _lock = threading.Lock() @@ -25,8 +27,6 @@ def init_tracking(tracker_name: SUPPORTED_TRACKERS_TYPE, **kwargs): if tracker_name not in SUPPORTED_TRACKERS: raise ValueError(f"Invalid tracker name: {tracker_name}") - from .langchain_types import use_langchain_types - use_langchain_types() if tracker_name == "local": diff --git a/agency_swarm/util/tracking/langchain_types.py b/agency_swarm/util/tracking/langchain_types.py index 0a8f4bfd..fe268e8e 100644 --- a/agency_swarm/util/tracking/langchain_types.py +++ b/agency_swarm/util/tracking/langchain_types.py @@ -1,42 +1,61 @@ -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Generic, TypeVar, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field if TYPE_CHECKING: - from langchain.schema import AgentAction as LangchainAgentAction - from langchain.schema import AgentFinish as LangchainAgentFinish - from langchain.schema import HumanMessage as LangchainHumanMessage + from langchain_core.schema import AgentAction as LangchainAgentAction + from langchain_core.schema import AgentFinish as LangchainAgentFinish + from langchain_core.schema import HumanMessage as LangchainHumanMessage # Create base classes that match langchain's structure class BaseAgentAction(BaseModel): tool: str - tool_input: Dict[str, Any] | str - log: str + tool_input: Union[str, Dict[str, Any]] = Field(default_factory=dict) + log: str = "" class BaseAgentFinish(BaseModel): - return_values: Dict[str, Any] - log: str + return_values: Dict[str, Any] = Field(default_factory=dict) + log: str = "" class BaseHumanMessage(BaseModel): - content: str + content: str = "" -# Initialize with our base implementations first -AgentAction = BaseAgentAction -AgentFinish = BaseAgentFinish -HumanMessage = BaseHumanMessage +T = TypeVar("T") -def use_langchain_types(): +class Proxy(Generic[T]): + def __init__(self, default_impl: T): + self._impl: T = default_impl + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self._impl(*args, **kwargs) + + def set_implementation(self, impl: T) -> None: + self._impl = impl + + +# Initialize with our base implementations +AgentAction = Proxy[Union[BaseAgentAction, "LangchainAgentAction"]](BaseAgentAction) +AgentFinish = Proxy[Union[BaseAgentFinish, "LangchainAgentFinish"]](BaseAgentFinish) +HumanMessage = Proxy[Union[BaseHumanMessage, "LangchainHumanMessage"]](BaseHumanMessage) + + +def use_langchain_types() -> None: """Switch to using langchain types after langchain is imported""" global AgentAction, AgentFinish, HumanMessage - from langchain.schema import AgentAction as LangchainAgentAction - from langchain.schema import AgentFinish as LangchainAgentFinish - from langchain.schema import HumanMessage as LangchainHumanMessage - - AgentAction = LangchainAgentAction - AgentFinish = LangchainAgentFinish - HumanMessage = LangchainHumanMessage + from langchain_core.schema import AgentAction as LangchainAgentAction + from langchain_core.schema import AgentFinish as LangchainAgentFinish + from langchain_core.schema import HumanMessage as LangchainHumanMessage + + # Call model_rebuild on these imported classes to resolve forward references + LangchainAgentAction.model_rebuild() + LangchainAgentFinish.model_rebuild() + LangchainHumanMessage.model_rebuild() + + AgentAction.set_implementation(LangchainAgentAction) + AgentFinish.set_implementation(LangchainAgentFinish) + HumanMessage.set_implementation(LangchainHumanMessage) diff --git a/tests/demos/demo_observability.py b/tests/demos/demo_observability.py index fc455118..d830b33c 100644 --- a/tests/demos/demo_observability.py +++ b/tests/demos/demo_observability.py @@ -1,15 +1,18 @@ +import logging + from dotenv import load_dotenv from agency_swarm import Agency, Agent from agency_swarm.util import init_tracking load_dotenv() +logging.basicConfig(level=logging.INFO) def main(): # Set the tracker type - # TRACKER = "local" - TRACKER = "langfuse" + TRACKER = "local" + # TRACKER = "langfuse" # Initialize tracking based on the selected tracker init_tracking(TRACKER)