From 52db122bef400f9f8ff6cd339031c0e0883f34b3 Mon Sep 17 00:00:00 2001 From: Johnny Bouder Date: Mon, 23 Sep 2024 14:11:27 -0400 Subject: [PATCH] Add more tests. Fix image embed error. --- onyxgenai/embed.py | 1 - tests/test_embed.py | 74 +++++++++++++++++++++++++++++++ tests/test_model.py | 106 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 1 deletion(-) diff --git a/onyxgenai/embed.py b/onyxgenai/embed.py index b1353b0..ef790b6 100644 --- a/onyxgenai/embed.py +++ b/onyxgenai/embed.py @@ -170,7 +170,6 @@ def embed_images( encoded.append(encoded_image) else: # assume that it is a PIL image buffered = BytesIO() - d.save(buffered, format="JPEG") encoded_image = base64.b64encode(buffered.getvalue()) encoded.append(encoded_image) diff --git a/tests/test_embed.py b/tests/test_embed.py index c9baa24..fb0c0af 100644 --- a/tests/test_embed.py +++ b/tests/test_embed.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + from onyxgenai.embed import EmbeddingClient @@ -6,3 +8,75 @@ def test_base_embedding_client(): client = EmbeddingClient(svc_url) assert client.svc_url == svc_url + + +@patch("requests.post") +def test_onyx_embed_text(mock_post): + svc_url = "http://localhost:8000" + client = EmbeddingClient(svc_url) + mock_response = {"embeddings": [[0.1, 0.2, 0.3]]} + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_response + + data = ["sample text"] + model_name = "test_model" + result = client.embed_text(data, model_name) + + assert result == mock_response["embeddings"] + + +@patch("requests.post") +def test_onyx_embed_image(mock_post): + svc_url = "http://localhost:8000" + client = EmbeddingClient(svc_url) + mock_response = {"embeddings": [[0.1, 0.2, 0.3]]} + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_response + + data = ["path/to/image.jpg"] + model_name = "test_model" + result = client.embed_images(data, model_name) + + assert result == mock_response["embeddings"] + + +@patch("requests.post") +def test_onyx_vector_search(mock_post): + svc_url = "http://localhost:8000" + client = EmbeddingClient(svc_url) + mock_response = {"results": ["result1", "result2", "result3"]} + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = mock_response + + query = "sample query" + collection_name = "test_collection" + result = client.vector_search(query, collection_name) + + assert result == mock_response["results"] + + +@patch("requests.get") +def test_onyx_get_collections(mock_get): + svc_url = "http://localhost:8000" + client = EmbeddingClient(svc_url) + mock_response = ["collection1", "collection2"] + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = mock_response + + result = client.get_collections() + + assert result == mock_response + + +@patch("requests.delete") +def test_onyx_delete_collection(mock_delete): + svc_url = "http://localhost:8000" + client = EmbeddingClient(svc_url) + mock_response = {"status": "success"} + mock_delete.return_value.status_code = 200 + mock_delete.return_value.json.return_value = mock_response + + collection_name = "test_collection" + result = client.delete_collection(collection_name) + + assert result == mock_response diff --git a/tests/test_model.py b/tests/test_model.py index 78e86f6..4135ece 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock, patch + from onyxgenai.model import ModelClient @@ -6,3 +8,107 @@ def test_base_model_client(): client = ModelClient(svc_url) assert client.svc_url == svc_url + + +@patch("onyxgenai.model.requests.get") +def test_get_models(mock_get): + svc_url = "http://localhost:8000" + client = ModelClient(svc_url) + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"data": {"models": ["model1", "model2"]}} + mock_get.return_value = mock_response + + models = client.get_models() + assert models == ["model1", "model2"] + + +@patch("onyxgenai.model.requests.get") +def test_get_deployments(mock_get): + svc_url = "http://localhost:8000" + client = ModelClient(svc_url) + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "model1": { + "status": "running", + "message": "All good", + "last_deployed_time_s": 1234567890, + "deployments": { + "model1": { + "status": "running", + "status_trigger": "manual", + "replica_states": {"RUNNING": 2}, + "message": "Deployment successful", + } + }, + } + } + mock_get.return_value = mock_response + + deployments = client.get_deployments() + assert len(deployments) == 1 + assert deployments[0]["model"] == "model1" + assert deployments[0]["status"] == "running" + + +@patch("onyxgenai.model.requests.post") +def test_embed_text(mock_post): + svc_url = "http://localhost:8000" + client = ModelClient(svc_url) + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"embeddings": [[0.1, 0.2, 0.3]]} + mock_post.return_value = mock_response + + embeddings = client.embed_text("sample text", "model1") + assert embeddings == [0.1, 0.2, 0.3] + + +@patch("onyxgenai.model.requests.post") +def test_generate_completion(mock_post): + svc_url = "http://localhost:8000" + client = ModelClient(svc_url) + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "generated_text": [{"content": "Generated text"}] + } + mock_post.return_value = mock_response + + generated_text = client.generate_completion( + "sample prompt", "system prompt", "model1" + ) + assert generated_text == "Generated text" + + +@patch("onyxgenai.model.requests.post") +def test_deploy_model(mock_post): + svc_url = "http://localhost:8000" + client = ModelClient(svc_url) + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"status": "success"} + mock_post.return_value = mock_response + + response = client.deploy_model("model1", 1, 1, {}) + assert response["status"] == "success" + + +@patch("onyxgenai.model.requests.post") +def test_delete_deployment(mock_post): + svc_url = "http://localhost:8000" + client = ModelClient(svc_url) + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"status": "deleted"} + mock_post.return_value = mock_response + + response = client.delete_deployment("deployment1") + assert response["status"] == "deleted"