From 70b0da8b7e3869aca0e95391e4e0f2629cf24a29 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Thu, 21 Sep 2023 19:55:54 +0500 Subject: [PATCH 01/21] feat[Agent]: add agent conversation code --- examples/agent.py | 28 ++++++ pandasai/__init__.py | 3 +- pandasai/agent/__init__.py | 93 +++++++++++++++++++ pandasai/helpers/memory.py | 7 +- .../prompts/clarification_questions_prompt.py | 47 ++++++++++ pandasai/smart_datalake/__init__.py | 14 ++- 6 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 examples/agent.py create mode 100644 pandasai/agent/__init__.py create mode 100644 pandasai/prompts/clarification_questions_prompt.py diff --git a/examples/agent.py b/examples/agent.py new file mode 100644 index 000000000..6ea63db67 --- /dev/null +++ b/examples/agent.py @@ -0,0 +1,28 @@ +import pandas as pd +from pandasai.agent import Agent + +from pandasai.llm.openai import OpenAI + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +llm = OpenAI("OPEN_API") +dl = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) +response = dl.chat("Who gets paid the most?") +print(response) +response = dl.clarification_questions() + +response = dl.chat("Which department does he belongs to?") +print(response) diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 6a63677b1..d7b22b2ec 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -44,6 +44,7 @@ from .callbacks.base import BaseCallback from .schemas.df_config import Config from .helpers.cache import Cache +from .agent import Agent __version__ = importlib.metadata.version(__package__ or __name__) @@ -257,4 +258,4 @@ def clear_cache(filename: str = None): cache.clear() -__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "clear_cache"] +__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "Agent", "clear_cache"] diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py new file mode 100644 index 000000000..0b6469318 --- /dev/null +++ b/pandasai/agent/__init__.py @@ -0,0 +1,93 @@ +import json +from typing import Union, List, Optional +from pandasai.helpers.df_info import DataFrameType +from pandasai.helpers.logger import Logger +from pandasai.helpers.memory import Memory +from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt +from pandasai.schemas.df_config import Config + +from pandasai.smart_datalake import SmartDatalake + + +class Agent: + """ + Agent class to improve the conversational experience in PandasAI + """ + + _memory: Memory + _lake: SmartDatalake = None + logger: Logger = None + + def __init__( + self, + dfs: Union[DataFrameType, List[DataFrameType]], + config: Optional[Union[Config, dict]] = None, + logger: Logger = None, + memory_size=1, + ): + """ + Args: + df (Union[SmartDataframe, SmartDatalake]): _description_ + memory_size (int, optional): _description_. Defaults to 1. + """ + + if not isinstance(dfs, list): + dfs = [dfs] + + self._lake = SmartDatalake(dfs, config, logger) + self.logger = self._lake.logger + self._memory = Memory(memory_size * 2) + + def _get_conversation(self): + """ + Get Conversation from history + + """ + return "\n".join( + [ + f"{'Question: ' if message['is_user'] else 'Answer:'}: " + f"{message['message']}" + for i, message in enumerate(self._memory.all()) + ] + ) + + def chat(self, query: str): + """ + Simulate a chat interaction with the assistant on Dataframe. + """ + self._memory.add(query, True) + conversation = self._get_conversation() + result = self._lake.chat(query, start_conversation=conversation) + self._memory.add(result, False) + return result + + def _get_clarification_prompt(self): + """ + Create a clarification prompt with relevant variables. + """ + prompt = ClarificationQuestionPrompt() + prompt.set_var("dfs", self._lake.dfs) + prompt.set_var("conversation", self._get_conversation()) + return prompt + + def clarification_questions(self): + """ + Generate and return up to three clarification questions based on a given prompt. + """ + try: + prompt = self._get_clarification_prompt() + result = self._lake.llm.generate_code(prompt) + except Exception as exception: + return ( + "Unfortunately, I was not able to get your clarification questions, " + "because of the following error:\n" + f"\n{exception}\n" + ) + questions = json.loads(result) + return questions[:3] + + def start_new_conversation(self): + """ + Clears the previous conversation + """ + self._memory.clear() diff --git a/pandasai/helpers/memory.py b/pandasai/helpers/memory.py index 5c7e01c8e..ad7478fd9 100644 --- a/pandasai/helpers/memory.py +++ b/pandasai/helpers/memory.py @@ -1,16 +1,21 @@ """ Memory class to store the conversations """ +import sys class Memory: """Memory class to store the conversations""" _messages: list + _max_messages: int - def __init__(self): + def __init__(self, max_messages=sys.maxsize): self._messages = [] + self._max_messages = max_messages def add(self, message: str, is_user: bool): self._messages.append({"message": message, "is_user": is_user}) + if len(self._messages) > self._max_messages: + del self._messages[:2] def count(self) -> int: return len(self._messages) diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py new file mode 100644 index 000000000..fc969629f --- /dev/null +++ b/pandasai/prompts/clarification_questions_prompt.py @@ -0,0 +1,47 @@ +""" Prompt to get clarification questions +You are provided with the following pandas DataFrames: + + +{dataframe} + + + +{conversation} + + +Based on the conversation, are there any clarification questions that a senior data scientist would ask? These are questions for non technical people, only ask for questions they could ask given low tech expertise and no knowledge about how the dataframes are structured. + +Return the JSON array of the clarification questions. If there is no clarification question, return an empty array. + +Json: +""" # noqa: E501 + + +from .base import Prompt + + +class ClarificationQuestionPrompt(Prompt): + """Prompt to get clarification questions""" + + text: str = """ +You are provided with the following pandas DataFrames: + + +{dataframes} + + + +{conversation} + + +Based on the conversation, are there any clarification questions +that a senior data scientist would ask? These are questions for non technical people, +only ask for questions they could ask given low tech expertise and +no knowledge about how the dataframes are structured. + +Return the JSON array of the clarification questions. + +If there is no clarification question, return an empty array. + +Json: +""" diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 77397b3b5..7f8dd64aa 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -255,7 +255,12 @@ def _get_cache_key(self) -> str: return cache_key - def chat(self, query: str, output_type: Optional[str] = None): + def chat( + self, + query: str, + output_type: Optional[str] = None, + start_conversation: Optional[str] = None, + ): """ Run a query on the dataframe. @@ -305,6 +310,9 @@ def chat(self, query: str, output_type: Optional[str] = None): "save_charts_path": self._config.save_charts_path.rstrip("/"), "output_type_hint": output_type_helper.template_hint, } + if start_conversation is not None: + default_values["conversation"] = start_conversation + generate_python_code_instruction = self._get_prompt( "generate_python_code", default_prompt=GeneratePythonCodePrompt, @@ -644,3 +652,7 @@ def last_error(self): @last_error.setter def last_error(self, last_error: str): self._last_error = last_error + + @property + def dfs(self): + return self._dfs From 1b51727a9c2e4455edbdd2e53a6f92792257cb3a Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Fri, 22 Sep 2023 14:08:12 +0500 Subject: [PATCH 02/21] feat[Agent]: add test cases for the agent class --- examples/agent.py | 10 +- pandasai/agent/__init__.py | 5 +- tests/test_agent.py | 197 +++++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+), 7 deletions(-) create mode 100644 tests/test_agent.py diff --git a/examples/agent.py b/examples/agent.py index 6ea63db67..fb89bd2b6 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -19,10 +19,10 @@ llm = OpenAI("OPEN_API") -dl = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) -response = dl.chat("Who gets paid the most?") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) +response = agent.chat("Who gets paid the most?") print(response) -response = dl.clarification_questions() - -response = dl.chat("Which department does he belongs to?") +questions = agent.clarification_questions() +print(questions) +response = agent.chat("Which department he belongs to?") print(response) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 0b6469318..f5e54815e 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -45,7 +45,7 @@ def _get_conversation(self): """ return "\n".join( [ - f"{'Question: ' if message['is_user'] else 'Answer:'}: " + f"{'Question' if message['is_user'] else 'Answer'}: " f"{message['message']}" for i, message in enumerate(self._memory.all()) ] @@ -77,13 +77,14 @@ def clarification_questions(self): try: prompt = self._get_clarification_prompt() result = self._lake.llm.generate_code(prompt) + questions = json.loads(result) except Exception as exception: return ( "Unfortunately, I was not able to get your clarification questions, " "because of the following error:\n" f"\n{exception}\n" ) - questions = json.loads(result) + return questions[:3] def start_new_conversation(self): diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 000000000..7a6fc4e14 --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,197 @@ +from typing import Optional +from unittest.mock import Mock +from pandasai.agent import Agent +import pandas as pd +import pytest +from pandasai.llm.fake import FakeLLM + +from pandasai.smart_datalake import SmartDatalake + + +class TestAgent: + "Unit tests for Agent class" + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def config(self, llm: FakeLLM): + return {"llm": llm} + + def test_constructor(self, sample_df, config): + agent = Agent(sample_df, config) + assert isinstance(agent._lake, SmartDatalake) + + agent = Agent([sample_df], config) + assert isinstance(agent._lake, SmartDatalake) + + def test_chat(self, sample_df, config): + # Create an Agent instance for testing + agent = Agent(sample_df, config) + agent._lake.chat = Mock() + agent._lake.chat.return_value = "United States has the highest gdp" + # Test the chat function + response = agent.chat("Which country has the highest gdp?") + assert agent._lake.chat.called + assert isinstance(response, str) + assert response == "United States has the highest gdp" + + def test_chat_memory(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.chat = Mock() + agent._lake.chat.return_value = "United States has the highest gdp" + + # Test the chat function + agent.chat("Which country has the highest gdp?") + + memory = agent._memory.all() + assert len(memory) == 2 + assert memory[0]["message"] == "Which country has the highest gdp?" + assert memory[1]["message"] == "United States has the highest gdp" + + # Add another conversation + agent._lake.chat.return_value = "United Kingdom has the second highest gdp" + agent.chat("Which country has the second highest gdp?") + + memory = agent._memory.all() + assert len(memory) == 4 + assert memory[0]["message"] == "Which country has the highest gdp?" + assert memory[1]["message"] == "United States has the highest gdp" + assert memory[2]["message"] == "Which country has the second highest gdp?" + assert memory[3]["message"] == "United Kingdom has the second highest gdp" + + def test_chat_memory_rollup(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=1) + agent._lake.chat = Mock() + agent._lake.chat.return_value = "United States has the highest gdp" + + # Test the chat function + agent.chat("Which country has the highest gdp?") + + memory = agent._memory.all() + assert len(memory) == 2 + assert memory[0]["message"] == "Which country has the highest gdp?" + assert memory[1]["message"] == "United States has the highest gdp" + + # Add another conversation + agent._lake.chat.return_value = "United Kingdom has the second highest gdp" + agent.chat("Which country has the second highest gdp?") + + memory = agent._memory.all() + assert len(memory) == 2 + assert memory[0]["message"] == "Which country has the second highest gdp?" + assert memory[1]["message"] == "United Kingdom has the second highest gdp" + + def test_chat_get_conversation(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.chat = Mock() + agent._lake.chat.return_value = "United States has the highest gdp" + + agent.chat("Which country has the highest gdp?") + + conversation = agent._get_conversation() + + assert conversation == ( + "Question: Which country has the highest gdp?\n" + "Answer: United States has the highest gdp" + ) + + # Add another conversation + agent._lake.chat.return_value = "United Kingdom has the second highest gdp" + agent.chat("Which country has the second highest gdp?") + + conversation = agent._get_conversation() + assert conversation == ( + "Question: Which country has the highest gdp?\n" + "Answer: United States has the highest gdp" + "\nQuestion: Which country has the second highest gdp?\n" + "Answer: United Kingdom has the second highest gdp" + ) + + def test_start_new_conversation(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.chat = Mock() + agent._lake.chat.return_value = "United States has the highest gdp" + + agent.chat("Which country has the highest gdp?") + + memory = agent._memory.all() + assert len(memory) == 2 + + agent.start_new_conversation() + memory = agent._memory.all() + assert len(memory) == 0 + + conversation = agent._get_conversation() + assert conversation == "" + + def test_clarification_questions(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.llm.generate_code = Mock() + clarification_response = ( + '["What is happiest index for you?", "What is unit of measure for gdp?"]' + ) + agent._lake.llm.generate_code.return_value = clarification_response + + questions = agent.clarification_questions() + assert len(questions) == 2 + assert questions[0] == "What is happiest index for you?" + assert questions[1] == "What is unit of measure for gdp?" + + def test_clarification_questions_max_3(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.llm.generate_code = Mock() + clarification_response = ( + '["What is happiest index for you", ' + '"What is unit of measure for gdp", ' + '"How many countries are involved in the survey", ' + '"How do you want this data to be represented"]' + ) + agent._lake.llm.generate_code.return_value = clarification_response + + questions = agent.clarification_questions() + assert len(questions) == 3 From 70244c3775dc902975e37259b960a9683c8dc0ac Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Fri, 22 Sep 2023 15:08:40 +0500 Subject: [PATCH 03/21] feat: add explain method --- pandasai/agent/__init__.py | 6 ++++ pandasai/prompts/explain_prompt.py | 47 ++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 pandasai/prompts/explain_prompt.py diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index f5e54815e..4707ee48c 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -92,3 +92,9 @@ def start_new_conversation(self): Clears the previous conversation """ self._memory.clear() + + def explain(self): + """ + Returns the explanation of the code how it reached to the solution + """ + pass diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py new file mode 100644 index 000000000..fc969629f --- /dev/null +++ b/pandasai/prompts/explain_prompt.py @@ -0,0 +1,47 @@ +""" Prompt to get clarification questions +You are provided with the following pandas DataFrames: + + +{dataframe} + + + +{conversation} + + +Based on the conversation, are there any clarification questions that a senior data scientist would ask? These are questions for non technical people, only ask for questions they could ask given low tech expertise and no knowledge about how the dataframes are structured. + +Return the JSON array of the clarification questions. If there is no clarification question, return an empty array. + +Json: +""" # noqa: E501 + + +from .base import Prompt + + +class ClarificationQuestionPrompt(Prompt): + """Prompt to get clarification questions""" + + text: str = """ +You are provided with the following pandas DataFrames: + + +{dataframes} + + + +{conversation} + + +Based on the conversation, are there any clarification questions +that a senior data scientist would ask? These are questions for non technical people, +only ask for questions they could ask given low tech expertise and +no knowledge about how the dataframes are structured. + +Return the JSON array of the clarification questions. + +If there is no clarification question, return an empty array. + +Json: +""" From f7150358f8c938f54f41d7e9e894b5fbbad648af Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Fri, 22 Sep 2023 19:58:45 +0500 Subject: [PATCH 04/21] feat: Add Explain functionality in the agent --- examples/agent.py | 21 ++++++-- pandasai/agent/__init__.py | 74 ++++++++++++++++++++-------- pandasai/agent/response.py | 38 +++++++++++++++ pandasai/helpers/memory.py | 4 +- pandasai/prompts/explain_prompt.py | 48 +++++------------- pandasai/smart_datalake/__init__.py | 2 +- tests/test_agent.py | 75 +++++++++++++++++++++++++---- 7 files changed, 191 insertions(+), 71 deletions(-) create mode 100644 pandasai/agent/response.py diff --git a/examples/agent.py b/examples/agent.py index fb89bd2b6..ae5727fc8 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -18,11 +18,24 @@ salaries_df = pd.DataFrame(salaries_data) -llm = OpenAI("OPEN_API") +llm = OpenAI("sk-lyDyNVyBwnykr1lJ4Yc7T3BlbkFJtJNyJlKTAvUa2E2D5Wdb44") agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +# Chat with the agent response = agent.chat("Who gets paid the most?") print(response) -questions = agent.clarification_questions() -print(questions) -response = agent.chat("Which department he belongs to?") + + +# Get Clarification Questions +response = agent.clarification_questions() + +if response: + for question in response.questions: + print(question) +else: + print(response.message) + + +# Explain how the chat response is generated +response = agent.explain() print(response) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 4707ee48c..6ae93b42d 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -1,9 +1,12 @@ import json from typing import Union, List, Optional +from pandasai.agent.response import ClarificationResponse from pandasai.helpers.df_info import DataFrameType from pandasai.helpers.logger import Logger from pandasai.helpers.memory import Memory +from pandasai.prompts.base import Prompt from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt +from pandasai.prompts.explain_prompt import ExplainPrompt from pandasai.schemas.df_config import Config from pandasai.smart_datalake import SmartDatalake @@ -23,7 +26,7 @@ def __init__( dfs: Union[DataFrameType, List[DataFrameType]], config: Optional[Union[Config, dict]] = None, logger: Logger = None, - memory_size=1, + memory_size: int = 1, ): """ Args: @@ -36,6 +39,7 @@ def __init__( self._lake = SmartDatalake(dfs, config, logger) self.logger = self._lake.logger + # For the conversation multiple the memory size by 2 self._memory = Memory(memory_size * 2) def _get_conversation(self): @@ -51,17 +55,26 @@ def _get_conversation(self): ] ) - def chat(self, query: str): + def chat(self, query: str, output_type: Optional[str] = None): """ Simulate a chat interaction with the assistant on Dataframe. """ - self._memory.add(query, True) - conversation = self._get_conversation() - result = self._lake.chat(query, start_conversation=conversation) - self._memory.add(result, False) - return result + try: + self._memory.add(query, True) + conversation = self._get_conversation() + result = self._lake.chat( + query, output_type=output_type, start_conversation=conversation + ) + self._memory.add(result, False) + return result + except Exception as exception: + return ( + "Unfortunately, I was not able to get your answers, " + "because of the following error:\n" + f"\n{exception}\n" + ) - def _get_clarification_prompt(self): + def _get_clarification_prompt(self) -> Prompt: """ Create a clarification prompt with relevant variables. """ @@ -70,31 +83,54 @@ def _get_clarification_prompt(self): prompt.set_var("conversation", self._get_conversation()) return prompt - def clarification_questions(self): + def clarification_questions(self) -> ClarificationResponse: """ - Generate and return up to three clarification questions based on a given prompt. + Generate clarification questions based on the data """ try: prompt = self._get_clarification_prompt() - result = self._lake.llm.generate_code(prompt) - questions = json.loads(result) + result = self._lake.llm.call(prompt) + self.logger.log( + f"""Clarification Questions: {result} + """ + ) + questions: list[str] = json.loads(result) + return ClarificationResponse( + success=True, questions=questions[:3], message=result + ) except Exception as exception: - return ( + return ClarificationResponse( + False, + [], "Unfortunately, I was not able to get your clarification questions, " "because of the following error:\n" - f"\n{exception}\n" + f"\n{exception}\n", ) - return questions[:3] - - def start_new_conversation(self): + def start_new_conversation(self) -> True: """ Clears the previous conversation """ + self._memory.clear() + return True - def explain(self): + def explain(self) -> str: """ Returns the explanation of the code how it reached to the solution """ - pass + try: + prompt = ExplainPrompt() + prompt.set_var("code", self._lake.last_code_executed) + response = self._lake.llm.call(prompt) + self.logger.log( + f"""Explaination: {response} + """ + ) + return response + except Exception as exception: + return ( + "Unfortunately, I was not able to explain, " + "because of the following error:\n" + f"\n{exception}\n" + ) diff --git a/pandasai/agent/response.py b/pandasai/agent/response.py new file mode 100644 index 000000000..1aff4423f --- /dev/null +++ b/pandasai/agent/response.py @@ -0,0 +1,38 @@ +from typing import List + + +class ClarificationResponse: + """ + Clarification Response + + """ + + def __init__( + self, success: bool = True, questions: List[str] = None, message: str = "" + ): + """ + Args: + success: Whether the response generated or not. + questions: List of questions + """ + self._success: bool = success + self._questions: List[str] = questions + self._message: str = message + + @property + def questions(self) -> List[str]: + return self._questions + + @property + def message(self) -> List[str]: + return self._message + + @property + def success(self) -> bool: + return self._success + + def __bool__(self) -> bool: + """ + Define the success of response. + """ + return self._success diff --git a/pandasai/helpers/memory.py b/pandasai/helpers/memory.py index ad7478fd9..568c2bbf1 100644 --- a/pandasai/helpers/memory.py +++ b/pandasai/helpers/memory.py @@ -8,12 +8,14 @@ class Memory: _messages: list _max_messages: int - def __init__(self, max_messages=sys.maxsize): + def __init__(self, max_messages: int = sys.maxsize): self._messages = [] self._max_messages = max_messages def add(self, message: str, is_user: bool): self._messages.append({"message": message, "is_user": is_user}) + + # Delete two entry because of the conversation if len(self._messages) > self._max_messages: del self._messages[:2] diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py index fc969629f..9f4612470 100644 --- a/pandasai/prompts/explain_prompt.py +++ b/pandasai/prompts/explain_prompt.py @@ -1,47 +1,23 @@ -""" Prompt to get clarification questions -You are provided with the following pandas DataFrames: - - -{dataframe} - - - -{conversation} - - -Based on the conversation, are there any clarification questions that a senior data scientist would ask? These are questions for non technical people, only ask for questions they could ask given low tech expertise and no knowledge about how the dataframes are structured. - -Return the JSON array of the clarification questions. If there is no clarification question, return an empty array. - -Json: -""" # noqa: E501 +""" Prompt to explain solution generated +Based on the last conversation you generated the code. +Can you explain briefly for non technical person on how you came up with code +without explaining pandas library? +""" from .base import Prompt -class ClarificationQuestionPrompt(Prompt): +class ExplainPrompt(Prompt): """Prompt to get clarification questions""" text: str = """ -You are provided with the following pandas DataFrames: - - -{dataframes} - - - -{conversation} - - -Based on the conversation, are there any clarification questions -that a senior data scientist would ask? These are questions for non technical people, -only ask for questions they could ask given low tech expertise and -no knowledge about how the dataframes are structured. - -Return the JSON array of the clarification questions. +Based on the last conversation you generated the code. -If there is no clarification question, return an empty array. + +{code} + Date: Fri, 22 Sep 2023 20:03:42 +0500 Subject: [PATCH 05/21] fix: refactor types --- examples/agent.py | 2 +- pandasai/agent/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/agent.py b/examples/agent.py index ae5727fc8..158537a19 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -18,7 +18,7 @@ salaries_df = pd.DataFrame(salaries_data) -llm = OpenAI("sk-lyDyNVyBwnykr1lJ4Yc7T3BlbkFJtJNyJlKTAvUa2E2D5Wdb44") +llm = OpenAI("OPEN_API_KEY") agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) # Chat with the agent diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 6ae93b42d..ef2c798dc 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -19,13 +19,13 @@ class Agent: _memory: Memory _lake: SmartDatalake = None - logger: Logger = None + logger: Optional[Logger] = None def __init__( self, dfs: Union[DataFrameType, List[DataFrameType]], config: Optional[Union[Config, dict]] = None, - logger: Logger = None, + logger: Optional[Logger] = None, memory_size: int = 1, ): """ From 6736c44545873ad1c91c85b777e9919c43126dea Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Fri, 22 Sep 2023 20:11:50 +0500 Subject: [PATCH 06/21] chore: fix typings --- pandasai/agent/__init__.py | 3 +-- pandasai/agent/response.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index ef2c798dc..242310d21 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -107,13 +107,12 @@ def clarification_questions(self) -> ClarificationResponse: f"\n{exception}\n", ) - def start_new_conversation(self) -> True: + def start_new_conversation(self): """ Clears the previous conversation """ self._memory.clear() - return True def explain(self) -> str: """ diff --git a/pandasai/agent/response.py b/pandasai/agent/response.py index 1aff4423f..772941667 100644 --- a/pandasai/agent/response.py +++ b/pandasai/agent/response.py @@ -8,7 +8,7 @@ class ClarificationResponse: """ def __init__( - self, success: bool = True, questions: List[str] = None, message: str = "" + self, success: bool = True, questions: List[str] = [], message: str = "" ): """ Args: @@ -24,7 +24,7 @@ def questions(self) -> List[str]: return self._questions @property - def message(self) -> List[str]: + def message(self) -> str: return self._message @property From cdeec68f25078ed1fada05a66f2b28d1d3e6691e Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Fri, 22 Sep 2023 20:27:40 +0500 Subject: [PATCH 07/21] chore: improve prompt add conversation --- pandasai/agent/__init__.py | 1 + pandasai/prompts/explain_prompt.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 242310d21..a76a16267 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -121,6 +121,7 @@ def explain(self) -> str: try: prompt = ExplainPrompt() prompt.set_var("code", self._lake.last_code_executed) + prompt.set_var("conversation", self._get_conversation()) response = self._lake.llm.call(prompt) self.logger.log( f"""Explaination: {response} diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py index 9f4612470..80b0e10b9 100644 --- a/pandasai/prompts/explain_prompt.py +++ b/pandasai/prompts/explain_prompt.py @@ -12,6 +12,12 @@ class ExplainPrompt(Prompt): """Prompt to get clarification questions""" text: str = """ +The previous conversation we had + + +{conversation} + + Based on the last conversation you generated the code. @@ -20,4 +26,5 @@ class ExplainPrompt(Prompt): Can you explain briefly for non technical person on how you came up with code without explaining pandas library? + """ From 9025f4e063fbac5df94e79320b7d72473b1765b4 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sat, 23 Sep 2023 00:30:27 +0500 Subject: [PATCH 08/21] refactor: remove memory from the agent class --- examples/agent.py | 12 +- pandasai/agent/__init__.py | 78 ++++-------- pandasai/agent/response.py | 38 ------ pandasai/helpers/memory.py | 9 +- .../prompts/clarification_questions_prompt.py | 4 + pandasai/prompts/explain_prompt.py | 10 +- pandasai/smart_datalake/__init__.py | 4 + tests/test_agent.py | 119 +++--------------- 8 files changed, 60 insertions(+), 214 deletions(-) delete mode 100644 pandasai/agent/response.py diff --git a/examples/agent.py b/examples/agent.py index 158537a19..38f419b7c 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -26,15 +26,11 @@ print(response) -# Get Clarification Questions -response = agent.clarification_questions() - -if response: - for question in response.questions: - print(question) -else: - print(response.message) +# # Get Clarification Questions +questions = agent.clarification_questions() +for question in questions: + print(question) # Explain how the chat response is generated response = agent.explain() diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index a76a16267..a3de23e70 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -1,14 +1,10 @@ import json from typing import Union, List, Optional -from pandasai.agent.response import ClarificationResponse from pandasai.helpers.df_info import DataFrameType from pandasai.helpers.logger import Logger -from pandasai.helpers.memory import Memory -from pandasai.prompts.base import Prompt from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt from pandasai.prompts.explain_prompt import ExplainPrompt from pandasai.schemas.df_config import Config - from pandasai.smart_datalake import SmartDatalake @@ -17,9 +13,9 @@ class Agent: Agent class to improve the conversational experience in PandasAI """ - _memory: Memory _lake: SmartDatalake = None - logger: Optional[Logger] = None + _logger: Optional[Logger] = None + _memory_size: int = None def __init__( self, @@ -38,34 +34,21 @@ def __init__( dfs = [dfs] self._lake = SmartDatalake(dfs, config, logger) - self.logger = self._lake.logger - # For the conversation multiple the memory size by 2 - self._memory = Memory(memory_size * 2) - - def _get_conversation(self): - """ - Get Conversation from history - - """ - return "\n".join( - [ - f"{'Question' if message['is_user'] else 'Answer'}: " - f"{message['message']}" - for i, message in enumerate(self._memory.all()) - ] - ) + self._logger = self._lake.logger + self._memory_size = memory_size def chat(self, query: str, output_type: Optional[str] = None): """ Simulate a chat interaction with the assistant on Dataframe. """ try: - self._memory.add(query, True) - conversation = self._get_conversation() result = self._lake.chat( - query, output_type=output_type, start_conversation=conversation + query, + output_type=output_type, + start_conversation=self._lake._memory.get_conversation( + self._memory_size + ), ) - self._memory.add(result, False) return result except Exception as exception: return ( @@ -74,56 +57,43 @@ def chat(self, query: str, output_type: Optional[str] = None): f"\n{exception}\n" ) - def _get_clarification_prompt(self) -> Prompt: - """ - Create a clarification prompt with relevant variables. - """ - prompt = ClarificationQuestionPrompt() - prompt.set_var("dfs", self._lake.dfs) - prompt.set_var("conversation", self._get_conversation()) - return prompt - - def clarification_questions(self) -> ClarificationResponse: + def clarification_questions(self) -> List[str]: """ Generate clarification questions based on the data """ try: - prompt = self._get_clarification_prompt() + prompt = ClarificationQuestionPrompt( + self._lake.dfs, self._lake._memory.get_conversation(self._memory_size) + ) + result = self._lake.llm.call(prompt) - self.logger.log( + self._logger.log( f"""Clarification Questions: {result} """ ) questions: list[str] = json.loads(result) - return ClarificationResponse( - success=True, questions=questions[:3], message=result - ) + return questions[:3] + except Exception as exception: - return ClarificationResponse( - False, - [], - "Unfortunately, I was not able to get your clarification questions, " - "because of the following error:\n" - f"\n{exception}\n", - ) + raise exception def start_new_conversation(self): """ Clears the previous conversation """ - - self._memory.clear() + self._lake._memory.clear() def explain(self) -> str: """ Returns the explanation of the code how it reached to the solution """ try: - prompt = ExplainPrompt() - prompt.set_var("code", self._lake.last_code_executed) - prompt.set_var("conversation", self._get_conversation()) + prompt = ExplainPrompt( + self._lake._memory.get_conversation(self._memory_size), + self._lake.last_code_executed, + ) response = self._lake.llm.call(prompt) - self.logger.log( + self._logger.log( f"""Explaination: {response} """ ) diff --git a/pandasai/agent/response.py b/pandasai/agent/response.py deleted file mode 100644 index 772941667..000000000 --- a/pandasai/agent/response.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List - - -class ClarificationResponse: - """ - Clarification Response - - """ - - def __init__( - self, success: bool = True, questions: List[str] = [], message: str = "" - ): - """ - Args: - success: Whether the response generated or not. - questions: List of questions - """ - self._success: bool = success - self._questions: List[str] = questions - self._message: str = message - - @property - def questions(self) -> List[str]: - return self._questions - - @property - def message(self) -> str: - return self._message - - @property - def success(self) -> bool: - return self._success - - def __bool__(self) -> bool: - """ - Define the success of response. - """ - return self._success diff --git a/pandasai/helpers/memory.py b/pandasai/helpers/memory.py index 568c2bbf1..5c7e01c8e 100644 --- a/pandasai/helpers/memory.py +++ b/pandasai/helpers/memory.py @@ -1,24 +1,17 @@ """ Memory class to store the conversations """ -import sys class Memory: """Memory class to store the conversations""" _messages: list - _max_messages: int - def __init__(self, max_messages: int = sys.maxsize): + def __init__(self): self._messages = [] - self._max_messages = max_messages def add(self, message: str, is_user: bool): self._messages.append({"message": message, "is_user": is_user}) - # Delete two entry because of the conversation - if len(self._messages) > self._max_messages: - del self._messages[:2] - def count(self) -> int: return len(self._messages) diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py index fc969629f..bdef252a5 100644 --- a/pandasai/prompts/clarification_questions_prompt.py +++ b/pandasai/prompts/clarification_questions_prompt.py @@ -45,3 +45,7 @@ class ClarificationQuestionPrompt(Prompt): Json: """ + + def __init__(self, dataframes, conversation): + self.set_var("dataframes", dataframes) + self.set_var("conversation", conversation) diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py index 80b0e10b9..25a4f7043 100644 --- a/pandasai/prompts/explain_prompt.py +++ b/pandasai/prompts/explain_prompt.py @@ -18,13 +18,17 @@ class ExplainPrompt(Prompt): {conversation} -Based on the last conversation you generated the code. +Based on the last conversation you generated the following code: {code} Date: Sat, 23 Sep 2023 00:33:57 +0500 Subject: [PATCH 09/21] refactor: import of Agent class in example --- examples/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/agent.py b/examples/agent.py index 38f419b7c..501c710f7 100644 --- a/examples/agent.py +++ b/examples/agent.py @@ -1,5 +1,5 @@ import pandas as pd -from pandasai.agent import Agent +from pandasai import Agent from pandasai.llm.openai import OpenAI From 49d872001b5f5d34bbce05c1303168d269a7a53e Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sat, 23 Sep 2023 01:29:58 +0500 Subject: [PATCH 10/21] refactor: memory to return conversation according to size --- pandasai/agent/__init__.py | 16 +++++----------- pandasai/helpers/memory.py | 11 +++++++++-- pandasai/smart_datalake/__init__.py | 9 +-------- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index a3de23e70..6706f4ed6 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -2,6 +2,7 @@ from typing import Union, List, Optional from pandasai.helpers.df_info import DataFrameType from pandasai.helpers.logger import Logger +from pandasai.helpers.memory import Memory from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt from pandasai.prompts.explain_prompt import ExplainPrompt from pandasai.schemas.df_config import Config @@ -33,22 +34,15 @@ def __init__( if not isinstance(dfs, list): dfs = [dfs] - self._lake = SmartDatalake(dfs, config, logger) + self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size)) self._logger = self._lake.logger - self._memory_size = memory_size def chat(self, query: str, output_type: Optional[str] = None): """ Simulate a chat interaction with the assistant on Dataframe. """ try: - result = self._lake.chat( - query, - output_type=output_type, - start_conversation=self._lake._memory.get_conversation( - self._memory_size - ), - ) + result = self._lake.chat(query, output_type=output_type) return result except Exception as exception: return ( @@ -63,7 +57,7 @@ def clarification_questions(self) -> List[str]: """ try: prompt = ClarificationQuestionPrompt( - self._lake.dfs, self._lake._memory.get_conversation(self._memory_size) + self._lake.dfs, self._lake._memory.get_conversation() ) result = self._lake.llm.call(prompt) @@ -89,7 +83,7 @@ def explain(self) -> str: """ try: prompt = ExplainPrompt( - self._lake._memory.get_conversation(self._memory_size), + self._lake._memory.get_conversation(), self._lake.last_code_executed, ) response = self._lake.llm.call(prompt) diff --git a/pandasai/helpers/memory.py b/pandasai/helpers/memory.py index 5c7e01c8e..072542d3e 100644 --- a/pandasai/helpers/memory.py +++ b/pandasai/helpers/memory.py @@ -5,9 +5,11 @@ class Memory: """Memory class to store the conversations""" _messages: list + _memory_size: int - def __init__(self): + def __init__(self, memory_size: int = 1): self._messages = [] + self._memory_size = memory_size def add(self, message: str, is_user: bool): self._messages.append({"message": message, "is_user": is_user}) @@ -21,7 +23,12 @@ def all(self) -> list: def last(self) -> dict: return self._messages[-1] - def get_conversation(self, limit: int = 1) -> str: + def get_conversation(self, limit: int = None) -> str: + """ + Returns the conversation messages based on limit parameter + or default memory size + """ + limit = self._memory_size if limit is None else limit return "\n".join( [ f"{f'User {i+1}' if message['is_user'] else f'Assistant {i}'}: " diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index ce74d7809..7b5040d0a 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -255,12 +255,7 @@ def _get_cache_key(self) -> str: return cache_key - def chat( - self, - query: str, - output_type: Optional[str] = None, - start_conversation: Optional[str] = None, - ): + def chat(self, query: str, output_type: Optional[str] = None): """ Run a query on the dataframe. @@ -310,8 +305,6 @@ def chat( "save_charts_path": self._config.save_charts_path.rstrip("/"), "output_type_hint": output_type_helper.template_hint, } - if start_conversation is not None: - default_values["conversation"] = start_conversation generate_python_code_instruction = self._get_prompt( "generate_python_code", From b92fb39243b890c4c2c17cdd5e8b90e909178d83 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sat, 23 Sep 2023 01:33:16 +0500 Subject: [PATCH 11/21] refactor: remove leftover property --- pandasai/agent/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 6706f4ed6..7c4cc5b4d 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -16,7 +16,6 @@ class Agent: _lake: SmartDatalake = None _logger: Optional[Logger] = None - _memory_size: int = None def __init__( self, From 7f17af85ccae24e37e405fa3b8642b7c4c9f36ad Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sat, 23 Sep 2023 01:39:27 +0500 Subject: [PATCH 12/21] fix: prompt comment --- pandasai/prompts/explain_prompt.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py index 25a4f7043..eaf5b909d 100644 --- a/pandasai/prompts/explain_prompt.py +++ b/pandasai/prompts/explain_prompt.py @@ -1,10 +1,20 @@ """ Prompt to explain solution generated -Based on the last conversation you generated the code. -Can you explain briefly for non technical person on how you came up with code -without explaining pandas library? -""" +The previous conversation we had + + +{conversation} + +Based on the last conversation you generated the following code: + +{code} + Date: Sat, 23 Sep 2023 12:14:30 +0500 Subject: [PATCH 13/21] fix: redundant try catch --- pandasai/agent/__init__.py | 24 ++++++++----------- .../prompts/clarification_questions_prompt.py | 4 +--- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 7c4cc5b4d..2681401e1 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -54,21 +54,17 @@ def clarification_questions(self) -> List[str]: """ Generate clarification questions based on the data """ - try: - prompt = ClarificationQuestionPrompt( - self._lake.dfs, self._lake._memory.get_conversation() - ) + prompt = ClarificationQuestionPrompt( + self._lake.dfs, self._lake._memory.get_conversation() + ) - result = self._lake.llm.call(prompt) - self._logger.log( - f"""Clarification Questions: {result} - """ - ) - questions: list[str] = json.loads(result) - return questions[:3] - - except Exception as exception: - raise exception + result = self._lake.llm.call(prompt) + self._logger.log( + f"""Clarification Questions: {result} + """ + ) + questions: list[str] = json.loads(result) + return questions[:3] def start_new_conversation(self): """ diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py index bdef252a5..94851b227 100644 --- a/pandasai/prompts/clarification_questions_prompt.py +++ b/pandasai/prompts/clarification_questions_prompt.py @@ -26,9 +26,7 @@ class ClarificationQuestionPrompt(Prompt): text: str = """ You are provided with the following pandas DataFrames: - {dataframes} - {conversation} @@ -47,5 +45,5 @@ class ClarificationQuestionPrompt(Prompt): """ def __init__(self, dataframes, conversation): - self.set_var("dataframes", dataframes) + self.set_var("dfs", dataframes) self.set_var("conversation", conversation) From 7a554a5df7b545c7556ade51bfa3289ff152a101 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sat, 23 Sep 2023 14:23:57 +0500 Subject: [PATCH 14/21] chore: improve docstring and add example in documentation --- docs/examples.md | 53 ++++++++++++++++++++++++++++++++++++++ pandasai/agent/__init__.py | 6 +++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/docs/examples.md b/docs/examples.md index 8ab0be765..acb98b50c 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -206,3 +206,56 @@ print(paid_from_males_df) # [247 rows x 11 columns] ``` + +## Working with Agent + +With the chat agent, you can engage in dynamic conversations where the agent retains context throughout the discussion. This enables you to have more interactive and meaningful exchanges. + +**Key Features** + +- **Context Retention:** The agent remembers the conversation history, allowing for seamless, context-aware interactions. + +- **Clarification Questions:** You can use the `clarification_questions` method to request clarification on any aspect of the conversation. This helps ensure you fully understand the information provided. + +- **Explanation:** The `explain` method is available to obtain detailed explanations of how the agent arrived at a particular solution or response. It offers transparency and insights into the agent's decision-making process. + +Feel free to initiate conversations, seek clarifications, and explore explanations to enhance your interactions with the chat agent! + +``` +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +llm = OpenAI("OpenAI_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +# Chat with the agent +response = agent.chat("Who gets paid the most?") +print(response) + +# Get Clarification Questions +questions = agent.clarification_questions() + +for question in questions: + print(question) + +# Explain how the chat response is generated +response = agent.explain() +print(response) +``` diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 2681401e1..0236480f3 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -26,8 +26,10 @@ def __init__( ): """ Args: - df (Union[SmartDataframe, SmartDatalake]): _description_ - memory_size (int, optional): _description_. Defaults to 1. + df (Union[DataFrameType, List[DataFrameType]]): DataFrame can be Pandas, + Polars or Database connectors + memory_size (int, optional): Conversation history to use during chat. + Defaults to 1. """ if not isinstance(dfs, list): From f7e4d9837228cdb75ef02a565253b24c385115fe Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Sun, 24 Sep 2023 19:56:00 +0500 Subject: [PATCH 15/21] fix: Comment in clarification prompts and add dtyps to the constructors --- pandasai/prompts/clarification_questions_prompt.py | 8 ++++---- pandasai/prompts/explain_prompt.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py index 94851b227..f236a26de 100644 --- a/pandasai/prompts/clarification_questions_prompt.py +++ b/pandasai/prompts/clarification_questions_prompt.py @@ -1,9 +1,7 @@ """ Prompt to get clarification questions You are provided with the following pandas DataFrames: - -{dataframe} - +{dataframes} {conversation} @@ -17,6 +15,8 @@ """ # noqa: E501 +from typing import List +import pandas as pd from .base import Prompt @@ -44,6 +44,6 @@ class ClarificationQuestionPrompt(Prompt): Json: """ - def __init__(self, dataframes, conversation): + def __init__(self, dataframes: List[pd.DataFrame], conversation: str): self.set_var("dfs", dataframes) self.set_var("conversation", conversation) diff --git a/pandasai/prompts/explain_prompt.py b/pandasai/prompts/explain_prompt.py index eaf5b909d..727202f8b 100644 --- a/pandasai/prompts/explain_prompt.py +++ b/pandasai/prompts/explain_prompt.py @@ -1,4 +1,4 @@ -""" Prompt to explain solution generated +""" Prompt to explain code generation by the LLM The previous conversation we had @@ -19,7 +19,7 @@ class ExplainPrompt(Prompt): - """Prompt to get clarification questions""" + """Prompt to explain code generation by the LLM""" text: str = """ The previous conversation we had @@ -39,6 +39,6 @@ class ExplainPrompt(Prompt): """ - def __init__(self, conversation, code): + def __init__(self, conversation: str, code: str): self.set_var("conversation", conversation) self.set_var("code", code) From 21f5bd8e9afb62666918313b0307916a83c1fd03 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 25 Sep 2023 12:35:06 +0500 Subject: [PATCH 16/21] feat(RephraseQuery): rephrase user query to get more accurate responses --- pandasai/agent/__init__.py | 19 ++++++++++ pandasai/prompts/rephase_query_prompt.py | 44 ++++++++++++++++++++++++ tests/test_agent.py | 32 ++++++++++++++--- 3 files changed, 91 insertions(+), 4 deletions(-) create mode 100644 pandasai/prompts/rephase_query_prompt.py diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 0236480f3..7266c3d7f 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -5,6 +5,7 @@ from pandasai.helpers.memory import Memory from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt from pandasai.prompts.explain_prompt import ExplainPrompt +from pandasai.prompts.rephase_query_prompt import RephraseQueryPrompt from pandasai.schemas.df_config import Config from pandasai.smart_datalake import SmartDatalake @@ -95,3 +96,21 @@ def explain(self) -> str: "because of the following error:\n" f"\n{exception}\n" ) + + def rephrase_query(self, query: str): + try: + prompt = RephraseQueryPrompt( + query, self._lake.dfs, self._lake._memory.get_conversation() + ) + response = self._lake.llm.call(prompt) + self._logger.log( + f"""Rephrased Response: {response} + """ + ) + return response + except Exception as exception: + return ( + "Unfortunately, I was not able to repharse query, " + "because of the following error:\n" + f"\n{exception}\n" + ) diff --git a/pandasai/prompts/rephase_query_prompt.py b/pandasai/prompts/rephase_query_prompt.py new file mode 100644 index 000000000..7c803f150 --- /dev/null +++ b/pandasai/prompts/rephase_query_prompt.py @@ -0,0 +1,44 @@ +""" Prompt to rephrase query to get more accurate responses +You are provided with the following pandas DataFrames: + +{dataframes} + +and based on our conversation: + + +{conversation} + + +Return the rephrased sentence of "{query}” in order to obtain more accurate and +comprehensive responses without any explanations. + +""" +from typing import List + +import pandas as pd +from .base import Prompt + + +class RephraseQueryPrompt(Prompt): + """Prompt to rephrase query to get more accurate responses""" + + text: str = """ +You are provided with the following pandas DataFrames: + +{dataframes} + +And based on our conversation: + + +{conversation} + + +Return the rephrased sentence of "{query}” in order to obtain more accurate and +comprehensive responses without any explanations. + +""" + + def __init__(self, query: str, dataframes: List[pd.DataFrame], conversation: str): + self.set_var("query", query) + self.set_var("conversation", conversation) + self.set_var("dfs", dataframes) diff --git a/tests/test_agent.py b/tests/test_agent.py index d70e91af8..917ec228e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -63,11 +63,19 @@ def config(self, llm: FakeLLM): return {"llm": llm} def test_constructor(self, sample_df, config): - agent = Agent(sample_df, config) - assert isinstance(agent._lake, SmartDatalake) + agent_1 = Agent(sample_df, config) + assert isinstance(agent_1._lake, SmartDatalake) + + agent_2 = Agent([sample_df], config) + assert isinstance(agent_2._lake, SmartDatalake) - agent = Agent([sample_df], config) - assert isinstance(agent._lake, SmartDatalake) + # test multiple agents instances + agent_1._lake._memory.add("Which country has the highest gdp?", True) + memory = agent_1._lake._memory.all() + assert len(memory) == 1 + + memory = agent_2._lake._memory.all() + assert len(memory) == 0 def test_chat(self, sample_df, config): # Create an Agent instance for testing @@ -163,3 +171,19 @@ def test_explain(self, sample_df, config): It's like finding the person who has the most marbles in a game """ ) + + def test_rephrase(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.llm.call = Mock() + clarification_response = """ +How much has the total salary expense increased? + """ + agent._lake.llm.call.return_value = clarification_response + + response = agent.rephrase_query("how much has the revenue increased?") + + assert response == ( + """ +How much has the total salary expense increased? + """ + ) From adfc86a6fd5a606b25b8d6400f3012c11429f599 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 25 Sep 2023 14:57:44 +0500 Subject: [PATCH 17/21] chore(agent): add max retries on queries --- pandasai/agent/__init__.py | 23 ++++++++++++-- tests/test_agent.py | 65 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 4 deletions(-) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 0236480f3..dbb9db26c 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -3,6 +3,7 @@ from pandasai.helpers.df_info import DataFrameType from pandasai.helpers.logger import Logger from pandasai.helpers.memory import Memory +from pandasai.prompts.base import Prompt from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt from pandasai.prompts.explain_prompt import ExplainPrompt from pandasai.schemas.df_config import Config @@ -38,6 +39,24 @@ def __init__( self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size)) self._logger = self._lake.logger + def _call_llm_with_prompt(self, prompt: Prompt): + """ + Call LLM with prompt using error handling to retry based on config + Args: + prompt (Prompt): Prompt to pass to LLM's + """ + retry_count = 0 + while retry_count < self._lake.config.max_retries: + try: + return self._lake.llm.call(prompt) + except Exception: + if ( + not self._lake.use_error_correction_framework + or retry_count >= self._lake.config.max_retries - 1 + ): + raise + retry_count += 1 + def chat(self, query: str, output_type: Optional[str] = None): """ Simulate a chat interaction with the assistant on Dataframe. @@ -60,7 +79,7 @@ def clarification_questions(self) -> List[str]: self._lake.dfs, self._lake._memory.get_conversation() ) - result = self._lake.llm.call(prompt) + result = self._call_llm_with_prompt(prompt) self._logger.log( f"""Clarification Questions: {result} """ @@ -83,7 +102,7 @@ def explain(self) -> str: self._lake._memory.get_conversation(), self._lake.last_code_executed, ) - response = self._lake.llm.call(prompt) + response = self._call_llm_with_prompt(prompt) self._logger.log( f"""Explaination: {response} """ diff --git a/tests/test_agent.py b/tests/test_agent.py index d70e91af8..8fefbd251 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -55,13 +55,17 @@ def sample_df(self): ) @pytest.fixture - def llm(self, output: Optional[str] = None): + def llm(self, output: Optional[str] = None) -> FakeLLM: return FakeLLM(output=output) @pytest.fixture - def config(self, llm: FakeLLM): + def config(self, llm: FakeLLM) -> dict: return {"llm": llm} + @pytest.fixture + def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent: + return Agent(sample_df, config) + def test_constructor(self, sample_df, config): agent = Agent(sample_df, config) assert isinstance(agent._lake, SmartDatalake) @@ -163,3 +167,60 @@ def test_explain(self, sample_df, config): It's like finding the person who has the most marbles in a game """ ) + + def test_call_prompt_success(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.llm.call = Mock() + clarification_response = """ +What is expected Salary Increase? + """ + agent._lake.llm.call.return_value = clarification_response + agent._call_llm_with_prompt("Test Prompt") + assert agent._lake.llm.call.call_count == 1 + + def test_call_prompt_max_retries_exceeds(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.llm.call = Mock() + agent._lake.llm.call.side_effect = Exception("Raise an exception") + with pytest.raises(Exception): + agent._call_llm_with_prompt("Test Prompt") + + assert agent._lake.llm.call.call_count == 3 + + def test_call_prompt_max_retry_on_error(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.llm.call = Mock() + agent._lake.llm.call.side_effect = [Exception(), Exception(), "LLM Result"] + result = agent._call_llm_with_prompt("Test Prompt") + assert result == "LLM Result" + assert agent._lake.llm.call.call_count == 3 + + def test_call_prompt_max_retry_twice(self, sample_df, config): + agent = Agent(sample_df, config, memory_size=10) + agent._lake.llm.call = Mock() + agent._lake.llm.call.side_effect = [Exception(), "LLM Result"] + result = agent._call_llm_with_prompt("Test Prompt") + assert result == "LLM Result" + assert agent._lake.llm.call.call_count == 2 + + def test_call_llm_with_prompt_no_retry_on_error(self, agent: Agent): + # Test when LLM call raises an exception but retries are disabled + + agent._lake.config.use_error_correction_framework = False + agent._lake.llm.call = Mock() + agent._lake.llm.call.side_effect = Exception() + with pytest.raises(Exception): + agent._call_llm_with_prompt("Test Prompt") + + assert agent._lake.llm.call.call_count == 1 + + def test_call_llm_with_prompt_max_retries_check(self, agent: Agent): + # Test when LLM call raises an exception but retries are disabled + + agent._lake.config.max_retries = 5 + agent._lake.llm.call = Mock() + agent._lake.llm.call.side_effect = Exception() + with pytest.raises(Exception): + agent._call_llm_with_prompt("Test Prompt") + + assert agent._lake.llm.call.call_count == 5 From bf9667ba3d0d2888c60c5efd902d8ad34afc2643 Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Mon, 25 Sep 2023 12:28:26 +0200 Subject: [PATCH 18/21] feat: improve the prompt to also add information about ambiguous parts --- pandasai/prompts/rephase_query_prompt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pandasai/prompts/rephase_query_prompt.py b/pandasai/prompts/rephase_query_prompt.py index 7c803f150..7015f4c37 100644 --- a/pandasai/prompts/rephase_query_prompt.py +++ b/pandasai/prompts/rephase_query_prompt.py @@ -34,8 +34,9 @@ class RephraseQueryPrompt(Prompt): Return the rephrased sentence of "{query}” in order to obtain more accurate and -comprehensive responses without any explanations. - +comprehensive responses without any explanations. If something from the original +query is ambiguous, please clarify it in the rephrased query, making assumptions, +if necessary. """ def __init__(self, query: str, dataframes: List[pd.DataFrame], conversation: str): From cccee44b1868b847352d7b82f1124cbdb639b98d Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 25 Sep 2023 16:30:31 +0500 Subject: [PATCH 19/21] feat[retry_wrapper]: add basic wrapper for error handling and add prompt validators --- pandasai/agent/__init__.py | 6 ++- pandasai/prompts/base.py | 3 ++ .../prompts/clarification_questions_prompt.py | 8 +++ tests/test_agent.py | 53 ++++++++++++++----- 4 files changed, 55 insertions(+), 15 deletions(-) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index dbb9db26c..b8a9eb0df 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -48,7 +48,11 @@ def _call_llm_with_prompt(self, prompt: Prompt): retry_count = 0 while retry_count < self._lake.config.max_retries: try: - return self._lake.llm.call(prompt) + result: str = self._lake.llm.call(prompt) + if prompt.validate(result): + return result + else: + raise Exception("Response validation failed!") except Exception: if ( not self._lake.use_error_correction_framework diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index f53ac9082..a31f80252 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -59,3 +59,6 @@ def to_string(self): def __str__(self): return self.to_string() + + def validate(self, output: str) -> bool: + return isinstance(output, str) diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py index f236a26de..43a3671af 100644 --- a/pandasai/prompts/clarification_questions_prompt.py +++ b/pandasai/prompts/clarification_questions_prompt.py @@ -15,6 +15,7 @@ """ # noqa: E501 +import json from typing import List import pandas as pd from .base import Prompt @@ -47,3 +48,10 @@ class ClarificationQuestionPrompt(Prompt): def __init__(self, dataframes: List[pd.DataFrame], conversation: str): self.set_var("dfs", dataframes) self.set_var("conversation", conversation) + + def validate(self, output) -> bool: + try: + json_data = json.loads(output) + return isinstance(json_data, List) + except Exception: + raise diff --git a/tests/test_agent.py b/tests/test_agent.py index 8fefbd251..9ee0f4595 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -4,6 +4,8 @@ import pandas as pd import pytest from pandasai.llm.fake import FakeLLM +from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt +from pandasai.prompts.explain_prompt import ExplainPrompt from pandasai.smart_datalake import SmartDatalake @@ -140,8 +142,7 @@ def test_clarification_questions_max_3(self, sample_df, config): assert isinstance(questions, list) assert len(questions) == 3 - def test_explain(self, sample_df, config): - agent = Agent(sample_df, config, memory_size=10) + def test_explain(self, agent: Agent): agent._lake.llm.call = Mock() clarification_response = """ Combine the Data: To find out who gets paid the most, @@ -168,18 +169,18 @@ def test_explain(self, sample_df, config): """ ) - def test_call_prompt_success(self, sample_df, config): - agent = Agent(sample_df, config, memory_size=10) + def test_call_prompt_success(self, agent: Agent): agent._lake.llm.call = Mock() clarification_response = """ What is expected Salary Increase? """ agent._lake.llm.call.return_value = clarification_response - agent._call_llm_with_prompt("Test Prompt") + prompt = ExplainPrompt("test conversation", "") + agent._call_llm_with_prompt(prompt) assert agent._lake.llm.call.call_count == 1 - def test_call_prompt_max_retries_exceeds(self, sample_df, config): - agent = Agent(sample_df, config, memory_size=10) + def test_call_prompt_max_retries_exceeds(self, agent: Agent): + # raises exception every time agent._lake.llm.call = Mock() agent._lake.llm.call.side_effect = Exception("Raise an exception") with pytest.raises(Exception): @@ -187,19 +188,22 @@ def test_call_prompt_max_retries_exceeds(self, sample_df, config): assert agent._lake.llm.call.call_count == 3 - def test_call_prompt_max_retry_on_error(self, sample_df, config): - agent = Agent(sample_df, config, memory_size=10) + def test_call_prompt_max_retry_on_error(self, agent: Agent): + # test the LLM call failed twice but succeed third time agent._lake.llm.call = Mock() agent._lake.llm.call.side_effect = [Exception(), Exception(), "LLM Result"] - result = agent._call_llm_with_prompt("Test Prompt") + prompt = ExplainPrompt("test conversation", "") + result = agent._call_llm_with_prompt(prompt) assert result == "LLM Result" assert agent._lake.llm.call.call_count == 3 - def test_call_prompt_max_retry_twice(self, sample_df, config): - agent = Agent(sample_df, config, memory_size=10) + def test_call_prompt_max_retry_twice(self, agent: Agent): + # test the LLM call failed once but succeed second time agent._lake.llm.call = Mock() agent._lake.llm.call.side_effect = [Exception(), "LLM Result"] - result = agent._call_llm_with_prompt("Test Prompt") + prompt = ExplainPrompt("test conversation", "") + result = agent._call_llm_with_prompt(prompt) + assert result == "LLM Result" assert agent._lake.llm.call.call_count == 2 @@ -215,12 +219,33 @@ def test_call_llm_with_prompt_no_retry_on_error(self, agent: Agent): assert agent._lake.llm.call.call_count == 1 def test_call_llm_with_prompt_max_retries_check(self, agent: Agent): - # Test when LLM call raises an exception but retries are disabled + # Test when LLM call raises an exception, but called call function + # 'max_retries' time agent._lake.config.max_retries = 5 agent._lake.llm.call = Mock() agent._lake.llm.call.side_effect = Exception() + with pytest.raises(Exception): agent._call_llm_with_prompt("Test Prompt") assert agent._lake.llm.call.call_count == 5 + + def test_clarification_prompt_validate_output_false_case(self, agent: Agent): + # Test whether the output is json or not + agent._lake.llm.call = Mock() + agent._lake.llm.call.return_value = "This is not json" + + prompt = ClarificationQuestionPrompt(agent._lake.dfs, "test conversation") + with pytest.raises(Exception): + agent._call_llm_with_prompt(prompt) + + def test_clarification_prompt_validate_output_true_case(self, agent: Agent): + # Test whether the output is json or not + agent._lake.llm.call = Mock() + agent._lake.llm.call.return_value = '["This is test quesiton"]' + + prompt = ClarificationQuestionPrompt(agent._lake.dfs, "test conversation") + result = agent._call_llm_with_prompt(prompt) + # Didn't raise any exception + assert isinstance(result, str) From 6fa9c1d36b3b1c13ee8d1a970bf6c2313e519bbb Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 25 Sep 2023 16:48:31 +0500 Subject: [PATCH 20/21] refactor(validation): return False from the validator in case of failure --- pandasai/prompts/clarification_questions_prompt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandasai/prompts/clarification_questions_prompt.py b/pandasai/prompts/clarification_questions_prompt.py index 43a3671af..5c70dd648 100644 --- a/pandasai/prompts/clarification_questions_prompt.py +++ b/pandasai/prompts/clarification_questions_prompt.py @@ -53,5 +53,5 @@ def validate(self, output) -> bool: try: json_data = json.loads(output) return isinstance(json_data, List) - except Exception: - raise + except json.JSONDecodeError: + return False From e9b2342d3fc8e9b67df6b467b4d85c27867842fa Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Mon, 25 Sep 2023 18:47:28 +0500 Subject: [PATCH 21/21] fix(RephraseQuery): remove conversation from the prompt if empty --- docs/getting-started.md | 11 +++++++++ pandasai/prompts/rephase_query_prompt.py | 31 +++++++++++++----------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/docs/getting-started.md b/docs/getting-started.md index cd28c6f22..609eb70cc 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -162,6 +162,17 @@ print("The answer is", response) print("The explanation is", explanation) ``` +### Rephrase Question + +Rephrase question to get accurate and comprehensive response from the model. For example: + +```python +rephrased_query = agent.rephrase_query('What is the GDP of the United States?') + +print("The answer is", rephrased_query) + +``` + ## Config When you instantiate a `SmartDataframe`, you can pass a `config` object as the second argument. This object can contain custom settings that will be used by `pandasai` when generating code. diff --git a/pandasai/prompts/rephase_query_prompt.py b/pandasai/prompts/rephase_query_prompt.py index 7015f4c37..23c86373b 100644 --- a/pandasai/prompts/rephase_query_prompt.py +++ b/pandasai/prompts/rephase_query_prompt.py @@ -2,15 +2,11 @@ You are provided with the following pandas DataFrames: {dataframes} - -and based on our conversation: - - {conversation} - - -Return the rephrased sentence of "{query}” in order to obtain more accurate and -comprehensive responses without any explanations. +Return the rephrased sentence of "{query}” in order to obtain more accurate and +comprehensive responses without any explanations. If something from the original +query is ambiguous, please clarify it in the rephrased query, making assumptions, +if necessary. """ from typing import List @@ -26,20 +22,27 @@ class RephraseQueryPrompt(Prompt): You are provided with the following pandas DataFrames: {dataframes} +{conversation} +Return the rephrased sentence of "{query}” in order to obtain more accurate and +comprehensive responses without any explanations. If something from the original +query is ambiguous, please clarify it in the rephrased query, making assumptions, +if necessary. +""" + conversation_text: str = """ And based on our conversation: {conversation} - -Return the rephrased sentence of "{query}” in order to obtain more accurate and -comprehensive responses without any explanations. If something from the original -query is ambiguous, please clarify it in the rephrased query, making assumptions, -if necessary. """ def __init__(self, query: str, dataframes: List[pd.DataFrame], conversation: str): + conversation_content = ( + self.conversation_text.format(conversation=conversation) + if conversation + else "" + ) + self.set_var("conversation", conversation_content) self.set_var("query", query) - self.set_var("conversation", conversation) self.set_var("dfs", dataframes)