Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[Agent]: add agent conversation code #584

Merged
merged 15 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd
from pandasai.agent import Agent
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can change it to:
from pandasaiimport Agent


from pandasai.llm.openai import OpenAI

employees_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Name": ["John", "Emma", "Liam", "Olivia", "William"],
"Department": ["HR", "Sales", "IT", "Marketing", "Finance"],
}

salaries_data = {
"EmployeeID": [1, 2, 3, 4, 5],
"Salary": [5000, 6000, 4500, 7000, 5500],
}

employees_df = pd.DataFrame(employees_data)
salaries_df = pd.DataFrame(salaries_data)


llm = OpenAI("OPEN_API")
agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10)
response = agent.chat("Who gets paid the most?")
print(response)
questions = agent.clarification_questions()
print(questions)
response = agent.chat("Which department he belongs to?")
print(response)
3 changes: 2 additions & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .callbacks.base import BaseCallback
from .schemas.df_config import Config
from .helpers.cache import Cache
from .agent import Agent

__version__ = importlib.metadata.version(__package__ or __name__)

Expand Down Expand Up @@ -257,4 +258,4 @@ def clear_cache(filename: str = None):
cache.clear()


__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "clear_cache"]
__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "Agent", "clear_cache"]
94 changes: 94 additions & 0 deletions pandasai/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
from typing import Union, List, Optional
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.logger import Logger
from pandasai.helpers.memory import Memory
from pandasai.prompts.clarification_questions_prompt import ClarificationQuestionPrompt
from pandasai.schemas.df_config import Config

from pandasai.smart_datalake import SmartDatalake


class Agent:
"""
Agent class to improve the conversational experience in PandasAI
"""

_memory: Memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should rely on the memory from the SmartDatalake instead, since we have access to it.

_lake: SmartDatalake = None
logger: Logger = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger attribute is initialized as None but it's not checked for None before being used in line 38. This could potentially lead to a NoneType error if the logger is not passed during initialization and later accessed.

-    logger: Logger = None
+    logger: Optional[Logger] = None


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class variables _memory, _lake, and logger are declared but not initialized in the class scope. This could lead to confusion as it might imply that these are static class variables, which they are not. It would be better to declare these variables inside the __init__ method.

class Agent:
    """
    Agent class to improve the conversational experience in PandasAI
    """

- 20:     _memory: Memory
- 21:     _lake: SmartDatalake = None
- 22:     logger: Logger = None
+ 20:     def __init__(
+ 21:         self,
+ 22:         dfs: Union[DataFrameType, List[DataFrameType]],
+ 23:         config: Optional[Union[Config, dict]] = None,
+ 24:         logger: Logger = None,
+ 25:         memory_size: int = 1,
+ 26:     ):
+ 27:         """
+ 28:         Args:
+ 29:             df (Union[SmartDataframe, SmartDatalake]): _description_
+ 30:             memory_size (int, optional): _description_. Defaults to 1.
+ 31:         """
+ 32: 
+ 33:         if not isinstance(dfs, list):
+ 34:             dfs = [dfs]
+ 35: 
+ 36:         self._lake = SmartDatalake(dfs, config, logger)
+ 37:         self.logger = self._lake.logger
+ 38:         # For the conversation multiple the memory size by 2
+ 39:         self._memory = Memory(memory_size * 2)

