Skip to content

Commit

Permalink
Merge pull request #6 from MetroStar/client-refactor
Browse files Browse the repository at this point in the history
Service-based Client Refactor
  • Loading branch information
jbouder authored Sep 9, 2024
2 parents 3cf63ab + d8c984a commit b18ca69
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 119 deletions.
4 changes: 2 additions & 2 deletions notebooks/embedding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
"metadata": {},
"outputs": [],
"source": [
"client = EmbeddingClient(\"http://embed.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, num_workers=1, collection_name=\"test_collection\")\n",
"embeddings = client.embed_text(sentences)\n"
"client = EmbeddingClient(\"http://embed.onyx-services\")\n",
"embeddings = client.embed_text(sentences, model_name=\"all-MiniLM-L6-v2\", collection_name=\"test_collection\")\n"
]
}
],
Expand Down
28 changes: 17 additions & 11 deletions notebooks/rag_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
"from onyxgenai.embed import EmbeddingClient\n",
"from onyxgenai.model import ModelClient\n",
"\n",
"collection_name = \"test_collection\"\n",
"embedding_client = EmbeddingClient(\"http://embed.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, num_workers=2, collection_name=collection_name)\n",
"embedding_model_client = ModelClient(\"http://store.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, replicas=2, deployment_name=\"all-MiniLM-L6-v2\", options={\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})\n",
"llm_client = ModelClient(\"http://store.onyx-services\", model_name=\"Mistral-7B-Instruct-v0.3\", model_version=1, replicas=1, deployment_name=\"Mistral-7B-Instruct-v0.3\", options={})"
"# Set the model and embedding names\n",
"embedding_model_name = \"all-MiniLM-L6-v2\"\n",
"embedding_model_version = \"1\"\n",
"language_model_name = \"Mistral-7B-Instruct-v0.3\"\n",
"language_model_version = \"1\"\n",
"\n",
"embedding_client = EmbeddingClient(\"http://embed.onyx-services\")\n",
"model_client = ModelClient(\"http://store.onyx-services\")"
]
},
{
Expand All @@ -42,7 +46,8 @@
"metadata": {},
"outputs": [],
"source": [
"embedding_model_client.deploy_model()"
"response = model_client.deploy_model(embedding_model_name, embedding_model_version, 2, {\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})\n",
"print(response)"
]
},
{
Expand All @@ -58,7 +63,8 @@
"metadata": {},
"outputs": [],
"source": [
"llm_client.deploy_model()"
"response = model_client.deploy_model(language_model_name, language_model_version, 1, {})\n",
"print(response)"
]
},
{
Expand All @@ -76,7 +82,7 @@
"source": [
"query = \"What is the capital of France?\"\n",
"data = [query]\n",
"embeddings = embedding_model_client.embed_text(data)"
"embeddings = model_client.embed_text(data, embedding_model_name)"
]
},
{
Expand All @@ -92,7 +98,7 @@
"metadata": {},
"outputs": [],
"source": [
"vector_data = embedding_client.vector_search(embeddings, collection_name)\n",
"vector_data = embedding_client.vector_search(embeddings, \"test_collection\")\n",
"print(vector_data)"
]
},
Expand Down Expand Up @@ -125,7 +131,7 @@
"Answer:\n",
"\"\"\" # noqa: E501\n",
"\n",
"answer = llm_client.generate_completion(prompt)\n",
"answer = model_client.generate_completion(prompt, model_name=language_model_name)\n",
"print(answer)"
]
},
Expand All @@ -142,8 +148,8 @@
"metadata": {},
"outputs": [],
"source": [
"embedding_model_client.delete_deployment()\n",
"llm_client.delete_deployment()"
"model_client.delete_deployment(embedding_model_name)\n",
"model_client.delete_deployment(language_model_name)"
]
}
],
Expand Down
58 changes: 37 additions & 21 deletions onyxgenai/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,17 @@ class EmbeddingClient:
A client for interacting with the Onyx Embedding Service.
Args:
svc_url (str): The URL of the Onyx Embedding Service
model_name (str): The name of the model to deploy
model_version (int): The version of the model to deploy
num_workers (int): The number of workers to deploy
collection_name (str): The name of the collection
"""

def __init__(
self,
svc_url,
model_name=None,
model_version=1,
num_workers=1,
collection_name=None,
) -> None:
self.svc_url = svc_url
self.model_name = model_name
self.model_version = model_version
self.num_workers = num_workers
self.collection_name = collection_name

def _onyx_embed(self, batch, media_type):
def _onyx_embed(
self, batch, media_type, model_name, model_version, num_workers, collection_name
):
if media_type == "text":
url = f"{self.svc_url}/embedding/text"
elif media_type == "image":
Expand All @@ -40,10 +30,10 @@ def _onyx_embed(self, batch, media_type):

data = {
"data": batch,
"model_identifier": self.model_name,
"model_version": self.model_version,
"num_workers": self.num_workers,
"collection_name": self.collection_name,
"model_identifier": model_name,
"model_version": model_version,
"num_workers": num_workers,
"collection_name": collection_name,
}

response = requests.post(url, json=data)
Expand Down Expand Up @@ -112,11 +102,24 @@ def batch(self, iterable, batch_size=1):
for ndx in range(0, batch_length, batch_size):
yield iterable[ndx : min(ndx + batch_size, batch_length)]

def embed_text(self, data: list, batch_size=None, return_results=True):
def embed_text(
self,
data: list,
model_name,
model_version=1,
num_workers=1,
collection_name=None,
batch_size=None,
return_results=True,
):
"""
Get the embeddings for the input text
Args:
data (list): The input text
model_name (str): The name of the model
model_version (int): The version of the model
num_workers (int): The number of workers
collection_name (str): The name of the collection
batch_size (int): The size of the batches
return_results (bool): Whether to return the results
Returns:
Expand All @@ -128,13 +131,24 @@ def embed_text(self, data: list, batch_size=None, return_results=True):

