Skip to content

Commit

Permalink
Merge pull request #2011 from solliancenet/cp-dalle-image-gen-tool
Browse files Browse the repository at this point in the history
Initial implementation of the DALL-E Image Generation tool.
  • Loading branch information
ciprianjichici authored Nov 29, 2024
2 parents efcefb4 + a2d0412 commit c8ec3fc
Show file tree
Hide file tree
Showing 13 changed files with 252 additions and 39 deletions.
8 changes: 5 additions & 3 deletions docs/release-notes/breaking-changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ The following new App Configuration settings are required:

#### Agent Tool configuration changes

Agent tools are now an array of AgentTool objects rather than a dictionary.

When defining tools for an agent, each tool now requires a `package_name` property. This property is used to identify the package that contains the tool's implementation. If the tool is internal, the `package_name` should be set to `FoundationaLLM`, if the tool is external, the `package_name` should be set to the name of the external package.

#### Renamed classes

The following classes have been renamed:

Original Class | New Class
--- | ---
`FoundationaLLM.Common.Models.Orchestration.Response.Citation` | `FoundationaLLM.Common.Models.Orchestration.Response.ContentArtifact`
| Original Class | New Class |
| --- | --- |
| `FoundationaLLM.Common.Models.Orchestration.Response.Citation` | `FoundationaLLM.Common.Models.Orchestration.Response.ContentArtifact` |

## Starting with 0.8.4

Expand Down
4 changes: 2 additions & 2 deletions src/dotnet/Common/Models/ResourceProviders/Agent/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ public class AgentBase : ResourceBase


/// <summary>
/// Gets or sets a dictionary of tools that are registered with the agent.
/// Gets or sets a list of tools that are registered with the agent.
/// </summary>
/// <remarks>
/// The key is the name of the tool, and the value is the <see cref="AgentTool"/> object.
/// </remarks>
[JsonPropertyName("tools")]
public Dictionary<string, AgentTool> Tools { get; set; } = [];
public AgentTool[] Tools { get; set; } = [];

/// <summary>
/// The object type of the agent.
Expand Down
10 changes: 5 additions & 5 deletions src/dotnet/Orchestration/Orchestration/OrchestrationBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,12 @@ await cosmosDBService.PatchOperationsItemPropertiesAsync<LongRunningOperationCon

List<string> toolNames = [];

