Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[Agent]: add agent conversation code #584

Merged
merged 15 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pandas as pd
from pandasai import Agent

from pandasai.llm.openai import OpenAI

employees_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Name": ["John", "Emma", "Liam", "Olivia", "William"],
"Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
}

salaries_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Salary": [5000, 6000, 4500, 7000, 5500],
}

employees_df = pd.DataFrame(employees_data)
salaries_df = pd.DataFrame(salaries_data)


llm = OpenAI("OPEN_API_KEY")
agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10)

# Chat with the agent
response = agent.chat("Who gets paid the most?")
print(response)


# # Get Clarification Questions
questions = agent.clarification_questions()

for question in questions:
print(question)

# Explain how the chat response is generated
response = agent.explain()
print(response)
3 changes: 2 additions & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .callbacks.base import BaseCallback
from .schemas.df_config import Config
from .helpers.cache import Cache
from .agent import Agent

__version__ = importlib.metadata.version(__package__ or __name__)

Expand Down Expand Up @@ -257,4 +258,4 @@ def clear_cache(filename: str = None):
cache.clear()


__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "clear_cache"]
__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "Agent", "clear_cache"]
99 changes: 99 additions & 0 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import json
from typing import Union, List, Optional
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.logger import Logger
from pandasai.helpers.memory import Memory
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.prompts.explain_prompt import ExplainPrompt
from pandasai.schemas.df_config import Config
from pandasai.smart_datalake import SmartDatalake


class Agent:
"""
Agent class to improve the conversational experience in PandasAI
"""

_lake: SmartDatalake = None
_logger: Optional[Logger] = None

def __init__(
self,
dfs: Union[DataFrameType, List[DataFrameType]],
config: Optional[Union[Config, dict]] = None,
logger: Optional[Logger] = None,
memory_size: int = 1,
):
"""
Args:
df (Union[SmartDataframe, SmartDatalake]): _description_
memory_size (int, optional): _description_. Defaults to 1.
gventuri marked this conversation as resolved.
Show resolved Hide resolved
"""

if not isinstance(dfs, list):
dfs = [dfs]
gventuri marked this conversation as resolved.
Show resolved Hide resolved

gventuri marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +35 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is assuming that if dfs is not a list, it should be converted into a list. This might lead to unexpected behavior if dfs is of an unsupported type. It would be better to explicitly check for supported types and raise an error if an unsupported type is provided.

-        if not isinstance(dfs, list):
-            dfs = [dfs]
+        if isinstance(dfs, DataFrameType):
+            dfs = [dfs]
+        elif not isinstance(dfs, list):
+            raise TypeError("dfs must be a DataFrameType or a list of DataFrameType")

self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size))
self._logger = self._lake.logger

def chat(self, query: str, output_type: Optional[str] = None):
"""
Simulate a chat interaction with the assistant on Dataframe.
"""
try:
result = self._lake.chat(query, output_type=output_type)
return result
except Exception as exception:
return (
"Unfortunately, I was not able to get your answers, "
"because of the following error:\n"
f"\n{exception}\n"
)
Comment on lines +46 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception handling here is too broad, which can make debugging difficult because it hides the details of what went wrong. Consider catching specific exceptions that you expect might occur during the execution of this block. If you want to catch all exceptions, at least log the full traceback to help with debugging.


def clarification_questions(self) -> List[str]:
"""
Generate clarification questions based on the data
"""
try:
prompt = ClarificationQuestionPrompt(
self._lake.dfs, self._lake._memory.get_conversation()
)

result = self._lake.llm.call(prompt)
self._logger.log(
f"""Clarification Questions: {result}
"""
)
questions: list[str] = json.loads(result)
return questions[:3]

