Skip to content

Commit

Permalink
Resolve Langfuse issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Dec 20, 2024
1 parent 64f2b54 commit 630fa65
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 32 deletions.
25 changes: 19 additions & 6 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def get_completion(
# Chain start
if self.callback_handler:
self.callback_handler.on_chain_start(
serialized={"name": f"Thread.get_completion -> {recipient_agent.name}"},
serialized={
"name": f"Thread.get_completion -> {recipient_agent.name}",
"id": [self._run.id],
},
inputs={"message": message},
run_id=chain_run_id,
parent_run_id=parent_run_id,
Expand All @@ -195,19 +198,29 @@ def get_completion(

# chat model start callback
if self.callback_handler:
chat_messages = (
[[HumanMessage(content=message)]] if isinstance(message, str) else []
)
if chat_messages:
chat_messages = []
if isinstance(message, str):
chat_messages = [[HumanMessage(content=message)]]

kwargs = {
"invocation_params": {
"_type": "openai",
"model": self._run.model,
"temperature": self._run.temperature,
},
"name": recipient_agent.name if recipient_agent else "Unknown",
}

self.callback_handler.on_chat_model_start(
serialized={"name": self._run.model},
serialized={"name": kwargs["name"], "id": [self._run.id]},
messages=chat_messages,
run_id=self._run.id,
parent_run_id=chain_run_id,
metadata={
"agent_name": self.agent.name,
"recipient_agent_name": recipient_agent.name,
},
**kwargs,
)

try:
Expand Down
4 changes: 2 additions & 2 deletions agency_swarm/util/tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import threading
from typing import Callable, Literal

from .langchain_types import use_langchain_types

_callback_handler = None
_lock = threading.Lock()

Expand All @@ -25,8 +27,6 @@ def init_tracking(tracker_name: SUPPORTED_TRACKERS_TYPE, **kwargs):
if tracker_name not in SUPPORTED_TRACKERS:
raise ValueError(f"Invalid tracker name: {tracker_name}")

from .langchain_types import use_langchain_types

use_langchain_types()

if tracker_name == "local":
Expand Down
63 changes: 41 additions & 22 deletions agency_swarm/util/tracking/langchain_types.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,61 @@
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Generic, TypeVar, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field

if TYPE_CHECKING:
from langchain.schema import AgentAction as LangchainAgentAction
from langchain.schema import AgentFinish as LangchainAgentFinish
from langchain.schema import HumanMessage as LangchainHumanMessage
from langchain_core.schema import AgentAction as LangchainAgentAction
from langchain_core.schema import AgentFinish as LangchainAgentFinish
from langchain_core.schema import HumanMessage as LangchainHumanMessage


# Create base classes that match langchain's structure
class BaseAgentAction(BaseModel):
tool: str
tool_input: Dict[str, Any] | str
log: str
tool_input: Union[str, Dict[str, Any]] = Field(default_factory=dict)
log: str = ""


class BaseAgentFinish(BaseModel):
return_values: Dict[str, Any]
log: str
return_values: Dict[str, Any] = Field(default_factory=dict)
log: str = ""


class BaseHumanMessage(BaseModel):
content: str
content: str = ""


# Initialize with our base implementations first
AgentAction = BaseAgentAction
AgentFinish = BaseAgentFinish
HumanMessage = BaseHumanMessage
T = TypeVar("T")


def use_langchain_types():
class Proxy(Generic[T]):
def __init__(self, default_impl: T):
self._impl: T = default_impl

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._impl(*args, **kwargs)

def set_implementation(self, impl: T) -> None:
self._impl = impl


# Initialize with our base implementations
AgentAction = Proxy[Union[BaseAgentAction, "LangchainAgentAction"]](BaseAgentAction)
AgentFinish = Proxy[Union[BaseAgentFinish, "LangchainAgentFinish"]](BaseAgentFinish)
HumanMessage = Proxy[Union[BaseHumanMessage, "LangchainHumanMessage"]](BaseHumanMessage)


def use_langchain_types() -> None:
"""Switch to using langchain types after langchain is imported"""
global AgentAction, AgentFinish, HumanMessage
from langchain.schema import AgentAction as LangchainAgentAction
from langchain.schema import AgentFinish as LangchainAgentFinish
from langchain.schema import HumanMessage as LangchainHumanMessage

AgentAction = LangchainAgentAction
AgentFinish = LangchainAgentFinish
HumanMessage = LangchainHumanMessage
from langchain_core.schema import AgentAction as LangchainAgentAction
from langchain_core.schema import AgentFinish as LangchainAgentFinish
from langchain_core.schema import HumanMessage as LangchainHumanMessage

# Call model_rebuild on these imported classes to resolve forward references
LangchainAgentAction.model_rebuild()
LangchainAgentFinish.model_rebuild()
LangchainHumanMessage.model_rebuild()

AgentAction.set_implementation(LangchainAgentAction)
AgentFinish.set_implementation(LangchainAgentFinish)
HumanMessage.set_implementation(LangchainHumanMessage)
7 changes: 5 additions & 2 deletions tests/demos/demo_observability.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import logging

from dotenv import load_dotenv

from agency_swarm import Agency, Agent
from agency_swarm.util import init_tracking

load_dotenv()
logging.basicConfig(level=logging.INFO)


def main():
# Set the tracker type
# TRACKER = "local"
TRACKER = "langfuse"
TRACKER = "local"
# TRACKER = "langfuse"

# Initialize tracking based on the selected tracker
init_tracking(TRACKER)
Expand Down

0 comments on commit 630fa65

Please sign in to comment.