Skip to content

Commit

Permalink
feat: add Starcoder LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed May 5, 2023
1 parent 60fd835 commit abf9ab1
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 24 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ As an alternative, you can also pass the environment variables directly to the c
# OpenAI
llm = OpenAI(api_token="YOUR_OPENAI_API_KEY")

# OpenAssistant
llm = OpenAssistant(api_token="YOUR_HF_API_KEY")
# Starcoder
llm = Starcoder(api_token="YOUR_HF_API_KEY")
```

## License
Expand Down
9 changes: 7 additions & 2 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd

from .constants import END_CODE_TAG, START_CODE_TAG
from .exceptions import LLMNotFoundError
from .helpers.notebook import Notebook
from .llm.base import LLM
Expand All @@ -20,7 +21,7 @@ class PandasAI:
This is the result of `print(df.head({rows_to_display}))`:
{df_head}.
Return the python code (do not import anything) and make sure to prefix the python code with <startCode> exactly and suffix the code with <endCode> exactly
Return the python code (do not import anything) and make sure to prefix the python code with {START_CODE_TAG} exactly and suffix the code with {END_CODE_TAG} exactly
to get the answer to the following question :
"""
_response_instruction: str = """
Expand All @@ -38,7 +39,7 @@ class PandasAI:
and this fails with the following error:
{error_returned}
Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error.
Make sure to prefix the python code with <startCode> exactly and suffix the code with <endCode> exactly.
Make sure to prefix the python code with {START_CODE_TAG} exactly and suffix the code with {END_CODE_TAG} exactly.
"""
_llm: LLM
_verbose: bool = False
Expand Down Expand Up @@ -96,13 +97,17 @@ def run(
self._task_instruction.format(
df_head=data_frame.head(rows_to_display),
rows_to_display=rows_to_display,
START_CODE_TAG=START_CODE_TAG,
END_CODE_TAG=END_CODE_TAG,
),
prompt,
)
self._original_instruction_and_prompt = (
self._task_instruction.format(
df_head=data_frame.head(rows_to_display),
rows_to_display=rows_to_display,
START_CODE_TAG=START_CODE_TAG,
END_CODE_TAG=END_CODE_TAG,
)
+ prompt
)
Expand Down
6 changes: 6 additions & 0 deletions pandasai/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Constants used in the pandasai package.
"""

START_CODE_TAG = "<startCode>"
END_CODE_TAG = "<endCode>"
3 changes: 2 additions & 1 deletion pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import astor

from ..constants import END_CODE_TAG, START_CODE_TAG
from ..exceptions import (
APIKeyNotFoundError,
MethodNotImplementedError,
Expand Down Expand Up @@ -84,7 +85,7 @@ def _extract_code(self, response: str, separator: str = "```") -> str:
code = response
if len(response.split(separator)) > 1:
code = response.split(separator)[1]
match = re.search(r"<startCode>(.*?)<endCode>", code, re.DOTALL)
match = re.search(rf"{START_CODE_TAG}(.*){END_CODE_TAG}", code, re.DOTALL)
if match:
code = match.group(1).strip()
code = self._polish_code(code)
Expand Down
48 changes: 48 additions & 0 deletions pandasai/llm/base_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
""" Base class to implement a new Hugging Face LLM. """

from typing import Optional

import requests

from .base import LLM


class HuggingFaceLLM(LLM):
"""Base class to implement a new Hugging Face LLM."""

last_prompt: Optional[str] = None
api_token: str
_api_url: str = "https://api-inference.huggingface.co/models/"
_max_retries: int = 3

@property
def type(self) -> str:
return "huggingface-llm"

def query(self, payload):
"""Query the API"""

headers = {"Authorization": f"Bearer {self.api_token}"}

response = requests.post(
self._api_url, headers=headers, json=payload, timeout=60
)

return response.json()[0]["generated_text"]

def call(self, instruction: str, value: str) -> str:
"""Call the LLM"""

payload = instruction + value

# sometimes the API doesn't return a valid response, so we retry passing the
# output generated from the previous call as the input
for _i in range(self._max_retries):
response = self.query({"inputs": payload})
payload = response
if response.count("<endCode>") >= 2:
break

# replace instruction + value from the inputs to avoid showing it in the output
output = response.replace(instruction + value, "")
return output
22 changes: 3 additions & 19 deletions pandasai/llm/open_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,29 @@
import os
from typing import Optional

import requests
from dotenv import load_dotenv

from ..exceptions import APIKeyNotFoundError
from .base import LLM
from .base_hf import HuggingFaceLLM

load_dotenv()


class OpenAssistant(LLM):
class OpenAssistant(HuggingFaceLLM):
"""Open Assistant LLM"""

api_token: str
_api_url: str = (
"https://api-inference.huggingface.co/models/"
"OpenAssistant/oasst-sft-1-pythia-12b"
)
_max_retries: int = 10

def __init__(self, api_token: Optional[str] = None):
self.api_token = api_token or os.getenv("HUGGINGFACE_API_KEY")
if self.api_token is None:
raise APIKeyNotFoundError("HuggingFace Hub API key is required")

def query(self, payload):
"""Query the API"""

headers = {"Authorization": f"Bearer {self.api_token}"}

response = requests.post(
self._api_url, headers=headers, json=payload, timeout=60
)
return response.json()

def call(self, instruction: str, value: str) -> str:
output = self.query(
{"inputs": "<|prompter|>" + instruction + value + "<|endoftext|>"}
)
return output[0]["generated_text"]

@property
def type(self) -> str:
return "open-assistant"
27 changes: 27 additions & 0 deletions pandasai/llm/starcoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
""" Starcoder LLM """

import os
from typing import Optional

from dotenv import load_dotenv

from ..exceptions import APIKeyNotFoundError
from .base_hf import HuggingFaceLLM

load_dotenv()


class Starcoder(HuggingFaceLLM):
"""Starcoder LLM"""

api_token: str
_api_url: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"

def __init__(self, api_token: Optional[str] = None):
self.api_token = api_token or os.getenv("HUGGINGFACE_API_KEY")
if self.api_token is None:
raise APIKeyNotFoundError("HuggingFace Hub API key is required")

@property
def type(self) -> str:
return "open-assistant"

0 comments on commit abf9ab1

Please sign in to comment.