Skip to content

Commit

Permalink
Merge pull request #2001 from solliancenet/cp-langraph-react-workflow
Browse files Browse the repository at this point in the history
Introduction of the LangGraph ReAct workflow
  • Loading branch information
ciprianjichici authored Nov 27, 2024
2 parents cd554f3 + 7eefe17 commit 78d4796
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace FoundationaLLM.Common.Models.ResourceProviders.Agent.AgentWorkflows
[JsonPolymorphic(TypeDiscriminatorPropertyName = "type")]
[JsonDerivedType(typeof(AzureOpenAIAssistantsAgentWorkflow), AgentWorkflowTypes.AzureOpenAIAssistants)]
[JsonDerivedType(typeof(LangChainExpressionLanguageAgentWorkflow), AgentWorkflowTypes.LangChainExpressionLanguage)]
[JsonDerivedType(typeof(LangGraphReactAgentWorkflow), AgentWorkflowTypes.LangGraphReactAgent)]
public class AgentWorkflowBase
{
/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,10 @@ public static class AgentWorkflowTypes
/// The LangChain Expression Language agent workflow.
/// </summary>
public const string LangChainExpressionLanguage = "langchain-expression-language-workflow";

/// <summary>
/// The LangGraph ReAct agent workflow.
/// </summary>
public const string LangGraphReactAgent = "langgraph-react-agent-workflow";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System.Text.Json.Serialization;

namespace FoundationaLLM.Common.Models.ResourceProviders.Agent.AgentWorkflows
{
/// <summary>
/// Provides an agent workflow configuration for a LangGraph ReAct Agent workflow.
/// </summary>
public class LangGraphReactAgentWorkflow : AgentWorkflowBase
{
/// <inheritdoc/>
[JsonIgnore]
public override string Type => AgentWorkflowTypes.LangGraphReactAgent;
}
}
1 change: 1 addition & 0 deletions src/python/PythonSDK/PythonSDK.pyproj
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
<Compile Include="foundationallm\models\agents\agent_workflows\agent_workflow_base.py" />
<Compile Include="foundationallm\models\agents\agent_workflows\azure_openai_assistants_agent_workflow.py" />
<Compile Include="foundationallm\models\agents\agent_workflows\langchain_expression_language_agent_workflow.py" />
<Compile Include="foundationallm\models\agents\agent_workflows\langgraph_react_agent_workflow.py" />
<Compile Include="foundationallm\models\attachments\attachment_properties.py" />
<Compile Include="foundationallm\models\attachments\attachment_providers.py" />
<Compile Include="foundationallm\models\attachments\__init__.py" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from langchain_core.language_models import BaseLanguageModel
from langchain_aws import ChatBedrockConverse
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI
from openai import AsyncAzureOpenAI as async_aoi
from foundationallm.config import Configuration, UserIdentity
Expand Down Expand Up @@ -173,6 +174,31 @@ def _build_conversation_history(self, messages:List[MessageHistoryItem]=None, me
chat_history += "\n\n"
return chat_history

def _build_conversation_history_message_list(self, messages:List[MessageHistoryItem]=None, message_count:int=None) -> List[BaseMessage]:
"""
Builds a LangChain Message chat history list from a list of MessageHistoryItem objects to
be added to the prompt template for the completion request.
Parameters
----------
messages : List[MessageHistoryItem]
The list of messages from which to build the chat history.
message_count : int
The number of messages to include in the chat history.
"""
if messages is None or len(messages)==0:
return []
if message_count is not None:
messages = messages[-message_count:]
history = []
for msg in messages:
# sender can be User (maps to HumanMessage) or Agent (maps to AIMessage)
if msg.sender == "User":
history.append(HumanMessage(content=msg.text))
else:
history.append(AIMessage(content=msg.text))
return history

def _record_full_prompt(self, prompt: str) -> str:
"""
Records the full prompt for the completion request.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import uuid
from langchain_community.callbacks import get_openai_callback
from langchain_community.callbacks.manager import get_bedrock_anthropic_callback
from langchain_core.messages import HumanMessage
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langgraph.prebuilt import create_react_agent
from foundationallm.langchain.agents import LangChainAgentBase
from foundationallm.langchain.exceptions import LangChainException
from foundationallm.langchain.retrievers import RetrieverFactory, CitationRetrievalBase
from foundationallm.models.agents.agent_workflows.azure_openai_assistants_agent_workflow import AzureOpenAIAssistantsAgentWorkflow
from foundationallm.models.agents import AzureOpenAIAssistantsAgentWorkflow, LangGraphReactAgentWorkflow
from foundationallm.models.constants import AgentCapabilityCategories
from foundationallm.models.operations import OperationTypes
from foundationallm.models.orchestration import (
CompletionRequestObjectKeys,
CompletionResponse,
OpenAITextMessageContentItem
)
from foundationallm.models.resource_providers.ai_models import CompletionAIModel
from foundationallm.models.resource_providers.configuration import APIEndpointConfiguration
from foundationallm.models.agents import (
AgentConversationHistorySettings,
Expand Down Expand Up @@ -44,7 +44,10 @@ class LangChainKnowledgeManagementAgent(LangChainAgentBase):
"""
The LangChain Knowledge Management agent.
"""


MAIN_MODEL_KEY = "main_model"
MAIN_PROMPT_KEY = "main_prompt"

def _get_document_retriever(
self,
request: KnowledgeManagementCompletionRequest,
Expand Down Expand Up @@ -174,12 +177,12 @@ def _validate_request(self, request: KnowledgeManagementCompletionRequest):
raise LangChainException("The objects property on the completion request cannot be null.", 400)

if request.agent.workflow is not None:
if request.agent.workflow.agent_workflow_ai_models["main_model"] is None:
if request.agent.workflow.agent_workflow_ai_models[self.MAIN_MODEL_KEY] is None:
raise LangChainException("The agent's workflow AI models requires a main_model.", 400)
if request.agent.workflow.prompt_object_ids["main_prompt"] is None:
if request.agent.workflow.prompt_object_ids[self.MAIN_PROMPT_KEY] is None:
raise LangChainException("The agent's workflow prompt object dictionary requires a main_prompt.", 400)
self.ai_model = self._get_ai_model_from_object_id(request.agent.workflow.agent_workflow_ai_models["main_model"].ai_model_object_id, request.objects)
self.prompt = self._get_prompt_from_object_id(request.agent.workflow.prompt_object_ids["main_prompt"], request.objects)
self.ai_model = self._get_ai_model_from_object_id(request.agent.workflow.agent_workflow_ai_models[self.MAIN_MODEL_KEY].ai_model_object_id, request.objects)
self.prompt = self._get_prompt_from_object_id(request.agent.workflow.prompt_object_ids[self.MAIN_PROMPT_KEY], request.objects)
else:
# Legacy code
self.ai_model = self._get_ai_model_from_object_id(request.agent.ai_model_object_id, request.objects)
Expand Down Expand Up @@ -275,6 +278,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
generated full prompt with context and token utilization and execution cost details.
"""
self._validate_request(request)
llm = self._get_language_model()

agent = request.agent
image_analysis_token_usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
Expand Down Expand Up @@ -373,7 +377,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
image_service = None
if "dalle-image-generation" in request.agent.tools:
dalle_tool = request.agent.tools["dalle-image-generation"]
model_object_id = dalle_tool["ai_model_object_ids"]["main_model"]
model_object_id = dalle_tool["ai_model_object_ids"][self.MAIN_MODEL_KEY]
image_generation_deployment_model = request.objects[model_object_id]["deployment_name"]
api_endpoint_object_id = request.objects[model_object_id]["endpoint_object_id"]
image_generation_client = self._get_image_gen_language_model(api_endpoint_object_id=api_endpoint_object_id, objects=request.objects)
Expand Down Expand Up @@ -414,6 +418,48 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
)
# End Assistants API implementation

# Start LangGraph ReAct Agent workflow implementation
if (agent.workflow is not None and isinstance(agent.workflow, LangGraphReactAgentWorkflow)):
# Temporary placeholder
from typing import Literal
from langchain_core.tools import tool
@tool
def get_weather(city: Literal["nyc", "sf"]):
"""Use this to get weather information."""
if city == "nyc":
return "It might be cloudy in nyc"
elif city == "sf":
return "It's always sunny in sf"
else:
raise AssertionError("Unknown city")
tools = [get_weather]
# End temporary placeholder

# Define the graph
graph = create_react_agent(llm, tools=tools, state_modifier=self.prompt.prefix)
messages = self._build_conversation_history_message_list(request.message_history, agent.conversation_history_settings.max_history)
messages.append(HumanMessage(content=request.user_prompt))
response = await graph.ainvoke({'messages': messages})
# TODO: process tool messages with analysis results AIMessage with content='' but has addition_kwargs={'tool_calls';[...]}
# print(response)
final_message = response["messages"][-1]
response_content = OpenAITextMessageContentItem(
value = final_message.content,
agent_capability_category = AgentCapabilityCategories.FOUNDATIONALLM_KNOWLEDGE_MANAGEMENT
)
return CompletionResponse(
operation_id = request.operation_id,
content = [response_content],
citations = [],
user_prompt = request.user_prompt,
full_prompt = self.prompt.prefix,
completion_tokens = final_message.usage_metadata["output_tokens"] or 0,
prompt_tokens = final_message.usage_metadata["input_tokens"] or 0,
total_tokens = final_message.usage_metadata["total_tokens"] or 0,
total_cost = 0
)
# End LangGraph ReAct Agent workflow implementation

# Start LangChain Expression Language (LCEL) implementation

# Get the vector document retriever, if it exists.
Expand Down Expand Up @@ -452,7 +498,7 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
chain_context
| prompt_template
| RunnableLambda(self._record_full_prompt)
| self._get_language_model()
| llm
)

