Skip to content

Commit

Permalink
Add more tests. Fix image embed error.
Browse files Browse the repository at this point in the history
  • Loading branch information
jbouder committed Sep 23, 2024
1 parent 1b16056 commit 52db122
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 1 deletion.
1 change: 0 additions & 1 deletion onyxgenai/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ def embed_images(
encoded.append(encoded_image)
else: # assume that it is a PIL image
buffered = BytesIO()
d.save(buffered, format="JPEG")
encoded_image = base64.b64encode(buffered.getvalue())
encoded.append(encoded_image)

Expand Down
74 changes: 74 additions & 0 deletions tests/test_embed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import patch

from onyxgenai.embed import EmbeddingClient


Expand All @@ -6,3 +8,75 @@ def test_base_embedding_client():
client = EmbeddingClient(svc_url)

assert client.svc_url == svc_url


@patch("requests.post")
def test_onyx_embed_text(mock_post):
svc_url = "http://localhost:8000"
client = EmbeddingClient(svc_url)
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = mock_response

data = ["sample text"]
model_name = "test_model"
result = client.embed_text(data, model_name)

assert result == mock_response["embeddings"]


@patch("requests.post")
def test_onyx_embed_image(mock_post):
svc_url = "http://localhost:8000"
client = EmbeddingClient(svc_url)
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = mock_response

data = ["path/to/image.jpg"]
model_name = "test_model"
result = client.embed_images(data, model_name)

assert result == mock_response["embeddings"]


@patch("requests.post")
def test_onyx_vector_search(mock_post):
svc_url = "http://localhost:8000"
client = EmbeddingClient(svc_url)
mock_response = {"results": ["result1", "result2", "result3"]}
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = mock_response

query = "sample query"
collection_name = "test_collection"
result = client.vector_search(query, collection_name)

assert result == mock_response["results"]


@patch("requests.get")
def test_onyx_get_collections(mock_get):
svc_url = "http://localhost:8000"
client = EmbeddingClient(svc_url)
mock_response = ["collection1", "collection2"]
mock_get.return_value.status_code = 200
mock_get.return_value.json.return_value = mock_response

result = client.get_collections()

assert result == mock_response


@patch("requests.delete")
def test_onyx_delete_collection(mock_delete):
svc_url = "http://localhost:8000"
client = EmbeddingClient(svc_url)
mock_response = {"status": "success"}
mock_delete.return_value.status_code = 200
mock_delete.return_value.json.return_value = mock_response

collection_name = "test_collection"
result = client.delete_collection(collection_name)

assert result == mock_response
106 changes: 106 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import Mock, patch

from onyxgenai.model import ModelClient


Expand All @@ -6,3 +8,107 @@ def test_base_model_client():
client = ModelClient(svc_url)

assert client.svc_url == svc_url


@patch("onyxgenai.model.requests.get")
def test_get_models(mock_get):
svc_url = "http://localhost:8000"
client = ModelClient(svc_url)

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"data": {"models": ["model1", "model2"]}}
mock_get.return_value = mock_response

models = client.get_models()
assert models == ["model1", "model2"]


@patch("onyxgenai.model.requests.get")
def test_get_deployments(mock_get):
svc_url = "http://localhost:8000"
client = ModelClient(svc_url)

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model1": {
"status": "running",
"message": "All good",
"last_deployed_time_s": 1234567890,
"deployments": {
"model1": {
"status": "running",
"status_trigger": "manual",
"replica_states": {"RUNNING": 2},
"message": "Deployment successful",
}
},
}
}
mock_get.return_value = mock_response

deployments = client.get_deployments()
assert len(deployments) == 1
assert deployments[0]["model"] == "model1"
assert deployments[0]["status"] == "running"


@patch("onyxgenai.model.requests.post")
def test_embed_text(mock_post):
svc_url = "http://localhost:8000"
client = ModelClient(svc_url)

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
mock_post.return_value = mock_response

embeddings = client.embed_text("sample text", "model1")
assert embeddings == [0.1, 0.2, 0.3]


@patch("onyxgenai.model.requests.post")
def test_generate_completion(mock_post):
svc_url = "http://localhost:8000"
client = ModelClient(svc_url)

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"generated_text": [{"content": "Generated text"}]
}
mock_post.return_value = mock_response

generated_text = client.generate_completion(
"sample prompt", "system prompt", "model1"
)
assert generated_text == "Generated text"


@patch("onyxgenai.model.requests.post")
def test_deploy_model(mock_post):
svc_url = "http://localhost:8000"
client = ModelClient(svc_url)

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"status": "success"}
mock_post.return_value = mock_response

response = client.deploy_model("model1", 1, 1, {})
assert response["status"] == "success"


@patch("onyxgenai.model.requests.post")
def test_delete_deployment(mock_post):
svc_url = "http://localhost:8000"
client = ModelClient(svc_url)

mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"status": "deleted"}
mock_post.return_value = mock_response

response = client.delete_deployment("deployment1")
assert response["status"] == "deleted"

0 comments on commit 52db122

Please sign in to comment.