forked from Sinaptik-AI/pandas-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
95 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
""" | ||
Constants used in the pandasai package. | ||
""" | ||
|
||
START_CODE_TAG = "<startCode>" | ||
END_CODE_TAG = "<endCode>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
""" Base class to implement a new Hugging Face LLM. """ | ||
|
||
from typing import Optional | ||
|
||
import requests | ||
|
||
from .base import LLM | ||
|
||
|
||
class HuggingFaceLLM(LLM): | ||
"""Base class to implement a new Hugging Face LLM.""" | ||
|
||
last_prompt: Optional[str] = None | ||
api_token: str | ||
_api_url: str = "https://api-inference.huggingface.co/models/" | ||
_max_retries: int = 3 | ||
|
||
@property | ||
def type(self) -> str: | ||
return "huggingface-llm" | ||
|
||
def query(self, payload): | ||
"""Query the API""" | ||
|
||
headers = {"Authorization": f"Bearer {self.api_token}"} | ||
|
||
response = requests.post( | ||
self._api_url, headers=headers, json=payload, timeout=60 | ||
) | ||
|
||
return response.json()[0]["generated_text"] | ||
|
||
def call(self, instruction: str, value: str) -> str: | ||
"""Call the LLM""" | ||
|
||
payload = instruction + value | ||
|
||
# sometimes the API doesn't return a valid response, so we retry passing the | ||
# output generated from the previous call as the input | ||
for _i in range(self._max_retries): | ||
response = self.query({"inputs": payload}) | ||
payload = response | ||
if response.count("<endCode>") >= 2: | ||
break | ||
|
||
# replace instruction + value from the inputs to avoid showing it in the output | ||
output = response.replace(instruction + value, "") | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" Starcoder LLM """ | ||
|
||
import os | ||
from typing import Optional | ||
|
||
from dotenv import load_dotenv | ||
|
||
from ..exceptions import APIKeyNotFoundError | ||
from .base_hf import HuggingFaceLLM | ||
|
||
load_dotenv() | ||
|
||
|
||
class Starcoder(HuggingFaceLLM): | ||
"""Starcoder LLM""" | ||
|
||
api_token: str | ||
_api_url: str = "https://api-inference.huggingface.co/models/bigcode/starcoder" | ||
|
||
def __init__(self, api_token: Optional[str] = None): | ||
self.api_token = api_token or os.getenv("HUGGINGFACE_API_KEY") | ||
if self.api_token is None: | ||
raise APIKeyNotFoundError("HuggingFace Hub API key is required") | ||
|
||
@property | ||
def type(self) -> str: | ||
return "open-assistant" |