Skip to content

Commit

Permalink
Observer: Create callbacks for agent, router & tools (#54)
Browse files Browse the repository at this point in the history
* Setting up callbacks

* Adding callbacks to models

* Renaming to model_name

* Completing callbacks for agent start and ends

* Fixing precommit hook issues

* Updating poetry lock

* Removing line

* Fix for properly registering the agent models
  • Loading branch information
vizsatiz authored Nov 15, 2024
1 parent f581f3f commit 26e90c8
Show file tree
Hide file tree
Showing 19 changed files with 722 additions and 317 deletions.
9 changes: 9 additions & 0 deletions examples/llm_extensibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from flo_ai.tools.flo_tool import flotool
from flo_ai.state.flo_callbacks import flo_agent_callback, FloCallbackResponse

from dotenv import load_dotenv
import warnings
Expand Down Expand Up @@ -38,7 +39,15 @@ def email_tool(to: str, message: str):
return f'Email sent successfully to: {to}'


@flo_agent_callback
def agent_callback(response: FloCallbackResponse):
print('------------- START AGENT CALLBACK -----------')
print(response)
print('------------- END AGENT CALLBACK -----------')


session.register_tool('SendEmailTool', email_tool)
session.register_callback(agent_callback)

agent_yaml = """
apiVersion: flo/alpha-v1
Expand Down
4 changes: 2 additions & 2 deletions examples/simple_blogging_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini')
session = (
FloSession(llm, log_level='INFO')
FloSession(llm)
.register_tool(name='TavilySearchResults', tool=TavilySearchResults())
.register_tool(
name='DummyTool',
Expand All @@ -43,5 +43,5 @@

Flo.set_log_level('INFO')
flo: Flo = Flo.build(session, yaml=yaml_data)
# data = flo.invoke(input_prompt)
data = flo.invoke(input_prompt)
# print((data['messages'][-1]).content)
10 changes: 9 additions & 1 deletion flo_ai/common/flo_langchain_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from flo_ai.common.flo_logger import get_logger
from flo_ai.state.flo_callbacks import FloToolCallback


class FloLangchainLogger(BaseCallbackHandler):
def __init__(self, session_id: str):
def __init__(self, session_id: str, tool_callbacks: List[FloToolCallback] = []):
self.session_id = session_id
self.tool_callbacks = tool_callbacks

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
Expand Down Expand Up @@ -41,14 +43,20 @@ def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
get_logger().debug(f'onToolStart: {input_str}', self)
[
x.on_tool_start(serialized['name'], kwargs['inputs'], kwargs)
for x in self.tool_callbacks
]

def on_tool_end(self, output: str, **kwargs: Any) -> None:
get_logger().debug(f'onToolEnd: {output}', self)
[x.on_tool_end(kwargs['name'], output, kwargs) for x in self.tool_callbacks]

def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
get_logger().debug(f'onToolError: {error}', self)
[x.on_tool_error(kwargs['name'], error, kwargs) for x in self.tool_callbacks]

def on_text(self, text: str, **kwargs: Any) -> None:
get_logger().debug(f'onText: {text}', self)
Expand Down
8 changes: 6 additions & 2 deletions flo_ai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
set_logger_internal,
FloLogConfig,
)
from langchain.tools import StructuredTool


class Flo:
Expand Down Expand Up @@ -90,11 +91,14 @@ def draw_to_file(self, filename: str, xray=True):
def validate_invoke(self, session: FloSession):
async_coroutines = filter(
lambda x: (
hasattr(x, 'coroutine') and asyncio.iscoroutinefunction(x.coroutine)
isinstance(x, StructuredTool)
and hasattr(x, 'coroutine')
and asyncio.iscoroutinefunction(x.coroutine)
),
session.tools.values(),
)
if len(list(async_coroutines)) > 0:
async_tools = list(async_coroutines)
if len(async_tools) > 0:
raise FloException(
f"""You seem to have atleast one async tool registered in this session. Please use flo.async_invoke or flo.async_stream. Checkout {DOCUMENTATION_WEBSITE}"""
)
15 changes: 12 additions & 3 deletions flo_ai/factory/agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,30 @@ def __create_agentic_agent(
agent_model = AgentFactory.__resolve_model(session, agent.model)
tools = [tool_map[tool.name] for tool in agent.tools]
flo_agent: FloAgent = FloAgent.Builder(
session, agent, tools, llm=agent_model, on_error=session.on_agent_error
session,
agent,
tools,
llm=agent_model,
on_error=session.on_agent_error,
model_name=agent.model,
).build()
return flo_agent

@staticmethod
def __create_llm_agent(session: FloSession, agent: AgentConfig) -> FloLLMAgent:
agent_model = AgentFactory.__resolve_model(session, agent.model)
builder = FloLLMAgent.Builder(session, agent, llm=agent_model)
builder = FloLLMAgent.Builder(
session, agent, llm=agent_model, model_name=agent.model
)
llm_agent: FloLLMAgent = builder.build()
return llm_agent

@staticmethod
def __create_runnable_agent(session: FloSession, agent: AgentConfig) -> FloLLMAgent:
runnable = session.tools[agent.tools[0].name]
return FloToolAgent.Builder(session, agent, runnable).build()
return FloToolAgent.Builder(
session, agent, runnable, model_name=agent.model
).build()

@staticmethod
def __create_reflection_agent(
Expand Down
13 changes: 11 additions & 2 deletions flo_ai/models/flo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@

class FloAgent(ExecutableFlo):
def __init__(
self, agent: Runnable, executor: AgentExecutor, config: AgentConfig
self,
agent: Runnable,
executor: AgentExecutor,
config: AgentConfig,
model_nick_name: str,
) -> None:
super().__init__(config.name, executor, ExecutableType.agentic)
self.model_name = model_nick_name
self.agent: Runnable = (agent,)
self.executor: AgentExecutor = executor
self.config: AgentConfig = config
Expand All @@ -30,9 +35,11 @@ def __init__(
role: Optional[str] = None,
llm: Union[BaseLanguageModel, None] = None,
on_error: Union[str, Callable] = True,
model_name: Union[str, None] = 'default',
) -> None:
prompt: Union[ChatPromptTemplate, str] = config.job
self.name: str = config.name
self.model_name = model_name
self.llm = llm if llm is not None else session.llm
self.config = config
system_prompts = (
Expand Down Expand Up @@ -60,4 +67,6 @@ def build(self) -> AgentExecutor:
return_intermediate_steps=True,
handle_parsing_errors=self.on_error,
)
return FloAgent(agent, executor, self.config)
return FloAgent(
agent, executor, self.config, model_nick_name=self.model_name
)
11 changes: 9 additions & 2 deletions flo_ai/models/flo_delegation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,29 @@


class FloDelegatorAgent(ExecutableFlo):
def __init__(self, executor: Runnable, config: AgentConfig) -> None:
def __init__(
self, executor: Runnable, config: AgentConfig, model_name: str
) -> None:
super().__init__(config.name, executor, ExecutableType.delegator)
self.executor: Runnable = executor
self.config: AgentConfig = config
self.model_name = model_name

class Builder:
def __init__(
self,
session: FloSession,
agentConfig: AgentConfig,
llm: Optional[BaseLanguageModel] = None,
model_name: str = None,
) -> None:
self.config = agentConfig
delegator_base_system_message = (
'You are a delegator tasked with routing a conversation between the'
' following {member_type}: {members}. Given the following rules,'
' respond with the worker to act next '
)
self.model_name = model_name
self.llm = session.llm if llm is None else llm
self.options = [x.name for x in agentConfig.to]
self.llm_router_prompt = ChatPromptTemplate.from_messages(
Expand Down Expand Up @@ -75,4 +80,6 @@ def build(self):
| JsonOutputFunctionsParser()
)

return FloDelegatorAgent(executor=chain, config=self.config)
return FloDelegatorAgent(
executor=chain, config=self.config, model_name=self.model_name
)
9 changes: 7 additions & 2 deletions flo_ai/models/flo_llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,23 @@


class FloLLMAgent(ExecutableFlo):
def __init__(self, executor: Runnable, config: AgentConfig) -> None:
def __init__(
self, executor: Runnable, config: AgentConfig, model_name: str
) -> None:
super().__init__(config.name, executor, ExecutableType.llm)
self.executor: Runnable = executor
self.config: AgentConfig = config
self.model_name: str = model_name

class Builder:
def __init__(
self,
session: FloSession,
config: AgentConfig,
llm: Union[BaseLanguageModel, None] = None,
model_name: str = None,
) -> None:
self.model_name = model_name
prompt: Union[ChatPromptTemplate, str] = config.job

self.name: str = config.name
Expand All @@ -42,4 +47,4 @@ def __init__(

def build(self) -> Runnable:
executor = self.prompt | self.llm | StrOutputParser()
return FloLLMAgent(executor, self.config)
return FloLLMAgent(executor, self.config, self.model_name)
98 changes: 92 additions & 6 deletions flo_ai/models/flo_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from langchain_core.messages import HumanMessage
from flo_ai.yaml.config import AgentConfig, TeamConfig
from flo_ai.models.flo_executable import ExecutableType
from typing import Union
from flo_ai.state.flo_session import FloSession
from typing import Union, Type, List
from flo_ai.state.flo_callbacks import FloAgentCallback, FloRouterCallback, FloCallback


class FloNode:
Expand All @@ -23,19 +25,26 @@ def __init__(
self.config: Union[AgentConfig | TeamConfig] = config

class Builder:
def __init__(self, session: FloSession) -> None:
self.session = session

def build_from_agent(self, flo_agent: FloAgent) -> 'FloNode':
agent_func = functools.partial(
FloNode.Builder.__teamflo_agent_node,
agent=flo_agent.runnable,
name=flo_agent.name,
agent_config=flo_agent.config,
session=self.session,
model_name=flo_agent.model_name,
)
return FloNode(agent_func, flo_agent.name, flo_agent.type, flo_agent.config)

def build_from_team(self, flo_team: FloRoutedTeam) -> 'FloNode':
team_chain = (
functools.partial(
FloNode.Builder.__teamflo_team_node, members=flo_team.runnable.nodes
FloNode.Builder.__teamflo_team_node,
members=flo_team.runnable.nodes,
session=self.session,
)
| flo_team.runnable
)
Expand All @@ -56,6 +65,8 @@ def build_from_router(self, flo_router) -> 'FloNode':
agent=flo_router.executor,
name=flo_router.router_name,
agent_config=flo_router.config,
session=self.session,
model_name=flo_router.model_name,
)
return FloNode(
router_func, flo_router.router_name, flo_router.type, flo_router.config
Expand All @@ -67,20 +78,95 @@ def __teamflo_agent_node(
agent: AgentExecutor,
name: str,
agent_config: AgentConfig,
session: FloSession,
model_name: str,
):
result = agent.invoke(state)
output = result if isinstance(result, str) else result['output']
agent_cbs: List[FloAgentCallback] = FloNode.Builder.__filter_callbacks(
session, FloAgentCallback
)
flo_cbs: List[FloCallback] = FloNode.Builder.__filter_callbacks(
session, FloCallback
)
[
callback.on_agent_start(name, model_name, state['messages'], **{})
for callback in agent_cbs
]
[
callback.on_agent_start(name, model_name, state['messages'], **{})
for callback in flo_cbs
]
try:
result = agent.invoke(state)
output = result if isinstance(result, str) else result['output']
except Exception as e:
[
callback.on_agent_error(name, model_name, e, **{})
for callback in agent_cbs
]
[
callback.on_agent_error(name, model_name, e, **{})
for callback in flo_cbs
]
raise e
[
callback.on_agent_end(name, model_name, output, **{})
for callback in agent_cbs
]
[
callback.on_agent_start(name, model_name, output, **{})
for callback in flo_cbs
]
return {STATE_NAME_MESSAGES: [HumanMessage(content=output, name=name)]}

@staticmethod
def __filter_callbacks(session: FloSession, type: Type):
cbs = session.callbacks
return list(filter(lambda callback: isinstance(callback, type), cbs))

@staticmethod
def __teamflo_router_node(
state: TeamFloAgentState,
agent: AgentExecutor,
name: str,
agent_config: AgentConfig,
session: FloSession,
model_name: str,
):
result = agent.invoke(state)
nextNode = result if isinstance(result, str) else result['next']
agent_cbs: List[FloRouterCallback] = FloNode.Builder.__filter_callbacks(
session, FloRouterCallback
)
flo_cbs: List[FloCallback] = FloNode.Builder.__filter_callbacks(
session, FloCallback
)
[
callback.on_router_start(name, model_name, state['messages'], **{})
for callback in agent_cbs
]
[
callback.on_router_start(name, model_name, state['messages'], **{})
for callback in flo_cbs
]
try:
result = agent.invoke(state)
nextNode = result if isinstance(result, str) else result['next']
except Exception as e:
[
callback.on_router_error(name, model_name, e, **{})
for callback in agent_cbs
]
[
callback.on_router_error(name, model_name, e, **{})
for callback in flo_cbs
]
raise e
[
callback.on_router_end(name, model_name, nextNode, **{})
for callback in agent_cbs
]
[
callback.on_router_start(name, model_name, nextNode, **{})
for callback in flo_cbs
]
return {'next': nextNode}

@staticmethod
Expand Down
Loading

0 comments on commit 26e90c8

Please sign in to comment.