diff --git a/lionagi/integrations/config/__init__.py b/lionagi/integrations/config/__init__.py index 504075415..cf65f69e2 100644 --- a/lionagi/integrations/config/__init__.py +++ b/lionagi/integrations/config/__init__.py @@ -1,4 +1,6 @@ from .oai_configs import oai_schema from .openrouter_configs import openrouter_schema +from .groq_configs import groq_schema +from .cerebras_configs import cerebras_schema -__all__ = ["oai_schema", "openrouter_schema"] +__all__ = ["oai_schema", "openrouter_schema", "groq_schema", "cerebras_schema"] diff --git a/lionagi/integrations/config/cerebras_configs.py b/lionagi/integrations/config/cerebras_configs.py new file mode 100644 index 000000000..557363536 --- /dev/null +++ b/lionagi/integrations/config/cerebras_configs.py @@ -0,0 +1,76 @@ +# Default configs for the Cerebras API + +API_key_schema = ("CEREBRAS_API_KEY",) + +cerebras_chat_llmconfig = { + "model": "llama3.1-70b", + "frequency_penalty": 0, + "max_tokens": None, + "num": 1, + "presence_penalty": 0, + "response_format": {"type": "text"}, + "seed": None, + "stop": None, + "stream": False, + "temperature": 0.1, + "top_p": 1, + "tools": None, + "tool_choice": "none", + "user": None, + "logprobs": False, + "top_logprobs": None, +} + +cerebras_chat_schema = { + "required": [ + "model", + "frequency_penalty", + "num", + "presence_penalty", + "response_format", + "temperature", + "top_p", + ], + "optional": [ + "seed", + "stop", + "stream", + "tools", + "tool_choice", + "user", + "max_tokens", + "logprobs", + "top_logprobs", + ], + "input_": "messages", + "config": cerebras_chat_llmconfig, + "token_encoding_name": "cl100k_base", + "token_limit": 128_000, + "interval_tokens": 10_000, + "interval_requests": 100, + "interval": 60, +} + +cerebras_finetune_llmconfig = { + "model": "llama3.1-8b", + "hyperparameters": { + "batch_size": "auto", + "learning_rate_multiplier": "auto", + "n_epochs": "auto", + }, + "suffix": None, + "training_file": None, +} + +cerebras_finetune_schema = { + "required": ["model", "training_file"], + "optional": ["hyperparameters", "suffix", "validate_file"], + "input_": ["training_file"], + "config": cerebras_finetune_llmconfig, +} + +cerebras_schema = { + "chat/completions": cerebras_chat_schema, + "finetune": cerebras_finetune_schema, + "API_key_schema": API_key_schema, +} diff --git a/lionagi/integrations/config/groq_configs.py b/lionagi/integrations/config/groq_configs.py new file mode 100644 index 000000000..93d742d37 --- /dev/null +++ b/lionagi/integrations/config/groq_configs.py @@ -0,0 +1,76 @@ +# Default configs for the Groq API + +API_key_schema = ("GROQ_API_KEY",) + +groq_chat_llmconfig = { + "model": "llama3-70b-8192", + "frequency_penalty": 0, + "max_tokens": None, + "num": 1, + "presence_penalty": 0, + "response_format": {"type": "text"}, + "seed": None, + "stop": None, + "stream": False, + "temperature": 0.1, + "top_p": 1, + "tools": None, + "tool_choice": "none", + "user": None, + "logprobs": False, + "top_logprobs": None, +} + +groq_chat_schema = { + "required": [ + "model", + "frequency_penalty", + "num", + "presence_penalty", + "response_format", + "temperature", + "top_p", + ], + "optional": [ + "seed", + "stop", + "stream", + "tools", + "tool_choice", + "user", + "max_tokens", + "logprobs", + "top_logprobs", + ], + "input_": "messages", + "config": groq_chat_llmconfig, + "token_encoding_name": "cl100k_base", + "token_limit": 128_000, + "interval_tokens": 10_000, + "interval_requests": 100, + "interval": 60, +} + +groq_finetune_llmconfig = { + "model": "mixtral-8x7b-32768", + "hyperparameters": { + "batch_size": "auto", + "learning_rate_multiplier": "auto", + "n_epochs": "auto", + }, + "suffix": None, + "training_file": None, +} + +groq_finetune_schema = { + "required": ["model", "training_file"], + "optional": ["hyperparameters", "suffix", "validate_file"], + "input_": ["training_file"], + "config": groq_finetune_llmconfig, +} + +groq_schema = { + "chat/completions": groq_chat_schema, + "finetune": groq_finetune_schema, + "API_key_schema": API_key_schema, +} diff --git a/lionagi/integrations/provider/_mapping.py b/lionagi/integrations/provider/_mapping.py index 08eaa49be..b60703372 100644 --- a/lionagi/integrations/provider/_mapping.py +++ b/lionagi/integrations/provider/_mapping.py @@ -4,8 +4,12 @@ from .transformers import TransformersService from .litellm import LiteLLMService from .mlx_service import MLXService +from .groq import GroqService +from .cerebras import CerebrasService from lionagi.integrations.config.oai_configs import oai_schema from lionagi.integrations.config.openrouter_configs import openrouter_schema +from lionagi.integrations.config.groq_configs import groq_schema +from lionagi.integrations.config.cerebras_configs import cerebras_schema SERVICE_PROVIDERS_MAPPING = { "openai": { @@ -38,6 +42,16 @@ "schema": {"model": "mlx-community/OLMo-7B-hf-4bit-mlx"}, "default_model": "mlx-community/OLMo-7B-hf-4bit-mlx", }, + "groq": { + "service": GroqService, + "schema": groq_schema, + "default_model": "llama3-70b-8192", + }, + "cerebras": { + "service": CerebrasService, + "schema": cerebras_schema, + "default_model": "llama3.1-70b", + }, } # TODO diff --git a/lionagi/integrations/provider/cerebras.py b/lionagi/integrations/provider/cerebras.py new file mode 100644 index 000000000..eab94ca9b --- /dev/null +++ b/lionagi/integrations/provider/cerebras.py @@ -0,0 +1,141 @@ +# lionagi/integrations/provider/cerebras.py + +from os import getenv +from lionagi.integrations.config.cerebras_configs import cerebras_schema +from lionagi.libs.ln_api import BaseService, PayloadPackage + +allowed_kwargs = [ + "model", + "frequency_penalty", + "n", + "presence_penalty", + "response_format", + "temperature", + "top_p", + "seed", + "stop", + "stream", + "stream_options", + "tools", + "tool_choice", + "user", + "max_tokens", + "logprobs", + "top_logprobs", + "logit_bias", +] + + +class CerebrasService(BaseService): + base_url = "https://api.cerebras.ai/v1" + available_endpoints = ["chat/completions"] + schema = cerebras_schema + key_scheme = "CEREBRAS_API_KEY" + token_encoding_name = "cl100k_base" + + def __init__( + self, + api_key=None, + key_scheme=None, + schema=None, + token_encoding_name: str = "cl100k_base", + **kwargs, + ): + key_scheme = key_scheme or self.key_scheme + super().__init__( + api_key=api_key or getenv(key_scheme), + schema=schema or self.schema, + token_encoding_name=token_encoding_name, + **kwargs, + ) + self.active_endpoint = [] + self.allowed_kwargs = allowed_kwargs + + async def serve(self, input_, endpoint="chat/completions", method="post", **kwargs): + """ + Serves the input using the specified endpoint and method. + + Args: + input_: The input text to be processed. + endpoint: The API endpoint to use for processing. + method: The HTTP method to use for the request. + **kwargs: Additional keyword arguments to pass to the payload creation. + + Returns: + A tuple containing the payload and the completion assistant_response from the API. + + Raises: + ValueError: If the specified endpoint is not supported. + + Examples: + >>> service = cerebrasService(api_key="your_api_key") + >>> asyncio.run(service.serve("Hello, world!","chat/completions")) + (payload, completion) + + >>> service = cerebrasService() + >>> asyncio.run(service.serve("Convert this text to speech.","audio_speech")) + ValueError: 'audio_speech' is currently not supported + """ + if endpoint not in self.active_endpoint: + await self.init_endpoint(endpoint) + if endpoint == "chat/completions": + return await self.serve_chat(input_, **kwargs) + else: + return ValueError(f"{endpoint} is currently not supported") + + async def serve_chat(self, messages, required_tokens=None, **kwargs): + """ + Serves the chat completion request with the given messages. + + Args: + messages: The messages to be included in the chat completion. + **kwargs: Additional keyword arguments for payload creation. + + Returns: + A tuple containing the payload and the completion assistant_response from the API. + + Raises: + Exception: If the API call fails. + """ + if "chat/completions" not in self.active_endpoint: + await self.init_endpoint("chat/completions") + self.active_endpoint.append("chat/completions") + + msgs = [] + + for msg in messages: + if isinstance(msg, dict): + content = msg.get("content") + if isinstance(content, (dict, str)): + msgs.append({"role": msg["role"], "content": content}) + elif isinstance(content, list): + _content = [] + for i in content: + if "text" in i: + _content.append({"type": "text", "text": str(i["text"])}) + elif "image_url" in i: + _content.append( + { + "type": "image_url", + "image_url": { + "url": f"{i['image_url'].get('url')}", + "detail": i["image_url"].get("detail", "low"), + }, + } + ) + msgs.append({"role": msg["role"], "content": _content}) + + payload = PayloadPackage.chat_completion( + msgs, + self.endpoints["chat/completions"].config, + self.schema["chat/completions"], + **kwargs, + ) + try: + completion = await self.call_api( + payload, "chat/completions", "post", required_tokens=required_tokens + ) + return payload, completion + except Exception as e: + self.status_tracker.num_tasks_failed += 1 + raise e diff --git a/lionagi/integrations/provider/groq.py b/lionagi/integrations/provider/groq.py new file mode 100644 index 000000000..4b0af384e --- /dev/null +++ b/lionagi/integrations/provider/groq.py @@ -0,0 +1,141 @@ +# lionagi/integrations/provider/groq.py + +from os import getenv +from lionagi.integrations.config.groq_configs import groq_schema +from lionagi.libs.ln_api import BaseService, PayloadPackage + +allowed_kwargs = [ + "model", + "frequency_penalty", + "n", + "presence_penalty", + "response_format", + "temperature", + "top_p", + "seed", + "stop", + "stream", + "stream_options", + "tools", + "tool_choice", + "user", + "max_tokens", + "logprobs", + "top_logprobs", + "logit_bias", +] + + +class GroqService(BaseService): + base_url = "https://api.groq.com/openai/v1" + available_endpoints = ["chat/completions"] + schema = groq_schema + key_scheme = "GROQ_API_KEY" + token_encoding_name = "cl100k_base" + + def __init__( + self, + api_key=None, + key_scheme=None, + schema=None, + token_encoding_name: str = "cl100k_base", + **kwargs, + ): + key_scheme = key_scheme or self.key_scheme + super().__init__( + api_key=api_key or getenv(key_scheme), + schema=schema or self.schema, + token_encoding_name=token_encoding_name, + **kwargs, + ) + self.active_endpoint = [] + self.allowed_kwargs = allowed_kwargs + + async def serve(self, input_, endpoint="chat/completions", method="post", **kwargs): + """ + Serves the input using the specified endpoint and method. + + Args: + input_: The input text to be processed. + endpoint: The API endpoint to use for processing. + method: The HTTP method to use for the request. + **kwargs: Additional keyword arguments to pass to the payload creation. + + Returns: + A tuple containing the payload and the completion assistant_response from the API. + + Raises: + ValueError: If the specified endpoint is not supported. + + Examples: + >>> service = GroqService(api_key="your_api_key") + >>> asyncio.run(service.serve("Hello, world!","chat/completions")) + (payload, completion) + + >>> service = GroqService() + >>> asyncio.run(service.serve("Convert this text to speech.","audio_speech")) + ValueError: 'audio_speech' is currently not supported + """ + if endpoint not in self.active_endpoint: + await self.init_endpoint(endpoint) + if endpoint == "chat/completions": + return await self.serve_chat(input_, **kwargs) + else: + return ValueError(f"{endpoint} is currently not supported") + + async def serve_chat(self, messages, required_tokens=None, **kwargs): + """ + Serves the chat completion request with the given messages. + + Args: + messages: The messages to be included in the chat completion. + **kwargs: Additional keyword arguments for payload creation. + + Returns: + A tuple containing the payload and the completion assistant_response from the API. + + Raises: + Exception: If the API call fails. + """ + if "chat/completions" not in self.active_endpoint: + await self.init_endpoint("chat/completions") + self.active_endpoint.append("chat/completions") + + msgs = [] + + for msg in messages: + if isinstance(msg, dict): + content = msg.get("content") + if isinstance(content, (dict, str)): + msgs.append({"role": msg["role"], "content": content}) + elif isinstance(content, list): + _content = [] + for i in content: + if "text" in i: + _content.append({"type": "text", "text": str(i["text"])}) + elif "image_url" in i: + _content.append( + { + "type": "image_url", + "image_url": { + "url": f"{i['image_url'].get('url')}", + "detail": i["image_url"].get("detail", "low"), + }, + } + ) + msgs.append({"role": msg["role"], "content": _content}) + + payload = PayloadPackage.chat_completion( + msgs, + self.endpoints["chat/completions"].config, + self.schema["chat/completions"], + **kwargs, + ) + try: + completion = await self.call_api( + payload, "chat/completions", "post", required_tokens=required_tokens + ) + return payload, completion + except Exception as e: + self.status_tracker.num_tasks_failed += 1 + raise e diff --git a/lionagi/integrations/provider/services.py b/lionagi/integrations/provider/services.py index 15747084a..38ac064a7 100644 --- a/lionagi/integrations/provider/services.py +++ b/lionagi/integrations/provider/services.py @@ -134,3 +134,31 @@ def MLX(**kwargs): from lionagi.integrations.provider.mlx_service import MLXService return MLXService(**kwargs) + + @staticmethod + def Groq(**kwargs): + """ + A provider to interact with Groq APLI + + Attributes: + model (str): name of the model to use + kwargs (Optional[Any]): additional kwargs for calling the model + """ + + from lionagi.integrations.provider.groq import GroqService + + return GroqService(**kwargs) + + @staticmethod + def Cerebras(**kwargs): + """ + A provider to interact with Cerebras APLI + + Attributes: + model (str): name of the model to use + kwargs (Optional[Any]): additional kwargs for calling the model + """ + + from lionagi.integrations.provider.cerebras import CerebrasService + + return CerebrasService(**kwargs)