diff --git a/docs/connectors.md b/docs/connectors.md index 1850b9a7c..58d11a970 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -89,6 +89,28 @@ df = SmartDataframe(mysql_connector) df.chat('What is the total amount of loans in the last year?') ``` +### Sqlite connector + +Similarly to the PostgreSQL and MySQL connectors, the Sqlite connector allows you to connect to a local Sqlite database file. It is designed to be easy to use, even if you are not familiar with Sqlite or with PandasAI. + +To use the Sqlite connector, you only need to import it into your Python code and pass it to a `SmartDataframe` or `SmartDatalake` object: + +```python +from pandasai.connectors import SqliteConnector + +connector = SqliteConnector(config={ + "database" : "PATH_TO_DB", + "table" : "actor", + "where" :[ + ["first_name","=","PENELOPE"] + ] +}) + +df = SmartDataframe(connector) +df.chat('How many records are there ?') +``` + + ### Generic SQL connector The generic SQL connector allows you to connect to any SQL database that is supported by SQLAlchemy. diff --git a/docs/examples.md b/docs/examples.md index a3c07abcf..424afaf23 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -260,3 +260,57 @@ for question in questions: response = agent.explain() print(response) ``` + +## Add Skills to the Agent + +You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations. + +``` +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +@skill( + name="Display employee salary", + description="Plots the employee salaries against names", + usage="Displays the plot having name on x axis and salaries on y axis", +) +def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) + +``` diff --git a/docs/getting-started.md b/docs/getting-started.md index 609eb70cc..9fe7a59b5 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -190,6 +190,7 @@ Settings: - `save_charts_path`: the path where to save the charts. Defaults to `exports/charts/`. You can use this setting to override the default path. - `enable_cache`: whether to enable caching. Defaults to `True`. If set to `True`, PandasAI will cache the results of the LLM to improve the response time. If set to `False`, PandasAI will always call the LLM. - `use_error_correction_framework`: whether to use the error correction framework. Defaults to `True`. If set to `True`, PandasAI will try to correct the errors in the code generated by the LLM with further calls to the LLM. If set to `False`, PandasAI will not try to correct the errors in the code generated by the LLM. +- `use_advanced_reasoning_framework`: whether to use the advanced reasoning framework. Defaults to `False`. If set to `True`, PandasAI will try to use advanced reasoning to improve the results of the LLM and provide an explanation for the results. - `max_retries`: the maximum number of retries to use when using the error correction framework. Defaults to `3`. You can use this setting to override the default number of retries. - `custom_prompts`: the custom prompts to use. Defaults to `{}`. You can use this setting to override the default custom prompts. You can find more information about custom prompts [here](custom-prompts.md). - `custom_whitelisted_dependencies`: the custom whitelisted dependencies to use. Defaults to `{}`. You can use this setting to override the default custom whitelisted dependencies. You can find more information about custom whitelisted dependencies [here](custom-whitelisted-dependencies.md). diff --git a/docs/skills.md b/docs/skills.md new file mode 100644 index 000000000..ae6488219 --- /dev/null +++ b/docs/skills.md @@ -0,0 +1,113 @@ +# Skills + +You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations. + +## Example Usage + +```python + +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Function doc string to give more context to the model for use this skill +@skill +def plot_salaries(name: list[str], salaries: list[int]): + """ + Displays the bar chart having name on x axis and salaries on y axis + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + # plot bars + import matplotlib.pyplot as plt + + plt.bar(name, salaries) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") + + +``` + +## Add Streamlit Skill + +```python +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill +import streamlit as st + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Function doc string to give more context to the model for use this skill +@skill +def plot_salaries(name: list[str], salary: list[int]): + """ + Displays the bar chart having name on x axis and salaries on y axis using streamlit + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + import matplotlib.pyplot as plt + + plt.bar(name, salary) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + fig = plt.gcf() + st.pyplot(fig) + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries_using_streamlit) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) +``` diff --git a/examples/from_sql.py b/examples/from_sql.py index c8d435559..626df3ff3 100644 --- a/examples/from_sql.py +++ b/examples/from_sql.py @@ -2,7 +2,7 @@ from pandasai import SmartDatalake from pandasai.llm import OpenAI -from pandasai.connectors import MySQLConnector, PostgreSQLConnector +from pandasai.connectors import MySQLConnector, PostgreSQLConnector, SqliteConnector # With a MySQL database loan_connector = MySQLConnector( @@ -38,8 +38,19 @@ } ) +# With a Sqlite databse + +invoice_connector = SqliteConnector( + config={ + "database": "local_path_to_db", + "table": "invoices", + "where": [["status", "=", "pending"]], + } +) llm = OpenAI() -df = SmartDatalake([loan_connector, payment_connector], config={"llm": llm}) +df = SmartDatalake( + [loan_connector, payment_connector, invoice_connector], config={"llm": llm} +) response = df.chat("How many people from the United states?") print(response) # Output: 247 loans have been paid off by men. diff --git a/examples/skills_example.py b/examples/skills_example.py new file mode 100644 index 000000000..e1df24d99 --- /dev/null +++ b/examples/skills_example.py @@ -0,0 +1,47 @@ +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +# Add function docstring to give more context to model +@skill +def plot_salaries(name: list[str], salary: list[int]) -> str: + """ + Displays the bar chart having name on x axis and salaries on y axis using streamlit + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + import matplotlib.pyplot as plt + + plt.bar(name, salary) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + + +llm = OpenAI("YOUR-API-KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) diff --git a/examples/using_pandasai_log_server.py b/examples/using_pandasai_log_server.py new file mode 100644 index 000000000..04e0857e0 --- /dev/null +++ b/examples/using_pandasai_log_server.py @@ -0,0 +1,59 @@ +import os +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Example 1: Using Environment Variables +os.environ["LOGGING_SERVER_URL"] = "SERVER_URL" +os.environ["LOGGING_SERVER_API_KEY"] = "YOUR_API_KEY" + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent( + [employees_df, salaries_df], + config={ + "llm": llm, + "enable_cache": True, + }, + memory_size=10, +) + +# Chat with the agent +response = agent.chat("Plot salary against department?") +print(response) + + +# Example 2: Using Config +llm = OpenAI("YOUR_API_KEY") +agent = Agent( + [employees_df, salaries_df], + config={ + "llm": llm, + "enable_cache": True, + "log_server": { + "server_url": "SERVER_URL", + "api_key": "YOUR_API_KEY", + }, + }, + memory_size=10, +) + +# Chat with the agent +response = agent.chat("Plot salary against department?") + +print(response) diff --git a/mkdocs.yml b/mkdocs.yml index 2aa148616..28b6b6701 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -27,6 +27,7 @@ nav: - callbacks.md - custom-instructions.md - custom-prompts.md + - skills.md - custom-whitelisted-dependencies.md - Examples: - examples.md diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 5e1a08ea3..2c71d8116 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -45,6 +45,7 @@ from .schemas.df_config import Config from .helpers.cache import Cache from .agent import Agent +from .skills import skill __version__ = importlib.metadata.version(__package__ or __name__) @@ -257,4 +258,11 @@ def clear_cache(filename: str = None): cache.clear() -__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "Agent", "clear_cache"] +__all__ = [ + "PandasAI", + "SmartDataframe", + "SmartDatalake", + "Agent", + "clear_cache", + "skill", +] diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 46d88c736..52034bac0 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -1,5 +1,7 @@ import json from typing import Union, List, Optional + +from pandasai.skills import skill from ..helpers.df_info import DataFrameType from ..helpers.logger import Logger from ..helpers.memory import Memory @@ -41,8 +43,18 @@ def __init__( dfs = [dfs] self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size)) + + # set instance type in SmartDataLake + self._lake.set_instance_type(self.__class__.__name__) + self._logger = self._lake.logger + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self._lake.add_skills(*skills) + def _call_llm_with_prompt(self, prompt: AbstractPrompt): """ Call LLM with prompt using error handling to retry based on config @@ -70,7 +82,8 @@ def chat(self, query: str, output_type: Optional[str] = None): Simulate a chat interaction with the assistant on Dataframe. """ try: - self.check_if_related_to_conversation(query) + is_related = self.check_if_related_to_conversation(query) + self._lake.is_related_query(is_related) result = self._lake.chat(query, output_type=output_type) return result except Exception as exception: @@ -80,7 +93,7 @@ def chat(self, query: str, output_type: Optional[str] = None): f"\n{exception}\n" ) - def check_if_related_to_conversation(self, query: str): + def check_if_related_to_conversation(self, query: str) -> bool: """ Check if the query is related to the previous conversation """ @@ -105,6 +118,8 @@ def check_if_related_to_conversation(self, query: str): if not related: self._lake.clear_memory() + return related + def clarification_questions(self, query: str) -> List[str]: """ Generate clarification questions based on the data @@ -170,3 +185,23 @@ def rephrase_query(self, query: str): "because of the following error:\n" f"\n{exception}\n" ) + + @property + def last_code_generated(self): + return self._lake.last_code_generated + + @property + def last_code_executed(self): + return self._lake.last_code_executed + + @property + def last_prompt(self): + return self._lake.last_prompt + + @property + def last_reasoning(self): + return self._lake.last_reasoning + + @property + def last_answer(self): + return self._lake.last_answer diff --git a/pandasai/assets/prompt_templates/advanced_reasoning.tmpl b/pandasai/assets/prompt_templates/advanced_reasoning.tmpl new file mode 100644 index 000000000..7d132a2e0 --- /dev/null +++ b/pandasai/assets/prompt_templates/advanced_reasoning.tmpl @@ -0,0 +1,3 @@ +- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags. +- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later). +- return the updated analyze_data function wrapped within ```python ``` \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl b/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl index 48f86a72b..058867102 100644 --- a/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl +++ b/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl @@ -6,4 +6,4 @@ {query} -Is the query related to the conversation? Answer only "true" or "false" (lowercase). \ No newline at end of file +Is the query somehow related to the previous conversation? Answer only "true" or "false" (lowercase). \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 93b52b3b5..b0c337e44 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -6,11 +6,12 @@ You are provided with the following pandas DataFrames: {conversation} -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python {current_code} ``` - -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. - -Return the updated code: \ No newline at end of file +{skills} +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: +{reasoning} \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/simple_reasoning.tmpl b/pandasai/assets/prompt_templates/simple_reasoning.tmpl new file mode 100644 index 000000000..c728ffb0d --- /dev/null +++ b/pandasai/assets/prompt_templates/simple_reasoning.tmpl @@ -0,0 +1 @@ +- return the updated analyze_data function wrapped within ```python ``` \ No newline at end of file diff --git a/pandasai/connectors/__init__.py b/pandasai/connectors/__init__.py index fb80c8628..a484835ff 100644 --- a/pandasai/connectors/__init__.py +++ b/pandasai/connectors/__init__.py @@ -10,6 +10,7 @@ from .databricks import DatabricksConnector from .yahoo_finance import YahooFinanceConnector from .airtable import AirtableConnector +from .sql import SqliteConnector __all__ = [ "BaseConnector", @@ -20,4 +21,5 @@ "SnowFlakeConnector", "DatabricksConnector", "AirtableConnector", + "SqliteConnector", ] diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index 4de547486..3193616bf 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -39,6 +39,15 @@ class SQLBaseConnectorConfig(BaseConnectorConfig): dialect: Optional[str] = None +class SqliteConnectorConfig(SQLBaseConnectorConfig): + """ + Connector configurations for sqlite db. + """ + + table: str + database: str + + class YahooFinanceConnectorConfig(BaseConnectorConfig): """ Connector configuration for Yahoo Finance. diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 215da08f4..8222bed21 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -5,7 +5,7 @@ import re import os import pandas as pd -from .base import BaseConnector, SQLConnectorConfig +from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig from .base import BaseConnectorConfig from sqlalchemy import create_engine, text, select, asc from sqlalchemy.engine import Connection @@ -364,6 +364,90 @@ def fallback_name(self): return self._config.table +class SqliteConnector(SQLConnector): + """ + Sqlite connector are used to connect to Sqlite databases. + """ + + def __init__(self, config: Union[SqliteConnectorConfig, dict]): + """ + Intialize the Sqlite connector with the given configuration. + + Args: + config (ConnectorConfig) : The configuration for the MySQL connector. + """ + config["dialect"] = "sqlite" + if isinstance(config, dict): + sqlite_env_vars = {"database": "SQLITE_DB_PATH", "table": "TABLENAME"} + config = self._populate_config_from_env(config, sqlite_env_vars) + + super().__init__(config) + + def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): + """ + Loads passed Configuration to object + + Args: + config (BaseConnectorConfig): Construct config in structure + + Returns: + config: BaseConenctorConfig + """ + return SqliteConnectorConfig(**config) + + def _init_connection(self, config: SqliteConnectorConfig): + """ + Initialize Database Connection + + Args: + config (SQLConnectorConfig): Configurations to load database + + """ + self._engine = create_engine(f"{config.dialect}:///{config.database}") + self._connection = self._engine.connect() + + def __del__(self): + """ + Close the connection to the SQL database. + """ + self._connection.close() + + @cache + def head(self): + """ + Return the head of the data source that the connector is connected to. + This information is passed to the LLM to provide the schema of the data source. + + Returns: + DataFrame: The head of the data source. + """ + + if self.logger: + self.logger.log( + f"Getting head of {self._config.table} " + f"using dialect {self._config.dialect}" + ) + + # Run a SQL query to get all the columns names and 5 random rows + query = self._build_query(limit=5, order="RANDOM()") + + # Return the head of the data source + return pd.read_sql(query, self._connection) + + def __repr__(self): + """ + Return the string representation of the SQL connector. + + Returns: + str: The string representation of the SQL connector. + """ + return ( + f"<{self.__class__.__name__} dialect={self._config.dialect} " + f"database={self._config.database} " + f"table={self._config.table}>" + ) + + class MySQLConnector(SQLConnector): """ MySQL connectors are used to connect to MySQL databases. diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index a165abb4b..56551ff34 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -136,3 +136,13 @@ def __init__(self, template_path, prompt_name="Unknown"): f"Unable to find a file with template at '{template_path}' " f"for '{prompt_name}' prompt." ) + + +class AdvancedReasoningDisabledError(Exception): + """ + Raised when one tries to have access to the answer or reasoning without + having use_advanced_reasoning_framework enabled. + + Args: + Exception (Exception): AdvancedReasoningDisabledError + """ diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index fe8a6721d..cb47fe2e3 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -6,6 +6,8 @@ import astor import pandas as pd +from pandasai.helpers.skills_manager import SkillsManager + from .node_visitors import AssignmentVisitor, CallVisitor from .save_chart import add_save_chart from .optional import import_dependency @@ -23,6 +25,29 @@ import traceback +class CodeExecutionContext: + _prompt_id: uuid.UUID = None + _skills_manager: SkillsManager = None + + def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager): + """ + Additional Context for code execution + Args: + prompt_id (uuid.UUID): prompt unique id + skill (List): list[functions] of skills added + """ + self._skills_manager = skills_manager + self._prompt_id = prompt_id + + @property + def prompt_id(self): + return self._prompt_id + + @property + def skills_manager(self): + return self._skills_manager + + class CodeManager: _dfs: List _middlewares: List[Middleware] = [ChartsMiddleware()] @@ -180,11 +205,7 @@ def _required_dfs(self, code: str) -> List[str]: required_dfs.append(None) return required_dfs - def execute_code( - self, - code: str, - prompt_id: uuid.UUID, - ) -> Any: + def execute_code(self, code: str, context: CodeExecutionContext) -> Any: """ Execute the python code generated by LLMs to answer the question about the input dataframe. Run the code in the current context and return the @@ -192,7 +213,8 @@ def execute_code( Args: code (str): Python code to execute. - prompt_id (uuid.UUID): UUID of the request. + context (CodeExecutionContext): Code Execution Context + with prompt id and skills. Returns: Any: The result of the code execution. The type of the result depends @@ -209,12 +231,15 @@ def execute_code( code = add_save_chart( code, logger=self._logger, - file_name=str(prompt_id), + file_name=str(context.prompt_id), save_charts_path_str=self._config.save_charts_path, ) + # Reset used skills + context.skills_manager.used_skills = [] + # Get the code to run removing unsafe imports and df overwrites - code_to_run = self._clean_code(code) + code_to_run = self._clean_code(code, context) self.last_code_executed = code_to_run self._logger.log( f""" @@ -228,6 +253,13 @@ def execute_code( # if the code does not need them dfs = self._required_dfs(code_to_run) environment: dict = self._get_environment() + + # Add Skills in the env + if len(context.skills_manager.used_skills) > 0: + for skill_func_name in context.skills_manager.used_skills: + skill = context.skills_manager.get_skill_by_func_name(skill_func_name) + environment[skill_func_name] = skill + environment["dfs"] = self._get_samples(dfs) caught_error = self._execute_catching_errors(code_to_run, environment) @@ -293,7 +325,6 @@ def _get_environment(self) -> dict: Returns (dict): A dictionary of environment variables """ - return { "pd": pd, **{ @@ -377,7 +408,7 @@ def _sanitize_analyze_data(self, analyze_data_node: ast.stmt) -> ast.stmt: analyze_data_node.body = sanitized_analyze_data return analyze_data_node - def _clean_code(self, code: str) -> str: + def _clean_code(self, code: str, context: CodeExecutionContext) -> str: """ A method to clean the code to prevent malicious code execution. @@ -400,11 +431,24 @@ def _clean_code(self, code: str) -> str: if isinstance(node, (ast.Import, ast.ImportFrom)): self._check_imports(node) continue + if isinstance(node, ast.FunctionDef) and node.name == "analyze_data": analyze_data_node = node sanitized_analyze_data = self._sanitize_analyze_data(analyze_data_node) + + # Walk inside the function def for used skills + if len(context.skills_manager.skills) > 0: + for node in ast.walk(analyze_data_node): + # Checks for function to get skill name + if isinstance(node, ast.Call) and isinstance( + node.func, ast.Name + ): + function_name = node.func.id + context.skills_manager.add_used_skill(function_name) + new_body.append(sanitized_analyze_data) continue + new_body.append(node) new_tree = ast.Module(body=new_body) diff --git a/pandasai/helpers/output_types/_output_types.py b/pandasai/helpers/output_types/_output_types.py index 7aa78889b..18b9eff42 100644 --- a/pandasai/helpers/output_types/_output_types.py +++ b/pandasai/helpers/output_types/_output_types.py @@ -140,7 +140,7 @@ def template_hint(self): return """- type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) Examples: - { "type": "string", "value": "The highest salary is $9,000." } + { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py new file mode 100644 index 000000000..76585498b --- /dev/null +++ b/pandasai/helpers/query_exec_tracker.py @@ -0,0 +1,283 @@ +import base64 +import os +import time +from typing import Any, List, TypedDict, Union + +import requests +from collections import defaultdict + + +class ResponseType(TypedDict): + type: str + value: Any + + +exec_steps = { + "cache_hit": "Cache Hit", + "_get_prompt": "Generate Prompt", + "generate_code": "Generate Code", + "execute_code": "Code Execution", + "_retry_run_code": "Retry Code Generation", + "parse": "Parse Output", +} + + +class QueryExecTracker: + _query_info: dict + _dataframes: List + _response: ResponseType + _steps: List + _func_exec_count: dict + _success: bool + _server_config: dict + + def __init__( + self, + server_config: Union[dict, None] = None, + ) -> None: + self._success = False + self._start_time = None + self._server_config = server_config + self._query_info = {} + self._is_related_query = True + + def set_related_query(self, flag: bool): + """ + Set Related Query Parameter whether new query is related to the conversation + or not + Args: + flag (bool): boolean to set true if related else false + """ + self._is_related_query = flag + + def add_query_info( + self, conversation_id: str, instance: str, query: str, output_type: str + ): + """ + Adds query information for new track + Args: + conversation_id (str): conversation id + instance (str): instance like Agent or SmartDataframe + query (str): chat query given by user + output_type (str): output type expected by user + """ + self._query_info = { + "conversation_id": str(conversation_id), + "instance": instance, + "query": query, + "output_type": output_type, + "is_related_query": self._is_related_query, + } + + def start_new_track(self): + """ + Resets tracking variables to start new track + """ + self._start_time = time.time() + self._dataframes: List = [] + self._response: ResponseType = {} + self._steps: List = [] + self._query_info = {} + self._func_exec_count: dict = defaultdict(int) + + def add_dataframes(self, dfs: List) -> None: + """ + Add used dataframes for the query to query exec tracker + Args: + dfs (List[SmartDataFrame]): List of dataframes + """ + for df in dfs: + head = df.head_df + self._dataframes.append( + {"headers": head.columns.tolist(), "rows": head.values.tolist()} + ) + + def add_step(self, step: dict) -> None: + """ + Add Custom Step that is performed for additional information + Args: + step (dict): dictionary containing information + """ + self._steps.append(step) + + def execute_func(self, function, *args, **kwargs) -> Any: + """ + Tracks function executions, calculates execution time and prepare data + Args: + function (function): Function that is to be executed + + Returns: + Any: Response return after function execution + """ + start_time = time.time() + + # Get the tag from kwargs if provided, or use the function name as the default + tag = kwargs.pop("tag", function.__name__) + + try: + result = function(*args, **kwargs) + + execution_time = time.time() - start_time + if tag not in exec_steps: + return result + + step_data = self._generate_exec_step(tag, result) + + step_data["success"] = True + step_data["execution_time"] = execution_time + + self._steps.append(step_data) + + return result + + except Exception: + execution_time = time.time() - start_time + self._steps.append( + { + "type": exec_steps[tag], + "success": False, + "execution_time": execution_time, + } + ) + raise + + def _generate_exec_step(self, func_name: str, result: Any) -> dict: + """ + Extracts and Generates result + Args: + func_name (str): function name that is executed + result (Any): function output response + + Returns: + dict: dictionary with information about the function execution + """ + + step = {"type": exec_steps[func_name]} + + if func_name == "cache_hit": + step["code_generated"] = result + + elif func_name == "generate_code": + step["code_generated"] = result[0] + step["reasoning"] = result[1] + step["answer"] = result[2] + + elif func_name == "_retry_run_code": + self._func_exec_count["_retry_run_code"] += 1 + + step[ + "type" + ] = f"{exec_steps[func_name]} ({self._func_exec_count['_retry_run_code']})" + step["code_generated"] = result[0] + step["reasoning"] = result[1] + step["answer"] = result[2] + + elif func_name == "_get_prompt": + step["prompt_class"] = result.__class__.__name__ + step["generated_prompt"] = result.to_string() + + elif func_name == "execute_code": + self._response = self._format_response(result) + step["result"] = self._response + + return step + + def _format_response(self, result: ResponseType) -> ResponseType: + """ + Format output response + Args: + result (ResponseType): response returned after execution + + Returns: + ResponseType: formatted response output + """ + formatted_result = {} + if result["type"] == "dataframe": + formatted_result = { + "type": result["type"], + "value": { + "headers": result["value"].columns.tolist(), + "rows": result["value"].values.tolist(), + }, + } + return formatted_result + elif result["type"] == "plot": + with open(result["value"], "rb") as image_file: + image_data = image_file.read() + # Encode the image data to Base64 + base64_image = ( + f"data:image/png;base64,{base64.b64encode(image_data).decode()}" + ) + formatted_result = { + "type": result["type"], + "value": base64_image, + } + + return formatted_result + else: + return result + + def get_summary(self) -> dict: + """ + Returns the summary in json to steps involved in execution of track + Returns: + dict: summary json + """ + if self._start_time is None: + raise RuntimeError("[QueryExecTracker]: Tracking not started") + + execution_time = time.time() - self._start_time + return { + "query_info": self._query_info, + "dataframes": self._dataframes, + "steps": self._steps, + "response": self._response, + "execution_time": execution_time, + "success": self._success, + } + + def get_execution_time(self) -> float: + return time.time() - self._start_time + + def publish(self) -> None: + """ + Publish Query Summary to remote logging server + """ + api_key = None + server_url = None + + if self._server_config is None: + server_url = os.environ.get("LOGGING_SERVER_URL") + api_key = os.environ.get("LOGGING_SERVER_API_KEY") + else: + server_url = self._server_config.get( + "server_url", os.environ.get("LOGGING_SERVER_URL") + ) + api_key = self._server_config.get( + "api_key", os.environ.get("LOGGING_SERVER_API_KEY") + ) + + if api_key is None or server_url is None: + return + + try: + log_data = { + "json_log": self.get_summary(), + } + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.post( + f"{server_url}/api/log/add", json=log_data, headers=headers + ) + if response.status_code != 200: + raise Exception(response.text) + + except Exception as e: + print(f"Exception in APILogger: {e}") + + @property + def success(self) -> bool: + return self._success + + @success.setter + def success(self, value: bool): + self._success = value diff --git a/pandasai/helpers/skills_manager.py b/pandasai/helpers/skills_manager.py new file mode 100644 index 000000000..c63c5ff6f --- /dev/null +++ b/pandasai/helpers/skills_manager.py @@ -0,0 +1,103 @@ +from typing import List + +# from pandasai.skills import skill + + +class SkillsManager: + """ + Manages Custom added Skills and tracks used skills for the query + """ + + _skills: List + _used_skills: List[str] + + def __init__(self) -> None: + self._skills = [] + self._used_skills = [] + + def add_skills(self, *skills): + """ + Add skills to the list of skills. If a skill with the same name + already exists, raise an error. + + Args: + *skills: Variable number of skill objects to add. + """ + for skill in skills: + if any( + existing_skill.name == skill.name for existing_skill in self._skills + ): + raise ValueError(f"Skill with name '{skill.name}' already exists.") + + self._skills.extend(skills) + + def skill_exists(self, name: str): + """ + Check if a skill with the given name exists in the list of skills. + + Args: + name (str): The name of the skill to check. + + Returns: + bool: True if a skill with the given name exists, False otherwise. + """ + return any(skill.name == name for skill in self._skills) + + def get_skill_by_func_name(self, name: str): + """ + Get a skill by its name. + + Args: + name (str): The name of the skill to retrieve. + + Returns: + Skill or None: The skill with the given name, or None if not found. + """ + for skill in self._skills: + if skill.name == name: + return skill + + return None + + def add_used_skill(self, skill: str): + if self.skill_exists(skill): + self._used_skills.append(skill) + + def __str__(self) -> str: + """ + Present all skills + Returns: + str: _description_ + """ + skills_repr = "" + for skill in self._skills: + skills_repr = skills_repr + skill.print + + return skills_repr + + def prompt_display(self) -> str: + """ + Displays skills for prompt + """ + if len(self._skills) == 0: + return + + return ( + """ +You can also use the following functions, if relevant: + +""" + + self.__str__() + ) + + @property + def used_skills(self): + return self._used_skills + + @used_skills.setter + def used_skills(self, value): + self._used_skills = value + + @property + def skills(self): + return self._skills diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py index 651edd34b..1ff64d57b 100644 --- a/pandasai/llm/base.py +++ b/pandasai/llm/base.py @@ -120,6 +120,60 @@ def _extract_code(self, response: str, separator: str = "```") -> str: return code + def _extract_tag_text(self, response: str, tag: str) -> str: + """ + Extracts the text between two tags in the response. + + Args: + response (str): Response + tag (str): Tag name + + Returns: + (str or None): Extracted text from the response + """ + + match = re.search( + f"(<{tag}>)(.*)()", + response, + re.DOTALL | re.MULTILINE, + ) + if match: + return match.group(2) + return None + + def _extract_reasoning(self, response: str) -> str: + """ + Extracts the reasoning from the response (wrapped in tags). + + Args: + response (str): Response + + Returns: + (str or None): Extracted reasoning from the response + """ + + return self._extract_tag_text(response, "reasoning") + + def _extract_answer(self, response: str) -> str: + """ + Extracts the answer from the response (wrapped in tags). + + Args: + response (str): Response + + Returns: + (str or None): Extracted answer from the response + """ + + sentences = [ + sentence + for sentence in response.split(". ") + if "temp_chart.png" not in sentence + ] + answer = ". ".join(sentences) + + return self._extract_tag_text(answer, "answer") + @abstractmethod def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ @@ -135,7 +189,7 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ raise MethodNotImplementedError("Call method has not been implemented") - def generate_code(self, instruction: AbstractPrompt) -> str: + def generate_code(self, instruction: AbstractPrompt) -> [str, str, str]: """ Generate the code based on the instruction and the given prompt. @@ -146,8 +200,12 @@ def generate_code(self, instruction: AbstractPrompt) -> str: str: A string of Python code. """ - code = self.call(instruction, suffix="") - return self._extract_code(code) + response = self.call(instruction, suffix="") + return [ + self._extract_code(response), + self._extract_reasoning(response), + self._extract_answer(response), + ] class BaseOpenAI(LLM, ABC): diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index 430ec37eb..7a692a18e 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -11,6 +11,7 @@ class AbstractPrompt(ABC): """ _args: dict = None + _config: dict = None def __init__(self, **kwargs): """ @@ -27,6 +28,9 @@ def __init__(self, **kwargs): def setup(self, **kwargs) -> None: pass + def on_prompt_generation(self) -> None: + pass + def _generate_dataframes(self, dfs): """ Generate the dataframes metadata @@ -58,6 +62,17 @@ def _generate_dataframes(self, dfs): def template(self) -> str: ... + def set_config(self, config): + self._config = config + + def get_config(self, key=None): + if self._config is None: + return None + if key is None: + return self._config + if hasattr(self._config, key): + return getattr(self._config, key) + def set_var(self, var, value): if self._args is None: self._args = {} @@ -72,10 +87,18 @@ def set_vars(self, vars): self._args.update(vars) def to_string(self): + self.on_prompt_generation() + prompt_args = {} for key, value in self._args.items(): if isinstance(value, AbstractPrompt): - value.set_vars({k: v for k, v in self._args.items() if k != key}) + value.set_vars( + { + k: v + for k, v in self._args.items() + if k != key and not isinstance(v, AbstractPrompt) + } + ) prompt_args[key] = value.to_string() else: prompt_args[key] = value diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 0c6ca22eb..513767d95 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -8,12 +8,13 @@ {conversation} -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. {current_code} -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. - -Return the updated code:""" # noqa: E501 +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: +- return the updated analyze_data function wrapped within ```python ```""" # noqa: E501 from .file_based_prompt import FileBasedPrompt @@ -25,18 +26,24 @@ class CurrentCodePrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/current_code.tmpl" +class AdvancedReasoningPrompt(FileBasedPrompt): + """The current code""" + + _path_to_template = "assets/prompt_templates/advanced_reasoning.tmpl" + + +class SimpleReasoningPrompt(FileBasedPrompt): + """The current code""" + + _path_to_template = "assets/prompt_templates/simple_reasoning.tmpl" + + class GeneratePythonCodePrompt(FileBasedPrompt): """Prompt to generate Python code""" _path_to_template = "assets/prompt_templates/generate_python_code.tmpl" def setup(self, **kwargs) -> None: - default_import = "import pandas as pd" - engine_df_name = "pd.DataFrame" - - self.set_var("default_import", default_import) - self.set_var("engine_df_name", engine_df_name) - if "custom_instructions" in kwargs: self._set_instructions(kwargs["custom_instructions"]) else: @@ -44,7 +51,7 @@ def setup(self, **kwargs) -> None: """Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501 +3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 ) if "current_code" in kwargs: @@ -52,6 +59,17 @@ def setup(self, **kwargs) -> None: else: self.set_var("current_code", CurrentCodePrompt()) + def on_prompt_generation(self) -> None: + default_import = "import pandas as pd" + engine_df_name = "pd.DataFrame" + + self.set_var("default_import", default_import) + self.set_var("engine_df_name", engine_df_name) + if self.get_config("use_advanced_reasoning_framework"): + self.set_var("reasoning", AdvancedReasoningPrompt()) + else: + self.set_var("reasoning", SimpleReasoningPrompt()) + def _set_instructions(self, instructions: str): lines = instructions.split("\n") indented_lines = [" " + line for line in lines[1:]] diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index 176167e24..7b89d3898 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, validator, Field -from typing import Optional, List, Any, Dict, Type +from typing import Optional, List, Any, Dict, Type, TypedDict from pandasai.responses import ResponseParser from ..middlewares.base import Middleware @@ -8,12 +8,18 @@ from ..exceptions import LLMNotFoundError +class LogServerConfig(TypedDict): + server_url: str + api_key: str + + class Config(BaseModel): save_logs: bool = True verbose: bool = False enforce_privacy: bool = False enable_cache: bool = True use_error_correction_framework: bool = True + use_advanced_reasoning_framework: bool = False custom_prompts: Dict = Field(default_factory=dict) custom_instructions: Optional[str] = None open_charts: bool = True @@ -26,6 +32,7 @@ class Config(BaseModel): lazy_load_connector: bool = True response_parser: Type[ResponseParser] = None llm: Any = None + log_server: LogServerConfig = None class Config: arbitrary_types_allowed = True diff --git a/pandasai/skills/__init__.py b/pandasai/skills/__init__.py new file mode 100644 index 000000000..cc82c5b2f --- /dev/null +++ b/pandasai/skills/__init__.py @@ -0,0 +1,32 @@ +import inspect + + +def skill(skill_function): + def wrapped_function(*args, **kwargs): + return skill_function(*args, **kwargs) + + wrapped_function.name = skill_function.__name__ + wrapped_function.func_def = ( + """def pandasai.skills.{funcion_name}{signature}""".format( + funcion_name=wrapped_function.name, + signature=str(inspect.signature(skill_function)), + ) + ) + + doc_string = skill_function.__doc__ + + wrapped_function.print = ( + """ + +{signature} +{doc_string} + +""" + ).format( + signature=wrapped_function.func_def, + doc_string=""" \"\"\"{0}\n \"\"\"""".format(doc_string) + if doc_string is not None + else "", + ) + + return wrapped_function diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 3672e7d01..0e59cf9d6 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -26,6 +26,7 @@ import pydantic from pandasai.helpers.df_validator import DfValidator +from pandasai.skills import skill from ..smart_datalake import SmartDatalake from ..schemas.df_config import Config @@ -302,6 +303,9 @@ def __init__( self._table_description = description self._lake = SmartDatalake([self], config, logger) + # set instance type in SmartDataLake + self._lake.set_instance_type(self.__class__.__name__) + # If no name is provided, use the fallback name provided the connector if self._table_name is None and self.connector: self._table_name = self.connector.fallback_name @@ -319,6 +323,12 @@ def add_middlewares(self, *middlewares: Optional[Middleware]): """ self.lake.add_middlewares(*middlewares) + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self.lake.add_skills(*skills) + def chat(self, query: str, output_type: Optional[str] = None): """ Run a query on the dataframe. @@ -682,6 +692,14 @@ def sample_head(self): data = StringIO(self._sample_head) return pd.read_csv(data) + @property + def last_reasoning(self): + return self.lake.last_reasoning + + @property + def last_answer(self): + return self.lake.last_answer + @sample_head.setter def sample_head(self, sample_head: pd.DataFrame): self._sample_head = sample_head.to_csv(index=False) diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 02e29bce9..34285b5a1 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -17,12 +17,15 @@ # The average loan amount is $15,000. ``` """ - -import time import uuid import logging import os import traceback +from pandasai.helpers.skills_manager import SkillsManager + +from pandasai.skills import skill + +from pandasai.helpers.query_exec_tracker import QueryExecTracker from ..helpers.output_types import output_type_factory from pandasai.responses.context import Context @@ -38,10 +41,11 @@ from ..prompts.correct_error_prompt import CorrectErrorPrompt from ..prompts.generate_python_code import GeneratePythonCodePrompt from typing import Union, List, Any, Type, Optional -from ..helpers.code_manager import CodeManager +from ..helpers.code_manager import CodeExecutionContext, CodeManager from ..middlewares.base import Middleware from ..helpers.df_info import DataFrameType from ..helpers.path import find_project_root +from ..exceptions import AdvancedReasoningDisabledError class SmartDatalake: @@ -50,12 +54,17 @@ class SmartDatalake: _llm: LLM _cache: Cache = None _logger: Logger - _start_time: float _last_prompt_id: uuid.UUID + _conversation_id: uuid.UUID _code_manager: CodeManager _memory: Memory + _skills: SkillsManager + _instance: str + _query_exec_tracker: QueryExecTracker _last_code_generated: str = None + _last_reasoning: str = None + _last_answer: str = None _last_result: str = None _last_error: str = None @@ -98,6 +107,8 @@ def __init__( logger=self.logger, ) + self._skills = SkillsManager() + if cache: self._cache = cache elif self._config.enable_cache: @@ -110,6 +121,20 @@ def __init__( else: self._response_parser = ResponseParser(context) + self._conversation_id = uuid.uuid4() + + self._instance = self.__class__.__name__ + + self._query_exec_tracker = QueryExecTracker( + server_config=self._config.log_server, + ) + + def set_instance_type(self, type: str): + self._instance = type + + def is_related_query(self, flag: bool): + self._query_exec_tracker.set_related_query(flag) + def initialize(self): """Initialize the SmartDatalake""" @@ -190,10 +215,11 @@ def add_middlewares(self, *middlewares: Optional[Middleware]): """ self._code_manager.add_middlewares(*middlewares) - def _start_timer(self): - """Start the timer""" - - self._start_time = time.time() + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self._skills.add_skills(*skills) def _assign_prompt_id(self): """Assign a prompt ID""" @@ -228,10 +254,16 @@ def _get_prompt( prompt = custom_prompt if custom_prompt else default_prompt() # set default values for the prompt + prompt.set_config(self._config) if "dfs" not in default_values: prompt.set_var("dfs", self._dfs) if "conversation" not in default_values: prompt.set_var("conversation", self._memory.get_conversation()) + + # Adds the skills to prompt if exist else display nothing + skills_prompt = self._skills.prompt_display() + prompt.set_var("skills", skills_prompt if skills_prompt is not None else "") + for key, value in default_values.items(): prompt.set_var(key, value) @@ -279,13 +311,19 @@ def chat(self, query: str, output_type: Optional[str] = None): ValueError: If the query is empty """ - self._start_timer() + self._query_exec_tracker.start_new_track() self.logger.log(f"Question: {query}") self.logger.log(f"Running PandasAI with {self._llm.type} LLM...") self._assign_prompt_id() + self._query_exec_tracker.add_query_info( + self._conversation_id, self._instance, query, output_type + ) + + self._query_exec_tracker.add_dataframes(self._dfs) + self._memory.add(query, True) try: @@ -297,7 +335,10 @@ def chat(self, query: str, output_type: Optional[str] = None): and self._cache.get(self._get_cache_key()) ): self.logger.log("Using cached response") - code = self._cache.get(self._get_cache_key()) + code = self._query_exec_tracker.execute_func( + self._cache.get, self._get_cache_key(), tag="cache_hit" + ) + else: default_values = { # TODO: find a better way to determine the engine, @@ -312,13 +353,21 @@ def chat(self, query: str, output_type: Optional[str] = None): ): default_values["current_code"] = self._last_code_generated - generate_python_code_instruction = self._get_prompt( - "generate_python_code", - default_prompt=GeneratePythonCodePrompt, - default_values=default_values, + generate_python_code_instruction = ( + self._query_exec_tracker.execute_func( + self._get_prompt, + "generate_python_code", + default_prompt=GeneratePythonCodePrompt, + default_values=default_values, + ) + ) + + [code, reasoning, answer] = self._query_exec_tracker.execute_func( + self._llm.generate_code, generate_python_code_instruction ) - code = self._llm.generate_code(generate_python_code_instruction) + self.last_reasoning = reasoning + self.last_answer = answer if self._config.enable_cache and self._cache: self._cache.set(self._get_cache_key(), code) @@ -341,10 +390,12 @@ def chat(self, query: str, output_type: Optional[str] = None): while retry_count < self._config.max_retries: try: # Execute the code + context = CodeExecutionContext(self._last_prompt_id, self._skills) result = self._code_manager.execute_code( code=code_to_run, - prompt_id=self._last_prompt_id, + context=context, ) + break except Exception as e: if ( @@ -362,7 +413,13 @@ def chat(self, query: str, output_type: Optional[str] = None): ) traceback_error = traceback.format_exc() - code_to_run = self._retry_run_code(code, traceback_error) + [ + code_to_run, + reasoning, + answer, + ] = self._query_exec_tracker.execute_func( + self._retry_run_code, code, traceback_error + ) if result is not None: if isinstance(result, dict): @@ -371,22 +428,51 @@ def chat(self, query: str, output_type: Optional[str] = None): self.logger.log( "\n".join(validation_logs), level=logging.WARNING ) + self._query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": False, + "message": "Output Validation Failed", + } + ) + else: + self._query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": True, + "message": "Output Validation Successful", + } + ) self.last_result = result self.logger.log(f"Answer: {result}") + except Exception as exception: self.last_error = str(exception) + self._query_exec_tracker.success = False + self._query_exec_tracker.publish() + return ( "Unfortunately, I was not able to answer your question, " "because of the following error:\n" f"\n{exception}\n" ) - self.logger.log(f"Executed in: {time.time() - self._start_time}s") + self.logger.log( + f"Executed in: {self._query_exec_tracker.get_execution_time()}s" + ) self._add_result_to_memory(result) - return self._response_parser.parse(result) + result = self._query_exec_tracker.execute_func( + self._response_parser.parse, result + ) + + self._query_exec_tracker.success = True + + self._query_exec_tracker.publish() + + return result def _add_result_to_memory(self, result: dict): """ @@ -403,7 +489,7 @@ def _add_result_to_memory(self, result: dict): elif result["type"] == "dataframe" or result["type"] == "plot": self._memory.add("Ok here it is", False) - def _retry_run_code(self, code: str, e: Exception): + def _retry_run_code(self, code: str, e: Exception) -> List: """ A method to retry the code execution with error correction framework. @@ -428,16 +514,18 @@ def _retry_run_code(self, code: str, e: Exception): default_values=default_values, ) - code = self._llm.generate_code(error_correcting_instruction) + result = self._llm.generate_code(error_correcting_instruction) if self._config.callback is not None: - self._config.callback.on_code(code) - return code + self._config.callback.on_code(result[0]) + + return result def clear_memory(self): """ Clears the memory """ self._memory.clear() + self._conversation_id = uuid.uuid4() @property def engine(self): @@ -595,6 +683,30 @@ def last_code_generated(self, last_code_generated: str): def last_code_executed(self): return self._code_manager.last_code_executed + @property + def last_reasoning(self): + if not self._config.use_advanced_reasoning_framework: + raise AdvancedReasoningDisabledError( + "You need to enable the advanced reasoning framework" + ) + return self._last_reasoning + + @last_reasoning.setter + def last_reasoning(self, last_reasoning: str): + self._last_reasoning = last_reasoning + + @property + def last_answer(self): + if not self._config.use_advanced_reasoning_framework: + raise AdvancedReasoningDisabledError( + "You need to enable the advanced reasoning framework" + ) + return self._last_answer + + @last_answer.setter + def last_answer(self, last_answer: str): + self._last_answer = last_answer + @property def last_result(self): return self._last_result @@ -618,3 +730,7 @@ def dfs(self): @property def memory(self): return self._memory + + @property + def instance(self): + return self._instance diff --git a/tests/connectors/test_sqlite.py b/tests/connectors/test_sqlite.py new file mode 100644 index 000000000..f55d48011 --- /dev/null +++ b/tests/connectors/test_sqlite.py @@ -0,0 +1,85 @@ +import unittest +import pandas as pd +from unittest.mock import Mock,patch +from pandasai.connectors.base import SqliteConnectorConfig +from pandasai.connectors import SqliteConnector + +class TestSqliteConnector(unittest.TestCase): + @patch("pandasai.connectors.sql.create_engine",autospec=True) + def setUp(self,mock_create_engine) -> None: + self.mock_engine = Mock() + self.mock_connection = Mock() + self.mock_engine.connect.return_value = self.mock_connection + mock_create_engine.return_value = self.mock_engine + + self.config = SqliteConnectorConfig( + dialect="sqlite", + database="path_todb.db", + table="yourtable" + ).dict() + + self.connector = SqliteConnector(self.config) + + @patch("pandasai.connectors.SqliteConnector._load_connector_config") + @patch("pandasai.connectors.SqliteConnector._init_connection") + def test_constructor_and_properties( + self, mock_load_connector_config, mock_init_connection + ): + # Test constructor and properties + self.assertEqual(self.connector._config, self.config) + self.assertEqual(self.connector._engine, self.mock_engine) + self.assertEqual(self.connector._connection, self.mock_connection) + self.assertEqual(self.connector._cache_interval, 600) + SqliteConnector(self.config) + mock_load_connector_config.assert_called() + mock_init_connection.assert_called() + + def test_repr_method(self): + # Test __repr__ method + expected_repr = ( + "" + ) + self.assertEqual(repr(self.connector), expected_repr) + + @patch("pandasai.connectors.sql.pd.read_sql", autospec=True) + def test_head_method(self, mock_read_sql): + expected_data = pd.DataFrame({"Column1": [1, 2, 3], "Column2": [4, 5, 6]}) + mock_read_sql.return_value = expected_data + head_data = self.connector.head() + pd.testing.assert_frame_equal(head_data, expected_data) + + def test_rows_count_property(self): + # Test rows_count property + self.connector._rows_count = None + self.mock_connection.execute.return_value.fetchone.return_value = ( + 50, + ) # Sample rows count + rows_count = self.connector.rows_count + self.assertEqual(rows_count, 50) + + def test_columns_count_property(self): + # Test columns_count property + self.connector._columns_count = None + mock_df = Mock() + mock_df.columns = ["Column1", "Column2"] + self.connector.head = Mock(return_value=mock_df) + columns_count = self.connector.columns_count + self.assertEqual(columns_count, 2) + + def test_column_hash_property(self): + # Test column_hash property + mock_df = Mock() + mock_df.columns = ["Column1", "Column2"] + self.connector.head = Mock(return_value=mock_df) + column_hash = self.connector.column_hash + self.assertIsNotNone(column_hash) + self.assertEqual( + column_hash, + "0d045cff164deef81e24b0ed165b7c9c2789789f013902115316cde9d214fe63", + ) + + def test_fallback_name_property(self): + # Test fallback_name property + fallback_name = self.connector.fallback_name + self.assertEqual(fallback_name, "yourtable") \ No newline at end of file diff --git a/tests/llms/test_base_llm.py b/tests/llms/test_base_llm.py index 8f044adc3..b6af1746f 100644 --- a/tests/llms/test_base_llm.py +++ b/tests/llms/test_base_llm.py @@ -64,3 +64,18 @@ def test_extract_code(self): """ assert LLM()._extract_code(code) == "print('Hello World')" + + def test_extract_answer(self): + llm = LLM() + response = "This is the answer." + expected_answer = "This is the answer." + assert llm._extract_answer(response) == expected_answer + + def test_extract_answer_with_temp_chart(self): + llm = LLM() + response = ( + "This is the answer. It returns a temp_chart.png. " + "But it shouldn't." + ) + expected_answer = "This is the answer. But it shouldn't." + assert llm._extract_answer(response) == expected_answer diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index 5140afb53..05ddb764e 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -51,6 +51,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint): prompt.set_var("conversation", "Question") prompt.set_var("save_charts_path", save_charts_path) prompt.set_var("output_type_hint", output_type_hint) + prompt.set_var("skills", "") expected_prompt_content = f'''You are provided with the following pandas DataFrames: @@ -65,7 +66,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint): Question -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required import pandas as pd @@ -75,15 +76,76 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) At the end, return a dictionary of: {output_type_hint} """ ``` -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: +- return the updated analyze_data function wrapped within ```python ```''' # noqa E501 + actual_prompt_content = prompt.to_string() + if sys.platform.startswith("win"): + actual_prompt_content = actual_prompt_content.replace("\r\n", "\n") + assert actual_prompt_content == expected_prompt_content + + def test_advanced_reasoning_prompt(self): + """ + Test a prompt with advanced reasoning framework + """ + + llm = FakeLLM("plt.show()") + dfs = [ + SmartDataframe( + pd.DataFrame({"a": [1], "b": [4]}), + config={"llm": llm, "use_advanced_reasoning_framework": True}, + ) + ] + prompt = GeneratePythonCodePrompt() + prompt.set_config(dfs[0]._lake.config) + prompt.set_var("dfs", dfs) + prompt.set_var("conversation", "Question") + prompt.set_var("save_charts_path", "") + prompt.set_var("output_type_hint", "") + prompt.set_var("skills", "") + + expected_prompt_content = f'''You are provided with the following pandas DataFrames: + + +Dataframe dfs[0], with 1 rows and 2 columns. +This is the metadata of the dataframe dfs[0]: +a,b +1,4 + + + +Question + + +This is the initial python function. Do not change the params. Given the context, use the right dataframes. +```python +# TODO import all the dependencies required +import pandas as pd + +def analyze_data(dfs: list[pd.DataFrame]) -> dict: + """ + Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: Preprocessing and cleaning data if necessary + 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) + At the end, return a dictionary of: + + """ +``` -Return the updated code:''' # noqa E501 +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: +- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags. +- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later). +- return the updated analyze_data function wrapped within ```python ```''' # noqa E501 actual_prompt_content = prompt.to_string() if sys.platform.startswith("win"): actual_prompt_content = actual_prompt_content.replace("\r\n", "\n") @@ -94,7 +156,7 @@ def test_custom_instructions(self): 1. Load: Load the data from a file or database 2. Prepare: Preprocessing and cleaning data if necessary 3. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -4. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501 +4. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 prompt = GeneratePythonCodePrompt(custom_instructions=custom_instructions) actual_instructions = prompt._args["instructions"] @@ -105,5 +167,5 @@ def test_custom_instructions(self): 1. Load: Load the data from a file or database 2. Prepare: Preprocessing and cleaning data if necessary 3. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 4. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501 + 4. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 ) diff --git a/tests/skills/test_skills.py b/tests/skills/test_skills.py new file mode 100644 index 000000000..ed979951c --- /dev/null +++ b/tests/skills/test_skills.py @@ -0,0 +1,347 @@ +from typing import Optional +from unittest.mock import MagicMock, Mock, patch +import uuid +import pandas as pd + +import pytest +from pandasai.agent import Agent +from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager + +from pandasai.helpers.skills_manager import SkillsManager +from pandasai.llm.fake import FakeLLM +from pandasai.skills import skill +from pandasai.smart_dataframe import SmartDataframe + + +class TestSkills: + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False}) + + @pytest.fixture + def code_manager(self, smart_dataframe: SmartDataframe): + return smart_dataframe.lake._code_manager + + @pytest.fixture + def exec_context(self) -> MagicMock: + context = MagicMock(spec=CodeExecutionContext) + return context + + @pytest.fixture + def agent(self, llm, sample_df): + return Agent(sample_df, config={"llm": llm, "enable_cache": False}) + + def test_add_skills(self): + skills_manager = SkillsManager() + skill1 = Mock(name="SkillA", print="SkillA Print") + skill2 = Mock(name="SkillB", print="SkillB Print") + skills_manager.add_skills(skill1, skill2) + + # Ensure that skills are added + assert skill1 in skills_manager.skills + assert skill2 in skills_manager.skills + + # Test that adding a skill with the same name raises an error + try: + skills_manager.add_skills(skill1) + except ValueError as e: + assert str(e) == f"Skill with name '{skill1.name}' already exists." + else: + assert False, "Expected ValueError" + + def test_skill_exists(self): + skills_manager = SkillsManager() + skill1 = MagicMock() + skill2 = MagicMock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + assert skills_manager.skill_exists("SkillA") + assert skills_manager.skill_exists("SkillB") + + # Test that a non-existing skill is not found + assert not skills_manager.skill_exists("SkillC") + + def test_get_skill_by_func_name(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test that you can retrieve a skill by its function name + retrieved_skill = skills_manager.get_skill_by_func_name("SkillA") + assert retrieved_skill == skill1 + + # Test that a non-existing skill returns None + retrieved_skill = skills_manager.get_skill_by_func_name("SkillC") + assert retrieved_skill is None + + def test_add_used_skill(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test adding used skills + skills_manager.add_used_skill("SkillA") + skills_manager.add_used_skill("SkillB") + + # Ensure used skills are added to the used_skills list + assert "SkillA" in skills_manager.used_skills + assert "SkillB" in skills_manager.used_skills + + def test_prompt_display(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skill1.print = "SkillA" + skill2.print = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test prompt_display method when skills exist + prompt = skills_manager.prompt_display() + assert "You can also use the following functions" in prompt + + # Test prompt_display method when no skills exist + skills_manager._skills = [] + prompt = skills_manager.prompt_display() + assert prompt is None + + @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)") + def test_skill_decorator(self, mock_inspect_signature): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + # Test the wrapped functions + assert skill_a() == "SkillA Result" + assert skill_b() == "SkillB Result" + + # Test the additional attributes added by the decorator + assert skill_a.name == "skill_a" + assert skill_b.name == "skill_b" + + assert skill_a.func_def == "def pandasai.skills.skill_a(a, b, c)" + assert skill_b.func_def == "def pandasai.skills.skill_b(a, b, c)" + + assert ( + skill_a.print + == """\n\ndef pandasai.skills.skill_a(a, b, c)\n\n\n""" # noqa: E501 + ) + assert ( + skill_b.print + == """\n\ndef pandasai.skills.skill_b(a, b, c)\n\n\n""" # noqa: E501 + ) + + @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)") + def test_skill_decorator_test_codc(self, llm): + df = pd.DataFrame({"country": []}) + df = SmartDataframe(df, config={"llm": llm, "enable_cache": False}) + + # Define skills using the decorator + @skill + def plot_salaries(*args, **kwargs): + """ + Test skill A + Args: + arg(str) + """ + return "SkillA Result" + + function_def = """ + Test skill A + Args: + arg(str) +""" # noqa: E501 + + assert function_def in plot_salaries.print + + def test_add_skills_with_agent(self, agent: Agent): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + agent.add_skills(skill_a) + assert len(agent._lake._skills.skills) == 1 + + agent._lake._skills._skills = [] + agent.add_skills(skill_a, skill_b) + assert len(agent._lake._skills.skills) == 2 + + def test_add_skills_with_smartDataframe(self, smart_dataframe: SmartDataframe): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + smart_dataframe.add_skills(skill_a) + assert len(smart_dataframe._lake._skills.skills) == 1 + + smart_dataframe._lake._skills._skills = [] + smart_dataframe.add_skills(skill_a, skill_b) + assert len(smart_dataframe._lake._skills.skills) == 2 + + def test_run_prompt(self, llm): + df = pd.DataFrame({"country": []}) + df = SmartDataframe(df, config={"llm": llm, "enable_cache": False}) + + function_def = """ + +def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str + + +""" # noqa: E501 + + @skill + def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + df.add_skills(plot_salaries) + + df.chat("How many countries are in the dataframe?") + last_prompt = df.last_prompt + assert function_def in last_prompt + + def test_run_prompt_agent(self, agent): + function_def = """ + +def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str + + +""" # noqa: E501 + + @skill + def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + agent.add_skills(plot_salaries) + + agent.chat("How many countries are in the dataframe?") + last_prompt = agent._lake.last_prompt + + assert function_def in last_prompt + + def test_run_prompt_without_skills(self, agent): + agent.chat("How many countries are in the dataframe?") + + last_prompt = agent._lake.last_prompt + + assert "" not in last_prompt + assert "" not in last_prompt + assert ( + "You can also use the following functions, if relevant:" not in last_prompt + ) + + def test_code_exec_with_skills_no_use( + self, code_manager: CodeManager, exec_context: MagicMock + ): + code = """def analyze_data(dfs): + return {'type': 'number', 'value': 1 + 1}""" + skill1 = MagicMock() + skill1.name = "SkillA" + exec_context._skills_manager._skills = [skill1] + code_manager.execute_code(code, exec_context) + assert len(exec_context._skills_manager.used_skills) == 0 + + def test_code_exec_with_skills(self, code_manager: CodeManager): + code = """def analyze_data(dfs): + plot_salaries() + return {'type': 'number', 'value': 1 + 1}""" + + @skill + def plot_salaries() -> str: + return "plot_salaries" + + code_manager._middlewares = [] + + sm = SkillsManager() + sm.add_skills(plot_salaries) + exec_context = CodeExecutionContext(uuid.uuid4(), sm) + code_manager.execute_code(code, exec_context) + + assert len(exec_context._skills_manager.used_skills) == 1 + assert exec_context._skills_manager.used_skills[0] == "plot_salaries" diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index b7a21cf6c..0df494429 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -1,7 +1,6 @@ """Unit tests for the CodeManager class""" -import uuid from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pandas as pd import pytest @@ -11,7 +10,7 @@ from pandasai.smart_dataframe import SmartDataframe -from pandasai.helpers.code_manager import CodeManager +from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager class TestCodeManager: @@ -72,23 +71,33 @@ def smart_dataframe(self, llm, sample_df): def code_manager(self, smart_dataframe: SmartDataframe): return smart_dataframe.lake._code_manager - def test_run_code_for_calculations(self, code_manager: CodeManager): + @pytest.fixture + def exec_context(self) -> MagicMock: + context = MagicMock(spec=CodeExecutionContext) + return context + + def test_run_code_for_calculations( + self, code_manager: CodeManager, exec_context: MagicMock + ): code = """def analyze_data(dfs): return {'type': 'number', 'value': 1 + 1}""" - - assert code_manager.execute_code(code, uuid.uuid4())["value"] == 2 + assert code_manager.execute_code(code, exec_context)["value"] == 2 assert code_manager.last_code_executed == code - def test_run_code_invalid_code(self, code_manager: CodeManager): + def test_run_code_invalid_code( + self, code_manager: CodeManager, exec_context: MagicMock + ): with pytest.raises(Exception): # noinspection PyStatementEffect - code_manager.execute_code("1+ ", uuid.uuid4())["value"] + code_manager.execute_code("1+ ", exec_context)["value"] - def test_clean_code_remove_builtins(self, code_manager: CodeManager): + def test_clean_code_remove_builtins( + self, code_manager: CodeManager, exec_context: MagicMock + ): builtins_code = """import set def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" - assert code_manager.execute_code(builtins_code, uuid.uuid4())["value"] == { + assert code_manager.execute_code(builtins_code, exec_context)["value"] == { 1, 2, 3, @@ -99,44 +108,57 @@ def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" ) - def test_clean_code_removes_jailbreak_code(self, code_manager: CodeManager): + def test_clean_code_removes_jailbreak_code( + self, code_manager: CodeManager, exec_context: MagicMock + ): malicious_code = """def analyze_data(dfs): __builtins__['str'].__class__.__mro__[-1].__subclasses__()[140].__init__.__globals__['system']('ls') print('hello world')""" assert ( - code_manager._clean_code(malicious_code) + code_manager._clean_code(malicious_code, exec_context) == """def analyze_data(dfs): print('hello world')""" ) - def test_clean_code_remove_environment_defaults(self, code_manager: CodeManager): + def test_clean_code_remove_environment_defaults( + self, code_manager: CodeManager, exec_context: MagicMock + ): pandas_code = """import pandas as pd print('hello world') """ - assert code_manager._clean_code(pandas_code) == "print('hello world')" + assert ( + code_manager._clean_code(pandas_code, exec_context) + == "print('hello world')" + ) - def test_clean_code_whitelist_import(self, code_manager: CodeManager): + def test_clean_code_whitelist_import( + self, code_manager: CodeManager, exec_context: MagicMock + ): """Test that an installed whitelisted library is added to the environment.""" safe_code = """ import numpy as np np.array() """ - assert code_manager._clean_code(safe_code) == "np.array()" + assert code_manager._clean_code(safe_code, exec_context) == "np.array()" - def test_clean_code_raise_bad_import_error(self, code_manager: CodeManager): + def test_clean_code_raise_bad_import_error( + self, code_manager: CodeManager, exec_context: MagicMock + ): malicious_code = """ import os print(os.listdir()) """ with pytest.raises(BadImportError): - code_manager.execute_code(malicious_code, uuid.uuid4()) + code_manager.execute_code(malicious_code, exec_context) - def test_remove_dfs_overwrites(self, code_manager: CodeManager): + def test_remove_dfs_overwrites( + self, code_manager: CodeManager, exec_context: MagicMock + ): hallucinated_code = """def analyze_data(dfs): dfs = [pd.DataFrame([1,2,3])] print(dfs)""" assert ( - code_manager._clean_code(hallucinated_code) + code_manager._clean_code(hallucinated_code, exec_context) == """def analyze_data(dfs): print(dfs)""" ) @@ -147,6 +169,7 @@ def test_exception_handling( code_manager.execute_code = Mock( side_effect=NoCodeFoundError("No code found in the answer.") ) + code_manager.execute_code.__name__ = "execute_code" result = smart_dataframe.chat("How many countries are in the dataframe?") assert result == ( @@ -156,7 +179,9 @@ def test_exception_handling( ) assert smart_dataframe.last_error == "No code found in the answer." - def test_custom_whitelisted_dependencies(self, code_manager: CodeManager, llm): + def test_custom_whitelisted_dependencies( + self, code_manager: CodeManager, llm, exec_context: MagicMock + ): code = """ import my_custom_library def analyze_data(dfs: list): @@ -165,11 +190,11 @@ def analyze_data(dfs: list): llm._output = code with pytest.raises(BadImportError): - code_manager._clean_code(code) + code_manager._clean_code(code, exec_context) code_manager._config.custom_whitelisted_dependencies = ["my_custom_library"] assert ( - code_manager._clean_code(code) + code_manager._clean_code(code, exec_context) == """def analyze_data(dfs: list): my_custom_library.do_something()""" ) diff --git a/tests/test_query_tracker.py b/tests/test_query_tracker.py new file mode 100644 index 000000000..342ee4bda --- /dev/null +++ b/tests/test_query_tracker.py @@ -0,0 +1,541 @@ +import os +import time +from typing import Optional +from unittest.mock import MagicMock, Mock, patch +import pandas as pd +import pytest + +from pandasai.helpers.query_exec_tracker import QueryExecTracker +from pandasai.llm.fake import FakeLLM +from pandasai.smart_dataframe import SmartDataframe +from unittest import TestCase + + +assert_almost_equal = TestCase().assertAlmostEqual + + +class TestQueryExecTracker: + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False}) + + @pytest.fixture + def smart_datalake(self, smart_dataframe: SmartDataframe): + return smart_dataframe.lake + + @pytest.fixture + def tracker(self): + tracker = QueryExecTracker() + tracker.start_new_track() + tracker.add_query_info( + conversation_id="123", + instance="SmartDatalake", + query="which country has the highest GDP?", + output_type="json", + ) + return tracker + + def test_add_dataframes( + self, smart_dataframe: SmartDataframe, tracker: QueryExecTracker + ): + # Add the dataframe to the tracker + tracker._dataframes = [] + tracker.add_dataframes([smart_dataframe]) + + # Check if the dataframe was added correctly + assert len(tracker._dataframes) == 1 + assert len(tracker._dataframes[0]["headers"]) == 3 + assert len(tracker._dataframes[0]["rows"]) == 5 + + def test_add_step(self, tracker: QueryExecTracker): + # Create a sample step + step = {"type": "CustomStep", "description": "This is a custom step."} + + tracker._steps = [] + # Add the step to the tracker + tracker.add_step(step) + + # Check if the step was added correctly + assert len(tracker._steps) == 1 + assert tracker._steps[0] == step + + def test_format_response_dataframe( + self, tracker: QueryExecTracker, sample_df: pd.DataFrame + ): + # Create a sample ResponseType for a dataframe + response = {"type": "dataframe", "value": sample_df} + + # Format the response using _format_response + formatted_response = tracker._format_response(response) + + # Check if the response is formatted correctly + assert formatted_response["type"] == "dataframe" + assert len(formatted_response["value"]["headers"]) == 3 + assert len(formatted_response["value"]["rows"]) == 10 + + def test_format_response_other_type(self, tracker: QueryExecTracker): + # Create a sample ResponseType for a non-dataframe response + response = { + "type": "other_type", + "value": "SomeValue", + } + + # Format the response using _format_response + formatted_response = tracker._format_response(response) + + # Check if the response is formatted correctly + assert formatted_response["type"] == "other_type" + assert formatted_response["value"] == "SomeValue" + + def test_get_summary(self): + # Execute a mock function to generate some steps and response + def mock_function(*args, **kwargs): + return "Mock Result" + + tracker = QueryExecTracker() + + tracker.start_new_track() + + tracker.add_query_info( + conversation_id="123", + instance="SmartDatalake", + query="which country has the highest GDP?", + output_type="json", + ) + + # Get the summary + summary = tracker.get_summary() + + tracker.execute_func(mock_function, tag="custom_tag") + + # Check if the summary contains the expected keys + assert "query_info" in summary + assert "dataframes" in summary + assert "steps" in summary + assert "response" in summary + assert "execution_time" in summary + assert "is_related_query" in summary["query_info"] + + def test_related_query_in_summary(self): + # Execute a mock function to generate some steps and response + def mock_function(*args, **kwargs): + return "Mock Result" + + tracker = QueryExecTracker() + + tracker.set_related_query(False) + + tracker.start_new_track() + + tracker.add_query_info( + conversation_id="123", + instance="SmartDatalake", + query="which country has the highest GDP?", + output_type="json", + ) + + # Get the summary + summary = tracker.get_summary() + + tracker.execute_func(mock_function, tag="custom_tag") + + # Check if the summary contains the expected keys + assert "is_related_query" in summary["query_info"] + assert not summary["query_info"]["is_related_query"] + + def test_get_execution_time(self, tracker: QueryExecTracker): + def mock_function(*args, **kwargs): + time.sleep(1) + return "Mock Result" + + # Sleep for a while to simulate execution time + with patch("time.time", return_value=0): + tracker.execute_func(mock_function, tag="cache_hit") + + # Get the execution time + execution_time = tracker.get_execution_time() + + # Check if the execution time is approximately 1 second + assert_almost_equal(execution_time, 1.0, delta=0.3) + + def test_execute_func_success(self, tracker: QueryExecTracker): + tracker._steps = [] + + # Create a mock function + mock_return_value = Mock() + mock_return_value.to_string = Mock() + mock_return_value.to_string.return_value = "Mock Result" + + mock_func = Mock() + mock_func.return_value = mock_return_value + mock_func.__name__ = "_get_prompt" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="_get_prompt") + + # Check if the result is as expected + assert result.to_string() == "Mock Result" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert step["type"] == "Generate Prompt" + assert step["success"] is True + + def test_execute_func_failure(self, tracker: QueryExecTracker): + # Create a mock function that raises an exception + def mock_function(*args, **kwargs): + raise Exception("Mock Exception") + + with pytest.raises(Exception): + tracker.execute_func(mock_function, tag="custom_tag") + + def test_execute_func_cache_hit(self, tracker: QueryExecTracker): + tracker._steps = [] + + mock_func = Mock() + mock_func.return_value = "code" + mock_func.__name__ = "get" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="cache_hit") + + # Check if the result is as expected + assert result == "code" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert "code_generated" in step + assert step["type"] == "Cache Hit" + assert step["success"] is True + + def test_execute_func_generate_code(self, tracker: QueryExecTracker): + tracker._steps = [] + + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + mock_func.__name__ = "generate_code" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="generate_code") + + # Check if the result is as expected + assert result == "code" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert "code_generated" in step + assert step["type"] == "Generate Code" + assert step["success"] is True + + def test_execute_func_re_rerun_code(self, tracker: QueryExecTracker): + tracker._steps = [] + + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + mock_func.__name__ = "_retry_run_code" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func) + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func) + + # Check if the result is as expected + assert result == "code" + # Check if the step was added correctly + assert len(tracker._steps) == 2 + step = tracker._steps[0] + assert "code_generated" in step + assert step["type"] == "Retry Code Generation (1)" + assert step["success"] is True + + # Check second step as well + step2 = tracker._steps[1] + assert "code_generated" in step2 + assert step2["type"] == "Retry Code Generation (2)" + assert step2["success"] is True + + def test_execute_func_execute_code_success( + self, sample_df: pd.DataFrame, tracker: QueryExecTracker + ): + tracker._steps = [] + + mock_func = Mock() + mock_func.return_value = {"type": "dataframe", "value": sample_df} + mock_func.__name__ = "execute_code" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func) + + # Check if the result is as expected + assert result["type"] == "dataframe" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert "result" in step + assert step["type"] == "Code Execution" + assert step["success"] is True + + def test_execute_func_execute_code_fail( + self, sample_df: pd.DataFrame, tracker: QueryExecTracker + ): + tracker._steps = [] + + mock_func = Mock() + mock_func.side_effect = Exception("Mock Exception") + mock_func.__name__ = "execute_code" + + with pytest.raises(Exception): + tracker.execute_func(mock_func) + + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert step["type"] == "Code Execution" + assert step["success"] is False + + def test_publish_method_with_server_key(self, tracker: QueryExecTracker): + # Define a mock summary function + def mock_get_summary(): + return "Test summary data" + + # Mock the server_config + tracker._server_config = { + "server_url": "http://custom-server", + "api_key": "custom-api-key", + } + + # Set the get_summary method to your mock + tracker.get_summary = mock_get_summary + + # Mock the requests.post method + mock_response = MagicMock() + mock_response.status_code = 200 + type(mock_response).text = "Response text" + # Mock the requests.post method + with patch("requests.post", return_value=mock_response) as mock_post: + # Call the publish method + result = tracker.publish() + + # Check that requests.post was called with the expected parameters + mock_post.assert_called_with( + "http://custom-server/api/log/add", + json={"json_log": "Test summary data"}, + headers={"Authorization": "Bearer custom-api-key"}, + ) + + # Check the result + assert result is None # The function should return None + + def test_publish_method_with_no_config(self, tracker: QueryExecTracker): + # Define a mock summary function + def mock_get_summary(): + return "Test summary data" + + tracker._server_config = None + + # Set the get_summary method to your mock + tracker.get_summary = mock_get_summary + + # Mock the requests.post method + mock_response = MagicMock() + mock_response.status_code = 200 + type(mock_response).text = "Response text" + # Mock the requests.post method + with patch("requests.post", return_value=mock_response) as mock_post: + # Call the publish method + result = tracker.publish() + + # Check that requests.post was called with the expected parameters + mock_post.assert_not_called() + + # Check the result + assert result is None # The function should return None + + def test_publish_method_with_os_env(self, tracker: QueryExecTracker): + # Define a mock summary function + def mock_get_summary(): + return "Test summary data" + + # Define a mock environment for testing + os.environ["LOGGING_SERVER_URL"] = "http://test-server" + os.environ["LOGGING_SERVER_API_KEY"] = "test-api-key" + + # Set the get_summary method to your mock + tracker.get_summary = mock_get_summary + + # Mock the requests.post method + mock_response = MagicMock() + mock_response.status_code = 200 + type(mock_response).text = "Response text" + # Mock the requests.post method + with patch("requests.post", return_value=mock_response) as mock_post: + # Call the publish method + result = tracker.publish() + + # Check that requests.post was called with the expected parameters + mock_post.assert_called_with( + "http://test-server/api/log/add", + json={"json_log": "Test summary data"}, + headers={"Authorization": "Bearer test-api-key"}, + ) + + # Check the result + assert result is None # The function should return None + + def test_multiple_instance_of_tracker(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + mock_func.__name__ = "generate_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="generate_code") + + tracker2 = QueryExecTracker() + tracker2.start_new_track() + tracker2.add_query_info( + conversation_id="12345", + instance="SmartDatalake", + query="which country has the highest GDP?", + output_type="json", + ) + + assert len(tracker._steps) == 1 + assert len(tracker2._steps) == 0 + + # Execute code with tracker 2 + tracker2.execute_func(mock_func, tag="generate_code") + assert len(tracker._steps) == 1 + assert len(tracker2._steps) == 1 + + # Create a mock function + mock_func2 = Mock() + mock_func2.return_value = "code" + mock_func2.__name__ = "_retry_run_code" + tracker2.execute_func(mock_func2, tag="_retry_run_code") + assert len(tracker._steps) == 1 + assert len(tracker2._steps) == 2 + + assert ( + tracker._query_info["conversation_id"] + != tracker2._query_info["conversation_id"] + ) + + def test_conversation_id_in_different_tracks(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + mock_func.__name__ = "generate_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="generate_code") + + summary = tracker.get_summary() + + tracker.start_new_track() + + tracker.add_query_info( + conversation_id="123", + instance="SmartDatalake", + query="Plot the GDP's?", + output_type="json", + ) + + # Create a mock function + mock_func2 = Mock() + mock_func2.return_value = "code" + mock_func2.__name__ = "_retry_run_code" + + tracker.execute_func(mock_func2, tag="_retry_run_code") + + summary2 = tracker.get_summary() + + assert ( + summary["query_info"]["conversation_id"] + == summary2["query_info"]["conversation_id"] + ) + assert len(tracker._steps) == 1 + + def test_reasoning_answer_in_code_section(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = ["code", "reason", "answer"] + mock_func.__name__ = "generate_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="generate_code") + + summary = tracker.get_summary() + + step = summary["steps"][0] + + assert "reasoning" in step + assert "answer" in step + assert step["reasoning"] == "reason" + assert step["answer"] == "answer" + + def test_reasoning_answer_in_rerun_code(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = ["code", "reason", "answer"] + mock_func.__name__ = "_retry_run_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="_retry_run_code") + + summary = tracker.get_summary() + + step = summary["steps"][0] + assert "reasoning" in step + assert "answer" in step + assert step["reasoning"] == "reason" + assert step["answer"] == "answer" diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 60ac31319..fada5ed04 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -215,7 +215,7 @@ def test_run_with_privacy_enforcement(self, llm): User: How many countries are in the dataframe? -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required import pandas as pd @@ -225,12 +225,12 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) At the end, return a dictionary of: - type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) Examples: - { "type": "string", "value": "The highest salary is $9,000." } + { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or @@ -240,9 +240,10 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: \"\"\" ``` -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. - -Return the updated code:""" # noqa: E501 +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: +- return the updated analyze_data function wrapped within ```python ```""" # noqa: E501 df.chat("How many countries are in the dataframe?") last_prompt = df.last_prompt if sys.platform.startswith("win"): @@ -275,7 +276,7 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint): User: How many countries are in the dataframe? -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required import pandas as pd @@ -285,15 +286,16 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) At the end, return a dictionary of: {output_type_hint} """ ``` -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. - -Return the updated code:''' # noqa: E501 +Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. +Based on the last message in the conversation: +- return the updated analyze_data function wrapped within ```python ```''' # noqa: E501 df.chat("How many countries are in the dataframe?", output_type=output_type) last_prompt = df.last_prompt diff --git a/tests/test_smartdatalake.py b/tests/test_smartdatalake.py index 2b02c431b..f3159726a 100644 --- a/tests/test_smartdatalake.py +++ b/tests/test_smartdatalake.py @@ -110,6 +110,8 @@ def test_load_llm_with_langchain_llm(self, smart_datalake: SmartDatalake, llm): def test_last_result_is_saved(self, _mocked_method, smart_datalake: SmartDatalake): assert smart_datalake.last_result is None + _mocked_method.__name__ = "execute_code" + smart_datalake.chat("How many countries are in the dataframe?") assert smart_datalake.last_result == { "type": "string", @@ -188,3 +190,22 @@ def test_initialize(self, mock_makedirs, smart_datalake: SmartDatalake): cache_dir = os.path.join(os.getcwd(), "cache") mock_makedirs.assert_any_call(cache_dir, mode=0o777, exist_ok=True) + + def test_last_answer_and_reasoning(self, smart_datalake: SmartDatalake): + llm = FakeLLM( + """ + Custom reasoning + Custom answer + ```python +def analyze_data(dfs): + return { 'type': 'text', 'value': "Hello World" } +```""" + ) + smart_datalake._llm = llm + smart_datalake.config.use_advanced_reasoning_framework = True + assert smart_datalake.last_answer is None + assert smart_datalake.last_reasoning is None + + smart_datalake.chat("How many countries are in the dataframe?") + assert smart_datalake.last_answer == "Custom answer" + assert smart_datalake.last_reasoning == "Custom reasoning"