foreach (var toolName in agentBase.Tools.Keys)
foreach (var tool in agentBase.Tools)
{
toolNames.Add(toolName);
explodedObjects[toolName] = agentBase.Tools[toolName];
toolNames.Add(tool.Name);
explodedObjects[tool.Name] = tool;

foreach (var aiModelObjectId in agentBase.Tools[toolName].AIModelObjectIds.Values)
foreach (var aiModelObjectId in tool.AIModelObjectIds.Values)
{
var toolAIModel = await aiModelResourceProvider.GetResourceAsync<AIModelBase>(
aiModelObjectId,
Expand All @@ -341,7 +341,7 @@ await cosmosDBService.PatchOperationsItemPropertiesAsync<LongRunningOperationCon
explodedObjects[toolAIModel.EndpointObjectId!] = toolAPIEndpointConfiguration;
}

foreach (var apiEndpointConfigurationObjectId in agentBase.Tools[toolName].APIEndpointConfigurationObjectIds.Values)
foreach (var apiEndpointConfigurationObjectId in tool.APIEndpointConfigurationObjectIds.Values)
{
var toolAPIEndpointConfiguration = await configurationResourceProvider.GetResourceAsync<APIEndpointConfiguration>(
apiEndpointConfigurationObjectId,
Expand Down
3 changes: 3 additions & 0 deletions src/python/PythonSDK/PythonSDK.pyproj
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@
<Compile Include="foundationallm\models\resource_providers\vectorization\profile_base.py" />
<Compile Include="foundationallm\models\resource_providers\vectorization\__init__.py" />
<Compile Include="foundationallm\models\resource_providers\__init__.py" />
<Compile Include="foundationallm\langchain\tools\dalle_image_generation_tool.py" />
<Compile Include="foundationallm\langchain\tools\fllm_tool.py" />
<Compile Include="foundationallm\langchain\tools\tool_factory.py" />
<Compile Include="foundationallm\utils\object_utils.py" />
<Compile Include="foundationallm\utils\openai_assistants_helpers.py" />
<Compile Include="foundationallm\utils\__init__.py" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
AgentConversationHistorySettings,
KnowledgeManagementAgent,
KnowledgeManagementCompletionRequest,
KnowledgeManagementIndexConfiguration
KnowledgeManagementIndexConfiguration,
AgentTool
)
from foundationallm.models.attachments import AttachmentProviders
from foundationallm.models.authentication import AuthenticationTypes
Expand All @@ -40,6 +41,8 @@
from foundationallm.services.gateway_text_embedding import GatewayTextEmbeddingService
from openai.types import CompletionUsage

from foundationallm.langchain.tools import ToolFactory

class LangChainKnowledgeManagementAgent(LangChainAgentBase):
"""
The LangChain Knowledge Management agent.
Expand Down Expand Up @@ -375,17 +378,17 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
)

image_service = None
if "dalle-image-generation" in request.agent.tools:
dalle_tool = request.agent.tools["dalle-image-generation"]
model_object_id = dalle_tool["ai_model_object_ids"][self.MAIN_MODEL_KEY]
if any(tool.name == "DALLEImageGeneration" for tool in request.agent.tools):
dalle_tool = next((tool for tool in request.agent.tools if tool.name == "DALLEImageGeneration"), None)
model_object_id = dalle_tool.ai_model_object_ids[self.MAIN_MODEL_KEY]
image_generation_deployment_model = request.objects[model_object_id]["deployment_name"]
api_endpoint_object_id = request.objects[model_object_id]["endpoint_object_id"]
image_generation_client = self._get_image_gen_language_model(api_endpoint_object_id=api_endpoint_object_id, objects=request.objects)
image_service=ImageService(
config=self.config,
client=image_generation_client,
deployment_name=image_generation_deployment_model,
image_generator_tool_description=dalle_tool["description"])
image_generator_tool_description=dalle_tool.description)

# invoke/run the service
assistant_response = await assistant_svc.run_async(
Expand Down Expand Up @@ -419,22 +422,14 @@ async def invoke_async(self, request: KnowledgeManagementCompletionRequest) -> C
# End Assistants API implementation

# Start LangGraph ReAct Agent workflow implementation
if (agent.workflow is not None and isinstance(agent.workflow, LangGraphReactAgentWorkflow)):
# Temporary placeholder
from typing import Literal
from langchain_core.tools import tool
@tool
def get_weather(city: Literal["nyc", "sf"]):
"""Use this to get weather information."""
if city == "nyc":
return "It might be cloudy in nyc"
elif city == "sf":
return "It's always sunny in sf"
else:
raise AssertionError("Unknown city")
tools = [get_weather]
# End temporary placeholder

if (agent.workflow is not None and isinstance(agent.workflow, LangGraphReactAgentWorkflow)):
tool_factory = ToolFactory()
tools = []

# Populate tools list from agent configuration
for tool in agent.tools:
tools.append(tool_factory.get_tool(tool, request.objects, self.config))

# Define the graph
graph = create_react_agent(llm, tools=tools, state_modifier=self.prompt.prefix)
messages = self._build_conversation_history_message_list(request.message_history, agent.conversation_history_settings.max_history)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .secure_sql_database_query_tool import SecureSQLDatabaseQueryTool
from .query_pandas_dataframe_tool import QueryPandasDataFrameTool
from .query_pandas_dataframe_tool import TypeConversionTool
from .fllm_tool import FLLMToolBase
from .dalle_image_generation_tool import DALLEImageGenerationTool
from .tool_factory import ToolFactory
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import json
from enum import Enum
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from langchain_core.callbacks import AsyncCallbackManagerForToolRun
from langchain_core.tools import ToolException
from openai import AsyncAzureOpenAI
from pydantic import BaseModel, Field
from typing import Optional, Type

from .fllm_tool import FLLMToolBase
from foundationallm.config import Configuration
from foundationallm.models.agents import AgentTool
from foundationallm.models.resource_providers.ai_models import AIModelBase
from foundationallm.models.resource_providers.configuration import APIEndpointConfiguration
from foundationallm.utils import ObjectUtils

class DALLEImageGenerationToolQualityEnum(str, Enum):
""" Enum for the quality parameter of the DALL-E image generation tool. """
standard = "standard"
hd = "hd"

class DALLEImageGenerationToolStyleEnum(str, Enum):
""" Enum for the style parameter of the DALL-E image generation tool. """
natural = "natural"
vivid = "vivid"

class DALLEImageGenerationToolSizeEnum(str, Enum):
""" Enum for the size parameter of the DALL-E image generation tool. """
size1024x1024 = "1024x1024"
size1792x1024 = "1792x1024"
size1024x1792 = "1024x1792"

class DALLEImageGenerationToolInput(BaseModel):
""" Input data model for the DALL-E image generation tool. """
prompt: str = Field(description="Prompt for the DALL-E image generation tool.", example="A cat in the forest.")
n: int = Field(description="Number of images to generate.", example=1, default=1)
quality: DALLEImageGenerationToolQualityEnum = Field(description="Quality of the generated images.", default=DALLEImageGenerationToolQualityEnum.hd)
style: DALLEImageGenerationToolStyleEnum = Field(description="Style of the generated images.", default=DALLEImageGenerationToolStyleEnum.natural)
size: DALLEImageGenerationToolSizeEnum = Field(description="Size of the generated images.", default=DALLEImageGenerationToolSizeEnum.size1024x1024)

class DALLEImageGenerationTool(FLLMToolBase):
"""
DALL-E image generation tool.
Supports only Azure Identity authentication.
"""
args_schema: Type[BaseModel] = DALLEImageGenerationToolInput

def __init__(self, tool_config: AgentTool, objects: dict, config: Configuration):
""" Initializes the DALLEImageGenerationTool class with the tool configuration,
exploded objects collection, and platform configuration. """
super().__init__(tool_config, objects, config)
self.ai_model = ObjectUtils.get_object_by_id(self.tool_config.ai_model_object_ids["main_model"], self.objects, AIModelBase)
self.api_endpoint = ObjectUtils.get_object_by_id(self.ai_model.endpoint_object_id, self.objects, APIEndpointConfiguration)
self.client = self._get_client()

def _run(self,
prompt: str,
n: int,
quality: DALLEImageGenerationToolQualityEnum,
style: DALLEImageGenerationToolStyleEnum,
size: DALLEImageGenerationToolSizeEnum,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
raise ToolException("This tool does not support synchronous execution. Please use the async version of the tool.")

async def _arun(self,
prompt: str,
n: int,
quality: DALLEImageGenerationToolQualityEnum,
style: DALLEImageGenerationToolStyleEnum,
size: DALLEImageGenerationToolSizeEnum,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""
Generate an image using the Azure OpenAI client.
"""
print(f'Attempting to generate {n} images with a style of {style}, quality of {quality}, and a size of {size}.')
try:
result = await self.client.images.generate(
model = self.ai_model.deployment_name,
prompt = prompt,
n = n,
quality = quality,
style = style,
size = size
)
return json.loads(result.model_dump_json())
except Exception as e:
print(f'Image generation error code and message: {e.code}; {e}')
# Specifically handle content policy violation errors.
if e.code in ['contentFilter', 'content_policy_violation']:
err = e.message[e.message.find("{"):e.message.rfind("}")+1]
err_json = err.replace("'", '"')
err_json = err_json.replace("True", "true").replace("False", "false")
obj = json.loads(err_json)
cfr = obj['error']['inner_error']['content_filter_results']
filtered = [k for k, v in cfr.items() if v['filtered']]
error_fmt = f"The image generation request resulted in a content policy violation for the following category: {', '.join(filtered)}"
raise ToolException(error_fmt)
elif e.code in ['invalidPayload', 'invalid_payload']:
raise ToolException(f'The image generation request is invalid: {e.message}')
else:
raise ToolException(f"An {e.code} error occurred while attempting to generate the requested image: {e.message}")

def _get_client(self):
"""
Returns the an AsyncOpenAI client for DALL-E image generation.
"""
scope = self.api_endpoint.authentication_parameters.get('scope', 'https://cognitiveservices.azure.com/.default')
# Set up a Azure AD token provider.
token_provider = get_bearer_token_provider(
DefaultAzureCredential(exclude_environment_credential=True),
scope
)
return AsyncAzureOpenAI(
azure_endpoint = self.api_endpoint.url,
api_version = self.api_endpoint.api_version,
azure_ad_token_provider = token_provider
)
27 changes: 27 additions & 0 deletions src/python/PythonSDK/foundationallm/langchain/tools/fllm_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Class: FLLMToolBase
Description: FoundationaLLM base class for tools that uses the AgentTool model for its configuration.
"""
from langchain_core.tools import BaseTool
from foundationallm.config import Configuration
from foundationallm.langchain.exceptions import LangChainException
from foundationallm.models.agents import AgentTool
from foundationallm.models.resource_providers.configuration import APIEndpointConfiguration

class FLLMToolBase(BaseTool):
"""
FoundationaLLM base class for tools that uses the AgentTool model for its configuration.
"""
def __init__(self, tool_config: AgentTool, objects:dict, config: Configuration):
""" Initializes the FLLMToolBase class with the tool configuration. """
super().__init__(
name=tool_config.name,
description=tool_config.description
)
self.tool_config = tool_config
self.config = config
self.objects = objects

class Config:
""" Pydantic configuration for FLLMToolBase. """
extra = "allow"
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Class: ToolFactory
Description: Factory class for creating tools based on the AgentTool configuration.
"""
from foundationallm.config import Configuration
from foundationallm.langchain.exceptions import LangChainException
from foundationallm.models.agents import AgentTool
from foundationallm.langchain.tools import FLLMToolBase, DALLEImageGenerationTool

class ToolFactory:
"""
Factory class for creating tools based on the AgentTool configuration.
"""
FLLM_PACKAGE_NAME = "FoundationaLLM"
DALLE_IMAGE_GENERATION_TOOL_NAME = "DALLEImageGenerationTool"

def get_tool(
self,
tool_config: AgentTool,
objects: dict,
config: Configuration
) -> FLLMToolBase:
"""
Creates an instance of a tool based on the tool configuration.
"""
if tool_config.package_name == self.FLLM_PACKAGE_NAME:
# internal tools
match tool_config.name:
case DALLE_IMAGE_GENERATION_TOOL_NAME:
return DALLEImageGenerationTool(tool_config, objects, config)
# else: external tools

raise LangChainException(f"Tool {tool_config.name} not found in package {tool_config.package_name}")

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
AgentOrchestrationSettings,
AzureOpenAIAssistantsAgentWorkflow,
LangChainExpressionLanguageAgentWorkflow,
LangGraphReactAgentWorkflow
LangGraphReactAgentWorkflow,
AgentTool
)
from foundationallm.models.resource_providers import ResourceBase

Expand All @@ -22,7 +23,7 @@ class AgentBase(ResourceBase):
prompt_object_id: Optional[str] = Field(default=None, description="The object identifier of the Prompt object providing the prompt for the agent.")
ai_model_object_id: Optional[str] = Field(default=None, description="The object identifier of the AIModelBase object providing the AI model for the agent.")
capabilities:Optional[List[str]] = Field(default=[], description="The capabilities of the agent.")
tools: Optional[dict] = Field(default=[], description="A dictionary object with assigned agent tools.")
tools: Optional[List[AgentTool]] = Field(default=[], description="A list of assigned agent tools.")
workflow: Optional[
Annotated [
Union[AzureOpenAIAssistantsAgentWorkflow, LangChainExpressionLanguageAgentWorkflow, LangGraphReactAgentWorkflow],
Expand Down
32 changes: 32 additions & 0 deletions src/python/PythonSDK/foundationallm/utils/object_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import re
from typing import Type, TypeVar, Dict
from foundationallm.langchain.exceptions import LangChainException

T = TypeVar('T') # Generic type variable

class ObjectUtils:

Expand All @@ -20,3 +24,31 @@ def translate_keys(obj):
return [ObjectUtils.translate_keys(item) for item in obj] # Apply to each item in the list
else:
return obj # Return the item itself if it's not a dict or list

@staticmethod
def get_object_by_id(object_id: str, objects: dict, object_type: Type[T]) -> T:
"""
Generic method to retrieve an object of a specified type from a dictionary by its ID.
Args:
object_id (str): The ID of the object to retrieve.
objects (dict): A dictionary containing object data.
object_type (Type[T]): The type of the object to construct.
Returns:
T: An instance of the specified type.
Raises:
LangChainException: If the object ID is invalid or the object cannot be constructed.
"""
if not object_id:
raise LangChainException("Invalid object ID.", 400)

object_data = objects.get(object_id)
if not object_data:
raise LangChainException(f"Object with ID '{object_id}' not found in the dictionary.", 400)

try:
return object_type(**object_data)
except Exception as e:
raise LangChainException(f"Failed to construct object of type '{object_type.__name__}': {str(e)}", 400)
Loading

0 comments on commit c8ec3fc

Please sign in to comment.