Skip to content

Commit

Permalink
[WIP] add agentic RAG notebook (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored Jun 24, 2024
1 parent 94ac16e commit 7c7a699
Show file tree
Hide file tree
Showing 23 changed files with 542 additions and 70 deletions.
455 changes: 455 additions & 0 deletions example_scripts/agentic_rag_toolservice.ipynb

Large diffs are not rendered by default.

18 changes: 15 additions & 3 deletions llama_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,24 @@
from llama_agents.orchestrators import (
AgentOrchestrator,
PipelineOrchestrator,
ServiceComponent,
ServiceTool,
)
from llama_agents.tools import MetaServiceTool
from llama_agents.tools import MetaServiceTool, ServiceComponent, ServiceTool
from llama_agents.services import AgentService, ToolService, HumanService

# configure logger
import logging

root_logger = logging.getLogger("llama_agents")

formatter = logging.Formatter("%(levelname)s:%(name)s - %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
root_logger.addHandler(console_handler)

root_logger.setLevel(logging.WARNING)
root_logger.propagate = False


__all__ = [
# services
"AgentService",
Expand Down
9 changes: 3 additions & 6 deletions llama_agents/control_plane/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
import uvicorn
from fastapi import FastAPI
from logging import getLogger
from typing import Dict, List, Optional

from llama_index.core import StorageContext, VectorStoreIndex
Expand All @@ -16,19 +17,15 @@
from llama_agents.message_queues.base import BaseMessageQueue, PublishCallback
from llama_agents.messages.base import QueueMessage
from llama_agents.orchestrators.base import BaseOrchestrator
from llama_agents.orchestrators.service_tool import ServiceTool
from llama_agents.tools import ServiceTool
from llama_agents.types import (
ActionTypes,
ServiceDefinition,
TaskDefinition,
TaskResult,
)

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = getLogger(__name__)


class ControlPlaneServer(BaseControlPlane):
Expand Down
9 changes: 7 additions & 2 deletions llama_agents/launchers/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def signal_handler(sig: Any, frame: Any) -> None:
return signal_handler

async def alaunch_single(self, initial_task: str) -> str:
# clear any result
self.result = None

# register human consumer
human_consumer = HumanMessageConsumer(
message_handler={
Expand All @@ -98,7 +101,9 @@ async def alaunch_single(self, initial_task: str) -> str:
# start services
bg_tasks: List[asyncio.Task] = []
for service in self.services:
bg_tasks.append(asyncio.create_task(service.launch_local()))
if hasattr(service, "raise_exceptions"):
service.raise_exceptions = True # ensure exceptions are raised
bg_tasks.append(await service.launch_local())

# publish initial task
await self.publish(
Expand All @@ -109,7 +114,7 @@ async def alaunch_single(self, initial_task: str) -> str:
),
)
# runs until the message queue is stopped by the human consumer
mq_task = asyncio.create_task(self.message_queue.launch_local())
mq_task = await self.message_queue.launch_local()
shutdown_handler = self.get_shutdown_handler([mq_task] + bg_tasks)
loop = asyncio.get_event_loop()
while loop.is_running():
Expand Down
7 changes: 5 additions & 2 deletions llama_agents/message_queues/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from llama_agents.message_queues.base import BaseMessageQueue
from llama_agents.message_queues.simple import SimpleMessageQueue
from llama_agents.message_queues.simple import (
SimpleMessageQueue,
SimpleRemoteClientMessageQueue,
)

__all__ = ["BaseMessageQueue", "SimpleMessageQueue"]
__all__ = ["BaseMessageQueue", "SimpleMessageQueue", "SimpleRemoteClientMessageQueue"]
20 changes: 11 additions & 9 deletions llama_agents/message_queues/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Message queue module."""

import asyncio
import inspect

from abc import ABC, abstractmethod
from logging import getLogger
from pydantic import BaseModel
from typing import Any, List, Optional, Protocol, TYPE_CHECKING

Expand All @@ -10,11 +12,7 @@
if TYPE_CHECKING:
from llama_agents.message_consumers.base import BaseMessageQueueConsumer

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = getLogger(__name__)


class MessageProcessor(Protocol):
Expand Down Expand Up @@ -49,10 +47,14 @@ async def publish(
self,
message: QueueMessage,
callback: Optional[PublishCallback] = None,
**kwargs: Any
**kwargs: Any,
) -> Any:
"""Send message to a consumer."""
logger.info("Publishing message: " + str(message))
logger.info(
f"Publishing message to '{message.type}' with action '{message.action}'"
)
logger.debug(f"Message: {message.model_dump()}")

message.stats.publish_time = message.stats.timestamp_str()
await self._publish(message)

Expand Down Expand Up @@ -88,7 +90,7 @@ async def processing_loop(self) -> None:
...

@abstractmethod
async def launch_local(self) -> None:
async def launch_local(self) -> asyncio.Task:
"""Launch the service in-process."""
...

Expand Down
12 changes: 5 additions & 7 deletions llama_agents/message_queues/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import asyncio
import httpx
import random
import logging
import uvicorn

from collections import deque
from contextlib import asynccontextmanager
from fastapi import FastAPI
from logging import getLogger
from pydantic import Field, PrivateAttr
from typing import Any, AsyncGenerator, Dict, List, Optional
from urllib.parse import urljoin
Expand All @@ -22,9 +22,7 @@
)
from llama_agents.types import PydanticValidatedUrl

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
logger = getLogger(__name__)


class SimpleRemoteClientMessageQueue(BaseMessageQueue):
Expand Down Expand Up @@ -101,7 +99,7 @@ async def processing_loop(self) -> None:
"`procesing_loop()` is not implemented for this class."
)

async def launch_local(self) -> None:
async def launch_local(self) -> asyncio.Task:
raise NotImplementedError("`launch_local()` is not implemented for this class.")

async def launch_server(self) -> None:
Expand Down Expand Up @@ -283,9 +281,9 @@ async def lifespan(self, app: FastAPI) -> AsyncGenerator[None, None]:
yield
self.running = False

async def launch_local(self) -> None:
async def launch_local(self) -> asyncio.Task:
logger.info("Launching message queue locally")
asyncio.create_task(self.processing_loop())
return asyncio.create_task(self.processing_loop())

async def launch_server(self) -> None:
logger.info(f"Launching message queue server at {self.host}:{self.port}")
Expand Down
4 changes: 0 additions & 4 deletions llama_agents/orchestrators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from llama_agents.orchestrators.agent import AgentOrchestrator
from llama_agents.orchestrators.base import BaseOrchestrator
from llama_agents.orchestrators.pipeline import PipelineOrchestrator
from llama_agents.orchestrators.service_component import ServiceComponent
from llama_agents.orchestrators.service_tool import ServiceTool

__all__ = [
"BaseOrchestrator",
"PipelineOrchestrator",
"ServiceComponent",
"ServiceTool",
"AgentOrchestrator",
]
2 changes: 1 addition & 1 deletion llama_agents/orchestrators/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from llama_agents.messages.base import QueueMessage
from llama_agents.orchestrators.base import BaseOrchestrator
from llama_agents.orchestrators.service_tool import ServiceTool
from llama_agents.tools.service_tool import ServiceTool
from llama_agents.types import ActionTypes, ChatMessage, TaskDefinition, TaskResult

HISTORY_KEY = "chat_history"
Expand Down
2 changes: 1 addition & 1 deletion llama_agents/orchestrators/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from llama_agents.messages.base import QueueMessage
from llama_agents.orchestrators.base import BaseOrchestrator
from llama_agents.orchestrators.service_component import ServiceComponent
from llama_agents.tools.service_component import ServiceComponent
from llama_agents.types import ActionTypes, TaskDefinition, TaskResult

RUN_STATE_KEY = "run_state"
Expand Down
21 changes: 14 additions & 7 deletions llama_agents/services/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI
from logging import getLogger
from pydantic import PrivateAttr
from typing import AsyncGenerator, Dict, List, Literal, Optional

Expand All @@ -25,11 +26,7 @@
CONTROL_PLANE_NAME,
)

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
logger = getLogger(__name__)


class AgentService(BaseService):
Expand All @@ -41,6 +38,7 @@ class AgentService(BaseService):
step_interval: float = 0.1
host: Optional[str] = None
port: Optional[int] = None
raise_exceptions: bool = False

_message_queue: BaseMessageQueue = PrivateAttr()
_app: FastAPI = PrivateAttr()
Expand All @@ -59,6 +57,7 @@ def __init__(
step_interval: float = 0.1,
host: Optional[str] = None,
port: Optional[int] = None,
raise_exceptions: bool = False,
) -> None:
super().__init__(
agent=agent,
Expand All @@ -69,6 +68,7 @@ def __init__(
prompt=prompt,
host=host,
port=port,
raise_exceptions=raise_exceptions,
)

self._message_queue = message_queue
Expand Down Expand Up @@ -173,6 +173,13 @@ async def processing_loop(self) -> None:
)
except Exception as e:
logger.error(f"Error in {self.service_name} processing_loop: {e}")
if self.raise_exceptions:
# Kill everything
# TODO: is there a better way to do this?
import signal

signal.raise_signal(signal.SIGINT)

continue

await asyncio.sleep(self.step_interval)
Expand All @@ -197,9 +204,9 @@ def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
handler=self.process_message,
)

async def launch_local(self) -> None:
async def launch_local(self) -> asyncio.Task:
logger.info(f"{self.service_name} launch_local")
asyncio.create_task(self.processing_loop())
return asyncio.create_task(self.processing_loop())

# ---- Server based methods ----

Expand Down
3 changes: 2 additions & 1 deletion llama_agents/services/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import httpx
from abc import ABC, abstractmethod
from pydantic import BaseModel
Expand Down Expand Up @@ -46,7 +47,7 @@ async def publish(self, message: QueueMessage, **kwargs: Any) -> None:
)

@abstractmethod
async def launch_local(self) -> None:
async def launch_local(self) -> asyncio.Task:
"""Launch the service in-process."""
...

Expand Down
10 changes: 4 additions & 6 deletions llama_agents/services/human.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import logging
import uuid
import uvicorn
from asyncio import Lock
from fastapi import FastAPI
from logging import getLogger
from pydantic import PrivateAttr
from typing import Dict, List, Optional

Expand All @@ -27,9 +27,7 @@
)


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
logger = getLogger(__name__)


HELP_REQUEST_TEMPLATE_STR = (
Expand Down Expand Up @@ -200,9 +198,9 @@ def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
handler=self.process_message,
)

async def launch_local(self) -> None:
async def launch_local(self) -> asyncio.Task:
logger.info(f"{self.service_name} launch_local")
asyncio.create_task(self.processing_loop())
return asyncio.create_task(self.processing_loop())

# ---- Server based methods ----

Expand Down
10 changes: 4 additions & 6 deletions llama_agents/services/tool.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import logging
import uuid
import uvicorn
from asyncio import Lock
from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import PrivateAttr
from logging import getLogger
from typing import Any, AsyncGenerator, Dict, List, Optional

from llama_index.core.agent.function_calling.step import (
Expand All @@ -29,9 +29,7 @@
ServiceDefinition,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)
logger = getLogger(__name__)


class ToolService(BaseService):
Expand Down Expand Up @@ -198,8 +196,8 @@ def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
handler=self.process_message,
)

async def launch_local(self) -> None:
asyncio.create_task(self.processing_loop())
async def launch_local(self) -> asyncio.Task:
return asyncio.create_task(self.processing_loop())

# ---- Server based methods ----

Expand Down
4 changes: 3 additions & 1 deletion llama_agents/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from llama_agents.tools.meta_service_tool import MetaServiceTool
from llama_agents.tools.service_tool import ServiceTool
from llama_agents.tools.service_component import ServiceComponent


__all__ = ["MetaServiceTool"]
__all__ = ["MetaServiceTool", "ServiceTool", "ServiceComponent"]
Loading

0 comments on commit 7c7a699

Please sign in to comment.