results = []
for b in self.batch(data, batch_size):
result = self._onyx_embed(b, "text")
result = self._onyx_embed(
b, "text", model_name, model_version, num_workers, collection_name
)
if return_results:
results.extend(result)

return results

def embed_images(self, data: list, batch_size=None, return_results=True):
def embed_images(
self,
data: list,
model_name,
model_version=1,
num_workers=1,
collection_name=None,
batch_size=None,
return_results=True,
):
"""
Get the embeddings for the input images
Args:
Expand Down Expand Up @@ -162,7 +176,9 @@ def embed_images(self, data: list, batch_size=None, return_results=True):

results = []
for b in self.batch(encoded, batch_size):
result = self._onyx_embed(b, "image")
result = self._onyx_embed(
b, "image", model_name, model_version, num_workers, collection_name
)
if return_results:
results.extend(result)

Expand Down
62 changes: 22 additions & 40 deletions onyxgenai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,13 @@ class ModelClient:
"""Client for interacting with the Onyx Model Store Service
Args:
svc_url (str): The URL of the Onyx Model Store Service
model_name (str): The name of the model to deploy
model_version (int): The version of the model to deploy
replicas (int): The number of replicas to deploy
deployment_name (str): The name of the deployment
options (dict): The options for deploying the model
"""

def __init__(
self,
svc_url,
model_name=None,
model_version=1,
replicas=1,
deployment_name=None,
options=None,
) -> None:
self.svc_url = svc_url
self.model_name = model_name
self.model_version = model_version
self.replicas = replicas
self.deployment_name = deployment_name
self.options = options

def _get_deployment_name(self):
if self.deployment_name:
return self.deployment_name
else:
return self.model_name

def _onyx_model_info(self):
url = f"{self.svc_url}/info/model_info"
Expand All @@ -58,10 +37,10 @@ def _onyx_get_deployments(self):
print("Failed to get deployment info:", response.status_code, response.text)
return None

def _onyx_model_predict(self, data):
def _onyx_model_predict(self, data, model_name):
url = f"{self.svc_url}/serve/predict/text"
payload = {
"app_name": self._get_deployment_name(),
"app_name": model_name,
"data": data,
}

Expand All @@ -79,11 +58,11 @@ def _onyx_model_predict(self, data):
return None

def _onyx_model_generate(
self, prompt, system_prompt, max_new_tokens, temperature, top_p
self, prompt, system_prompt, model_name, max_new_tokens, temperature, top_p
):
url = f"{self.svc_url}/serve/generate/text"
payload = {
"app_name": self._get_deployment_name(),
"app_name": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
Expand All @@ -108,13 +87,13 @@ def _onyx_model_generate(
print("Generate Failed:", response.status_code, response.text)
return None

def _onyx_model_serve(self):
url = f"{self.svc_url}/serve/deploy/{self.model_name}"
def _onyx_model_serve(self, model_name, model_version, replicas, options):
url = f"{self.svc_url}/serve/deploy/{model_name}"
payload = {
"app_name": self._get_deployment_name(),
"model_version": str(self.model_version),
"num_replicas": self.replicas,
"ray_actor_options": self.options,
"app_name": model_name,
"model_version": str(model_version),
"num_replicas": replicas,
"ray_actor_options": options,
}

response = requests.post(url, json=payload)
Expand All @@ -126,10 +105,10 @@ def _onyx_model_serve(self):
print("Deployment Failed:", response.status_code, response.text)
return None

def _onyx_model_cleanup(self):
def _onyx_model_cleanup(self, deployment_name):
url = f"{self.svc_url}/serve/cleanup"
payload = {
"app_name": self._get_deployment_name(),
"app_name": deployment_name,
}

response = requests.post(url, json=payload)
Expand Down Expand Up @@ -159,21 +138,23 @@ def get_deployments(self):
result = self._onyx_get_deployments()
return result

def embed_text(self, data):
def embed_text(self, data, model_name):
"""Get the embeddings for the input text
Args:
data (str): The input text
model_name (str): The name of the model
Returns:
list: The embeddings for the input text
"""

result = self._onyx_model_predict(data)
result = self._onyx_model_predict(data, model_name)
return result

def generate_completion(
self,
prompt,
system_prompt="",
model_name=None,
max_new_tokens=10000,
temperature=0.4,
top_p=0.9,
Expand All @@ -182,6 +163,7 @@ def generate_completion(
Args:
prompt (str): The prompt for completion
system_prompt (str): The system prompt for completion
model_name (str): The name of the model
max_new_tokens (int): The maximum number of tokens to generate
temperature (float): The temperature for sampling
top_p (float): The top_p value for sampling
Expand All @@ -190,18 +172,18 @@ def generate_completion(
"""

