Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Service-based Client Refactor #6

Merged
merged 3 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions notebooks/embedding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
"metadata": {},
"outputs": [],
"source": [
"client = EmbeddingClient(\"http://embed.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, num_workers=1, collection_name=\"test_collection\")\n",
"embeddings = client.embed_text(sentences)\n"
"client = EmbeddingClient(\"http://embed.onyx-services\")\n",
"embeddings = client.embed_text(sentences, model_name=\"all-MiniLM-L6-v2\", collection_name=\"test_collection\")\n"
]
}
],
Expand Down
28 changes: 17 additions & 11 deletions notebooks/rag_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
"from onyxgenai.embed import EmbeddingClient\n",
"from onyxgenai.model import ModelClient\n",
"\n",
"collection_name = \"test_collection\"\n",
"embedding_client = EmbeddingClient(\"http://embed.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, num_workers=2, collection_name=collection_name)\n",
"embedding_model_client = ModelClient(\"http://store.onyx-services\", model_name=\"all-MiniLM-L6-v2\", model_version=1, replicas=2, deployment_name=\"all-MiniLM-L6-v2\", options={\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})\n",
"llm_client = ModelClient(\"http://store.onyx-services\", model_name=\"Mistral-7B-Instruct-v0.3\", model_version=1, replicas=1, deployment_name=\"Mistral-7B-Instruct-v0.3\", options={})"
"# Set the model and embedding names\n",
"embedding_model_name = \"all-MiniLM-L6-v2\"\n",
"embedding_model_version = \"1\"\n",
"language_model_name = \"Mistral-7B-Instruct-v0.3\"\n",
"language_model_version = \"1\"\n",
"\n",
"embedding_client = EmbeddingClient(\"http://embed.onyx-services\")\n",
"model_client = ModelClient(\"http://store.onyx-services\")"
]
},
{
Expand All @@ -42,7 +46,8 @@
"metadata": {},
"outputs": [],
"source": [
"embedding_model_client.deploy_model()"
"response = model_client.deploy_model(embedding_model_name, embedding_model_version, 2, {\"num_cpus\": 2, \"memory\": 8000 * 1024 * 1024})\n",
"print(response)"
]
},
{
Expand All @@ -58,7 +63,8 @@
"metadata": {},
"outputs": [],
"source": [
"llm_client.deploy_model()"
"response = model_client.deploy_model(language_model_name, language_model_version, 1, {})\n",
"print(response)"
]
},
{
Expand All @@ -76,7 +82,7 @@
"source": [
"query = \"What is the capital of France?\"\n",
"data = [query]\n",
"embeddings = embedding_model_client.embed_text(data)"
"embeddings = model_client.embed_text(data, embedding_model_name)"
]
},
{
Expand All @@ -92,7 +98,7 @@
"metadata": {},
"outputs": [],
"source": [
"vector_data = embedding_client.vector_search(embeddings, collection_name)\n",
"vector_data = embedding_client.vector_search(embeddings, \"test_collection\")\n",
"print(vector_data)"
]
},
Expand Down Expand Up @@ -125,7 +131,7 @@
"Answer:\n",
"\"\"\" # noqa: E501\n",
"\n",
"answer = llm_client.generate_completion(prompt)\n",
"answer = model_client.generate_completion(prompt, model_name=language_model_name)\n",
"print(answer)"
]
},
Expand All @@ -142,8 +148,8 @@
"metadata": {},
"outputs": [],
"source": [
"embedding_model_client.delete_deployment()\n",
"llm_client.delete_deployment()"
"model_client.delete_deployment(embedding_model_name)\n",
"model_client.delete_deployment(language_model_name)"
]
}
],
Expand Down
58 changes: 37 additions & 21 deletions onyxgenai/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,17 @@ class EmbeddingClient:
A client for interacting with the Onyx Embedding Service.
Args:
svc_url (str): The URL of the Onyx Embedding Service
model_name (str): The name of the model to deploy
model_version (int): The version of the model to deploy
num_workers (int): The number of workers to deploy
collection_name (str): The name of the collection
"""

def __init__(
self,
svc_url,
model_name=None,
model_version=1,
num_workers=1,
collection_name=None,
) -> None:
self.svc_url = svc_url
self.model_name = model_name
self.model_version = model_version
self.num_workers = num_workers
self.collection_name = collection_name

def _onyx_embed(self, batch, media_type):
def _onyx_embed(
self, batch, media_type, model_name, model_version, num_workers, collection_name
):
if media_type == "text":
url = f"{self.svc_url}/embedding/text"
elif media_type == "image":
Expand All @@ -40,10 +30,10 @@ def _onyx_embed(self, batch, media_type):

data = {
"data": batch,
"model_identifier": self.model_name,
"model_version": self.model_version,
"num_workers": self.num_workers,
"collection_name": self.collection_name,
"model_identifier": model_name,
"model_version": model_version,
"num_workers": num_workers,
"collection_name": collection_name,
}

response = requests.post(url, json=data)
Expand Down Expand Up @@ -112,11 +102,24 @@ def batch(self, iterable, batch_size=1):
for ndx in range(0, batch_length, batch_size):
yield iterable[ndx : min(ndx + batch_size, batch_length)]

def embed_text(self, data: list, batch_size=None, return_results=True):
def embed_text(
self,
data: list,
model_name,
model_version=1,
num_workers=1,
collection_name=None,
batch_size=None,
return_results=True,
):
"""
Get the embeddings for the input text
Args:
data (list): The input text
model_name (str): The name of the model
model_version (int): The version of the model
num_workers (int): The number of workers
collection_name (str): The name of the collection
batch_size (int): The size of the batches
return_results (bool): Whether to return the results
Returns:
Expand All @@ -128,13 +131,24 @@ def embed_text(self, data: list, batch_size=None, return_results=True):