retvalue = None
Expand Down
2 changes: 2 additions & 0 deletions src/python/PythonSDK/foundationallm/models/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .agent_conversation_history_settings import AgentConversationHistorySettings
from .agent_gatekeeper_settings import AgentGatekeeperSettings
from .agent_orchestration_settings import AgentOrchestrationSettings
from .agent_tool import AgentTool
from .agent_vectorization_settings import AgentVectorizationSettings
from .agent_workflows.agent_workflow_ai_model import AgentWorkflowAIModel
from .agent_workflows.agent_workflow_base import AgentWorkflowBase
from .agent_workflows.azure_openai_assistants_agent_workflow import AzureOpenAIAssistantsAgentWorkflow
from .agent_workflows.langchain_expression_language_agent_workflow import LangChainExpressionLanguageAgentWorkflow
from .agent_workflows.langgraph_react_agent_workflow import LangGraphReactAgentWorkflow
from .agent_base import AgentBase
from .knowledge_management_agent import KnowledgeManagementAgent
from .knowledge_management_completion_request import KnowledgeManagementCompletionRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
AgentGatekeeperSettings,
AgentOrchestrationSettings,
AzureOpenAIAssistantsAgentWorkflow,
LangChainExpressionLanguageAgentWorkflow
LangChainExpressionLanguageAgentWorkflow,
LangGraphReactAgentWorkflow
)
from foundationallm.models.resource_providers import ResourceBase

