Skip to content

Commit

Permalink
refactor: move prompt templates to files (#593)
Browse files Browse the repository at this point in the history
* (refactor): define base prompt class as an abstract
* (refactor): attribute `text` rename to `template` (since it's
  essentially is a template)
* (feat): add FileBasedPrompt
  • Loading branch information
nautics889 committed Sep 26, 2023
1 parent 57c5259 commit b9f4c37
Show file tree
Hide file tree
Showing 26 changed files with 201 additions and 208 deletions.
15 changes: 15 additions & 0 deletions assets/prompt-templates/correct_error_prompt.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

You are provided with the following {engine} DataFrames with the following metadata:

{dataframes}

The user asked the following question:
{conversation}

You generated this python code:
{code}

It fails with the following error:
{error_returned}

Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again.
30 changes: 30 additions & 0 deletions assets/prompt-templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

You are provided with the following pandas DataFrames:

{dataframes}

<conversation>
{conversation}
</conversation>

This is the initial python code to be updated:
```python
# TODO import all the dependencies required
{default_import}

def analyze_data(dfs: list[{engine_df_name}]) -> dict:
"""
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.)
4. Output: return a dictionary of:
- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }}
"""
```

Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.

Updated code:
19 changes: 12 additions & 7 deletions docs/custom-prompts.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ To create your custom prompt create a new CustomPromptClass inherited from base

```python
from pandasai import SmartDataframe
from pandasai.prompts import Prompt
from pandasai.prompts import AbstractPrompt


class MyCustomPrompt(AbstractPrompt):
template = """This is your custom text for your prompt with custom {my_custom_value}"""

class MyCustomPrompt(Prompt):
text = """This is your custom text for your prompt with custom {my_custom_value}"""

df = SmartDataframe("data.csv", {
"custom_prompts": {
"generate_python_code": MyCustomPrompt(my_custom_value="my custom value")
"generate_python_code": MyCustomPrompt(
my_custom_value="my custom value")
}
})
```
Expand All @@ -36,15 +39,17 @@ You can directly access the default prompt variables (for example dfs, conversat

```python
from pandasai import SmartDataframe
from pandasai.prompts import Prompt
from pandasai.prompts import AbstractPrompt

class MyCustomPrompt(Prompt):
text = """You are given a dataframe with number if rows equal to {dfs[0].shape[0]} and number of columns equal to {dfs[0].shape[1]}

class MyCustomPrompt(AbstractPrompt):
template = """You are given a dataframe with number if rows equal to {dfs[0].shape[0]} and number of columns equal to {dfs[0].shape[1]}
Here's the conversation:
{conversation}
"""


df = SmartDataframe("data.csv", {
"custom_prompts": {
"generate_python_code": MyCustomPrompt()
Expand Down
4 changes: 2 additions & 2 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import pandas as pd
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake
from .prompts.base import Prompt
from .prompts.base import AbstractPrompt
from .callbacks.base import BaseCallback
from .schemas.df_config import Config
from .helpers.cache import Cache
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(
middlewares=None,
custom_whitelisted_dependencies=None,
enable_logging=True,
non_default_prompts: Optional[Dict[str, Type[Prompt]]] = None,
non_default_prompts: Optional[Dict[str, Type[AbstractPrompt]]] = None,
callback: Optional[BaseCallback] = None,
):
"""
Expand Down
2 changes: 1 addition & 1 deletion pandasai/helpers/from_google_sheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get_google_sheet(src) -> list:
cols = row.find_all("td")
clean_row = []
for col in cols:
clean_row.append(col.text)
clean_row.append(col.template)
grid.append(clean_row)
return grid

Expand Down
6 changes: 3 additions & 3 deletions pandasai/llm/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..helpers import load_dotenv

from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
from ..prompts.base import Prompt
from ..prompts.base import AbstractPrompt
from .base import BaseOpenAI

load_dotenv()
Expand Down Expand Up @@ -105,12 +105,12 @@ def _default_params(self) -> Dict[str, Any]:
"""
return {**super()._default_params, "engine": self.engine}

def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the Azure OpenAI LLM.
Args:
instruction (Prompt): A prompt object with instruction for LLM.
instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): Suffix to pass.
Returns:
Expand Down
18 changes: 9 additions & 9 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CustomLLM(BaseOpenAI):
NoCodeFoundError,
)
from ..helpers.openai_info import openai_callback_var
from ..prompts.base import Prompt
from ..prompts.base import AbstractPrompt


class LLM:
Expand Down Expand Up @@ -120,12 +120,12 @@ def _extract_code(self, response: str, separator: str = "```") -> str:
return code

@abstractmethod
def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Execute the LLM with given prompt.
Args:
instruction (Prompt): A prompt object with instruction for LLM.
instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str, optional): Suffix. Defaults to "".
Raises:
Expand All @@ -134,12 +134,12 @@ def call(self, instruction: Prompt, suffix: str = "") -> str:
"""
raise MethodNotImplementedError("Call method has not been implemented")

def generate_code(self, instruction: Prompt) -> str:
def generate_code(self, instruction: AbstractPrompt) -> str:
"""
Generate the code based on the instruction and the given prompt.
Args:
instruction (Prompt): Prompt with instruction for LLM.
instruction (AbstractPrompt): Prompt with instruction for LLM.
Returns:
str: A string of Python code.
Expand Down Expand Up @@ -334,11 +334,11 @@ def query(self, payload) -> str:

return response.json()[0]["generated_text"]

def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
A call method of HuggingFaceLLM class.
Args:
instruction (Prompt): A prompt object with instruction for LLM.
instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): A string representing the suffix to be truncated
from the generated response.
Expand Down Expand Up @@ -429,12 +429,12 @@ def _generate_text(self, prompt: str) -> str:
"""
raise MethodNotImplementedError("method has not been implemented")

def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the Google LLM.
Args:
instruction (Prompt): Instruction to pass.
instruction (AbstractPrompt): Instruction to pass.
suffix (str): Suffix to pass. Defaults to an empty string ("").
Returns:
Expand Down
4 changes: 2 additions & 2 deletions pandasai/llm/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

from ..prompts.base import Prompt
from ..prompts.base import AbstractPrompt
from .base import LLM


Expand All @@ -16,7 +16,7 @@ def __init__(self, output: Optional[str] = None):
if output is not None:
self._output = output

def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
self.last_prompt = instruction.to_string() + suffix
return self._output

Expand Down
14 changes: 7 additions & 7 deletions pandasai/llm/huggingface_text_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .base import LLM
from ..helpers import load_dotenv
from ..prompts.base import Prompt
from ..prompts.base import AbstractPrompt

load_dotenv()

Expand All @@ -14,7 +14,7 @@ class HuggingFaceTextGen(LLM):
top_k: Optional[int] = None
top_p: Optional[float] = 0.8
typical_p: Optional[float] = 0.8
temperature: float = 1E-3 # must be strictly positive
temperature: float = 1e-3 # must be strictly positive
repetition_penalty: Optional[float] = None
truncate: Optional[int] = None
stop_sequences: List[str] = None
Expand All @@ -29,7 +29,7 @@ def __init__(self, inference_server_url: str, **kwargs):
try:
import text_generation

for (key, val) in kwargs.items():
for key, val in kwargs.items():
if key in self.__annotations__:
setattr(self, key, val)

Expand Down Expand Up @@ -60,14 +60,14 @@ def _default_params(self) -> Dict[str, Any]:
"seed": self.seed,
}

def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
prompt = instruction.to_string() + suffix

params = self._default_params
if self.streaming:
completion = ""
for chunk in self.client.generate_stream(prompt, **params):
completion += chunk.text
completion += chunk.template
return completion

res = self.client.generate(prompt, **params)
Expand All @@ -76,8 +76,8 @@ def call(self, instruction: Prompt, suffix: str = "") -> str:
for stop_seq in self.stop_sequences:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
:res.generated_text.index(stop_seq)
]
: res.generated_text.index(stop_seq)
]
return res.generated_text

@property
Expand Down
4 changes: 2 additions & 2 deletions pandasai/llm/langchain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pandasai.prompts.base import Prompt
from pandasai.prompts.base import AbstractPrompt
from .base import LLM


Expand All @@ -13,7 +13,7 @@ class LangchainLLM(LLM):
def __init__(self, langchain_llm):
self._langchain_llm = langchain_llm

def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
prompt = instruction.to_string() + suffix
return self._langchain_llm.predict(prompt)

Expand Down
6 changes: 3 additions & 3 deletions pandasai/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..helpers import load_dotenv

from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
from ..prompts.base import Prompt
from ..prompts.base import AbstractPrompt
from .base import BaseOpenAI

load_dotenv()
Expand Down Expand Up @@ -85,12 +85,12 @@ def _default_params(self) -> Dict[str, Any]:
"model": self.model,
}

def call(self, instruction: Prompt, suffix: str = "") -> str:
def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the OpenAI LLM.
Args:
instruction (Prompt): A prompt object with instruction for LLM.
instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): Suffix to pass.
Raises:
Expand Down
12 changes: 6 additions & 6 deletions pandasai/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .base import Prompt
from .correct_error_prompt import CorrectErrorPrompt
from .generate_python_code import GeneratePythonCodePrompt
from .base import AbstractPrompt
from .correct_error_prompt import CorrectErrorAbstractPrompt
from .generate_python_code import GeneratePythonCodeAbstractPrompt

__all__ = [
"Prompt",
"CorrectErrorPrompt",
"GeneratePythonCodePrompt",
"AbstractPrompt",
"CorrectErrorAbstractPrompt",
"GeneratePythonCodeAbstractPrompt",
]
37 changes: 29 additions & 8 deletions pandasai/prompts/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
""" Base class to implement a new Prompt
In order to better handle the instructions, this prompt module is written.
"""
import os
from abc import ABC, abstractmethod

from pandasai.exceptions import MethodNotImplementedError


class Prompt:
class AbstractPrompt(ABC):
"""Base class to implement a new Prompt"""

text = None
_args = {}

def __init__(self, **kwargs):
Expand Down Expand Up @@ -46,16 +45,38 @@ def _generate_dataframes(self, dfs):

return "\n\n".join(dataframes)

@property
@abstractmethod
def template(self):
...

def set_var(self, var, value):
if var == "dfs":
self._args["dataframes"] = self._generate_dataframes(value)
self._args[var] = value

def to_string(self):
if self.text is None:
raise MethodNotImplementedError

return self.text.format(**self._args)
return self.template.format(**self._args)

def __str__(self):
return self.to_string()


class FileBasedPrompt(AbstractPrompt):
_path_to_template: str

def __init__(self, **kwargs):
if (template_path := kwargs.pop("path_to_template", None)) is not None:
self._path_to_template = template_path

super().__init__(**kwargs)

@property
def template(self):
if not os.path.exists(self._path_to_template):
raise FileNotFoundError(
f"Unable to find a file with template at '{self._path_to_template}' "
f"for '{self.__class__.__name__}' prompt."
)
with open(self._path_to_template) as fp:
return fp.read()
Loading

0 comments on commit b9f4c37

Please sign in to comment.