Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(RephraseQuery): rephrase user query to get more accurate responses #592

Merged
merged 26 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
70b0da8
feat[Agent]: add agent conversation code
ArslanSaleem Sep 21, 2023
1b51727
feat[Agent]: add test cases for the agent class
ArslanSaleem Sep 22, 2023
70244c3
feat: add explain method
ArslanSaleem Sep 22, 2023
f715035
feat: Add Explain functionality in the agent
ArslanSaleem Sep 22, 2023
2da890a
fix: refactor types
ArslanSaleem Sep 22, 2023
6736c44
chore: fix typings
ArslanSaleem Sep 22, 2023
cdeec68
chore: improve prompt add conversation
ArslanSaleem Sep 22, 2023
9025f4e
refactor: remove memory from the agent class
ArslanSaleem Sep 22, 2023
d1b8e61
refactor: import of Agent class in example
ArslanSaleem Sep 22, 2023
49d8720
refactor: memory to return conversation according to size
ArslanSaleem Sep 22, 2023
b92fb39
refactor: remove leftover property
ArslanSaleem Sep 22, 2023
7f17af8
fix: prompt comment
ArslanSaleem Sep 22, 2023
2e4c902
fix: redundant try catch
ArslanSaleem Sep 23, 2023
7a554a5
chore: improve docstring and add example in documentation
ArslanSaleem Sep 23, 2023
f7e4d98
fix: Comment in clarification prompts and add dtyps to the constructors
ArslanSaleem Sep 24, 2023
21f5bd8
feat(RephraseQuery): rephrase user query to get more accurate responses
ArslanSaleem Sep 25, 2023
393f2f2
Merge branch 'feature/v1.3' into agent_rephrase_query
ArslanSaleem Sep 25, 2023
adfc86a
chore(agent): add max retries on queries
ArslanSaleem Sep 25, 2023
bf9667b
feat: improve the prompt to also add information about ambiguous parts
gventuri Sep 25, 2023
cccee44
feat[retry_wrapper]: add basic wrapper for error handling and add pro…
ArslanSaleem Sep 25, 2023
2d87306
Merge branch 'agent_rephrase_query' into retry_wrapper
ArslanSaleem Sep 25, 2023
995b90d
Merge branch 'agent_rephrase_query' into retry_wrapper
ArslanSaleem Sep 25, 2023
6fa9c1d
refactor(validation): return False from the validator in case of failure
ArslanSaleem Sep 25, 2023
30765ac
Merge branch 'feature/v1.3' into agent_rephrase_query
ArslanSaleem Sep 25, 2023
e9b2342
fix(RephraseQuery): remove conversation from the prompt if empty
ArslanSaleem Sep 25, 2023
0c3c997
Merge branch 'agent_rephrase_query' into retry_wrapper
ArslanSaleem Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,17 @@ print("The answer is", response)
print("The explanation is", explanation)
```

### Rephrase Question

Rephrase question to get accurate and comprehensive response from the model. For example:

```python
rephrased_query = agent.rephrase_query('What is the GDP of the United States?')

print("The answer is", rephrased_query)

