Skip to content

Commit

Permalink
feat(ErrorHandling): retry query multiple times based on the max retr…
Browse files Browse the repository at this point in the history
…ies (#592)

* feat[Agent]: add agent conversation code

* feat[Agent]: add test cases for the agent class

* feat: add explain method

* feat: Add Explain functionality in the agent

* fix: refactor types

* chore: fix typings

* chore: improve prompt add conversation

* refactor: remove memory from the agent class

* refactor: import of Agent class in example

* refactor: memory to return conversation according to size

* refactor: remove leftover property

* fix: prompt comment

* fix: redundant try catch

* chore: improve docstring and add example in documentation

* fix: Comment in clarification prompts and add dtyps to the constructors

* feat(RephraseQuery): rephrase user query to get more accurate responses

* chore(agent): add max retries on queries

* feat: improve the prompt to also add information about ambiguous parts

* feat[retry_wrapper]: add basic wrapper for error handling and add prompt validators

* refactor(validation): return False from the validator in case of failure

* fix(RephraseQuery): remove conversation from the prompt if empty

---------

Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
ArslanSaleem and gventuri authored Sep 25, 2023
1 parent 09872ac commit 722ca3e
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 10 deletions.
11 changes: 11 additions & 0 deletions docs/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,17 @@ print("The answer is", response)
print("The explanation is", explanation)
```

### Rephrase Question

Rephrase question to get accurate and comprehensive response from the model. For example:

```python
rephrased_query = agent.rephrase_query('What is the GDP of the United States?')

print("The answer is", rephrased_query)

```

## Config

When you instantiate a `SmartDataframe`, you can pass a `config` object as the second argument. This object can contain custom settings that will be used by `pandasai` when generating code.
Expand Down
46 changes: 44 additions & 2 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.logger import Logger
from pandasai.helpers.memory import Memory
from pandasai.prompts.base import Prompt
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt
from pandasai.prompts.rephase_query_prompt import RephraseQueryPrompt
from pandasai.schemas.df_config import Config
from pandasai.smart_datalake import SmartDatalake

Expand Down Expand Up @@ -38,6 +40,28 @@ def __init__(
self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size))
self._logger = self._lake.logger

def _call_llm_with_prompt(self, prompt: Prompt):
"""
Call LLM with prompt using error handling to retry based on config
Args:
prompt (Prompt): Prompt to pass to LLM's
"""
retry_count = 0
while retry_count < self._lake.config.max_retries:
try:
result: str = self._lake.llm.call(prompt)
if prompt.validate(result):
return result
else:
raise Exception("Response validation failed!")
except Exception:
if (
not self._lake.use_error_correction_framework
or retry_count >= self._lake.config.max_retries - 1
):
raise
retry_count += 1

def chat(self, query: str, output_type: Optional[str] = None):
"""
Simulate a chat interaction with the assistant on Dataframe.
Expand All @@ -60,7 +84,7 @@ def clarification_questions(self, query: str) -> List[str]:
self._lake.dfs, self._lake._memory.get_conversation(), query
)

result = self._lake.llm.call(prompt)
result = self._call_llm_with_prompt(prompt)
self._logger.log(
f"""Clarification Questions: {result}
"""
Expand All @@ -83,7 +107,7 @@ def explain(self) -> str:
self._lake._memory.get_conversation(),
self._lake.last_code_executed,
)
response = self._lake.llm.call(prompt)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
f"""Explaination: {response}
"""
Expand All @@ -95,3 +119,21 @@ def explain(self) -> str:
"because of the following error:\n"
f"\n{exception}\n"
)

def rephrase_query(self, query: str):
try:
prompt = RephraseQueryPrompt(
query, self._lake.dfs, self._lake._memory.get_conversation()
)
response = self._call_llm_with_prompt(prompt)
self._logger.log(
f"""Rephrased Response: {response}
"""
)
return response
except Exception as exception:
return (
"Unfortunately, I was not able to repharse query, "
"because of the following error:\n"
f"\n{exception}\n"
)
3 changes: 3 additions & 0 deletions pandasai/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ def to_string(self):

def __str__(self):
return self.to_string()

def validate(self, output: str) -> bool:
return isinstance(output, str)
8 changes: 8 additions & 0 deletions pandasai/prompts/clarification_questions_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" # noqa: E501


import json
from typing import List
import pandas as pd
from .base import Prompt
Expand Down Expand Up @@ -49,3 +50,10 @@ def __init__(self, dataframes: List[pd.DataFrame], conversation: str, query: str
self.set_var("dfs", dataframes)
self.set_var("conversation", conversation)
self.set_var("query", query)

def validate(self, output) -> bool:
try:
json_data = json.loads(output)
return isinstance(json_data, List)
except json.JSONDecodeError:
return False
48 changes: 48 additions & 0 deletions pandasai/prompts/rephase_query_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
""" Prompt to rephrase query to get more accurate responses
You are provided with the following pandas DataFrames:
{dataframes}
{conversation}
Return the rephrased sentence of "{query}” in order to obtain more accurate and
comprehensive responses without any explanations. If something from the original
query is ambiguous, please clarify it in the rephrased query, making assumptions,
if necessary.
"""
from typing import List

import pandas as pd
from .base import Prompt


class RephraseQueryPrompt(Prompt):
"""Prompt to rephrase query to get more accurate responses"""

text: str = """
You are provided with the following pandas DataFrames:
{dataframes}
{conversation}
Return the rephrased sentence of "{query}” in order to obtain more accurate and
comprehensive responses without any explanations. If something from the original
query is ambiguous, please clarify it in the rephrased query, making assumptions,
if necessary.
"""

conversation_text: str = """
And based on our conversation:
<conversation>
{conversation}
</conversation>
"""

def __init__(self, query: str, dataframes: List[pd.DataFrame], conversation: str):
conversation_content = (
self.conversation_text.format(conversation=conversation)
if conversation
else ""
)
self.set_var("conversation", conversation_content)
self.set_var("query", query)
self.set_var("dfs", dataframes)
130 changes: 122 additions & 8 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pandas as pd
import pytest
from pandasai.llm.fake import FakeLLM
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt

from pandasai.smart_datalake import SmartDatalake

Expand Down Expand Up @@ -55,19 +57,31 @@ def sample_df(self):
)

@pytest.fixture
def llm(self, output: Optional[str] = None):
def llm(self, output: Optional[str] = None) -> FakeLLM:
return FakeLLM(output=output)

@pytest.fixture
def config(self, llm: FakeLLM):
def config(self, llm: FakeLLM) -> dict:
return {"llm": llm}

@pytest.fixture
def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent:
return Agent(sample_df, config)

def test_constructor(self, sample_df, config):
agent = Agent(sample_df, config)
assert isinstance(agent._lake, SmartDatalake)
agent_1 = Agent(sample_df, config)
assert isinstance(agent_1._lake, SmartDatalake)

agent_2 = Agent([sample_df], config)
assert isinstance(agent_2._lake, SmartDatalake)

agent = Agent([sample_df], config)
assert isinstance(agent._lake, SmartDatalake)
# test multiple agents instances data overlap
agent_1._lake._memory.add("Which country has the highest gdp?", True)
memory = agent_1._lake._memory.all()
assert len(memory) == 1

memory = agent_2._lake._memory.all()
assert len(memory) == 0

def test_chat(self, sample_df, config):
# Create an Agent instance for testing
Expand Down Expand Up @@ -136,8 +150,7 @@ def test_clarification_questions_max_3(self, sample_df, config):
assert isinstance(questions, list)
assert len(questions) == 3

def test_explain(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
def test_explain(self, agent: Agent):
agent._lake.llm.call = Mock()
clarification_response = """
Combine the Data: To find out who gets paid the most,
Expand All @@ -163,3 +176,104 @@ def test_explain(self, sample_df, config):
It's like finding the person who has the most marbles in a game
"""
)

def test_call_prompt_success(self, agent: Agent):
agent._lake.llm.call = Mock()
clarification_response = """
What is expected Salary Increase?
"""
agent._lake.llm.call.return_value = clarification_response
prompt = ExplainPrompt("test conversation", "")
agent._call_llm_with_prompt(prompt)
assert agent._lake.llm.call.call_count == 1

def test_call_prompt_max_retries_exceeds(self, agent: Agent):
# raises exception every time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = Exception("Raise an exception")
with pytest.raises(Exception):
agent._call_llm_with_prompt("Test Prompt")

assert agent._lake.llm.call.call_count == 3

def test_call_prompt_max_retry_on_error(self, agent: Agent):
# test the LLM call failed twice but succeed third time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), Exception(), "LLM Result"]
prompt = ExplainPrompt("test conversation", "")
result = agent._call_llm_with_prompt(prompt)
assert result == "LLM Result"
assert agent._lake.llm.call.call_count == 3

def test_call_prompt_max_retry_twice(self, agent: Agent):
# test the LLM call failed once but succeed second time
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = [Exception(), "LLM Result"]
prompt = ExplainPrompt("test conversation", "")
result = agent._call_llm_with_prompt(prompt)

assert result == "LLM Result"
assert agent._lake.llm.call.call_count == 2

def test_call_llm_with_prompt_no_retry_on_error(self, agent: Agent):
# Test when LLM call raises an exception but retries are disabled

agent._lake.config.use_error_correction_framework = False
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = Exception()
with pytest.raises(Exception):
agent._call_llm_with_prompt("Test Prompt")

assert agent._lake.llm.call.call_count == 1

def test_call_llm_with_prompt_max_retries_check(self, agent: Agent):
# Test when LLM call raises an exception, but called call function
# 'max_retries' time

agent._lake.config.max_retries = 5
agent._lake.llm.call = Mock()
agent._lake.llm.call.side_effect = Exception()

with pytest.raises(Exception):
agent._call_llm_with_prompt("Test Prompt")

assert agent._lake.llm.call.call_count == 5

def test_clarification_prompt_validate_output_false_case(self, agent: Agent):
# Test whether the output is json or not
agent._lake.llm.call = Mock()
agent._lake.llm.call.return_value = "This is not json"

prompt = ClarificationQuestionPrompt(
agent._lake.dfs, "test conversation", "test query"
)
with pytest.raises(Exception):
agent._call_llm_with_prompt(prompt)

def test_clarification_prompt_validate_output_true_case(self, agent: Agent):
# Test whether the output is json or not
agent._lake.llm.call = Mock()
agent._lake.llm.call.return_value = '["This is test quesiton"]'

prompt = ClarificationQuestionPrompt(
agent._lake.dfs, "test conversation", "test query"
)
result = agent._call_llm_with_prompt(prompt)
# Didn't raise any exception
assert isinstance(result, str)

def test_rephrase(self, sample_df, config):
agent = Agent(sample_df, config, memory_size=10)
agent._lake.llm.call = Mock()
clarification_response = """
How much has the total salary expense increased?
"""
agent._lake.llm.call.return_value = clarification_response

response = agent.rephrase_query("how much has the revenue increased?")

assert response == (
"""
How much has the total salary expense increased?
"""
)

0 comments on commit 722ca3e

Please sign in to comment.