except Exception as exception:
raise exception
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the clarification_questions method, exceptions are caught and immediately re-raised. This is redundant and can be removed, as uncaught exceptions will automatically propagate up the call stack.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem agree, I suggest we remove try/catch wrapper. In the future we might add a retry mechanism!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like humans are talking. I'll be quiet for now. If you need me again, just tag @coderabbitai in a new comment.


def start_new_conversation(self):
gventuri marked this conversation as resolved.
Show resolved Hide resolved
"""
Clears the previous conversation
"""
self._lake._memory.clear()

gventuri marked this conversation as resolved.
Show resolved Hide resolved
def explain(self) -> str:
"""
Returns the explanation of the code how it reached to the solution
"""
try:
prompt = ExplainPrompt(
self._lake._memory.get_conversation(),
self._lake.last_code_executed,
)
response = self._lake.llm.call(prompt)
self._logger.log(
f"""Explaination: {response}
"""
)
return response
except Exception as exception:
return (
"Unfortunately, I was not able to explain, "
"because of the following error:\n"
f"\n{exception}\n"
)
Comment on lines +92 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error handling in these methods only returns a string with the error message. It might be better to log the exception and re-raise it after logging. This way, the caller of the method can decide how to handle the exception. Also, this will help in debugging as the stack trace will be available in the logs.

-            return (
-                "Unfortunately, I was not able to get your answers, "
-                "because of the following error:\n"
-                f"\n{exception}\n"
-            )
+            self._logger.log(
+                "Unfortunately, I was not able to get your answers, "
+                "because of the following error:\n"
+                f"\n{exception}\n"
+            )
+            raise

...

-            return (
-                "Unfortunately, I was not able to explain, "
-                "because of the following error:\n"
-                f"\n{exception}\n"
-            )
+            self._logger.log(
+                "Unfortunately, I was not able to explain, "
-                "because of the following error:\n"
-                f"\n{exception}\n"
+            )
+            raise

Comment on lines +81 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous comment, the exception handling here is too broad. Consider catching specific exceptions that you expect might occur during the execution of this block. If you want to catch all exceptions, at least log the full traceback to help with debugging.

11 changes: 9 additions & 2 deletions pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ class Memory:
"""Memory class to store the conversations"""

_messages: list
_memory_size: int

def __init__(self):
def __init__(self, memory_size: int = 1):
self._messages = []
self._memory_size = memory_size
Comment on lines +10 to +12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __init__ method now accepts an optional parameter memory_size which defaults to 1. This is used to limit the number of stored messages in memory. However, there's no logic implemented yet to enforce this limit when adding new messages. Consider adding a check in the add method to remove the oldest message(s) when the limit is exceeded.

     def add(self, message: str, is_user: bool):
         self._messages.append({"message": message, "is_user": is_user})
+         while len(self._messages) > self._memory_size:
+             self._messages.pop(0)


def add(self, message: str, is_user: bool):
self._messages.append({"message": message, "is_user": is_user})
Expand All @@ -21,7 +23,12 @@ def all(self) -> list:
def last(self) -> dict:
return self._messages[-1]

