diff --git a/docs/docs/examples/llm/oci_data_science.ipynb b/docs/docs/examples/llm/oci_data_science.ipynb new file mode 100644 index 0000000000000..1f6b2e1ee56ff --- /dev/null +++ b/docs/docs/examples/llm/oci_data_science.ipynb @@ -0,0 +1,595 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "6d1ca9ac", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "id": "9e3a8796-edc8-43f2-94ad-fe4fb20d70ed", + "metadata": {}, + "source": [ + "# Oracle Cloud Infrastructure Data Science \n", + "\n", + "Oracle Cloud Infrastructure [(OCI) Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed, serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure.\n", + "\n", + "It offers [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm), which can be used to deploy, evaluate, and fine-tune foundation LLM models in OCI Data Science. AI Quick Actions target users who want to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free, and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook.\n", + "\n", + "Detailed documentation on how to deploy LLM models in OCI Data Science using AI Quick Actions is available [here](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/model-deployment-tips.md) and [here](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm).\n", + "\n", + "This notebook explains how to use OCI's Data Science models with LlamaIndex." + ] + }, + { + "cell_type": "markdown", + "id": "3802e8c4", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb0dd8c9", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-oci-data-science" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "544d49f9", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index" + ] + }, + { + "cell_type": "markdown", + "id": "c2921307", + "metadata": {}, + "source": [ + "You will also need to install the [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) SDK." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "378d5179", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U oracle-ads" + ] + }, + { + "cell_type": "markdown", + "id": "737b5293", + "metadata": {}, + "source": [ + "## Authentication\n", + "The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) helps to simplify the authentication within OCI Data Science." + ] + }, + { + "cell_type": "markdown", + "id": "03d4024a", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "Using LLMs offered by OCI Data Science AI with LlamaIndex only requires you to initialize the `OCIDataScience` interface with your Data Science Model Deployment endpoint and model ID. By default the all deployed models in AI Quick Actions get `odsc-model` ID. However this ID cna be changed during the deployment." + ] + }, + { + "cell_type": "markdown", + "id": "8ead155e-b8bd-46f9-ab9b-28fc009361dd", + "metadata": {}, + "source": [ + "#### Call `complete` with a prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60be18ae-c957-4ac2-a58a-0652e18ee6d6", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = llm.complete(\"Tell me a joke\")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "c1f3fcbd", + "metadata": {}, + "source": [ + "### Call `chat` with a list of messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a80c9f6e", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = llm.chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "9581413d", + "metadata": {}, + "source": [ + "## Streaming" + ] + }, + { + "cell_type": "markdown", + "id": "6f4dbedf", + "metadata": {}, + "source": [ + "### Using `stream_complete` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "977ad99f", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "\n", + "for chunk in llm.stream_complete(\"Tell me a joke\"):\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "38abd64d", + "metadata": {}, + "source": [ + "### Using `stream_chat` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fca03dac", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = llm.stream_chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", + "\n", + "for chunk in response:\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "0b986d4e", + "metadata": {}, + "source": [ + "## Async" + ] + }, + { + "cell_type": "markdown", + "id": "42294b23", + "metadata": {}, + "source": [ + "### Call `acomplete` with a prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d52768eb", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = await llm.acomplete(\"Tell me a joke\")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "aad4d4cb", + "metadata": {}, + "source": [ + "### Call `achat` with a list of messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1416bacf", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = await llm.achat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "0da3c384", + "metadata": {}, + "source": [ + "### Using `astream_complete` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b392dc3a", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "\n", + "async for chunk in await llm.astream_complete(\"Tell me a joke\"):\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "c22e167a", + "metadata": {}, + "source": [ + "### Using `astream_chat` endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "056daa3a", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.base.llms.types import ChatMessage\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + ")\n", + "response = await llm.stream_chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ChatMessage(\n", + " role=\"assistant\", content=\"Why did the chicken cross the road?\"\n", + " ),\n", + " ChatMessage(role=\"user\", content=\"I don't know, why?\"),\n", + " ]\n", + ")\n", + "\n", + "async for chunk in response:\n", + " print(chunk.delta, end=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "ed26b8a7", + "metadata": {}, + "source": [ + "## Configure Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42fa2409", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + " temperature=0.2,\n", + " max_tokens=500,\n", + " timeout=120,\n", + " context_window=2500,\n", + " additional_kwargs={\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " },\n", + ")\n", + "response = llm.chat(\n", + " [\n", + " ChatMessage(role=\"user\", content=\"Tell me a joke\"),\n", + " ]\n", + ")\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "094b98c0", + "metadata": {}, + "source": [ + "## Function Calling" + ] + }, + { + "cell_type": "markdown", + "id": "63a1532a", + "metadata": {}, + "source": [ + "The [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) offers prebuilt service containers that make deploying and serving a large language model very easy. Either one of vLLM (a high-throughput and memory-efficient inference and serving engine for LLMs) or TGI (a high-performance text generation server for the popular open-source LLMs) is used in the service container to host the model, the end point created supports the OpenAI API protocol. This allows the model deployment to be used as a drop-in replacement for applications using OpenAI API. If the deployed model supports function calling, then integration with LlamaIndex tools, through the predict_and_call function on the llm allows to attach any tools and let the LLM decide which tools to call (if any).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28b53563", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.tools import FunctionTool\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + " temperature=0.2,\n", + " max_tokens=500,\n", + " timeout=120,\n", + " context_window=2500,\n", + " additional_kwargs={\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " },\n", + ")\n", + "\n", + "\n", + "def multiply(a: float, b: float) -> float:\n", + " print(f\"---> {a} * {b}\")\n", + " return a * b\n", + "\n", + "\n", + "def add(a: float, b: float) -> float:\n", + " print(f\"---> {a} + {b}\")\n", + " return a + b\n", + "\n", + "\n", + "def subtract(a: float, b: float) -> float:\n", + " print(f\"---> {a} - {b}\")\n", + " return a - b\n", + "\n", + "\n", + "def divide(a: float, b: float) -> float:\n", + " print(f\"---> {a} / {b}\")\n", + " return a / b\n", + "\n", + "\n", + "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n", + "add_tool = FunctionTool.from_defaults(fn=add)\n", + "sub_tool = FunctionTool.from_defaults(fn=subtract)\n", + "divide_tool = FunctionTool.from_defaults(fn=divide)\n", + "\n", + "response = llm.predict_and_call(\n", + " [multiply_tool, add_tool, sub_tool, divide_tool],\n", + " user_msg=\"Calculate the result of `8 + 2 - 6`.\",\n", + " verbose=True,\n", + ")\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "2dc0829c", + "metadata": {}, + "source": [ + "### Using `FunctionCallingAgent`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29fa7fb6", + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from llama_index.llms.oci_data_science import OCIDataScience\n", + "from llama_index.core.tools import FunctionTool\n", + "from llama_index.core.agent import FunctionCallingAgent\n", + "\n", + "ads.set_auth(auth=\"security_token\", profile=\"\")\n", + "\n", + "llm = OCIDataScience(\n", + " model=\"odsc-llm\",\n", + " endpoint=\"https:///predict\",\n", + " temperature=0.2,\n", + " max_tokens=500,\n", + " timeout=120,\n", + " context_window=2500,\n", + " additional_kwargs={\n", + " \"top_p\": 0.75,\n", + " \"logprobs\": True,\n", + " \"top_logprobs\": 3,\n", + " },\n", + ")\n", + "\n", + "\n", + "def multiply(a: float, b: float) -> float:\n", + " print(f\"---> {a} * {b}\")\n", + " return a * b\n", + "\n", + "\n", + "def add(a: float, b: float) -> float:\n", + " print(f\"---> {a} + {b}\")\n", + " return a + b\n", + "\n", + "\n", + "def subtract(a: float, b: float) -> float:\n", + " print(f\"---> {a} - {b}\")\n", + " return a - b\n", + "\n", + "\n", + "def divide(a: float, b: float) -> float:\n", + " print(f\"---> {a} / {b}\")\n", + " return a / b\n", + "\n", + "\n", + "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n", + "add_tool = FunctionTool.from_defaults(fn=add)\n", + "sub_tool = FunctionTool.from_defaults(fn=subtract)\n", + "divide_tool = FunctionTool.from_defaults(fn=divide)\n", + "\n", + "agent = FunctionCallingAgent.from_tools(\n", + " tools=[multiply_tool, add_tool, sub_tool, divide_tool],\n", + " llm=llm,\n", + " verbose=True,\n", + ")\n", + "response = agent.chat(\n", + " \"Calculate the result of `8 + 2 - 6`. Use tools. Return the calculated result.\"\n", + ")\n", + "\n", + "print(response)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/.gitignore b/llama-index-integrations/llms/llama-index-llms-oci-data-science/.gitignore new file mode 100644 index 0000000000000..990c18de22908 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD new file mode 100644 index 0000000000000..a40903f2ec7e7 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/BUILD @@ -0,0 +1,4 @@ +poetry_requirements( + name="poetry", + module_mapping={"oracle-ads": ["ads"]}, +) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/Makefile b/llama-index-integrations/llms/llama-index-llms-oci-data-science/Makefile new file mode 100644 index 0000000000000..b9eab05aa3706 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md new file mode 100644 index 0000000000000..ed19b0aec93b9 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/README.md @@ -0,0 +1,361 @@ +# LlamaIndex LLMs Integration: Oracle Cloud Infrastructure (OCI) Data Science Service + +Oracle Cloud Infrastructure (OCI) [Data Science](https://www.oracle.com/artificial-intelligence/data-science) is a fully managed, serverless platform for data science teams to build, train, and manage machine learning models in Oracle Cloud Infrastructure. + +It offers [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm), which can be used to deploy, evaluate, and fine-tune foundation models in OCI Data Science. AI Quick Actions target users who want to quickly leverage the capabilities of AI. They aim to expand the reach of foundation models to a broader set of users by providing a streamlined, code-free, and efficient environment for working with foundation models. AI Quick Actions can be accessed from the Data Science Notebook. + +Detailed documentation on how to deploy LLM models in OCI Data Science using AI Quick Actions is available [here](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/model-deployment-tips.md) and [here](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm). + +## Installation + +Install the required packages: + +```bash +pip install oracle-ads llama-index llama-index-llms-oci-data-science + +``` + +The [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/index.html) is required to simplify the authentication within OCI Data Science. + +## Authentication + +The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the standard SDK authentication methods, specifically API Key, session token, instance principal, and resource principal. More details can be found [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html). Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm) to access the OCI Data Science Model Deployment endpoint. + +## Basic Usage + +Using LLMs offered by OCI Data Science AI with LlamaIndex only requires you to initialize the OCIDataScience interface with your Data Science Model Deployment endpoint and model ID. By default the all deployed models in AI Quick Actions get `odsc-model` ID. However this ID can be changed during the deployment. + +### Call `complete` with a prompt + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = llm.complete("Tell me a joke") + +print(response) +``` + +### Call `chat` with a list of messages + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = llm.chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) + +print(response) +``` + +## Streaming + +### Using `stream_complete` endpoint + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) + +for chunk in llm.stream_complete("Tell me a joke"): + print(chunk.delta, end="") +``` + +### Using `stream_chat` endpoint + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = llm.stream_chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) + +for chunk in response: + print(chunk.delta, end="") +``` + +## Async + +### Call `acomplete` with a prompt + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = await llm.acomplete("Tell me a joke") + +print(response) +``` + +### Call `achat` with a list of messages + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = await llm.achat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) + +print(response) +``` + +## Streaming + +### Using `astream_complete` endpoint + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) + +async for chunk in await llm.astream_complete("Tell me a joke"): + print(chunk.delta, end="") +``` + +### Using `astream_chat` endpoint + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.base.llms.types import ChatMessage + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", +) +response = await llm.stream_chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ChatMessage( + role="assistant", content="Why did the chicken cross the road?" + ), + ChatMessage(role="user", content="I don't know, why?"), + ] +) + +async for chunk in response: + print(chunk.delta, end="") +``` + +## Configure Model + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", + temperature=0.2, + max_tokens=500, + timeout=120, + context_window=2500, + additional_kwargs={ + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + }, +) +response = llm.chat( + [ + ChatMessage(role="user", content="Tell me a joke"), + ] +) +print(response) +``` + +## Function Calling + +The [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) offers prebuilt service containers that make deploying and serving a large language model very easy. Either one of vLLM (a high-throughput and memory-efficient inference and serving engine for LLMs) or TGI (a high-performance text generation server for the popular open-source LLMs) is used in the service container to host the model, the end point created supports the OpenAI API protocol. This allows the model deployment to be used as a drop-in replacement for applications using OpenAI API. If the deployed model supports function calling, then integration with LlamaIndex tools, through the predict_and_call function on the llm allows to attach any tools and let the LLM decide which tools to call (if any). + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.tools import FunctionTool + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", + temperature=0.2, + max_tokens=500, + timeout=120, + context_window=2500, + additional_kwargs={ + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + }, +) + + +def multiply(a: float, b: float) -> float: + print(f"---> {a} * {b}") + return a * b + + +def add(a: float, b: float) -> float: + print(f"---> {a} + {b}") + return a + b + + +def subtract(a: float, b: float) -> float: + print(f"---> {a} - {b}") + return a - b + + +def divide(a: float, b: float) -> float: + print(f"---> {a} / {b}") + return a / b + + +multiply_tool = FunctionTool.from_defaults(fn=multiply) +add_tool = FunctionTool.from_defaults(fn=add) +sub_tool = FunctionTool.from_defaults(fn=subtract) +divide_tool = FunctionTool.from_defaults(fn=divide) + +response = llm.predict_and_call( + [multiply_tool, add_tool, sub_tool, divide_tool], + user_msg="Calculate the result of `8 + 2 - 6`.", + verbose=True, +) + +print(response) +``` + +### Using `FunctionCallingAgent` + +```python +import ads +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.core.tools import FunctionTool +from llama_index.core.agent import FunctionCallingAgent + +ads.set_auth(auth="security_token", profile="") + +llm = OCIDataScience( + model="odsc-llm", + endpoint="https:///predict", + temperature=0.2, + max_tokens=500, + timeout=120, + context_window=2500, + additional_kwargs={ + "top_p": 0.75, + "logprobs": True, + "top_logprobs": 3, + }, +) + + +def multiply(a: float, b: float) -> float: + print(f"---> {a} * {b}") + return a * b + + +def add(a: float, b: float) -> float: + print(f"---> {a} + {b}") + return a + b + + +def subtract(a: float, b: float) -> float: + print(f"---> {a} - {b}") + return a - b + + +def divide(a: float, b: float) -> float: + print(f"---> {a} / {b}") + return a / b + + +multiply_tool = FunctionTool.from_defaults(fn=multiply) +add_tool = FunctionTool.from_defaults(fn=add) +sub_tool = FunctionTool.from_defaults(fn=subtract) +divide_tool = FunctionTool.from_defaults(fn=divide) + +agent = FunctionCallingAgent.from_tools( + tools=[multiply_tool, add_tool, sub_tool, divide_tool], + llm=llm, + verbose=True, +) +response = agent.chat( + "Calculate the result of `8 + 2 - 6`. Use tools. Return the calculated result." +) + +print(response) +``` + +## LLM Implementation example + +https://docs.llamaindex.ai/en/stable/examples/llm/oci_data_science/ diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/BUILD new file mode 100644 index 0000000000000..db46e8d6c978c --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/__init__.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/__init__.py new file mode 100644 index 0000000000000..d82f3b1b7f4a5 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/__init__.py @@ -0,0 +1,4 @@ +from llama_index.llms.oci_data_science.base import OCIDataScience + + +__all__ = ["OCIDataScience"] diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py new file mode 100644 index 0000000000000..05c26308b51e1 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/base.py @@ -0,0 +1,936 @@ +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Union, TYPE_CHECKING + +import llama_index.core.instrumentation as instrument +from ads.common import auth as authutil +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.core.bridge.pydantic import ( + Field, + PrivateAttr, + model_validator, +) +from llama_index.core.callbacks import CallbackManager +from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_TEMPERATURE +from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.llms.llm import ToolSelection +from llama_index.core.llms.utils import parse_partial_json +from llama_index.core.types import BaseOutputParser, PydanticProgramMode +from llama_index.llms.oci_data_science.client import AsyncClient, Client +from llama_index.llms.oci_data_science.utils import ( + DEFAULT_TOOL_CHOICE, + _from_completion_logprobs_dict, + _from_message_dict, + _from_token_logprob_dicts, + _get_response_token_counts, + _resolve_tool_choice, + _to_message_dicts, + _update_tool_calls, + _validate_dependency, +) + + +if TYPE_CHECKING: + from llama_index.core.tools.types import BaseTool + +dispatcher = instrument.get_dispatcher(__name__) + + +DEFAULT_MODEL = "odsc-llm" +DEFAULT_MAX_TOKENS = 512 +DEFAULT_TIMEOUT = 120 +DEFAULT_MAX_RETRIES = 5 + +logger = logging.getLogger(__name__) + + +class OCIDataScience(FunctionCallingLLM): + """ + LLM deployed on OCI Data Science Model Deployment. + + **Setup:** + Install ``oracle-ads`` and ``llama-index-llms-oci-data-science``. + + ```bash + pip install -U oracle-ads llama-index-llms-oci-data-science + ``` + + Use `ads.set_auth()` to configure authentication. + For example, to use OCI resource_principal for authentication: + + ```python + import ads + ads.set_auth("resource_principal") + ``` + + For more details on authentication, see: + https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html + + Make sure to have the required policies to access the OCI Data + Science Model Deployment endpoint. See: + https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm + + To learn more about deploying LLM models in OCI Data Science, see: + https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm + + + **Examples:** + + **Basic Usage:** + + ```python + from llama_index.llms.oci_data_science import OCIDataScience + import ads + ads.set_auth(auth="security_token", profile="OC1") + + llm = OCIDataScience( + endpoint="https:///predict", + model="odsc-llm", + ) + prompt = "What is the capital of France?" + response = llm.complete(prompt) + print(response) + ``` + + **Custom Parameters:** + + ```python + llm = OCIDataScience( + endpoint="https:///predict", + model="odsc-llm", + temperature=0.7, + max_tokens=150, + additional_kwargs={"top_p": 0.9}, + ) + ``` + + **Using Chat Interface:** + + ```python + messages = [ + ChatMessage(role="user", content="Tell me a joke."), + ChatMessage(role="assistant", content="Why did the chicken cross the road?"), + ChatMessage(role="user", content="I don't know, why?"), + ] + + chat_response = llm.chat(messages) + print(chat_response) + ``` + + **Streaming Completion:** + + ```python + for chunk in llm.stream_complete("Once upon a time"): + print(chunk.delta, end="") + ``` + + **Asynchronous Chat:** + + ```python + import asyncio + + async def async_chat(): + messages = [ + ChatMessage(role="user", content="What's the weather like today?") + ] + response = await llm.achat(messages) + print(response) + + asyncio.run(async_chat()) + ``` + + **Using Tools (Function Calling):** + + ```python + from llama_index.llms.oci_data_science import OCIDataScience + from llama_index.core.tools import FunctionTool + import ads + ads.set_auth(auth="security_token", profile="OC1") + + def multiply(a: float, b: float) -> float: + return a * b + + def add(a: float, b: float) -> float: + return a + b + + def subtract(a: float, b: float) -> float: + return a - b + + def divide(a: float, b: float) -> float: + return a / b + + + multiply_tool = FunctionTool.from_defaults(fn=multiply) + add_tool = FunctionTool.from_defaults(fn=add) + sub_tool = FunctionTool.from_defaults(fn=subtract) + divide_tool = FunctionTool.from_defaults(fn=divide) + + llm = OCIDataScience( + endpoint="https:///predict", + model="odsc-llm", + temperature=0.7, + max_tokens=150, + additional_kwargs={"top_p": 0.9}, + ) + + response = llm.chat_with_tools( + user_msg="Calculate the result of 2 + 2.", + tools=[multiply_tool, add_tool, sub_tool, divide_tool], + ) + print(response) + ``` + """ + + endpoint: str = Field( + default=None, description="The URI of the endpoint from the deployed model." + ) + + auth: Dict[str, Any] = Field( + default_factory=dict, + exclude=True, + description=( + "The authentication dictionary used for OCI API requests. Default is an empty dictionary. " + "If not provided, it will be autogenerated based on the environment variables. " + "https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html." + ), + ) + model: Optional[str] = Field( + default=DEFAULT_MODEL, + description="The OCI Data Science default model. Defaults to `odsc-llm`.", + ) + temperature: Optional[float] = Field( + default=DEFAULT_TEMPERATURE, + description="A non-negative float that tunes the degree of randomness in generation.", + ge=0.0, + le=1.0, + ) + max_tokens: Optional[int] = Field( + default=DEFAULT_MAX_TOKENS, + description="Denotes the number of tokens to predict per generation.", + gt=0, + ) + timeout: float = Field( + default=DEFAULT_TIMEOUT, description="The timeout to use in seconds.", ge=0 + ) + max_retries: int = Field( + default=DEFAULT_MAX_RETRIES, + description="The maximum number of API retries.", + ge=0, + ) + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, + ) + is_chat_model: bool = Field( + default=True, + description="If the model exposes a chat interface.", + ) + is_function_calling_model: bool = Field( + default=True, + description="If the model supports function calling messages.", + ) + additional_kwargs: Optional[Dict[str, Any]] = Field( + default_factory=dict, + description="Additional kwargs for the OCI Data Science AI request.", + ) + strict: bool = Field( + default=False, + description="Whether to use strict mode for invoking tools/using schemas.", + ) + default_headers: Optional[Dict[str, str]] = Field( + default=None, description="The default headers for API requests." + ) + + _client: Client = PrivateAttr() + _async_client: AsyncClient = PrivateAttr() + + def __init__( + self, + endpoint: str, + auth: Optional[Dict[str, Any]] = None, + model: Optional[str] = DEFAULT_MODEL, + temperature: Optional[float] = DEFAULT_TEMPERATURE, + max_tokens: Optional[int] = DEFAULT_MAX_TOKENS, + context_window: Optional[int] = DEFAULT_CONTEXT_WINDOW, + timeout: Optional[float] = DEFAULT_TIMEOUT, + max_retries: Optional[int] = DEFAULT_MAX_RETRIES, + additional_kwargs: Optional[Dict[str, Any]] = None, + callback_manager: Optional[CallbackManager] = None, + is_chat_model: Optional[bool] = True, + is_function_calling_model: Optional[bool] = True, + default_headers: Optional[Dict[str, str]] = None, + # base class + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, + strict: bool = False, + **kwargs, + ) -> None: + """ + Initialize the OCIDataScience LLM class. + + Args: + endpoint (str): The URI of the endpoint from the deployed model. + auth (Optional[Dict[str, Any]]): Authentication dictionary for OCI API requests. + model (Optional[str]): The model name to use. Defaults to `odsc-llm`. + temperature (Optional[float]): Controls the randomness in generation. + max_tokens (Optional[int]): Number of tokens to predict per generation. + context_window (Optional[int]): Maximum number of context tokens for the model. + timeout (Optional[float]): Timeout for API requests in seconds. + max_retries (Optional[int]): Maximum number of API retries. + additional_kwargs (Optional[Dict[str, Any]]): Additional parameters for the API request. + callback_manager (Optional[CallbackManager]): Callback manager for LLM. + is_chat_model (Optional[bool]): If the model exposes a chat interface. Defaults to `True`. + is_function_calling_model (Optional[bool]): If the model supports function calling messages. Defaults to `True`. + default_headers (Optional[Dict[str, str]]): The default headers for API requests. + system_prompt (Optional[str]): System prompt to use. + messages_to_prompt (Optional[Callable]): Function to convert messages to prompt. + completion_to_prompt (Optional[Callable]): Function to convert completion to prompt. + pydantic_program_mode (PydanticProgramMode): Pydantic program mode. + output_parser (Optional[BaseOutputParser]): Output parser for the LLM. + strict (bool): Whether to use strict mode for invoking tools/using schemas. + **kwargs: Additional keyword arguments. + """ + super().__init__( + endpoint=endpoint, + model=model, + auth=auth or authutil.default_signer(), + temperature=temperature, + context_window=context_window, + max_tokens=max_tokens, + timeout=timeout, + max_retries=max_retries, + additional_kwargs=additional_kwargs or {}, + callback_manager=callback_manager or CallbackManager([]), + is_chat_model=is_chat_model, + is_function_calling_model=is_function_calling_model, + default_headers=default_headers, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + strict=strict, + **kwargs, + ) + + self._client: Client = None + self._async_client: AsyncClient = None + + logger.debug( + f"Initialized OCIDataScience LLM with endpoint: {self.endpoint} and model: {self.model}" + ) + + @model_validator(mode="before") + @_validate_dependency + def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate the environment and dependencies.""" + return values + + @property + def client(self) -> Client: + """ + Synchronous client for interacting with the OCI Data Science Model Deployment endpoint. + + Returns: + Client: The synchronous client instance. + """ + if self._client is None: + self._client = Client( + endpoint=self.endpoint, + auth=self.auth, + retries=self.max_retries, + timeout=self.timeout, + ) + return self._client + + @property + def async_client(self) -> AsyncClient: + """ + Asynchronous client for interacting with the OCI Data Science Model Deployment endpoint. + + Returns: + AsyncClient: The asynchronous client instance. + """ + if self._async_client is None: + self._async_client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth, + retries=self.max_retries, + timeout=self.timeout, + ) + return self._async_client + + @classmethod + def class_name(cls) -> str: + """ + Return the class name. + + Returns: + str: The name of the class. + """ + return "OCIDataScience_LLM" + + @property + def metadata(self) -> LLMMetadata: + """ + Return the metadata of the LLM. + + Returns: + LLMMetadata: The metadata of the LLM. + """ + return LLMMetadata( + context_window=self.context_window, + num_output=self.max_tokens or -1, + is_chat_model=self.is_chat_model, + is_function_calling_model=self.is_function_calling_model, + model_name=self.model, + ) + + def _model_kwargs(self, **kwargs: Any) -> Dict[str, Any]: + """ + Get model-specific parameters for the API request. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + Dict[str, Any]: The combined model parameters. + """ + base_kwargs = { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + } + return {**base_kwargs, **self.additional_kwargs, **kwargs} + + def _prepare_headers( + self, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + """ + Construct and return the headers for a request. + + Args: + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, str]: The prepared headers. + """ + return {**(self.default_headers or {}), **(headers or {})} + + @llm_completion_callback() + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """ + Generate a completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Returns: + CompletionResponse: The response from the LLM. + """ + logger.debug(f"Calling complete with prompt: {prompt}") + response = self.client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=False, + ) + + logger.debug(f"Received response: {response}") + try: + choice = response["choices"][0] + text = choice.get("text", "") + logprobs = _from_completion_logprobs_dict(choice.get("logprobs") or {}) + + return CompletionResponse( + text=text, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {e!s}") from e + + @llm_completion_callback() + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + """ + Stream the completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Yields: + CompletionResponse: The streamed response from the LLM. + """ + logger.debug(f"Starting stream_complete with prompt: {prompt}") + text = "" + for response in self.client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=True, + ): + logger.debug(f"Received chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("text") + if delta is None: + delta = "" + else: + delta = "" + text += delta + + yield CompletionResponse( + delta=delta, + text=text, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + @llm_chat_callback() + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + """ + Generate a chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Returns: + ChatResponse: The chat response from the LLM. + """ + logger.debug(f"Calling chat with messages: {messages}") + response = self.client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=False, + ) + + logger.debug(f"Received chat response: {response}") + try: + choice = response["choices"][0] + message = _from_message_dict(choice.get("message", "")) + logprobs = _from_token_logprob_dicts( + (choice.get("logprobs") or {}).get("content", []) + ) + return ChatResponse( + message=message, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {e!s}") from e + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + """ + Stream the chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Yields: + ChatResponse: The streamed chat response from the LLM. + """ + logger.debug(f"Starting stream_chat with messages: {messages}") + content = "" + is_function = False + tool_calls = [] + for response in self.client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=True, + ): + logger.debug(f"Received chat chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("delta") or {} + else: + delta = {} + + # Check if this chunk is the start of a function call + if delta.get("tool_calls"): + is_function = True + + # Update using deltas + role = delta.get("role") or MessageRole.ASSISTANT + content_delta = delta.get("content") or "" + content += content_delta + + additional_kwargs = {} + if is_function: + tool_calls = _update_tool_calls(tool_calls, delta.get("tool_calls")) + if tool_calls: + additional_kwargs["tool_calls"] = tool_calls + + yield ChatResponse( + message=ChatMessage( + role=role, + content=content, + additional_kwargs=additional_kwargs, + ), + delta=content_delta, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + """ + Asynchronously generate a completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Returns: + CompletionResponse: The response from the LLM. + """ + logger.debug(f"Calling acomplete with prompt: {prompt}") + response = await self.async_client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=False, + ) + + logger.debug(f"Received async response: {response}") + try: + choice = response["choices"][0] + text = choice.get("text", "") + logprobs = _from_completion_logprobs_dict(choice.get("logprobs", {}) or {}) + + return CompletionResponse( + text=text, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {e!s}") from e + + @llm_completion_callback() + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + """ + Asynchronously stream the completion for the given prompt. + + Args: + prompt (str): The prompt to generate a completion for. + formatted (bool): Whether the prompt is formatted. + **kwargs: Additional keyword arguments. + + Yields: + CompletionResponse: The streamed response from the LLM. + """ + + async def gen() -> CompletionResponseAsyncGen: + logger.debug(f"Starting astream_complete with prompt: {prompt}") + text = "" + + async for response in await self.async_client.generate( + prompt=prompt, + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=True, + ): + logger.debug(f"Received async chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("text") + if delta is None: + delta = "" + else: + delta = "" + text += delta + + yield CompletionResponse( + delta=delta, + text=text, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + return gen() + + @llm_chat_callback() + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + """ + Asynchronously generate a chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Returns: + ChatResponse: The chat response from the LLM. + """ + logger.debug(f"Calling achat with messages: {messages}") + response = await self.async_client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=False, + ) + + logger.debug(f"Received async chat response: {response}") + try: + choice = response["choices"][0] + message = _from_message_dict(choice.get("message", "")) + logprobs = _from_token_logprob_dicts( + (choice.get("logprobs") or {}).get("content", {}) + ) + return ChatResponse( + message=message, + raw=response, + logprobs=logprobs, + additional_kwargs=_get_response_token_counts(response), + ) + except (IndexError, KeyError, TypeError) as e: + raise ValueError(f"Failed to parse response: {e!s}") from e + + @llm_chat_callback() + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + """ + Asynchronously stream the chat completion based on the input messages. + + Args: + messages (Sequence[ChatMessage]): A sequence of chat messages. + **kwargs: Additional keyword arguments. + + Yields: + ChatResponse: The streamed chat response from the LLM. + """ + + async def gen() -> ChatResponseAsyncGen: + logger.debug(f"Starting astream_chat with messages: {messages}") + content = "" + is_function = False + tool_calls = [] + async for response in await self.async_client.chat( + messages=_to_message_dicts( + messages=messages, drop_none=kwargs.pop("drop_none", False) + ), + payload=self._model_kwargs(**kwargs), + headers=self._prepare_headers(kwargs.pop("headers", {})), + stream=True, + ): + logger.debug(f"Received async chat chunk: {response}") + if len(response.get("choices", [])) > 0: + delta = response["choices"][0].get("delta") or {} + else: + delta = {} + + # Check if this chunk is the start of a function call + if delta.get("tool_calls"): + is_function = True + + # Update using deltas + role = delta.get("role") or MessageRole.ASSISTANT + content_delta = delta.get("content") or "" + content += content_delta + + additional_kwargs = {} + if is_function: + tool_calls = _update_tool_calls(tool_calls, delta.get("tool_calls")) + if tool_calls: + additional_kwargs["tool_calls"] = tool_calls + + yield ChatResponse( + message=ChatMessage( + role=role, + content=content, + additional_kwargs=additional_kwargs, + ), + delta=content_delta, + raw=response, + additional_kwargs=_get_response_token_counts(response), + ) + + return gen() + + def _prepare_chat_with_tools( + self, + tools: List["BaseTool"], + user_msg: Optional[Union[str, ChatMessage]] = None, + chat_history: Optional[List[ChatMessage]] = None, + verbose: bool = False, + allow_parallel_tool_calls: bool = False, + tool_choice: Union[str, dict] = DEFAULT_TOOL_CHOICE, + strict: Optional[bool] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """ + Prepare the chat input with tools for function calling. + + Args: + tools (List[BaseTool]): A list of tools to use. + user_msg (Optional[Union[str, ChatMessage]]): The user's message. + chat_history (Optional[List[ChatMessage]]): The chat history. + verbose (bool): Whether to output verbose logs. + allow_parallel_tool_calls (bool): Whether to allow parallel tool calls. + tool_choice (Union[str, dict]): Tool choice strategy. + strict (Optional[bool]): Whether to enforce strict mode. + **kwargs: Additional keyword arguments. + + Returns: + Dict[str, Any]: The prepared parameters for the chat request. + """ + tool_specs = [tool.metadata.to_openai_tool() for tool in tools] + + logger.debug( + f"Preparing chat with tools. Tools: {tool_specs}, User message: {user_msg}, " + f"Chat history: {chat_history}" + ) + + # Determine strict mode + strict = strict or self.strict + + if self.metadata.is_function_calling_model: + for tool_spec in tool_specs: + if tool_spec["type"] == "function": + if strict: + tool_spec["function"]["strict"] = strict + tool_spec["function"]["parameters"]["additionalProperties"] = False + + if isinstance(user_msg, str): + user_msg = ChatMessage(role=MessageRole.USER, content=user_msg) + + messages = chat_history or [] + if user_msg: + messages.append(user_msg) + + return { + "messages": messages, + "tools": tool_specs or None, + "tool_choice": (_resolve_tool_choice(tool_choice) if tool_specs else None), + **kwargs, + } + + def _validate_chat_with_tools_response( + self, + response: ChatResponse, + tools: List["BaseTool"], + allow_parallel_tool_calls: bool = False, + **kwargs: Any, + ) -> ChatResponse: + """ + Validate the response from chat_with_tools. + + Args: + response (ChatResponse): The chat response to validate. + tools (List[BaseTool]): A list of tools used. + allow_parallel_tool_calls (bool): Whether parallel tool calls are allowed. + **kwargs: Additional keyword arguments. + + Returns: + ChatResponse: The validated chat response. + """ + if not allow_parallel_tool_calls: + # Ensures that the 'tool_calls' in the response contain only a single tool call. + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + if len(tool_calls) > 1: + logger.warning( + "Multiple tool calls detected but parallel tool calls are not allowed. " + "Limiting to the first tool call." + ) + response.message.additional_kwargs["tool_calls"] = [tool_calls[0]] + return response + + def get_tool_calls_from_response( + self, + response: ChatResponse, + error_on_no_tool_call: bool = True, + **kwargs: Any, + ) -> List[ToolSelection]: + """ + Extract tool calls from the chat response. + + Args: + response (ChatResponse): The chat response containing tool calls. + error_on_no_tool_call (bool): Whether to raise an error if no tool calls are found. + **kwargs: Additional keyword arguments. + + Returns: + List[ToolSelection]: A list of tool selections extracted from the response. + + Raises: + ValueError: If no tool calls are found and error_on_no_tool_call is True. + """ + tool_calls = response.message.additional_kwargs.get("tool_calls", []) + logger.debug(f"Getting tool calls from response: {tool_calls}") + + if len(tool_calls) < 1: + if error_on_no_tool_call: + raise ValueError( + f"Expected at least one tool call, but got {len(tool_calls)} tool calls." + ) + else: + return [] + + tool_selections = [] + for tool_call in tool_calls: + if tool_call.get("type") != "function": + raise ValueError(f"Invalid tool type detected: {tool_call.get('type')}") + + # Handle both complete and partial JSON + try: + argument_dict = parse_partial_json( + tool_call.get("function", {}).get("arguments", {}) + ) + except ValueError as e: + logger.debug(f"Failed to parse tool call arguments: {e!s}") + argument_dict = {} + + tool_selections.append( + ToolSelection( + tool_id=tool_call.get("id"), + tool_name=tool_call.get("function", {}).get("name"), + tool_kwargs=argument_dict, + ) + ) + + logger.debug( + f"Extracted tool calls: { [tool_selection.model_dump() for tool_selection in tool_selections] }" + ) + return tool_selections diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py new file mode 100644 index 0000000000000..783318b6b070a --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/client.py @@ -0,0 +1,746 @@ +import asyncio +import functools +import json +import logging +import time +from abc import ABC +from types import TracebackType +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) + +import httpx +import oci +import requests +from ads.common import auth as authutil +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception, + stop_after_attempt, + stop_after_delay, + wait_exponential, + wait_random_exponential, +) + +DEFAULT_RETRIES = 3 +DEFAULT_BACKOFF_FACTOR = 3 +TIMEOUT = 600 # Timeout in seconds +STATUS_FORCE_LIST = [429, 500, 502, 503, 504] +DEFAULT_ENCODING = "utf-8" + +_T = TypeVar("_T", bound="BaseClient") + +logger = logging.getLogger(__name__) + + +class OCIAuth(httpx.Auth): + """ + Custom HTTPX authentication class that uses the OCI Signer for request signing. + + Attributes: + signer (oci.signer.Signer): The OCI signer used to sign requests. + """ + + def __init__(self, signer: oci.signer.Signer): + """ + Initialize the OCIAuth instance. + + Args: + signer (oci.signer.Signer): The OCI signer to use for signing requests. + """ + self.signer = signer + + def auth_flow(self, request: httpx.Request) -> Iterator[httpx.Request]: + """ + The authentication flow that signs the HTTPX request using the OCI signer. + + Args: + request (httpx.Request): The outgoing HTTPX request to be signed. + + Yields: + httpx.Request: The signed HTTPX request. + """ + # Create a requests.Request object from the HTTPX request + req = requests.Request( + method=request.method, + url=str(request.url), + headers=dict(request.headers), + data=request.content, + ) + prepared_request = req.prepare() + + # Sign the request using the OCI Signer + self.signer.do_request_sign(prepared_request) + + # Update the original HTTPX request with the signed headers + request.headers.update(prepared_request.headers) + + # Proceed with the request + yield request + + +class ExtendedRequestException(Exception): + """ + Custom exception for handling request errors with additional context. + + Attributes: + original_exception (Exception): The original exception that caused the error. + response_text (str): The text of the response received from the request, if available. + """ + + def __init__(self, message: str, original_exception: Exception, response_text: str): + """ + Initialize the ExtendedRequestException. + + Args: + message (str): The error message associated with the exception. + original_exception (Exception): The original exception that caused the error. + response_text (str): The text of the response received from the request, if available. + """ + super().__init__(message) + self.original_exception = original_exception + self.response_text = response_text + + +def _should_retry_exception(e: ExtendedRequestException) -> bool: + """ + Determine whether the exception should trigger a retry. + + Args: + e (ExtendedRequestException): The exception raised. + + Returns: + bool: True if the exception should trigger a retry, False otherwise. + """ + original_exception = e.original_exception if hasattr(e, "original_exception") else e + if isinstance(original_exception, httpx.HTTPStatusError): + return original_exception.response.status_code in STATUS_FORCE_LIST + elif isinstance(original_exception, httpx.RequestError): + return True + return False + + +def _create_retry_decorator( + max_retries: int, + backoff_factor: float, + random_exponential: bool = False, + stop_after_delay_seconds: Optional[float] = None, + min_seconds: float = 0, + max_seconds: float = 60, +) -> Callable[[Any], Any]: + """ + Create a tenacity retry decorator with the specified configuration. + + Args: + max_retries (int): The maximum number of retry attempts. + backoff_factor (float): The backoff factor for calculating retry delays. + random_exponential (bool): Whether to use random exponential backoff. + stop_after_delay_seconds (Optional[float]): Maximum total time to retry. + min_seconds (float): Minimum wait time between retries. + max_seconds (float): Maximum wait time between retries. + + Returns: + Callable[[Any], Any]: A tenacity retry decorator configured with the specified strategy. + """ + wait_strategy = ( + wait_random_exponential(min=min_seconds, max=max_seconds) + if random_exponential + else wait_exponential( + multiplier=backoff_factor, min=min_seconds, max=max_seconds + ) + ) + + stop_strategy = stop_after_attempt(max_retries) + if stop_after_delay_seconds is not None: + stop_strategy = stop_strategy | stop_after_delay(stop_after_delay_seconds) + + retry_strategy = retry_if_exception(_should_retry_exception) + return retry( + wait=wait_strategy, + stop=stop_strategy, + retry=retry_strategy, + reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def _retry_decorator(f: Callable) -> Callable: + """ + Decorator to apply retry logic to a function using tenacity. + + Args: + f (Callable): The function to be decorated. + + Returns: + Callable: The decorated function with retry logic applied. + """ + + @functools.wraps(f) + def wrapper(self, *args: Any, **kwargs: Any): + retries = getattr(self, "retries", DEFAULT_RETRIES) + if retries <= 0: + return f(self, *args, **kwargs) + backoff_factor = getattr(self, "backoff_factor", DEFAULT_BACKOFF_FACTOR) + retry_func = _create_retry_decorator( + max_retries=retries, + backoff_factor=backoff_factor, + random_exponential=False, + stop_after_delay_seconds=getattr(self, "timeout", TIMEOUT), + min_seconds=0, + max_seconds=60, + ) + + return retry_func(f)(self, *args, **kwargs) + + return wrapper + + +class BaseClient(ABC): + """ + Base class for invoking models via HTTP requests with retry logic. + + Attributes: + endpoint (str): The URL endpoint to send the request. + auth (Any): The authentication signer for the requests. + retries (int): The number of retry attempts for the request. + backoff_factor (float): The factor to determine the delay between retries. + timeout (Union[float, Tuple[float, float]]): The timeout setting for the HTTP request. + kwargs (Dict): Additional keyword arguments. + """ + + def __init__( + self, + endpoint: str, + auth: Optional[Any] = None, + retries: Optional[int] = DEFAULT_RETRIES, + backoff_factor: Optional[float] = DEFAULT_BACKOFF_FACTOR, + timeout: Optional[Union[float, Tuple[float, float]]] = None, + **kwargs: Any, + ) -> None: + """ + Initialize the BaseClient. + + Args: + endpoint (str): The URL endpoint to send the request. + auth (Optional[Any]): The authentication signer for the requests. + retries (Optional[int]): The number of retry attempts for the request. + backoff_factor (Optional[float]): The factor to determine the delay between retries. + timeout (Optional[Union[float, Tuple[float, float]]]): The timeout setting for the HTTP request. + **kwargs: Additional keyword arguments. + """ + self.endpoint = endpoint + self.retries = retries or DEFAULT_RETRIES + self.backoff_factor = backoff_factor or DEFAULT_BACKOFF_FACTOR + self.timeout = timeout or TIMEOUT + self.kwargs = kwargs + + # Validate auth object + auth = auth or authutil.default_signer() + if not callable(auth.get("signer")): + raise ValueError("Auth object must have a 'signer' callable attribute.") + self.auth = OCIAuth(auth["signer"]) + + logger.debug( + f"Initialized {self.__class__.__name__} with endpoint={self.endpoint}, " + f"retries={self.retries}, backoff_factor={self.backoff_factor}, timeout={self.timeout}" + ) + + def _parse_streaming_line( + self, line: Union[bytes, str] + ) -> Optional[Dict[str, Any]]: + """ + Parse a single line from the streaming response. + + Args: + line (Union[bytes, str]): A line of the response in bytes or string format. + + Returns: + Optional[Dict[str, Any]]: Parsed JSON object, or None if the line is to be ignored. + + Raises: + Exception: Raised if the line contains an error object. + json.JSONDecodeError: Raised if the line cannot be decoded as JSON. + """ + logger.debug(f"Parsing streaming line: {line}") + + if isinstance(line, bytes): + line = line.decode(DEFAULT_ENCODING) + + line = line.strip() + + if line.lower().startswith("data:"): + line = line[5:].lstrip() + + if not line or line.startswith("[DONE]"): + logger.debug("Received end of stream signal or empty line.") + return None + + try: + json_line = json.loads(line) + logger.debug(f"Parsed JSON line: {json_line}") + except json.JSONDecodeError as e: + logger.debug(f"Error decoding JSON from line: {line}") + raise json.JSONDecodeError( + f"Error decoding JSON from line: {e!s}", e.doc, e.pos + ) from e + + if json_line.get("object") == "error": + # Raise an error for error objects in the stream + error_message = json_line.get("message", "Unknown error") + logger.debug(f"Error in streaming response: {error_message}") + raise Exception(f"Error in streaming response: {error_message}") + + return json_line + + def _prepare_headers( + self, + stream: bool, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + """ + Construct and return the headers for a request. + + Args: + stream (bool): Whether to use streaming for the response. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, str]: The prepared headers. + """ + default_headers = { + "Content-Type": "application/json", + "Accept": "text/event-stream" if stream else "application/json", + } + if stream: + default_headers["enable-streaming"] = "true" + if headers: + default_headers.update(headers) + + logger.debug(f"Prepared headers: {default_headers}") + return default_headers + + +class Client(BaseClient): + """ + Synchronous HTTP client for invoking models with support for request and streaming APIs. + """ + + def __init__(self, *args, **kwargs) -> None: + """ + Initialize the Client. + + Args: + *args: Positional arguments forwarded to BaseClient. + **kwargs: Keyword arguments forwarded to BaseClient. + """ + super().__init__(*args, **kwargs) + self._client = httpx.Client(timeout=self.timeout) + + def is_closed(self) -> bool: + return self._client.is_closed + + def close(self) -> None: + """Close the underlying HTTPX client.""" + self._client.close() + + def __enter__(self: _T) -> _T: # noqa: PYI019 + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + self.close() + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + @_retry_decorator + def _request( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """ + Send a POST request to the configured endpoint with retry and error handling. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, Any]: Decoded JSON response. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting synchronous request with payload: {payload}") + try: + response = self._client.post( + self.endpoint, + headers=self._prepare_headers(stream=False, headers=headers), + auth=self.auth, + json=payload, + ) + logger.debug(f"Received response with status code: {response.status_code}") + response.raise_for_status() + json_response = response.json() + logger.debug(f"Response JSON: {json_response}") + return json_response + except Exception as e: + last_exception_text = ( + e.response.text if hasattr(e, "response") and e.response else str(e) + ) + logger.error( + f"Request failed. Error: {e!s}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Request failed: {e!s}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + def _stream( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Iterator[Mapping[str, Any]]: + """ + Send a POST request expecting a streaming response. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Yields: + Mapping[str, Any]: Decoded JSON response line-by-line. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting synchronous streaming request with payload: {payload}") + last_exception_text = None + + for attempt in range(1, self.retries + 2): # retries + initial attempt + logger.debug(f"Attempt {attempt} for synchronous streaming request.") + try: + with self._client.stream( + "POST", + self.endpoint, + headers=self._prepare_headers(stream=True, headers=headers), + auth=self.auth, + json={**payload, "stream": True}, + ) as response: + try: + logger.debug( + f"Received streaming response with status code: {response.status_code}" + ) + response.raise_for_status() + for line in response.iter_lines(): + if not line: # Skip empty lines + continue + + parsed_line = self._parse_streaming_line(line) + if parsed_line: + logger.debug(f"Yielding parsed line: {parsed_line}") + yield parsed_line + return + except Exception as e: + last_exception_text = ( + e.response.read().decode( + e.response.encoding or DEFAULT_ENCODING + ) + if hasattr(e, "response") and e.response + else str(e) + ) + raise + + except Exception as e: + if attempt <= self.retries and _should_retry_exception(e): + delay = self.backoff_factor * (2 ** (attempt - 1)) + logger.warning( + f"Streaming attempt {attempt} failed: {e}. Retrying in {delay} seconds..." + ) + time.sleep(delay) + else: + logger.error( + f"Streaming request failed. Error: {e!s}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Streaming request failed: {e!s}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + def generate( + self, + prompt: str, + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = True, + ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: + """ + Generate text completion for the given prompt. + + Args: + prompt (str): Input text prompt for the model. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: A full JSON response or an iterator for streaming responses. + """ + logger.debug(f"Generating text with prompt: {prompt}, stream: {stream}") + payload = {**(payload or {}), "prompt": prompt} + headers = {"route": "/v1/completions", **(headers or {})} + if stream: + return self._stream(payload=payload, headers=headers) + return self._request(payload=payload, headers=headers) + + def chat( + self, + messages: List[Dict[str, Any]], + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = True, + ) -> Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: + """ + Perform a chat interaction with the model. + + Args: + messages (List[Dict[str, Any]]): List of message dictionaries for chat interaction. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], Iterator[Mapping[str, Any]]]: A full JSON response or an iterator for streaming responses. + """ + logger.debug(f"Starting chat with messages: {messages}, stream: {stream}") + payload = {**(payload or {}), "messages": messages} + headers = {"route": "/v1/chat/completions", **(headers or {})} + if stream: + return self._stream(payload=payload, headers=headers) + return self._request(payload=payload, headers=headers) + + +class AsyncClient(BaseClient): + """ + Asynchronous HTTP client for invoking models with support for request and streaming APIs, including retry logic. + """ + + def __init__(self, *args, **kwargs) -> None: + """ + Initialize the AsyncClient. + + Args: + *args: Positional arguments forwarded to BaseClient. + **kwargs: Keyword arguments forwarded to BaseClient. + """ + super().__init__(*args, **kwargs) + self._client = httpx.AsyncClient(timeout=self.timeout) + + def is_closed(self) -> bool: + return self._client.is_closed + + async def close(self) -> None: + """Close the underlying HTTPX client. + + The client will *not* be usable after this. + """ + await self._client.aclose() + + async def __aenter__(self: _T) -> _T: # noqa: PYI019 + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]] = None, + exc: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + await self.close() + + def __del__(self) -> None: + try: + if not self._client.is_closed: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + pass + + @_retry_decorator + async def _request( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + """ + Send a POST request to the configured endpoint with retry and error handling. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Returns: + Dict[str, Any]: Decoded JSON response. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting asynchronous request with payload: {payload}") + try: + response = await self._client.post( + self.endpoint, + headers=self._prepare_headers(stream=False, headers=headers), + auth=self.auth, + json=payload, + ) + logger.debug(f"Received response with status code: {response.status_code}") + response.raise_for_status() + json_response = response.json() + logger.debug(f"Response JSON: {json_response}") + return json_response + except Exception as e: + last_exception_text = ( + e.response.text if hasattr(e, "response") and e.response else str(e) + ) + logger.error( + f"Request failed. Error: {e!s}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Request failed: {e!s}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + async def _stream( + self, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> AsyncIterator[Mapping[str, Any]]: + """ + Send a POST request expecting a streaming response with retry logic. + + Args: + payload (Dict[str, Any]): Parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + + Yields: + Mapping[str, Any]: Decoded JSON response line-by-line. + + Raises: + ExtendedRequestException: Raised when the request fails. + """ + logger.debug(f"Starting asynchronous streaming request with payload: {payload}") + last_exception_text = None + for attempt in range(1, self.retries + 2): # retries + initial attempt + logger.debug(f"Attempt {attempt} for asynchronous streaming request.") + try: + async with self._client.stream( + "POST", + self.endpoint, + headers=self._prepare_headers(stream=True, headers=headers), + auth=self.auth, + json={**payload, "stream": True}, + ) as response: + try: + logger.debug( + f"Received streaming response with status code: {response.status_code}" + ) + response.raise_for_status() + async for line in response.aiter_lines(): + if not line: # Skip empty lines + continue + parsed_line = self._parse_streaming_line(line) + if parsed_line: + logger.debug(f"Yielding parsed line: {parsed_line}") + yield parsed_line + return + except Exception as e: + if hasattr(e, "response") and e.response: + content = await e.response.aread() + last_exception_text = content.decode( + e.response.encoding or DEFAULT_ENCODING + ) + raise + except Exception as e: + if attempt <= self.retries and _should_retry_exception(e): + delay = self.backoff_factor * (2 ** (attempt - 1)) + logger.warning( + f"Streaming attempt {attempt} failed: {e}. Retrying in {delay} seconds..." + ) + await asyncio.sleep(delay) + else: + logger.error( + f"Streaming request failed. Error: {e!s}. Details: {last_exception_text}" + ) + raise ExtendedRequestException( + f"Streaming request failed: {e!s}. Details: {last_exception_text}", + e, + last_exception_text, + ) from e + + async def generate( + self, + prompt: str, + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = False, + ) -> Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: + """ + Generate text completion for the given prompt. + + Args: + prompt (str): Input text prompt for the model. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: A full JSON response or an async iterator for streaming responses. + """ + logger.debug(f"Generating text with prompt: {prompt}, stream: {stream}") + payload = {**(payload or {}), "prompt": prompt} + headers = {"route": "/v1/completions", **(headers or {})} + if stream: + return self._stream(payload=payload, headers=headers) + return await self._request(payload=payload, headers=headers) + + async def chat( + self, + messages: List[Dict[str, Any]], + payload: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + stream: bool = False, + ) -> Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: + """ + Perform a chat interaction with the model. + + Args: + messages (List[Dict[str, Any]]): List of message dictionaries for chat interaction. + payload (Optional[Dict[str, Any]]): Additional parameters for the request payload. + headers (Optional[Dict[str, str]]): HTTP headers to include in the request. + stream (bool): Whether to use streaming for the response. + + Returns: + Union[Dict[str, Any], AsyncIterator[Mapping[str, Any]]]: A full JSON response or an async iterator for streaming responses. + """ + logger.debug(f"Starting chat with messages: {messages}, stream: {stream}") + payload = {**(payload or {}), "messages": messages} + headers = {"route": "/v1/chat/completions", **(headers or {})} + if stream: + return self._stream(payload=payload, headers=headers) + return await self._request(payload=payload, headers=headers) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py new file mode 100644 index 0000000000000..8201d416192c6 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/llama_index/llms/oci_data_science/utils.py @@ -0,0 +1,231 @@ +import logging +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from llama_index.core.base.llms.types import ChatMessage, LogProb +from packaging import version + +MIN_ADS_VERSION = "2.12.6" +SUPPORTED_TOOL_CHOICES = ["none", "auto", "required"] +DEFAULT_TOOL_CHOICE = "auto" + +logger = logging.getLogger(__name__) + + +class UnsupportedOracleAdsVersionError(Exception): + """Custom exception for unsupported `oracle-ads` versions. + + Attributes: + current_version: The installed version of `oracle-ads`. + required_version: The minimum required version of `oracle-ads`. + """ + + def __init__(self, current_version: str, required_version: str): + super().__init__( + f"The `oracle-ads` version {current_version} currently installed is incompatible with " + "the `llama-index-llms-oci-data-science` version in use. To resolve this issue, " + f"please upgrade to `oracle-ads:{required_version}` or later using the " + "command: `pip install oracle-ads -U`" + ) + + +def _validate_dependency(func: Callable[..., Any]) -> Callable[..., Any]: + """Decorator to validate the presence and version of `oracle-ads`. + + This decorator checks that `oracle-ads` is installed and that its version meets + the minimum requirement. If not, it raises an error. + + Args: + func: The function to wrap with the dependency validation. + + Returns: + The wrapped function. + + Raises: + ImportError: If `oracle-ads` is not installed. + UnsupportedOracleAdsVersionError: If the installed version is below the required version. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + try: + from ads import __version__ as ads_version + + if version.parse(ads_version) < version.parse(MIN_ADS_VERSION): + raise UnsupportedOracleAdsVersionError(ads_version, MIN_ADS_VERSION) + except ImportError as ex: + raise ImportError( + "Could not import `oracle-ads` Python package. " + "Please install it with `pip install oracle-ads`." + ) from ex + return func(*args, **kwargs) + + return wrapper + + +def _to_message_dicts( + messages: Sequence[ChatMessage], drop_none: bool = False +) -> List[Dict[str, Any]]: + """Convert a sequence of ChatMessage objects to a list of dictionaries. + + Args: + messages: The messages to convert. + drop_none: Whether to drop keys with `None` values. Defaults to False. + + Returns: + A list of message dictionaries. + """ + message_dicts = [] + for message in messages: + message_dict = { + "role": message.role.value, + "content": message.content, + **message.additional_kwargs, + } + if drop_none: + message_dict = {k: v for k, v in message_dict.items() if v is not None} + message_dicts.append(message_dict) + return message_dicts + + +def _from_completion_logprobs_dict( + completion_logprobs_dict: Dict[str, Any] +) -> List[List[LogProb]]: + """Convert completion logprobs to a list of generic LogProb objects. + + Args: + completion_logprobs_dict: The completion logprobs to convert. + + Returns: + A list of lists of LogProb objects. + """ + return [ + [ + LogProb(token=token, logprob=logprob, bytes=[]) + for token, logprob in logprob_dict.items() + ] + for logprob_dict in completion_logprobs_dict.get("top_logprobs", []) + ] + + +def _from_token_logprob_dicts( + token_logprob_dicts: Sequence[Dict[str, Any]] +) -> List[List[LogProb]]: + """Convert a sequence of token logprob dictionaries to a list of LogProb objects. + + Args: + token_logprob_dicts: The token logprob dictionaries to convert. + + Returns: + A list of lists of LogProb objects. + + Raises: + Warning: Logs a warning if an error occurs while parsing token logprobs. + """ + result = [] + for token_logprob_dict in token_logprob_dicts: + try: + logprobs_list = [ + LogProb( + token=el.get("token"), + logprob=el.get("logprob"), + bytes=el.get("bytes") or [], + ) + for el in token_logprob_dict.get("top_logprobs", []) + ] + if logprobs_list: + result.append(logprobs_list) + except Exception as e: + logger.warning( + "Error occurred in attempt to parse token logprob. " + f"Details: {e}. Src: {token_logprob_dict}" + ) + return result + + +def _from_message_dict(message_dict: Dict[str, Any]) -> ChatMessage: + """Convert a message dictionary to a ChatMessage object. + + Args: + message_dict: The message dictionary. + + Returns: + A ChatMessage object representing the given dictionary. + """ + return ChatMessage( + role=message_dict.get("role"), + content=message_dict.get("content"), + additional_kwargs={"tool_calls": message_dict.get("tool_calls", [])}, + ) + + +def _get_response_token_counts(raw_response: Dict[str, Any]) -> Dict[str, int]: + """Extract token usage information from the response. + + Args: + raw_response: The raw response containing token usage information. + + Returns: + A dictionary containing token counts, or an empty dictionary if usage info is not found. + """ + if not raw_response.get("usage"): + return {} + + return { + "prompt_tokens": raw_response["usage"].get("prompt_tokens", 0), + "completion_tokens": raw_response["usage"].get("completion_tokens", 0), + "total_tokens": raw_response["usage"].get("total_tokens", 0), + } + + +def _update_tool_calls( + tool_calls: List[Dict[str, Any]], tool_calls_delta: Optional[List[Dict[str, Any]]] +) -> List[Dict[str, Any]]: + """Update the tool calls using delta objects received from stream chunks. + + Args: + tool_calls: The list of existing tool calls. + tool_calls_delta: The delta updates for the tool calls (if any). + + Returns: + The updated list of tool calls. + """ + if not tool_calls_delta: + return tool_calls + + delta_call = tool_calls_delta[0] + if not tool_calls or tool_calls[-1].get("index") != delta_call.get("index"): + tool_calls.append(delta_call) + else: + latest_call = tool_calls[-1] + latest_function = latest_call.setdefault("function", {}) + delta_function = delta_call.get("function", {}) + + latest_function["arguments"] = latest_function.get( + "arguments", "" + ) + delta_function.get("arguments", "") + latest_function["name"] = latest_function.get("name", "") + delta_function.get( + "name", "" + ) + latest_call["id"] = latest_call.get("id", "") + delta_call.get("id", "") + + return tool_calls + + +def _resolve_tool_choice( + tool_choice: Union[str, dict] = DEFAULT_TOOL_CHOICE +) -> Union[str, dict]: + """Resolve the tool choice into a string or a dictionary. + + If the tool_choice is a string that is not in SUPPORTED_TOOL_CHOICES, a dictionary + representing a function call is returned. + + Args: + tool_choice: The desired tool choice, which can be a string or a dictionary. Defaults to "auto". + + Returns: + Either the original tool_choice if valid or a dictionary representing a function call. + """ + if isinstance(tool_choice, str) and tool_choice not in SUPPORTED_TOOL_CHOICES: + return {"type": "function", "function": {"name": tool_choice}} + return tool_choice diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml new file mode 100644 index 0000000000000..8b9078ad59c14 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/pyproject.toml @@ -0,0 +1,64 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.llms.oci_data_science" + +[tool.llamahub.class_authors] +OCIDataScience = "mrdzurb" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Dmitrii Cherkasov "] +description = "llama-index llms OCI Data Science integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-llms-oci-data-science" +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.9,<4.0" +oracle-ads = ">=2.12.6" +llama-index-core = "^0.12.0" + +[tool.poetry.group.dev.dependencies] +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-asyncio = ">=0.24.0" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD new file mode 100644 index 0000000000000..b9078077b4d11 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/BUILD @@ -0,0 +1,3 @@ +python_tests( + interpreter_constraints=["==3.9.*"] +) diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/__init__.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py new file mode 100644 index 0000000000000..c916df6e89ef6 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_llms_oci_data_science.py @@ -0,0 +1,344 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from ads.common import auth as authutil +from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole +from llama_index.core.callbacks import CallbackManager +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.tools.types import BaseTool +from llama_index.llms.oci_data_science import OCIDataScience +from llama_index.llms.oci_data_science.base import OCIDataScience +from llama_index.llms.oci_data_science.client import AsyncClient, Client + + +def test_embedding_class(): + names_of_base_classes = [b.__name__ for b in OCIDataScience.__mro__] + assert FunctionCallingLLM.__name__ in names_of_base_classes + + +@pytest.fixture() +def llm(): + endpoint = "https://example.com/api" + auth = {"signer": Mock()} + model = "odsc-llm" + temperature = 0.7 + max_tokens = 100 + timeout = 60 + max_retries = 3 + additional_kwargs = {"top_p": 0.9} + callback_manager = CallbackManager([]) + + with patch.object(authutil, "default_signer", return_value=auth): + llm_instance = OCIDataScience( + endpoint=endpoint, + auth=auth, + model=model, + temperature=temperature, + max_tokens=max_tokens, + timeout=timeout, + max_retries=max_retries, + additional_kwargs=additional_kwargs, + callback_manager=callback_manager, + ) + # Mock the client + llm_instance._client = Mock(spec=Client) + llm_instance._async_client = AsyncMock(spec=AsyncClient) + return llm_instance + + +def test_complete_success(llm): + prompt = "What is the capital of France?" + response_data = { + "choices": [ + { + "text": "The capital of France is Paris.", + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12, + }, + } + # Mock the client's generate method + llm.client.generate.return_value = response_data + + response = llm.complete(prompt) + + # Assertions + llm.client.generate.assert_called_once() + assert response.text == "The capital of France is Paris." + assert response.additional_kwargs["total_tokens"] == 12 + + +def test_complete_invalid_response(llm): + prompt = "What is the capital of France?" + response_data = {} # Empty response + llm.client.generate.return_value = response_data + + with pytest.raises(ValueError): + llm.complete(prompt) + + +def test_chat_success(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + response_data = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Why did the chicken cross the road? To get to the other side!", + }, + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 15, + "total_tokens": 25, + }, + } + llm.client.chat.return_value = response_data + + response = llm.chat(messages) + + llm.client.chat.assert_called_once() + assert ( + response.message.content + == "Why did the chicken cross the road? To get to the other side!" + ) + assert response.additional_kwargs["total_tokens"] == 25 + + +def test_stream_complete(llm): + prompt = "Once upon a time" + # Mock the client's generate method to return an iterator + response_data = iter( + [ + {"choices": [{"text": "Once"}], "usage": {}}, + {"choices": [{"text": " upon"}], "usage": {}}, + {"choices": [{"text": " a"}], "usage": {}}, + {"choices": [{"text": " time."}], "usage": {}}, + ] + ) + llm.client.generate.return_value = response_data + + responses = list(llm.stream_complete(prompt)) + + llm.client.generate.assert_called_once() + assert len(responses) == 4 + assert responses[0].delta == "Once" + assert responses[1].delta == " upon" + assert responses[2].delta == " a" + assert responses[3].delta == " time." + assert responses[-1].text == "Once upon a time." + + +def test_stream_chat(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + response_data = iter( + [ + {"choices": [{"delta": {"content": "Why"}}], "usage": {}}, + {"choices": [{"delta": {"content": " did"}}], "usage": {}}, + {"choices": [{"delta": {"content": " the"}}], "usage": {}}, + { + "choices": [{"delta": {"content": " chicken cross the road?"}}], + "usage": {}, + }, + ] + ) + llm.client.chat.return_value = response_data + + responses = list(llm.stream_chat(messages)) + + llm.client.chat.assert_called_once() + assert len(responses) == 4 + content = "".join([r.delta for r in responses]) + assert content == "Why did the chicken cross the road?" + assert responses[-1].message.content == content + + +def test_prepare_chat_with_tools(llm): + # Mock tools + tool1 = Mock(spec=BaseTool) + tool1.metadata.to_openai_tool.return_value = { + "name": "tool1", + "type": "function", + "function": { + "name": "tool1", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + tool2 = Mock(spec=BaseTool) + tool2.metadata.to_openai_tool.return_value = { + "name": "tool2", + "type": "function", + "function": { + "name": "tool2", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + + user_msg = "Calculate the result of 2 + 2." + chat_history = [ChatMessage(role=MessageRole.USER, content="Previous message")] + + result = llm._prepare_chat_with_tools( + tools=[tool1, tool2], + user_msg=user_msg, + chat_history=chat_history, + ) + + # Check that 'function' key has been updated as expected + for tool_spec in result["tools"]: + assert "function" in tool_spec + assert "parameters" in tool_spec["function"] + assert tool_spec["function"]["parameters"]["additionalProperties"] is False + + assert "messages" in result + assert "tools" in result + assert len(result["tools"]) == 2 + assert result["messages"][-1].content == user_msg + + +def test_get_tool_calls_from_response(llm): + tool_call = { + "type": "function", + "id": "123", + "function": { + "name": "multiply", + "arguments": '{"a": 2, "b": 3}', + }, + } + response = ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content="", + additional_kwargs={"tool_calls": [tool_call]}, + ), + raw={}, + ) + + tool_selections = llm.get_tool_calls_from_response(response) + + assert len(tool_selections) == 1 + assert tool_selections[0].tool_name == "multiply" + assert tool_selections[0].tool_kwargs == {"a": 2, "b": 3} + + +@pytest.mark.asyncio() +async def test_acomplete_success(llm): + prompt = "What is the capital of France?" + response_data = { + "choices": [ + { + "text": "The capital of France is Paris.", + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12, + }, + } + llm.async_client.generate.return_value = response_data + + response = await llm.acomplete(prompt) + + llm.async_client.generate.assert_called_once() + assert response.text == "The capital of France is Paris." + assert response.additional_kwargs["total_tokens"] == 12 + + +@pytest.mark.asyncio() +async def test_astream_complete(llm): + prompt = "Once upon a time" + + async def async_generator(): + response_data = [ + {"choices": [{"text": "Once"}], "usage": {}}, + {"choices": [{"text": " upon"}], "usage": {}}, + {"choices": [{"text": " a"}], "usage": {}}, + {"choices": [{"text": " time."}], "usage": {}}, + ] + for item in response_data: + yield item + + llm.async_client.generate.return_value = async_generator() + + responses = [] + async for response in await llm.astream_complete(prompt): + responses.append(response) + + llm.async_client.generate.assert_called_once() + assert len(responses) == 4 + assert responses[0].delta == "Once" + assert responses[-1].text == "Once upon a time." + + +@pytest.mark.asyncio() +async def test_achat_success(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + response_data = { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Why did the chicken cross the road? To get to the other side!", + }, + "logprobs": {}, + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 15, + "total_tokens": 25, + }, + } + llm.async_client.chat.return_value = response_data + + response = await llm.achat(messages) + + llm.async_client.chat.assert_called_once() + assert ( + response.message.content + == "Why did the chicken cross the road? To get to the other side!" + ) + assert response.additional_kwargs["total_tokens"] == 25 + + +@pytest.mark.asyncio() +async def test_astream_chat(llm): + messages = [ChatMessage(role=MessageRole.USER, content="Tell me a joke.")] + + async def async_generator(): + response_data = [ + {"choices": [{"delta": {"content": "Why"}}], "usage": {}}, + {"choices": [{"delta": {"content": " did"}}], "usage": {}}, + {"choices": [{"delta": {"content": " the"}}], "usage": {}}, + { + "choices": [{"delta": {"content": " chicken cross the road?"}}], + "usage": {}, + }, + ] + for item in response_data: + yield item + + llm.async_client.chat.return_value = async_generator() + + responses = [] + async for response in await llm.astream_chat(messages): + responses.append(response) + + llm.async_client.chat.assert_called_once() + assert len(responses) == 4 + content = "".join([r.delta for r in responses]) + assert content == "Why did the chicken cross the road?" + assert responses[-1].message.content == content diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py new file mode 100644 index 0000000000000..b926c4039eb9b --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_client.py @@ -0,0 +1,693 @@ +import json +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import httpx +import pytest +from ads.common import auth as authutil +from llama_index.llms.oci_data_science.client import ( + AsyncClient, + BaseClient, + Client, + ExtendedRequestException, + OCIAuth, + _create_retry_decorator, + _retry_decorator, + _should_retry_exception, +) + + +class TestOCIAuth: + """Unit tests for OCIAuth class.""" + + def setup_method(self): + self.signer_mock = Mock() + self.oci_auth = OCIAuth(self.signer_mock) + + def test_auth_flow(self): + """Ensures that the auth_flow signs the request correctly.""" + request = httpx.Request("POST", "https://example.com") + prepared_request_mock = Mock() + prepared_request_mock.headers = {"Authorization": "Signed"} + with patch("requests.Request") as mock_requests_request: + mock_requests_request.return_value = Mock() + mock_requests_request.return_value.prepare.return_value = ( + prepared_request_mock + ) + self.signer_mock.do_request_sign = Mock() + + list(self.oci_auth.auth_flow(request)) + + self.signer_mock.do_request_sign.assert_called() + assert request.headers.get("Authorization") == "Signed" + + +class TestExtendedRequestException: + """Unit tests for ExtendedRequestException.""" + + def test_exception_attributes(self): + """Ensures the exception stores the correct attributes.""" + original_exception = Exception("Original error") + response_text = "Error response text" + message = "Extended error message" + + exception = ExtendedRequestException(message, original_exception, response_text) + + assert str(exception) == message + assert exception.original_exception == original_exception + assert exception.response_text == response_text + + +class TestShouldRetryException: + """Unit tests for _should_retry_exception function.""" + + def test_http_status_error_in_force_list(self): + """Ensures it returns True for HTTPStatusError with status in STATUS_FORCE_LIST.""" + response_mock = Mock() + response_mock.status_code = 500 + original_exception = httpx.HTTPStatusError( + "Error", request=None, response=response_mock + ) + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is True + + def test_http_status_error_not_in_force_list(self): + """Ensures it returns False for HTTPStatusError with status not in STATUS_FORCE_LIST.""" + response_mock = Mock() + response_mock.status_code = 404 + original_exception = httpx.HTTPStatusError( + "Error", request=None, response=response_mock + ) + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is False + + def test_http_request_error(self): + """Ensures it returns True for RequestError.""" + original_exception = httpx.RequestError("Error") + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is True + + def test_other_exception(self): + """Ensures it returns False for other exceptions.""" + original_exception = Exception("Some other error") + exception = ExtendedRequestException( + "Message", original_exception, "Response text" + ) + + result = _should_retry_exception(exception) + assert result is False + + +class TestCreateRetryDecorator: + """Unit tests for _create_retry_decorator function.""" + + def test_create_retry_decorator(self): + """Ensures the retry decorator is created with correct parameters.""" + max_retries = 5 + backoff_factor = 2 + random_exponential = False + stop_after_delay_seconds = 100 + min_seconds = 1 + max_seconds = 10 + + retry_decorator = _create_retry_decorator( + max_retries, + backoff_factor, + random_exponential, + stop_after_delay_seconds, + min_seconds, + max_seconds, + ) + + assert callable(retry_decorator) + + +class TestRetryDecorator: + """Unit tests for _retry_decorator function.""" + + def test_retry_decorator_no_retries(self): + """Ensures the function is called directly when retries is 0.""" + + class TestClass: + retries = 0 + backoff_factor = 1 + timeout = 10 + + @_retry_decorator + def test_method(self): + return "Success" + + test_instance = TestClass() + result = test_instance.test_method() + assert result == "Success" + + def test_retry_decorator_with_retries(self): + """Ensures the function retries upon exception.""" + + class TestClass: + retries = 3 + backoff_factor = 0.1 + timeout = 10 + + call_count = 0 + + @_retry_decorator + def test_method(self): + self.call_count += 1 + if self.call_count < 3: + raise ExtendedRequestException( + "Error", + original_exception=httpx.RequestError("Error"), + response_text="test", + ) + return "Success" + + test_instance = TestClass() + result = test_instance.test_method() + assert result == "Success" + assert test_instance.call_count == 3 + + def test_retry_decorator_exceeds_retries(self): + """Ensures the function raises exception after exceeding retries.""" + + class TestClass: + retries = 3 + backoff_factor = 0.1 + timeout = 10 + + call_count = 0 + + @_retry_decorator + def test_method(self): + self.call_count += 1 + raise ExtendedRequestException( + "Error", + original_exception=httpx.RequestError("Error"), + response_text="test", + ) + + test_instance = TestClass() + with pytest.raises(ExtendedRequestException): + test_instance.test_method() + assert test_instance.call_count == 3 # initial attempt + 2 retries + + +class TestBaseClient: + """Unit tests for BaseClient class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 3 + self.backoff_factor = 2 + self.timeout = 30 + + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + self.base_client = BaseClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + + def test_init(self): + """Ensures that the client is initialized correctly.""" + assert self.base_client.endpoint == self.endpoint + assert self.base_client.retries == self.retries + assert self.base_client.backoff_factor == self.backoff_factor + assert self.base_client.timeout == self.timeout + assert isinstance(self.base_client.auth, OCIAuth) + + def test_init_default_auth(self): + """Ensures that default auth is used when auth is None.""" + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + client = BaseClient(endpoint=self.endpoint) + assert client.auth is not None + + def test_init_invalid_auth(self): + """Ensures that ValueError is raised when auth signer is invalid.""" + with pytest.raises(ValueError): + BaseClient(endpoint=self.endpoint, auth={"signer": None}) + + def test_prepare_headers(self): + """Ensures that headers are prepared correctly.""" + headers = {"Custom-Header": "Value"} + result = self.base_client._prepare_headers(stream=False, headers=headers) + expected_headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Custom-Header": "Value", + } + assert result == expected_headers + + def test_prepare_headers_stream(self): + """Ensures that headers are prepared correctly for streaming.""" + headers = {"Custom-Header": "Value"} + result = self.base_client._prepare_headers(stream=True, headers=headers) + expected_headers = { + "Content-Type": "application/json", + "Accept": "text/event-stream", + "enable-streaming": "true", + "Custom-Header": "Value", + } + assert result == expected_headers + + def test_parse_streaming_line_valid(self): + """Ensures that a valid streaming line is parsed correctly.""" + line = 'data: {"key": "value"}' + result = self.base_client._parse_streaming_line(line) + assert result == {"key": "value"} + + def test_parse_streaming_line_invalid_json(self): + """Ensures that JSONDecodeError is raised for invalid JSON.""" + line = "data: invalid json" + with pytest.raises(json.JSONDecodeError): + self.base_client._parse_streaming_line(line) + + def test_parse_streaming_line_empty(self): + """Ensures that None is returned for empty or end-of-stream lines.""" + line = "" + result = self.base_client._parse_streaming_line(line) + assert result is None + + line = "[DONE]" + result = self.base_client._parse_streaming_line(line) + assert result is None + + def test_parse_streaming_line_error_object(self): + """Ensures that an exception is raised for error objects in the stream.""" + line = 'data: {"object": "error", "message": "Error message"}' + with pytest.raises(Exception) as exc_info: + self.base_client._parse_streaming_line(line) + assert "Error in streaming response: Error message" in str(exc_info.value) + + +class TestClient: + """Unit tests for Client class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 2 + self.backoff_factor = 0.1 + self.timeout = 10 + + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + self.client = Client( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + # Mock the internal HTTPX client + self.client._client = Mock() + + def test_request_success(self): + """Ensures that _request returns JSON response on success.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client._request(payload) + + assert result == response_json + + def test_request_http_error(self): + """Ensures that _request raises ExtendedRequestException on HTTP error.""" + payload = {"prompt": "Hello"} + response_mock = Mock() + response_mock.status_code = 500 + response_mock.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=None, response=response_mock + ) + response_mock.text = "Internal Server Error" + + self.client._client.post.return_value = response_mock + + with pytest.raises(ExtendedRequestException) as exc_info: + self.client._request(payload) + + assert "Request failed" in str(exc_info.value) + assert exc_info.value.response_text == "Internal Server Error" + + def test_stream_success(self): + """Ensures that _stream yields parsed lines on success.""" + payload = {"prompt": "Hello"} + response_mock = Mock() + response_mock.status_code = 200 + response_mock.iter_lines.return_value = [ + b'data: {"key": "value1"}', + b'data: {"key": "value2"}', + b"[DONE]", + ] + # Mock the context manager + stream_cm = MagicMock() + stream_cm.__enter__.return_value = response_mock + self.client._client.stream.return_value = stream_cm + + result = list(self.client._stream(payload)) + + assert result == [{"key": "value1"}, {"key": "value2"}] + + @patch("time.sleep", return_value=None) + def test_stream_retry_on_exception(self, mock_sleep): + """Ensures that _stream retries on exceptions and raises after retries exhausted.""" + payload = {"prompt": "Hello"} + + # Mock the exception to be raised + def side_effect(*args, **kwargs): + raise httpx.RequestError("Connection error") + + # Mock the context manager + self.client._client.stream.side_effect = side_effect + + with pytest.raises(ExtendedRequestException): + list(self.client._stream(payload)) + + assert ( + self.client._client.stream.call_count == self.retries + 1 + ) # initial attempt + retries + + def test_generate_stream(self): + """Ensures that generate method calls _stream when stream=True.""" + payload = {"prompt": "Hello"} + response_mock = Mock() + response_mock.status_code = 200 + response_mock.iter_lines.return_value = [b'data: {"key": "value"}', b"[DONE]"] + # Mock the context manager + stream_cm = MagicMock() + stream_cm.__enter__.return_value = response_mock + self.client._client.stream.return_value = stream_cm + + result = list(self.client.generate(prompt="Hello", stream=True)) + + assert result == [{"key": "value"}] + + def test_generate_request(self): + """Ensures that generate method calls _request when stream=False.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client.generate(prompt="Hello", stream=False) + + assert result == response_json + + def test_chat_stream(self): + """Ensures that chat method calls _stream when stream=True.""" + messages = [{"role": "user", "content": "Hello"}] + response_mock = Mock() + response_mock.status_code = 200 + response_mock.iter_lines.return_value = [b'data: {"key": "value"}', b"[DONE]"] + # Mock the context manager + stream_cm = MagicMock() + stream_cm.__enter__.return_value = response_mock + self.client._client.stream.return_value = stream_cm + + result = list(self.client.chat(messages=messages, stream=True)) + + assert result == [{"key": "value"}] + + def test_chat_request(self): + """Ensures that chat method calls _request when stream=False.""" + messages = [{"role": "user", "content": "Hello"}] + response_json = {"choices": [{"message": {"content": "Hi"}}]} + response_mock = Mock() + response_mock.json.return_value = response_json + response_mock.status_code = 200 + + self.client._client.post.return_value = response_mock + + result = self.client.chat(messages=messages, stream=False) + + assert result == response_json + + def test_close(self): + """Ensures that close method closes the client.""" + self.client._client.close = Mock() + self.client.close() + self.client._client.close.assert_called_once() + + def test_is_closed(self): + """Ensures that is_closed returns the client's is_closed status.""" + self.client._client.is_closed = False + assert not self.client.is_closed() + self.client._client.is_closed = True + assert self.client.is_closed() + + def test_context_manager(self): + """Ensures that the client can be used as a context manager.""" + self.client.close = Mock() + with self.client as client_instance: + assert client_instance == self.client + self.client.close.assert_called_once() + + def test_del(self): + """Ensures that __del__ method closes the client.""" + client = Client( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + client.close = Mock() + client.__del__() # Manually invoke __del__ + client.close.assert_called_once() + + +@pytest.mark.asyncio() +class TestAsyncClient: + """Unit tests for AsyncClient class.""" + + def setup_method(self): + self.endpoint = "https://example.com/api" + self.auth_mock = {"signer": Mock()} + self.retries = 2 + self.backoff_factor = 0.1 + self.timeout = 10 + + with patch.object(authutil, "default_signer", return_value=self.auth_mock): + self.client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + # Mock the internal HTTPX client + self.client._client = AsyncMock() + self.client._client.is_closed = False + + def async_iter(self, items): + """Helper function to create an async iterator from a list.""" + + async def generator(): + for item in items: + yield item + + return generator() + + async def test_request_success(self): + """Ensures that _request returns JSON response on success.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + self.client._client.post.return_value = response_mock + result = await self.client._request(payload) + assert await result == response_json + + async def test_request_http_error(self): + """Ensures that _request raises ExtendedRequestException on HTTP error.""" + payload = {"prompt": "Hello"} + response_mock = MagicMock() + response_mock.status_code = 500 + response_mock.raise_for_status.side_effect = httpx.HTTPStatusError( + "Server error", request=None, response=response_mock + ) + response_mock.text = "Internal Server Error" + + self.client._client.post.return_value = response_mock + + with pytest.raises(ExtendedRequestException) as exc_info: + await self.client._request(payload) + + assert "Request failed" in str(exc_info.value) + assert exc_info.value.response_text == "Internal Server Error" + + async def test_stream_success(self): + """Ensures that _stream yields parsed lines on success.""" + payload = {"prompt": "Hello"} + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.raise_for_status = Mock() + response_mock.aiter_lines.return_value = self.async_iter( + ['data: {"key": "value1"}', 'data: {"key": "value2"}', "[DONE]"] + ) + + # Define an async context manager + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + yield response_mock + + # Mock the stream method to return our context manager + self.client._client.stream = Mock(side_effect=stream_context_manager) + + result = [] + async for item in self.client._stream(payload): + result.append(item) + + assert result == [{"key": "value1"}, {"key": "value2"}] + + @patch("asyncio.sleep", return_value=None) + async def test_stream_retry_on_exception(self, mock_sleep): + """Ensures that _stream retries on exceptions and raises after retries exhausted.""" + payload = {"prompt": "Hello"} + + # Define an async context manager that raises an exception + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + raise httpx.RequestError("Connection error") + yield # This is never reached + + # Mock the stream method to use our context manager + self.client._client.stream = Mock(side_effect=stream_context_manager) + + with pytest.raises(ExtendedRequestException): + async for _ in self.client._stream(payload): + pass + + assert ( + self.client._client.stream.call_count == self.retries + 1 + ) # initial attempt + retries + + async def test_generate_stream(self): + """Ensures that generate method calls _stream when stream=True.""" + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.raise_for_status = Mock() + response_mock.aiter_lines.return_value = self.async_iter( + ['data: {"key": "value"}', "[DONE]"] + ) + + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + yield response_mock + + self.client._client.stream = Mock(side_effect=stream_context_manager) + + result = [] + async for item in await self.client.generate(prompt="Hello", stream=True): + result.append(item) + + assert result == [{"key": "value"}] + + async def test_generate_request(self): + """Ensures that generate method calls _request when stream=False.""" + payload = {"prompt": "Hello"} + response_json = {"choices": [{"text": "Hi"}]} + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + + self.client._client.post.return_value = response_mock + + result = await self.client.generate(prompt="Hello", stream=False) + + assert await result == response_json + + async def test_chat_stream(self): + """Ensures that chat method calls _stream when stream=True.""" + messages = [{"role": "user", "content": "Hello"}] + response_mock = MagicMock() + response_mock.status_code = 200 + response_mock.raise_for_status = Mock() + response_mock.aiter_lines.return_value = self.async_iter( + ['data: {"key": "value"}', "[DONE]"] + ) + + @asynccontextmanager + async def stream_context_manager(*args, **kwargs): + yield response_mock + + self.client._client.stream = Mock(side_effect=stream_context_manager) + + result = [] + async for item in await self.client.chat(messages=messages, stream=True): + result.append(item) + + assert result == [{"key": "value"}] + + async def test_chat_request(self): + """Ensures that chat method calls _request when stream=False.""" + messages = [{"role": "user", "content": "Hello"}] + response_json = {"choices": [{"message": {"content": "Hi"}}]} + response_mock = AsyncMock() + response_mock.status_code = 200 + response_mock.json = AsyncMock(return_value=response_json) + response_mock.raise_for_status = Mock() + + self.client._client.post.return_value = response_mock + + result = await self.client.chat(messages=messages, stream=False) + + assert await result == response_json + + async def test_close(self): + """Ensures that close method closes the client.""" + self.client._client.aclose = AsyncMock() + await self.client.close() + self.client._client.aclose.assert_called_once() + + def test_is_closed(self): + """Ensures that is_closed returns the client's is_closed status.""" + self.client._client.is_closed = False + assert not self.client.is_closed() + self.client._client.is_closed = True + assert self.client.is_closed() + + async def test_context_manager(self): + """Ensures that the client can be used as a context manager.""" + self.client.close = AsyncMock() + async with self.client as client_instance: + assert client_instance == self.client + self.client.close.assert_called_once() + + async def test_del(self): + """Ensures that __del__ method closes the client.""" + client = AsyncClient( + endpoint=self.endpoint, + auth=self.auth_mock, + retries=self.retries, + backoff_factor=self.backoff_factor, + timeout=self.timeout, + ) + client.close = AsyncMock() + await client.__aexit__(None, None, None) # Manually invoke __aexit__ + client.close.assert_called_once() diff --git a/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py new file mode 100644 index 0000000000000..7b0c420761ae4 --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-oci-data-science/tests/test_oci_data_science_utils.py @@ -0,0 +1,337 @@ +from unittest.mock import patch + +import pytest +from llama_index.core.base.llms.types import ChatMessage, LogProb, MessageRole +from llama_index.llms.oci_data_science.utils import ( + UnsupportedOracleAdsVersionError, + _from_completion_logprobs_dict, + _from_message_dict, + _from_token_logprob_dicts, + _get_response_token_counts, + _resolve_tool_choice, + _to_message_dicts, + _update_tool_calls, + _validate_dependency, +) + + +class TestUnsupportedOracleAdsVersionError: + """Unit tests for UnsupportedOracleAdsVersionError.""" + + def test_exception_message(self): + """Ensures the exception message is formatted correctly.""" + current_version = "2.12.5" + required_version = "2.12.6" + expected_message = ( + f"The `oracle-ads` version {current_version} currently installed is incompatible with " + "the `llama-index-llms-oci-data-science` version in use. To resolve this issue, " + f"please upgrade to `oracle-ads:{required_version}` or later using the " + "command: `pip install oracle-ads -U`" + ) + + exception = UnsupportedOracleAdsVersionError(current_version, required_version) + assert str(exception) == expected_message + + +class TestValidateDependency: + """Unit tests for _validate_dependency decorator.""" + + def setup_method(self): + @_validate_dependency + def sample_function(): + return "function executed" + + self.sample_function = sample_function + + @patch("llama_index.llms.oci_data_science.utils.MIN_ADS_VERSION", new="2.12.6") + @patch("ads.__version__", new="2.12.7") + def test_valid_version(self): + """Ensures the function executes when the oracle-ads version is sufficient.""" + result = self.sample_function() + assert result == "function executed" + + @patch("llama_index.llms.oci_data_science.utils.MIN_ADS_VERSION", new="2.12.6") + @patch("ads.__version__", new="2.12.5") + def test_unsupported_version(self): + """Ensures UnsupportedOracleAdsVersionError is raised for insufficient version.""" + with pytest.raises(UnsupportedOracleAdsVersionError) as exc_info: + self.sample_function() + + @patch("llama_index.llms.oci_data_science.utils.MIN_ADS_VERSION", new="2.12.6") + def test_oracle_ads_not_installed(self): + """Ensures ImportError is raised when oracle-ads is not installed.""" + with patch.dict("sys.modules", {"ads": None}): + with pytest.raises(ImportError) as exc_info: + self.sample_function() + assert "Could not import `oracle-ads` Python package." in str( + exc_info.value + ) + + +class TestToMessageDicts: + """Unit tests for _to_message_dicts function.""" + + def test_sequence_conversion(self): + """Ensures a sequence of ChatMessages is converted correctly.""" + messages = [ + ChatMessage(role=MessageRole.USER, content="Hello"), + ChatMessage(role=MessageRole.ASSISTANT, content="Hi there!"), + ] + expected_result = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + result = _to_message_dicts(messages) + assert result == expected_result + + def test_empty_sequence(self): + """Ensures the function works with an empty sequence.""" + messages = [] + expected_result = [] + result = _to_message_dicts(messages) + assert result == expected_result + + def test_drop_none(self): + """Ensures drop_none parameter works correctly for sequences.""" + messages = [ + ChatMessage(role=MessageRole.USER, content=None), + ChatMessage( + role=MessageRole.ASSISTANT, + content="Hi there!", + additional_kwargs={"custom_field": None}, + ), + ] + expected_result = [ + {"role": "user"}, + {"role": "assistant", "content": "Hi there!"}, + ] + result = _to_message_dicts(messages, drop_none=True) + assert result == expected_result + + +class TestFromCompletionLogprobs: + """Unit tests for _from_completion_logprobs_dict function.""" + + def test_conversion(self): + """Ensures completion logprobs are converted correctly.""" + logprobs = { + "tokens": ["Hello", "world"], + "token_logprobs": [-0.1, -0.2], + "top_logprobs": [ + {"Hello": -0.1, "Hi": -1.0}, + {"world": -0.2, "earth": -1.2}, + ], + } + expected_result = [ + [ + LogProb(token="Hello", logprob=-0.1, bytes=[]), + LogProb(token="Hi", logprob=-1.0, bytes=[]), + ], + [ + LogProb(token="world", logprob=-0.2, bytes=[]), + LogProb(token="earth", logprob=-1.2, bytes=[]), + ], + ] + result = _from_completion_logprobs_dict(logprobs) + assert result == expected_result + + def test_empty_logprobs(self): + """Ensures function returns empty list when no logprobs are provided.""" + logprobs = {} + expected_result = [] + result = _from_completion_logprobs_dict(logprobs) + assert result == expected_result + + +class TestFromTokenLogprobs: + """Unit tests for _from_token_logprob_dicts function.""" + + def test_conversion(self): + """Ensures multiple token logprobs are converted correctly.""" + token_logprob_dicts = [ + { + "token": "Hello", + "logprob": -0.1, + "top_logprobs": [ + {"token": "Hello", "logprob": -0.1, "bytes": [1, 2, 3]}, + {"token": "Hi", "logprob": -1.0, "bytes": [1, 2, 3]}, + ], + }, + { + "token": "world", + "logprob": -0.2, + "top_logprobs": [ + {"token": "world", "logprob": -0.2, "bytes": [2, 3, 4]}, + {"token": "earth", "logprob": -1.2, "bytes": [2, 3, 4]}, + ], + }, + ] + expected_result = [ + [ + LogProb(token="Hello", logprob=-0.1, bytes=[1, 2, 3]), + LogProb(token="Hi", logprob=-1.0, bytes=[1, 2, 3]), + ], + [ + LogProb(token="world", logprob=-0.2, bytes=[2, 3, 4]), + LogProb(token="earth", logprob=-1.2, bytes=[2, 3, 4]), + ], + ] + result = _from_token_logprob_dicts(token_logprob_dicts) + assert result == expected_result + + def test_empty_input(self): + """Ensures function returns empty list when input is empty.""" + token_logprob_dicts = [] + expected_result = [] + result = _from_token_logprob_dicts(token_logprob_dicts) + assert result == expected_result + + +class TestFromMessage: + """Unit tests for _from_message_dict function.""" + + def test_conversion(self): + """Ensures an message dict is converted to ChatMessage.""" + message_dict = { + "role": "assistant", + "content": "Hello!", + "tool_calls": [{"name": "tool1", "arguments": "arg1"}], + } + expected_result = ChatMessage( + role="assistant", + content="Hello!", + additional_kwargs={"tool_calls": [{"name": "tool1", "arguments": "arg1"}]}, + ) + result = _from_message_dict(message_dict) + assert result == expected_result + + def test_missing_optional_fields(self): + """Ensures function works when optional fields are missing.""" + message_dict = {"role": "user", "content": "Hi!"} + expected_result = ChatMessage( + role="user", content="Hi!", additional_kwargs={"tool_calls": []} + ) + result = _from_message_dict(message_dict) + assert result == expected_result + + +class TestGetResponseTokenCounts: + """Unit tests for _get_response_token_counts function.""" + + def test_with_usage(self): + """Ensures token counts are extracted correctly when usage is present.""" + raw_response = { + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + } + expected_result = { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + } + result = _get_response_token_counts(raw_response) + assert result == expected_result + + def test_without_usage(self): + """Ensures function returns empty dict when usage is missing.""" + raw_response = {} + expected_result = {} + result = _get_response_token_counts(raw_response) + assert result == expected_result + + def test_missing_token_counts(self): + """Ensures missing token counts default to zero.""" + raw_response = {"usage": {}} + result = _get_response_token_counts(raw_response) + assert result == {} + + raw_response = {"usage": {"prompt_tokens": 10}} + expected_result = { + "prompt_tokens": 10, + "completion_tokens": 0, + "total_tokens": 0, + } + result = _get_response_token_counts(raw_response) + assert result == expected_result + + +class TestUpdateToolCalls: + """Unit tests for _update_tool_calls function.""" + + def test_add_new_call(self): + """Ensures a new tool call is added when indices do not match.""" + tool_calls = [{"index": 0, "function": {"name": "tool1", "arguments": "arg1"}}] + tool_calls_delta = [ + {"index": 1, "function": {"name": "tool2", "arguments": "arg2"}} + ] + expected_result = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}}, + {"index": 1, "function": {"name": "tool2", "arguments": "arg2"}}, + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result == expected_result + + def test_update_existing_call(self): + """Ensures the existing tool call is updated when indices match.""" + tool_calls = [{"index": 0, "function": {"name": "tool", "arguments": "arg"}}] + tool_calls_delta = [{"index": 0, "function": {"name": "1", "arguments": "1"}}] + expected_result = [ + { + "index": 0, + "function": {"name": "tool1", "arguments": "arg1"}, + "id": "", + } + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result[0]["function"]["name"] == "tool1" + assert result[0]["function"]["arguments"] == "arg1" + + def test_no_delta(self): + """Ensures the original tool_calls is returned when delta is None.""" + tool_calls = [{"index": 0, "function": {"name": "tool1", "arguments": "arg1"}}] + tool_calls_delta = None + expected_result = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}} + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result == expected_result + + def test_empty_tool_calls(self): + """Ensures tool_calls is initialized when empty.""" + tool_calls = [] + tool_calls_delta = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}} + ] + expected_result = [ + {"index": 0, "function": {"name": "tool1", "arguments": "arg1"}} + ] + result = _update_tool_calls(tool_calls, tool_calls_delta) + assert result == expected_result + + +class TestResolveToolChoice: + """Unit tests for _resolve_tool_choice function.""" + + @pytest.mark.parametrize( + ("input_choice", "expected_output"), + [ + ("auto", "auto"), + ("none", "none"), + ("required", "required"), + ( + "custom_tool", + {"type": "function", "function": {"name": "custom_tool"}}, + ), + ( + {"type": "function", "function": {"name": "custom_tool"}}, + {"type": "function", "function": {"name": "custom_tool"}}, + ), + ], + ) + def test_resolve_tool_choice(self, input_choice, expected_output): + """Ensures tool choices are resolved correctly.""" + result = _resolve_tool_choice(input_choice) + assert result == expected_output