Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Groq and Cerebras API calls #461

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
4 changes: 3 additions & 1 deletion lionagi/integrations/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .oai_configs import oai_schema
from .openrouter_configs import openrouter_schema
from .groq_configs import groq_schema
from .cerebras_configs import cerebras_schema

__all__ = ["oai_schema", "openrouter_schema"]
__all__ = ["oai_schema", "openrouter_schema", "groq_schema", "cerebras_schema"]
76 changes: 76 additions & 0 deletions lionagi/integrations/config/cerebras_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Default configs for the Cerebras API

API_key_schema = ("CEREBRAS_API_KEY",)

cerebras_chat_llmconfig = {
"model": "llama3.1-70b",
"frequency_penalty": 0,
"max_tokens": None,
"num": 1,
"presence_penalty": 0,
"response_format": {"type": "text"},
"seed": None,
"stop": None,
"stream": False,
"temperature": 0.1,
"top_p": 1,
"tools": None,
"tool_choice": "none",
"user": None,
"logprobs": False,
"top_logprobs": None,
}

cerebras_chat_schema = {
"required": [
"model",
"frequency_penalty",
"num",
"presence_penalty",
"response_format",
"temperature",
"top_p",
],
"optional": [
"seed",
"stop",
"stream",
"tools",
"tool_choice",
"user",
"max_tokens",
"logprobs",
"top_logprobs",
],
"input_": "messages",
"config": cerebras_chat_llmconfig,
"token_encoding_name": "cl100k_base",
"token_limit": 128_000,
"interval_tokens": 10_000,
"interval_requests": 100,
"interval": 60,
}

cerebras_finetune_llmconfig = {
"model": "llama3.1-8b",
"hyperparameters": {
"batch_size": "auto",
"learning_rate_multiplier": "auto",
"n_epochs": "auto",
},
"suffix": None,
"training_file": None,
}

cerebras_finetune_schema = {
"required": ["model", "training_file"],
"optional": ["hyperparameters", "suffix", "validate_file"],
"input_": ["training_file"],
"config": cerebras_finetune_llmconfig,
}

cerebras_schema = {
"chat/completions": cerebras_chat_schema,
"finetune": cerebras_finetune_schema,
"API_key_schema": API_key_schema,
}
76 changes: 76 additions & 0 deletions lionagi/integrations/config/groq_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Default configs for the Groq API

API_key_schema = ("GROQ_API_KEY",)

groq_chat_llmconfig = {
"model": "llama3-70b-8192",
"frequency_penalty": 0,
"max_tokens": None,
"num": 1,
"presence_penalty": 0,
"response_format": {"type": "text"},
"seed": None,
"stop": None,
"stream": False,
"temperature": 0.1,
"top_p": 1,
"tools": None,
"tool_choice": "none",
"user": None,
"logprobs": False,
"top_logprobs": None,
}

groq_chat_schema = {
"required": [
"model",
"frequency_penalty",
"num",
"presence_penalty",
"response_format",
"temperature",
"top_p",
],
"optional": [
"seed",
"stop",
"stream",
"tools",
"tool_choice",
"user",
"max_tokens",
"logprobs",
"top_logprobs",
],
"input_": "messages",
"config": groq_chat_llmconfig,
"token_encoding_name": "cl100k_base",
"token_limit": 128_000,
"interval_tokens": 10_000,
"interval_requests": 100,
"interval": 60,
}

groq_finetune_llmconfig = {
"model": "mixtral-8x7b-32768",
"hyperparameters": {
"batch_size": "auto",
"learning_rate_multiplier": "auto",
"n_epochs": "auto",
},
"suffix": None,
"training_file": None,
}

groq_finetune_schema = {
"required": ["model", "training_file"],
"optional": ["hyperparameters", "suffix", "validate_file"],
"input_": ["training_file"],
"config": groq_finetune_llmconfig,
}

groq_schema = {
"chat/completions": groq_chat_schema,
"finetune": groq_finetune_schema,
"API_key_schema": API_key_schema,
}
14 changes: 14 additions & 0 deletions lionagi/integrations/provider/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from .transformers import TransformersService
from .litellm import LiteLLMService
from .mlx_service import MLXService
from .groq import GroqService
from .cerebras import CerebrasService
from lionagi.integrations.config.oai_configs import oai_schema
from lionagi.integrations.config.openrouter_configs import openrouter_schema
from lionagi.integrations.config.groq_configs import groq_schema
from lionagi.integrations.config.cerebras_configs import cerebras_schema

