From 4c089214883ddb999a78f5a38d08fe2785231553 Mon Sep 17 00:00:00 2001 From: Johnny Bouder Date: Mon, 9 Sep 2024 13:45:33 -0400 Subject: [PATCH 1/3] Refactor to breakout client based only on the service. Updated methods and notebooks for client changes. --- notebooks/embedding.ipynb | 4 +-- notebooks/rag_search.ipynb | 20 ++++++------ onyxgenai/embed.py | 58 ++++++++++++++++++++++------------- onyxgenai/model.py | 62 ++++++++++++++------------------------ 4 files changed, 70 insertions(+), 74 deletions(-) diff --git a/notebooks/embedding.ipynb b/notebooks/embedding.ipynb index e6082ec..20afcf8 100644 --- a/notebooks/embedding.ipynb +++ b/notebooks/embedding.ipynb @@ -55,8 +55,8 @@ "metadata": {}, "outputs": [], "source": [ - "client = EmbeddingClient(\"http://embed.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, num_workers=1, collection_name=\"test_collection\")\n", - "embeddings = client.embed_text(sentences)\n" + "client = EmbeddingClient(\"http://embed.onyx-services\")\n", + "embeddings = client.embed_text(sentences, model_name=\"all-MiniLM-L6-v2\", collection_name=\"test_collection\")\n" ] } ], diff --git a/notebooks/rag_search.ipynb b/notebooks/rag_search.ipynb index 1e73873..6deecb9 100644 --- a/notebooks/rag_search.ipynb +++ b/notebooks/rag_search.ipynb @@ -23,10 +23,8 @@ "from onyxgenai.embed import EmbeddingClient\n", "from onyxgenai.model import ModelClient\n", "\n", - "collection_name = \"test_collection\"\n", - "embedding_client = EmbeddingClient(\"http://embed.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, num_workers=2, collection_name=collection_name)\n", - "embedding_model_client = ModelClient(\"http://store.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, replicas=2, deployment_name=\"all-MiniLM-L6-v2\", options={\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})\n", - "llm_client = ModelClient(\"http://store.onyx-services\", model_name=\"Mistral-7B-Instruct-v0.3\", model_version=1, replicas=1, deployment_name=\"Mistral-7B-Instruct-v0.3\", options={})" + "embedding_client = EmbeddingClient(\"http://embed.onyx-services\")\n", + "model_client = ModelClient(\"http://store.onyx-services\")" ] }, { @@ -42,7 +40,7 @@ "metadata": {}, "outputs": [], "source": [ - "embedding_model_client.deploy_model()" + "embedding_model = model_client.deploy_model(\"all-MiniLM-L6-v2\", 1, 2, {\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})" ] }, { @@ -58,7 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm_client.deploy_model()" + "language_model = model_client.deploy_model(\"Mistral-7B-Instruct-v0.3\", 1, 1, {})" ] }, { @@ -76,7 +74,7 @@ "source": [ "query = \"What is the capital of France?\"\n", "data = [query]\n", - "embeddings = embedding_model_client.embed_text(data)" + "embeddings = model_client.embed_text(data, embedding_model)" ] }, { @@ -92,7 +90,7 @@ "metadata": {}, "outputs": [], "source": [ - "vector_data = embedding_client.vector_search(embeddings, collection_name)\n", + "vector_data = embedding_client.vector_search(embeddings, \"test_collection\")\n", "print(vector_data)" ] }, @@ -125,7 +123,7 @@ "Answer:\n", "\"\"\" # noqa: E501\n", "\n", - "answer = llm_client.generate_completion(prompt)\n", + "answer = model_client.generate_completion(prompt, model_name=\"Mistral-7B-Instruct-v0.3\")\n", "print(answer)" ] }, @@ -142,8 +140,8 @@ "metadata": {}, "outputs": [], "source": [ - "embedding_model_client.delete_deployment()\n", - "llm_client.delete_deployment()" + "model_client.delete_deployment(embedding_model)\n", + "model_client.delete_deployment(language_model)" ] } ], diff --git a/onyxgenai/embed.py b/onyxgenai/embed.py index 5099b58..b1353b0 100644 --- a/onyxgenai/embed.py +++ b/onyxgenai/embed.py @@ -9,27 +9,17 @@ class EmbeddingClient: A client for interacting with the Onyx Embedding Service. Args: svc_url (str): The URL of the Onyx Embedding Service - model_name (str): The name of the model to deploy - model_version (int): The version of the model to deploy - num_workers (int): The number of workers to deploy - collection_name (str): The name of the collection """ def __init__( self, svc_url, - model_name=None, - model_version=1, - num_workers=1, - collection_name=None, ) -> None: self.svc_url = svc_url - self.model_name = model_name - self.model_version = model_version - self.num_workers = num_workers - self.collection_name = collection_name - def _onyx_embed(self, batch, media_type): + def _onyx_embed( + self, batch, media_type, model_name, model_version, num_workers, collection_name + ): if media_type == "text": url = f"{self.svc_url}/embedding/text" elif media_type == "image": @@ -40,10 +30,10 @@ def _onyx_embed(self, batch, media_type): data = { "data": batch, - "model_identifier": self.model_name, - "model_version": self.model_version, - "num_workers": self.num_workers, - "collection_name": self.collection_name, + "model_identifier": model_name, + "model_version": model_version, + "num_workers": num_workers, + "collection_name": collection_name, } response = requests.post(url, json=data) @@ -112,11 +102,24 @@ def batch(self, iterable, batch_size=1): for ndx in range(0, batch_length, batch_size): yield iterable[ndx : min(ndx + batch_size, batch_length)] - def embed_text(self, data: list, batch_size=None, return_results=True): + def embed_text( + self, + data: list, + model_name, + model_version=1, + num_workers=1, + collection_name=None, + batch_size=None, + return_results=True, + ): """ Get the embeddings for the input text Args: data (list): The input text + model_name (str): The name of the model + model_version (int): The version of the model + num_workers (int): The number of workers + collection_name (str): The name of the collection batch_size (int): The size of the batches return_results (bool): Whether to return the results Returns: @@ -128,13 +131,24 @@ def embed_text(self, data: list, batch_size=None, return_results=True): results = [] for b in self.batch(data, batch_size): - result = self._onyx_embed(b, "text") + result = self._onyx_embed( + b, "text", model_name, model_version, num_workers, collection_name + ) if return_results: results.extend(result) return results - def embed_images(self, data: list, batch_size=None, return_results=True): + def embed_images( + self, + data: list, + model_name, + model_version=1, + num_workers=1, + collection_name=None, + batch_size=None, + return_results=True, + ): """ Get the embeddings for the input images Args: @@ -162,7 +176,9 @@ def embed_images(self, data: list, batch_size=None, return_results=True): results = [] for b in self.batch(encoded, batch_size): - result = self._onyx_embed(b, "image") + result = self._onyx_embed( + b, "image", model_name, model_version, num_workers, collection_name + ) if return_results: results.extend(result) diff --git a/onyxgenai/model.py b/onyxgenai/model.py index 6d929fb..d56435b 100644 --- a/onyxgenai/model.py +++ b/onyxgenai/model.py @@ -5,34 +5,13 @@ class ModelClient: """Client for interacting with the Onyx Model Store Service Args: svc_url (str): The URL of the Onyx Model Store Service - model_name (str): The name of the model to deploy - model_version (int): The version of the model to deploy - replicas (int): The number of replicas to deploy - deployment_name (str): The name of the deployment - options (dict): The options for deploying the model """ def __init__( self, svc_url, - model_name=None, - model_version=1, - replicas=1, - deployment_name=None, - options=None, ) -> None: self.svc_url = svc_url - self.model_name = model_name - self.model_version = model_version - self.replicas = replicas - self.deployment_name = deployment_name - self.options = options - - def _get_deployment_name(self): - if self.deployment_name: - return self.deployment_name - else: - return self.model_name def _onyx_model_info(self): url = f"{self.svc_url}/info/model_info" @@ -58,10 +37,10 @@ def _onyx_get_deployments(self): print("Failed to get deployment info:", response.status_code, response.text) return None - def _onyx_model_predict(self, data): + def _onyx_model_predict(self, data, model_name): url = f"{self.svc_url}/serve/predict/text" payload = { - "app_name": self._get_deployment_name(), + "app_name": model_name, "data": data, } @@ -79,11 +58,11 @@ def _onyx_model_predict(self, data): return None def _onyx_model_generate( - self, prompt, system_prompt, max_new_tokens, temperature, top_p + self, prompt, system_prompt, model_name, max_new_tokens, temperature, top_p ): url = f"{self.svc_url}/serve/generate/text" payload = { - "app_name": self._get_deployment_name(), + "app_name": model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, @@ -108,13 +87,13 @@ def _onyx_model_generate( print("Generate Failed:", response.status_code, response.text) return None - def _onyx_model_serve(self): - url = f"{self.svc_url}/serve/deploy/{self.model_name}" + def _onyx_model_serve(self, model_name, model_version, replicas, options): + url = f"{self.svc_url}/serve/deploy/{model_name}" payload = { - "app_name": self._get_deployment_name(), - "model_version": str(self.model_version), - "num_replicas": self.replicas, - "ray_actor_options": self.options, + "app_name": model_name, + "model_version": str(model_version), + "num_replicas": replicas, + "ray_actor_options": options, } response = requests.post(url, json=payload) @@ -126,10 +105,10 @@ def _onyx_model_serve(self): print("Deployment Failed:", response.status_code, response.text) return None - def _onyx_model_cleanup(self): + def _onyx_model_cleanup(self, deployment_name): url = f"{self.svc_url}/serve/cleanup" payload = { - "app_name": self._get_deployment_name(), + "app_name": deployment_name, } response = requests.post(url, json=payload) @@ -159,21 +138,23 @@ def get_deployments(self): result = self._onyx_get_deployments() return result - def embed_text(self, data): + def embed_text(self, data, model_name): """Get the embeddings for the input text Args: data (str): The input text + model_name (str): The name of the model Returns: list: The embeddings for the input text """ - result = self._onyx_model_predict(data) + result = self._onyx_model_predict(data, model_name) return result def generate_completion( self, prompt, system_prompt="", + model_name=None, max_new_tokens=10000, temperature=0.4, top_p=0.9, @@ -182,6 +163,7 @@ def generate_completion( Args: prompt (str): The prompt for completion system_prompt (str): The system prompt for completion + model_name (str): The name of the model max_new_tokens (int): The maximum number of tokens to generate temperature (float): The temperature for sampling top_p (float): The top_p value for sampling @@ -190,18 +172,18 @@ def generate_completion( """ result = self._onyx_model_generate( - prompt, system_prompt, max_new_tokens, temperature, top_p + prompt, system_prompt, model_name, max_new_tokens, temperature, top_p ) return result - def deploy_model(self): + def deploy_model(self, model_name, model_version=1, replicas=1, options=None): """Deploy the model to the service""" - result = self._onyx_model_serve() + result = self._onyx_model_serve(model_name, model_version, replicas, options) return result - def delete_deployment(self): + def delete_deployment(self, deployment_name): """Delete the deployment from the service""" - result = self._onyx_model_cleanup() + result = self._onyx_model_cleanup(deployment_name) return result From 2fdf39cb754bce863ccea8402da65ffb574e02c7 Mon Sep 17 00:00:00 2001 From: Johnny Bouder Date: Mon, 9 Sep 2024 15:08:35 -0400 Subject: [PATCH 2/3] Add model name variables to rag search notebook. --- notebooks/rag_search.ipynb | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/notebooks/rag_search.ipynb b/notebooks/rag_search.ipynb index 6deecb9..4ce186b 100644 --- a/notebooks/rag_search.ipynb +++ b/notebooks/rag_search.ipynb @@ -23,6 +23,12 @@ "from onyxgenai.embed import EmbeddingClient\n", "from onyxgenai.model import ModelClient\n", "\n", + "# Set the model and embedding names\n", + "embedding_model_name = \"all-MiniLM-L6-v2\"\n", + "embedding_model_version = \"1\"\n", + "language_model_name = \"Mistral-7B-Instruct-v0.3\"\n", + "language_model_version = \"1\"\n", + "\n", "embedding_client = EmbeddingClient(\"http://embed.onyx-services\")\n", "model_client = ModelClient(\"http://store.onyx-services\")" ] @@ -40,7 +46,8 @@ "metadata": {}, "outputs": [], "source": [ - "embedding_model = model_client.deploy_model(\"all-MiniLM-L6-v2\", 1, 2, {\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})" + "response = model_client.deploy_model(embedding_model_name, embedding_model_version, 2, {\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})\n", + "print(response)" ] }, { @@ -56,7 +63,8 @@ "metadata": {}, "outputs": [], "source": [ - "language_model = model_client.deploy_model(\"Mistral-7B-Instruct-v0.3\", 1, 1, {})" + "response = model_client.deploy_model(language_model_name, language_model_version, 1, {})\n", + "print(response)" ] }, { @@ -74,7 +82,7 @@ "source": [ "query = \"What is the capital of France?\"\n", "data = [query]\n", - "embeddings = model_client.embed_text(data, embedding_model)" + "embeddings = model_client.embed_text(data, embedding_model_name)" ] }, { @@ -123,7 +131,7 @@ "Answer:\n", "\"\"\" # noqa: E501\n", "\n", - "answer = model_client.generate_completion(prompt, model_name=\"Mistral-7B-Instruct-v0.3\")\n", + "answer = model_client.generate_completion(prompt, model_name=language_model_name)\n", "print(answer)" ] }, @@ -140,8 +148,8 @@ "metadata": {}, "outputs": [], "source": [ - "model_client.delete_deployment(embedding_model)\n", - "model_client.delete_deployment(language_model)" + "model_client.delete_deployment(embedding_model_name)\n", + "model_client.delete_deployment(language_model_name)" ] } ], From d8c984ad07ad440b36c8ec52b10141261dc8d20e Mon Sep 17 00:00:00 2001 From: Johnny Bouder Date: Mon, 9 Sep 2024 15:23:25 -0400 Subject: [PATCH 3/3] Fix unit tests. --- tests/test_embed.py | 21 --------------------- tests/test_model.py | 24 ------------------------ 2 files changed, 45 deletions(-) diff --git a/tests/test_embed.py b/tests/test_embed.py index ea71eb0..c9baa24 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -6,24 +6,3 @@ def test_base_embedding_client(): client = EmbeddingClient(svc_url) assert client.svc_url == svc_url - assert client.model_name is None - assert client.model_version == 1 - assert client.num_workers == 1 - assert client.collection_name is None - - -def test_full_embedding_client(): - svc_url = "http://localhost:8000" - model_name = "test_model" - model_version = 2 - num_workers = 4 - collection_name = "test_collection" - client = EmbeddingClient( - svc_url, model_name, model_version, num_workers, collection_name - ) - - assert client.svc_url == svc_url - assert client.model_name == model_name - assert client.model_version == 2 - assert client.num_workers == 4 - assert client.collection_name == collection_name diff --git a/tests/test_model.py b/tests/test_model.py index e62d21b..78e86f6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -6,27 +6,3 @@ def test_base_model_client(): client = ModelClient(svc_url) assert client.svc_url == svc_url - assert client.model_name is None - assert client.model_version == 1 - assert client.replicas == 1 - assert client.deployment_name is None - assert client.options is None - - -def test_full_model_client(): - svc_url = "http://localhost:8000" - model_name = "test_model" - model_version = 2 - replicas = 4 - deployment_name = "test_deployment" - options = {"option1": "value1"} - client = ModelClient( - svc_url, model_name, model_version, replicas, deployment_name, options - ) - - assert client.svc_url == svc_url - assert client.model_name == model_name - assert client.model_version == 2 - assert client.replicas == 4 - assert client.deployment_name == deployment_name - assert client.options == options