def __init__(
self,
dfs: Union[DataFrameType, List[DataFrameType]],
config: Optional[Union[Config, dict]] = None,
logger: Logger = None,
memory_size=1,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memory_size parameter in the constructor lacks type hinting, which can lead to confusion about what type of argument should be passed. It's also missing from the docstring. Add an integer type hint and update the docstring accordingly.

-        memory_size=1,
+        memory_size: int = 1,

"""
Args:
df (Union[SmartDataframe, SmartDatalake]): _description_
memory_size (int, optional): _description_. Defaults to 1.
gventuri marked this conversation as resolved.
Show resolved Hide resolved
"""

if not isinstance(dfs, list):
dfs = [dfs]
gventuri marked this conversation as resolved.
Show resolved Hide resolved

gventuri marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +35 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is assuming that if dfs is not a list, it should be converted into a list. This might lead to unexpected behavior if dfs is of an unsupported type. It would be better to explicitly check for supported types and raise an error if an unsupported type is provided.

-        if not isinstance(dfs, list):
-            dfs = [dfs]
+        if isinstance(dfs, DataFrameType):
+            dfs = [dfs]
+        elif not isinstance(dfs, list):
+            raise TypeError("dfs must be a DataFrameType or a list of DataFrameType")

self._lake = SmartDatalake(dfs, config, logger)
self.logger = self._lake.logger
self._memory = Memory(memory_size * 2)

def _get_conversation(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have such a method in the memory.py file. What's the advantage or replicating it instead of using the one from the Memory?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important trying to reuse the code as much as possible to grant maintainability over time!

"""
Get Conversation from history

"""
return "\n".join(
[
f"{'Question' if message['is_user'] else 'Answer'}: "
f"{message['message']}"
for i, message in enumerate(self._memory.all())
]
)

def chat(self, query: str):
"""
Simulate a chat interaction with the assistant on Dataframe.
"""
self._memory.add(query, True)
conversation = self._get_conversation()
result = self._lake.chat(query, start_conversation=conversation)
self._memory.add(result, False)
return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chat method does not handle exceptions. If an exception occurs during the execution of the self._lake.chat() method, it will propagate up the call stack and could potentially crash the application. Consider adding a try/except block to handle exceptions gracefully.


def _get_clarification_prompt(self):
"""
Create a clarification prompt with relevant variables.
"""
prompt = ClarificationQuestionPrompt()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's pass these as arguments instead (both dfs and conversation). Since we'll pass these as arguments, we won't need the _get_clarification_prompt method anymore

prompt.set_var("dfs", self._lake.dfs)
prompt.set_var("conversation", self._get_conversation())
return prompt

def clarification_questions(self):
"""
Generate and return up to three clarification questions based on a given prompt.
"""
try:
prompt = self._get_clarification_prompt()
result = self._lake.llm.generate_code(prompt)
questions = json.loads(result)
except Exception as exception:
return (
"Unfortunately, I was not able to get your clarification questions, "
"because of the following error:\n"
f"\n{exception}\n"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it's good that you're handling exceptions in the clarification_questions method, returning a string with the error message might not be the best approach. It would be more appropriate to log the error and raise the exception again after logging, so that the caller can decide how to handle it.


return questions[:3]

def start_new_conversation(self):
gventuri marked this conversation as resolved.
Show resolved Hide resolved
"""
Clears the previous conversation
"""
self._memory.clear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The start_new_conversation method clears the memory but does not return any confirmation or status. It would be helpful to return a status message or boolean value indicating whether the operation was successful.

7 changes: 6 additions & 1 deletion pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
""" Memory class to store the conversations """
import sys


class Memory:
"""Memory class to store the conversations"""

_messages: list
_max_messages: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _max_messages attribute is not defined in the class scope, which might lead to confusion and potential errors. It's better to define it within the __init__ method.


def __init__(self):
def __init__(self, max_messages=sys.maxsize):
self._messages = []
self._max_messages = max_messages
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this (see below)


def add(self, message: str, is_user: bool):
self._messages.append({"message": message, "is_user": is_user})
if len(self._messages) > self._max_messages:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this, the get_conversation method of memory will automatically limit based on the limit variable that is passed.

del self._messages[:2]
gventuri marked this conversation as resolved.
Show resolved Hide resolved

def count(self) -> int:
return len(self._messages)
Expand Down
47 changes: 47 additions & 0 deletions pandasai/prompts/clarification_questions_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
""" Prompt to get clarification questions
You are provided with the following pandas DataFrames:

<dataframe>
{dataframe}
</dataframe>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also change this to

{dataframes}

for consistency!


<conversation>
{conversation}
</conversation>

Based on the conversation, are there any clarification questions that a senior data scientist would ask? These are questions for non technical people, only ask for questions they could ask given low tech expertise and no knowledge about how the dataframes are structured.

Return the JSON array of the clarification questions. If there is no clarification question, return an empty array.

Json:
""" # noqa: E501
gventuri marked this conversation as resolved.
Show resolved Hide resolved


from .base import Prompt


class ClarificationQuestionPrompt(Prompt):
"""Prompt to get clarification questions"""

text: str = """
You are provided with the following pandas DataFrames:

<dataframe>
{dataframes}
</dataframe>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{dataframes} only should be enough, the set_var method in prompts already takes care of wrapping each one individually.


<conversation>
{conversation}
</conversation>

Based on the conversation, are there any clarification questions
that a senior data scientist would ask? These are questions for non technical people,
only ask for questions they could ask given low tech expertise and
no knowledge about how the dataframes are structured.

Return the JSON array of the clarification questions.

If there is no clarification question, return an empty array.

Json:
"""
gventuri marked this conversation as resolved.
Show resolved Hide resolved
gventuri marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 13 additions & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,12 @@ def _get_cache_key(self) -> str:

return cache_key

def chat(self, query: str, output_type: Optional[str] = None):
def chat(
self,
query: str,
output_type: Optional[str] = None,
start_conversation: Optional[str] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this

Copy link
Collaborator Author

@ArslanSaleem ArslanSaleem Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gventuri Currently the get_conversation method is called with default which has limit=1 means returning the last message only. We need to then pass the memory size to SmartLake constructor to do so or let's use memory for that.

):
"""
Run a query on the dataframe.

Expand Down Expand Up @@ -305,6 +310,9 @@ def chat(self, query: str, output_type: Optional[str] = None):
"save_charts_path": self._config.save_charts_path.rstrip("/"),
"output_type_hint": output_type_helper.template_hint,
}
if start_conversation is not None:
default_values["conversation"] = start_conversation
gventuri marked this conversation as resolved.
Show resolved Hide resolved

generate_python_code_instruction = self._get_prompt(
"generate_python_code",
default_prompt=GeneratePythonCodePrompt,
Expand Down Expand Up @@ -644,3 +652,7 @@ def last_error(self):
@last_error.setter
def last_error(self, last_error: str):
self._last_error = last_error

@property
def dfs(self):
return self._dfs
Loading