From 4caade92475c1b8fa802bf01f439ad257a3c06df Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Fri, 22 Nov 2024 21:05:57 -0800 Subject: [PATCH 01/11] Add Support for OCI DSC LLM AQUA Models in LlamaIndex --- .../.gitignore | 153 +++ .../llama-index-llms-oci-data-science/BUILD | 3 + .../Makefile | 17 + .../README.md | 29 + .../llama_index/llms/oci_data_science/BUILD | 1 + .../llms/oci_data_science/__init__.py | 4 + .../llama_index/llms/oci_data_science/base.py | 970 ++++++++++++++++++ .../llms/oci_data_science/client.py | 742 ++++++++++++++ .../llms/oci_data_science/utils.py | 264 +++++ .../pyproject.toml | 66 ++ .../tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_llms_oci_data_science.py | 344 +++++++ .../tests/test_oci_data_science_client.py | 694 +++++++++++++ .../tests/test_oci_data_science_utils.py | 340 ++++++ 15 files changed, 3628 insertions(+) create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/.gitignore create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/Makefile create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/__init__.py create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/.gitignore b/llama-index-integrations/llms/llama-index-llms-oci-data-science/.gitignore new file mode 100644 index 0000000000000..990c18de22908 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD new file mode 100644 index 0000000000000..0896ca890d8bf --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/Makefile b/llama-index-integrations/llms/llama-index-llms-oci-data-science/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md new file mode 100644 index 0000000000000..26b100d41791b --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md @@ -0,0 +1,29 @@ +# LlamaIndex Llms Integration: Oracle Cloud Infrastructure (OCI) Data Science Service + +Oracle Cloud Infrastructure (OCI) [Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure. + +It offers the [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) that can be used to deploy, evaluate and fine tune foundation models in OCI Data Science. AI Quick Actions target a user who wants to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook. + + +## Installation + +Install the required packages: + +```bash +pip install llama-index-llms-oci-data-science oralce-ads +``` + +The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) is required to simplify the authentication within OCI Data Science. + + +## Basic Usage + +```bash +from llama_index.llms.oci_data_science import OCIDataScience + +TBD +``` + +## LLM Implementation example + +https://docs.llamaindex.ai/en/stable/examples/llm/oci_data_science/ diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/__init__.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/__init__.py new file mode 100644 index 0000000000000..d82f3b1b7f4a5 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/__init__.py @@ -0,0 +1,4 @@ +from llama_index.llms.oci_data_science.base import OCIDataScience + + +__all__ = ["OCIDataScience"] diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py new file mode 100644 index 0000000000000..8bd7212bcb4ac --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py @@ -0,0 +1,970 @@ +import logging +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Generator, + List, + Optional, + Sequence, + Union, +) + +import llama_index.core.instrumentation as instrument +from ads.common import auth as authutil +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.core.bridge.pydantic import ( + BaseModel, + Field, + PrivateAttr, + model_validator, +) +from llama_index.core.callbacks import CallbackManager +from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_TEMPERATURE +from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.llms.llm import ToolSelection +from llama_index.core.llms.utils import parse_partial_json +from llama_index.core.types import BaseOutputParser, Model, PydanticProgramMode +from llama_index.llms.oci_data_science.client import AsyncClient, Client +from llama_index.llms.oci_data_science.utils import ( + _from_completion_logprobs_dict, + _from_message_dict, + _from_token_logprob_dicts, + _get_response_token_counts, + _resolve_tool_choice, + _to_message_dicts, + _update_tool_calls, + _validate_dependency, +) + +dispatcher = instrument.get_dispatcher(__name__) +if TYPE_CHECKING: + from llama_index.core.tools.types import BaseTool + +DEFAULT_MODEL = "odsc-llm" +DEFAULT_MAX_TOKENS = 512 +DEFAULT_TIMEOUT = 120 +DEFAULT_MAX_RETRIES = 5 + +logger = logging.getLogger(__name__) + + +class OCIDataScience(FunctionCallingLLM): + """ + LLM deployed on OCI Data Science Model Deployment. + + **Setup:** + Install ``oracle-ads`` and ``llama-index-oci-data-science``. + + ```bash + pip install -U oracle-ads llama-index-oci-data-science + ``` + + Use `ads.set_auth()` to configure authentication. + For example, to use OCI resource_principal for authentication: + + ```python + import ads + ads.set_auth("resource_principal") + ``` + + For more details on authentication, see: + https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html + + Make sure to have the required policies to access the OCI Data + Science Model Deployment endpoint. See: + https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm + + To learn more about deploying LLM models in OCI Data Science, see: + https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm + + + **Examples:** + + **Basic Usage:** + + ```python + from llama_index.llms.oci_data_science import OCIDataScience + import ads + ads.set_auth(auth="security_token", profile="OC1") + + llm = OCIDataScience( + endpoint="https:///predict", + model="odsc-llm", + ) + prompt = "What is the capital of France?" + response = llm.complete(prompt) + print(response) + ``` + + **Custom Parameters:** + + ```python + llm = OCIDataScience( + endpoint="https:///predict", + model="odsc-llm", + temperature=0.7, + max_tokens=150, + additional_kwargs={"top_p": 0.9}, + ) + ``` + + **Using Chat Interface:** + + ```python + messages = [ + ChatMessage(role="user", content="Tell me a joke."), + ChatMessage(role="assistant", content="Why did the chicken cross the road?"), + ChatMessage(role="user", content="I don't know, why?"), + ] + + chat_response = llm.chat(messages) + print(chat_response) + ``` + + **Streaming Completion:** + + ```python + for chunk in llm.stream_complete("Once upon a time"): + print(chunk.delta, end="") + ``` + + **Asynchronous Chat:** + + ```python + import asyncio + + async def async_chat(): + messages = [ + ChatMessage(role="user", content="What's the weather like today?") + ] + response = await llm.achat(messages) + print(response) + + asyncio.run(async_chat()) + ``` + + **Using Tools (Function Calling):** + + ```python + from llama_index.llms.oci_data_science import OCIDataScience + from llama_index.core.tools import FunctionTool + import ads + ads.set_auth(auth="security_token", profile="OC1") + + def multiply(a: float, b: float) -> float: + return a * b + + def add(a: float, b: float) -> float: + return a + b + + def subtract(a: float, b: float) -> float: + return a - b + + def divide(a: float, b: float) -> float: + return a / b + + + multiply_tool = FunctionTool.from_defaults(fn=multiply) + add_tool = FunctionTool.from_defaults(fn=add) + sub_tool = FunctionTool.from_defaults(fn=subtract) + divide_tool = FunctionTool.from_defaults(fn=divide) + + llm = OCIDataScience( + endpoint="https:///predict", + model="odsc-llm", + temperature=0.7, + max_tokens=150, + additional_kwargs={"top_p": 0.9}, + ) + + response = llm.chat_with_tools( + user_msg="Calculate the result of 2 + 2.", + tools=[multiply_tool, add_tool, sub_tool, divide_tool], + ) + print(response) + ``` + """ + + endpoint: str = Field( + default=None, description="The URI of the endpoint from the deployed model." + ) + + auth: Dict[str, Any] = Field( + default_factory=dict, + exclude=True, + description=( + "The authentication dictionary used for OCI API requests. Default is an empty dictionary. " + "If not provided, it will be autogenerated based on the environment variables. " + "https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html." + ), + ) + model: Optional[str] = Field( + default=DEFAULT_MODEL, + description="The OCI Data Science default model. Defaults to `odsc-llm`.", + ) + temperature: Optional[float] = Field( + default=DEFAULT_TEMPERATURE, + description="A non-negative float that tunes the degree of randomness in generation.", + ge=0.0, + le=1.0, + ) + max_tokens: Optional[int] = Field( + default=DEFAULT_MAX_TOKENS, + description="Denotes the number of tokens to predict per generation.", + gt=0, + ) + timeout: float = Field( + default=DEFAULT_TIMEOUT, description="The timeout to use in seconds.", ge=0 + ) + max_retries: int = Field( + default=DEFAULT_MAX_RETRIES, + description="The maximum number of API retries.", + ge=0, + ) + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, + ) + is_chat_model: bool = Field( + default=True, + description="If the model exposes a chat interface.", + ) + is_function_calling_model: bool = Field( + default=True, + description="If the model supports function calling messages.", + ) + additional_kwargs: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Additional kwargs for the OCI Data Science AI request.", + ) + strict: bool = Field( + default=False, + description="Whether to use strict mode for invoking tools/using schemas.", + ) + + _client: Client = PrivateAttr() + _async_client: AsyncClient = PrivateAttr() + + def __init__( + self, + endpoint: str, + auth: Optional[Dict[str, Any]] = None, + model: Optional[str] = DEFAULT_MODEL, + temperature: Optional[float] = DEFAULT_TEMPERATURE, + max_tokens: Optional[int] = DEFAULT_MAX_TOKENS, + context_window: Optional[int] = DEFAULT_CONTEXT_WINDOW, + timeout: Optional[float] = DEFAULT_TIMEOUT, + max_retries: Optional[int] = DEFAULT_MAX_RETRIES, + additional_kwargs: Optional[Dict[str, Any]] = None, + callback_manager: Optional[CallbackManager] = None, + is_chat_model: Optional[bool] = True, + is_function_calling_model: Optional[bool] = True, + # base class + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + strict: bool = False, + **kwargs, + ) -> None: + """ + Initialize the OCIDataScience LLM class. + + Args: + endpoint (str): The URI of the endpoint from the deployed model. + auth (Optional[Dict[str, Any]]): Authentication dictionary for OCI API requests. + model (Optional[str]): The model name to use. Defaults to `odsc-llm`. + temperature (Optional[float]): Controls the randomness in generation. + max_tokens (Optional[int]): Number of tokens to predict per generation. + context_window (Optional[int]): Maximum number of context tokens for the model. + timeout (Optional[float]): Timeout for API requests in seconds. + max_retries (Optional[int]): Maximum number of API retries. + additional_kwargs (Optional[Dict[str, Any]]): Additional parameters for the API request. + callback_manager (Optional[CallbackManager]): Callback manager for LLM. + is_chat_model (Optional[bool]): If the model exposes a chat interface. Defaults to `True`. + is_function_calling_model (Optional[bool]): If the model supports function calling messages. Defaults to `True`. + system_prompt (Optional[str]): System prompt to use. + messages_to_prompt (Optional[Callable]): Function to convert messages to prompt. + completion_to_prompt (Optional[Callable]): Function to convert completion to prompt. + pydantic_program_mode (PydanticProgramMode): Pydantic program mode. + output_parser (Optional[BaseOutputParser]): Output parser for the LLM. + strict (bool): Whether to use strict mode for invoking tools/using schemas. + **kwargs: Additional keyword arguments. + """ + super().__init__( + endpoint=endpoint, + model=model, + auth=auth or authutil.default_signer(), + temperature=temperature, + context_window=context_window, + max_tokens=max_tokens, + timeout=timeout, + max_retries=max_retries, + additional_kwargs=additional_kwargs or {}, + callback_manager=callback_manager or CallbackManager([]), + is_chat_model=is_chat_model, + is_function_calling_model=is_function_calling_model, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + strict=strict, + **kwargs, + ) + + self._client: Client = None + self._async_client: AsyncClient = None + + logger.debug( + f"Initialized OCIDataScience LLM with endpoint: {self.endpoint} and model: {self.model}" + ) + + @model_validator(mode="before") + @_validate_dependency + def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate the environment and dependencies.""" + return values + + @property + def client(self) -> Client: + """ + Synchronous client for interacting with the OCI Data Science Model Deployment endpoint. + + Returns: + Client: The synchronous client instance. + """ + if self._client is None: + self._client = Client( + endpoint=self.endpoint, + auth=self.auth, + retries=self.max_retries, + timeout=self.timeout, + ) + return self._client + + @property + def async_client(self) -> AsyncClient: + """ + Asynchronous client for interacting with the OCI Data Science Model Deployment endpoint. + + Returns: + AsyncClient: The asynchronous client instance. + """ + if self._async_client is None: + self._async_client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth, + retries=self.max_retries, + timeout=self.timeout, + ) + return self._async_client + + @classmethod + def class_name(cls) -> str: + """ + Return the class name. + + Returns: + str: The name of the class. + """ + return "OCIDataScience_LLM" + + @property + def metadata(self) -> LLMMetadata: + """ + Return the metadata of the LLM. + + Returns: + LLMMetadata: The metadata of the LLM. + """ + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_tokens or -1, + is_chat_model=self.is_chat_model, + is_function_calling_model=self.is_function_calling_model, + model_name=self.model, + ) + + def _model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + """ + Get model-specific parameters for the API request. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + Dict[str, Any]: The combined model parameters. + """ + base_kwargs = { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + } + return {**base_kwargs, **self.additional_kwargs, **kwargs} + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """ + Generate a completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Returns: + CompletionResponse: The response from the LLM. + """ + logger.debug(f"Calling complete with prompt: {prompt}") + response = self.client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=False, + ) + + logger.debug(f"Received response: {response}") + try: + choice = response["choices"][0] + text = choice.get("text", "") + logprobs = _from_completion_logprobs_dict(choice.get("logprobs") or {}) + + return CompletionResponse( + text=text, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {str(e)}") from e + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + """ + Stream the completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Yields: + CompletionResponse: The streamed response from the LLM. + """ + logger.debug(f"Starting stream_complete with prompt: {prompt}") + text = "" + for response in self.client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=True, + ): + logger.debug(f"Received chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("text") + if delta is None: + delta = "" + else: + delta = "" + text += delta + + yield CompletionResponse( + delta=delta, + text=text, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + """ + Generate a chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Returns: + ChatResponse: The chat response from the LLM. + """ + logger.debug(f"Calling chat with messages: {messages}") + response = self.client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=False, + ) + + logger.debug(f"Received chat response: {response}") + try: + choice = response["choices"][0] + message = _from_message_dict(choice.get("message", "")) + logprobs = _from_token_logprob_dicts( + (choice.get("logprobs") or {}).get("content", []) + ) + return ChatResponse( + message=message, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {str(e)}") from e + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + """ + Stream the chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Yields: + ChatResponse: The streamed chat response from the LLM. + """ + logger.debug(f"Starting stream_chat with messages: {messages}") + content = "" + is_function = False + tool_calls = [] + for response in self.client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=True, + ): + logger.debug(f"Received chat chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("delta") or {} + else: + delta = {} + + # Check if this chunk is the start of a function call + if delta.get("tool_calls"): + is_function = True + + # Update using deltas + role = delta.get("role") or MessageRole.ASSISTANT + content_delta = delta.get("content") or "" + content += content_delta + + additional_kwargs = {} + if is_function: + tool_calls = _update_tool_calls(tool_calls, delta.get("tool_calls")) + if tool_calls: + additional_kwargs["tool_calls"] = tool_calls + + yield ChatResponse( + message=ChatMessage( + role=role, + content=content, + additional_kwargs=additional_kwargs, + ), + delta=content_delta, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """ + Asynchronously generate a completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Returns: + CompletionResponse: The response from the LLM. + """ + logger.debug(f"Calling acomplete with prompt: {prompt}") + response = await self.async_client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=False, + ) + + logger.debug(f"Received async response: {response}") + try: + choice = response["choices"][0] + text = choice.get("text", "") + logprobs = _from_completion_logprobs_dict(choice.get("logprobs", {}) or {}) + + return CompletionResponse( + text=text, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {str(e)}") from e + + @llm_completion_callback() + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + """ + Asynchronously stream the completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Yields: + CompletionResponse: The streamed response from the LLM. + """ + + async def gen() -> CompletionResponseAsyncGen: + logger.debug(f"Starting astream_complete with prompt: {prompt}") + text = "" + + async for response in await self.async_client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=True, + ): + logger.debug(f"Received async chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("text") + if delta is None: + delta = "" + else: + delta = "" + text += delta + + yield CompletionResponse( + delta=delta, + text=text, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + return gen() + + @llm_chat_callback() + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + """ + Asynchronously generate a chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Returns: + ChatResponse: The chat response from the LLM. + """ + logger.debug(f"Calling achat with messages: {messages}") + response = await self.async_client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=False, + ) + + logger.debug(f"Received async chat response: {response}") + try: + choice = response["choices"][0] + message = _from_message_dict(choice.get("message", "")) + logprobs = _from_token_logprob_dicts( + (choice.get("logprobs") or {}).get("content", {}) + ) + return ChatResponse( + message=message, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {str(e)}") from e + + @llm_chat_callback() + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + """ + Asynchronously stream the chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Yields: + ChatResponse: The streamed chat response from the LLM. + """ + + async def gen() -> ChatResponseAsyncGen: + logger.debug(f"Starting astream_chat with messages: {messages}") + content = "" + is_function = False + first_chat_chunk = True + tool_calls = [] + async for response in await self.async_client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=kwargs.pop("headers", None), + stream=True, + ): + logger.debug(f"Received async chat chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("delta") or {} + else: + delta = {} + + # Check if this chunk is the start of a function call + if delta.get("tool_calls"): + is_function = True + + # Update using deltas + role = delta.get("role") or MessageRole.ASSISTANT + content_delta = delta.get("content") or "" + content += content_delta + + additional_kwargs = {} + if is_function: + tool_calls = _update_tool_calls(tool_calls, delta.get("tool_calls")) + if tool_calls: + additional_kwargs["tool_calls"] = tool_calls + + yield ChatResponse( + message=ChatMessage( + role=role, + content=content, + additional_kwargs=additional_kwargs, + ), + delta=content_delta, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + return gen() + + def _prepare_chat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + tool_choice: Union[str, dict] = "auto", + strict: Optional[bool] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """ + Prepare the chat input with tools for function calling. + + Args: + tools (List[BaseTool]): A list of tools to use. + user_msg (Optional[Union[str, ChatMessage]]): The user's message. + chat_history (Optional[List[ChatMessage]]): The chat history. + verbose (bool): Whether to output verbose logs. + allow_parallel_tool_calls (bool): Whether to allow parallel tool calls. + tool_choice (Union[str, dict]): Tool choice strategy. + strict (Optional[bool]): Whether to enforce strict mode. + **kwargs: Additional keyword arguments. + + Returns: + Dict[str, Any]: The prepared parameters for the chat request. + """ + logger.debug( + f"Preparing chat with tools. Tools: {tools}, User message: {user_msg}, " + f"Chat history: {chat_history}" + ) + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + + # Determine strict mode + if strict is not None: + strict = strict + else: + strict = self.strict + + if self.metadata.is_function_calling_model: + for tool_spec in tool_specs: + if tool_spec["type"] == "function": + if strict: + tool_spec["function"]["strict"] = strict + tool_spec["function"]["parameters"]["additionalProperties"] = False + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + return { + "messages": messages, + "tools": tool_specs or None, + "tool_choice": _resolve_tool_choice(tool_choice) if tool_specs else None, + **kwargs, + } + + def _validate_chat_with_tools_response( + self, + response: ChatResponse, + tools: List["BaseTool"], + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """ + Validate the response from chat_with_tools. + + Args: + response (ChatResponse): The chat response to validate. + tools (List[BaseTool]): A list of tools used. + allow_parallel_tool_calls (bool): Whether parallel tool calls are allowed. + **kwargs: Additional keyword arguments. + + Returns: + ChatResponse: The validated chat response. + """ + if not allow_parallel_tool_calls: + # Ensures that the 'tool_calls' in the response contain only a single tool call. + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) > 1: + logger.debug( + "Multiple tool calls detected but parallel tool calls are not allowed. " + "Limiting to the first tool call." + ) + response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + return response + + def get_tool_calls_from_response( + self, + response: ChatResponse, + error_on_no_tool_call: bool = True, + **kwargs: Any, + ) -> List[ToolSelection]: + """ + Extract tool calls from the chat response. + + Args: + response (ChatResponse): The chat response containing tool calls. + error_on_no_tool_call (bool): Whether to raise an error if no tool calls are found. + **kwargs: Additional keyword arguments. + + Returns: + List[ToolSelection]: A list of tool selections extracted from the response. + + Raises: + ValueError: If no tool calls are found and error_on_no_tool_call is True. + """ + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + logger.debug(f"Extracted tool calls: {tool_calls}") + + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + if tool_call.get("type") != "function": + logger.error(f"Invalid tool type detected: {tool_call.get('type')}") + raise ValueError("Invalid tool type.") + + # Handle both complete and partial JSON + try: + argument_dict = parse_partial_json( + tool_call.get("function", {}).get("arguments", {}) + ) + except ValueError as e: + logger.debug(f"Failed to parse tool call arguments: {str(e)}") + argument_dict = {} + + tool_selections.append( + ToolSelection( + tool_id=tool_call.get("id"), + tool_name=tool_call.get("function", {}).get("name"), + tool_kwargs=argument_dict, + ) + ) + + return tool_selections + + @dispatcher.span + def structured_predict( + self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> BaseModel: + # force tool_choice to be required + llm_kwargs = llm_kwargs or {} + llm_kwargs["tool_choice"] = ( + "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] + ) + return super().structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) + + @dispatcher.span + async def astructured_predict( + self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> BaseModel: + # force tool_choice to be required + llm_kwargs = llm_kwargs or {} + llm_kwargs["tool_choice"] = ( + "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] + ) + return await super().astructured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) + + @dispatcher.span + def stream_structured_predict( + self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> Generator[Union[Model, List[Model]], None, None]: + # force tool_choice to be required + llm_kwargs = llm_kwargs or {} + llm_kwargs["tool_choice"] = ( + "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] + ) + return super().stream_structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) + + @dispatcher.span + async def astream_structured_predict( + self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> AsyncGenerator[Union[Model, List[Model]], None]: + # force tool_choice to be required + llm_kwargs = llm_kwargs or {} + llm_kwargs["tool_choice"] = ( + "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] + ) + return await super().astream_structured_predict( + *args, llm_kwargs=llm_kwargs, **kwargs + ) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py new file mode 100644 index 0000000000000..0996257d50ac4 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py @@ -0,0 +1,742 @@ +import asyncio +import functools +import json +import logging +import time +from abc import ABC +from types import TracebackType +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) + +import httpx +import oci +import requests +from ads.common import auth as authutil +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception, + stop_after_attempt, + stop_after_delay, + wait_exponential, + wait_random_exponential, +) + +DEFAULT_RETRIES = 3 +DEFAULT_BACKOFF_FACTOR = 3 +TIMEOUT = 600 # Timeout in seconds +STATUS_FORCE_LIST = [429, 500, 502, 503, 504] +DEFAULT_ENCODING = "utf-8" + +_T = TypeVar("_T", bound="BaseClient") + +logger = logging.getLogger(__name__) + + +class OCIAuth(httpx.Auth): + """ + Custom HTTPX authentication class that uses the OCI Signer for request signing. + + Attributes: + signer (oci.signer.Signer): The OCI signer used to sign requests. + """ + + def __init__(self, signer: oci.signer.Signer): + """ + Initialize the OCIAuth instance. + + Args: + signer (oci.signer.Signer): The OCI signer to use for signing requests. + """ + self.signer = signer + + def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]: + """ + The authentication flow that signs the HTTPX request using the OCI signer. + + Args: + request (httpx.Request): The outgoing HTTPX request to be signed. + + Yields: + httpx.Request: The signed HTTPX request. + """ + # Create a requests.Request object from the HTTPX request + req = requests.Request( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + data=request.content, + ) + prepared_request = req.prepare() + + # Sign the request using the OCI Signer + self.signer.do_request_sign(prepared_request) + + # Update the original HTTPX request with the signed headers + request.headers.update(prepared_request.headers) + + # Proceed with the request + yield request + + +class ExtendedRequestException(Exception): + """ + Custom exception for handling request errors with additional context. + + Attributes: + original_exception (Exception): The original exception that caused the error. + response_text (str): The text of the response received from the request, if available. + """ + + def __init__(self, message: str, original_exception: Exception, response_text: str): + """ + Initialize the ExtendedRequestException. + + Args: + message (str): The error message associated with the exception. + original_exception (Exception): The original exception that caused the error. + response_text (str): The text of the response received from the request, if available. + """ + super().__init__(message) + self.original_exception = original_exception + self.response_text = response_text + + +def _should_retry_exception(e: ExtendedRequestException) -> bool: + """ + Determine whether the exception should trigger a retry. + + Args: + e (ExtendedRequestException): The exception raised. + + Returns: + bool: True if the exception should trigger a retry, False otherwise. + """ + original_exception = e.original_exception if hasattr(e, "original_exception") else e + if isinstance(original_exception, httpx.HTTPStatusError): + return original_exception.response.status_code in STATUS_FORCE_LIST + elif isinstance(original_exception, httpx.RequestError): + return True + return False + + +def _create_retry_decorator( + max_retries: int, + backoff_factor: float, + random_exponential: bool = False, + stop_after_delay_seconds: Optional[float] = None, + min_seconds: float = 0, + max_seconds: float = 60, +) -> Callable[[Any], Any]: + """ + Create a tenacity retry decorator with the specified configuration. + + Args: + max_retries (int): The maximum number of retry attempts. + backoff_factor (float): The backoff factor for calculating retry delays. + random_exponential (bool): Whether to use random exponential backoff. + stop_after_delay_seconds (Optional[float]): Maximum total time to retry. + min_seconds (float): Minimum wait time between retries. + max_seconds (float): Maximum wait time between retries. + + Returns: + Callable[[Any], Any]: A tenacity retry decorator configured with the specified strategy. + """ + wait_strategy = ( + wait_random_exponential(min=min_seconds, max=max_seconds) + if random_exponential + else wait_exponential( + multiplier=backoff_factor, min=min_seconds, max=max_seconds + ) + ) + + stop_strategy = stop_after_attempt(max_retries) + if stop_after_delay_seconds is not None: + stop_strategy = stop_strategy | stop_after_delay(stop_after_delay_seconds) + + retry_strategy = retry_if_exception(_should_retry_exception) + return retry( + wait=wait_strategy, + stop=stop_strategy, + retry=retry_strategy, + reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _retry_decorator(f: Callable) -> Callable: + """ + Decorator to apply retry logic to a function using tenacity. + + Args: + f (Callable): The function to be decorated. + + Returns: + Callable: The decorated function with retry logic applied. + """ + + @functools.wraps(f) + def wrapper(self, *args: Any, **kwargs: Any): + retries = getattr(self, "retries", DEFAULT_RETRIES) + if retries <= 0: + return f(self, *args, **kwargs) + backoff_factor = getattr(self, "backoff_factor", DEFAULT_BACKOFF_FACTOR) + retry_func = _create_retry_decorator( + max_retries=retries, + backoff_factor=backoff_factor, + random_exponential=False, + stop_after_delay_seconds=getattr(self, "timeout", TIMEOUT), + min_seconds=0, + max_seconds=60, + ) + + return retry_func(f)(self, *args, **kwargs) + + return wrapper + + +class BaseClient(ABC): + """ + Base class for invoking models via HTTP requests with retry logic. + + Attributes: + endpoint (str): The URL endpoint to send the request. + auth (Any): The authentication signer for the requests. + retries (int): The number of retry attempts for the request. + backoff_factor (float): The factor to determine the delay between retries. + timeout (Union[float, Tuple[float, float]]): The timeout setting for the HTTP request. + kwargs (Dict): Additional keyword arguments. + """ + + def __init__( + self, + endpoint: str, + auth: Optional[Any] = None, + retries: Optional[int] = DEFAULT_RETRIES, + backoff_factor: Optional[float] = DEFAULT_BACKOFF_FACTOR, + timeout: Optional[Union[float, Tuple[float, float]]] = None, + **kwargs: Any, + ) -> None: + """ + Initialize the BaseClient. + + Args: + endpoint (str): The URL endpoint to send the request. + auth (Optional[Any]): The authentication signer for the requests. + retries (Optional[int]): The number of retry attempts for the request. + backoff_factor (Optional[float]): The factor to determine the delay between retries. + timeout (Optional[Union[float, Tuple[float, float]]]): The timeout setting for the HTTP request. + **kwargs: Additional keyword arguments. + """ + self.endpoint = endpoint + self.retries = retries or DEFAULT_RETRIES + self.backoff_factor = backoff_factor or DEFAULT_BACKOFF_FACTOR + self.timeout = timeout or TIMEOUT + self.kwargs = kwargs + + # Validate auth object + auth = auth or authutil.default_signer() + if not callable(auth.get("signer")): + raise ValueError("Auth object must have a 'signer' callable attribute.") + self.auth = OCIAuth(auth["signer"]) + + logger.debug( + f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, " + f"retries={self.retries}, backoff_factor={self.backoff_factor}, timeout={self.timeout}" + ) + + def _parse_streaming_line( + self, line: Union[bytes, str] + ) -> Optional[Dict[str, Any]]: + """ + Parse a single line from the streaming response. + + Args: + line (Union[bytes, str]): A line of the response in bytes or string format. + + Returns: + Optional[Dict[str, Any]]: Parsed JSON object, or None if the line is to be ignored. + + Raises: + Exception: Raised if the line contains an error object. + json.JSONDecodeError: Raised if the line cannot be decoded as JSON. + """ + logger.debug(f"Parsing streaming line: {line}") + + if isinstance(line, bytes): + line = line.decode(DEFAULT_ENCODING) + + line = line.strip() + + if line.lower().startswith("data:"): + line = line[5:].lstrip() + + if not line or line.startswith("[DONE]"): + logger.debug("Received end of stream signal or empty line.") + return None + + try: + json_line = json.loads(line) + logger.debug(f"Parsed JSON line: {json_line}") + except json.JSONDecodeError as e: + logger.debug(f"Error decoding JSON from line: {line}") + raise json.JSONDecodeError( + f"Error decoding JSON from line: {str(e)}", e.doc, e.pos + ) from e + + if json_line.get("object") == "error": + # Raise an error for error objects in the stream + error_message = json_line.get("message", "Unknown error") + logger.debug(f"Error in streaming response: {error_message}") + raise Exception(f"Error in streaming response: {error_message}") + + return json_line + + def _prepare_headers( + self, + stream: bool, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + """ + Construct and return the headers for a request. + + Args: + stream (bool): Whether to use streaming for the response. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, str]: The prepared headers. + """ + default_headers = { + "Content-Type": "application/json", + "Accept": "text/event-stream" if stream else "application/json", + } + if stream: + default_headers["enable-streaming"] = "true" + if headers: + default_headers.update(headers) + + logger.debug(f"Prepared headers: {default_headers}") + return default_headers + + +class Client(BaseClient): + """ + Synchronous HTTP client for invoking models with support for request and streaming APIs. + """ + + def __init__(self, *args, **kwargs): + """ + Initialize the Client. + + Args: + *args: Positional arguments forwarded to BaseClient. + **kwargs: Keyword arguments forwarded to BaseClient. + """ + super().__init__(*args, **kwargs) + self._client = httpx.Client(timeout=self.timeout) + + def is_closed(self) -> bool: + return self._client.is_closed + + def close(self) -> None: + """Close the underlying HTTPX client.""" + self._client.close() + + def __enter__(self: _T) -> _T: + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + self.close() + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + @_retry_decorator + def _request( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """ + Send a POST request to the configured endpoint with retry and error handling. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, Any]: Decoded JSON response. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting synchronous request with payload: {payload}") + try: + response = self._client.post( + self.endpoint, + headers=self._prepare_headers(stream=False, headers=headers), + auth=self.auth, + json=payload, + ) + logger.debug(f"Received response with status code: {response.status_code}") + response.raise_for_status() + json_response = response.json() + logger.debug(f"Response JSON: {json_response}") + return json_response + except Exception as e: + last_exception_text = ( + e.response.text if hasattr(e, "response") and e.response else str(e) + ) + logger.error( + f"Request failed. Error: {str(e)}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Request failed: {str(e)}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + def _stream( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Iterator[Mapping[str, Any]]: + """ + Send a POST request expecting a streaming response. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Yields: + Mapping[str, Any]: Decoded JSON response line-by-line. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting synchronous streaming request with payload: {payload}") + last_exception_text = None + + for attempt in range(1, self.retries + 2): # retries + initial attempt + logger.debug(f"Attempt {attempt} for synchronous streaming request.") + try: + with self._client.stream( + "POST", + self.endpoint, + headers=self._prepare_headers(stream=True, headers=headers), + auth=self.auth, + json={**payload, "stream": True}, + ) as response: + try: + logger.debug( + f"Received streaming response with status code: {response.status_code}" + ) + response.raise_for_status() + for line in response.iter_lines(): + if not line: # Skip empty lines + continue + + parsed_line = self._parse_streaming_line(line) + if parsed_line: + logger.debug(f"Yielding parsed line: {parsed_line}") + yield parsed_line + return + except Exception as e: + last_exception_text = ( + e.response.read().decode( + e.response.encoding or DEFAULT_ENCODING + ) + if hasattr(e, "response") and e.response + else str(e) + ) + raise + + except Exception as e: + if attempt <= self.retries and _should_retry_exception(e): + delay = self.backoff_factor * (2 ** (attempt - 1)) + logger.warning( + f"Streaming attempt {attempt} failed: {e}. Retrying in {delay} seconds..." + ) + time.sleep(delay) + else: + logger.error( + f"Streaming request failed. Error: {str(e)}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Streaming request failed: {str(e)}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + def generate( + self, + prompt: str, + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = True, + ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: + """ + Generate text completion for the given prompt. + + Args: + prompt (str): Input text prompt for the model. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: A full JSON response or an iterator for streaming responses. + """ + logger.debug(f"Generating text with prompt: {prompt}, stream: {stream}") + payload = {**(payload or {}), "prompt": prompt} + if stream: + return self._stream(payload=payload, headers=headers) + return self._request(payload=payload, headers=headers) + + def chat( + self, + messages: List[Dict[str, Any]], + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = True, + ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: + """ + Perform a chat interaction with the model. + + Args: + messages (List[Dict[str, Any]]): List of message dictionaries for chat interaction. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: A full JSON response or an iterator for streaming responses. + """ + logger.debug(f"Starting chat with messages: {messages}, stream: {stream}") + payload = {**(payload or {}), "messages": messages} + if stream: + return self._stream(payload=payload, headers=headers) + return self._request(payload=payload, headers=headers) + + +class AsyncClient(BaseClient): + """ + Asynchronous HTTP client for invoking models with support for request and streaming APIs, including retry logic. + """ + + def __init__(self, *args, **kwargs): + """ + Initialize the AsyncClient. + + Args: + *args: Positional arguments forwarded to BaseClient. + **kwargs: Keyword arguments forwarded to BaseClient. + """ + super().__init__(*args, **kwargs) + self._client = httpx.AsyncClient(timeout=self.timeout) + + def is_closed(self) -> bool: + return self._client.is_closed + + async def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + await self._client.aclose() + + async def __aenter__(self: _T) -> _T: + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + await self.close() + + def __del__(self) -> None: + try: + if not self._client.is_closed: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + pass + + @_retry_decorator + async def _request( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """ + Send a POST request to the configured endpoint with retry and error handling. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, Any]: Decoded JSON response. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting asynchronous request with payload: {payload}") + try: + response = await self._client.post( + self.endpoint, + headers=self._prepare_headers(stream=False, headers=headers), + auth=self.auth, + json=payload, + ) + logger.debug(f"Received response with status code: {response.status_code}") + response.raise_for_status() + json_response = response.json() + logger.debug(f"Response JSON: {json_response}") + return json_response + except Exception as e: + last_exception_text = ( + e.response.text if hasattr(e, "response") and e.response else str(e) + ) + logger.error( + f"Request failed. Error: {str(e)}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Request failed: {str(e)}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + async def _stream( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> AsyncIterator[Mapping[str, Any]]: + """ + Send a POST request expecting a streaming response with retry logic. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Yields: + Mapping[str, Any]: Decoded JSON response line-by-line. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting asynchronous streaming request with payload: {payload}") + last_exception_text = None + for attempt in range(1, self.retries + 2): # retries + initial attempt + logger.debug(f"Attempt {attempt} for asynchronous streaming request.") + try: + async with self._client.stream( + "POST", + self.endpoint, + headers=self._prepare_headers(stream=True, headers=headers), + auth=self.auth, + json={**payload, "stream": True}, + ) as response: + try: + logger.debug( + f"Received streaming response with status code: {response.status_code}" + ) + response.raise_for_status() + async for line in response.aiter_lines(): + if not line: # Skip empty lines + continue + parsed_line = self._parse_streaming_line(line) + if parsed_line: + logger.debug(f"Yielding parsed line: {parsed_line}") + yield parsed_line + return + except Exception as e: + if hasattr(e, "response") and e.response: + content = await e.response.aread() + last_exception_text = content.decode( + e.response.encoding or DEFAULT_ENCODING + ) + raise + except Exception as e: + if attempt <= self.retries and _should_retry_exception(e): + delay = self.backoff_factor * (2 ** (attempt - 1)) + logger.warning( + f"Streaming attempt {attempt} failed: {e}. Retrying in {delay} seconds..." + ) + await asyncio.sleep(delay) + else: + logger.error( + f"Streaming request failed. Error: {str(e)}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Streaming request failed: {str(e)}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + async def generate( + self, + prompt: str, + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = False, + ) -> Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: + """ + Generate text completion for the given prompt. + + Args: + prompt (str): Input text prompt for the model. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: A full JSON response or an async iterator for streaming responses. + """ + logger.debug(f"Generating text with prompt: {prompt}, stream: {stream}") + payload = {**(payload or {}), "prompt": prompt} + if stream: + return self._stream(payload=payload, headers=headers) + return await self._request(payload=payload, headers=headers) + + async def chat( + self, + messages: List[Dict[str, Any]], + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = False, + ) -> Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: + """ + Perform a chat interaction with the model. + + Args: + messages (List[Dict[str, Any]]): List of message dictionaries for chat interaction. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: A full JSON response or an async iterator for streaming responses. + """ + logger.debug(f"Starting chat with messages: {messages}, stream: {stream}") + payload = {**(payload or {}), "messages": messages} + if stream: + return self._stream(payload=payload, headers=headers) + return await self._request(payload=payload, headers=headers) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py new file mode 100644 index 0000000000000..3ac495250bdde --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py @@ -0,0 +1,264 @@ +import logging +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from llama_index.core.base.llms.types import ChatMessage, LogProb +from packaging import version + +MIN_ADS_VERSION = "2.12.6" + +logger = logging.getLogger(__name__) + + +class UnsupportedOracleAdsVersionError(Exception): + """ + Custom exception for unsupported `oracle-ads` versions. + + Attributes + ---------- + current_version : str + The installed version of `oracle-ads`. + required_version : str + The minimum required version of `oracle-ads`. + """ + + def __init__(self, current_version: str, required_version: str): + super().__init__( + f"The `oracle-ads` version {current_version} currently installed is incompatible with " + "the `llama-index-llms-oci-data-science` version in use. To resolve this issue, " + f"please upgrade to `oracle-ads:{required_version}` or later using the " + "command: `pip install oracle-ads -U`" + ) + + +def _validate_dependency(func: Callable[..., Any]) -> Callable[..., Any]: + """ + Decorator to validate the presence and version of the `oracle-ads` package. + + This decorator checks whether `oracle-ads` is installed and ensures its version meets + the minimum requirement. Raises an error if the conditions are not met. + + Parameters + ---------- + func : Callable[..., Any] + The function to wrap with the dependency validation. + + Returns + ------- + Callable[..., Any] + The wrapped function. + + Raises + ------ + ImportError + If `oracle-ads` is not installed. + UnsupportedOracleAdsVersionError + If the installed version is below the required version. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + try: + from ads import __version__ as ads_version + + if version.parse(ads_version) < version.parse(MIN_ADS_VERSION): + raise UnsupportedOracleAdsVersionError(ads_version, MIN_ADS_VERSION) + + except ImportError as ex: + raise ImportError( + "Could not import `oracle-ads` Python package. " + "Please install it with `pip install oracle-ads`." + ) from ex + return func(*args, **kwargs) + + return wrapper + + +def _to_message_dicts( + messages: Sequence[ChatMessage], drop_none: bool = False +) -> List[Dict[str, Any]]: + """ + Converts a sequence of ChatMessage objects to a list of dictionaries. + + Parameters + ---------- + messages : Sequence[ChatMessage] + The messages to convert. + drop_none : bool, optional + Whether to drop keys with `None` values. Defaults to False. + + Returns + ------- + List[Dict[str, Any]] + The converted list of message dictionaries. + """ + message_dicts = [] + for message in messages: + message_dict = { + "role": message.role.value, + "content": message.content, + **message.additional_kwargs, + } + if drop_none: + message_dict = {k: v for k, v in message_dict.items() if v is not None} + message_dicts.append(message_dict) + return message_dicts + + +def _from_completion_logprobs_dict( + completion_logprobs_dict: Dict[str, Any] +) -> List[List[LogProb]]: + """ + Converts completion logprobs to a list of generic LogProb objects. + + Parameters + ---------- + completion_logprobs_dict : Dict[str, Any] + The completion logprobs to convert. + + Returns + ------- + List[List[LogProb]] + The converted logprobs. + """ + return [ + [ + LogProb(token=token, logprob=logprob, bytes=[]) + for token, logprob in logprob_dict.items() + ] + for logprob_dict in completion_logprobs_dict.get("top_logprobs", []) + ] + + +def _from_token_logprob_dicts( + token_logprob_dicts: Sequence[Dict[str, Any]], +) -> List[List[LogProb]]: + """ + Converts a sequence of token logprob dictionaries to a list of lists of LogProb objects. + + Parameters + ---------- + token_logprob_dicts : Sequence[Dict[str, Any]] + The token logprob dictionaries to convert. + + Returns + ------- + List[List[LogProb]] + The converted logprobs. + """ + result = [] + for token_logprob_dict in token_logprob_dicts: + try: + logprobs_list = [ + LogProb( + token=el.get("token"), + logprob=el.get("logprob"), + bytes=el.get("bytes") or [], + ) + for el in token_logprob_dict.get("top_logprobs", []) + ] + if logprobs_list: + result.append(logprobs_list) + except Exception as e: + logger.warning( + f"Error occurred in attempt to parse token logprob. " + f"Details: {e}. Src: {token_logprob_dict}" + ) + return result + + +def _from_message_dict(message_dict: Dict[str, Any]) -> ChatMessage: + """ + Converts a message dictionary to a generic ChatMessage object. + + Parameters + ---------- + message_dict : Dict[str, Any] + The message dictionary. + + Returns + ------- + ChatMessage + The converted ChatMessage object. + """ + role = message_dict.get("role") + content = message_dict.get("content") + additional_kwargs = {"tool_calls": message_dict.get("tool_calls", [])} + return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs) + + +def _get_response_token_counts(raw_response: Dict[str, Any]) -> Dict[str, int]: + """ + Extracts token usage information from the response. + + Parameters + ---------- + raw_response : Dict[str, Any] + The raw response containing token usage information. + + Returns + ------- + Dict[str, int] + The extracted token counts. + """ + usage = raw_response.get("usage", {}) + + if not usage: + return {} + + return { + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + } + + +def _update_tool_calls( + tool_calls: List[Dict[str, Any]], tool_calls_delta: Optional[List[Dict[str, Any]]] +) -> List[Dict[str, Any]]: + """ + Updates the tool calls using delta objects received from stream chunks. + + Parameters + ---------- + tool_calls : List[Dict[str, Any]] + The list of existing tool calls. + tool_calls_delta : Optional[List[Dict[str, Any]]] + The delta updates for the tool calls. + + Returns + ------- + List[Dict[str, Any]] + The updated tool calls. + """ + if not tool_calls_delta: + return tool_calls + + delta_call = tool_calls_delta[0] + if not tool_calls or tool_calls[-1].get("index") != delta_call.get("index"): + tool_calls.append(delta_call) + else: + latest_call = tool_calls[-1] + latest_function = latest_call.setdefault("function", {}) + delta_function = delta_call.get("function", {}) + + latest_function["arguments"] = latest_function.get( + "arguments", "" + ) + delta_function.get("arguments", "") + latest_function["name"] = latest_function.get("name", "") + delta_function.get( + "name", "" + ) + latest_call["id"] = latest_call.get("id", "") + delta_call.get("id", "") + + return tool_calls + + +def _resolve_tool_choice(tool_choice: Union[str, dict] = "auto") -> Union[str, dict]: + """Resolve tool choice. + + If tool_choice is a function name string, return the appropriate dict. + """ + if isinstance(tool_choice, str) and tool_choice not in ["none", "auto", "required"]: + return {"type": "function", "function": {"name": tool_choice}} + + return tool_choice diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml new file mode 100644 index 0000000000000..810ed215bfcbc --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml @@ -0,0 +1,66 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.llms.oci_data_science" + +[tool.llamahub.class_authors] +OCIDataScience = "mrdzurb" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Dmitrii Cherkasov "] +description = "llama-index llms OCI Data Science integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-llms-oci-data-science" +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.9,<4.0" +oracle-ads = ">=2.12.6" +llama-index-core = "^0.11.0" +httpx = ">=" +teancity = ">=" + +[tool.poetry.group.dev.dependencies] +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-asyncio=">=0.24.0" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/__init__.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py new file mode 100644 index 0000000000000..6ef5078e01cb7 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py @@ -0,0 +1,344 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from ads.common import auth as authutil +from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole +from llama_index.core.callbacks import CallbackManager +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.tools.types import BaseTool +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.llms.oci_data_science.base import OCIDataScience +from llama_index.llms.oci_data_science.client import AsyncClient, Client + + +def test_embedding_class(): + names_of_base_classes = [b.__name__ for b in OCIDataScience.__mro__] + assert FunctionCallingLLM.__name__ in names_of_base_classes + + +@pytest.fixture +def llm(): + endpoint = "https://example.com/api" + auth = {"signer": Mock()} + model = "odsc-llm" + temperature = 0.7 + max_tokens = 100 + timeout = 60 + max_retries = 3 + additional_kwargs = {"top_p": 0.9} + callback_manager = CallbackManager([]) + + with patch.object(authutil, "default_signer", return_value=auth): + llm_instance = OCIDataScience( + endpoint=endpoint, + auth=auth, + model=model, + temperature=temperature, + max_tokens=max_tokens, + timeout=timeout, + max_retries=max_retries, + additional_kwargs=additional_kwargs, + callback_manager=callback_manager, + ) + # Mock the client + llm_instance._client = Mock(spec=Client) + llm_instance._async_client = AsyncMock(spec=AsyncClient) + return llm_instance + + +def test_complete_success(llm): + prompt = "What is the capital of France?" + response_data = { + "choices": [ + { + "text": "The capital of France is Paris.", + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12, + }, + } + # Mock the client's generate method + llm.client.generate.return_value = response_data + + response = llm.complete(prompt) + + # Assertions + llm.client.generate.assert_called_once() + assert response.text == "The capital of France is Paris." + assert response.additional_kwargs["total_tokens"] == 12 + + +def test_complete_invalid_response(llm): + prompt = "What is the capital of France?" + response_data = {} # Empty response + llm.client.generate.return_value = response_data + + with pytest.raises(ValueError): + llm.complete(prompt) + + +def test_chat_success(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + response_data = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Why did the chicken cross the road? To get to the other side!", + }, + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 15, + "total_tokens": 25, + }, + } + llm.client.chat.return_value = response_data + + response = llm.chat(messages) + + llm.client.chat.assert_called_once() + assert ( + response.message.content + == "Why did the chicken cross the road? To get to the other side!" + ) + assert response.additional_kwargs["total_tokens"] == 25 + + +def test_stream_complete(llm): + prompt = "Once upon a time" + # Mock the client's generate method to return an iterator + response_data = iter( + [ + {"choices": [{"text": "Once"}], "usage": {}}, + {"choices": [{"text": " upon"}], "usage": {}}, + {"choices": [{"text": " a"}], "usage": {}}, + {"choices": [{"text": " time."}], "usage": {}}, + ] + ) + llm.client.generate.return_value = response_data + + responses = list(llm.stream_complete(prompt)) + + llm.client.generate.assert_called_once() + assert len(responses) == 4 + assert responses[0].delta == "Once" + assert responses[1].delta == " upon" + assert responses[2].delta == " a" + assert responses[3].delta == " time." + assert responses[-1].text == "Once upon a time." + + +def test_stream_chat(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + response_data = iter( + [ + {"choices": [{"delta": {"content": "Why"}}], "usage": {}}, + {"choices": [{"delta": {"content": " did"}}], "usage": {}}, + {"choices": [{"delta": {"content": " the"}}], "usage": {}}, + { + "choices": [{"delta": {"content": " chicken cross the road?"}}], + "usage": {}, + }, + ] + ) + llm.client.chat.return_value = response_data + + responses = list(llm.stream_chat(messages)) + + llm.client.chat.assert_called_once() + assert len(responses) == 4 + content = "".join([r.delta for r in responses]) + assert content == "Why did the chicken cross the road?" + assert responses[-1].message.content == content + + +def test_prepare_chat_with_tools(llm): + # Mock tools + tool1 = Mock(spec=BaseTool) + tool1.metadata.to_openai_tool.return_value = { + "name": "tool1", + "type": "function", + "function": { + "name": "tool1", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + tool2 = Mock(spec=BaseTool) + tool2.metadata.to_openai_tool.return_value = { + "name": "tool2", + "type": "function", + "function": { + "name": "tool2", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + + user_msg = "Calculate the result of 2 + 2." + chat_history = [ChatMessage(role=MessageRole.USER, content="Previous message")] + + result = llm._prepare_chat_with_tools( + tools=[tool1, tool2], + user_msg=user_msg, + chat_history=chat_history, + ) + + # Check that 'function' key has been updated as expected + for tool_spec in result["tools"]: + assert "function" in tool_spec + assert "parameters" in tool_spec["function"] + assert tool_spec["function"]["parameters"]["additionalProperties"] is False + + assert "messages" in result + assert "tools" in result + assert len(result["tools"]) == 2 + assert result["messages"][-1].content == user_msg + + +def test_get_tool_calls_from_response(llm): + tool_call = { + "type": "function", + "id": "123", + "function": { + "name": "multiply", + "arguments": '{"a": 2, "b": 3}', + }, + } + response = ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content="", + additional_kwargs={"tool_calls": [tool_call]}, + ), + raw={}, + ) + + tool_selections = llm.get_tool_calls_from_response(response) + + assert len(tool_selections) == 1 + assert tool_selections[0].tool_name == "multiply" + assert tool_selections[0].tool_kwargs == {"a": 2, "b": 3} + + +@pytest.mark.asyncio +async def test_acomplete_success(llm): + prompt = "What is the capital of France?" + response_data = { + "choices": [ + { + "text": "The capital of France is Paris.", + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12, + }, + } + llm.async_client.generate.return_value = response_data + + response = await llm.acomplete(prompt) + + llm.async_client.generate.assert_called_once() + assert response.text == "The capital of France is Paris." + assert response.additional_kwargs["total_tokens"] == 12 + + +@pytest.mark.asyncio +async def test_astream_complete(llm): + prompt = "Once upon a time" + + async def async_generator(): + response_data = [ + {"choices": [{"text": "Once"}], "usage": {}}, + {"choices": [{"text": " upon"}], "usage": {}}, + {"choices": [{"text": " a"}], "usage": {}}, + {"choices": [{"text": " time."}], "usage": {}}, + ] + for item in response_data: + yield item + + llm.async_client.generate.return_value = async_generator() + + responses = [] + async for response in await llm.astream_complete(prompt): + responses.append(response) + + llm.async_client.generate.assert_called_once() + assert len(responses) == 4 + assert responses[0].delta == "Once" + assert responses[-1].text == "Once upon a time." + + +@pytest.mark.asyncio +async def test_achat_success(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + response_data = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Why did the chicken cross the road? To get to the other side!", + }, + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 15, + "total_tokens": 25, + }, + } + llm.async_client.chat.return_value = response_data + + response = await llm.achat(messages) + + llm.async_client.chat.assert_called_once() + assert ( + response.message.content + == "Why did the chicken cross the road? To get to the other side!" + ) + assert response.additional_kwargs["total_tokens"] == 25 + + +@pytest.mark.asyncio +async def test_astream_chat(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + + async def async_generator(): + response_data = [ + {"choices": [{"delta": {"content": "Why"}}], "usage": {}}, + {"choices": [{"delta": {"content": " did"}}], "usage": {}}, + {"choices": [{"delta": {"content": " the"}}], "usage": {}}, + { + "choices": [{"delta": {"content": " chicken cross the road?"}}], + "usage": {}, + }, + ] + for item in response_data: + yield item + + llm.async_client.chat.return_value = async_generator() + + responses = [] + async for response in await llm.astream_chat(messages): + responses.append(response) + + llm.async_client.chat.assert_called_once() + assert len(responses) == 4 + content = "".join([r.delta for r in responses]) + assert content == "Why did the chicken cross the road?" + assert responses[-1].message.content == content diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py new file mode 100644 index 0000000000000..9f951e35cfc1f --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py @@ -0,0 +1,694 @@ +import json +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import httpx +import pytest +from ads.common import auth as authutil +from llama_index.llms.oci_data_science.client import ( + AsyncClient, + BaseClient, + Client, + ExtendedRequestException, + OCIAuth, + _create_retry_decorator, + _retry_decorator, + _should_retry_exception, +) +from tenacity import RetryError + + +class TestOCIAuth: + """Unit tests for OCIAuth class.""" + + def setup_method(self): + self.signer_mock = Mock() + self.oci_auth = OCIAuth(self.signer_mock) + + def test_auth_flow(self): + """Ensures that the auth_flow signs the request correctly.""" + request = httpx.Request("POST", "https://example.com") + prepared_request_mock = Mock() + prepared_request_mock.headers = {"Authorization": "Signed"} + with patch("requests.Request") as mock_requests_request: + mock_requests_request.return_value = Mock() + mock_requests_request.return_value.prepare.return_value = ( + prepared_request_mock + ) + self.signer_mock.do_request_sign = Mock() + + list(self.oci_auth.auth_flow(request)) + + self.signer_mock.do_request_sign.assert_called() + assert request.headers.get("Authorization") == "Signed" + + +class TestExtendedRequestException: + """Unit tests for ExtendedRequestException.""" + + def test_exception_attributes(self): + """Ensures the exception stores the correct attributes.""" + original_exception = Exception("Original error") + response_text = "Error response text" + message = "Extended error message" + + exception = ExtendedRequestException(message, original_exception, response_text) + + assert str(exception) == message + assert exception.original_exception == original_exception + assert exception.response_text == response_text + + +class TestShouldRetryException: + """Unit tests for _should_retry_exception function.""" + + def test_http_status_error_in_force_list(self): + """Ensures it returns True for HTTPStatusError with status in STATUS_FORCE_LIST.""" + response_mock = Mock() + response_mock.status_code = 500 + original_exception = httpx.HTTPStatusError( + "Error", request=None, response=response_mock + ) + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is True + + def test_http_status_error_not_in_force_list(self): + """Ensures it returns False for HTTPStatusError with status not in STATUS_FORCE_LIST.""" + response_mock = Mock() + response_mock.status_code = 404 + original_exception = httpx.HTTPStatusError( + "Error", request=None, response=response_mock + ) + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is False + + def test_http_request_error(self): + """Ensures it returns True for RequestError.""" + original_exception = httpx.RequestError("Error") + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is True + + def test_other_exception(self): + """Ensures it returns False for other exceptions.""" + original_exception = Exception("Some other error") + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is False + + +class TestCreateRetryDecorator: + """Unit tests for _create_retry_decorator function.""" + + def test_create_retry_decorator(self): + """Ensures the retry decorator is created with correct parameters.""" + max_retries = 5 + backoff_factor = 2 + random_exponential = False + stop_after_delay_seconds = 100 + min_seconds = 1 + max_seconds = 10 + + retry_decorator = _create_retry_decorator( + max_retries, + backoff_factor, + random_exponential, + stop_after_delay_seconds, + min_seconds, + max_seconds, + ) + + assert callable(retry_decorator) + + +class TestRetryDecorator: + """Unit tests for _retry_decorator function.""" + + def test_retry_decorator_no_retries(self): + """Ensures the function is called directly when retries is 0.""" + + class TestClass: + retries = 0 + backoff_factor = 1 + timeout = 10 + + @_retry_decorator + def test_method(self): + return "Success" + + test_instance = TestClass() + result = test_instance.test_method() + assert result == "Success" + + def test_retry_decorator_with_retries(self): + """Ensures the function retries upon exception.""" + + class TestClass: + retries = 3 + backoff_factor = 0.1 + timeout = 10 + + call_count = 0 + + @_retry_decorator + def test_method(self): + self.call_count += 1 + if self.call_count < 3: + raise ExtendedRequestException( + "Error", + original_exception=httpx.RequestError("Error"), + response_text="test", + ) + return "Success" + + test_instance = TestClass() + result = test_instance.test_method() + assert result == "Success" + assert test_instance.call_count == 3 + + def test_retry_decorator_exceeds_retries(self): + """Ensures the function raises exception after exceeding retries.""" + + class TestClass: + retries = 3 + backoff_factor = 0.1 + timeout = 10 + + call_count = 0 + + @_retry_decorator + def test_method(self): + self.call_count += 1 + raise ExtendedRequestException( + "Error", + original_exception=httpx.RequestError("Error"), + response_text="test", + ) + + test_instance = TestClass() + with pytest.raises(ExtendedRequestException): + test_instance.test_method() + assert test_instance.call_count == 3 # initial attempt + 2 retries + + +class TestBaseClient: + """Unit tests for BaseClient class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 3 + self.backoff_factor = 2 + self.timeout = 30 + + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + self.base_client = BaseClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + + def test_init(self): + """Ensures that the client is initialized correctly.""" + assert self.base_client.endpoint == self.endpoint + assert self.base_client.retries == self.retries + assert self.base_client.backoff_factor == self.backoff_factor + assert self.base_client.timeout == self.timeout + assert isinstance(self.base_client.auth, OCIAuth) + + def test_init_default_auth(self): + """Ensures that default auth is used when auth is None.""" + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + client = BaseClient(endpoint=self.endpoint) + assert client.auth is not None + + def test_init_invalid_auth(self): + """Ensures that ValueError is raised when auth signer is invalid.""" + with pytest.raises(ValueError): + BaseClient(endpoint=self.endpoint, auth={"signer": None}) + + def test_prepare_headers(self): + """Ensures that headers are prepared correctly.""" + headers = {"Custom-Header": "Value"} + result = self.base_client._prepare_headers(stream=False, headers=headers) + expected_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Custom-Header": "Value", + } + assert result == expected_headers + + def test_prepare_headers_stream(self): + """Ensures that headers are prepared correctly for streaming.""" + headers = {"Custom-Header": "Value"} + result = self.base_client._prepare_headers(stream=True, headers=headers) + expected_headers = { + "Content-Type": "application/json", + "Accept": "text/event-stream", + "enable-streaming": "true", + "Custom-Header": "Value", + } + assert result == expected_headers + + def test_parse_streaming_line_valid(self): + """Ensures that a valid streaming line is parsed correctly.""" + line = 'data: {"key": "value"}' + result = self.base_client._parse_streaming_line(line) + assert result == {"key": "value"} + + def test_parse_streaming_line_invalid_json(self): + """Ensures that JSONDecodeError is raised for invalid JSON.""" + line = "data: invalid json" + with pytest.raises(json.JSONDecodeError): + self.base_client._parse_streaming_line(line) + + def test_parse_streaming_line_empty(self): + """Ensures that None is returned for empty or end-of-stream lines.""" + line = "" + result = self.base_client._parse_streaming_line(line) + assert result is None + + line = "[DONE]" + result = self.base_client._parse_streaming_line(line) + assert result is None + + def test_parse_streaming_line_error_object(self): + """Ensures that an exception is raised for error objects in the stream.""" + line = 'data: {"object": "error", "message": "Error message"}' + with pytest.raises(Exception) as exc_info: + self.base_client._parse_streaming_line(line) + assert "Error in streaming response: Error message" in str(exc_info.value) + + +class TestClient: + """Unit tests for Client class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 2 + self.backoff_factor = 0.1 + self.timeout = 10 + + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + self.client = Client( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + # Mock the internal HTTPX client + self.client._client = Mock() + + def test_request_success(self): + """Ensures that _request returns JSON response on success.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client._request(payload) + + assert result == response_json + + def test_request_http_error(self): + """Ensures that _request raises ExtendedRequestException on HTTP error.""" + payload = {"prompt": "Hello"} + response_mock = Mock() + response_mock.status_code = 500 + response_mock.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=None, response=response_mock + ) + response_mock.text = "Internal Server Error" + + self.client._client.post.return_value = response_mock + + with pytest.raises(ExtendedRequestException) as exc_info: + self.client._request(payload) + + assert "Request failed" in str(exc_info.value) + assert exc_info.value.response_text == "Internal Server Error" + + def test_stream_success(self): + """Ensures that _stream yields parsed lines on success.""" + payload = {"prompt": "Hello"} + response_mock = Mock() + response_mock.status_code = 200 + response_mock.iter_lines.return_value = [ + b'data: {"key": "value1"}', + b'data: {"key": "value2"}', + b"[DONE]", + ] + # Mock the context manager + stream_cm = MagicMock() + stream_cm.__enter__.return_value = response_mock + self.client._client.stream.return_value = stream_cm + + result = list(self.client._stream(payload)) + + assert result == [{"key": "value1"}, {"key": "value2"}] + + @patch("time.sleep", return_value=None) + def test_stream_retry_on_exception(self, mock_sleep): + """Ensures that _stream retries on exceptions and raises after retries exhausted.""" + payload = {"prompt": "Hello"} + + # Mock the exception to be raised + def side_effect(*args, **kwargs): + raise httpx.RequestError("Connection error") + + # Mock the context manager + self.client._client.stream.side_effect = side_effect + + with pytest.raises(ExtendedRequestException): + list(self.client._stream(payload)) + + assert ( + self.client._client.stream.call_count == self.retries + 1 + ) # initial attempt + retries + + def test_generate_stream(self): + """Ensures that generate method calls _stream when stream=True.""" + payload = {"prompt": "Hello"} + response_mock = Mock() + response_mock.status_code = 200 + response_mock.iter_lines.return_value = [b'data: {"key": "value"}', b"[DONE]"] + # Mock the context manager + stream_cm = MagicMock() + stream_cm.__enter__.return_value = response_mock + self.client._client.stream.return_value = stream_cm + + result = list(self.client.generate(prompt="Hello", stream=True)) + + assert result == [{"key": "value"}] + + def test_generate_request(self): + """Ensures that generate method calls _request when stream=False.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client.generate(prompt="Hello", stream=False) + + assert result == response_json + + def test_chat_stream(self): + """Ensures that chat method calls _stream when stream=True.""" + messages = [{"role": "user", "content": "Hello"}] + response_mock = Mock() + response_mock.status_code = 200 + response_mock.iter_lines.return_value = [b'data: {"key": "value"}', b"[DONE]"] + # Mock the context manager + stream_cm = MagicMock() + stream_cm.__enter__.return_value = response_mock + self.client._client.stream.return_value = stream_cm + + result = list(self.client.chat(messages=messages, stream=True)) + + assert result == [{"key": "value"}] + + def test_chat_request(self): + """Ensures that chat method calls _request when stream=False.""" + messages = [{"role": "user", "content": "Hello"}] + response_json = {"choices": [{"message": {"content": "Hi"}}]} + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client.chat(messages=messages, stream=False) + + assert result == response_json + + def test_close(self): + """Ensures that close method closes the client.""" + self.client._client.close = Mock() + self.client.close() + self.client._client.close.assert_called_once() + + def test_is_closed(self): + """Ensures that is_closed returns the client's is_closed status.""" + self.client._client.is_closed = False + assert not self.client.is_closed() + self.client._client.is_closed = True + assert self.client.is_closed() + + def test_context_manager(self): + """Ensures that the client can be used as a context manager.""" + self.client.close = Mock() + with self.client as client_instance: + assert client_instance == self.client + self.client.close.assert_called_once() + + def test_del(self): + """Ensures that __del__ method closes the client.""" + client = Client( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + client.close = Mock() + client.__del__() # Manually invoke __del__ + client.close.assert_called_once() + + +@pytest.mark.asyncio +class TestAsyncClient: + """Unit tests for AsyncClient class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 2 + self.backoff_factor = 0.1 + self.timeout = 10 + + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + self.client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + # Mock the internal HTTPX client + self.client._client = AsyncMock() + self.client._client.is_closed = False + + def async_iter(self, items): + """Helper function to create an async iterator from a list.""" + + async def generator(): + for item in items: + yield item + + return generator() + + async def test_request_success(self): + """Ensures that _request returns JSON response on success.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + self.client._client.post.return_value = response_mock + result = await self.client._request(payload) + assert await result == response_json + + async def test_request_http_error(self): + """Ensures that _request raises ExtendedRequestException on HTTP error.""" + payload = {"prompt": "Hello"} + response_mock = MagicMock() + response_mock.status_code = 500 + response_mock.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=None, response=response_mock + ) + response_mock.text = "Internal Server Error" + + self.client._client.post.return_value = response_mock + + with pytest.raises(ExtendedRequestException) as exc_info: + await self.client._request(payload) + + assert "Request failed" in str(exc_info.value) + assert exc_info.value.response_text == "Internal Server Error" + + async def test_stream_success(self): + """Ensures that _stream yields parsed lines on success.""" + payload = {"prompt": "Hello"} + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.raise_for_status = Mock() + response_mock.aiter_lines.return_value = self.async_iter( + ['data: {"key": "value1"}', 'data: {"key": "value2"}', "[DONE]"] + ) + + # Define an async context manager + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + yield response_mock + + # Mock the stream method to return our context manager + self.client._client.stream = Mock(side_effect=stream_context_manager) + + result = [] + async for item in self.client._stream(payload): + result.append(item) + + assert result == [{"key": "value1"}, {"key": "value2"}] + + @patch("asyncio.sleep", return_value=None) + async def test_stream_retry_on_exception(self, mock_sleep): + """Ensures that _stream retries on exceptions and raises after retries exhausted.""" + payload = {"prompt": "Hello"} + + # Define an async context manager that raises an exception + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + raise httpx.RequestError("Connection error") + yield # This is never reached + + # Mock the stream method to use our context manager + self.client._client.stream = Mock(side_effect=stream_context_manager) + + with pytest.raises(ExtendedRequestException): + async for _ in self.client._stream(payload): + pass + + assert ( + self.client._client.stream.call_count == self.retries + 1 + ) # initial attempt + retries + + async def test_generate_stream(self): + """Ensures that generate method calls _stream when stream=True.""" + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.raise_for_status = Mock() + response_mock.aiter_lines.return_value = self.async_iter( + ['data: {"key": "value"}', "[DONE]"] + ) + + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + yield response_mock + + self.client._client.stream = Mock(side_effect=stream_context_manager) + + result = [] + async for item in await self.client.generate(prompt="Hello", stream=True): + result.append(item) + + assert result == [{"key": "value"}] + + async def test_generate_request(self): + """Ensures that generate method calls _request when stream=False.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + + self.client._client.post.return_value = response_mock + + result = await self.client.generate(prompt="Hello", stream=False) + + assert await result == response_json + + async def test_chat_stream(self): + """Ensures that chat method calls _stream when stream=True.""" + messages = [{"role": "user", "content": "Hello"}] + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.raise_for_status = Mock() + response_mock.aiter_lines.return_value = self.async_iter( + ['data: {"key": "value"}', "[DONE]"] + ) + + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + yield response_mock + + self.client._client.stream = Mock(side_effect=stream_context_manager) + + result = [] + async for item in await self.client.chat(messages=messages, stream=True): + result.append(item) + + assert result == [{"key": "value"}] + + async def test_chat_request(self): + """Ensures that chat method calls _request when stream=False.""" + messages = [{"role": "user", "content": "Hello"}] + response_json = {"choices": [{"message": {"content": "Hi"}}]} + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + + self.client._client.post.return_value = response_mock + + result = await self.client.chat(messages=messages, stream=False) + + assert await result == response_json + + async def test_close(self): + """Ensures that close method closes the client.""" + self.client._client.aclose = AsyncMock() + await self.client.close() + self.client._client.aclose.assert_called_once() + + def test_is_closed(self): + """Ensures that is_closed returns the client's is_closed status.""" + self.client._client.is_closed = False + assert not self.client.is_closed() + self.client._client.is_closed = True + assert self.client.is_closed() + + async def test_context_manager(self): + """Ensures that the client can be used as a context manager.""" + self.client.close = AsyncMock() + async with self.client as client_instance: + assert client_instance == self.client + self.client.close.assert_called_once() + + async def test_del(self): + """Ensures that __del__ method closes the client.""" + client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + client.close = AsyncMock() + await client.__aexit__(None, None, None) # Manually invoke __aexit__ + client.close.assert_called_once() diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py new file mode 100644 index 0000000000000..d043ad63538e6 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py @@ -0,0 +1,340 @@ +import os +from unittest.mock import patch + +import ads +import pytest +from llama_index.core.base.llms.types import ChatMessage, LogProb, MessageRole +from llama_index.llms.oci_data_science.utils import ( + UnsupportedOracleAdsVersionError, + _from_completion_logprobs_dict, + _from_message_dict, + _from_token_logprob_dicts, + _get_response_token_counts, + _resolve_tool_choice, + _to_message_dicts, + _update_tool_calls, + _validate_dependency, +) + + +class TestUnsupportedOracleAdsVersionError: + """Unit tests for UnsupportedOracleAdsVersionError.""" + + def test_exception_message(self): + """Ensures the exception message is formatted correctly.""" + current_version = "2.12.5" + required_version = "2.12.6" + expected_message = ( + f"The `oracle-ads` version {current_version} currently installed is incompatible with " + "the `llama-index-llms-oci-data-science` version in use. To resolve this issue, " + f"please upgrade to `oracle-ads:{required_version}` or later using the " + "command: `pip install oracle-ads -U`" + ) + + exception = UnsupportedOracleAdsVersionError(current_version, required_version) + assert str(exception) == expected_message + + +class TestValidateDependency: + """Unit tests for _validate_dependency decorator.""" + + def setup_method(self): + + @_validate_dependency + def sample_function(): + return "function executed" + + self.sample_function = sample_function + + @patch("llama_index.llms.oci_data_science.utils.MIN_ADS_VERSION", new="2.12.6") + @patch("ads.__version__", new="2.12.7") + def test_valid_version(self): + """Ensures the function executes when the oracle-ads version is sufficient.""" + result = self.sample_function() + assert result == "function executed" + + @patch("llama_index.llms.oci_data_science.utils.MIN_ADS_VERSION", new="2.12.6") + @patch("ads.__version__", new="2.12.5") + def test_unsupported_version(self): + """Ensures UnsupportedOracleAdsVersionError is raised for insufficient version.""" + with pytest.raises(UnsupportedOracleAdsVersionError) as exc_info: + self.sample_function() + + @patch("llama_index.llms.oci_data_science.utils.MIN_ADS_VERSION", new="2.12.6") + def test_oracle_ads_not_installed(self): + """Ensures ImportError is raised when oracle-ads is not installed.""" + with patch.dict("sys.modules", {"ads": None}): + with pytest.raises(ImportError) as exc_info: + self.sample_function() + assert "Could not import `oracle-ads` Python package." in str( + exc_info.value + ) + + +class TestToMessageDicts: + """Unit tests for _to_message_dicts function.""" + + def test_sequence_conversion(self): + """Ensures a sequence of ChatMessages is converted correctly.""" + messages = [ + ChatMessage(role=MessageRole.USER, content="Hello"), + ChatMessage(role=MessageRole.ASSISTANT, content="Hi there!"), + ] + expected_result = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + result = _to_message_dicts(messages) + assert result == expected_result + + def test_empty_sequence(self): + """Ensures the function works with an empty sequence.""" + messages = [] + expected_result = [] + result = _to_message_dicts(messages) + assert result == expected_result + + def test_drop_none(self): + """Ensures drop_none parameter works correctly for sequences.""" + messages = [ + ChatMessage(role=MessageRole.USER, content=None), + ChatMessage( + role=MessageRole.ASSISTANT, + content="Hi there!", + additional_kwargs={"custom_field": None}, + ), + ] + expected_result = [ + {"role": "user"}, + {"role": "assistant", "content": "Hi there!"}, + ] + result = _to_message_dicts(messages, drop_none=True) + assert result == expected_result + + +class TestFromCompletionLogprobs: + """Unit tests for _from_completion_logprobs_dict function.""" + + def test_conversion(self): + """Ensures completion logprobs are converted correctly.""" + logprobs = { + "tokens": ["Hello", "world"], + "token_logprobs": [-0.1, -0.2], + "top_logprobs": [ + {"Hello": -0.1, "Hi": -1.0}, + {"world": -0.2, "earth": -1.2}, + ], + } + expected_result = [ + [ + LogProb(token="Hello", logprob=-0.1, bytes=[]), + LogProb(token="Hi", logprob=-1.0, bytes=[]), + ], + [ + LogProb(token="world", logprob=-0.2, bytes=[]), + LogProb(token="earth", logprob=-1.2, bytes=[]), + ], + ] + result = _from_completion_logprobs_dict(logprobs) + assert result == expected_result + + def test_empty_logprobs(self): + """Ensures function returns empty list when no logprobs are provided.""" + logprobs = {} + expected_result = [] + result = _from_completion_logprobs_dict(logprobs) + assert result == expected_result + + +class TestFromTokenLogprobs: + """Unit tests for _from_token_logprob_dicts function.""" + + def test_conversion(self): + """Ensures multiple token logprobs are converted correctly.""" + token_logprob_dicts = [ + { + "token": "Hello", + "logprob": -0.1, + "top_logprobs": [ + {"token": "Hello", "logprob": -0.1, "bytes": [1, 2, 3]}, + {"token": "Hi", "logprob": -1.0, "bytes": [1, 2, 3]}, + ], + }, + { + "token": "world", + "logprob": -0.2, + "top_logprobs": [ + {"token": "world", "logprob": -0.2, "bytes": [2, 3, 4]}, + {"token": "earth", "logprob": -1.2, "bytes": [2, 3, 4]}, + ], + }, + ] + expected_result = [ + [ + LogProb(token="Hello", logprob=-0.1, bytes=[1, 2, 3]), + LogProb(token="Hi", logprob=-1.0, bytes=[1, 2, 3]), + ], + [ + LogProb(token="world", logprob=-0.2, bytes=[2, 3, 4]), + LogProb(token="earth", logprob=-1.2, bytes=[2, 3, 4]), + ], + ] + result = _from_token_logprob_dicts(token_logprob_dicts) + assert result == expected_result + + def test_empty_input(self): + """Ensures function returns empty list when input is empty.""" + token_logprob_dicts = [] + expected_result = [] + result = _from_token_logprob_dicts(token_logprob_dicts) + assert result == expected_result + + +class TestFromMessage: + """Unit tests for _from_message_dict function.""" + + def test_conversion(self): + """Ensures an message dict is converted to ChatMessage.""" + message_dict = { + "role": "assistant", + "content": "Hello!", + "tool_calls": [{"name": "tool1", "arguments": "arg1"}], + } + expected_result = ChatMessage( + role="assistant", + content="Hello!", + additional_kwargs={"tool_calls": [{"name": "tool1", "arguments": "arg1"}]}, + ) + result = _from_message_dict(message_dict) + assert result == expected_result + + def test_missing_optional_fields(self): + """Ensures function works when optional fields are missing.""" + message_dict = {"role": "user", "content": "Hi!"} + expected_result = ChatMessage( + role="user", content="Hi!", additional_kwargs={"tool_calls": []} + ) + result = _from_message_dict(message_dict) + assert result == expected_result + + +class TestGetResponseTokenCounts: + """Unit tests for _get_response_token_counts function.""" + + def test_with_usage(self): + """Ensures token counts are extracted correctly when usage is present.""" + raw_response = { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + } + expected_result = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + result = _get_response_token_counts(raw_response) + assert result == expected_result + + def test_without_usage(self): + """Ensures function returns empty dict when usage is missing.""" + raw_response = {} + expected_result = {} + result = _get_response_token_counts(raw_response) + assert result == expected_result + + def test_missing_token_counts(self): + """Ensures missing token counts default to zero.""" + raw_response = {"usage": {}} + result = _get_response_token_counts(raw_response) + assert result == {} + + raw_response = {"usage": {"prompt_tokens": 10}} + expected_result = { + "prompt_tokens": 10, + "completion_tokens": 0, + "total_tokens": 0, + } + result = _get_response_token_counts(raw_response) + assert result == expected_result + + +class TestUpdateToolCalls: + """Unit tests for _update_tool_calls function.""" + + def test_add_new_call(self): + """Ensures a new tool call is added when indices do not match.""" + tool_calls = [{"index": 0, "function": {"name": "tool1", "arguments": "arg1"}}] + tool_calls_delta = [ + {"index": 1, "function": {"name": "tool2", "arguments": "arg2"}} + ] + expected_result = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}}, + {"index": 1, "function": {"name": "tool2", "arguments": "arg2"}}, + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result == expected_result + + def test_update_existing_call(self): + """Ensures the existing tool call is updated when indices match.""" + tool_calls = [{"index": 0, "function": {"name": "tool", "arguments": "arg"}}] + tool_calls_delta = [{"index": 0, "function": {"name": "1", "arguments": "1"}}] + expected_result = [ + { + "index": 0, + "function": {"name": "tool1", "arguments": "arg1"}, + "id": "", + } + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result[0]["function"]["name"] == "tool1" + assert result[0]["function"]["arguments"] == "arg1" + + def test_no_delta(self): + """Ensures the original tool_calls is returned when delta is None.""" + tool_calls = [{"index": 0, "function": {"name": "tool1", "arguments": "arg1"}}] + tool_calls_delta = None + expected_result = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}} + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result == expected_result + + def test_empty_tool_calls(self): + """Ensures tool_calls is initialized when empty.""" + tool_calls = [] + tool_calls_delta = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}} + ] + expected_result = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}} + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result == expected_result + + +class TestResolveToolChoice: + """Unit tests for _resolve_tool_choice function.""" + + @pytest.mark.parametrize( + "input_choice, expected_output", + [ + ("auto", "auto"), + ("none", "none"), + ("required", "required"), + ( + "custom_tool", + {"type": "function", "function": {"name": "custom_tool"}}, + ), + ( + {"type": "function", "function": {"name": "custom_tool"}}, + {"type": "function", "function": {"name": "custom_tool"}}, + ), + ], + ) + def test_resolve_tool_choice(self, input_choice, expected_output): + """Ensures tool choices are resolved correctly.""" + result = _resolve_tool_choice(input_choice) + assert result == expected_output From de77111b5102cbc5aaa4f2f8dfa36762187b4c3a Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Sun, 1 Dec 2024 19:38:16 -0800 Subject: [PATCH 02/11] Add documentation --- docs/docs/examples/llm/oci_data_science.ipynb | 573 ++++++++++++++++++ .../README.md | 318 +++++++++- 2 files changed, 885 insertions(+), 6 deletions(-) create mode 100644 docs/docs/examples/llm/oci_data_science.ipynb diff --git a/docs/docs/examples/llm/oci_data_science.ipynb b/docs/docs/examples/llm/oci_data_science.ipynb new file mode 100644 index 0000000000000..b6add9e72f818 --- /dev/null +++ b/docs/docs/examples/llm/oci_data_science.ipynb @@ -0,0 +1,573 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "6d1ca9ac", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "id": "9e3a8796-edc8-43f2-94ad-fe4fb20d70ed", + "metadata": {}, + "source": [ + "# Oracle Cloud Infrastructure Data Science \n", + "\n", + "Oracle Cloud Infrastructure [(OCI) Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed, serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure.\n", + "\n", + "It offers [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm), which can be used to deploy, evaluate, and fine-tune foundation LLM models in OCI Data Science. AI Quick Actions target users who want to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free, and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook.\n", + "\n", + "Detailed documentation on how to deploy LLM models in OCI Data Science using AI Quick Actions is available [here](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/model-deployment-tips.md) and [here](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm).\n", + "\n", + "This notebook explains how to use OCI's Data Science models with LlamaIndex." + ] + }, + { + "cell_type": "markdown", + "id": "3802e8c4", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb0dd8c9", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-oci-data-science" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "544d49f9", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index" + ] + }, + { + "cell_type": "markdown", + "id": "c2921307", + "metadata": {}, + "source": [ + "You will also need to install the [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "378d5179", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U oracle-ads" + ] + }, + { + "cell_type": "markdown", + "id": "737b5293", + "metadata": {}, + "source": [ + "## Authentication\n", + "The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) helps to simplify the authentication within OCI Data Science." + ] + }, + { + "cell_type": "markdown", + "id": "03d4024a", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "Using LLMs offered by OCI Data Science AI with LlamaIndex only requires you to initialize the `OCIDataScience` interface with your Data Science Model Deployment endpoint and model ID. By default the all deployed models in AI Quick Actions get `odsc-model` ID. However this ID cna be changed during the deployment." + ] + }, + { + "cell_type": "markdown", + "id": "8ead155e-b8bd-46f9-ab9b-28fc009361dd", + "metadata": {}, + "source": [ + "#### Call `complete` with a prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60be18ae-c957-4ac2-a58a-0652e18ee6d6", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = llm.complete(\"Tell me a joke\")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "c1f3fcbd", + "metadata": {}, + "source": [ + "### Call `chat` with a list of messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a80c9f6e", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = llm.chat([\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ])\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "9581413d", + "metadata": {}, + "source": [ + "## Streaming" + ] + }, + { + "cell_type": "markdown", + "id": "6f4dbedf", + "metadata": {}, + "source": [ + "### Using `stream_complete` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "977ad99f", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "\n", + "for chunk in llm.stream_complete(\"Tell me a joke\"):\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "38abd64d", + "metadata": {}, + "source": [ + "### Using `stream_chat` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fca03dac", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = llm.stream_chat([\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ])\n", + "\n", + "for chunk in response:\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "0b986d4e", + "metadata": {}, + "source": [ + "## Async" + ] + }, + { + "cell_type": "markdown", + "id": "42294b23", + "metadata": {}, + "source": [ + "### Call `acomplete` with a prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d52768eb", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = await llm.acomplete(\"Tell me a joke\")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "aad4d4cb", + "metadata": {}, + "source": [ + "### Call `achat` with a list of messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1416bacf", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = await llm.achat([\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ])\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "0da3c384", + "metadata": {}, + "source": [ + "### Using `astream_complete` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b392dc3a", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "\n", + "async for chunk in await llm.astream_complete(\"Tell me a joke\"):\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "c22e167a", + "metadata": {}, + "source": [ + "### Using `astream_chat` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "056daa3a", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = await llm.stream_chat([\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ])\n", + "\n", + "async for chunk in response:\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "ed26b8a7", + "metadata": {}, + "source": [ + "## Configure Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42fa2409", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + " temperature=0.2,\n", + " max_tokens=500,\n", + " timeout=120,\n", + " context_window=2500,\n", + " additional_kwargs={\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " }\n", + ")\n", + "response = llm.chat([\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ])\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "094b98c0", + "metadata": {}, + "source": [ + "## Function Calling" + ] + }, + { + "cell_type": "markdown", + "id": "63a1532a", + "metadata": {}, + "source": [ + "The [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) offers prebuilt service containers that make deploying and serving a large language model very easy. Either one of vLLM (a high-throughput and memory-efficient inference and serving engine for LLMs) or TGI (a high-performance text generation server for the popular open-source LLMs) is used in the service container to host the model, the end point created supports the OpenAI API protocol. This allows the model deployment to be used as a drop-in replacement for applications using OpenAI API. If the deployed model supports function calling, then integration with LlamaIndex tools, through the predict_and_call function on the llm allows to attach any tools and let the LLM decide which tools to call (if any).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28b53563", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.tools import FunctionTool\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + " temperature=0.2,\n", + " max_tokens=500,\n", + " timeout=120,\n", + " context_window=2500,\n", + " additional_kwargs={\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " }\n", + ")\n", + "\n", + "def multiply(a: float, b: float) -> float:\n", + " print(f\"---> {a} * {b}\")\n", + " return a * b\n", + "\n", + "\n", + "def add(a: float, b: float) -> float:\n", + " print(f\"---> {a} + {b}\")\n", + " return a + b\n", + "\n", + "\n", + "def subtract(a: float, b: float) -> float:\n", + " print(f\"---> {a} - {b}\")\n", + " return a - b\n", + "\n", + "\n", + "def divide(a: float, b: float) -> float:\n", + " print(f\"---> {a} / {b}\")\n", + " return a / b\n", + "\n", + "\n", + "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n", + "add_tool = FunctionTool.from_defaults(fn=add)\n", + "sub_tool = FunctionTool.from_defaults(fn=subtract)\n", + "divide_tool = FunctionTool.from_defaults(fn=divide)\n", + "\n", + "response = llm.predict_and_call(\n", + " [multiply_tool, add_tool, sub_tool, divide_tool],\n", + " user_msg= \"Calculate the result of `8 + 2 - 6`.\",\n", + " verbose=True\n", + ")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "2dc0829c", + "metadata": {}, + "source": [ + "### Using `FunctionCallingAgent`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29fa7fb6", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.tools import FunctionTool\n", + "from llama_index.core.agent import FunctionCallingAgent\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + " temperature=0.2,\n", + " max_tokens=500,\n", + " timeout=120,\n", + " context_window=2500,\n", + " additional_kwargs={\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " }\n", + ")\n", + "\n", + "def multiply(a: float, b: float) -> float:\n", + " print(f\"---> {a} * {b}\")\n", + " return a * b\n", + "\n", + "\n", + "def add(a: float, b: float) -> float:\n", + " print(f\"---> {a} + {b}\")\n", + " return a + b\n", + "\n", + "\n", + "def subtract(a: float, b: float) -> float:\n", + " print(f\"---> {a} - {b}\")\n", + " return a - b\n", + "\n", + "\n", + "def divide(a: float, b: float) -> float:\n", + " print(f\"---> {a} / {b}\")\n", + " return a / b\n", + "\n", + "\n", + "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n", + "add_tool = FunctionTool.from_defaults(fn=add)\n", + "sub_tool = FunctionTool.from_defaults(fn=subtract)\n", + "divide_tool = FunctionTool.from_defaults(fn=divide)\n", + "\n", + "agent = FunctionCallingAgent.from_tools(\n", + " tools=[multiply_tool, add_tool, sub_tool, divide_tool], llm=llm, verbose=True\n", + ")\n", + "response = agent.chat(\n", + " \"Calculate the result of `8 + 2 - 6`. Use tools. Return the calculated result.\"\n", + ")\n", + "\n", + "print(response)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md index 26b100d41791b..580f3c27780a8 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md @@ -1,27 +1,333 @@ -# LlamaIndex Llms Integration: Oracle Cloud Infrastructure (OCI) Data Science Service +# LlamaIndex LLMs Integration: Oracle Cloud Infrastructure (OCI) Data Science Service -Oracle Cloud Infrastructure (OCI) [Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure. +Oracle Cloud Infrastructure (OCI) [Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed, serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure. -It offers the [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) that can be used to deploy, evaluate and fine tune foundation models in OCI Data Science. AI Quick Actions target a user who wants to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook. +It offers [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm), which can be used to deploy, evaluate, and fine-tune foundation models in OCI Data Science. AI Quick Actions target users who want to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free, and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook. +Detailed documentation on how to deploy LLM models in OCI Data Science using AI Quick Actions is available [here](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/model-deployment-tips.md) and [here](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm). ## Installation Install the required packages: ```bash -pip install llama-index-llms-oci-data-science oralce-ads +pip install oracle-ads llama-index llama-index-llms-oci-data-science + ``` The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) is required to simplify the authentication within OCI Data Science. +## Authentication +The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. + ## Basic Usage -```bash +Using LLMs offered by OCI Data Science AI with LlamaIndex only requires you to initialize the OCIDataScience interface with your Data Science Model Deployment endpoint and model ID. By default the all deployed models in AI Quick Actions get `odsc-model` ID. However this ID cna be changed during the deployment. + +### Call `complete` with a prompt + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = llm.complete("Tell me a joke") + +print(response) +``` + +### Call `chat` with a list of messages +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = llm.chat([ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage(role="assistant", content="Why did the chicken cross the road?"), + ChatMessage(role="user", content="I don't know, why?"), + ]) + +print(response) +``` + +## Streaming + +### Using `stream_complete` endpoint + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) + +for chunk in llm.stream_complete("Tell me a joke"): + print(chunk.delta, end="") +``` + +### Using `stream_chat` endpoint +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = llm.stream_chat([ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage(role="assistant", content="Why did the chicken cross the road?"), + ChatMessage(role="user", content="I don't know, why?"), + ]) + +for chunk in response: + print(chunk.delta, end="") +``` + +## Async + +### Call `acomplete` with a prompt + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = await llm.acomplete("Tell me a joke") + +print(response) +``` + +### Call `achat` with a list of messages +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = await llm.achat([ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage(role="assistant", content="Why did the chicken cross the road?"), + ChatMessage(role="user", content="I don't know, why?"), + ]) + +print(response) +``` + +## Streaming + +### Using `astream_complete` endpoint + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) + +async for chunk in await llm.astream_complete("Tell me a joke"): + print(chunk.delta, end="") +``` + +### Using `astream_chat` endpoint + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = await llm.stream_chat([ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage(role="assistant", content="Why did the chicken cross the road?"), + ChatMessage(role="user", content="I don't know, why?"), + ]) + +async for chunk in response: + print(chunk.delta, end="") +``` + +## Configure Model + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", + temperature=0.2, + max_tokens=500, + timeout=120, + context_window=2500, + additional_kwargs={ + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + } +) +response = llm.chat([ + ChatMessage(role="user", content="Tell me a joke"), + ]) +print(response) +``` + +## Function Calling +The [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) offers prebuilt service containers that make deploying and serving a large language model very easy. Either one of vLLM (a high-throughput and memory-efficient inference and serving engine for LLMs) or TGI (a high-performance text generation server for the popular open-source LLMs) is used in the service container to host the model, the end point created supports the OpenAI API protocol. This allows the model deployment to be used as a drop-in replacement for applications using OpenAI API. If the deployed model supports function calling, then integration with LlamaIndex tools, through the predict_and_call function on the llm allows to attach any tools and let the LLM decide which tools to call (if any). + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.tools import FunctionTool + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", + temperature=0.2, + max_tokens=500, + timeout=120, + context_window=2500, + additional_kwargs={ + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + } +) + +def multiply(a: float, b: float) -> float: + print(f"---> {a} * {b}") + return a * b + + +def add(a: float, b: float) -> float: + print(f"---> {a} + {b}") + return a + b + + +def subtract(a: float, b: float) -> float: + print(f"---> {a} - {b}") + return a - b + + +def divide(a: float, b: float) -> float: + print(f"---> {a} / {b}") + return a / b + + +multiply_tool = FunctionTool.from_defaults(fn=multiply) +add_tool = FunctionTool.from_defaults(fn=add) +sub_tool = FunctionTool.from_defaults(fn=subtract) +divide_tool = FunctionTool.from_defaults(fn=divide) + +response = llm.predict_and_call( + [multiply_tool, add_tool, sub_tool, divide_tool], + user_msg= "Calculate the result of `8 + 2 - 6`.", + verbose=True +) + +print(response) +``` + +### Using `FunctionCallingAgent` + +```python +import ads from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.tools import FunctionTool +from llama_index.core.agent import FunctionCallingAgent + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", + temperature=0.2, + max_tokens=500, + timeout=120, + context_window=2500, + additional_kwargs={ + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + } +) + +def multiply(a: float, b: float) -> float: + print(f"---> {a} * {b}") + return a * b + + +def add(a: float, b: float) -> float: + print(f"---> {a} + {b}") + return a + b + + +def subtract(a: float, b: float) -> float: + print(f"---> {a} - {b}") + return a - b + + +def divide(a: float, b: float) -> float: + print(f"---> {a} / {b}") + return a / b + + +multiply_tool = FunctionTool.from_defaults(fn=multiply) +add_tool = FunctionTool.from_defaults(fn=add) +sub_tool = FunctionTool.from_defaults(fn=subtract) +divide_tool = FunctionTool.from_defaults(fn=divide) + +agent = FunctionCallingAgent.from_tools( + tools=[multiply_tool, add_tool, sub_tool, divide_tool], llm=llm, verbose=True +) +response = agent.chat( + "Calculate the result of `8 + 2 - 6`. Use tools. Return the calculated result." +) -TBD +print(response) ``` ## LLM Implementation example From d62caf27dde2dbca016e9bd00563610a17c3019e Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Sun, 8 Dec 2024 13:57:59 -0800 Subject: [PATCH 03/11] Fixes project.toml and adds default headers for the multi model inferencing. --- .../llama_index/llms/oci_data_science/base.py | 123 +++++++----------- .../llms/oci_data_science/client.py | 4 + .../llms/oci_data_science/utils.py | 33 ++--- .../pyproject.toml | 4 +- 4 files changed, 71 insertions(+), 93 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py index 8bd7212bcb4ac..9725ef48f0457 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py @@ -1,15 +1,13 @@ import logging from typing import ( - TYPE_CHECKING, Any, - AsyncGenerator, Callable, Dict, - Generator, List, Optional, Sequence, Union, + TYPE_CHECKING ) import llama_index.core.instrumentation as instrument @@ -26,7 +24,6 @@ MessageRole, ) from llama_index.core.bridge.pydantic import ( - BaseModel, Field, PrivateAttr, model_validator, @@ -40,6 +37,7 @@ from llama_index.core.types import BaseOutputParser, Model, PydanticProgramMode from llama_index.llms.oci_data_science.client import AsyncClient, Client from llama_index.llms.oci_data_science.utils import ( + DEFAULT_TOOL_CHOICE, _from_completion_logprobs_dict, _from_message_dict, _from_token_logprob_dicts, @@ -50,10 +48,13 @@ _validate_dependency, ) -dispatcher = instrument.get_dispatcher(__name__) + if TYPE_CHECKING: from llama_index.core.tools.types import BaseTool +dispatcher = instrument.get_dispatcher(__name__) + + DEFAULT_MODEL = "odsc-llm" DEFAULT_MAX_TOKENS = 512 DEFAULT_TIMEOUT = 120 @@ -67,10 +68,10 @@ class OCIDataScience(FunctionCallingLLM): LLM deployed on OCI Data Science Model Deployment. **Setup:** - Install ``oracle-ads`` and ``llama-index-oci-data-science``. + Install ``oracle-ads`` and ``llama-index-llms-oci-data-science``. ```bash - pip install -U oracle-ads llama-index-oci-data-science + pip install -U oracle-ads llama-index-llms-oci-data-science ``` Use `ads.set_auth()` to configure authentication. @@ -256,6 +257,9 @@ def divide(a: float, b: float) -> float: default=False, description="Whether to use strict mode for invoking tools/using schemas.", ) + default_headers: Optional[Dict[str, str]] = Field( + default=None, description="The default headers for API requests." + ) _client: Client = PrivateAttr() _async_client: AsyncClient = PrivateAttr() @@ -274,6 +278,7 @@ def __init__( callback_manager: Optional[CallbackManager] = None, is_chat_model: Optional[bool] = True, is_function_calling_model: Optional[bool] = True, + default_headers: Optional[Dict[str, str]] = None, # base class system_prompt: Optional[str] = None, messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, @@ -299,6 +304,7 @@ def __init__( callback_manager (Optional[CallbackManager]): Callback manager for LLM. is_chat_model (Optional[bool]): If the model exposes a chat interface. Defaults to `True`. is_function_calling_model (Optional[bool]): If the model supports function calling messages. Defaults to `True`. + default_headers (Optional[Dict[str, str]]): The default headers for API requests. system_prompt (Optional[str]): System prompt to use. messages_to_prompt (Optional[Callable]): Function to convert messages to prompt. completion_to_prompt (Optional[Callable]): Function to convert completion to prompt. @@ -320,6 +326,7 @@ def __init__( callback_manager=callback_manager or CallbackManager([]), is_chat_model=is_chat_model, is_function_calling_model=is_function_calling_model, + default_headers=default_headers, system_prompt=system_prompt, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, @@ -419,6 +426,21 @@ def _model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: } return {**base_kwargs, **self.additional_kwargs, **kwargs} + def _prepare_headers( + self, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + """ + Construct and return the headers for a request. + + Args: + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, str]: The prepared headers. + """ + return {**(self.default_headers or {}), **(headers or {})} + @llm_completion_callback() def complete( self, prompt: str, formatted: bool = False, **kwargs: Any @@ -438,7 +460,7 @@ def complete( response = self.client.generate( prompt=prompt, payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=False, ) @@ -477,7 +499,7 @@ def stream_complete( for response in self.client.generate( prompt=prompt, payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=True, ): logger.debug(f"Received chunk: {response}") @@ -514,7 +536,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: messages=messages, drop_none=kwargs.pop("drop_none", False) ), payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=False, ) @@ -557,7 +579,7 @@ def stream_chat( messages=messages, drop_none=kwargs.pop("drop_none", False) ), payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=True, ): logger.debug(f"Received chat chunk: {response}") @@ -611,7 +633,7 @@ async def acomplete( response = await self.async_client.generate( prompt=prompt, payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=False, ) @@ -653,7 +675,7 @@ async def gen() -> CompletionResponseAsyncGen: async for response in await self.async_client.generate( prompt=prompt, payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=True, ): logger.debug(f"Received async chunk: {response}") @@ -694,7 +716,7 @@ async def achat( messages=messages, drop_none=kwargs.pop("drop_none", False) ), payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=False, ) @@ -733,14 +755,13 @@ async def gen() -> ChatResponseAsyncGen: logger.debug(f"Starting astream_chat with messages: {messages}") content = "" is_function = False - first_chat_chunk = True tool_calls = [] async for response in await self.async_client.chat( messages=_to_message_dicts( messages=messages, drop_none=kwargs.pop("drop_none", False) ), payload=self._model_kwargs(**kwargs), - headers=kwargs.pop("headers", None), + headers=self._prepare_headers(kwargs.pop("headers", {})), stream=True, ): logger.debug(f"Received async chat chunk: {response}") @@ -784,7 +805,7 @@ def _prepare_chat_with_tools( chat_history: Optional[List[ChatMessage]] = None, verbose: bool = False, allow_parallel_tool_calls: bool = False, - tool_choice: Union[str, dict] = "auto", + tool_choice: Union[str, dict] = DEFAULT_TOOL_CHOICE, strict: Optional[bool] = None, **kwargs: Any, ) -> Dict[str, Any]: @@ -804,17 +825,15 @@ def _prepare_chat_with_tools( Returns: Dict[str, Any]: The prepared parameters for the chat request. """ + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + logger.debug( - f"Preparing chat with tools. Tools: {tools}, User message: {user_msg}, " + f"Preparing chat with tools. Tools: {tool_specs}, User message: {user_msg}, " f"Chat history: {chat_history}" ) - tool_specs = [tool.metadata.to_openai_tool() for tool in tools] # Determine strict mode - if strict is not None: - strict = strict - else: - strict = self.strict + strict = strict or self.strict if self.metadata.is_function_calling_model: for tool_spec in tool_specs: @@ -833,7 +852,7 @@ def _prepare_chat_with_tools( return { "messages": messages, "tools": tool_specs or None, - "tool_choice": _resolve_tool_choice(tool_choice) if tool_specs else None, + "tool_choice": (_resolve_tool_choice(tool_choice) if tool_specs else None), **kwargs, } @@ -860,7 +879,7 @@ def _validate_chat_with_tools_response( # Ensures that the 'tool_calls' in the response contain only a single tool call. tool_calls = response.message.additional_kwargs.get("tool_calls", []) if len(tool_calls) > 1: - logger.debug( + logger.warning( "Multiple tool calls detected but parallel tool calls are not allowed. " "Limiting to the first tool call." ) @@ -888,7 +907,7 @@ def get_tool_calls_from_response( ValueError: If no tool calls are found and error_on_no_tool_call is True. """ tool_calls = response.message.additional_kwargs.get("tool_calls", []) - logger.debug(f"Extracted tool calls: {tool_calls}") + logger.debug(f"Getting tool calls from response: {tool_calls}") if len(tool_calls) < 1: if error_on_no_tool_call: @@ -901,8 +920,7 @@ def get_tool_calls_from_response( tool_selections = [] for tool_call in tool_calls: if tool_call.get("type") != "function": - logger.error(f"Invalid tool type detected: {tool_call.get('type')}") - raise ValueError("Invalid tool type.") + raise ValueError(f"Invalid tool type detected: {tool_call.get('type')}") # Handle both complete and partial JSON try: @@ -921,50 +939,7 @@ def get_tool_calls_from_response( ) ) - return tool_selections - - @dispatcher.span - def structured_predict( - self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> BaseModel: - # force tool_choice to be required - llm_kwargs = llm_kwargs or {} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] - ) - return super().structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) - - @dispatcher.span - async def astructured_predict( - self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> BaseModel: - # force tool_choice to be required - llm_kwargs = llm_kwargs or {} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] - ) - return await super().astructured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) - - @dispatcher.span - def stream_structured_predict( - self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Generator[Union[Model, List[Model]], None, None]: - # force tool_choice to be required - llm_kwargs = llm_kwargs or {} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] - ) - return super().stream_structured_predict(*args, llm_kwargs=llm_kwargs, **kwargs) - - @dispatcher.span - async def astream_structured_predict( - self, *args: Any, llm_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> AsyncGenerator[Union[Model, List[Model]], None]: - # force tool_choice to be required - llm_kwargs = llm_kwargs or {} - llm_kwargs["tool_choice"] = ( - "required" if "tool_choice" not in llm_kwargs else llm_kwargs["tool_choice"] - ) - return await super().astream_structured_predict( - *args, llm_kwargs=llm_kwargs, **kwargs + logger.debug( + f"Extracted tool calls: { [tool_selection.model_dump() for tool_selection in tool_selections] }" ) + return tool_selections diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py index 0996257d50ac4..a7b10882a68f0 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py @@ -505,6 +505,7 @@ def generate( """ logger.debug(f"Generating text with prompt: {prompt}, stream: {stream}") payload = {**(payload or {}), "prompt": prompt} + headers = {"route": "/v1/completions", **(headers or {})} if stream: return self._stream(payload=payload, headers=headers) return self._request(payload=payload, headers=headers) @@ -530,6 +531,7 @@ def chat( """ logger.debug(f"Starting chat with messages: {messages}, stream: {stream}") payload = {**(payload or {}), "messages": messages} + headers = {"route": "/v1/chat/completions", **(headers or {})} if stream: return self._stream(payload=payload, headers=headers) return self._request(payload=payload, headers=headers) @@ -712,6 +714,7 @@ async def generate( """ logger.debug(f"Generating text with prompt: {prompt}, stream: {stream}") payload = {**(payload or {}), "prompt": prompt} + headers = {"route": "/v1/completions", **(headers or {})} if stream: return self._stream(payload=payload, headers=headers) return await self._request(payload=payload, headers=headers) @@ -737,6 +740,7 @@ async def chat( """ logger.debug(f"Starting chat with messages: {messages}, stream: {stream}") payload = {**(payload or {}), "messages": messages} + headers = {"route": "/v1/chat/completions", **(headers or {})} if stream: return self._stream(payload=payload, headers=headers) return await self._request(payload=payload, headers=headers) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py index 3ac495250bdde..ac62f6cf0eda4 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py @@ -6,6 +6,9 @@ from packaging import version MIN_ADS_VERSION = "2.12.6" +SUPPORTED_TOOL_CHOICES = ["none", "auto", "required"] +DEFAULT_TOOL_CHOICE = "auto" + logger = logging.getLogger(__name__) @@ -181,10 +184,11 @@ def _from_message_dict(message_dict: Dict[str, Any]) -> ChatMessage: ChatMessage The converted ChatMessage object. """ - role = message_dict.get("role") - content = message_dict.get("content") - additional_kwargs = {"tool_calls": message_dict.get("tool_calls", [])} - return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs) + return ChatMessage( + role=message_dict.get("role"), + content=message_dict.get("content"), + additional_kwargs={"tool_calls": message_dict.get("tool_calls", [])}, + ) def _get_response_token_counts(raw_response: Dict[str, Any]) -> Dict[str, int]: @@ -201,15 +205,13 @@ def _get_response_token_counts(raw_response: Dict[str, Any]) -> Dict[str, int]: Dict[str, int] The extracted token counts. """ - usage = raw_response.get("usage", {}) - - if not usage: + if not raw_response.get("usage"): return {} return { - "prompt_tokens": usage.get("prompt_tokens", 0), - "completion_tokens": usage.get("completion_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), + "prompt_tokens": raw_response["usage"].get("prompt_tokens", 0), + "completion_tokens": raw_response["usage"].get("completion_tokens", 0), + "total_tokens": raw_response["usage"].get("total_tokens", 0), } @@ -253,12 +255,11 @@ def _update_tool_calls( return tool_calls -def _resolve_tool_choice(tool_choice: Union[str, dict] = "auto") -> Union[str, dict]: - """Resolve tool choice. - - If tool_choice is a function name string, return the appropriate dict. - """ - if isinstance(tool_choice, str) and tool_choice not in ["none", "auto", "required"]: +def _resolve_tool_choice( + tool_choice: Union[str, dict] = DEFAULT_TOOL_CHOICE +) -> Union[str, dict]: + """If tool_choice is a function name string, return the appropriate dict.""" + if isinstance(tool_choice, str) and tool_choice not in SUPPORTED_TOOL_CHOICES: return {"type": "function", "function": {"name": tool_choice}} return tool_choice diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml index 810ed215bfcbc..63c35a41ce12c 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml @@ -32,9 +32,7 @@ version = "0.1.0" [tool.poetry.dependencies] python = ">=3.9,<4.0" oracle-ads = ">=2.12.6" -llama-index-core = "^0.11.0" -httpx = ">=" -teancity = ">=" +llama-index-core = "^0.12.0" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" From 9434fc58b6a8aeebce6a7919e2a029e40759f446 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 10 Dec 2024 12:23:19 -0800 Subject: [PATCH 04/11] Ruff fixes. --- .../README.md | 106 +++++++---- .../llama_index/llms/oci_data_science/base.py | 12 +- .../llms/oci_data_science/client.py | 28 +-- .../llms/oci_data_science/utils.py | 176 +++++++----------- .../pyproject.toml | 2 +- .../tests/test_llms_oci_data_science.py | 10 +- .../tests/test_oci_data_science_client.py | 3 +- .../tests/test_oci_data_science_utils.py | 4 +- 8 files changed, 166 insertions(+), 175 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md index 580f3c27780a8..ed19b0aec93b9 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md @@ -18,12 +18,12 @@ pip install oracle-ads llama-index llama-index-llms-oci-data-science The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) is required to simplify the authentication within OCI Data Science. ## Authentication -The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. +The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. ## Basic Usage -Using LLMs offered by OCI Data Science AI with LlamaIndex only requires you to initialize the OCIDataScience interface with your Data Science Model Deployment endpoint and model ID. By default the all deployed models in AI Quick Actions get `odsc-model` ID. However this ID cna be changed during the deployment. +Using LLMs offered by OCI Data Science AI with LlamaIndex only requires you to initialize the OCIDataScience interface with your Data Science Model Deployment endpoint and model ID. By default the all deployed models in AI Quick Actions get `odsc-model` ID. However this ID can be changed during the deployment. ### Call `complete` with a prompt @@ -43,6 +43,7 @@ print(response) ``` ### Call `chat` with a list of messages + ```python import ads from llama_index.llms.oci_data_science import OCIDataScience @@ -54,11 +55,15 @@ llm = OCIDataScience( model="odsc-llm", endpoint="https:///predict", ) -response = llm.chat([ - ChatMessage(role="user", content="Tell me a joke"), - ChatMessage(role="assistant", content="Why did the chicken cross the road?"), - ChatMessage(role="user", content="I don't know, why?"), - ]) +response = llm.chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) print(response) ``` @@ -83,6 +88,7 @@ for chunk in llm.stream_complete("Tell me a joke"): ``` ### Using `stream_chat` endpoint + ```python import ads from llama_index.llms.oci_data_science import OCIDataScience @@ -94,11 +100,15 @@ llm = OCIDataScience( model="odsc-llm", endpoint="https:///predict", ) -response = llm.stream_chat([ - ChatMessage(role="user", content="Tell me a joke"), - ChatMessage(role="assistant", content="Why did the chicken cross the road?"), - ChatMessage(role="user", content="I don't know, why?"), - ]) +response = llm.stream_chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) for chunk in response: print(chunk.delta, end="") @@ -124,6 +134,7 @@ print(response) ``` ### Call `achat` with a list of messages + ```python import ads from llama_index.llms.oci_data_science import OCIDataScience @@ -135,11 +146,15 @@ llm = OCIDataScience( model="odsc-llm", endpoint="https:///predict", ) -response = await llm.achat([ - ChatMessage(role="user", content="Tell me a joke"), - ChatMessage(role="assistant", content="Why did the chicken cross the road?"), - ChatMessage(role="user", content="I don't know, why?"), - ]) +response = await llm.achat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) print(response) ``` @@ -176,11 +191,15 @@ llm = OCIDataScience( model="odsc-llm", endpoint="https:///predict", ) -response = await llm.stream_chat([ - ChatMessage(role="user", content="Tell me a joke"), - ChatMessage(role="assistant", content="Why did the chicken cross the road?"), - ChatMessage(role="user", content="I don't know, why?"), - ]) +response = await llm.stream_chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) async for chunk in response: print(chunk.delta, end="") @@ -202,18 +221,21 @@ llm = OCIDataScience( timeout=120, context_window=2500, additional_kwargs={ - "top_p": 0.75, - "logprobs": True, - "top_logprobs": 3, - } + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + }, +) +response = llm.chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ] ) -response = llm.chat([ - ChatMessage(role="user", content="Tell me a joke"), - ]) print(response) ``` ## Function Calling + The [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) offers prebuilt service containers that make deploying and serving a large language model very easy. Either one of vLLM (a high-throughput and memory-efficient inference and serving engine for LLMs) or TGI (a high-performance text generation server for the popular open-source LLMs) is used in the service container to host the model, the end point created supports the OpenAI API protocol. This allows the model deployment to be used as a drop-in replacement for applications using OpenAI API. If the deployed model supports function calling, then integration with LlamaIndex tools, through the predict_and_call function on the llm allows to attach any tools and let the LLM decide which tools to call (if any). ```python @@ -231,12 +253,13 @@ llm = OCIDataScience( timeout=120, context_window=2500, additional_kwargs={ - "top_p": 0.75, - "logprobs": True, - "top_logprobs": 3, - } + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + }, ) + def multiply(a: float, b: float) -> float: print(f"---> {a} * {b}") return a * b @@ -264,8 +287,8 @@ divide_tool = FunctionTool.from_defaults(fn=divide) response = llm.predict_and_call( [multiply_tool, add_tool, sub_tool, divide_tool], - user_msg= "Calculate the result of `8 + 2 - 6`.", - verbose=True + user_msg="Calculate the result of `8 + 2 - 6`.", + verbose=True, ) print(response) @@ -289,12 +312,13 @@ llm = OCIDataScience( timeout=120, context_window=2500, additional_kwargs={ - "top_p": 0.75, - "logprobs": True, - "top_logprobs": 3, - } + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + }, ) + def multiply(a: float, b: float) -> float: print(f"---> {a} * {b}") return a * b @@ -321,7 +345,9 @@ sub_tool = FunctionTool.from_defaults(fn=subtract) divide_tool = FunctionTool.from_defaults(fn=divide) agent = FunctionCallingAgent.from_tools( - tools=[multiply_tool, add_tool, sub_tool, divide_tool], llm=llm, verbose=True + tools=[multiply_tool, add_tool, sub_tool, divide_tool], + llm=llm, + verbose=True, ) response = agent.chat( "Calculate the result of `8 + 2 - 6`. Use tools. Return the calculated result." diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py index 9725ef48f0457..cabc288c8dfe2 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py @@ -34,7 +34,7 @@ from llama_index.core.llms.function_calling import FunctionCallingLLM from llama_index.core.llms.llm import ToolSelection from llama_index.core.llms.utils import parse_partial_json -from llama_index.core.types import BaseOutputParser, Model, PydanticProgramMode +from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.oci_data_science.client import AsyncClient, Client from llama_index.llms.oci_data_science.utils import ( DEFAULT_TOOL_CHOICE, @@ -477,7 +477,7 @@ def complete( additional_kwargs=_get_response_token_counts(response), ) except (IndexError, KeyError, TypeError) as e: - raise ValueError(f"Failed to parse response: {str(e)}") from e + raise ValueError(f"Failed to parse response: {e!s}") from e @llm_completion_callback() def stream_complete( @@ -554,7 +554,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: additional_kwargs=_get_response_token_counts(response), ) except (IndexError, KeyError, TypeError) as e: - raise ValueError(f"Failed to parse response: {str(e)}") from e + raise ValueError(f"Failed to parse response: {e!s}") from e @llm_chat_callback() def stream_chat( @@ -650,7 +650,7 @@ async def acomplete( additional_kwargs=_get_response_token_counts(response), ) except (IndexError, KeyError, TypeError) as e: - raise ValueError(f"Failed to parse response: {str(e)}") from e + raise ValueError(f"Failed to parse response: {e!s}") from e @llm_completion_callback() async def astream_complete( @@ -734,7 +734,7 @@ async def achat( additional_kwargs=_get_response_token_counts(response), ) except (IndexError, KeyError, TypeError) as e: - raise ValueError(f"Failed to parse response: {str(e)}") from e + raise ValueError(f"Failed to parse response: {e!s}") from e @llm_chat_callback() async def astream_chat( @@ -928,7 +928,7 @@ def get_tool_calls_from_response( tool_call.get("function", {}).get("arguments", {}) ) except ValueError as e: - logger.debug(f"Failed to parse tool call arguments: {str(e)}") + logger.debug(f"Failed to parse tool call arguments: {e!s}") argument_dict = {} tool_selections.append( diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py index a7b10882a68f0..8e36894ccd454 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py @@ -39,6 +39,8 @@ STATUS_FORCE_LIST = [429, 500, 502, 503, 504] DEFAULT_ENCODING = "utf-8" +from typing import Self + _T = TypeVar("_T", bound="BaseClient") logger = logging.getLogger(__name__) @@ -292,7 +294,7 @@ def _parse_streaming_line( except json.JSONDecodeError as e: logger.debug(f"Error decoding JSON from line: {line}") raise json.JSONDecodeError( - f"Error decoding JSON from line: {str(e)}", e.doc, e.pos + f"Error decoding JSON from line: {e!s}", e.doc, e.pos ) from e if json_line.get("object") == "error": @@ -336,7 +338,7 @@ class Client(BaseClient): Synchronous HTTP client for invoking models with support for request and streaming APIs. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """ Initialize the Client. @@ -354,7 +356,7 @@ def close(self) -> None: """Close the underlying HTTPX client.""" self._client.close() - def __enter__(self: _T) -> _T: + def __enter__(self: _T) -> Self: return self def __exit__( @@ -406,10 +408,10 @@ def _request( e.response.text if hasattr(e, "response") and e.response else str(e) ) logger.error( - f"Request failed. Error: {str(e)}. Details: {last_exception_text}" + f"Request failed. Error: {e!s}. Details: {last_exception_text}" ) raise ExtendedRequestException( - f"Request failed: {str(e)}. Details: {last_exception_text}", + f"Request failed: {e!s}. Details: {last_exception_text}", e, last_exception_text, ) from e @@ -476,10 +478,10 @@ def _stream( time.sleep(delay) else: logger.error( - f"Streaming request failed. Error: {str(e)}. Details: {last_exception_text}" + f"Streaming request failed. Error: {e!s}. Details: {last_exception_text}" ) raise ExtendedRequestException( - f"Streaming request failed: {str(e)}. Details: {last_exception_text}", + f"Streaming request failed: {e!s}. Details: {last_exception_text}", e, last_exception_text, ) from e @@ -542,7 +544,7 @@ class AsyncClient(BaseClient): Asynchronous HTTP client for invoking models with support for request and streaming APIs, including retry logic. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """ Initialize the AsyncClient. @@ -563,7 +565,7 @@ async def close(self) -> None: """ await self._client.aclose() - async def __aenter__(self: _T) -> _T: + async def __aenter__(self: _T) -> Self: return self async def __aexit__( @@ -620,10 +622,10 @@ async def _request( e.response.text if hasattr(e, "response") and e.response else str(e) ) logger.error( - f"Request failed. Error: {str(e)}. Details: {last_exception_text}" + f"Request failed. Error: {e!s}. Details: {last_exception_text}" ) raise ExtendedRequestException( - f"Request failed: {str(e)}. Details: {last_exception_text}", + f"Request failed: {e!s}. Details: {last_exception_text}", e, last_exception_text, ) from e @@ -685,10 +687,10 @@ async def _stream( await asyncio.sleep(delay) else: logger.error( - f"Streaming request failed. Error: {str(e)}. Details: {last_exception_text}" + f"Streaming request failed. Error: {e!s}. Details: {last_exception_text}" ) raise ExtendedRequestException( - f"Streaming request failed: {str(e)}. Details: {last_exception_text}", + f"Streaming request failed: {e!s}. Details: {last_exception_text}", e, last_exception_text, ) from e diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py index ac62f6cf0eda4..52a201978149b 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py @@ -9,20 +9,15 @@ SUPPORTED_TOOL_CHOICES = ["none", "auto", "required"] DEFAULT_TOOL_CHOICE = "auto" - logger = logging.getLogger(__name__) class UnsupportedOracleAdsVersionError(Exception): - """ - Custom exception for unsupported `oracle-ads` versions. - - Attributes - ---------- - current_version : str - The installed version of `oracle-ads`. - required_version : str - The minimum required version of `oracle-ads`. + """Custom exception for unsupported `oracle-ads` versions. + + Attributes: + current_version: The installed version of `oracle-ads`. + required_version: The minimum required version of `oracle-ads`. """ def __init__(self, current_version: str, required_version: str): @@ -35,28 +30,20 @@ def __init__(self, current_version: str, required_version: str): def _validate_dependency(func: Callable[..., Any]) -> Callable[..., Any]: - """ - Decorator to validate the presence and version of the `oracle-ads` package. + """Decorator to validate the presence and version of `oracle-ads`. - This decorator checks whether `oracle-ads` is installed and ensures its version meets - the minimum requirement. Raises an error if the conditions are not met. + This decorator checks that `oracle-ads` is installed and that its version meets + the minimum requirement. If not, it raises an error. - Parameters - ---------- - func : Callable[..., Any] - The function to wrap with the dependency validation. + Args: + func: The function to wrap with the dependency validation. - Returns - ------- - Callable[..., Any] + Returns: The wrapped function. - Raises - ------ - ImportError - If `oracle-ads` is not installed. - UnsupportedOracleAdsVersionError - If the installed version is below the required version. + Raises: + ImportError: If `oracle-ads` is not installed. + UnsupportedOracleAdsVersionError: If the installed version is below the required version. """ @wraps(func) @@ -66,7 +53,6 @@ def wrapper(*args, **kwargs) -> Any: if version.parse(ads_version) < version.parse(MIN_ADS_VERSION): raise UnsupportedOracleAdsVersionError(ads_version, MIN_ADS_VERSION) - except ImportError as ex: raise ImportError( "Could not import `oracle-ads` Python package. " @@ -80,20 +66,14 @@ def wrapper(*args, **kwargs) -> Any: def _to_message_dicts( messages: Sequence[ChatMessage], drop_none: bool = False ) -> List[Dict[str, Any]]: - """ - Converts a sequence of ChatMessage objects to a list of dictionaries. - - Parameters - ---------- - messages : Sequence[ChatMessage] - The messages to convert. - drop_none : bool, optional - Whether to drop keys with `None` values. Defaults to False. - - Returns - ------- - List[Dict[str, Any]] - The converted list of message dictionaries. + """Convert a sequence of ChatMessage objects to a list of dictionaries. + + Args: + messages: The messages to convert. + drop_none: Whether to drop keys with `None` values. Defaults to False. + + Returns: + A list of message dictionaries. """ message_dicts = [] for message in messages: @@ -111,18 +91,13 @@ def _to_message_dicts( def _from_completion_logprobs_dict( completion_logprobs_dict: Dict[str, Any] ) -> List[List[LogProb]]: - """ - Converts completion logprobs to a list of generic LogProb objects. + """Convert completion logprobs to a list of generic LogProb objects. - Parameters - ---------- - completion_logprobs_dict : Dict[str, Any] - The completion logprobs to convert. + Args: + completion_logprobs_dict: The completion logprobs to convert. - Returns - ------- - List[List[LogProb]] - The converted logprobs. + Returns: + A list of lists of LogProb objects. """ return [ [ @@ -134,20 +109,18 @@ def _from_completion_logprobs_dict( def _from_token_logprob_dicts( - token_logprob_dicts: Sequence[Dict[str, Any]], + token_logprob_dicts: Sequence[Dict[str, Any]] ) -> List[List[LogProb]]: - """ - Converts a sequence of token logprob dictionaries to a list of lists of LogProb objects. + """Convert a sequence of token logprob dictionaries to a list of LogProb objects. + + Args: + token_logprob_dicts: The token logprob dictionaries to convert. - Parameters - ---------- - token_logprob_dicts : Sequence[Dict[str, Any]] - The token logprob dictionaries to convert. + Returns: + A list of lists of LogProb objects. - Returns - ------- - List[List[LogProb]] - The converted logprobs. + Raises: + Warning: Logs a warning if an error occurs while parsing token logprobs. """ result = [] for token_logprob_dict in token_logprob_dicts: @@ -164,25 +137,20 @@ def _from_token_logprob_dicts( result.append(logprobs_list) except Exception as e: logger.warning( - f"Error occurred in attempt to parse token logprob. " + "Error occurred in attempt to parse token logprob. " f"Details: {e}. Src: {token_logprob_dict}" ) return result def _from_message_dict(message_dict: Dict[str, Any]) -> ChatMessage: - """ - Converts a message dictionary to a generic ChatMessage object. + """Convert a message dictionary to a ChatMessage object. - Parameters - ---------- - message_dict : Dict[str, Any] - The message dictionary. + Args: + message_dict: The message dictionary. - Returns - ------- - ChatMessage - The converted ChatMessage object. + Returns: + A ChatMessage object representing the given dictionary. """ return ChatMessage( role=message_dict.get("role"), @@ -192,18 +160,13 @@ def _from_message_dict(message_dict: Dict[str, Any]) -> ChatMessage: def _get_response_token_counts(raw_response: Dict[str, Any]) -> Dict[str, int]: - """ - Extracts token usage information from the response. + """Extract token usage information from the response. - Parameters - ---------- - raw_response : Dict[str, Any] - The raw response containing token usage information. + Args: + raw_response: The raw response containing token usage information. - Returns - ------- - Dict[str, int] - The extracted token counts. + Returns: + A dictionary containing token counts, or an empty dictionary if usage info is not found. """ if not raw_response.get("usage"): return {} @@ -218,20 +181,14 @@ def _get_response_token_counts(raw_response: Dict[str, Any]) -> Dict[str, int]: def _update_tool_calls( tool_calls: List[Dict[str, Any]], tool_calls_delta: Optional[List[Dict[str, Any]]] ) -> List[Dict[str, Any]]: - """ - Updates the tool calls using delta objects received from stream chunks. - - Parameters - ---------- - tool_calls : List[Dict[str, Any]] - The list of existing tool calls. - tool_calls_delta : Optional[List[Dict[str, Any]]] - The delta updates for the tool calls. - - Returns - ------- - List[Dict[str, Any]] - The updated tool calls. + """Update the tool calls using delta objects received from stream chunks. + + Args: + tool_calls: The list of existing tool calls. + tool_calls_delta: The delta updates for the tool calls (if any). + + Returns: + The updated list of tool calls. """ if not tool_calls_delta: return tool_calls @@ -244,11 +201,11 @@ def _update_tool_calls( latest_function = latest_call.setdefault("function", {}) delta_function = delta_call.get("function", {}) - latest_function["arguments"] = latest_function.get( - "arguments", "" - ) + delta_function.get("arguments", "") - latest_function["name"] = latest_function.get("name", "") + delta_function.get( - "name", "" + latest_function["arguments"] = ( + latest_function.get("arguments", "") + delta_function.get("arguments", "") + ) + latest_function["name"] = ( + latest_function.get("name", "") + delta_function.get("name", "") ) latest_call["id"] = latest_call.get("id", "") + delta_call.get("id", "") @@ -258,8 +215,17 @@ def _update_tool_calls( def _resolve_tool_choice( tool_choice: Union[str, dict] = DEFAULT_TOOL_CHOICE ) -> Union[str, dict]: - """If tool_choice is a function name string, return the appropriate dict.""" + """Resolve the tool choice into a string or a dictionary. + + If the tool_choice is a string that is not in SUPPORTED_TOOL_CHOICES, a dictionary + representing a function call is returned. + + Args: + tool_choice: The desired tool choice, which can be a string or a dictionary. Defaults to "auto". + + Returns: + Either the original tool_choice if valid or a dictionary representing a function call. + """ if isinstance(tool_choice, str) and tool_choice not in SUPPORTED_TOOL_CHOICES: return {"type": "function", "function": {"name": tool_choice}} - return tool_choice diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml index 63c35a41ce12c..8b9078ad59c14 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml @@ -41,7 +41,7 @@ mypy = "0.991" pre-commit = "3.2.0" pylint = "2.15.10" pytest = "7.2.1" -pytest-asyncio=">=0.24.0" +pytest-asyncio = ">=0.24.0" pytest-mock = "3.11.1" ruff = "0.0.292" tree-sitter-languages = "^1.8.0" diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py index 6ef5078e01cb7..c916df6e89ef6 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py @@ -16,7 +16,7 @@ def test_embedding_class(): assert FunctionCallingLLM.__name__ in names_of_base_classes -@pytest.fixture +@pytest.fixture() def llm(): endpoint = "https://example.com/api" auth = {"signer": Mock()} @@ -232,7 +232,7 @@ def test_get_tool_calls_from_response(llm): assert tool_selections[0].tool_kwargs == {"a": 2, "b": 3} -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_acomplete_success(llm): prompt = "What is the capital of France?" response_data = { @@ -257,7 +257,7 @@ async def test_acomplete_success(llm): assert response.additional_kwargs["total_tokens"] == 12 -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_astream_complete(llm): prompt = "Once upon a time" @@ -283,7 +283,7 @@ async def async_generator(): assert responses[-1].text == "Once upon a time." -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_achat_success(llm): messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] response_data = { @@ -314,7 +314,7 @@ async def test_achat_success(llm): assert response.additional_kwargs["total_tokens"] == 25 -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_astream_chat(llm): messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py index 9f951e35cfc1f..b926c4039eb9b 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py @@ -15,7 +15,6 @@ _retry_decorator, _should_retry_exception, ) -from tenacity import RetryError class TestOCIAuth: @@ -479,7 +478,7 @@ def test_del(self): client.close.assert_called_once() -@pytest.mark.asyncio +@pytest.mark.asyncio() class TestAsyncClient: """Unit tests for AsyncClient class.""" diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py index d043ad63538e6..7333a015bb373 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py @@ -1,7 +1,5 @@ -import os from unittest.mock import patch -import ads import pytest from llama_index.core.base.llms.types import ChatMessage, LogProb, MessageRole from llama_index.llms.oci_data_science.utils import ( @@ -319,7 +317,7 @@ class TestResolveToolChoice: """Unit tests for _resolve_tool_choice function.""" @pytest.mark.parametrize( - "input_choice, expected_output", + ("input_choice", "expected_output"), [ ("auto", "auto"), ("none", "none"), From 8d87880ceded924adae249ba3bfdacdcbed8d683 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 10 Dec 2024 12:37:51 -0800 Subject: [PATCH 05/11] Fixes collab link. --- docs/docs/examples/llm/oci_data_science.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/examples/llm/oci_data_science.ipynb b/docs/docs/examples/llm/oci_data_science.ipynb index b6add9e72f818..54ca3074d7c59 100644 --- a/docs/docs/examples/llm/oci_data_science.ipynb +++ b/docs/docs/examples/llm/oci_data_science.ipynb @@ -6,7 +6,7 @@ "id": "6d1ca9ac", "metadata": {}, "source": [ - "\"Open" + "\"Open" ] }, { From 980bf0479aa128ae71f54772b59c974cbd416e13 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 10 Dec 2024 22:23:24 -0800 Subject: [PATCH 06/11] Adjustments by black formatter. --- .../llama_index/llms/oci_data_science/base.py | 11 +---------- .../llama_index/llms/oci_data_science/utils.py | 10 +++++----- .../tests/test_oci_data_science_utils.py | 1 - 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py index cabc288c8dfe2..05c26308b51e1 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py @@ -1,14 +1,5 @@ import logging -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Sequence, - Union, - TYPE_CHECKING -) +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, TYPE_CHECKING import llama_index.core.instrumentation as instrument from ads.common import auth as authutil diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py index 52a201978149b..8201d416192c6 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py @@ -201,11 +201,11 @@ def _update_tool_calls( latest_function = latest_call.setdefault("function", {}) delta_function = delta_call.get("function", {}) - latest_function["arguments"] = ( - latest_function.get("arguments", "") + delta_function.get("arguments", "") - ) - latest_function["name"] = ( - latest_function.get("name", "") + delta_function.get("name", "") + latest_function["arguments"] = latest_function.get( + "arguments", "" + ) + delta_function.get("arguments", "") + latest_function["name"] = latest_function.get("name", "") + delta_function.get( + "name", "" ) latest_call["id"] = latest_call.get("id", "") + delta_call.get("id", "") diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py index 7333a015bb373..7b0c420761ae4 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py @@ -37,7 +37,6 @@ class TestValidateDependency: """Unit tests for _validate_dependency decorator.""" def setup_method(self): - @_validate_dependency def sample_function(): return "function executed" From 4c10eaf42fd566eb2fc0c50f3db4565c1568591a Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Wed, 11 Dec 2024 11:54:17 -0800 Subject: [PATCH 07/11] Fixed the notebook formatting. --- docs/docs/examples/llm/oci_data_science.ipynb | 98 ++++++++++++------- .../llms/oci_data_science/client.py | 6 +- 2 files changed, 62 insertions(+), 42 deletions(-) diff --git a/docs/docs/examples/llm/oci_data_science.ipynb b/docs/docs/examples/llm/oci_data_science.ipynb index 54ca3074d7c59..1f6b2e1ee56ff 100644 --- a/docs/docs/examples/llm/oci_data_science.ipynb +++ b/docs/docs/examples/llm/oci_data_science.ipynb @@ -146,11 +146,15 @@ " model=\"odsc-llm\",\n", " endpoint=\"https:///predict\",\n", ")\n", - "response = llm.chat([\n", - " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", - " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", - " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", - " ])\n", + "response = llm.chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", "\n", "print(response)" ] @@ -217,11 +221,15 @@ " model=\"odsc-llm\",\n", " endpoint=\"https:///predict\",\n", ")\n", - "response = llm.stream_chat([\n", - " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", - " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", - " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", - " ])\n", + "response = llm.stream_chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", "\n", "for chunk in response:\n", " print(chunk.delta, end=\"\")" @@ -289,11 +297,15 @@ " model=\"odsc-llm\",\n", " endpoint=\"https:///predict\",\n", ")\n", - "response = await llm.achat([\n", - " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", - " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", - " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", - " ])\n", + "response = await llm.achat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", "\n", "print(response)" ] @@ -352,11 +364,15 @@ " model=\"odsc-llm\",\n", " endpoint=\"https:///predict\",\n", ")\n", - "response = await llm.stream_chat([\n", - " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", - " ChatMessage(role=\"assistant\", content=\"Why did the chicken cross the road?\"),\n", - " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", - " ])\n", + "response = await llm.stream_chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", "\n", "async for chunk in response:\n", " print(chunk.delta, end=\"\")" @@ -390,14 +406,16 @@ " timeout=120,\n", " context_window=2500,\n", " additional_kwargs={\n", - " \"top_p\": 0.75,\n", - " \"logprobs\": True,\n", - " \"top_logprobs\": 3,\n", - " }\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " },\n", + ")\n", + "response = llm.chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ]\n", ")\n", - "response = llm.chat([\n", - " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", - " ])\n", "print(response)" ] }, @@ -438,12 +456,13 @@ " timeout=120,\n", " context_window=2500,\n", " additional_kwargs={\n", - " \"top_p\": 0.75,\n", - " \"logprobs\": True,\n", - " \"top_logprobs\": 3,\n", - " }\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " },\n", ")\n", "\n", + "\n", "def multiply(a: float, b: float) -> float:\n", " print(f\"---> {a} * {b}\")\n", " return a * b\n", @@ -471,8 +490,8 @@ "\n", "response = llm.predict_and_call(\n", " [multiply_tool, add_tool, sub_tool, divide_tool],\n", - " user_msg= \"Calculate the result of `8 + 2 - 6`.\",\n", - " verbose=True\n", + " user_msg=\"Calculate the result of `8 + 2 - 6`.\",\n", + " verbose=True,\n", ")\n", "\n", "print(response)" @@ -508,12 +527,13 @@ " timeout=120,\n", " context_window=2500,\n", " additional_kwargs={\n", - " \"top_p\": 0.75,\n", - " \"logprobs\": True,\n", - " \"top_logprobs\": 3,\n", - " }\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " },\n", ")\n", "\n", + "\n", "def multiply(a: float, b: float) -> float:\n", " print(f\"---> {a} * {b}\")\n", " return a * b\n", @@ -540,7 +560,9 @@ "divide_tool = FunctionTool.from_defaults(fn=divide)\n", "\n", "agent = FunctionCallingAgent.from_tools(\n", - " tools=[multiply_tool, add_tool, sub_tool, divide_tool], llm=llm, verbose=True\n", + " tools=[multiply_tool, add_tool, sub_tool, divide_tool],\n", + " llm=llm,\n", + " verbose=True,\n", ")\n", "response = agent.chat(\n", " \"Calculate the result of `8 + 2 - 6`. Use tools. Return the calculated result.\"\n", diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py index 8e36894ccd454..783318b6b070a 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py @@ -39,8 +39,6 @@ STATUS_FORCE_LIST = [429, 500, 502, 503, 504] DEFAULT_ENCODING = "utf-8" -from typing import Self - _T = TypeVar("_T", bound="BaseClient") logger = logging.getLogger(__name__) @@ -356,7 +354,7 @@ def close(self) -> None: """Close the underlying HTTPX client.""" self._client.close() - def __enter__(self: _T) -> Self: + def __enter__(self: _T) -> _T: # noqa: PYI019 return self def __exit__( @@ -565,7 +563,7 @@ async def close(self) -> None: """ await self._client.aclose() - async def __aenter__(self: _T) -> Self: + async def __aenter__(self: _T) -> _T: # noqa: PYI019 return self async def __aexit__( From f6e5b729ce4a1f74fe055cc731ad120609bbe426 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Wed, 11 Dec 2024 17:50:23 -0800 Subject: [PATCH 08/11] Adds requirements.txt --- .../llms/llama-index-llms-oci-data-science/requirements.txt | 1 + 1 file changed, 1 insertion(+) create mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt b/llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt new file mode 100644 index 0000000000000..344d2da70731e --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt @@ -0,0 +1 @@ +oracle-ads From 67f5f731d162c803a8c3e8db0d88ac04c38695c2 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Fri, 13 Dec 2024 09:55:15 -0800 Subject: [PATCH 09/11] Modifies BUILD file. Adds ads mapping. --- .../llms/llama-index-llms-oci-data-science/BUILD | 1 + .../llms/llama-index-llms-oci-data-science/requirements.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD index 0896ca890d8bf..a40903f2ec7e7 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD @@ -1,3 +1,4 @@ poetry_requirements( name="poetry", + module_mapping={"oracle-ads": ["ads"]}, ) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt b/llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt deleted file mode 100644 index 344d2da70731e..0000000000000 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -oracle-ads From 89fd22d57c613cc519aa08d9f98e984574047238 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Wed, 18 Dec 2024 20:48:40 -0600 Subject: [PATCH 10/11] restrict tests --- .../llms/llama-index-llms-oci-data-science/tests/BUILD | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD index dabf212d7e716..89a6d832d22a7 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD @@ -1 +1,3 @@ -python_tests() +python_tests( + dependencies=["==3.9.*"] +) From 30766e466431b5a930376592f9a407b70fa5e2a1 Mon Sep 17 00:00:00 2001 From: Logan Markewich Date: Wed, 18 Dec 2024 20:58:10 -0600 Subject: [PATCH 11/11] wrong key name --- .../llms/llama-index-llms-oci-data-science/tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD index 89a6d832d22a7..b9078077b4d11 100644 --- a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD @@ -1,3 +1,3 @@ python_tests( - dependencies=["==3.9.*"] + interpreter_constraints=["==3.9.*"] )