diff --git a/assets/prompt-templates/correct_error_prompt.tmpl b/assets/prompt-templates/correct_error_prompt.tmpl
new file mode 100644
index 000000000..c73687279
--- /dev/null
+++ b/assets/prompt-templates/correct_error_prompt.tmpl
@@ -0,0 +1,15 @@
+
+You are provided with the following {engine} DataFrames with the following metadata:
+
+{dataframes}
+
+The user asked the following question:
+{conversation}
+
+You generated this python code:
+{code}
+
+It fails with the following error:
+{error_returned}
+
+Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again.
diff --git a/assets/prompt-templates/generate_python_code.tmpl b/assets/prompt-templates/generate_python_code.tmpl
new file mode 100644
index 000000000..10a057c58
--- /dev/null
+++ b/assets/prompt-templates/generate_python_code.tmpl
@@ -0,0 +1,30 @@
+
+You are provided with the following pandas DataFrames:
+
+{dataframes}
+
+
+{conversation}
+
+
+This is the initial python code to be updated:
+```python
+# TODO import all the dependencies required
+{default_import}
+
+def analyze_data(dfs: list[{engine_df_name}]) -> dict:
+ """
+ Analyze the data
+ 1. Prepare: Preprocessing and cleaning data if necessary
+ 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
+ 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.)
+ 4. Output: return a dictionary of:
+ - type (possible values "text", "number", "dataframe", "plot")
+ - value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
+ Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }}
+ """
+```
+
+Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
+
+Updated code:
diff --git a/docs/custom-prompts.md b/docs/custom-prompts.md
index 9a64a3968..b375ada83 100644
--- a/docs/custom-prompts.md
+++ b/docs/custom-prompts.md
@@ -16,14 +16,17 @@ To create your custom prompt create a new CustomPromptClass inherited from base
```python
from pandasai import SmartDataframe
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
+
+
+class MyCustomPrompt(AbstractPrompt):
+ template = """This is your custom text for your prompt with custom {my_custom_value}"""
-class MyCustomPrompt(Prompt):
- text = """This is your custom text for your prompt with custom {my_custom_value}"""
df = SmartDataframe("data.csv", {
"custom_prompts": {
- "generate_python_code": MyCustomPrompt(my_custom_value="my custom value")
+ "generate_python_code": MyCustomPrompt(
+ my_custom_value="my custom value")
}
})
```
@@ -36,15 +39,17 @@ You can directly access the default prompt variables (for example dfs, conversat
```python
from pandasai import SmartDataframe
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
-class MyCustomPrompt(Prompt):
- text = """You are given a dataframe with number if rows equal to {dfs[0].shape[0]} and number of columns equal to {dfs[0].shape[1]}
+
+class MyCustomPrompt(AbstractPrompt):
+ template = """You are given a dataframe with number if rows equal to {dfs[0].shape[0]} and number of columns equal to {dfs[0].shape[1]}
Here's the conversation:
{conversation}
"""
+
df = SmartDataframe("data.csv", {
"custom_prompts": {
"generate_python_code": MyCustomPrompt()
diff --git a/pandasai/__init__.py b/pandasai/__init__.py
index 6a63677b1..15d6ea229 100644
--- a/pandasai/__init__.py
+++ b/pandasai/__init__.py
@@ -40,7 +40,7 @@
import pandas as pd
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake
-from .prompts.base import Prompt
+from .prompts.base import AbstractPrompt
from .callbacks.base import BaseCallback
from .schemas.df_config import Config
from .helpers.cache import Cache
@@ -112,7 +112,7 @@ def __init__(
middlewares=None,
custom_whitelisted_dependencies=None,
enable_logging=True,
- non_default_prompts: Optional[Dict[str, Type[Prompt]]] = None,
+ non_default_prompts: Optional[Dict[str, Type[AbstractPrompt]]] = None,
callback: Optional[BaseCallback] = None,
):
"""
diff --git a/pandasai/helpers/from_google_sheets.py b/pandasai/helpers/from_google_sheets.py
index 523277ccc..44cd1b159 100644
--- a/pandasai/helpers/from_google_sheets.py
+++ b/pandasai/helpers/from_google_sheets.py
@@ -26,7 +26,7 @@ def get_google_sheet(src) -> list:
cols = row.find_all("td")
clean_row = []
for col in cols:
- clean_row.append(col.text)
+ clean_row.append(col.template)
grid.append(clean_row)
return grid
diff --git a/pandasai/llm/azure_openai.py b/pandasai/llm/azure_openai.py
index 32c543a82..ffe722146 100644
--- a/pandasai/llm/azure_openai.py
+++ b/pandasai/llm/azure_openai.py
@@ -18,7 +18,7 @@
from ..helpers import load_dotenv
from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
-from ..prompts.base import Prompt
+from ..prompts.base import AbstractPrompt
from .base import BaseOpenAI
load_dotenv()
@@ -105,12 +105,12 @@ def _default_params(self) -> Dict[str, Any]:
"""
return {**super()._default_params, "engine": self.engine}
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the Azure OpenAI LLM.
Args:
- instruction (Prompt): A prompt object with instruction for LLM.
+ instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): Suffix to pass.
Returns:
diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py
index 2b6fcad9a..5b2e72012 100644
--- a/pandasai/llm/base.py
+++ b/pandasai/llm/base.py
@@ -29,7 +29,7 @@ class CustomLLM(BaseOpenAI):
NoCodeFoundError,
)
from ..helpers.openai_info import openai_callback_var
-from ..prompts.base import Prompt
+from ..prompts.base import AbstractPrompt
class LLM:
@@ -120,12 +120,12 @@ def _extract_code(self, response: str, separator: str = "```") -> str:
return code
@abstractmethod
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Execute the LLM with given prompt.
Args:
- instruction (Prompt): A prompt object with instruction for LLM.
+ instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str, optional): Suffix. Defaults to "".
Raises:
@@ -134,12 +134,12 @@ def call(self, instruction: Prompt, suffix: str = "") -> str:
"""
raise MethodNotImplementedError("Call method has not been implemented")
- def generate_code(self, instruction: Prompt) -> str:
+ def generate_code(self, instruction: AbstractPrompt) -> str:
"""
Generate the code based on the instruction and the given prompt.
Args:
- instruction (Prompt): Prompt with instruction for LLM.
+ instruction (AbstractPrompt): Prompt with instruction for LLM.
Returns:
str: A string of Python code.
@@ -334,11 +334,11 @@ def query(self, payload) -> str:
return response.json()[0]["generated_text"]
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
A call method of HuggingFaceLLM class.
Args:
- instruction (Prompt): A prompt object with instruction for LLM.
+ instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): A string representing the suffix to be truncated
from the generated response.
@@ -429,12 +429,12 @@ def _generate_text(self, prompt: str) -> str:
"""
raise MethodNotImplementedError("method has not been implemented")
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the Google LLM.
Args:
- instruction (Prompt): Instruction to pass.
+ instruction (AbstractPrompt): Instruction to pass.
suffix (str): Suffix to pass. Defaults to an empty string ("").
Returns:
diff --git a/pandasai/llm/fake.py b/pandasai/llm/fake.py
index f92166290..8be3b9b98 100644
--- a/pandasai/llm/fake.py
+++ b/pandasai/llm/fake.py
@@ -2,7 +2,7 @@
from typing import Optional
-from ..prompts.base import Prompt
+from ..prompts.base import AbstractPrompt
from .base import LLM
@@ -16,7 +16,7 @@ def __init__(self, output: Optional[str] = None):
if output is not None:
self._output = output
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
self.last_prompt = instruction.to_string() + suffix
return self._output
diff --git a/pandasai/llm/huggingface_text_gen.py b/pandasai/llm/huggingface_text_gen.py
index 4bbbbbbbf..fd14fc2de 100644
--- a/pandasai/llm/huggingface_text_gen.py
+++ b/pandasai/llm/huggingface_text_gen.py
@@ -2,7 +2,7 @@
from .base import LLM
from ..helpers import load_dotenv
-from ..prompts.base import Prompt
+from ..prompts.base import AbstractPrompt
load_dotenv()
@@ -14,7 +14,7 @@ class HuggingFaceTextGen(LLM):
top_k: Optional[int] = None
top_p: Optional[float] = 0.8
typical_p: Optional[float] = 0.8
- temperature: float = 1E-3 # must be strictly positive
+ temperature: float = 1e-3 # must be strictly positive
repetition_penalty: Optional[float] = None
truncate: Optional[int] = None
stop_sequences: List[str] = None
@@ -29,7 +29,7 @@ def __init__(self, inference_server_url: str, **kwargs):
try:
import text_generation
- for (key, val) in kwargs.items():
+ for key, val in kwargs.items():
if key in self.__annotations__:
setattr(self, key, val)
@@ -60,14 +60,14 @@ def _default_params(self) -> Dict[str, Any]:
"seed": self.seed,
}
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
prompt = instruction.to_string() + suffix
params = self._default_params
if self.streaming:
completion = ""
for chunk in self.client.generate_stream(prompt, **params):
- completion += chunk.text
+ completion += chunk.template
return completion
res = self.client.generate(prompt, **params)
@@ -76,8 +76,8 @@ def call(self, instruction: Prompt, suffix: str = "") -> str:
for stop_seq in self.stop_sequences:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
- :res.generated_text.index(stop_seq)
- ]
+ : res.generated_text.index(stop_seq)
+ ]
return res.generated_text
@property
diff --git a/pandasai/llm/langchain.py b/pandasai/llm/langchain.py
index 7364c54d4..d6e74a97e 100644
--- a/pandasai/llm/langchain.py
+++ b/pandasai/llm/langchain.py
@@ -1,4 +1,4 @@
-from pandasai.prompts.base import Prompt
+from pandasai.prompts.base import AbstractPrompt
from .base import LLM
@@ -13,7 +13,7 @@ class LangchainLLM(LLM):
def __init__(self, langchain_llm):
self._langchain_llm = langchain_llm
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
prompt = instruction.to_string() + suffix
return self._langchain_llm.predict(prompt)
diff --git a/pandasai/llm/openai.py b/pandasai/llm/openai.py
index 8bc7d779c..9a6fb6066 100644
--- a/pandasai/llm/openai.py
+++ b/pandasai/llm/openai.py
@@ -15,7 +15,7 @@
from ..helpers import load_dotenv
from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
-from ..prompts.base import Prompt
+from ..prompts.base import AbstractPrompt
from .base import BaseOpenAI
load_dotenv()
@@ -85,12 +85,12 @@ def _default_params(self) -> Dict[str, Any]:
"model": self.model,
}
- def call(self, instruction: Prompt, suffix: str = "") -> str:
+ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str:
"""
Call the OpenAI LLM.
Args:
- instruction (Prompt): A prompt object with instruction for LLM.
+ instruction (AbstractPrompt): A prompt object with instruction for LLM.
suffix (str): Suffix to pass.
Raises:
diff --git a/pandasai/prompts/__init__.py b/pandasai/prompts/__init__.py
index 7d37a6699..5fc9f4a87 100644
--- a/pandasai/prompts/__init__.py
+++ b/pandasai/prompts/__init__.py
@@ -1,9 +1,9 @@
-from .base import Prompt
-from .correct_error_prompt import CorrectErrorPrompt
-from .generate_python_code import GeneratePythonCodePrompt
+from .base import AbstractPrompt
+from .correct_error_prompt import CorrectErrorAbstractPrompt
+from .generate_python_code import GeneratePythonCodeAbstractPrompt
__all__ = [
- "Prompt",
- "CorrectErrorPrompt",
- "GeneratePythonCodePrompt",
+ "AbstractPrompt",
+ "CorrectErrorAbstractPrompt",
+ "GeneratePythonCodeAbstractPrompt",
]
diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py
index f53ac9082..93d0f4710 100644
--- a/pandasai/prompts/base.py
+++ b/pandasai/prompts/base.py
@@ -1,14 +1,13 @@
""" Base class to implement a new Prompt
In order to better handle the instructions, this prompt module is written.
"""
+import os
+from abc import ABC, abstractmethod
-from pandasai.exceptions import MethodNotImplementedError
-
-class Prompt:
+class AbstractPrompt(ABC):
"""Base class to implement a new Prompt"""
- text = None
_args = {}
def __init__(self, **kwargs):
@@ -46,16 +45,38 @@ def _generate_dataframes(self, dfs):
return "\n\n".join(dataframes)
+ @property
+ @abstractmethod
+ def template(self):
+ ...
+
def set_var(self, var, value):
if var == "dfs":
self._args["dataframes"] = self._generate_dataframes(value)
self._args[var] = value
def to_string(self):
- if self.text is None:
- raise MethodNotImplementedError
-
- return self.text.format(**self._args)
+ return self.template.format(**self._args)
def __str__(self):
return self.to_string()
+
+
+class FileBasedPrompt(AbstractPrompt):
+ _path_to_template: str
+
+ def __init__(self, **kwargs):
+ if (template_path := kwargs.pop("path_to_template", None)) is not None:
+ self._path_to_template = template_path
+
+ super().__init__(**kwargs)
+
+ @property
+ def template(self):
+ if not os.path.exists(self._path_to_template):
+ raise FileNotFoundError(
+ f"Unable to find a file with template at '{self._path_to_template}' "
+ f"for '{self.__class__.__name__}' prompt."
+ )
+ with open(self._path_to_template) as fp:
+ return fp.read()
diff --git a/pandasai/prompts/correct_error_prompt.py b/pandasai/prompts/correct_error_prompt.py
index 3a961761c..07b6ee8b5 100644
--- a/pandasai/prompts/correct_error_prompt.py
+++ b/pandasai/prompts/correct_error_prompt.py
@@ -16,25 +16,10 @@
Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again.
""" # noqa: E501
-from .base import Prompt
+from .base import FileBasedPrompt
-class CorrectErrorPrompt(Prompt):
+class CorrectErrorAbstractPrompt(FileBasedPrompt):
"""Prompt to Correct Python code on Error"""
- text: str = """
-You are provided with the following {engine} DataFrames with the following metadata:
-
-{dataframes}
-
-The user asked the following question:
-{conversation}
-
-You generated this python code:
-{code}
-
-It fails with the following error:
-{error_returned}
-
-Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again.
-""" # noqa: E501
+ _path_to_template = "assets/prompt-templates/correct_error_prompt.tmpl"
diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py
index baa39e3ad..5c369b8cf 100644
--- a/pandasai/prompts/generate_python_code.py
+++ b/pandasai/prompts/generate_python_code.py
@@ -33,47 +33,19 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict:
""" # noqa: E501
-from .base import Prompt
+from .base import FileBasedPrompt
-class GeneratePythonCodePrompt(Prompt):
+class GeneratePythonCodeAbstractPrompt(FileBasedPrompt):
"""Prompt to generate Python code"""
- text: str = """
-You are provided with the following pandas DataFrames:
-
-{dataframes}
-
-
-{conversation}
-
-
-This is the initial python code to be updated:
-```python
-# TODO import all the dependencies required
-{default_import}
+ _path_to_template = "assets/prompt-templates/generate_python_code.tmpl"
-def analyze_data(dfs: list[{engine_df_name}]) -> dict:
- \"\"\"
- Analyze the data
- 1. Prepare: Preprocessing and cleaning data if necessary
- 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
- 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.)
- 4. Output: return a dictionary of:
- - type (possible values "text", "number", "dataframe", "plot")
- - value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
- Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }}
- \"\"\"
-```
-
-Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
-
-Updated code:
-""" # noqa: E501
-
- def __init__(self):
+ def __init__(self, **kwargs):
default_import = "import pandas as pd"
engine_df_name = "pd.DataFrame"
self.set_var("default_import", default_import)
self.set_var("engine_df_name", engine_df_name)
+
+ super().__init__(**kwargs)
diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py
index 7ae01e5ed..9cac86650 100644
--- a/pandasai/smart_datalake/__init__.py
+++ b/pandasai/smart_datalake/__init__.py
@@ -32,9 +32,9 @@
from ..helpers.memory import Memory
from ..schemas.df_config import Config
from ..config import load_config
-from ..prompts.base import Prompt
-from ..prompts.correct_error_prompt import CorrectErrorPrompt
-from ..prompts.generate_python_code import GeneratePythonCodePrompt
+from ..prompts.base import AbstractPrompt
+from ..prompts.correct_error_prompt import CorrectErrorAbstractPrompt
+from ..prompts.generate_python_code import GeneratePythonCodeAbstractPrompt
from typing import Union, List, Any, Type, Optional
from ..helpers.code_manager import CodeManager
from ..middlewares.base import Middleware
@@ -204,20 +204,20 @@ def _is_running_in_console(self) -> bool:
def _get_prompt(
self,
key: str,
- default_prompt: Type[Prompt],
+ default_prompt: Type[AbstractPrompt],
default_values: Optional[dict] = None,
- ) -> Prompt:
+ ) -> AbstractPrompt:
"""
Return a prompt by key.
Args:
key (str): The key of the prompt
- default_prompt (Type[Prompt]): The default prompt to use
+ default_prompt (Type[AbstractPrompt]): The default prompt to use
default_values (Optional[dict], optional): The default values to use for the
prompt. Defaults to None.
Returns:
- Prompt: The prompt
+ AbstractPrompt: The prompt
"""
if default_values is None:
default_values = {}
@@ -287,7 +287,7 @@ def chat(self, query: str):
}
generate_python_code_instruction = self._get_prompt(
"generate_python_code",
- default_prompt=GeneratePythonCodePrompt,
+ default_prompt=GeneratePythonCodeAbstractPrompt,
default_values=default_values,
)
@@ -436,7 +436,7 @@ def _retry_run_code(self, code: str, e: Exception):
}
error_correcting_instruction = self._get_prompt(
"correct_error",
- default_prompt=CorrectErrorPrompt,
+ default_prompt=CorrectErrorAbstractPrompt,
default_values=default_values,
)
diff --git a/tests/connectors/test_base.py b/tests/connectors/test_base.py
index f80ba5b11..8533dd91c 100644
--- a/tests/connectors/test_base.py
+++ b/tests/connectors/test_base.py
@@ -14,7 +14,6 @@ def __init__(self, host, port, database, table):
# Mock subclass of BaseConnector for testing
class MockConnector(BaseConnector):
-
def _load_connector_config(self, config: BaseConnectorConfig):
pass
diff --git a/tests/llms/test_base_hf.py b/tests/llms/test_base_hf.py
index c09430c32..1965a7cd9 100644
--- a/tests/llms/test_base_hf.py
+++ b/tests/llms/test_base_hf.py
@@ -4,7 +4,7 @@
import requests
from pandasai.llm.base import HuggingFaceLLM
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
class TestBaseHfLLM:
@@ -16,10 +16,10 @@ def api_response(self):
@pytest.fixture
def prompt(self):
- class MockPrompt(Prompt):
- text: str = "instruction"
+ class MockAbstractPrompt(AbstractPrompt):
+ template: str = "instruction"
- return MockPrompt()
+ return MockAbstractPrompt()
def test_type(self):
assert HuggingFaceLLM(api_token="test_token").type == "huggingface-llm"
@@ -62,10 +62,10 @@ def test_call(self, mocker, prompt):
def test_call_removes_original_prompt(self, mocker):
huggingface = HuggingFaceLLM(api_token="test_token")
- class MockPrompt(Prompt):
- text: str = "instruction "
+ class MockAbstractPrompt(AbstractPrompt):
+ template: str = "instruction "
- instruction = MockPrompt()
+ instruction = MockAbstractPrompt()
suffix = "suffix "
mocker.patch.object(
diff --git a/tests/llms/test_google_palm.py b/tests/llms/test_google_palm.py
index 1d03534fa..91162819b 100644
--- a/tests/llms/test_google_palm.py
+++ b/tests/llms/test_google_palm.py
@@ -6,7 +6,7 @@
from pandasai.exceptions import APIKeyNotFoundError
from pandasai.llm import GooglePalm
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
class MockedCompletion:
@@ -19,10 +19,10 @@ class TestGooglePalm:
@pytest.fixture
def prompt(self):
- class MockPrompt(Prompt):
- text: str = "Hello"
+ class MockAbstractPrompt(AbstractPrompt):
+ template: str = "Hello"
- return MockPrompt()
+ return MockAbstractPrompt()
def test_type_without_token(self):
with pytest.raises(APIKeyNotFoundError):
diff --git a/tests/llms/test_huggingface_text_gen.py b/tests/llms/test_huggingface_text_gen.py
index 1407662b3..1925b1561 100644
--- a/tests/llms/test_huggingface_text_gen.py
+++ b/tests/llms/test_huggingface_text_gen.py
@@ -1,10 +1,10 @@
"""Unit tests for the LLaMa2TextGen LLM class"""
-from pandasai import Prompt
+from pandasai import AbstractPrompt
from pandasai.llm import HuggingFaceTextGen
-class MockPrompt(Prompt):
- text: str = "instruction."
+class MockAbstractPrompt(AbstractPrompt):
+ template: str = "instruction."
class MockResponse:
@@ -19,9 +19,8 @@ class TestHuggingFaceTextGen:
def test_type_with_token(self):
assert (
- HuggingFaceTextGen(
- inference_server_url="http://127.0.0.1:8080"
- ).type == "huggingface-text-generation"
+ HuggingFaceTextGen(inference_server_url="http://127.0.0.1:8080").type
+ == "huggingface-text-generation"
)
def test_params_setting(self):
@@ -30,7 +29,7 @@ def test_params_setting(self):
max_new_tokens=1024,
top_p=0.8,
typical_p=0.8,
- temperature=1E-3,
+ temperature=1e-3,
stop_sequences=["\n"],
seed=0,
do_sample=False,
@@ -53,11 +52,9 @@ def test_completion(self, mocker):
expected_text = "This is the generated text."
tgi_mock.return_value = MockResponse(expected_text)
- llm = HuggingFaceTextGen(
- inference_server_url="http://127.0.0.1:8080"
- )
+ llm = HuggingFaceTextGen(inference_server_url="http://127.0.0.1:8080")
- instruction = MockPrompt()
+ instruction = MockAbstractPrompt()
result = llm.call(instruction)
tgi_mock.assert_called_once_with(
diff --git a/tests/llms/test_langchain_llm.py b/tests/llms/test_langchain_llm.py
index e9859ed07..a93e95e77 100644
--- a/tests/llms/test_langchain_llm.py
+++ b/tests/llms/test_langchain_llm.py
@@ -4,7 +4,7 @@
import pytest
from pandasai.llm import LangchainLLM
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
from unittest.mock import Mock
@@ -25,10 +25,10 @@ def __call__(self, _prompt, stop=None, callbacks=None, **kwargs):
@pytest.fixture
def prompt(self):
- class MockPrompt(Prompt):
- text: str = "Hello"
+ class MockAbstractPrompt(AbstractPrompt):
+ template: str = "Hello"
- return MockPrompt()
+ return MockAbstractPrompt()
def test_langchain_llm_type(self, langchain_llm):
langchain_wrapper = LangchainLLM(langchain_llm)
diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py
index 65444a079..dacac2ada 100644
--- a/tests/llms/test_openai.py
+++ b/tests/llms/test_openai.py
@@ -4,7 +4,7 @@
from pandasai.exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError
from pandasai.llm import OpenAI
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
from openai.openai_object import OpenAIObject
@@ -13,10 +13,10 @@ class TestOpenAILLM:
@pytest.fixture
def prompt(self):
- class MockPrompt(Prompt):
- text: str = "instruction"
+ class MockAbstractPrompt(AbstractPrompt):
+ template: str = "instruction"
- return MockPrompt()
+ return MockAbstractPrompt()
def test_type_without_token(self):
with pytest.raises(APIKeyNotFoundError):
diff --git a/tests/prompts/test_base_prompt.py b/tests/prompts/test_base_prompt.py
index 9e0ff6904..6ad67c3fd 100644
--- a/tests/prompts/test_base_prompt.py
+++ b/tests/prompts/test_base_prompt.py
@@ -2,56 +2,10 @@
import pytest
-from pandasai.exceptions import MethodNotImplementedError
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
class TestBasePrompt:
- """Unit tests for the base prompt class"""
-
- def test_text(self):
- """Test that the text attribute is required"""
- with pytest.raises(MethodNotImplementedError):
- print(Prompt())
-
- def test_str(self):
- """Test that the __str__ method is implemented"""
-
- class TestPrompt(Prompt):
- """Test prompt"""
-
- text = "Test prompt"
-
- assert str(TestPrompt()) == "Test prompt"
-
- def test_str_not_implemented(self):
- """Test that the __str__ method is implemented"""
-
- class TestPrompt(Prompt):
- """Test prompt"""
-
- pass
-
- with pytest.raises(MethodNotImplementedError):
- str(TestPrompt())
-
- def test_str_with_args(self):
- """Test that the __str__ method is implemented"""
-
- class TestPrompt(Prompt):
- """Test prompt"""
-
- text = "Test prompt {arg1} {arg2}"
-
- assert str(TestPrompt(arg1="arg1", arg2="arg2")) == "Test prompt arg1 arg2"
-
- def test_str_with_args_not_implemented(self):
- """Test that the __str__ method is implemented"""
-
- class TestPrompt(Prompt):
- """Test prompt"""
-
- text = "Test prompt {arg1} {arg2}"
-
- with pytest.raises(KeyError):
- str(TestPrompt())
+ def test_instantiate_without_template(self):
+ with pytest.raises(TypeError):
+ AbstractPrompt()
diff --git a/tests/prompts/test_correct_error_prompt.py b/tests/prompts/test_correct_error_prompt.py
index aaa29f66f..2ce8303e2 100644
--- a/tests/prompts/test_correct_error_prompt.py
+++ b/tests/prompts/test_correct_error_prompt.py
@@ -1,8 +1,9 @@
"""Unit tests for the correct error prompt class"""
+import sys
import pandas as pd
from pandasai import SmartDataframe
-from pandasai.prompts import CorrectErrorPrompt
+from pandasai.prompts import CorrectErrorAbstractPrompt
from pandasai.llm.fake import FakeLLM
@@ -19,14 +20,17 @@ def test_str_with_args(self):
config={"llm": llm},
)
]
- prompt = CorrectErrorPrompt(
+ prompt = CorrectErrorAbstractPrompt(
engine="pandas", code="df.head()", error_returned="Error message"
)
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "What is the correct code?")
+ prompt_content = prompt.to_string()
+ if sys.platform.startswith("win"):
+ prompt_content = prompt_content.replace("\r\n", "\n")
assert (
- prompt.to_string()
+ prompt_content
== """
You are provided with the following pandas DataFrames with the following metadata:
diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py
index abad90140..7bf4a1673 100644
--- a/tests/prompts/test_generate_python_code_prompt.py
+++ b/tests/prompts/test_generate_python_code_prompt.py
@@ -1,9 +1,9 @@
"""Unit tests for the generate python code prompt class"""
-
+import sys
import pandas as pd
from pandasai import SmartDataframe
-from pandasai.prompts import GeneratePythonCodePrompt
+from pandasai.prompts import GeneratePythonCodeAbstractPrompt
from pandasai.llm.fake import FakeLLM
@@ -20,12 +20,17 @@ def test_str_with_args(self):
config={"llm": llm},
)
]
- prompt = GeneratePythonCodePrompt()
+ prompt = GeneratePythonCodeAbstractPrompt()
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "exports/charts")
+
+ prompt_content = prompt.to_string()
+ if sys.platform.startswith("win"):
+ prompt_content = prompt_content.replace("\r\n", "\n")
+
assert (
- prompt.to_string()
+ prompt_content
== """
You are provided with the following pandas DataFrames:
@@ -75,13 +80,17 @@ def test_str_with_custom_save_charts_path(self):
)
]
- prompt = GeneratePythonCodePrompt()
+ prompt = GeneratePythonCodeAbstractPrompt()
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "custom_path")
+ prompt_content = prompt.to_string()
+ if sys.platform.startswith("win"):
+ prompt_content = prompt_content.replace("\r\n", "\n")
+
assert (
- prompt.to_string()
+ prompt_content
== """
You are provided with the following pandas DataFrames:
diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py
index dced66ee5..bd1880152 100644
--- a/tests/test_smartdataframe.py
+++ b/tests/test_smartdataframe.py
@@ -17,7 +17,7 @@
from pandasai.llm.fake import FakeLLM
from pandasai.middlewares import Middleware
from pandasai.callbacks import StdoutCallback
-from pandasai.prompts import Prompt
+from pandasai.prompts import AbstractPrompt
from pandasai.helpers.cache import Cache
import logging
@@ -377,13 +377,13 @@ def test_shortcut(self, smart_dataframe: SmartDataframe):
smart_dataframe.chat.assert_called_once()
def test_replace_generate_code_prompt(self, llm):
- class CustomPrompt(Prompt):
- text: str = """{test} || {dfs[0].shape[1]} || {conversation}"""
+ class CustomAbstractPrompt(AbstractPrompt):
+ template: str = """{test} || {dfs[0].shape[1]} || {conversation}"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
- replacement_prompt = CustomPrompt(test="test value")
+ replacement_prompt = CustomAbstractPrompt(test="test value")
df = SmartDataframe(
pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}),
config={
@@ -399,8 +399,10 @@ def __init__(self, **kwargs):
assert llm.last_prompt == expected_last_prompt
def test_replace_correct_error_prompt(self, llm):
- class ReplacementPrompt(Prompt):
- text = "Custom prompt"
+ class ReplacementPrompt(AbstractPrompt):
+ @property
+ def template(self):
+ return "Custom prompt"
replacement_prompt = ReplacementPrompt()
df = SmartDataframe(
@@ -510,7 +512,7 @@ def test_updates_configs_with_setters(self, smart_dataframe: SmartDataframe):
smart_dataframe.use_error_correction_framework = False
assert smart_dataframe.use_error_correction_framework is False
- smart_dataframe.custom_prompts = {"generate_python_code": Prompt()}
+ smart_dataframe.custom_prompts = {"generate_python_code": AbstractPrompt()}
assert smart_dataframe.custom_prompts != {}
smart_dataframe.save_charts = True