diff --git a/docs/custom-prompts.md b/docs/custom-prompts.md
index f0c5f24f5..382e48ba1 100644
--- a/docs/custom-prompts.md
+++ b/docs/custom-prompts.md
@@ -24,6 +24,12 @@ class MyCustomPrompt(AbstractPrompt):
def template(self):
return """This is your custom text for your prompt with custom {my_custom_value}"""
+ def setup(self, kwargs):
+ # This method is called before the prompt is intialized
+ # You can use it to setup your prompt and pass any additional
+ # variables to the template
+ self.set_var("my_custom_value", kwargs["my_custom_value"])
+
df = SmartDataframe("data.csv", {
"custom_prompts": {
@@ -36,9 +42,11 @@ df = SmartDataframe("data.csv", {
You can also use `FileBasedPrompt` in case you prefer to store prompt template in a file:
_my_prompt_template.tmpl:_
+
```
This is your custom text for your prompt with custom {my_custom_value}
```
+
_python code:_
```python
diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py
index 69f8f6e71..3d73ff8de 100644
--- a/pandasai/agent/__init__.py
+++ b/pandasai/agent/__init__.py
@@ -81,7 +81,9 @@ def clarification_questions(self, query: str) -> List[str]:
Generate clarification questions based on the data
"""
prompt = ClarificationQuestionPrompt(
- self._lake.dfs, self._lake._memory.get_conversation(), query
+ dataframes=self._lake.dfs,
+ conversation=self._lake._memory.get_conversation(),
+ query=query,
)
result = self._call_llm_with_prompt(prompt)
@@ -104,8 +106,8 @@ def explain(self) -> str:
"""
try:
prompt = ExplainPrompt(
- self._lake._memory.get_conversation(),
- self._lake.last_code_executed,
+ conversation=self._lake._memory.get_conversation(),
+ code=self._lake.last_code_executed,
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
@@ -123,7 +125,9 @@ def explain(self) -> str:
def rephrase_query(self, query: str):
try:
prompt = RephraseQueryPrompt(
- query, self._lake.dfs, self._lake._memory.get_conversation()
+ query=query,
+ dataframes=self._lake.dfs,
+ conversation=self._lake._memory.get_conversation(),
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
diff --git a/pandasai/assets/prompt_templates/explain_prompt.tmpl b/pandasai/assets/prompt_templates/explain_prompt.tmpl
index 51a5ea3dd..cc03c3615 100644
--- a/pandasai/assets/prompt_templates/explain_prompt.tmpl
+++ b/pandasai/assets/prompt_templates/explain_prompt.tmpl
@@ -9,7 +9,7 @@ Based on the last conversation you generated the following code:
{code}
-
Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?
diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py
index e33341ba0..b5d297fb9 100644
--- a/pandasai/prompts/base.py
+++ b/pandasai/prompts/base.py
@@ -25,8 +25,11 @@ def __init__(self, **kwargs):
if self._args is None:
self._args = {}
- if kwargs:
- self._args = {**kwargs, **self._args}
+ self._args.update(kwargs)
+ self.setup(**kwargs)
+
+ def setup(self, **kwargs):
+ pass
def _generate_dataframes(self, dfs):
"""
diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py
index a38db32c3..a63ea19ab 100644
--- a/pandasai/prompts/clarification_questions_prompt.py
+++ b/pandasai/prompts/clarification_questions_prompt.py
@@ -26,15 +26,11 @@ class ClarificationQuestionPrompt(FileBasedPrompt):
_path_to_template = "assets/prompt_templates/clarification_questions_prompt.tmpl"
- def __init__(
- self, dataframes: List[pd.DataFrame], conversation: str, query: str, **kwargs
- ):
+ def setup(self, dataframes: List[pd.DataFrame], conversation: str, query: str):
self.set_var("dfs", dataframes)
self.set_var("conversation", conversation)
self.set_var("query", query)
- super().__init__(**kwargs)
-
def validate(self, output) -> bool:
try:
json_data = json.loads(output)
diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py
index 73f986090..6c20a92bd 100644
--- a/pandasai/prompts/explain_prompt.py
+++ b/pandasai/prompts/explain_prompt.py
@@ -9,7 +9,7 @@
{code}
-
Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?
@@ -23,8 +23,6 @@ class ExplainPrompt(FileBasedPrompt):
_path_to_template = "assets/prompt_templates/explain_prompt.tmpl"
- def __init__(self, conversation: str, code: str, **kwargs):
+ def setup(self, conversation: str, code: str):
self.set_var("conversation", conversation)
self.set_var("code", code)
-
- super().__init__(**kwargs)
diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py
index 72eb376d9..f19c8b702 100644
--- a/pandasai/prompts/generate_python_code.py
+++ b/pandasai/prompts/generate_python_code.py
@@ -41,7 +41,7 @@ class GeneratePythonCodePrompt(FileBasedPrompt):
_path_to_template = "assets/prompt_templates/generate_python_code.tmpl"
- def __init__(self, **kwargs):
+ def setup(self, **kwargs):
default_import = "import pandas as pd"
engine_df_name = "pd.DataFrame"
@@ -58,8 +58,6 @@ def __init__(self, **kwargs):
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501
)
- super().__init__(**kwargs)
-
def _set_instructions(self, instructions: str):
lines = instructions.split("\n")
indented_lines = [" " + line for line in lines[1:]]
diff --git a/pandasai/prompts/rephase_query_prompt.py b/pandasai/prompts/rephase_query_prompt.py
index de28ea71b..a359af67a 100644
--- a/pandasai/prompts/rephase_query_prompt.py
+++ b/pandasai/prompts/rephase_query_prompt.py
@@ -19,9 +19,7 @@ class RephraseQueryPrompt(FileBasedPrompt):
_path_to_template = "assets/prompt_templates/rephrase_query_prompt.tmpl"
- def __init__(
- self, query: str, dataframes: List[pd.DataFrame], conversation: str, **kwargs
- ):
+ def setup(self, query: str, dataframes: List[pd.DataFrame], conversation: str):
conversation_content = (
self.conversation_text.format(conversation=conversation)
if conversation
@@ -30,5 +28,3 @@ def __init__(
self.set_var("conversation", conversation_content)
self.set_var("query", query)
self.set_var("dfs", dataframes)
-
- super().__init__(**kwargs)
diff --git a/tests/test_agent.py b/tests/test_agent.py
index 3989e8f6c..630a2101b 100644
--- a/tests/test_agent.py
+++ b/tests/test_agent.py
@@ -183,7 +183,10 @@ def test_call_prompt_success(self, agent: Agent):
What is expected Salary Increase?
"""
agent._lake.llm.call.return_value = clarification_response
- prompt = ExplainPrompt("test conversation", "")
+ prompt = ExplainPrompt(
+ conversation="test conversation",
+ code="test code",
+ )
agent._call_llm_with_prompt(prompt)
assert agent._lake.llm.call.call_count == 1
@@ -200,7 +203,7 @@ def test_call_prompt_max_retry_on_error(self, agent: Agent):
# test the LLM call failed twice but succeed third time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), Exception(), "LLM Result"]
- prompt = ExplainPrompt("test conversation", "")
+ prompt = ExplainPrompt(conversation="test conversation", code="")
result = agent._call_llm_with_prompt(prompt)
assert result == "LLM Result"
assert agent._lake.llm.call.call_count == 3
@@ -209,7 +212,7 @@ def test_call_prompt_max_retry_twice(self, agent: Agent):
# test the LLM call failed once but succeed second time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), "LLM Result"]
- prompt = ExplainPrompt("test conversation", "")
+ prompt = ExplainPrompt(conversation="test conversation", code="")
result = agent._call_llm_with_prompt(prompt)
assert result == "LLM Result"
@@ -245,7 +248,9 @@ def test_clarification_prompt_validate_output_false_case(self, agent: Agent):
agent._lake.llm.call.return_value = "This is not json"
prompt = ClarificationQuestionPrompt(
- agent._lake.dfs, "test conversation", "test query"
+ dataframes=agent._lake.dfs,
+ conversation="test conversation",
+ query="test query",
)
with pytest.raises(Exception):
agent._call_llm_with_prompt(prompt)
@@ -256,7 +261,9 @@ def test_clarification_prompt_validate_output_true_case(self, agent: Agent):
agent._lake.llm.call.return_value = '["This is test quesiton"]'
prompt = ClarificationQuestionPrompt(
- agent._lake.dfs, "test conversation", "test query"
+ dataframes=agent._lake.dfs,
+ conversation="test conversation",
+ query="test query",
)
result = agent._call_llm_with_prompt(prompt)
# Didn't raise any exception