Skip to content

Commit

Permalink
feat: clean up the prompt to increase the accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Oct 11, 2023
1 parent fef6dd6 commit ac738c4
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 47 deletions.
6 changes: 2 additions & 4 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

You are provided with the following pandas DataFrames:

{dataframes}
Expand All @@ -7,7 +6,6 @@ You are provided with the following pandas DataFrames:
{conversation}
</conversation>

This is the initial python code to be updated:
```python
# TODO import all the dependencies required
{default_import}
Expand All @@ -20,6 +18,6 @@ def analyze_data(dfs: list[{engine_df_name}]) -> dict:
"""
```

Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.

Updated code:
Return the updated code:
2 changes: 1 addition & 1 deletion pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_conversation(self, limit: int = None) -> str:
limit = self._memory_size if limit is None else limit
return "\n".join(
[
f"{f'User {i+1}' if message['is_user'] else f'Assistant {i}'}: "
f"{'User' if message['is_user'] else 'Assistant'}: "
f"{message['message']}"
for i, message in enumerate(self._messages[-limit:])
]
Expand Down
26 changes: 13 additions & 13 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,31 @@
{conversation}
</conversation>
This is the initial python code to be updated:
```python
# TODO import all the dependencies required
{default_import}
def analyze_data(dfs: list[{engine_df_name}]) -> dict:
\"\"\"
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
- type (possible values "text", "number", "dataframe", "plot")
- type (possible values "string", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Example output: {{ "type": "string", "value": f"The average loan amount is {{average_amount}}" }}
Examples:
{ "type": "string", "value": "The highest salary is $9,000." }
or
{ "type": "number", "value": 125 }
or
{ "type": "dataframe", "value": pd.DataFrame({...}) }
or
{ "type": "plot", "value": "temp_chart.png" }
\"\"\"
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
{conversation}
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Updated code:
""" # noqa: E501
Return the updated code:""" # noqa: E501


from .file_based_prompt import FileBasedPrompt
Expand All @@ -53,9 +55,7 @@ def setup(self, **kwargs) -> None:
else:
self._set_instructions(
"""Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.""" # noqa: E501
)

def _set_instructions(self, instructions: str):
Expand Down
13 changes: 4 additions & 9 deletions tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint):
prompt.set_var("save_charts_path", save_charts_path)
prompt.set_var("output_type_hint", output_type_hint)

expected_prompt_content = f'''
You are provided with the following pandas DataFrames:
expected_prompt_content = f'''You are provided with the following pandas DataFrames:
<dataframe>
Dataframe dfs[0], with 1 rows and 2 columns.
Expand All @@ -66,26 +65,22 @@ def test_str_with_args(self, save_charts_path, output_type_hint):
Question
</conversation>
This is the initial python code to be updated:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
"""
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
{output_type_hint}
"""
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Updated code:
''' # noqa E501
Return the updated code:''' # noqa E501
actual_prompt_content = prompt.to_string()
if sys.platform.startswith("win"):
actual_prompt_content = actual_prompt_content.replace("\r\n", "\n")
Expand Down
30 changes: 10 additions & 20 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def test_run_with_privacy_enforcement(self, llm):
df = SmartDataframe(df, config={"llm": llm, "enable_cache": False})
df.enforce_privacy = True

expected_prompt = """
You are provided with the following pandas DataFrames:
expected_prompt = """You are provided with the following pandas DataFrames:
<dataframe>
Dataframe dfs[0], with 0 rows and 1 columns.
Expand All @@ -213,20 +212,17 @@ def test_run_with_privacy_enforcement(self, llm):
</dataframe>
<conversation>
User 1: How many countries are in the dataframe?
User: How many countries are in the dataframe?
</conversation>
This is the initial python code to be updated:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
\"\"\"
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
- type (possible values "string", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Expand All @@ -241,10 +237,9 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
\"\"\"
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Updated code:
""" # noqa: E501
Return the updated code:""" # noqa: E501
df.chat("How many countries are in the dataframe?")
last_prompt = df.last_prompt
if sys.platform.startswith("win"):
Expand All @@ -265,8 +260,7 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint):
df = pd.DataFrame({"country": []})
df = SmartDataframe(df, config={"llm": llm, "enable_cache": False})

expected_prompt = f'''
You are provided with the following pandas DataFrames:
expected_prompt = f'''You are provided with the following pandas DataFrames:
<dataframe>
Dataframe dfs[0], with 0 rows and 1 columns.
Expand All @@ -275,29 +269,25 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint):
</dataframe>
<conversation>
User 1: How many countries are in the dataframe?
User: How many countries are in the dataframe?
</conversation>
This is the initial python code to be updated:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
"""
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
{output_type_hint}
"""
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Updated code:
''' # noqa: E501
Return the updated code:''' # noqa: E501

df.chat("How many countries are in the dataframe?", output_type=output_type)
last_prompt = df.last_prompt
Expand Down

0 comments on commit ac738c4

Please sign in to comment.