Skip to content

Commit

Permalink
feat: add additional step with the agent to figure out if the query i…
Browse files Browse the repository at this point in the history
…s related to the dconversation
  • Loading branch information
gventuri committed Oct 12, 2023
1 parent ac738c4 commit f9facf3
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 52 deletions.
47 changes: 38 additions & 9 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
from typing import Union, List, Optional
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.logger import Logger
from pandasai.helpers.memory import Memory
from pandasai.prompts.base import AbstractPrompt
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt
from pandasai.prompts.rephase_query_prompt import RephraseQueryPrompt
from pandasai.schemas.df_config import Config
from pandasai.smart_datalake import SmartDatalake
from ..helpers.df_info import DataFrameType
from ..helpers.logger import Logger
from ..helpers.memory import Memory
from ..prompts.base import AbstractPrompt
from ..prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from ..prompts.explain_prompt import ExplainPrompt
from ..prompts.rephase_query_prompt import RephraseQueryPrompt
from ..prompts.check_if_relevant_to_conversation import (
CheckIfRelevantToConversationPrompt,
)
from ..schemas.df_config import Config
from ..smart_datalake import SmartDatalake


class Agent:
Expand Down Expand Up @@ -67,6 +70,7 @@ def chat(self, query: str, output_type: Optional[str] = None):
Simulate a chat interaction with the assistant on Dataframe.
"""
try:
self.check_if_related_to_conversation(query)
result = self._lake.chat(query, output_type=output_type)
return result
except Exception as exception:
Expand All @@ -76,6 +80,31 @@ def chat(self, query: str, output_type: Optional[str] = None):
f"\n{exception}\n"
)

def check_if_related_to_conversation(self, query: str):
"""
Check if the query is related to the previous conversation
"""
if self._lake._memory.count() == 0:
return

prompt = CheckIfRelevantToConversationPrompt(
conversation=self._lake._memory.get_conversation(),
query=query,
)

result = self._call_llm_with_prompt(prompt)

related = False
if "true" in result:
related = True

self._logger.log(
f"""Check if the new message is related to the conversation: {related}"""
)

if not related:
self._lake.clear_memory()

def clarification_questions(self, query: str) -> List[str]:
"""
Generate clarification questions based on the data
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<conversation>
{conversation}
</conversation>

<query>
{query}
</query>

Is the query related to the conversation? Answer only "true" or "false" (lowercase).
9 changes: 9 additions & 0 deletions pandasai/assets/prompt_templates/current_code.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TODO import all the dependencies required
{default_import}

def analyze_data(dfs: list[{engine_df_name}]) -> dict:
"""
{instructions}
At the end, return a dictionary of:
{output_type_hint}
"""
15 changes: 5 additions & 10 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@ You are provided with the following pandas DataFrames:
{conversation}
</conversation>

This is the initial python code:
```python
# TODO import all the dependencies required
{default_import}

def analyze_data(dfs: list[{engine_df_name}]) -> dict:
"""
{instructions}
At the end, return a dictionary of:
{output_type_hint}
"""
{current_code}
```

Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function.
If the new query from the user is not relevant with the code, rewrite the content of the `analyze_data` function from scratch.
It is very important that you do not change the params that are passed to `analyze_data`.

Return the updated code:
4 changes: 4 additions & 0 deletions pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ def get_conversation(self, limit: int = None) -> str:

def clear(self):
self._messages = []

@property
def size(self):
return self._memory_size
15 changes: 14 additions & 1 deletion pandasai/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,21 @@ def set_var(self, var, value):
self._args["dataframes"] = self._generate_dataframes(value)
self._args[var] = value

def set_vars(self, vars):
if self._args is None:
self._args = {}
self._args.update(vars)

def to_string(self):
return self.template.format(**self._args)
prompt_args = {}
for key, value in self._args.items():
if isinstance(value, AbstractPrompt):
value.set_vars({k: v for k, v in self._args.items() if k != key})
prompt_args[key] = value.to_string()
else:
prompt_args[key] = value

return self.template.format_map(prompt_args)

def __str__(self):
return self.to_string()
Expand Down
19 changes: 19 additions & 0 deletions pandasai/prompts/check_if_relevant_to_conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
""" Prompt to check if the query is related to the previous conversation
<conversation>
{conversation}
</conversation>
<query>
{query}
</query>
Is the query related to the conversation? Answer only "true" or "false" (lowercase).
"""
from .file_based_prompt import FileBasedPrompt


class CheckIfRelevantToConversationPrompt(FileBasedPrompt):
"""Prompt to check if the query is related to the previous conversation"""

