From 091a33d7e96f04fd60365d3bef51be7edfa61cf4 Mon Sep 17 00:00:00 2001 From: yashbonde Date: Sat, 14 Dec 2024 21:47:56 +0530 Subject: [PATCH] [0.6.0] Turbo with distributed_chat --- docs/changelog.rst | 6 ++ pyproject.toml | 2 +- tuneapi/apis/__init__.py | 1 + tuneapi/apis/model_anthropic.py | 18 ++++ tuneapi/apis/model_gemini.py | 20 +++- tuneapi/apis/model_groq.py | 18 ++++ tuneapi/apis/model_mistral.py | 30 ++++-- tuneapi/apis/model_openai.py | 18 ++++ tuneapi/apis/model_tune.py | 18 ++++ tuneapi/apis/turbo.py | 172 ++++++++++++++++++++++++++++++++ tuneapi/types/chats.py | 3 + tuneapi/utils/misc.py | 1 + 12 files changed, 298 insertions(+), 9 deletions(-) create mode 100644 tuneapi/apis/turbo.py diff --git a/docs/changelog.rst b/docs/changelog.rst index e28fcb0..81b78bd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,12 @@ minor versions. All relevant steps to be taken will be mentioned here. +0.6.0 +----- + +- ``distributed_chat`` functionality in ``tuneapi.apis.turbo`` support. In all APIs search for ``model.distributed_chat()`` + method. This enables **fault tolerant LLM API calls**. + 0.5.13 ----- diff --git a/pyproject.toml b/pyproject.toml index 34004c8..54bd707 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tuneapi" -version = "0.5.13" +version = "0.6.0" description = "Tune AI APIs." authors = ["Frello Technology Private Limited "] license = "MIT" diff --git a/tuneapi/apis/__init__.py b/tuneapi/apis/__init__.py index a7e368f..9775425 100644 --- a/tuneapi/apis/__init__.py +++ b/tuneapi/apis/__init__.py @@ -7,3 +7,4 @@ from tuneapi.apis.model_groq import Groq from tuneapi.apis.model_mistral import Mistral from tuneapi.apis.model_gemini import Gemini +from tuneapi.apis.turbo import distributed_chat diff --git a/tuneapi/apis/model_anthropic.py b/tuneapi/apis/model_anthropic.py index 497bad5..df3f8d6 100644 --- a/tuneapi/apis/model_anthropic.py +++ b/tuneapi/apis/model_anthropic.py @@ -11,6 +11,7 @@ import tuneapi.utils as tu import tuneapi.types as tt +from tuneapi.apis.turbo import distributed_chat class Anthropic(tt.ModelInterface): @@ -236,6 +237,23 @@ def stream_chat( break return + def distributed_chat( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + ): + return distributed_chat( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + ) + # helper methods diff --git a/tuneapi/apis/model_gemini.py b/tuneapi/apis/model_gemini.py index a7f93f4..efc64ed 100644 --- a/tuneapi/apis/model_gemini.py +++ b/tuneapi/apis/model_gemini.py @@ -7,10 +7,11 @@ import json import requests -from typing import Optional, Any, Dict +from typing import Optional, Any, Dict, List import tuneapi.utils as tu import tuneapi.types as tt +from tuneapi.apis.turbo import distributed_chat class Gemini(tt.ModelInterface): @@ -276,3 +277,20 @@ def stream_chat( fn_call["arguments"] = fn_call.pop("args") yield fn_call block_lines = "" + + def distributed_chat( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + ): + return distributed_chat( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + ) diff --git a/tuneapi/apis/model_groq.py b/tuneapi/apis/model_groq.py index 7e5f7dd..13f4274 100644 --- a/tuneapi/apis/model_groq.py +++ b/tuneapi/apis/model_groq.py @@ -10,6 +10,7 @@ import tuneapi.utils as tu import tuneapi.types as tt +from tuneapi.apis.turbo import distributed_chat class Groq(tt.ModelInterface): @@ -190,3 +191,20 @@ def stream_chat( fn_call["arguments"] = tu.from_json(fn_call["arguments"]) yield fn_call return + + def distributed_chat( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + ): + return distributed_chat( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + ) diff --git a/tuneapi/apis/model_mistral.py b/tuneapi/apis/model_mistral.py index fdfce2a..1f141a3 100644 --- a/tuneapi/apis/model_mistral.py +++ b/tuneapi/apis/model_mistral.py @@ -10,8 +10,7 @@ import tuneapi.utils as tu import tuneapi.types as tt -from tuneapi.utils import ENV, SimplerTimes as stime, from_json, to_json -from tuneapi.types import Thread, human, Message +from tuneapi.apis.turbo import distributed_chat class Mistral(tt.ModelInterface): @@ -23,7 +22,7 @@ def __init__( ): self.model_id = id self.base_url = base_url - self.api_token = ENV.MISTRAL_TOKEN("") + self.api_token = tu.ENV.MISTRAL_TOKEN("") self.extra_headers = extra_headers def set_api_token(self, token: str) -> None: @@ -95,7 +94,7 @@ def _process_input(self, chats, token: Optional[str] = None): def chat( self, - chats: Thread | str, + chats: tt.Thread | str, model: Optional[str] = None, max_tokens: int = 1024, temperature: float = 1, @@ -124,7 +123,7 @@ def chat( def stream_chat( self, - chats: Thread | str, + chats: tt.Thread | str, model: Optional[str] = None, max_tokens: int = 1024, temperature: float = 1, @@ -135,7 +134,7 @@ def stream_chat( extra_headers: Optional[Dict[str, str]] = None, ): tools = [] - if isinstance(chats, Thread): + if isinstance(chats, tt.Thread): tools = [{"type": "function", "function": x.to_dict()} for x in chats.tools] headers, messages = self._process_input(chats, token) extra_headers = extra_headers or self.extra_headers @@ -191,6 +190,23 @@ def stream_chat( except: break if fn_call: - fn_call["arguments"] = from_json(fn_call["arguments"]) + fn_call["arguments"] = tu.from_json(fn_call["arguments"]) yield fn_call return + + def distributed_chat( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + ): + return distributed_chat( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + ) diff --git a/tuneapi/apis/model_openai.py b/tuneapi/apis/model_openai.py index 04dcbb6..13bf8b4 100644 --- a/tuneapi/apis/model_openai.py +++ b/tuneapi/apis/model_openai.py @@ -11,6 +11,7 @@ import tuneapi.utils as tu import tuneapi.types as tt +from tuneapi.apis.turbo import distributed_chat class Openai(tt.ModelInterface): @@ -190,6 +191,23 @@ def stream_chat( yield fn_call return + def distributed_chat( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + ): + return distributed_chat( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + ) + def embedding( self, chats: tt.Thread | List[str] | str, diff --git a/tuneapi/apis/model_tune.py b/tuneapi/apis/model_tune.py index d486a4f..8af80af 100644 --- a/tuneapi/apis/model_tune.py +++ b/tuneapi/apis/model_tune.py @@ -10,6 +10,7 @@ import tuneapi.utils as tu import tuneapi.types as tt +from tuneapi.apis.turbo import distributed_chat class TuneModel(tt.ModelInterface): @@ -217,3 +218,20 @@ def stream_chat( fn_call["arguments"] = tu.from_json(fn_call["arguments"]) yield fn_call return + + def distributed_chat( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + ): + return distributed_chat( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + ) diff --git a/tuneapi/apis/turbo.py b/tuneapi/apis/turbo.py new file mode 100644 index 0000000..0114976 --- /dev/null +++ b/tuneapi/apis/turbo.py @@ -0,0 +1,172 @@ +# Copyright © 2024- Frello Technology Private Limited + +import queue +import threading +from tqdm import trange +from typing import List, Optional +from dataclasses import dataclass + +from tuneapi.types import Thread, ModelInterface, human, system + + +def distributed_chat( + model: ModelInterface, + prompts: List[Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, +): + """ + Distributes multiple chat prompts across a thread pool for parallel processing. + + This function creates a pool of worker threads to process multiple chat prompts concurrently. It handles retry + logic for failed requests and maintains the order of responses corresponding to the input prompts. + + Args: + model (ModelInterface): The base model instance to clone for each worker thread. Each thread gets its own model + instance to ensure thread safety. + + prompts (List[Thread]): A list of chat prompts to process. The order of responses will match the order of these + prompts. + + post_logic (Optional[callable], default=None): A function to process each chat response before storing. If None, + raw responses are stored. Function signature should be: f(chat_response) -> processed_response + + max_threads (int, default=10): Maximum number of concurrent worker threads. Adjust based on API rate limits and + system capabilities. + + retry (int, default=3): Number of retry attempts for failed requests. Set to 0 to disable retries. + + pbar (bool, default=True): Whether to display a progress bar. + + Returns: + List[Any]: A list of responses or errors, maintaining the same order as input prompts. + Successful responses will be either raw or processed (if post_logic provided). + Failed requests (after retries) will contain the last error encountered. + + Raises: + ValueError: If max_threads < 1 or retry < 0 + TypeError: If model is not an instance of ModelInterface + + Example: + >>> model = ChatModel(api_token="...") + >>> prompts = [ + ... Thread([Message("What is 2+2?")]), + ... Thread([Message("What is Python?")]) + ... ] + >>> responses = distributed_chat(model, prompts, max_threads=5) + >>> for prompt, response in zip(prompts, responses): + ... print(f"Q: {prompt}\nA: {response}\n") + + Note: + - Each worker thread gets its own model instance to prevent sharing state + - Progress bar shows both initial processing and retries + - The function maintains thread safety through message passing channels + """ + task_channel = queue.Queue() + result_channel = queue.Queue() + + # Initialize results container + results = [None for _ in range(len(prompts))] + + def worker(): + while True: + try: + task: _Task = task_channel.get(timeout=1) + if task is None: # Poison pill + break + + try: + print(">") + out = task.model.chat(task.prompt) + if post_logic: + out = post_logic(out) + result_channel.put(_Result(task.index, out, True)) + except Exception as e: + if task.retry_count < retry: + # Create new model instance for retry + nm = model.__class__( + id=model.model_id, + base_url=model.base_url, + extra_headers=model.extra_headers, + ) + nm.set_api_token(model.api_token) + # Increment retry count and requeue + task_channel.put( + _Task(task.index, nm, task.prompt, task.retry_count + 1) + ) + else: + # If we've exhausted retries, store the error + result_channel.put(_Result(task.index, e, False, e)) + finally: + task_channel.task_done() + except queue.Empty: + continue + + # Create and start worker threads + workers = [] + for _ in range(max_threads): + t = threading.Thread(target=worker) + t.start() + workers.append(t) + + # Initialize progress bar + _pbar = trange(len(prompts), desc="Processing", unit=" input") if pbar else None + + # Queue initial tasks + for i, p in enumerate(prompts): + nm = model.__class__( + id=model.model_id, + base_url=model.base_url, + extra_headers=model.extra_headers, + ) + nm.set_api_token(model.api_token) + task_channel.put(_Task(i, nm, p)) + + # Process results + completed = 0 + while completed < len(prompts): + try: + result = result_channel.get(timeout=1) + results[result.index] = result.data if result.success else result.error + if _pbar: + _pbar.update(1) + completed += 1 + result_channel.task_done() + except queue.Empty: + continue + + # Cleanup + for _ in workers: + task_channel.put(None) # Send poison pills + for w in workers: + w.join() + + if _pbar: + _pbar.close() + + return results + + +# helpers + + +@dataclass +class _Task: + """Individual Task""" + + index: int + model: ModelInterface + prompt: Thread + retry_count: int = 0 + + +@dataclass +class _Result: + """The Result object""" + + index: int + data: any + success: bool + error: Optional[Exception] = None diff --git a/tuneapi/types/chats.py b/tuneapi/types/chats.py index 84f504d..07b15d8 100644 --- a/tuneapi/types/chats.py +++ b/tuneapi/types/chats.py @@ -308,6 +308,9 @@ class ModelInterface: extra_headers: Dict[str, Any] """This is the placeholder for any extra headers to be passed during request""" + base_url: str + """This is the default URL that has to be pinged. This may not be the REST endpoint URL but anything""" + def set_api_token(self, token: str) -> None: """This are used to set the API token for the model""" raise NotImplementedError("This model has no operation for this.") diff --git a/tuneapi/utils/misc.py b/tuneapi/utils/misc.py index 328f19b..b80c2a1 100644 --- a/tuneapi/utils/misc.py +++ b/tuneapi/utils/misc.py @@ -4,6 +4,7 @@ import os import base64 import hashlib + from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC