Skip to content

Commit

Permalink
[FEAT][TogertherModel]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 13, 2023
1 parent b1d3aa5 commit 4bef09a
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "swarms"
version = "2.7.9"
version = "2.8.0"
description = "Swarms - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
Expand Down
140 changes: 140 additions & 0 deletions swarms/models/together.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
import os
from typing import Optional

import requests
from dotenv import load_dotenv

from swarms.models.base_llm import AbstractLLM

# Load environment variables
load_dotenv()


def together_api_key_env():
"""Get the API key from the environment."""
return os.getenv("TOGETHER_API_KEY")


class TogetherModel(AbstractLLM):
"""
GPT-4 Vision API
This class is a wrapper for the OpenAI API. It is used to run the GPT-4 Vision model.
Parameters
----------
together_api_key : str
The OpenAI API key. Defaults to the together_api_key environment variable.
max_tokens : int
The maximum number of tokens to generate. Defaults to 300.
Methods
-------
encode_image(img: str)
Encode image to base64.
run(task: str, img: str)
Run the model.
__call__(task: str, img: str)
Run the model.
Examples:
---------
>>> from swarms.models import GPT4VisionAPI
>>> llm = GPT4VisionAPI()
>>> task = "What is the color of the object?"
>>> img = "https://i.imgur.com/2M2ZGwC.jpeg"
>>> llm.run(task, img)
"""

def __init__(
self,
together_api_key: str = together_api_key_env,
model_name: str = "mistralai/Mixtral-8x7B-Instruct-v0.1",
logging_enabled: bool = False,
max_workers: int = 10,
max_tokens: str = 300,
api_endpoint: str = "https://api.together.xyz",
beautify: bool = False,
streaming_enabled: Optional[bool] = False,
meta_prompt: Optional[bool] = False,
system_prompt: Optional[str] = None,
*args,
**kwargs,
):
super(TogetherModel).__init__(*args, **kwargs)
self.together_api_key = together_api_key
self.logging_enabled = logging_enabled
self.model_name = model_name
self.max_workers = max_workers
self.max_tokens = max_tokens
self.api_endpoint = api_endpoint
self.beautify = beautify
self.streaming_enabled = streaming_enabled
self.meta_prompt = meta_prompt
self.system_prompt = system_prompt

if self.logging_enabled:
logging.basicConfig(level=logging.DEBUG)
else:
# Disable debug logs for requests and urllib3
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

if self.meta_prompt:
self.system_prompt = self.meta_prompt_init()

# Function to handle vision tasks
def run(self, task: str = None, *args, **kwargs):
"""Run the model."""
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.together_api_key}",
}
payload = {
"model": self.model_name,
"messages": [
{
"role": "system",
"content": [self.system_prompt],
},
{
"role": "user",
"content": task,
},
],
"max_tokens": self.max_tokens,
**kwargs,
}
response = requests.post(
self.api_endpoint,
headers=headers,
json=payload,
*args,
**kwargs,
)

out = response.json()
if "choices" in out and out["choices"]:
content = (
out["choices"][0]
.get("message", {})
.get("content", None)
)
if self.streaming_enabled:
content = self.stream_response(content)
return content
else:
print("No valid response in 'choices'")
return None

except Exception as error:
print(
f"Error with the request: {error}, make sure you"
" double check input types and positions"
)
return None
5 changes: 2 additions & 3 deletions swarms/prompts/react.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

def react_prompt(task: str = None):
REACT = f"""
PROMPT = f"""
Task Description:
Accomplish the following {task} using the reasoning guidelines below.
Expand Down Expand Up @@ -56,4 +55,4 @@ def react_prompt(task: str = None):
Remember, your goal is to provide a transparent and logical process that leads from observation to effective action. Your responses should demonstrate clear thinking, an understanding of the problem, and a rational approach to solving it. The use of tokens helps to structure your response and clarify the different stages of your reasoning and action.
"""
return REACT
return PROMPT
144 changes: 144 additions & 0 deletions tests/models/test_togther.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
import requests
import pytest
from unittest.mock import patch, Mock
from swarms.models.together import TogetherModel
import logging


@pytest.fixture
def mock_api_key(monkeypatch):
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")


def test_init_defaults():
model = TogetherModel()
assert model.together_api_key == "mocked-api-key"
assert model.logging_enabled is False
assert model.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
assert model.max_workers == 10
assert model.max_tokens == 300
assert model.api_endpoint == "https://api.together.xyz"
assert model.beautify is False
assert model.streaming_enabled is False
assert model.meta_prompt is False
assert model.system_prompt is None


def test_init_custom_params(mock_api_key):
model = TogetherModel(
together_api_key="custom-api-key",
logging_enabled=True,
model_name="custom-model",
max_workers=5,
max_tokens=500,
api_endpoint="https://custom-api.together.xyz",
beautify=True,
streaming_enabled=True,
meta_prompt="meta-prompt",
system_prompt="system-prompt",
)
assert model.together_api_key == "custom-api-key"
assert model.logging_enabled is True
assert model.model_name == "custom-model"
assert model.max_workers == 5
assert model.max_tokens == 500
assert model.api_endpoint == "https://custom-api.together.xyz"
assert model.beautify is True
assert model.streaming_enabled is True
assert model.meta_prompt == "meta-prompt"
assert model.system_prompt == "system-prompt"


@patch("swarms.models.together_model.requests.post")
def test_run_success(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Generated response"}}]
}
mock_post.return_value = mock_response

model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)

assert response == "Generated response"


@patch("swarms.models.together_model.requests.post")
def test_run_failure(mock_post, mock_api_key):
mock_post.side_effect = requests.exceptions.RequestException(
"Request failed"
)

model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)

assert response is None


def test_run_with_logging_enabled(caplog, mock_api_key):
model = TogetherModel(logging_enabled=True)
task = "What is the color of the object?"

with caplog.at_level(logging.DEBUG):
model.run(task)

assert "Sending request to" in caplog.text


@pytest.mark.parametrize(
"invalid_input", [None, 123, ["list", "of", "items"]]
)
def test_invalid_task_input(invalid_input, mock_api_key):
model = TogetherModel()
response = model.run(invalid_input)

assert response is None


@patch("swarms.models.together_model.requests.post")
def test_run_streaming_enabled(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {
"choices": [{"message": {"content": "Generated response"}}]
}
mock_post.return_value = mock_response

model = TogetherModel(streaming_enabled=True)
task = "What is the color of the object?"
response = model.run(task)

assert response == "Generated response"


@patch("swarms.models.together_model.requests.post")
def test_run_empty_choices(mock_post, mock_api_key):
mock_response = Mock()
mock_response.json.return_value = {"choices": []}
mock_post.return_value = mock_response

model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)

assert response is None


@patch("swarms.models.together_model.requests.post")
def test_run_with_exception(mock_post, mock_api_key):
mock_post.side_effect = Exception("Test exception")

model = TogetherModel()
task = "What is the color of the object?"
response = model.run(task)

assert response is None


def test_init_logging_disabled(monkeypatch):
monkeypatch.setenv("TOGETHER_API_KEY", "mocked-api-key")
model = TogetherModel()
assert model.logging_enabled is False
assert not model.system_prompt

0 comments on commit 4bef09a

Please sign in to comment.