Skip to content

Commit

Permalink
refactor: replace init method with setup in prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Oct 4, 2023
1 parent 4ce05f0 commit b379d03
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 29 deletions.
8 changes: 8 additions & 0 deletions docs/custom-prompts.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class MyCustomPrompt(AbstractPrompt):
def template(self):
return """This is your custom text for your prompt with custom {my_custom_value}"""

def setup(self, kwargs):
# This method is called before the prompt is intialized
# You can use it to setup your prompt and pass any additional
# variables to the template
self.set_var("my_custom_value", kwargs["my_custom_value"])


df = SmartDataframe("data.csv", {
"custom_prompts": {
Expand All @@ -36,9 +42,11 @@ df = SmartDataframe("data.csv", {
You can also use `FileBasedPrompt` in case you prefer to store prompt template in a file:

_my_prompt_template.tmpl:_

```
This is your custom text for your prompt with custom {my_custom_value}
```

_python code:_

```python
Expand Down
12 changes: 8 additions & 4 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def clarification_questions(self, query: str) -> List[str]:
Generate clarification questions based on the data
"""
prompt = ClarificationQuestionPrompt(
self._lake.dfs, self._lake._memory.get_conversation(), query
dataframes=self._lake.dfs,
conversation=self._lake._memory.get_conversation(),
query=query,
)

result = self._call_llm_with_prompt(prompt)
Expand All @@ -104,8 +106,8 @@ def explain(self) -> str:
"""
try:
prompt = ExplainPrompt(
self._lake._memory.get_conversation(),
self._lake.last_code_executed,
conversation=self._lake._memory.get_conversation(),
code=self._lake.last_code_executed,
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
Expand All @@ -123,7 +125,9 @@ def explain(self) -> str:
def rephrase_query(self, query: str):
try:
prompt = RephraseQueryPrompt(
query, self._lake.dfs, self._lake._memory.get_conversation()
query=query,
dataframes=self._lake.dfs,
conversation=self._lake._memory.get_conversation(),
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
Expand Down
2 changes: 1 addition & 1 deletion pandasai/assets/prompt_templates/explain_prompt.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Based on the last conversation you generated the following code:

<Code>
{code}
</Code
</Code>

Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?
7 changes: 5 additions & 2 deletions pandasai/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ def __init__(self, **kwargs):
if self._args is None:
self._args = {}

if kwargs:
self._args = {**kwargs, **self._args}
self._args.update(kwargs)
self.setup(**kwargs)

def setup(self, **kwargs):
pass

def _generate_dataframes(self, dfs):
"""
Expand Down
6 changes: 1 addition & 5 deletions pandasai/prompts/clarification_questions_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@ class ClarificationQuestionPrompt(FileBasedPrompt):

_path_to_template = "assets/prompt_templates/clarification_questions_prompt.tmpl"

def __init__(
self, dataframes: List[pd.DataFrame], conversation: str, query: str, **kwargs
):
def setup(self, dataframes: List[pd.DataFrame], conversation: str, query: str):
self.set_var("dfs", dataframes)
self.set_var("conversation", conversation)
self.set_var("query", query)

super().__init__(**kwargs)

def validate(self, output) -> bool:
try:
json_data = json.loads(output)
Expand Down
6 changes: 2 additions & 4 deletions pandasai/prompts/explain_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<Code>
{code}
</Code
</Code>
Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?
Expand All @@ -23,8 +23,6 @@ class ExplainPrompt(FileBasedPrompt):

_path_to_template = "assets/prompt_templates/explain_prompt.tmpl"

def __init__(self, conversation: str, code: str, **kwargs):
def setup(self, conversation: str, code: str):
self.set_var("conversation", conversation)
self.set_var("code", code)

super().__init__(**kwargs)
4 changes: 1 addition & 3 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GeneratePythonCodePrompt(FileBasedPrompt):

_path_to_template = "assets/prompt_templates/generate_python_code.tmpl"

def __init__(self, **kwargs):
def setup(self, **kwargs):
default_import = "import pandas as pd"
engine_df_name = "pd.DataFrame"

Expand All @@ -58,8 +58,6 @@ def __init__(self, **kwargs):
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501
)

super().__init__(**kwargs)

def _set_instructions(self, instructions: str):
lines = instructions.split("\n")
indented_lines = [" " + line for line in lines[1:]]
Expand Down
6 changes: 1 addition & 5 deletions pandasai/prompts/rephase_query_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ class RephraseQueryPrompt(FileBasedPrompt):

_path_to_template = "assets/prompt_templates/rephrase_query_prompt.tmpl"

def __init__(
self, query: str, dataframes: List[pd.DataFrame], conversation: str, **kwargs
):
def setup(self, query: str, dataframes: List[pd.DataFrame], conversation: str):
conversation_content = (
self.conversation_text.format(conversation=conversation)
if conversation
Expand All @@ -30,5 +28,3 @@ def __init__(
self.set_var("conversation", conversation_content)
self.set_var("query", query)
self.set_var("dfs", dataframes)

super().__init__(**kwargs)
17 changes: 12 additions & 5 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ def test_call_prompt_success(self, agent: Agent):
What is expected Salary Increase?
"""
agent._lake.llm.call.return_value = clarification_response
prompt = ExplainPrompt("test conversation", "")
prompt = ExplainPrompt(
conversation="test conversation",
code="test code",
)
agent._call_llm_with_prompt(prompt)
assert agent._lake.llm.call.call_count == 1

Expand All @@ -200,7 +203,7 @@ def test_call_prompt_max_retry_on_error(self, agent: Agent):
# test the LLM call failed twice but succeed third time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), Exception(), "LLM Result"]
prompt = ExplainPrompt("test conversation", "")
prompt = ExplainPrompt(conversation="test conversation", code="")
result = agent._call_llm_with_prompt(prompt)
assert result == "LLM Result"
assert agent._lake.llm.call.call_count == 3
Expand All @@ -209,7 +212,7 @@ def test_call_prompt_max_retry_twice(self, agent: Agent):
# test the LLM call failed once but succeed second time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), "LLM Result"]
prompt = ExplainPrompt("test conversation", "")
prompt = ExplainPrompt(conversation="test conversation", code="")
result = agent._call_llm_with_prompt(prompt)

assert result == "LLM Result"
Expand Down Expand Up @@ -245,7 +248,9 @@ def test_clarification_prompt_validate_output_false_case(self, agent: Agent):
agent._lake.llm.call.return_value = "This is not json"

prompt = ClarificationQuestionPrompt(
agent._lake.dfs, "test conversation", "test query"
dataframes=agent._lake.dfs,
conversation="test conversation",
query="test query",
)
with pytest.raises(Exception):
agent._call_llm_with_prompt(prompt)
Expand All @@ -256,7 +261,9 @@ def test_clarification_prompt_validate_output_true_case(self, agent: Agent):
agent._lake.llm.call.return_value = '["This is test quesiton"]'

prompt = ClarificationQuestionPrompt(
agent._lake.dfs, "test conversation", "test query"
dataframes=agent._lake.dfs,
conversation="test conversation",
query="test query",
)
result = agent._call_llm_with_prompt(prompt)
# Didn't raise any exception
Expand Down

0 comments on commit b379d03

Please sign in to comment.