diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 074d8587cf..f89e40cf84 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 # -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT """Create a OpenAI-compatible client for Gemini features. @@ -38,6 +38,8 @@ from __future__ import annotations import base64 +import copy +import json import logging import os import random @@ -45,24 +47,34 @@ import time import warnings from io import BytesIO -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import google.generativeai as genai +import PIL import requests import vertexai -from google.ai.generativelanguage import Content, Part +from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool +from google.ai.generativelanguage_v1beta.types import Schema from google.auth.credentials import Credentials -from openai.types.chat import ChatCompletion +from jsonschema import ValidationError +from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage from PIL import Image from pydantic import BaseModel -from vertexai.generative_models import Content as VertexAIContent +from vertexai.generative_models import ( + Content as VertexAIContent, +) +from vertexai.generative_models import FunctionDeclaration as vaiFunctionDeclaration from vertexai.generative_models import GenerativeModel from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold from vertexai.generative_models import HarmCategory as VertexAIHarmCategory +from vertexai.generative_models import Image as VertexAIImage from vertexai.generative_models import Part as VertexAIPart from vertexai.generative_models import SafetySetting as VertexAISafetySetting +from vertexai.generative_models import ( + Tool as vaiTool, +) logger = logging.getLogger(__name__) @@ -164,6 +176,7 @@ def get_usage(response) -> Dict: } def create(self, params: Dict) -> ChatCompletion: + if self.use_vertexai: self._initialize_vertexai(**params) else: @@ -171,7 +184,12 @@ def create(self, params: Dict) -> ChatCompletion: "location" not in params ), "Google Cloud project and compute location cannot be set when using an API Key!" model_name = params.get("model", "gemini-pro") - if not model_name: + + if model_name == "gemini-pro-vision": + raise ValueError( + "Gemini 1.0 Pro vision ('gemini-pro-vision') has been deprecated, please consider switching to a different model, for example 'gemini-1.5-flash'." + ) + elif not model_name: raise ValueError( "Please provide a model name for the Gemini Client. " "You can configure it in the OAI Config List file. " @@ -184,6 +202,10 @@ def create(self, params: Dict) -> ChatCompletion: n_response = params.get("n", 1) system_instruction = params.get("system_instruction", None) response_validation = params.get("response_validation", True) + if "tools" in params: + tools = self._tools_to_gemini_tools(params["tools"]) + else: + tools = None generation_config = { gemini_term: params[autogen_term] @@ -200,77 +222,92 @@ def create(self, params: Dict) -> ChatCompletion: "Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.", UserWarning, ) + stream = False if n_response > 1: warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning) - if "vision" not in model_name: - # A. create and call the chat model. - gemini_messages = self._oai_messages_to_gemini_messages(messages) - if self.use_vertexai: - model = GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) - else: - # we use chat model by default - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - genai.configure(api_key=self.api_key) - chat = model.start_chat(history=gemini_messages[:-1]) - - response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings) - ans: str = chat.history[-1].parts[0].text - prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens - completion_tokens = model.count_tokens(ans).total_tokens - elif model_name == "gemini-pro-vision": - # B. handle the vision model - if self.use_vertexai: - model = GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - else: - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - genai.configure(api_key=self.api_key) - # Gemini's vision model does not support chat history yet - # chat = model.start_chat(history=gemini_messages[:-1]) - # response = chat.send_message(gemini_messages[-1].parts) - user_message = self._oai_content_to_gemini_content(messages[-1]["content"]) - if len(messages) > 2: - warnings.warn( - "Warning: Gemini's vision model does not support chat history yet.", - "We only use the last message as the prompt.", - UserWarning, - ) + autogen_tool_calls = [] - response = model.generate_content(user_message, stream=stream) - # ans = response.text - if self.use_vertexai: - ans: str = response.candidates[0].content.parts[0].text - else: - ans: str = response._result.candidates[0].content.parts[0].text + # Maps the function call ids to function names so we can inject it into FunctionResponse messages + self.tool_call_function_map: Dict[str, str] = {} - prompt_tokens = model.count_tokens(user_message).total_tokens - completion_tokens = model.count_tokens(ans).total_tokens + # A. create and call the chat model. + gemini_messages = self._oai_messages_to_gemini_messages(messages) + if self.use_vertexai: + model = GenerativeModel( + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, + tools=tools, + ) + + chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) + else: + model = genai.GenerativeModel( + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, + tools=tools, + ) + + genai.configure(api_key=self.api_key) + chat = model.start_chat(history=gemini_messages[:-1]) + + response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings) + + # Extract text and tools from response + ans = "" + random_id = random.randint(0, 10000) + prev_function_calls = [] + for part in response.parts: + + # Function calls + if fn_call := part.function_call: + + # If we have a repeated function call, ignore it + if fn_call not in prev_function_calls: + autogen_tool_calls.append( + ChatCompletionMessageToolCall( + id=random_id, + function={ + "name": fn_call.name, + "arguments": ( + json.dumps({key: val for key, val in fn_call.args.items()}) + if fn_call.args is not None + else "" + ), + }, + type="function", + ) + ) + + prev_function_calls.append(fn_call) + random_id += 1 + + # Plain text content + elif text := part.text: + ans += text + + # If we have function calls, ignore the text + # as it can be Gemini guessing the function response + if len(autogen_tool_calls) != 0: + ans = "" + else: + autogen_tool_calls = None + + prompt_tokens = response.usage_metadata.prompt_token_count + completion_tokens = response.usage_metadata.candidates_token_count # 3. convert output - message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None) - choices = [Choice(finish_reason="stop", index=0, message=message)] + message = ChatCompletionMessage( + role="assistant", content=ans, function_call=None, tool_calls=autogen_tool_calls + ) + choices = [ + Choice(finish_reason="tool_calls" if autogen_tool_calls is not None else "stop", index=0, message=message) + ] response_oai = ChatCompletion( id=str(random.randint(0, 1000)), @@ -283,50 +320,105 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), - cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name), + cost=calculate_gemini_cost(self.use_vertexai, prompt_tokens, completion_tokens, model_name), ) return response_oai - def _oai_content_to_gemini_content(self, content: Union[str, List]) -> List: - """Convert content from OAI format to Gemini format""" + def _oai_content_to_gemini_content(self, message: Dict[str, Any]) -> Tuple[List, str]: + """Convert AutoGen content to Gemini parts, catering for text and tool calls""" rst = [] - if isinstance(content, str): + + if message["role"] == "tool": + # Tool call recommendation + + function_name = self.tool_call_function_map[message["tool_call_id"]] + + if self.use_vertexai: + rst.append( + VertexAIPart.from_function_response( + name=function_name, response={"result": self._to_json_or_str(message["content"])} + ) + ) + else: + rst.append( + Part( + function_response=FunctionResponse( + name=function_name, response={"result": self._to_json_or_str(message["content"])} + ) + ) + ) + + return rst, "tool" + elif "tool_calls" in message and len(message["tool_calls"]) != 0: + for tool_call in message["tool_calls"]: + + function_id = tool_call["id"] + function_name = tool_call["function"]["name"] + self.tool_call_function_map[function_id] = function_name + + if self.use_vertexai: + rst.append( + VertexAIPart.from_dict( + { + "functionCall": { + "name": function_name, + "args": json.loads(tool_call["function"]["arguments"]), + } + } + ) + ) + else: + rst.append( + Part( + function_call=FunctionCall( + name=function_name, + args=json.loads(tool_call["function"]["arguments"]), + ) + ) + ) + + return rst, "tool_call" + + elif isinstance(message["content"], str): + content = message["content"] if content == "": content = "empty" # Empty content is not allowed. if self.use_vertexai: rst.append(VertexAIPart.from_text(content)) else: rst.append(Part(text=content)) - return rst - - assert isinstance(content, list) - for msg in content: - if isinstance(msg, dict): - assert "type" in msg, f"Missing 'type' field in message: {msg}" - if msg["type"] == "text": - if self.use_vertexai: - rst.append(VertexAIPart.from_text(text=msg["text"])) + return rst, "text" + + # For images the message contains a list of text items + if isinstance(message["content"], list): + has_image = False + for msg in message["content"]: + if isinstance(msg, dict): + assert "type" in msg, f"Missing 'type' field in message: {msg}" + if msg["type"] == "text": + if self.use_vertexai: + rst.append(VertexAIPart.from_text(text=msg["text"])) + else: + rst.append(Part(text=msg["text"])) + elif msg["type"] == "image_url": + if self.use_vertexai: + img_url = msg["image_url"]["url"] + img_part = VertexAIPart.from_uri(img_url, mime_type="image/png") + rst.append(img_part) + else: + b64_img = get_image_data(msg["image_url"]["url"]) + rst.append(Part(inline_data={"mime_type": "image/png", "data": b64_img})) + + has_image = True else: - rst.append(Part(text=msg["text"])) - elif msg["type"] == "image_url": - if self.use_vertexai: - img_url = msg["image_url"]["url"] - re.match(r"data:image/(?:png|jpeg);base64,", img_url) - img = get_image_data(img_url, use_b64=False) - # image/png works with jpeg as well - img_part = VertexAIPart.from_data(img, mime_type="image/png") - rst.append(img_part) - else: - b64_img = get_image_data(msg["image_url"]["url"]) - img = _to_pil(b64_img) - rst.append(img) + raise ValueError(f"Unsupported message type: {msg['type']}") else: - raise ValueError(f"Unsupported message type: {msg['type']}") - else: - raise ValueError(f"Unsupported message type: {type(msg)}") - return rst + raise ValueError(f"Unsupported message type: {type(msg)}") + return rst, "image" if has_image else "text" + else: + raise Exception("Unable to convert content to Gemini format.") def _concat_parts(self, parts: List[Part]) -> List: """Concatenate parts with the same type. @@ -362,39 +454,178 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li Make sure the "user" role and "model" role are interleaved. Also, make sure the last item is from the "user" role. """ - prev_role = None rst = [] - curr_parts = [] - for i, message in enumerate(messages): - parts = self._oai_content_to_gemini_content(message["content"]) + for message in messages: + parts, part_type = self._oai_content_to_gemini_content(message) role = "user" if message["role"] in ["user", "system"] else "model" - if (prev_role is None) or (role == prev_role): - curr_parts += parts - elif role != prev_role: - if self.use_vertexai: - rst.append(VertexAIContent(parts=curr_parts, role=prev_role)) - else: - rst.append(Content(parts=curr_parts, role=prev_role)) - curr_parts = parts - prev_role = role - # handle the last message - if self.use_vertexai: - rst.append(VertexAIContent(parts=curr_parts, role=role)) - else: - rst.append(Content(parts=curr_parts, role=role)) + if part_type == "text": + rst.append( + VertexAIContent(parts=parts, role=role) + if self.use_vertexai + else rst.append(Content(parts=parts, role=role)) + ) + elif part_type == "tool": + rst.append( + VertexAIContent(parts=parts, role="function") + if self.use_vertexai + else rst.append(Content(parts=parts, role="function")) + ) + elif part_type == "tool_call": + rst.append( + VertexAIContent(parts=parts, role="function") + if self.use_vertexai + else rst.append(Content(parts=parts, role="function")) + ) + elif part_type == "image": + # Image has multiple parts, some can be text and some can be image based + text_parts = [] + image_parts = [] + for part in parts: + if isinstance(part, Part): + # Text or non-Vertex AI image part + text_parts.append(part) + elif isinstance(part, VertexAIPart): + # Image + image_parts.append(part) + else: + raise Exception("Unable to process image part") + + if len(text_parts) > 0: + rst.append( + VertexAIContent(parts=text_parts, role=role) + if self.use_vertexai + else rst.append(Content(parts=text_parts, role=role)) + ) + + if len(image_parts) > 0: + rst.append( + VertexAIContent(parts=image_parts, role=role) + if self.use_vertexai + else rst.append(Content(parts=image_parts, role=role)) + ) + + if len(rst) != 0 and rst[-1] is None: + rst.pop() # The Gemini is restrict on order of roles, such that # 1. The messages should be interleaved between user and model. # 2. The last message must be from the user role. # We add a dummy message "continue" if the last role is not the user. - if rst[-1].role != "user": + if rst[-1].role not in ["user", "function"]: + text_part, type = self._oai_content_to_gemini_content({"content": "continue"}) + rst.append( + VertexAIContent(parts=text_part, role="user") + if self.use_vertexai + else Content(parts=text_part, role="user") + ) + + return rst + + def _tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]: + """Create Gemini tools (as typically requires Callables)""" + + functions = [] + for tool in tools: if self.use_vertexai: - rst.append(VertexAIContent(parts=self._oai_content_to_gemini_content("continue"), role="user")) + function = vaiFunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"]["description"], + parameters=tool["function"]["parameters"], + ) else: - rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user")) + function = GeminiClient._create_gemini_function_declaration(tool) + functions.append(function) - return rst + if self.use_vertexai: + return [vaiTool(function_declarations=functions)] + else: + return [Tool(function_declarations=functions)] + + @staticmethod + def _create_gemini_function_declaration(tool: Dict) -> FunctionDeclaration: + function_declaration = FunctionDeclaration() + function_declaration.name = tool["function"]["name"] + function_declaration.description = tool["function"]["description"] + if len(tool["function"]["parameters"]["properties"]) != 0: + function_declaration.parameters = GeminiClient._create_gemini_function_parameters( + copy.deepcopy(tool["function"]["parameters"]) + ) + + return function_declaration + + @staticmethod + def _create_gemini_function_declaration_schema(json_data) -> Schema: + """Recursively creates Schema objects for FunctionDeclaration.""" + param_schema = Schema() + param_type = json_data["type"] + + """ + TYPE_UNSPECIFIED = 0 + STRING = 1 + INTEGER = 2 + NUMBER = 3 + OBJECT = 4 + ARRAY = 5 + BOOLEAN = 6 + """ + + if param_type == "integer": + param_schema.type_ = 2 + elif param_type == "number": + param_schema.type_ = 3 + elif param_type == "string": + param_schema.type_ = 1 + elif param_type == "boolean": + param_schema.type_ = 6 + elif param_type == "array": + param_schema.type_ = 5 + if "items" in json_data: + param_schema.items = GeminiClient._create_gemini_function_declaration_schema(json_data["items"]) + else: + print("Warning: Array schema missing 'items' definition.") + elif param_type == "object": + param_schema.type_ = 4 + param_schema.properties = {} + if "properties" in json_data: + for prop_name, prop_data in json_data["properties"].items(): + param_schema.properties[prop_name] = GeminiClient._create_gemini_function_declaration_schema( + prop_data + ) + else: + print("Warning: Object schema missing 'properties' definition.") + + elif param_type in ("null", "any"): + param_schema.type_ = 1 # Treating these as strings for simplicity + else: + print(f"Warning: Unsupported parameter type '{param_type}'.") + + if "description" in json_data: + param_schema.description = json_data["description"] + + return param_schema + + def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> dict[str, any]: + """Convert function parameters to Gemini format, recursive""" + + function_parameter["type_"] = function_parameter["type"].upper() + + # Parameter properties and items + if "properties" in function_parameter: + for key in function_parameter["properties"]: + function_parameter["properties"][key] = GeminiClient._create_gemini_function_parameters( + function_parameter["properties"][key] + ) + + if "items" in function_parameter: + function_parameter["items"] = GeminiClient._create_gemini_function_parameters(function_parameter["items"]) + + # Remove any attributes not needed + for attr in ["type", "default"]: + if attr in function_parameter: + del function_parameter[attr] + + return function_parameter @staticmethod def _to_vertexai_safety_settings(safety_settings): @@ -425,21 +656,13 @@ def _to_vertexai_safety_settings(safety_settings): else: return safety_settings - -def _to_pil(data: str) -> Image.Image: - """ - Converts a base64 encoded image data string to a PIL Image object. - - This function first decodes the base64 encoded string to bytes, then creates a BytesIO object from the bytes, - and finally creates and returns a PIL Image object from the BytesIO object. - - Parameters: - data (str): The base64 encoded image data string. - - Returns: - Image.Image: The PIL Image object created from the input data. - """ - return Image.open(BytesIO(base64.b64decode(data))) + @staticmethod + def _to_json_or_str(data: str) -> Union[Dict, str]: + try: + json_data = json.loads(data) + return json_data + except (json.JSONDecodeError, ValidationError): + return data def get_image_data(image_file: str, use_b64=True) -> bytes: @@ -460,14 +683,76 @@ def get_image_data(image_file: str, use_b64=True) -> bytes: return content -def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float: - if "1.5" in model_name or "gemini-experimental" in model_name: - # "gemini-1.5-pro-preview-0409" - # Cost is $7 per million input tokens and $21 per million output tokens - return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 +def calculate_gemini_cost(use_vertexai: bool, input_tokens: int, output_tokens: int, model_name: str) -> float: + + def total_cost_mil(cost_per_mil_input: float, cost_per_mil_output: float): + # Cost per million + return cost_per_mil_input * input_tokens / 1e6 + cost_per_mil_output * output_tokens / 1e6 + + def total_cost_k(cost_per_k_input: float, cost_per_k_output: float): + # Cost per thousand + return cost_per_k_input * input_tokens / 1e3 + cost_per_k_output * output_tokens / 1e3 + + model_name = model_name.lower() + up_to_128k = input_tokens <= 128000 + + if use_vertexai: + # Vertex AI pricing - based on Text input + # https://cloud.google.com/vertex-ai/generative-ai/pricing#vertex-ai-pricing + + if "gemini-1.5-flash" in model_name: + if up_to_128k: + return total_cost_k(0.00001875, 0.000075) + else: + return total_cost_k(0.0000375, 0.00015) + + elif "gemini-1.5-pro" in model_name: + if up_to_128k: + return total_cost_k(0.0003125, 0.00125) + else: + return total_cost_k(0.000625, 0.0025) + + elif "gemini-1.0-pro" in model_name: + return total_cost_k(0.000125, 0.00001875) + + else: + warnings.warn( + f"Cost calculation is not implemented for model {model_name}. Cost will be calculated zero.", + UserWarning, + ) + return 0 + + else: + # Non-Vertex AI pricing + + if "gemini-1.5-flash-8b" in model_name: + # https://ai.google.dev/pricing#1_5flash-8B + if up_to_128k: + return total_cost_mil(0.0375, 0.15) + else: + return total_cost_mil(0.075, 0.3) - if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name: - warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning) + elif "gemini-1.5-flash" in model_name: + # https://ai.google.dev/pricing#1_5flash + if up_to_128k: + return total_cost_mil(0.075, 0.3) + else: + return total_cost_mil(0.15, 0.6) + + elif "gemini-1.5-pro" in model_name: + # https://ai.google.dev/pricing#1_5pro + if up_to_128k: + return total_cost_mil(1.25, 5.0) + else: + return total_cost_mil(2.50, 10.0) + + elif "gemini-1.0-pro" in model_name: + # https://ai.google.dev/pricing#1_5pro + return total_cost_mil(0.50, 1.5) - # Cost is $0.5 per million input tokens and $1.5 per million output tokens - return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6 + else: + warnings.warn( + f"Cost calculation is not implemented for model {model_name}. Cost will be calculated zero.", + UserWarning, + ) + return 0 diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index b5b84cd028..7bc834b3bd 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# Copyright (c) 2023 - 2024, Owners of https://github.com/autogenhub # # SPDX-License-Identifier: Apache-2.0 # @@ -114,28 +114,23 @@ def test_gemini_message_handling(gemini_client): {"role": "model", "content": "How can I help you?"}, {"role": "user", "content": "Which planet is the nearest to the sun?"}, {"role": "user", "content": "Which planet is the farthest from the sun?"}, - {"role": "model", "content": "Mercury is the closest palnet to the sun."}, - {"role": "model", "content": "Neptune is the farthest palnet from the sun."}, + {"role": "model", "content": "Mercury is the closest planet to the sun."}, + {"role": "model", "content": "Neptune is the farthest planet from the sun."}, {"role": "user", "content": "How can we determine the mass of a black hole?"}, ] # The datastructure below defines what the structure of the messages # should resemble after converting to Gemini format. - # Messages of similar roles are expected to be merged to a single message, - # where the contents of the original messages will be included in - # consecutive parts of the converted Gemini message + # Historically it has merged messages and ensured alternating roles, + # this no longer appears to be required by the Gemini API expected_gemini_struct = [ # system role is converted to user role {"role": "user", "parts": ["You are my personal assistant."]}, {"role": "model", "parts": ["How can I help you?"]}, - { - "role": "user", - "parts": ["Which planet is the nearest to the sun?", "Which planet is the farthest from the sun?"], - }, - { - "role": "model", - "parts": ["Mercury is the closest palnet to the sun.", "Neptune is the farthest palnet from the sun."], - }, + {"role": "user", "parts": ["Which planet is the nearest to the sun?"]}, + {"role": "user", "parts": ["Which planet is the farthest from the sun?"]}, + {"role": "model", "parts": ["Mercury is the closest planet to the sun."]}, + {"role": "model", "parts": ["Neptune is the farthest planet from the sun."]}, {"role": "user", "parts": ["How can we determine the mass of a black hole?"]}, ] @@ -286,22 +281,33 @@ def test_cost_calculation(gemini_client, mock_response): @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.genai.GenerativeModel") -@patch("autogen.oai.gemini.genai.configure") -def test_create_response(mock_configure, mock_generative_model, gemini_client): +# @patch("autogen.oai.gemini.genai.configure") +@patch("autogen.oai.gemini.calculate_gemini_cost") +def test_create_response_with_text(mock_calculate_cost, mock_generative_model, gemini_client): # Mock the genai model configuration and creation process mock_chat = MagicMock() mock_model = MagicMock() - mock_configure.return_value = None + # mock_configure.return_value = None mock_generative_model.return_value = mock_model mock_model.start_chat.return_value = mock_chat - # Set up a mock for the chat history item access and the text attribute return - mock_history_part = MagicMock() - mock_history_part.text = "Example response" - mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part + # Set up mock token counts with real integers + mock_usage_metadata = MagicMock() + mock_usage_metadata.prompt_token_count = 100 + mock_usage_metadata.candidates_token_count = 50 + + # Setup the mock to return a response with only text content + mock_text_part = MagicMock() + mock_text_part.text = "Example response" + mock_text_part.function_call = None - # Setup the mock to return a mocked chat response - mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) + mock_response = MagicMock() + mock_response.parts = [mock_text_part] + mock_response.usage_metadata = mock_usage_metadata + mock_chat.send_message.return_value = mock_response + + # Mock the calculate_gemini_cost function + mock_calculate_cost.return_value = 0.002 # Call the create method response = gemini_client.create( @@ -309,13 +315,25 @@ def test_create_response(mock_configure, mock_generative_model, gemini_client): ) # Assertions to check if response is structured as expected + # assert isinstance(response, ChatCompletion), "Response should be an instance of ChatCompletion" assert response.choices[0].message.content == "Example response", "Response content should match expected output" + assert not response.choices[0].message.tool_calls, "There should be no tool calls" + assert response.usage.prompt_tokens == 100, "Prompt tokens should match the mocked value" + assert response.usage.completion_tokens == 50, "Completion tokens should match the mocked value" + assert response.usage.total_tokens == 150, "Total tokens should be the sum of prompt and completion tokens" + assert response.cost == 0.002, "Cost should match the mocked calculate_gemini_cost return value" + + # Verify that calculate_gemini_cost was called with the correct arguments + mock_calculate_cost.assert_called_once_with(False, 100, 50, "gemini-pro") @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.GenerativeModel") @patch("autogen.oai.gemini.vertexai.init") -def test_vertexai_create_response(mock_init, mock_generative_model, gemini_client_with_credentials): +@patch("autogen.oai.gemini.calculate_gemini_cost") +def test_vertexai_create_response( + mock_calculate_cost, mock_init, mock_generative_model, gemini_client_with_credentials +): # Mock the genai model configuration and creation process mock_chat = MagicMock() mock_model = MagicMock() @@ -323,139 +341,37 @@ def test_vertexai_create_response(mock_init, mock_generative_model, gemini_clien mock_generative_model.return_value = mock_model mock_model.start_chat.return_value = mock_chat - # Set up a mock for the chat history item access and the text attribute return - mock_history_part = MagicMock() - mock_history_part.text = "Example response" - mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part - - # Setup the mock to return a mocked chat response - mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) - - # Call the create method - response = gemini_client_with_credentials.create( - {"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False} - ) - - # Assertions to check if response is structured as expected - assert response.choices[0].message.content == "Example response", "Response content should match expected output" - + # Set up mock token counts with real integers + mock_usage_metadata = MagicMock() + mock_usage_metadata.prompt_token_count = 100 + mock_usage_metadata.candidates_token_count = 50 -@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") -@patch("autogen.oai.gemini.GenerativeModel") -@patch("autogen.oai.gemini.vertexai.init") -def test_vertexai_default_auth_create_response(mock_init, mock_generative_model, gemini_google_auth_default_client): - # Mock the genai model configuration and creation process - mock_chat = MagicMock() - mock_model = MagicMock() - mock_init.return_value = None - mock_generative_model.return_value = mock_model - mock_model.start_chat.return_value = mock_chat + # Setup the mock to return a response with only text content + mock_text_part = MagicMock() + mock_text_part.text = "Example response" + mock_text_part.function_call = None - # Set up a mock for the chat history item access and the text attribute return - mock_history_part = MagicMock() - mock_history_part.text = "Example response" - mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part + mock_response = MagicMock() + mock_response.parts = [mock_text_part] + mock_response.usage_metadata = mock_usage_metadata + mock_chat.send_message.return_value = mock_response - # Setup the mock to return a mocked chat response - mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) + # Mock the calculate_gemini_cost function + mock_calculate_cost.return_value = 0.002 # Call the create method - response = gemini_google_auth_default_client.create( + response = gemini_client_with_credentials.create( {"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False} ) # Assertions to check if response is structured as expected + # assert isinstance(response, ChatCompletion), "Response should be an instance of ChatCompletion" assert response.choices[0].message.content == "Example response", "Response content should match expected output" - - -@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") -@patch("autogen.oai.gemini.genai.GenerativeModel") -@patch("autogen.oai.gemini.genai.configure") -def test_create_vision_model_response(mock_configure, mock_generative_model, gemini_client): - # Mock the genai model configuration and creation process - mock_model = MagicMock() - mock_configure.return_value = None - mock_generative_model.return_value = mock_model - - # Set up a mock to simulate the vision model behavior - mock_vision_response = MagicMock() - mock_vision_part = MagicMock(text="Vision model output") - - # Setting up the chain of return values for vision model response - mock_vision_response._result.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = ( - mock_vision_part - ) - mock_model.generate_content.return_value = mock_vision_response - - # Call the create method with vision model parameters - response = gemini_client.create( - { - "model": "gemini-pro-vision", # Vision model name - "messages": [ - { - "content": [ - {"type": "text", "text": "Let's play a game."}, - { - "type": "image_url", - "image_url": { - "url": "" - }, - }, - ], - "role": "user", - } - ], # Assuming a simple content input for vision - "stream": False, - } - ) - - # Assertions to check if response is structured as expected - assert ( - response.choices[0].message.content == "Vision model output" - ), "Response content should match expected output from vision model" - - -@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") -@patch("autogen.oai.gemini.GenerativeModel") -@patch("autogen.oai.gemini.vertexai.init") -def test_vertexai_create_vision_model_response(mock_init, mock_generative_model, gemini_google_auth_default_client): - # Mock the genai model configuration and creation process - mock_model = MagicMock() - mock_init.return_value = None - mock_generative_model.return_value = mock_model - - # Set up a mock to simulate the vision model behavior - mock_vision_response = MagicMock() - mock_vision_part = MagicMock(text="Vision model output") - - # Setting up the chain of return values for vision model response - mock_vision_response.candidates.__getitem__.return_value.content.parts.__getitem__.return_value = mock_vision_part - - mock_model.generate_content.return_value = mock_vision_response - - # Call the create method with vision model parameters - response = gemini_google_auth_default_client.create( - { - "model": "gemini-pro-vision", # Vision model name - "messages": [ - { - "content": [ - {"type": "text", "text": "Let's play a game."}, - { - "type": "image_url", - "image_url": { - "url": "" - }, - }, - ], - "role": "user", - } - ], # Assuming a simple content input for vision - "stream": False, - } - ) - - # Assertions to check if response is structured as expected - assert ( - response.choices[0].message.content == "Vision model output" - ), "Response content should match expected output from vision model" + assert not response.choices[0].message.tool_calls, "There should be no tool calls" + assert response.usage.prompt_tokens == 100, "Prompt tokens should match the mocked value" + assert response.usage.completion_tokens == 50, "Completion tokens should match the mocked value" + assert response.usage.total_tokens == 150, "Total tokens should be the sum of prompt and completion tokens" + assert response.cost == 0.002, "Cost should match the mocked calculate_gemini_cost return value" + + # Verify that calculate_gemini_cost was called with the correct arguments + mock_calculate_cost.assert_called_once_with(True, 100, 50, "gemini-pro")