```

## Config

When you instantiate a `SmartDataframe`, you can pass a `config` object as the second argument. This object can contain custom settings that will be used by `pandasai` when generating code.
Expand Down
46 changes: 44 additions & 2 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.logger import Logger
from pandasai.helpers.memory import Memory
from pandasai.prompts.base import Prompt
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt
from pandasai.prompts.rephase_query_prompt import RephraseQueryPrompt
from pandasai.schemas.df_config import Config
from pandasai.smart_datalake import SmartDatalake

Expand Down Expand Up @@ -38,6 +40,28 @@ def __init__(
self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size))
self._logger = self._lake.logger

def _call_llm_with_prompt(self, prompt: Prompt):
"""
Call LLM with prompt using error handling to retry based on config
Args:
prompt (Prompt): Prompt to pass to LLM's
"""
retry_count = 0
while retry_count < self._lake.config.max_retries:
try:
result: str = self._lake.llm.call(prompt)
if prompt.validate(result):
return result
else:
raise Exception("Response validation failed!")
except Exception:
if (
not self._lake.use_error_correction_framework
or retry_count >= self._lake.config.max_retries - 1
):
raise
retry_count += 1

Comment on lines +43 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _call_llm_with_prompt method is a good addition for handling retries and error management. However, it's important to note that the exception handling here is quite broad. It catches all exceptions without distinguishing between different types of errors. This could potentially hide unexpected issues and make debugging more difficult. Consider refining the exception handling to be more specific or at least log the exception details before retrying.

def chat(self, query: str, output_type: Optional[str] = None):
"""
Simulate a chat interaction with the assistant on Dataframe.
Expand All @@ -60,7 +84,7 @@ def clarification_questions(self, query: str) -> List[str]:
self._lake.dfs, self._lake._memory.get_conversation(), query
)

result = self._lake.llm.call(prompt)
result = self._call_llm_with_prompt(prompt)
self._logger.log(
f"""Clarification Questions: {result}
"""
Expand All @@ -83,7 +107,7 @@ def explain(self) -> str:
self._lake._memory.get_conversation(),
self._lake.last_code_executed,
)
response = self._lake.llm.call(prompt)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
f"""Explaination: {response}
"""
Expand All @@ -95,3 +119,21 @@ def explain(self) -> str:
"because of the following error:\n"
f"\n{exception}\n"
)

def rephrase_query(self, query: str):
try:
prompt = RephraseQueryPrompt(
query, self._lake.dfs, self._lake._memory.get_conversation()
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
f"""Rephrased Response: {response}
"""
)
return response
except Exception as exception:
return (
"Unfortunately, I was not able to repharse query, "
"because of the following error:\n"
f"\n{exception}\n"
)
Comment on lines +123 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rephrase_query method is a new feature that rephrases a given query using the LLM. The implementation looks correct and follows the same pattern as other methods in the class. However, similar to the previous comment, the exception handling is too broad and could hide unexpected issues. Consider refining the exception handling to be more specific or at least log the exception details when an error occurs.

3 changes: 3 additions & 0 deletions pandasai/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ def to_string(self):

def __str__(self):
return self.to_string()

def validate(self, output: str) -> bool:
return isinstance(output, str)
8 changes: 8 additions & 0 deletions pandasai/prompts/clarification_questions_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" # noqa: E501


import json
from typing import List
import pandas as pd
from .base import Prompt
Expand Down Expand Up @@ -49,3 +50,10 @@ def __init__(self, dataframes: List[pd.DataFrame], conversation: str, query: str
self.set_var("dfs", dataframes)
self.set_var("conversation", conversation)
self.set_var("query", query)

def validate(self, output) -> bool:
try:
json_data = json.loads(output)
return isinstance(json_data, List)
except json.JSONDecodeError:
return False
48 changes: 48 additions & 0 deletions pandasai/prompts/rephase_query_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
""" Prompt to rephrase query to get more accurate responses
You are provided with the following pandas DataFrames:

{dataframes}
{conversation}
Return the rephrased sentence of "{query}” in order to obtain more accurate and
comprehensive responses without any explanations. If something from the original
query is ambiguous, please clarify it in the rephrased query, making assumptions,
if necessary.

"""
from typing import List

import pandas as pd
from .base import Prompt


class RephraseQueryPrompt(Prompt):
"""Prompt to rephrase query to get more accurate responses"""

text: str = """
You are provided with the following pandas DataFrames:

{dataframes}
{conversation}
Return the rephrased sentence of "{query}” in order to obtain more accurate and
comprehensive responses without any explanations. If something from the original
query is ambiguous, please clarify it in the rephrased query, making assumptions,
if necessary.
"""
Comment on lines +21 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same text is repeated twice in the class RephraseQueryPrompt. This repetition could be avoided by defining the text once and reusing it. This would make the code more maintainable and easier to modify in the future.

-    text: str = """
-    You are provided with the following pandas DataFrames:
-    
-    {dataframes}
-    {conversation}
-    Return the rephrased sentence of "{query}” in order to obtain more accurate and 
-    comprehensive responses without any explanations. If something from the original
-    query is ambiguous, please clarify it in the rephrased query, making assumptions,
-    if necessary.
-    """
-
-    conversation_text: str = """
-    And based on our conversation:
-    
-    <conversation>
-    {conversation}
-    </conversation>
-    """
+    PROMPT_TEXT: str = """
+    You are provided with the following pandas DataFrames:
+    
+    {dataframes}
+    {conversation}
+    Return the rephrased sentence of "{query}” in order to obtain more accurate and 
+    comprehensive responses without any explanations. If something from the original
+    query is ambiguous, please clarify it in the rephrased query, making assumptions,
+    if necessary.
+    """
+
+    CONVERSATION_TEXT: str = """
+    And based on our conversation:
+    
+    <conversation>
+    {conversation}
+    </conversation>
+    """


conversation_text: str = """
And based on our conversation:

<conversation>
{conversation}
</conversation>
"""

def __init__(self, query: str, dataframes: List[pd.DataFrame], conversation: str):
conversation_content = (
self.conversation_text.format(conversation=conversation)
if conversation
else ""
)
self.set_var("conversation", conversation_content)
self.set_var("query", query)
self.set_var("dfs", dataframes)
Comment on lines +40 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no validation for the input parameters query, dataframes, and conversation. It's a good practice to validate function arguments to ensure they meet certain conditions before proceeding with the function execution. For instance, you could check if query is a non-empty string, dataframes is a list of pandas DataFrames, and conversation is a string.

130 changes: 122 additions & 8 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pandas as pd
import pytest
from pandasai.llm.fake import FakeLLM
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt

from pandasai.smart_datalake import SmartDatalake

Expand Down Expand Up @@ -55,19 +57,31 @@ def sample_df(self):
)

@pytest.fixture
def llm(self, output: Optional[str] = None):
def llm(self, output: Optional[str] = None) -> FakeLLM:
return FakeLLM(output=output)

@pytest.fixture
def config(self, llm: FakeLLM):
def config(self, llm: FakeLLM) -> dict:
return {"llm": llm}

@pytest.fixture
def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent:
return Agent(sample_df, config)

def test_constructor(self, sample_df, config):
agent = Agent(sample_df, config)
assert isinstance(agent._lake, SmartDatalake)
agent_1 = Agent(sample_df, config)
assert isinstance(agent_1._lake, SmartDatalake)

agent_2 = Agent([sample_df], config)
assert isinstance(agent_2._lake, SmartDatalake)

agent = Agent([sample_df], config)
assert isinstance(agent._lake, SmartDatalake)
# test multiple agents instances data overlap
agent_1._lake._memory.add("Which country has the highest gdp?", True)
memory = agent_1._lake._memory.all()
assert len(memory) == 1

memory = agent_2._lake._memory.all()
assert len(memory) == 0

def test_chat(self, sample_df, config):
# Create an Agent instance for testing
Expand Down Expand Up @@ -136,8 +150,7 @@ def test_clarification_questions_max_3(self, sample_df, config):
assert isinstance(questions, list)
assert len(questions) == 3

def test_explain(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
def test_explain(self, agent: Agent):
agent._lake.llm.call = Mock()
clarification_response = """
Combine the Data: To find out who gets paid the most,
Expand All @@ -163,3 +176,104 @@ def test_explain(self, sample_df, config):
It's like finding the person who has the most marbles in a game
"""
)

def test_call_prompt_success(self, agent: Agent):
agent._lake.llm.call = Mock()
clarification_response = """
What is expected Salary Increase?
"""
agent._lake.llm.call.return_value = clarification_response
prompt = ExplainPrompt("test conversation", "")
agent._call_llm_with_prompt(prompt)
assert agent._lake.llm.call.call_count == 1

def test_call_prompt_max_retries_exceeds(self, agent: Agent):
# raises exception every time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = Exception("Raise an exception")
with pytest.raises(Exception):
agent._call_llm_with_prompt("Test Prompt")

assert agent._lake.llm.call.call_count == 3

def test_call_prompt_max_retry_on_error(self, agent: Agent):
# test the LLM call failed twice but succeed third time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), Exception(), "LLM Result"]
prompt = ExplainPrompt("test conversation", "")
result = agent._call_llm_with_prompt(prompt)
assert result == "LLM Result"
assert agent._lake.llm.call.call_count == 3

def test_call_prompt_max_retry_twice(self, agent: Agent):
# test the LLM call failed once but succeed second time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), "LLM Result"]
prompt = ExplainPrompt("test conversation", "")
result = agent._call_llm_with_prompt(prompt)

assert result == "LLM Result"
assert agent._lake.llm.call.call_count == 2

def test_call_llm_with_prompt_no_retry_on_error(self, agent: Agent):
# Test when LLM call raises an exception but retries are disabled

agent._lake.config.use_error_correction_framework = False
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = Exception()
with pytest.raises(Exception):
agent._call_llm_with_prompt("Test Prompt")

assert agent._lake.llm.call.call_count == 1

def test_call_llm_with_prompt_max_retries_check(self, agent: Agent):
# Test when LLM call raises an exception, but called call function
# 'max_retries' time

agent._lake.config.max_retries = 5
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = Exception()

with pytest.raises(Exception):
agent._call_llm_with_prompt("Test Prompt")

assert agent._lake.llm.call.call_count == 5

def test_clarification_prompt_validate_output_false_case(self, agent: Agent):
# Test whether the output is json or not
agent._lake.llm.call = Mock()
agent._lake.llm.call.return_value = "This is not json"

prompt = ClarificationQuestionPrompt(
agent._lake.dfs, "test conversation", "test query"
)
with pytest.raises(Exception):
agent._call_llm_with_prompt(prompt)

def test_clarification_prompt_validate_output_true_case(self, agent: Agent):
# Test whether the output is json or not
agent._lake.llm.call = Mock()
agent._lake.llm.call.return_value = '["This is test quesiton"]'

prompt = ClarificationQuestionPrompt(
agent._lake.dfs, "test conversation", "test query"
)
result = agent._call_llm_with_prompt(prompt)
# Didn't raise any exception
assert isinstance(result, str)

def test_rephrase(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
agent._lake.llm.call = Mock()
clarification_response = """
How much has the total salary expense increased?
"""
agent._lake.llm.call.return_value = clarification_response

response = agent.rephrase_query("how much has the revenue increased?")

assert response == (
"""
How much has the total salary expense increased?
"""
)