Skip to content

Commit

Permalink
fix: output_type parameter (#519)
Browse files Browse the repository at this point in the history
* (tests): fix error in `TestGeneratePythonCodePrompt` with confused
  actual prompt's content and excepted prompt's content (which led to
  tests being failed)
* (refactor): update test method `test_str_with_args()` with
  `parametrize` decorator, remove duplication of code (DRY)
  • Loading branch information
nautics889 committed Sep 16, 2023
1 parent 6cafc12 commit 8ae345f
Showing 1 changed file with 29 additions and 75 deletions.
104 changes: 29 additions & 75 deletions tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys

import pandas as pd
import pytest

from pandasai import SmartDataframe
from pandasai.prompts import GeneratePythonCodePrompt
from pandasai.llm.fake import FakeLLM
Expand All @@ -10,8 +12,27 @@
class TestGeneratePythonCodePrompt:
"""Unit tests for the generate python code prompt class"""

def test_str_with_args(self):
"""Test that the __str__ method is implemented"""
@pytest.mark.parametrize(
"save_charts_path,output_type_hint",
[
("exports/charts", GeneratePythonCodePrompt._output_type_default),
("custom/dir/for/charts", GeneratePythonCodePrompt._output_type_default),
*[
("exports/charts", GeneratePythonCodePrompt._output_type_map[type_])
for type_ in GeneratePythonCodePrompt._output_type_map
],
],
)
def test_str_with_args(self, save_charts_path, output_type_hint):
"""Test casting of prompt to string and interpolation of context.
Parameterized for the following cases:
* `save_charts_path` is "exports/charts", `output_type_hint` is default
* `save_charts_path` is "custom/dir/for/charts", `output_type_hint`
is default
* `save_charts_path` is "exports/charts", `output_type_hint` any of
possible types in `GeneratePythonCodePrompt._output_type_map`
"""

llm = FakeLLM("plt.show()")
dfs = [
Expand All @@ -23,76 +44,10 @@ def test_str_with_args(self):
prompt = GeneratePythonCodePrompt()
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "exports/charts")
output_type_hint = """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
""" # noqa E501
prompt.set_var("output_type_hint", output_type_hint)

expected_prompt_content = '''
You are provided with the following pandas DataFrames:
<dataframe>
Dataframe dfs[0], with 1 rows and 2 columns.
This is the metadata of the dataframe dfs[0]:
a,b
1,4
</dataframe>
<conversation>
Question
</conversation>
This is the initial python code to be updated:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
"""
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in exports/charts/temp_chart.png and do not show the chart.)
4. Output: return a dictionary of:
- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Example output: { "type": "text", "value": "The average loan amount is $15,000." }
"""
```
Using the provided dataframes (`dfs`), update the python code based on the last question in the conversation.
Updated code:
''' # noqa E501
actual_prompt_content = prompt.to_string()
if sys.platform.startswith("win"):
actual_prompt_content = expected_prompt_content.replace("\r\n", "\n")
assert actual_prompt_content == expected_prompt_content

def test_str_with_custom_save_charts_path(self):
"""Test that the __str__ method is implemented"""

llm = FakeLLM("plt.show()")
dfs = [
SmartDataframe(
pd.DataFrame({"a": [1], "b": [4]}),
config={"llm": llm},
)
]

prompt = GeneratePythonCodePrompt()
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "custom_path")
# noqa E501
output_type_hint = """- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
""" # noqa E501
prompt.set_var("save_charts_path", save_charts_path)
prompt.set_var("output_type_hint", output_type_hint)

expected_prompt_content = '''
expected_prompt_content = f'''
You are provided with the following pandas DataFrames:
<dataframe>
Expand All @@ -116,11 +71,10 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
Analyze the data
1. Prepare: Preprocessing and cleaning data if necessary
2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in custom_path/temp_chart.png and do not show the chart.)
3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in {save_charts_path}/temp_chart.png and do not show the chart.)
4. Output: return a dictionary of:
- type (possible values "text", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Example output: { "type": "text", "value": "The average loan amount is $15,000." }
{output_type_hint}
Example output: {{ "type": "text", "value": "The average loan amount is $15,000." }}
"""
```
Expand All @@ -130,5 +84,5 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
''' # noqa E501
actual_prompt_content = prompt.to_string()
if sys.platform.startswith("win"):
actual_prompt_content = expected_prompt_content.replace("\r\n", "\n")
actual_prompt_content = actual_prompt_content.replace("\r\n", "\n")
assert actual_prompt_content == expected_prompt_content

0 comments on commit 8ae345f

Please sign in to comment.