diff --git a/assets/prompt-templates/correct_error_prompt.tmpl b/assets/prompt-templates/correct_error_prompt.tmpl new file mode 100644 index 000000000..c73687279 --- /dev/null +++ b/assets/prompt-templates/correct_error_prompt.tmpl @@ -0,0 +1,15 @@ + +You are provided with the following {engine} DataFrames with the following metadata: + +{dataframes} + +The user asked the following question: +{conversation} + +You generated this python code: +{code} + +It fails with the following error: +{error_returned} + +Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again. diff --git a/assets/prompt-templates/generate_python_code.tmpl b/assets/prompt-templates/generate_python_code.tmpl new file mode 100644 index 000000000..10a057c58 --- /dev/null +++ b/assets/prompt-templates/generate_python_code.tmpl @@ -0,0 +1,30 @@ + +You are provided with the following pandas DataFrames: + +{dataframes} + + +{conversation} + + +This is the initial python code to be updated: +```python +# TODO import all the dependencies required +{default_import} + +def analyze_data(dfs: list[{engine_df_name}]) -> dict: + """ + Analyze the data + 1. Prepare: Preprocessing and cleaning data if necessary + 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.) + 4. Output: return a dictionary of: + - type (possible values "text", "number", "dataframe", "plot") + - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) + Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }} + """ +``` + +Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation. + +Updated code: diff --git a/docs/custom-prompts.md b/docs/custom-prompts.md index 9a64a3968..b375ada83 100644 --- a/docs/custom-prompts.md +++ b/docs/custom-prompts.md @@ -16,14 +16,17 @@ To create your custom prompt create a new CustomPromptClass inherited from base ```python from pandasai import SmartDataframe -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt + + +class MyCustomPrompt(AbstractPrompt): + template = """This is your custom text for your prompt with custom {my_custom_value}""" -class MyCustomPrompt(Prompt): - text = """This is your custom text for your prompt with custom {my_custom_value}""" df = SmartDataframe("data.csv", { "custom_prompts": { - "generate_python_code": MyCustomPrompt(my_custom_value="my custom value") + "generate_python_code": MyCustomPrompt( + my_custom_value="my custom value") } }) ``` @@ -36,15 +39,17 @@ You can directly access the default prompt variables (for example dfs, conversat ```python from pandasai import SmartDataframe -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt -class MyCustomPrompt(Prompt): - text = """You are given a dataframe with number if rows equal to {dfs[0].shape[0]} and number of columns equal to {dfs[0].shape[1]} + +class MyCustomPrompt(AbstractPrompt): + template = """You are given a dataframe with number if rows equal to {dfs[0].shape[0]} and number of columns equal to {dfs[0].shape[1]} Here's the conversation: {conversation} """ + df = SmartDataframe("data.csv", { "custom_prompts": { "generate_python_code": MyCustomPrompt() diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 6a63677b1..15d6ea229 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -40,7 +40,7 @@ import pandas as pd from .smart_dataframe import SmartDataframe from .smart_datalake import SmartDatalake -from .prompts.base import Prompt +from .prompts.base import AbstractPrompt from .callbacks.base import BaseCallback from .schemas.df_config import Config from .helpers.cache import Cache @@ -112,7 +112,7 @@ def __init__( middlewares=None, custom_whitelisted_dependencies=None, enable_logging=True, - non_default_prompts: Optional[Dict[str, Type[Prompt]]] = None, + non_default_prompts: Optional[Dict[str, Type[AbstractPrompt]]] = None, callback: Optional[BaseCallback] = None, ): """ diff --git a/pandasai/helpers/from_google_sheets.py b/pandasai/helpers/from_google_sheets.py index 523277ccc..44cd1b159 100644 --- a/pandasai/helpers/from_google_sheets.py +++ b/pandasai/helpers/from_google_sheets.py @@ -26,7 +26,7 @@ def get_google_sheet(src) -> list: cols = row.find_all("td") clean_row = [] for col in cols: - clean_row.append(col.text) + clean_row.append(col.template) grid.append(clean_row) return grid diff --git a/pandasai/llm/azure_openai.py b/pandasai/llm/azure_openai.py index 32c543a82..ffe722146 100644 --- a/pandasai/llm/azure_openai.py +++ b/pandasai/llm/azure_openai.py @@ -18,7 +18,7 @@ from ..helpers import load_dotenv from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError -from ..prompts.base import Prompt +from ..prompts.base import AbstractPrompt from .base import BaseOpenAI load_dotenv() @@ -105,12 +105,12 @@ def _default_params(self) -> Dict[str, Any]: """ return {**super()._default_params, "engine": self.engine} - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ Call the Azure OpenAI LLM. Args: - instruction (Prompt): A prompt object with instruction for LLM. + instruction (AbstractPrompt): A prompt object with instruction for LLM. suffix (str): Suffix to pass. Returns: diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py index 2b6fcad9a..5b2e72012 100644 --- a/pandasai/llm/base.py +++ b/pandasai/llm/base.py @@ -29,7 +29,7 @@ class CustomLLM(BaseOpenAI): NoCodeFoundError, ) from ..helpers.openai_info import openai_callback_var -from ..prompts.base import Prompt +from ..prompts.base import AbstractPrompt class LLM: @@ -120,12 +120,12 @@ def _extract_code(self, response: str, separator: str = "```") -> str: return code @abstractmethod - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ Execute the LLM with given prompt. Args: - instruction (Prompt): A prompt object with instruction for LLM. + instruction (AbstractPrompt): A prompt object with instruction for LLM. suffix (str, optional): Suffix. Defaults to "". Raises: @@ -134,12 +134,12 @@ def call(self, instruction: Prompt, suffix: str = "") -> str: """ raise MethodNotImplementedError("Call method has not been implemented") - def generate_code(self, instruction: Prompt) -> str: + def generate_code(self, instruction: AbstractPrompt) -> str: """ Generate the code based on the instruction and the given prompt. Args: - instruction (Prompt): Prompt with instruction for LLM. + instruction (AbstractPrompt): Prompt with instruction for LLM. Returns: str: A string of Python code. @@ -334,11 +334,11 @@ def query(self, payload) -> str: return response.json()[0]["generated_text"] - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ A call method of HuggingFaceLLM class. Args: - instruction (Prompt): A prompt object with instruction for LLM. + instruction (AbstractPrompt): A prompt object with instruction for LLM. suffix (str): A string representing the suffix to be truncated from the generated response. @@ -429,12 +429,12 @@ def _generate_text(self, prompt: str) -> str: """ raise MethodNotImplementedError("method has not been implemented") - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ Call the Google LLM. Args: - instruction (Prompt): Instruction to pass. + instruction (AbstractPrompt): Instruction to pass. suffix (str): Suffix to pass. Defaults to an empty string (""). Returns: diff --git a/pandasai/llm/fake.py b/pandasai/llm/fake.py index f92166290..8be3b9b98 100644 --- a/pandasai/llm/fake.py +++ b/pandasai/llm/fake.py @@ -2,7 +2,7 @@ from typing import Optional -from ..prompts.base import Prompt +from ..prompts.base import AbstractPrompt from .base import LLM @@ -16,7 +16,7 @@ def __init__(self, output: Optional[str] = None): if output is not None: self._output = output - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: self.last_prompt = instruction.to_string() + suffix return self._output diff --git a/pandasai/llm/huggingface_text_gen.py b/pandasai/llm/huggingface_text_gen.py index 4bbbbbbbf..fd14fc2de 100644 --- a/pandasai/llm/huggingface_text_gen.py +++ b/pandasai/llm/huggingface_text_gen.py @@ -2,7 +2,7 @@ from .base import LLM from ..helpers import load_dotenv -from ..prompts.base import Prompt +from ..prompts.base import AbstractPrompt load_dotenv() @@ -14,7 +14,7 @@ class HuggingFaceTextGen(LLM): top_k: Optional[int] = None top_p: Optional[float] = 0.8 typical_p: Optional[float] = 0.8 - temperature: float = 1E-3 # must be strictly positive + temperature: float = 1e-3 # must be strictly positive repetition_penalty: Optional[float] = None truncate: Optional[int] = None stop_sequences: List[str] = None @@ -29,7 +29,7 @@ def __init__(self, inference_server_url: str, **kwargs): try: import text_generation - for (key, val) in kwargs.items(): + for key, val in kwargs.items(): if key in self.__annotations__: setattr(self, key, val) @@ -60,14 +60,14 @@ def _default_params(self) -> Dict[str, Any]: "seed": self.seed, } - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: prompt = instruction.to_string() + suffix params = self._default_params if self.streaming: completion = "" for chunk in self.client.generate_stream(prompt, **params): - completion += chunk.text + completion += chunk.template return completion res = self.client.generate(prompt, **params) @@ -76,8 +76,8 @@ def call(self, instruction: Prompt, suffix: str = "") -> str: for stop_seq in self.stop_sequences: if stop_seq in res.generated_text: res.generated_text = res.generated_text[ - :res.generated_text.index(stop_seq) - ] + : res.generated_text.index(stop_seq) + ] return res.generated_text @property diff --git a/pandasai/llm/langchain.py b/pandasai/llm/langchain.py index 7364c54d4..d6e74a97e 100644 --- a/pandasai/llm/langchain.py +++ b/pandasai/llm/langchain.py @@ -1,4 +1,4 @@ -from pandasai.prompts.base import Prompt +from pandasai.prompts.base import AbstractPrompt from .base import LLM @@ -13,7 +13,7 @@ class LangchainLLM(LLM): def __init__(self, langchain_llm): self._langchain_llm = langchain_llm - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: prompt = instruction.to_string() + suffix return self._langchain_llm.predict(prompt) diff --git a/pandasai/llm/openai.py b/pandasai/llm/openai.py index 8bc7d779c..9a6fb6066 100644 --- a/pandasai/llm/openai.py +++ b/pandasai/llm/openai.py @@ -15,7 +15,7 @@ from ..helpers import load_dotenv from ..exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError -from ..prompts.base import Prompt +from ..prompts.base import AbstractPrompt from .base import BaseOpenAI load_dotenv() @@ -85,12 +85,12 @@ def _default_params(self) -> Dict[str, Any]: "model": self.model, } - def call(self, instruction: Prompt, suffix: str = "") -> str: + def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ Call the OpenAI LLM. Args: - instruction (Prompt): A prompt object with instruction for LLM. + instruction (AbstractPrompt): A prompt object with instruction for LLM. suffix (str): Suffix to pass. Raises: diff --git a/pandasai/prompts/__init__.py b/pandasai/prompts/__init__.py index 7d37a6699..5fc9f4a87 100644 --- a/pandasai/prompts/__init__.py +++ b/pandasai/prompts/__init__.py @@ -1,9 +1,9 @@ -from .base import Prompt -from .correct_error_prompt import CorrectErrorPrompt -from .generate_python_code import GeneratePythonCodePrompt +from .base import AbstractPrompt +from .correct_error_prompt import CorrectErrorAbstractPrompt +from .generate_python_code import GeneratePythonCodeAbstractPrompt __all__ = [ - "Prompt", - "CorrectErrorPrompt", - "GeneratePythonCodePrompt", + "AbstractPrompt", + "CorrectErrorAbstractPrompt", + "GeneratePythonCodeAbstractPrompt", ] diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index f53ac9082..93d0f4710 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -1,14 +1,13 @@ """ Base class to implement a new Prompt In order to better handle the instructions, this prompt module is written. """ +import os +from abc import ABC, abstractmethod -from pandasai.exceptions import MethodNotImplementedError - -class Prompt: +class AbstractPrompt(ABC): """Base class to implement a new Prompt""" - text = None _args = {} def __init__(self, **kwargs): @@ -46,16 +45,38 @@ def _generate_dataframes(self, dfs): return "\n\n".join(dataframes) + @property + @abstractmethod + def template(self): + ... + def set_var(self, var, value): if var == "dfs": self._args["dataframes"] = self._generate_dataframes(value) self._args[var] = value def to_string(self): - if self.text is None: - raise MethodNotImplementedError - - return self.text.format(**self._args) + return self.template.format(**self._args) def __str__(self): return self.to_string() + + +class FileBasedPrompt(AbstractPrompt): + _path_to_template: str + + def __init__(self, **kwargs): + if (template_path := kwargs.pop("path_to_template", None)) is not None: + self._path_to_template = template_path + + super().__init__(**kwargs) + + @property + def template(self): + if not os.path.exists(self._path_to_template): + raise FileNotFoundError( + f"Unable to find a file with template at '{self._path_to_template}' " + f"for '{self.__class__.__name__}' prompt." + ) + with open(self._path_to_template) as fp: + return fp.read() diff --git a/pandasai/prompts/correct_error_prompt.py b/pandasai/prompts/correct_error_prompt.py index 3a961761c..07b6ee8b5 100644 --- a/pandasai/prompts/correct_error_prompt.py +++ b/pandasai/prompts/correct_error_prompt.py @@ -16,25 +16,10 @@ Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again. """ # noqa: E501 -from .base import Prompt +from .base import FileBasedPrompt -class CorrectErrorPrompt(Prompt): +class CorrectErrorAbstractPrompt(FileBasedPrompt): """Prompt to Correct Python code on Error""" - text: str = """ -You are provided with the following {engine} DataFrames with the following metadata: - -{dataframes} - -The user asked the following question: -{conversation} - -You generated this python code: -{code} - -It fails with the following error: -{error_returned} - -Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again. -""" # noqa: E501 + _path_to_template = "assets/prompt-templates/correct_error_prompt.tmpl" diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index baa39e3ad..5c369b8cf 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -33,47 +33,19 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict: """ # noqa: E501 -from .base import Prompt +from .base import FileBasedPrompt -class GeneratePythonCodePrompt(Prompt): +class GeneratePythonCodeAbstractPrompt(FileBasedPrompt): """Prompt to generate Python code""" - text: str = """ -You are provided with the following pandas DataFrames: - -{dataframes} - - -{conversation} - - -This is the initial python code to be updated: -```python -# TODO import all the dependencies required -{default_import} + _path_to_template = "assets/prompt-templates/generate_python_code.tmpl" -def analyze_data(dfs: list[{engine_df_name}]) -> dict: - \"\"\" - Analyze the data - 1. Prepare: Preprocessing and cleaning data if necessary - 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.) - 4. Output: return a dictionary of: - - type (possible values "text", "number", "dataframe", "plot") - - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) - Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }} - \"\"\" -``` - -Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation. - -Updated code: -""" # noqa: E501 - - def __init__(self): + def __init__(self, **kwargs): default_import = "import pandas as pd" engine_df_name = "pd.DataFrame" self.set_var("default_import", default_import) self.set_var("engine_df_name", engine_df_name) + + super().__init__(**kwargs) diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 7ae01e5ed..9cac86650 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -32,9 +32,9 @@ from ..helpers.memory import Memory from ..schemas.df_config import Config from ..config import load_config -from ..prompts.base import Prompt -from ..prompts.correct_error_prompt import CorrectErrorPrompt -from ..prompts.generate_python_code import GeneratePythonCodePrompt +from ..prompts.base import AbstractPrompt +from ..prompts.correct_error_prompt import CorrectErrorAbstractPrompt +from ..prompts.generate_python_code import GeneratePythonCodeAbstractPrompt from typing import Union, List, Any, Type, Optional from ..helpers.code_manager import CodeManager from ..middlewares.base import Middleware @@ -204,20 +204,20 @@ def _is_running_in_console(self) -> bool: def _get_prompt( self, key: str, - default_prompt: Type[Prompt], + default_prompt: Type[AbstractPrompt], default_values: Optional[dict] = None, - ) -> Prompt: + ) -> AbstractPrompt: """ Return a prompt by key. Args: key (str): The key of the prompt - default_prompt (Type[Prompt]): The default prompt to use + default_prompt (Type[AbstractPrompt]): The default prompt to use default_values (Optional[dict], optional): The default values to use for the prompt. Defaults to None. Returns: - Prompt: The prompt + AbstractPrompt: The prompt """ if default_values is None: default_values = {} @@ -287,7 +287,7 @@ def chat(self, query: str): } generate_python_code_instruction = self._get_prompt( "generate_python_code", - default_prompt=GeneratePythonCodePrompt, + default_prompt=GeneratePythonCodeAbstractPrompt, default_values=default_values, ) @@ -436,7 +436,7 @@ def _retry_run_code(self, code: str, e: Exception): } error_correcting_instruction = self._get_prompt( "correct_error", - default_prompt=CorrectErrorPrompt, + default_prompt=CorrectErrorAbstractPrompt, default_values=default_values, ) diff --git a/tests/connectors/test_base.py b/tests/connectors/test_base.py index f80ba5b11..8533dd91c 100644 --- a/tests/connectors/test_base.py +++ b/tests/connectors/test_base.py @@ -14,7 +14,6 @@ def __init__(self, host, port, database, table): # Mock subclass of BaseConnector for testing class MockConnector(BaseConnector): - def _load_connector_config(self, config: BaseConnectorConfig): pass diff --git a/tests/llms/test_base_hf.py b/tests/llms/test_base_hf.py index c09430c32..1965a7cd9 100644 --- a/tests/llms/test_base_hf.py +++ b/tests/llms/test_base_hf.py @@ -4,7 +4,7 @@ import requests from pandasai.llm.base import HuggingFaceLLM -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt class TestBaseHfLLM: @@ -16,10 +16,10 @@ def api_response(self): @pytest.fixture def prompt(self): - class MockPrompt(Prompt): - text: str = "instruction" + class MockAbstractPrompt(AbstractPrompt): + template: str = "instruction" - return MockPrompt() + return MockAbstractPrompt() def test_type(self): assert HuggingFaceLLM(api_token="test_token").type == "huggingface-llm" @@ -62,10 +62,10 @@ def test_call(self, mocker, prompt): def test_call_removes_original_prompt(self, mocker): huggingface = HuggingFaceLLM(api_token="test_token") - class MockPrompt(Prompt): - text: str = "instruction " + class MockAbstractPrompt(AbstractPrompt): + template: str = "instruction " - instruction = MockPrompt() + instruction = MockAbstractPrompt() suffix = "suffix " mocker.patch.object( diff --git a/tests/llms/test_google_palm.py b/tests/llms/test_google_palm.py index 1d03534fa..91162819b 100644 --- a/tests/llms/test_google_palm.py +++ b/tests/llms/test_google_palm.py @@ -6,7 +6,7 @@ from pandasai.exceptions import APIKeyNotFoundError from pandasai.llm import GooglePalm -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt class MockedCompletion: @@ -19,10 +19,10 @@ class TestGooglePalm: @pytest.fixture def prompt(self): - class MockPrompt(Prompt): - text: str = "Hello" + class MockAbstractPrompt(AbstractPrompt): + template: str = "Hello" - return MockPrompt() + return MockAbstractPrompt() def test_type_without_token(self): with pytest.raises(APIKeyNotFoundError): diff --git a/tests/llms/test_huggingface_text_gen.py b/tests/llms/test_huggingface_text_gen.py index 1407662b3..1925b1561 100644 --- a/tests/llms/test_huggingface_text_gen.py +++ b/tests/llms/test_huggingface_text_gen.py @@ -1,10 +1,10 @@ """Unit tests for the LLaMa2TextGen LLM class""" -from pandasai import Prompt +from pandasai import AbstractPrompt from pandasai.llm import HuggingFaceTextGen -class MockPrompt(Prompt): - text: str = "instruction." +class MockAbstractPrompt(AbstractPrompt): + template: str = "instruction." class MockResponse: @@ -19,9 +19,8 @@ class TestHuggingFaceTextGen: def test_type_with_token(self): assert ( - HuggingFaceTextGen( - inference_server_url="http://127.0.0.1:8080" - ).type == "huggingface-text-generation" + HuggingFaceTextGen(inference_server_url="http://127.0.0.1:8080").type + == "huggingface-text-generation" ) def test_params_setting(self): @@ -30,7 +29,7 @@ def test_params_setting(self): max_new_tokens=1024, top_p=0.8, typical_p=0.8, - temperature=1E-3, + temperature=1e-3, stop_sequences=["\n"], seed=0, do_sample=False, @@ -53,11 +52,9 @@ def test_completion(self, mocker): expected_text = "This is the generated text." tgi_mock.return_value = MockResponse(expected_text) - llm = HuggingFaceTextGen( - inference_server_url="http://127.0.0.1:8080" - ) + llm = HuggingFaceTextGen(inference_server_url="http://127.0.0.1:8080") - instruction = MockPrompt() + instruction = MockAbstractPrompt() result = llm.call(instruction) tgi_mock.assert_called_once_with( diff --git a/tests/llms/test_langchain_llm.py b/tests/llms/test_langchain_llm.py index e9859ed07..a93e95e77 100644 --- a/tests/llms/test_langchain_llm.py +++ b/tests/llms/test_langchain_llm.py @@ -4,7 +4,7 @@ import pytest from pandasai.llm import LangchainLLM -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt from unittest.mock import Mock @@ -25,10 +25,10 @@ def __call__(self, _prompt, stop=None, callbacks=None, **kwargs): @pytest.fixture def prompt(self): - class MockPrompt(Prompt): - text: str = "Hello" + class MockAbstractPrompt(AbstractPrompt): + template: str = "Hello" - return MockPrompt() + return MockAbstractPrompt() def test_langchain_llm_type(self, langchain_llm): langchain_wrapper = LangchainLLM(langchain_llm) diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 65444a079..dacac2ada 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -4,7 +4,7 @@ from pandasai.exceptions import APIKeyNotFoundError, UnsupportedOpenAIModelError from pandasai.llm import OpenAI -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt from openai.openai_object import OpenAIObject @@ -13,10 +13,10 @@ class TestOpenAILLM: @pytest.fixture def prompt(self): - class MockPrompt(Prompt): - text: str = "instruction" + class MockAbstractPrompt(AbstractPrompt): + template: str = "instruction" - return MockPrompt() + return MockAbstractPrompt() def test_type_without_token(self): with pytest.raises(APIKeyNotFoundError): diff --git a/tests/prompts/test_base_prompt.py b/tests/prompts/test_base_prompt.py index 9e0ff6904..6ad67c3fd 100644 --- a/tests/prompts/test_base_prompt.py +++ b/tests/prompts/test_base_prompt.py @@ -2,56 +2,10 @@ import pytest -from pandasai.exceptions import MethodNotImplementedError -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt class TestBasePrompt: - """Unit tests for the base prompt class""" - - def test_text(self): - """Test that the text attribute is required""" - with pytest.raises(MethodNotImplementedError): - print(Prompt()) - - def test_str(self): - """Test that the __str__ method is implemented""" - - class TestPrompt(Prompt): - """Test prompt""" - - text = "Test prompt" - - assert str(TestPrompt()) == "Test prompt" - - def test_str_not_implemented(self): - """Test that the __str__ method is implemented""" - - class TestPrompt(Prompt): - """Test prompt""" - - pass - - with pytest.raises(MethodNotImplementedError): - str(TestPrompt()) - - def test_str_with_args(self): - """Test that the __str__ method is implemented""" - - class TestPrompt(Prompt): - """Test prompt""" - - text = "Test prompt {arg1} {arg2}" - - assert str(TestPrompt(arg1="arg1", arg2="arg2")) == "Test prompt arg1 arg2" - - def test_str_with_args_not_implemented(self): - """Test that the __str__ method is implemented""" - - class TestPrompt(Prompt): - """Test prompt""" - - text = "Test prompt {arg1} {arg2}" - - with pytest.raises(KeyError): - str(TestPrompt()) + def test_instantiate_without_template(self): + with pytest.raises(TypeError): + AbstractPrompt() diff --git a/tests/prompts/test_correct_error_prompt.py b/tests/prompts/test_correct_error_prompt.py index aaa29f66f..2ce8303e2 100644 --- a/tests/prompts/test_correct_error_prompt.py +++ b/tests/prompts/test_correct_error_prompt.py @@ -1,8 +1,9 @@ """Unit tests for the correct error prompt class""" +import sys import pandas as pd from pandasai import SmartDataframe -from pandasai.prompts import CorrectErrorPrompt +from pandasai.prompts import CorrectErrorAbstractPrompt from pandasai.llm.fake import FakeLLM @@ -19,14 +20,17 @@ def test_str_with_args(self): config={"llm": llm}, ) ] - prompt = CorrectErrorPrompt( + prompt = CorrectErrorAbstractPrompt( engine="pandas", code="df.head()", error_returned="Error message" ) prompt.set_var("dfs", dfs) prompt.set_var("conversation", "What is the correct code?") + prompt_content = prompt.to_string() + if sys.platform.startswith("win"): + prompt_content = prompt_content.replace("\r\n", "\n") assert ( - prompt.to_string() + prompt_content == """ You are provided with the following pandas DataFrames with the following metadata: diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index abad90140..7bf4a1673 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -1,9 +1,9 @@ """Unit tests for the generate python code prompt class""" - +import sys import pandas as pd from pandasai import SmartDataframe -from pandasai.prompts import GeneratePythonCodePrompt +from pandasai.prompts import GeneratePythonCodeAbstractPrompt from pandasai.llm.fake import FakeLLM @@ -20,12 +20,17 @@ def test_str_with_args(self): config={"llm": llm}, ) ] - prompt = GeneratePythonCodePrompt() + prompt = GeneratePythonCodeAbstractPrompt() prompt.set_var("dfs", dfs) prompt.set_var("conversation", "Question") prompt.set_var("save_charts_path", "exports/charts") + + prompt_content = prompt.to_string() + if sys.platform.startswith("win"): + prompt_content = prompt_content.replace("\r\n", "\n") + assert ( - prompt.to_string() + prompt_content == """ You are provided with the following pandas DataFrames: @@ -75,13 +80,17 @@ def test_str_with_custom_save_charts_path(self): ) ] - prompt = GeneratePythonCodePrompt() + prompt = GeneratePythonCodeAbstractPrompt() prompt.set_var("dfs", dfs) prompt.set_var("conversation", "Question") prompt.set_var("save_charts_path", "custom_path") + prompt_content = prompt.to_string() + if sys.platform.startswith("win"): + prompt_content = prompt_content.replace("\r\n", "\n") + assert ( - prompt.to_string() + prompt_content == """ You are provided with the following pandas DataFrames: diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index dced66ee5..bd1880152 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -17,7 +17,7 @@ from pandasai.llm.fake import FakeLLM from pandasai.middlewares import Middleware from pandasai.callbacks import StdoutCallback -from pandasai.prompts import Prompt +from pandasai.prompts import AbstractPrompt from pandasai.helpers.cache import Cache import logging @@ -377,13 +377,13 @@ def test_shortcut(self, smart_dataframe: SmartDataframe): smart_dataframe.chat.assert_called_once() def test_replace_generate_code_prompt(self, llm): - class CustomPrompt(Prompt): - text: str = """{test} || {dfs[0].shape[1]} || {conversation}""" + class CustomAbstractPrompt(AbstractPrompt): + template: str = """{test} || {dfs[0].shape[1]} || {conversation}""" def __init__(self, **kwargs): super().__init__(**kwargs) - replacement_prompt = CustomPrompt(test="test value") + replacement_prompt = CustomAbstractPrompt(test="test value") df = SmartDataframe( pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), config={ @@ -399,8 +399,10 @@ def __init__(self, **kwargs): assert llm.last_prompt == expected_last_prompt def test_replace_correct_error_prompt(self, llm): - class ReplacementPrompt(Prompt): - text = "Custom prompt" + class ReplacementPrompt(AbstractPrompt): + @property + def template(self): + return "Custom prompt" replacement_prompt = ReplacementPrompt() df = SmartDataframe( @@ -510,7 +512,7 @@ def test_updates_configs_with_setters(self, smart_dataframe: SmartDataframe): smart_dataframe.use_error_correction_framework = False assert smart_dataframe.use_error_correction_framework is False - smart_dataframe.custom_prompts = {"generate_python_code": Prompt()} + smart_dataframe.custom_prompts = {"generate_python_code": AbstractPrompt()} assert smart_dataframe.custom_prompts != {} smart_dataframe.save_charts = True