_path_to_template = "assets/prompt_templates/check_if_relevant_to_conversation.tmpl"
40 changes: 17 additions & 23 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,25 @@
{conversation}
</conversation>
```python
# TODO import all the dependencies required
{default_import}
def analyze_data(dfs: list[{engine_df_name}]) -> dict:
\"\"\"
Analyze the data
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
- type (possible values "string", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Examples:
{ "type": "string", "value": "The highest salary is $9,000." }
or
{ "type": "number", "value": 125 }
or
{ "type": "dataframe", "value": pd.DataFrame({...}) }
or
{ "type": "plot", "value": "temp_chart.png" }
\"\"\"
```
This is the initial python code:
{current_code}
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function.
If the new query from the user is not relevant with the code, rewrite the content of the `analyze_data` function from scratch.
It is very important that you do not change the params that are passed to `analyze_data`.
Return the updated code:""" # noqa: E501


from .file_based_prompt import FileBasedPrompt


class CurrentCodePrompt(FileBasedPrompt):
"""The current code"""

_path_to_template = "assets/prompt_templates/current_code.tmpl"


class GeneratePythonCodePrompt(FileBasedPrompt):
"""Prompt to generate Python code"""

Expand All @@ -54,10 +43,15 @@ def setup(self, **kwargs) -> None:
self._set_instructions(kwargs["custom_instructions"])
else:
self._set_instructions(
"""Analyze the data
"""Analyze the data.
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.""" # noqa: E501
)

if "current_code" in kwargs:
self.set_var("current_code", kwargs["current_code"])
else:
self.set_var("current_code", CurrentCodePrompt())

def _set_instructions(self, instructions: str):
lines = instructions.split("\n")
indented_lines = [" " + line for line in lines[1:]]
Expand Down
9 changes: 8 additions & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class SmartDatalake:
_code_manager: CodeManager
_memory: Memory

_last_code_generated: str
_last_code_generated: str = None
_last_result: str = None
_last_error: str = None

Expand Down Expand Up @@ -305,6 +305,13 @@ def chat(self, query: str, output_type: Optional[str] = None):
"output_type_hint": output_type_helper.template_hint,
}

if (
self.memory.size > 1
and self.memory.count() > 1
and self._last_code_generated
):
default_values["current_code"] = self._last_code_generated

generate_python_code_instruction = self._get_prompt(
"generate_python_code",
default_prompt=GeneratePythonCodePrompt,
Expand Down
11 changes: 7 additions & 4 deletions tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,23 @@ def test_str_with_args(self, save_charts_path, output_type_hint):
Question
</conversation>
This is the initial python code:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
"""
Analyze the data
Analyze the data.
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
{output_type_hint}
"""
```
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function.
If the new query from the user is not relevant with the code, rewrite the content of the `analyze_data` function from scratch.
It is very important that you do not change the params that are passed to `analyze_data`.
Return the updated code:''' # noqa E501
actual_prompt_content = prompt.to_string()
Expand All @@ -87,7 +90,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
assert actual_prompt_content == expected_prompt_content

def test_custom_instructions(self):
custom_instructions = """Analyze the data
custom_instructions = """Analyze the data.
1. Load: Load the data from a file or database
2. Prepare: Preprocessing and cleaning data if necessary
3. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
Expand All @@ -98,7 +101,7 @@ def test_custom_instructions(self):

assert (
actual_instructions
== """Analyze the data
== """Analyze the data.
1. Load: Load the data from a file or database
2. Prepare: Preprocessing and cleaning data if necessary
3. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.)
Expand Down
14 changes: 10 additions & 4 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,14 @@ def test_run_with_privacy_enforcement(self, llm):
User: How many countries are in the dataframe?
</conversation>
This is the initial python code:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
\"\"\"
Analyze the data
Analyze the data.
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
- type (possible values "string", "number", "dataframe", "plot")
Expand All @@ -237,7 +238,9 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
\"\"\"
```
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function.
If the new query from the user is not relevant with the code, rewrite the content of the `analyze_data` function from scratch.
It is very important that you do not change the params that are passed to `analyze_data`.
Return the updated code:""" # noqa: E501
df.chat("How many countries are in the dataframe?")
Expand Down Expand Up @@ -272,20 +275,23 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint):
User: How many countries are in the dataframe?
</conversation>
This is the initial python code:
```python
# TODO import all the dependencies required
import pandas as pd
def analyze_data(dfs: list[pd.DataFrame]) -> dict:
"""
Analyze the data
Analyze the data.
If the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.
At the end, return a dictionary of:
{output_type_hint}
"""
```
Use the provided dataframes (`dfs`) and update the python code to answer the last question in the conversation.
Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function.
If the new query from the user is not relevant with the code, rewrite the content of the `analyze_data` function from scratch.
It is very important that you do not change the params that are passed to `analyze_data`.
Return the updated code:''' # noqa: E501

Expand Down

0 comments on commit f9facf3

Please sign in to comment.