results = []
for b in self.batch(data, batch_size):
result = self._onyx_embed(b, "text")
result = self._onyx_embed(
b, "text", model_name, model_version, num_workers, collection_name
)
if return_results:
results.extend(result)

return results

def embed_images(self, data: list, batch_size=None, return_results=True):
def embed_images(
self,
data: list,
model_name,
model_version=1,
num_workers=1,
collection_name=None,
batch_size=None,
return_results=True,
):
"""
Get the embeddings for the input images
Args:
Expand Down Expand Up @@ -162,7 +176,9 @@ def embed_images(self, data: list, batch_size=None, return_results=True):

results = []
for b in self.batch(encoded, batch_size):
result = self._onyx_embed(b, "image")
result = self._onyx_embed(
b, "image", model_name, model_version, num_workers, collection_name
)
if return_results:
results.extend(result)

Expand Down
62 changes: 22 additions & 40 deletions onyxgenai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,13 @@ class ModelClient:
"""Client for interacting with the Onyx Model Store Service
Args:
svc_url (str): The URL of the Onyx Model Store Service
model_name (str): The name of the model to deploy
model_version (int): The version of the model to deploy
replicas (int): The number of replicas to deploy
deployment_name (str): The name of the deployment
options (dict): The options for deploying the model
"""

def __init__(
self,
svc_url,
model_name=None,
model_version=1,
replicas=1,
deployment_name=None,
options=None,
) -> None:
self.svc_url = svc_url
self.model_name = model_name
self.model_version = model_version
self.replicas = replicas
self.deployment_name = deployment_name
self.options = options

def _get_deployment_name(self):
if self.deployment_name:
return self.deployment_name
else:
return self.model_name

def _onyx_model_info(self):
url = f"{self.svc_url}/info/model_info"
Expand All @@ -58,10 +37,10 @@ def _onyx_get_deployments(self):
print("Failed to get deployment info:", response.status_code, response.text)
return None

def _onyx_model_predict(self, data):
def _onyx_model_predict(self, data, model_name):
url = f"{self.svc_url}/serve/predict/text"
payload = {
"app_name": self._get_deployment_name(),
"app_name": model_name,
"data": data,
}

Expand All @@ -79,11 +58,11 @@ def _onyx_model_predict(self, data):
return None

def _onyx_model_generate(
self, prompt, system_prompt, max_new_tokens, temperature, top_p
self, prompt, system_prompt, model_name, max_new_tokens, temperature, top_p
):
url = f"{self.svc_url}/serve/generate/text"
payload = {
"app_name": self._get_deployment_name(),
"app_name": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
Expand All @@ -108,13 +87,13 @@ def _onyx_model_generate(
print("Generate Failed:", response.status_code, response.text)
return None

def _onyx_model_serve(self):
url = f"{self.svc_url}/serve/deploy/{self.model_name}"
def _onyx_model_serve(self, model_name, model_version, replicas, options):
url = f"{self.svc_url}/serve/deploy/{model_name}"
payload = {
"app_name": self._get_deployment_name(),
"model_version": str(self.model_version),
"num_replicas": self.replicas,
"ray_actor_options": self.options,
"app_name": model_name,
"model_version": str(model_version),
"num_replicas": replicas,
"ray_actor_options": options,
}

response = requests.post(url, json=payload)
Expand All @@ -126,10 +105,10 @@ def _onyx_model_serve(self):
print("Deployment Failed:", response.status_code, response.text)
return None

def _onyx_model_cleanup(self):
def _onyx_model_cleanup(self, deployment_name):
url = f"{self.svc_url}/serve/cleanup"
payload = {
"app_name": self._get_deployment_name(),
"app_name": deployment_name,
}

response = requests.post(url, json=payload)
Expand Down Expand Up @@ -159,21 +138,23 @@ def get_deployments(self):
result = self._onyx_get_deployments()
return result

def embed_text(self, data):
def embed_text(self, data, model_name):
"""Get the embeddings for the input text
Args:
data (str): The input text
model_name (str): The name of the model
Returns:
list: The embeddings for the input text
"""

result = self._onyx_model_predict(data)
result = self._onyx_model_predict(data, model_name)
return result

def generate_completion(
self,
prompt,
system_prompt="",
model_name=None,
max_new_tokens=10000,
temperature=0.4,
top_p=0.9,
Expand All @@ -182,6 +163,7 @@ def generate_completion(
Args:
prompt (str): The prompt for completion
system_prompt (str): The system prompt for completion
model_name (str): The name of the model
max_new_tokens (int): The maximum number of tokens to generate
temperature (float): The temperature for sampling
top_p (float): The top_p value for sampling
Expand All @@ -190,18 +172,18 @@ def generate_completion(
"""

