-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[Agent]: add agent conversation code #584
Changes from 5 commits
70b0da8
1b51727
70244c3
f715035
2da890a
6736c44
cdeec68
9025f4e
d1b8e61
49d8720
b92fb39
7f17af8
2e4c902
7a554a5
f7e4d98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import pandas as pd | ||
from pandasai.agent import Agent | ||
|
||
from pandasai.llm.openai import OpenAI | ||
|
||
employees_data = { | ||
"EmployeeID": [1, 2, 3, 4, 5], | ||
"Name": ["John", "Emma", "Liam", "Olivia", "William"], | ||
"Department": ["HR", "Sales", "IT", "Marketing", "Finance"], | ||
} | ||
|
||
salaries_data = { | ||
"EmployeeID": [1, 2, 3, 4, 5], | ||
"Salary": [5000, 6000, 4500, 7000, 5500], | ||
} | ||
|
||
employees_df = pd.DataFrame(employees_data) | ||
salaries_df = pd.DataFrame(salaries_data) | ||
|
||
|
||
llm = OpenAI("OPEN_API_KEY") | ||
agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) | ||
|
||
# Chat with the agent | ||
response = agent.chat("Who gets paid the most?") | ||
print(response) | ||
|
||
|
||
# Get Clarification Questions | ||
response = agent.clarification_questions() | ||
|
||
if response: | ||
for question in response.questions: | ||
print(question) | ||
else: | ||
print(response.message) | ||
|
||
|
||
# Explain how the chat response is generated | ||
response = agent.explain() | ||
print(response) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import json | ||
from typing import Union, List, Optional | ||
from pandasai.agent.response import ClarificationResponse | ||
from pandasai.helpers.df_info import DataFrameType | ||
from pandasai.helpers.logger import Logger | ||
from pandasai.helpers.memory import Memory | ||
from pandasai.prompts.base import Prompt | ||
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt | ||
from pandasai.prompts.explain_prompt import ExplainPrompt | ||
from pandasai.schemas.df_config import Config | ||
|
||
from pandasai.smart_datalake import SmartDatalake | ||
|
||
|
||
class Agent: | ||
""" | ||
Agent class to improve the conversational experience in PandasAI | ||
""" | ||
|
||
_memory: Memory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we should rely on the memory from the |
||
_lake: SmartDatalake = None | ||
logger: Optional[Logger] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class variables - _memory: Memory
- _lake: SmartDatalake = None
- logger: Optional[Logger] = None
+ def __init__(
+ self,
+ dfs: Union[DataFrameType, List[DataFrameType]],
+ config: Optional[Union[Config, dict]] = None,
+ logger: Optional[Logger] = None,
+ memory_size: int = 1,
+ ):
+ self._memory = None
+ self._lake = None
+ self.logger = logger There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make logger private: |
||
|
||
def __init__( | ||
self, | ||
dfs: Union[DataFrameType, List[DataFrameType]], | ||
config: Optional[Union[Config, dict]] = None, | ||
logger: Optional[Logger] = None, | ||
memory_size: int = 1, | ||
): | ||
""" | ||
Args: | ||
df (Union[SmartDataframe, SmartDatalake]): _description_ | ||
memory_size (int, optional): _description_. Defaults to 1. | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
if not isinstance(dfs, list): | ||
dfs = [dfs] | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+35
to
+37
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code is assuming that if - if not isinstance(dfs, list):
- dfs = [dfs]
+ if isinstance(dfs, DataFrameType):
+ dfs = [dfs]
+ elif not isinstance(dfs, list):
+ raise TypeError("dfs must be a DataFrameType or a list of DataFrameType") |
||
self._lake = SmartDatalake(dfs, config, logger) | ||
self.logger = self._lake.logger | ||
# For the conversation multiple the memory size by 2 | ||
self._memory = Memory(memory_size * 2) | ||
|
||
def _get_conversation(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have such a method in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's important trying to reuse the code as much as possible to grant maintainability over time! |
||
""" | ||
Get Conversation from history | ||
|
||
""" | ||
return "\n".join( | ||
[ | ||
f"{'Question' if message['is_user'] else 'Answer'}: " | ||
f"{message['message']}" | ||
for i, message in enumerate(self._memory.all()) | ||
] | ||
) | ||
|
||
def chat(self, query: str, output_type: Optional[str] = None): | ||
""" | ||
Simulate a chat interaction with the assistant on Dataframe. | ||
""" | ||
try: | ||
self._memory.add(query, True) | ||
conversation = self._get_conversation() | ||
result = self._lake.chat( | ||
query, output_type=output_type, start_conversation=conversation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use the memory from the |
||
) | ||
self._memory.add(result, False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here we don't need to run it. The SmartDataframe already has its own memory, no need to do duplication |
||
return result | ||
except Exception as exception: | ||
return ( | ||
"Unfortunately, I was not able to get your answers, " | ||
"because of the following error:\n" | ||
f"\n{exception}\n" | ||
) | ||
|
||
def _get_clarification_prompt(self) -> Prompt: | ||
""" | ||
Create a clarification prompt with relevant variables. | ||
""" | ||
prompt = ClarificationQuestionPrompt() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's pass these as arguments instead (both dfs and conversation). Since we'll pass these as arguments, we won't need the |
||
prompt.set_var("dfs", self._lake.dfs) | ||
prompt.set_var("conversation", self._get_conversation()) | ||
return prompt | ||
|
||
def clarification_questions(self) -> ClarificationResponse: | ||
""" | ||
Generate clarification questions based on the data | ||
""" | ||
try: | ||
prompt = self._get_clarification_prompt() | ||
result = self._lake.llm.call(prompt) | ||
self.logger.log( | ||
f"""Clarification Questions: {result} | ||
""" | ||
) | ||
questions: list[str] = json.loads(result) | ||
return ClarificationResponse( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are not doing a lot of OOP at the moment. Let's simplify this:
|
||
success=True, questions=questions[:3], message=result | ||
) | ||
except Exception as exception: | ||
return ClarificationResponse( | ||
False, | ||
[], | ||
"Unfortunately, I was not able to get your clarification questions, " | ||
"because of the following error:\n" | ||
f"\n{exception}\n", | ||
) | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def start_new_conversation(self) -> True: | ||
""" | ||
Clears the previous conversation | ||
""" | ||
|
||
self._memory.clear() | ||
return True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
def explain(self) -> str: | ||
""" | ||
Returns the explanation of the code how it reached to the solution | ||
""" | ||
try: | ||
prompt = ExplainPrompt() | ||
prompt.set_var("code", self._lake.last_code_executed) | ||
response = self._lake.llm.call(prompt) | ||
self.logger.log( | ||
f"""Explaination: {response} | ||
""" | ||
) | ||
return response | ||
except Exception as exception: | ||
return ( | ||
"Unfortunately, I was not able to explain, " | ||
"because of the following error:\n" | ||
f"\n{exception}\n" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous points, consider raising a custom exception instead of returning an error message string. - 131: except Exception as exception:
- 132: return (
- 133: "Unfortunately, I was not able to explain, "
- 134: "because of the following error:\n"
- 135: f"\n{exception}\n"
- 136: )
+ 131: except Exception as exception:
+ 132: raise ExplanationException("Unable to generate explanation") from exception
Comment on lines
+92
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error handling in these methods only returns a string with the error message. It might be better to log the exception and re-raise it after logging. This way, the caller of the method can decide how to handle the exception. Also, this will help in debugging as the stack trace will be available in the logs. - return (
- "Unfortunately, I was not able to get your answers, "
- "because of the following error:\n"
- f"\n{exception}\n"
- )
+ self._logger.log(
+ "Unfortunately, I was not able to get your answers, "
+ "because of the following error:\n"
+ f"\n{exception}\n"
+ )
+ raise
...
- return (
- "Unfortunately, I was not able to explain, "
- "because of the following error:\n"
- f"\n{exception}\n"
- )
+ self._logger.log(
+ "Unfortunately, I was not able to explain, "
- "because of the following error:\n"
- f"\n{exception}\n"
+ )
+ raise |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import List | ||
|
||
|
||
class ClarificationResponse: | ||
""" | ||
Clarification Response | ||
|
||
""" | ||
|
||
def __init__( | ||
self, success: bool = True, questions: List[str] = None, message: str = "" | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value for - def __init__(self, success: bool = True, questions: List[str] = None, message: str = "")
+ def __init__(self, success: bool = True, questions: List[str] = [], message: str = "") |
||
""" | ||
Args: | ||
success: Whether the response generated or not. | ||
questions: List of questions | ||
""" | ||
self._success: bool = success | ||
self._questions: List[str] = questions | ||
self._message: str = message | ||
|
||
@property | ||
def questions(self) -> List[str]: | ||
return self._questions | ||
|
||
@property | ||
def message(self) -> List[str]: | ||
return self._message | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type hint for the - def message(self) -> List[str]:
+ def message(self) -> str: |
||
|
||
@property | ||
def success(self) -> bool: | ||
return self._success | ||
|
||
def __bool__(self) -> bool: | ||
""" | ||
Define the success of response. | ||
""" | ||
return self._success | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great attempt, but I suggest we remove this class for now, as we are not doing much OOP at the moment. Let's try to keep things simple as much as possible :) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,24 @@ | ||
""" Memory class to store the conversations """ | ||
import sys | ||
|
||
|
||
class Memory: | ||
"""Memory class to store the conversations""" | ||
|
||
_messages: list | ||
_max_messages: int | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
def __init__(self): | ||
def __init__(self, max_messages: int = sys.maxsize): | ||
self._messages = [] | ||
self._max_messages = max_messages | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for this (see below) |
||
|
||
def add(self, message: str, is_user: bool): | ||
self._messages.append({"message": message, "is_user": is_user}) | ||
|
||
# Delete two entry because of the conversation | ||
if len(self._messages) > self._max_messages: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for this, the |
||
del self._messages[:2] | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def count(self) -> int: | ||
return len(self._messages) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" Prompt to get clarification questions | ||
You are provided with the following pandas DataFrames: | ||
|
||
<dataframe> | ||
{dataframe} | ||
</dataframe> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also change this to
for consistency! |
||
|
||
<conversation> | ||
{conversation} | ||
</conversation> | ||
|
||
Based on the conversation, are there any clarification questions that a senior data scientist would ask? These are questions for non technical people, only ask for questions they could ask given low tech expertise and no knowledge about how the dataframes are structured. | ||
|
||
Return the JSON array of the clarification questions. If there is no clarification question, return an empty array. | ||
|
||
Json: | ||
""" # noqa: E501 | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
from .base import Prompt | ||
|
||
|
||
class ClarificationQuestionPrompt(Prompt): | ||
"""Prompt to get clarification questions""" | ||
|
||
text: str = """ | ||
You are provided with the following pandas DataFrames: | ||
|
||
<dataframe> | ||
{dataframes} | ||
</dataframe> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
<conversation> | ||
{conversation} | ||
</conversation> | ||
|
||
Based on the conversation, are there any clarification questions | ||
that a senior data scientist would ask? These are questions for non technical people, | ||
only ask for questions they could ask given low tech expertise and | ||
no knowledge about how the dataframes are structured. | ||
|
||
Return the JSON array of the clarification questions. | ||
|
||
If there is no clarification question, return an empty array. | ||
|
||
Json: | ||
""" | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" Prompt to explain solution generated | ||
Based on the last conversation you generated the code. | ||
Can you explain briefly for non technical person on how you came up with code | ||
without explaining pandas library? | ||
""" | ||
|
||
|
||
from .base import Prompt | ||
|
||
|
||
class ExplainPrompt(Prompt): | ||
"""Prompt to get clarification questions""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class docstring is misleading. This class is for generating explanation prompts, not clarification questions. Please correct it to avoid confusion. - """Prompt to get clarification questions"""
+ """Prompt to generate explanation for the code""" |
||
|
||
text: str = """ | ||
Based on the last conversation you generated the code. | ||
|
||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
<Code> | ||
{code} | ||
</Code | ||
|
||
Can you explain briefly for non technical person on how you came up with code | ||
without explaining pandas library? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's try with "Explain how you came up with code for non-technical people without mentioning technical details or mentioning the libraries used". |
||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -255,7 +255,12 @@ def _get_cache_key(self) -> str: | |
|
||
return cache_key | ||
|
||
def chat(self, query: str, output_type: Optional[str] = None): | ||
def chat( | ||
self, | ||
query: str, | ||
output_type: Optional[str] = None, | ||
start_conversation: Optional[str] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gventuri Currently the get_conversation method is called with default which has limit=1 means returning the last message only. We need to then pass the memory size to SmartLake constructor to do so or let's use memory for that. |
||
): | ||
""" | ||
Run a query on the dataframe. | ||
|
||
|
@@ -305,6 +310,9 @@ def chat(self, query: str, output_type: Optional[str] = None): | |
"save_charts_path": self._config.save_charts_path.rstrip("/"), | ||
"output_type_hint": output_type_helper.template_hint, | ||
} | ||
if start_conversation is not None: | ||
default_values["conversation"] = start_conversation | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
generate_python_code_instruction = self._get_prompt( | ||
"generate_python_code", | ||
default_prompt=GeneratePythonCodePrompt, | ||
|
@@ -623,7 +631,7 @@ def last_code_generated(self): | |
|
||
@last_code_generated.setter | ||
def last_code_generated(self, last_code_generated: str): | ||
self._code_manager._last_code_generated = last_code_generated | ||
self._last_code_generated = last_code_generated | ||
gventuri marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@property | ||
def last_code_executed(self): | ||
|
@@ -644,3 +652,7 @@ def last_error(self): | |
@last_error.setter | ||
def last_error(self, last_error: str): | ||
self._last_error = last_error | ||
|
||
@property | ||
def dfs(self): | ||
return self._dfs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can change it to:
from pandasaiimport Agent