def get_conversation(self, limit: int = 1) -> str:
def get_conversation(self, limit: int = None) -> str:
"""
Returns the conversation messages based on limit parameter
or default memory size
"""
limit = self._memory_size if limit is None else limit
return "\n".join(
[
f"{f'User {i+1}' if message['is_user'] else f'Assistant {i}'}: "
Expand Down
51 changes: 51 additions & 0 deletions pandasai/prompts/clarification_questions_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
""" Prompt to get clarification questions
You are provided with the following pandas DataFrames:

<dataframe>
{dataframe}
</dataframe>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also change this to

{dataframes}

for consistency!


<conversation>
{conversation}
</conversation>

Based on the conversation, are there any clarification questions that a senior data scientist would ask? These are questions for non technical people, only ask for questions they could ask given low tech expertise and no knowledge about how the dataframes are structured.

Return the JSON array of the clarification questions. If there is no clarification question, return an empty array.

Json:
""" # noqa: E501
gventuri marked this conversation as resolved.
Show resolved Hide resolved


from .base import Prompt


class ClarificationQuestionPrompt(Prompt):
"""Prompt to get clarification questions"""

text: str = """
You are provided with the following pandas DataFrames:

<dataframe>
{dataframes}
</dataframe>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{dataframes} only should be enough, the set_var method in prompts already takes care of wrapping each one individually.


<conversation>
{conversation}
</conversation>

Based on the conversation, are there any clarification questions
that a senior data scientist would ask? These are questions for non technical people,
only ask for questions they could ask given low tech expertise and
no knowledge about how the dataframes are structured.

Return the JSON array of the clarification questions.

If there is no clarification question, return an empty array.

Json:
"""
gventuri marked this conversation as resolved.
Show resolved Hide resolved
gventuri marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, dataframes, conversation):
self.set_var("dataframes", dataframes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def set_var(self, var, value):
        if var == "dfs":
            self._args["dataframes"] = self._generate_dataframes(value)
        self._args[var] = value

This method from Prompt automatically configures the dfs to be used in the prompt, but it requires a dfs key.
Let's change this to self.set_var("dfs", dataframes)

self.set_var("conversation", conversation)
gventuri marked this conversation as resolved.
Show resolved Hide resolved
44 changes: 44 additions & 0 deletions pandasai/prompts/explain_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
""" Prompt to explain solution generated
The previous conversation we had

<Conversation>
{conversation}
</Conversation>

Based on the last conversation you generated the following code:

<Code>
{code}
</Code

Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring at the top of the file seems to be a copy of the prompt text. It's not providing any useful information about the module or its contents. Consider replacing it with a more informative docstring that describes the purpose and functionality of the module.

from .base import Prompt


class ExplainPrompt(Prompt):
"""Prompt to get clarification questions"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class docstring is misleading. This class is for generating explanation prompts, not clarification questions. Please correct it to avoid confusion.

-     """Prompt to get clarification questions"""
+     """Prompt to generate explanation for the code"""


text: str = """
The previous conversation we had

<Conversation>
{conversation}
</Conversation>

Based on the last conversation you generated the following code:

gventuri marked this conversation as resolved.
Show resolved Hide resolved
<Code>
{code}
</Code

Explain how you came up with code for non-technical people without
mentioning technical details or mentioning the libraries used?

"""
Comment on lines +24 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The text attribute in the ExplainPrompt class is a class variable, which means it's shared across all instances of this class. If you modify it in one instance, it will affect all other instances. This could lead to unexpected behavior if multiple instances of ExplainPrompt are used concurrently. Consider moving this attribute to the instance level (inside __init__) to avoid potential issues.


def __init__(self, conversation, code):
self.set_var("conversation", conversation)
self.set_var("code", code)
gventuri marked this conversation as resolved.
Show resolved Hide resolved
11 changes: 10 additions & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def chat(self, query: str, output_type: Optional[str] = None):
"save_charts_path": self._config.save_charts_path.rstrip("/"),
"output_type_hint": output_type_helper.template_hint,
}

generate_python_code_instruction = self._get_prompt(
"generate_python_code",
default_prompt=GeneratePythonCodePrompt,
Expand Down Expand Up @@ -623,7 +624,7 @@ def last_code_generated(self):

@last_code_generated.setter
def last_code_generated(self, last_code_generated: str):
self._code_manager._last_code_generated = last_code_generated
self._last_code_generated = last_code_generated
gventuri marked this conversation as resolved.
Show resolved Hide resolved

@property
def last_code_executed(self):
Expand All @@ -644,3 +645,11 @@ def last_error(self):
@last_error.setter
def last_error(self, last_error: str):
self._last_error = last_error

@property
def dfs(self):
return self._dfs

@property
def memory(self):
return self._memory
gventuri marked this conversation as resolved.
Show resolved Hide resolved
Loading