result = self._onyx_model_generate(
prompt, system_prompt, max_new_tokens, temperature, top_p
prompt, system_prompt, model_name, max_new_tokens, temperature, top_p
)
return result

def deploy_model(self):
def deploy_model(self, model_name, model_version=1, replicas=1, options=None):
"""Deploy the model to the service"""

result = self._onyx_model_serve()
result = self._onyx_model_serve(model_name, model_version, replicas, options)
return result

def delete_deployment(self):
def delete_deployment(self, deployment_name):
"""Delete the deployment from the service"""

result = self._onyx_model_cleanup()
result = self._onyx_model_cleanup(deployment_name)
return result
21 changes: 0 additions & 21 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,3 @@ def test_base_embedding_client():
client = EmbeddingClient(svc_url)

assert client.svc_url == svc_url
assert client.model_name is None
assert client.model_version == 1
assert client.num_workers == 1
assert client.collection_name is None


def test_full_embedding_client():
svc_url = "http://localhost:8000"
model_name = "test_model"
model_version = 2
num_workers = 4
collection_name = "test_collection"
client = EmbeddingClient(
svc_url, model_name, model_version, num_workers, collection_name
)

assert client.svc_url == svc_url
assert client.model_name == model_name
assert client.model_version == 2
assert client.num_workers == 4
assert client.collection_name == collection_name
Loading
Loading