-
-
Notifications
You must be signed in to change notification settings - Fork 494
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Dec 13, 2023
1 parent
b1d3aa5
commit 4bef09a
Showing
4 changed files
with
287 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" | |
|
||
[tool.poetry] | ||
name = "swarms" | ||
version = "2.7.9" | ||
version = "2.8.0" | ||
description = "Swarms - Pytorch" | ||
license = "MIT" | ||
authors = ["Kye Gomez <[email protected]>"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import logging | ||
import os | ||
from typing import Optional | ||
|
||
import requests | ||
from dotenv import load_dotenv | ||
|
||
from swarms.models.base_llm import AbstractLLM | ||
|
||
# Load environment variables | ||
load_dotenv() | ||
|
||
|
||
def together_api_key_env(): | ||
"""Get the API key from the environment.""" | ||
return os.getenv("TOGETHER_API_KEY") | ||
|
||
|
||
class TogetherModel(AbstractLLM): | ||
""" | ||
GPT-4 Vision API | ||
This class is a wrapper for the OpenAI API. It is used to run the GPT-4 Vision model. | ||
Parameters | ||
---------- | ||
together_api_key : str | ||
The OpenAI API key. Defaults to the together_api_key environment variable. | ||
max_tokens : int | ||
The maximum number of tokens to generate. Defaults to 300. | ||
Methods | ||
------- | ||
encode_image(img: str) | ||
Encode image to base64. | ||
run(task: str, img: str) | ||
Run the model. | ||
__call__(task: str, img: str) | ||
Run the model. | ||
Examples: | ||
--------- | ||
>>> from swarms.models import GPT4VisionAPI | ||
>>> llm = GPT4VisionAPI() | ||
>>> task = "What is the color of the object?" | ||
>>> img = "https://i.imgur.com/2M2ZGwC.jpeg" | ||
>>> llm.run(task, img) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
together_api_key: str = together_api_key_env, | ||
model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1", | ||
logging_enabled: bool = False, | ||
max_workers: int = 10, | ||
max_tokens: str = 300, | ||
api_endpoint: str = "https://api.together.xyz", | ||
beautify: bool = False, | ||
streaming_enabled: Optional[bool] = False, | ||
meta_prompt: Optional[bool] = False, | ||
system_prompt: Optional[str] = None, | ||
*args, | ||
**kwargs, | ||
): | ||
super(TogetherModel).__init__(*args, **kwargs) | ||
self.together_api_key = together_api_key | ||
self.logging_enabled = logging_enabled | ||
self.model_name = model_name | ||
self.max_workers = max_workers | ||
self.max_tokens = max_tokens | ||
self.api_endpoint = api_endpoint | ||
self.beautify = beautify | ||
self.streaming_enabled = streaming_enabled | ||
self.meta_prompt = meta_prompt | ||
self.system_prompt = system_prompt | ||
|
||
if self.logging_enabled: | ||
logging.basicConfig(level=logging.DEBUG) | ||
else: | ||
# Disable debug logs for requests and urllib3 | ||
logging.getLogger("requests").setLevel(logging.WARNING) | ||
logging.getLogger("urllib3").setLevel(logging.WARNING) | ||
|
||
if self.meta_prompt: | ||
self.system_prompt = self.meta_prompt_init() | ||
|
||
# Function to handle vision tasks | ||
def run(self, task: str = None, *args, **kwargs): | ||
"""Run the model.""" | ||
try: | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {self.together_api_key}", | ||
} | ||
payload = { | ||
"model": self.model_name, | ||
"messages": [ | ||
{ | ||
"role": "system", | ||
"content": [self.system_prompt], | ||
}, | ||
{ | ||
"role": "user", | ||
"content": task, | ||
}, | ||
], | ||
"max_tokens": self.max_tokens, | ||
**kwargs, | ||
} | ||
response = requests.post( | ||
self.api_endpoint, | ||
headers=headers, | ||
json=payload, | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
out = response.json() | ||
if "choices" in out and out["choices"]: | ||
content = ( | ||
out["choices"][0] | ||
.get("message", {}) | ||
.get("content", None) | ||
) | ||
if self.streaming_enabled: | ||
content = self.stream_response(content) | ||
return content | ||
else: | ||
print("No valid response in 'choices'") | ||
return None | ||
|
||
except Exception as error: | ||
print( | ||
f"Error with the request: {error}, make sure you" | ||
" double check input types and positions" | ||
) | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import os | ||
import requests | ||
import pytest | ||
from unittest.mock import patch, Mock | ||
from swarms.models.together import TogetherModel | ||
import logging | ||
|
||
|
||
@pytest.fixture | ||
def mock_api_key(monkeypatch): | ||
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key") | ||
|
||
|
||
def test_init_defaults(): | ||
model = TogetherModel() | ||
assert model.together_api_key == "mocked-api-key" | ||
assert model.logging_enabled is False | ||
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1" | ||
assert model.max_workers == 10 | ||
assert model.max_tokens == 300 | ||
assert model.api_endpoint == "https://api.together.xyz" | ||
assert model.beautify is False | ||
assert model.streaming_enabled is False | ||
assert model.meta_prompt is False | ||
assert model.system_prompt is None | ||
|
||
|
||
def test_init_custom_params(mock_api_key): | ||
model = TogetherModel( | ||
together_api_key="custom-api-key", | ||
logging_enabled=True, | ||
model_name="custom-model", | ||
max_workers=5, | ||
max_tokens=500, | ||
api_endpoint="https://custom-api.together.xyz", | ||
beautify=True, | ||
streaming_enabled=True, | ||
meta_prompt="meta-prompt", | ||
system_prompt="system-prompt", | ||
) | ||
assert model.together_api_key == "custom-api-key" | ||
assert model.logging_enabled is True | ||
assert model.model_name == "custom-model" | ||
assert model.max_workers == 5 | ||
assert model.max_tokens == 500 | ||
assert model.api_endpoint == "https://custom-api.together.xyz" | ||
assert model.beautify is True | ||
assert model.streaming_enabled is True | ||
assert model.meta_prompt == "meta-prompt" | ||
assert model.system_prompt == "system-prompt" | ||
|
||
|
||
@patch("swarms.models.together_model.requests.post") | ||
def test_run_success(mock_post, mock_api_key): | ||
mock_response = Mock() | ||
mock_response.json.return_value = { | ||
"choices": [{"message": {"content": "Generated response"}}] | ||
} | ||
mock_post.return_value = mock_response | ||
|
||
model = TogetherModel() | ||
task = "What is the color of the object?" | ||
response = model.run(task) | ||
|
||
assert response == "Generated response" | ||
|
||
|
||
@patch("swarms.models.together_model.requests.post") | ||
def test_run_failure(mock_post, mock_api_key): | ||
mock_post.side_effect = requests.exceptions.RequestException( | ||
"Request failed" | ||
) | ||
|
||
model = TogetherModel() | ||
task = "What is the color of the object?" | ||
response = model.run(task) | ||
|
||
assert response is None | ||
|
||
|
||
def test_run_with_logging_enabled(caplog, mock_api_key): | ||
model = TogetherModel(logging_enabled=True) | ||
task = "What is the color of the object?" | ||
|
||
with caplog.at_level(logging.DEBUG): | ||
model.run(task) | ||
|
||
assert "Sending request to" in caplog.text | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"invalid_input", [None, 123, ["list", "of", "items"]] | ||
) | ||
def test_invalid_task_input(invalid_input, mock_api_key): | ||
model = TogetherModel() | ||
response = model.run(invalid_input) | ||
|
||
assert response is None | ||
|
||
|
||
@patch("swarms.models.together_model.requests.post") | ||
def test_run_streaming_enabled(mock_post, mock_api_key): | ||
mock_response = Mock() | ||
mock_response.json.return_value = { | ||
"choices": [{"message": {"content": "Generated response"}}] | ||
} | ||
mock_post.return_value = mock_response | ||
|
||
model = TogetherModel(streaming_enabled=True) | ||
task = "What is the color of the object?" | ||
response = model.run(task) | ||
|
||
assert response == "Generated response" | ||
|
||
|
||
@patch("swarms.models.together_model.requests.post") | ||
def test_run_empty_choices(mock_post, mock_api_key): | ||
mock_response = Mock() | ||
mock_response.json.return_value = {"choices": []} | ||
mock_post.return_value = mock_response | ||
|
||
model = TogetherModel() | ||
task = "What is the color of the object?" | ||
response = model.run(task) | ||
|
||
assert response is None | ||
|
||
|
||
@patch("swarms.models.together_model.requests.post") | ||
def test_run_with_exception(mock_post, mock_api_key): | ||
mock_post.side_effect = Exception("Test exception") | ||
|
||
model = TogetherModel() | ||
task = "What is the color of the object?" | ||
response = model.run(task) | ||
|
||
assert response is None | ||
|
||
|
||
def test_init_logging_disabled(monkeypatch): | ||
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key") | ||
model = TogetherModel() | ||
assert model.logging_enabled is False | ||
assert not model.system_prompt |