SERVICE_PROVIDERS_MAPPING = {
"openai": {
Expand Down Expand Up @@ -38,6 +42,16 @@
"schema": {"model": "mlx-community/OLMo-7B-hf-4bit-mlx"},
"default_model": "mlx-community/OLMo-7B-hf-4bit-mlx",
},
"groq": {
"service": GroqService,
"schema": groq_schema,
"default_model": "llama3-70b-8192",
},
"cerebras": {
"service": CerebrasService,
"schema": cerebras_schema,
"default_model": "llama3.1-70b",
},
}

# TODO
Expand Down
141 changes: 141 additions & 0 deletions lionagi/integrations/provider/cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# lionagi/integrations/provider/cerebras.py

from os import getenv
from lionagi.integrations.config.cerebras_configs import cerebras_schema
from lionagi.libs.ln_api import BaseService, PayloadPackage

allowed_kwargs = [
"model",
"frequency_penalty",
"n",
"presence_penalty",
"response_format",
"temperature",
"top_p",
"seed",
"stop",
"stream",
"stream_options",
"tools",
"tool_choice",
"user",
"max_tokens",
"logprobs",
"top_logprobs",
"logit_bias",
]


class CerebrasService(BaseService):
base_url = "https://api.cerebras.ai/v1"
available_endpoints = ["chat/completions"]
schema = cerebras_schema
key_scheme = "CEREBRAS_API_KEY"
token_encoding_name = "cl100k_base"

def __init__(
self,
api_key=None,
key_scheme=None,
schema=None,
token_encoding_name: str = "cl100k_base",
**kwargs,
):
key_scheme = key_scheme or self.key_scheme
super().__init__(
api_key=api_key or getenv(key_scheme),
schema=schema or self.schema,
token_encoding_name=token_encoding_name,
**kwargs,
)
self.active_endpoint = []
self.allowed_kwargs = allowed_kwargs

async def serve(self, input_, endpoint="chat/completions", method="post", **kwargs):
"""
Serves the input using the specified endpoint and method.

Args:
input_: The input text to be processed.
endpoint: The API endpoint to use for processing.
method: The HTTP method to use for the request.
**kwargs: Additional keyword arguments to pass to the payload creation.

Returns:
A tuple containing the payload and the completion assistant_response from the API.

Raises:
ValueError: If the specified endpoint is not supported.

Examples:
>>> service = cerebrasService(api_key="your_api_key")
>>> asyncio.run(service.serve("Hello, world!","chat/completions"))
(payload, completion)

>>> service = cerebrasService()
>>> asyncio.run(service.serve("Convert this text to speech.","audio_speech"))
ValueError: 'audio_speech' is currently not supported
"""
if endpoint not in self.active_endpoint:
await self.init_endpoint(endpoint)
if endpoint == "chat/completions":
return await self.serve_chat(input_, **kwargs)
else:
return ValueError(f"{endpoint} is currently not supported")

async def serve_chat(self, messages, required_tokens=None, **kwargs):
"""
Serves the chat completion request with the given messages.

Args:
messages: The messages to be included in the chat completion.
**kwargs: Additional keyword arguments for payload creation.

Returns:
A tuple containing the payload and the completion assistant_response from the API.

Raises:
Exception: If the API call fails.
"""
if "chat/completions" not in self.active_endpoint:
await self.init_endpoint("chat/completions")
self.active_endpoint.append("chat/completions")

msgs = []

for msg in messages:
if isinstance(msg, dict):
content = msg.get("content")
if isinstance(content, (dict, str)):
msgs.append({"role": msg["role"], "content": content})
elif isinstance(content, list):
_content = []
for i in content:
if "text" in i:
_content.append({"type": "text", "text": str(i["text"])})
elif "image_url" in i:
_content.append(
{
"type": "image_url",
"image_url": {
"url": f"{i['image_url'].get('url')}",
"detail": i["image_url"].get("detail", "low"),
},
}
)
msgs.append({"role": msg["role"], "content": _content})

payload = PayloadPackage.chat_completion(
msgs,
self.endpoints["chat/completions"].config,
self.schema["chat/completions"],
**kwargs,
)
try:
completion = await self.call_api(
payload, "chat/completions", "post", required_tokens=required_tokens
)
return payload, completion
except Exception as e:
self.status_tracker.num_tasks_failed += 1
raise e
Loading