Expand All @@ -24,7 +25,7 @@ class AgentBase(ResourceBase):
tools: Optional[dict] = Field(default=[], description="A dictionary object with assigned agent tools.")
workflow: Optional[
Annotated [
Union[AzureOpenAIAssistantsAgentWorkflow, LangChainExpressionLanguageAgentWorkflow],
Union[AzureOpenAIAssistantsAgentWorkflow, LangChainExpressionLanguageAgentWorkflow, LangGraphReactAgentWorkflow],
Field(discriminator='type')
]
]= Field(default=None, description="The workflow configuration for the agent.")
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pydantic import Field
from typing import Any, Self, Literal
from foundationallm.langchain.exceptions import LangChainException
from foundationallm.utils import object_utils
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any, Self, Literal
from foundationallm.langchain.exceptions import LangChainException
from foundationallm.utils import object_utils
from .agent_workflow_base import AgentWorkflowBase

class LangGraphReactAgentWorkflow(AgentWorkflowBase):
"""
The configuration for a LangGraph ReAct agent workflow.
"""
type: Literal["langgraph-react-agent-workflow"] = "langgraph-react-agent-workflow"

@staticmethod
def from_object(obj: Any) -> Self:

workflow: LangGraphReactAgentWorkflow = None

try:
workflow = LangGraphReactAgentWorkflow(**object_utils.translate_keys(obj))
except Exception as e:
raise LangChainException(f"The LangGraph ReAct Agent Workflow object provided is invalid. {str(e)}", 400)

if workflow is None:
raise LangChainException("The LangGraph ReAct Agent Workflow object provided is invalid.", 400)

return workflow
1 change: 1 addition & 0 deletions src/python/PythonSDK/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ langchain==0.3.7
langchain-aws==0.2.7
langchain-experimental==0.3.3
langchain-openai==0.2.9
langgraph==0.2.53
openai==1.55.0
opentelemetry-api==1.27.0
opentelemetry-sdk==1.27.0
Expand Down

0 comments on commit 78d4796

Please sign in to comment.