result = self._onyx_model_generate(
prompt, system_prompt, max_new_tokens, temperature, top_p
prompt, system_prompt, model_name, max_new_tokens, temperature, top_p
)
return result

def deploy_model(self):
def deploy_model(self, model_name, model_version=1, replicas=1, options=None):
"""Deploy the model to the service"""

result = self._onyx_model_serve()
result = self._onyx_model_serve(model_name, model_version, replicas, options)
return result

def delete_deployment(self):
def delete_deployment(self, deployment_name):
"""Delete the deployment from the service"""

result = self._onyx_model_cleanup()
result = self._onyx_model_cleanup(deployment_name)
return result
21 changes: 0 additions & 21 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,3 @@ def test_base_embedding_client():
client = EmbeddingClient(svc_url)

assert client.svc_url == svc_url
assert client.model_name is None
assert client.model_version == 1
assert client.num_workers == 1
assert client.collection_name is None


def test_full_embedding_client():
svc_url = "http://localhost:8000"
model_name = "test_model"
model_version = 2
num_workers = 4
collection_name = "test_collection"
client = EmbeddingClient(
svc_url, model_name, model_version, num_workers, collection_name
)

assert client.svc_url == svc_url
assert client.model_name == model_name
assert client.model_version == 2
assert client.num_workers == 4
assert client.collection_name == collection_name
Loading

0 comments on commit b18ca69

Please sign in to comment.