diff --git a/docs/custom-prompts.md b/docs/custom-prompts.md index f0c5f24f5..382e48ba1 100644 --- a/docs/custom-prompts.md +++ b/docs/custom-prompts.md @@ -24,6 +24,12 @@ class MyCustomPrompt(AbstractPrompt): def template(self): return """This is your custom text for your prompt with custom {my_custom_value}""" + def setup(self, kwargs): + # This method is called before the prompt is intialized + # You can use it to setup your prompt and pass any additional + # variables to the template + self.set_var("my_custom_value", kwargs["my_custom_value"]) + df = SmartDataframe("data.csv", { "custom_prompts": { @@ -36,9 +42,11 @@ df = SmartDataframe("data.csv", { You can also use `FileBasedPrompt` in case you prefer to store prompt template in a file: _my_prompt_template.tmpl:_ + ``` This is your custom text for your prompt with custom {my_custom_value} ``` + _python code:_ ```python diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 69f8f6e71..3d73ff8de 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -81,7 +81,9 @@ def clarification_questions(self, query: str) -> List[str]: Generate clarification questions based on the data """ prompt = ClarificationQuestionPrompt( - self._lake.dfs, self._lake._memory.get_conversation(), query + dataframes=self._lake.dfs, + conversation=self._lake._memory.get_conversation(), + query=query, ) result = self._call_llm_with_prompt(prompt) @@ -104,8 +106,8 @@ def explain(self) -> str: """ try: prompt = ExplainPrompt( - self._lake._memory.get_conversation(), - self._lake.last_code_executed, + conversation=self._lake._memory.get_conversation(), + code=self._lake.last_code_executed, ) response = self._call_llm_with_prompt(prompt) self._logger.log( @@ -123,7 +125,9 @@ def explain(self) -> str: def rephrase_query(self, query: str): try: prompt = RephraseQueryPrompt( - query, self._lake.dfs, self._lake._memory.get_conversation() + query=query, + dataframes=self._lake.dfs, + conversation=self._lake._memory.get_conversation(), ) response = self._call_llm_with_prompt(prompt) self._logger.log( diff --git a/pandasai/assets/prompt_templates/explain_prompt.tmpl b/pandasai/assets/prompt_templates/explain_prompt.tmpl index 51a5ea3dd..cc03c3615 100644 --- a/pandasai/assets/prompt_templates/explain_prompt.tmpl +++ b/pandasai/assets/prompt_templates/explain_prompt.tmpl @@ -9,7 +9,7 @@ Based on the last conversation you generated the following code: {code} - Explain how you came up with code for non-technical people without mentioning technical details or mentioning the libraries used? diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index e33341ba0..b5d297fb9 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -25,8 +25,11 @@ def __init__(self, **kwargs): if self._args is None: self._args = {} - if kwargs: - self._args = {**kwargs, **self._args} + self._args.update(kwargs) + self.setup(**kwargs) + + def setup(self, **kwargs): + pass def _generate_dataframes(self, dfs): """ diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py index a38db32c3..a63ea19ab 100644 --- a/pandasai/prompts/clarification_questions_prompt.py +++ b/pandasai/prompts/clarification_questions_prompt.py @@ -26,15 +26,11 @@ class ClarificationQuestionPrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/clarification_questions_prompt.tmpl" - def __init__( - self, dataframes: List[pd.DataFrame], conversation: str, query: str, **kwargs - ): + def setup(self, dataframes: List[pd.DataFrame], conversation: str, query: str): self.set_var("dfs", dataframes) self.set_var("conversation", conversation) self.set_var("query", query) - super().__init__(**kwargs) - def validate(self, output) -> bool: try: json_data = json.loads(output) diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py index 73f986090..6c20a92bd 100644 --- a/pandasai/prompts/explain_prompt.py +++ b/pandasai/prompts/explain_prompt.py @@ -9,7 +9,7 @@ {code} - Explain how you came up with code for non-technical people without mentioning technical details or mentioning the libraries used? @@ -23,8 +23,6 @@ class ExplainPrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/explain_prompt.tmpl" - def __init__(self, conversation: str, code: str, **kwargs): + def setup(self, conversation: str, code: str): self.set_var("conversation", conversation) self.set_var("code", code) - - super().__init__(**kwargs) diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 72eb376d9..f19c8b702 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -41,7 +41,7 @@ class GeneratePythonCodePrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/generate_python_code.tmpl" - def __init__(self, **kwargs): + def setup(self, **kwargs): default_import = "import pandas as pd" engine_df_name = "pd.DataFrame" @@ -58,8 +58,6 @@ def __init__(self, **kwargs): 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501 ) - super().__init__(**kwargs) - def _set_instructions(self, instructions: str): lines = instructions.split("\n") indented_lines = [" " + line for line in lines[1:]] diff --git a/pandasai/prompts/rephase_query_prompt.py b/pandasai/prompts/rephase_query_prompt.py index de28ea71b..a359af67a 100644 --- a/pandasai/prompts/rephase_query_prompt.py +++ b/pandasai/prompts/rephase_query_prompt.py @@ -19,9 +19,7 @@ class RephraseQueryPrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/rephrase_query_prompt.tmpl" - def __init__( - self, query: str, dataframes: List[pd.DataFrame], conversation: str, **kwargs - ): + def setup(self, query: str, dataframes: List[pd.DataFrame], conversation: str): conversation_content = ( self.conversation_text.format(conversation=conversation) if conversation @@ -30,5 +28,3 @@ def __init__( self.set_var("conversation", conversation_content) self.set_var("query", query) self.set_var("dfs", dataframes) - - super().__init__(**kwargs) diff --git a/tests/test_agent.py b/tests/test_agent.py index 3989e8f6c..630a2101b 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -183,7 +183,10 @@ def test_call_prompt_success(self, agent: Agent): What is expected Salary Increase? """ agent._lake.llm.call.return_value = clarification_response - prompt = ExplainPrompt("test conversation", "") + prompt = ExplainPrompt( + conversation="test conversation", + code="test code", + ) agent._call_llm_with_prompt(prompt) assert agent._lake.llm.call.call_count == 1 @@ -200,7 +203,7 @@ def test_call_prompt_max_retry_on_error(self, agent: Agent): # test the LLM call failed twice but succeed third time agent._lake.llm.call = Mock() agent._lake.llm.call.side_effect = [Exception(), Exception(), "LLM Result"] - prompt = ExplainPrompt("test conversation", "") + prompt = ExplainPrompt(conversation="test conversation", code="") result = agent._call_llm_with_prompt(prompt) assert result == "LLM Result" assert agent._lake.llm.call.call_count == 3 @@ -209,7 +212,7 @@ def test_call_prompt_max_retry_twice(self, agent: Agent): # test the LLM call failed once but succeed second time agent._lake.llm.call = Mock() agent._lake.llm.call.side_effect = [Exception(), "LLM Result"] - prompt = ExplainPrompt("test conversation", "") + prompt = ExplainPrompt(conversation="test conversation", code="") result = agent._call_llm_with_prompt(prompt) assert result == "LLM Result" @@ -245,7 +248,9 @@ def test_clarification_prompt_validate_output_false_case(self, agent: Agent): agent._lake.llm.call.return_value = "This is not json" prompt = ClarificationQuestionPrompt( - agent._lake.dfs, "test conversation", "test query" + dataframes=agent._lake.dfs, + conversation="test conversation", + query="test query", ) with pytest.raises(Exception): agent._call_llm_with_prompt(prompt) @@ -256,7 +261,9 @@ def test_clarification_prompt_validate_output_true_case(self, agent: Agent): agent._lake.llm.call.return_value = '["This is test quesiton"]' prompt = ClarificationQuestionPrompt( - agent._lake.dfs, "test conversation", "test query" + dataframes=agent._lake.dfs, + conversation="test conversation", + query="test query", ) result = agent._call_llm_with_prompt(prompt) # Didn't raise any exception