From 045691e1a12b7e1aa3ea2a60a71465960209b306 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 10 Oct 2023 12:44:01 +0500 Subject: [PATCH 01/19] feat(QueryTracker): tracks query execution cycle --- pandasai/helpers/query_exec_tracker.py | 165 +++++++++++++++++++++++++ pandasai/smart_datalake/__init__.py | 61 ++++++--- tests/test_codemanager.py | 1 + tests/test_smartdatalake.py | 2 + 4 files changed, 214 insertions(+), 15 deletions(-) create mode 100644 pandasai/helpers/query_exec_tracker.py diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py new file mode 100644 index 000000000..a22c11466 --- /dev/null +++ b/pandasai/helpers/query_exec_tracker.py @@ -0,0 +1,165 @@ +import time +from typing import Any, List, TypedDict + + +class ResponseType(TypedDict): + type: str + value: Any + + +exec_steps = { + "cache_hit": "Cache Hit", + "_get_prompt": "Generate Prompt", + "generate_code": "Generate Code", + "execute_code": "Code Execution", + "_retry_run_code": "Retry Code Generation", +} + + +class QueryExecTracker: + _query_info: str = {} + _dataframes: List = [] + _response: ResponseType = {} + _steps: List = [] + _start_time = None + + def __init__( + self, + conversation_id: str, + query: str, + instance: str, + output_type: str, + ) -> None: + self._start_time = time.time() + self._query_info = { + "conversation_id": str(conversation_id), + "query": query, + "instance": instance, + "output_type": output_type, + } + + def add_dataframes(self, dfs: List) -> None: + """ + Add used dataframes for the query to query exec tracker + Args: + dfs (List[SmartDataFrame]): List of dataframes + """ + for df in dfs: + head = df.head_df + self._dataframes.append( + {"headers": head.columns.tolist(), "rows": head.values.tolist()} + ) + + def add_step(self, step: dict) -> None: + """ + Add Custom Step that is performed for additional information + Args: + step (dict): dictionary containing information + """ + self._steps.append(step) + + def execute_func(self, function, *args, **kwargs) -> Any: + """ + Tracks function executions, calculates execution time and prepare data + Args: + function (function): Function that is to be executed + + Returns: + Any: Response return after function execution + """ + start_time = time.time() + + func_name = kwargs["tag"] if "tag" in kwargs else function.__name__ + + try: + result = function(*args, **kwargs) + + execution_time = time.time() - start_time + if func_name not in exec_steps: + return result + + step_data = self._generate_exec_step(func_name, result) + + step_data["type"] = exec_steps[func_name] + step_data["success"] = True + step_data["execution_time"] = execution_time + + self._steps.append(step_data) + + return result + + except Exception: + execution_time = time.time() - start_time + self._steps.append( + { + "type": exec_steps[func_name], + "success": False, + "execution_time": execution_time, + } + ) + raise + + def _generate_exec_step(self, func_name: str, result: Any) -> dict: + """ + Extracts and Generates result + Args: + func_name (str): function name that is executed + result (Any): function output response + + Returns: + dict: dictionary with information about the function execution + """ + if ( + func_name == "cache_hit" + or func_name == "generate_code" + or func_name == "_retry_run_code" + ): + return {"code_generated": result} + elif func_name == "_get_prompt": + return { + "prompt_class": result.__class__.__name__, + "generated_prompt": result.to_string(), + } + elif func_name == "execute_code": + self._response = self._format_response(result) + return {"result": result} + + def _format_response(self, result: ResponseType) -> ResponseType: + """ + Format output response + Args: + result (ResponseType): response returned after execution + + Returns: + ResponseType: formatted response output + """ + formatted_result = {} + if result["type"] == "dataframe": + formatted_result = { + "type": result["type"], + "value": { + "headers": result["value"].columns.tolist(), + "rows": result["value"].values.tolist(), + }, + } + return formatted_result + else: + return result + + def get_summary(self) -> dict: + """ + Returns the summary to steps involved in execution of track + Returns: + dict: summary json + """ + execution_time = time.time() - self._start_time + return { + "query_info": self._query_info, + "dataframes": self._dataframes, + "steps": self._steps, + "response": self._response, + "execution_time": execution_time, + } + + def get_execution_time(self) -> float: + return time.time() - self._start_time diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index a94a1484b..4c0758b34 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -17,13 +17,13 @@ # The average loan amount is $15,000. ``` """ - -import time import uuid import logging import os import traceback +from pandasai.helpers.query_exec_tracker import QueryExecTracker + from ..helpers.output_types import output_type_factory from pandasai.responses.context import Context from pandasai.responses.response_parser import ResponseParser @@ -190,11 +190,6 @@ def add_middlewares(self, *middlewares: Optional[Middleware]): """ self._code_manager.add_middlewares(*middlewares) - def _start_timer(self): - """Start the timer""" - - self._start_time = time.time() - def _assign_prompt_id(self): """Assign a prompt ID""" @@ -279,13 +274,20 @@ def chat(self, query: str, output_type: Optional[str] = None): ValueError: If the query is empty """ - self._start_timer() - self.logger.log(f"Question: {query}") self.logger.log(f"Running PandasAI with {self._llm.type} LLM...") self._assign_prompt_id() + query_exec_tracker = QueryExecTracker( + conversation_id=self._last_prompt_id, + query=query, + instance=self.__class__.__name__, + output_type=output_type, + ) + + query_exec_tracker.add_dataframes(self._dfs) + self._memory.add(query, True) try: @@ -297,7 +299,10 @@ def chat(self, query: str, output_type: Optional[str] = None): and self._cache.get(self._get_cache_key()) ): self.logger.log("Using cached response") - code = self._cache.get(self._get_cache_key()) + code = query_exec_tracker.execute_func( + self._cache.get, self._get_cache_key(), tag="cache_hit" + ) + else: default_values = { # TODO: find a better way to determine the engine, @@ -305,13 +310,16 @@ def chat(self, query: str, output_type: Optional[str] = None): "output_type_hint": output_type_helper.template_hint, } - generate_python_code_instruction = self._get_prompt( + generate_python_code_instruction = query_exec_tracker.execute_func( + self._get_prompt, "generate_python_code", default_prompt=GeneratePythonCodePrompt, default_values=default_values, ) - code = self._llm.generate_code(generate_python_code_instruction) + code = query_exec_tracker.execute_func( + self._llm.generate_code, generate_python_code_instruction + ) if self._config.enable_cache and self._cache: self._cache.set(self._get_cache_key(), code) @@ -334,7 +342,8 @@ def chat(self, query: str, output_type: Optional[str] = None): while retry_count < self._config.max_retries: try: # Execute the code - result = self._code_manager.execute_code( + result = query_exec_tracker.execute_func( + self._code_manager.execute_code, code=code_to_run, prompt_id=self._last_prompt_id, ) @@ -355,7 +364,9 @@ def chat(self, query: str, output_type: Optional[str] = None): ) traceback_error = traceback.format_exc() - code_to_run = self._retry_run_code(code, traceback_error) + code_to_run = query_exec_tracker.execute_func( + self._retry_run_code, code, traceback_error + ) if result is not None: if isinstance(result, dict): @@ -364,21 +375,41 @@ def chat(self, query: str, output_type: Optional[str] = None): self.logger.log( "\n".join(validation_logs), level=logging.WARNING ) + query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": False, + "message": "Output Validation Failed", + } + ) + else: + query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": True, + "message": "Output Validation Successful", + } + ) self.last_result = result self.logger.log(f"Answer: {result}") + except Exception as exception: self.last_error = str(exception) + self.logger.log(query_exec_tracker.get_summary()) + return ( "Unfortunately, I was not able to answer your question, " "because of the following error:\n" f"\n{exception}\n" ) - self.logger.log(f"Executed in: {time.time() - self._start_time}s") + self.logger.log(f"Executed in: {query_exec_tracker.get_execution_time()}s") self._add_result_to_memory(result) + self.logger.log(query_exec_tracker.get_summary()) + return self._response_parser.parse(result) def _add_result_to_memory(self, result: dict): diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index b7a21cf6c..b5c39933a 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -147,6 +147,7 @@ def test_exception_handling( code_manager.execute_code = Mock( side_effect=NoCodeFoundError("No code found in the answer.") ) + code_manager.execute_code.__name__ = "execute_code" result = smart_dataframe.chat("How many countries are in the dataframe?") assert result == ( diff --git a/tests/test_smartdatalake.py b/tests/test_smartdatalake.py index 2b02c431b..3a590b0a6 100644 --- a/tests/test_smartdatalake.py +++ b/tests/test_smartdatalake.py @@ -110,6 +110,8 @@ def test_load_llm_with_langchain_llm(self, smart_datalake: SmartDatalake, llm): def test_last_result_is_saved(self, _mocked_method, smart_datalake: SmartDatalake): assert smart_datalake.last_result is None + _mocked_method.__name__ = "execute_code" + smart_datalake.chat("How many countries are in the dataframe?") assert smart_datalake.last_result == { "type": "string", From 947f7a46a6dd0d0d79a5e1d0d808a5577adf410a Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 10 Oct 2023 12:45:58 +0500 Subject: [PATCH 02/19] feat(QueryTracker): Adding test cases --- tests/test_query_tracker.py | 204 ++++++++++++++++++++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 tests/test_query_tracker.py diff --git a/tests/test_query_tracker.py b/tests/test_query_tracker.py new file mode 100644 index 000000000..c6a149502 --- /dev/null +++ b/tests/test_query_tracker.py @@ -0,0 +1,204 @@ +import time +from typing import Optional + +from unittest.mock import Mock, patch + +import pandas as pd +import pytest + +from pandasai.helpers.query_exec_tracker import QueryExecTracker +from pandasai.llm.fake import FakeLLM +from pandasai.smart_dataframe import SmartDataframe + + +class TestQueryExecTracker: + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False}) + + @pytest.fixture + def smart_datalake(self, smart_dataframe: SmartDataframe): + return smart_dataframe.lake + + @pytest.fixture + def tracker(self): + return QueryExecTracker( + conversation_id="123", + query="which country has the highest GDP?", + instance="SmartDatalake", + output_type="json", + ) + + # Define a custom assert_almost_equal function + def assert_almost_equal(self, first, second, places=None, msg=None, delta=None): + if delta is not None: + if abs(float(first) - float(second)) > delta: + raise AssertionError(msg or f"{first} != {second} within {delta}") + else: + assert round(abs(float(second) - float(first)), places) == 0, ( + msg or f"{first} != {second} within {places} places" + ) + + def test_add_dataframes( + self, smart_dataframe: SmartDataframe, tracker: QueryExecTracker + ): + # Add the dataframe to the tracker + tracker._dataframes = [] + tracker.add_dataframes([smart_dataframe]) + + # Check if the dataframe was added correctly + assert len(tracker._dataframes) == 1 + assert len(tracker._dataframes[0]["headers"]) == 3 + assert len(tracker._dataframes[0]["rows"]) == 5 + + def test_add_step(self, tracker: QueryExecTracker): + # Create a sample step + step = {"type": "CustomStep", "description": "This is a custom step."} + + tracker._steps = [] + # Add the step to the tracker + tracker.add_step(step) + + # Check if the step was added correctly + assert len(tracker._steps) == 1 + assert tracker._steps[0] == step + + def test_execute_func_success(self, tracker: QueryExecTracker): + tracker._steps = [] + + # Create a mock function + mock_return_value = Mock() + mock_return_value.to_string = Mock() + mock_return_value.to_string.return_value = "Mock Result" + + mock_func = Mock() + mock_func.return_value = mock_return_value + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="_get_prompt") + + # Check if the result is as expected + assert result.to_string() == "Mock Result" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert step["type"] == "Generate Prompt" + assert step["success"] is True + + def test_execute_func_failure(self, tracker: QueryExecTracker): + # Create a mock function that raises an exception + def mock_function(*args, **kwargs): + raise Exception("Mock Exception") + + # Execute the mock function using execute_func and expect an exception + with pytest.raises(Exception): + tracker.execute_func(mock_function, tag="custom_tag") + + def test_format_response_dataframe( + self, tracker: QueryExecTracker, sample_df: pd.DataFrame + ): + # Create a sample ResponseType for a dataframe + response = {"type": "dataframe", "value": sample_df} + + # Format the response using _format_response + formatted_response = tracker._format_response(response) + + # Check if the response is formatted correctly + assert formatted_response["type"] == "dataframe" + assert len(formatted_response["value"]["headers"]) == 3 + assert len(formatted_response["value"]["rows"]) == 10 + + def test_format_response_other_type(self, tracker: QueryExecTracker): + # Create a sample ResponseType for a non-dataframe response + response = { + "type": "other_type", + "value": "SomeValue", + } + + # Format the response using _format_response + formatted_response = tracker._format_response(response) + + # Check if the response is formatted correctly + assert formatted_response["type"] == "other_type" + assert formatted_response["value"] == "SomeValue" + + def test_get_summary(self, tracker: QueryExecTracker): + # Execute a mock function to generate some steps and response + def mock_function(*args, **kwargs): + return "Mock Result" + + tracker.execute_func(mock_function, tag="custom_tag") + + # Get the summary + summary = tracker.get_summary() + + # Check if the summary contains the expected keys + assert "query_info" in summary + assert "dataframes" in summary + assert "steps" in summary + assert "response" in summary + assert "execution_time" in summary + + def test_get_execution_time(self, tracker: QueryExecTracker): + def mock_function(*args, **kwargs): + time.sleep(1) + return "Mock Result" + + # Sleep for a while to simulate execution time + with patch("time.time", return_value=0): + tracker.execute_func(mock_function, tag="cache_hit") + + # Get the execution time + execution_time = tracker.get_execution_time() + + print("Type", execution_time) + + # Check if the execution time is approximately 1 second + self.assert_almost_equal(execution_time, 1.0, delta=0.1) From 45672fe96ea454ae3d8e7603bb306caab89a4c9f Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 10 Oct 2023 12:49:08 +0500 Subject: [PATCH 03/19] feat(ApiLogger): logger to send data to remote api --- examples/custom_logger.py | 32 +++++++++++++++++++++++++++ pandasai/custom_loggers/api_logger.py | 28 +++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 examples/custom_logger.py create mode 100644 pandasai/custom_loggers/api_logger.py diff --git a/examples/custom_logger.py b/examples/custom_logger.py new file mode 100644 index 000000000..11cf478d3 --- /dev/null +++ b/examples/custom_logger.py @@ -0,0 +1,32 @@ +import pandas as pd +from pandasai import Agent +from pandasai.custom_loggers.api_logger import APILogger + +from pandasai.llm.openai import OpenAI + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +llm = OpenAI("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") +agent = Agent( + [employees_df, salaries_df], + config={"llm": llm, "enable_cache": True}, + memory_size=10, + logger=APILogger("SERVER-URL", "USER-ID", "API-KEY"), +) + +# Chat with the agent +response = agent.chat("Who gets paid the most?") +print(response) diff --git a/pandasai/custom_loggers/api_logger.py b/pandasai/custom_loggers/api_logger.py new file mode 100644 index 000000000..3570c1838 --- /dev/null +++ b/pandasai/custom_loggers/api_logger.py @@ -0,0 +1,28 @@ +import logging +import requests +from pandasai.helpers.logger import Logger + + +class APILogger(Logger): + _api_key: str = None + _server_url: str = None + _user_id: str = None + + def __init__(self, server_url: str, user_id: str, api_key: str): + self._api_key = api_key + self._server_url = server_url + self._user_id = user_id + + def log(self, message: str, level: int = logging.INFO): + try: + log_data = { + # TODO - Remove user id from the API + "user_id": self._user_id, + "api_key_id": self._api_key, + "json_log": message, + } + response = requests.post(f"{self._server_url}/api/log/add", json=log_data) + if response.status_code != 200: + raise Exception(response.text) + except Exception: + pass From 7c396f7ff981a00b84c801ba73fba06c7378fc2a Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 10 Oct 2023 16:35:00 +0500 Subject: [PATCH 04/19] chore: add more test cases to test the flow --- examples/custom_logger.py | 2 +- pandasai/smart_datalake/__init__.py | 1 - tests/test_query_tracker.py | 173 ++++++++++++++++++++-------- 3 files changed, 128 insertions(+), 48 deletions(-) diff --git a/examples/custom_logger.py b/examples/custom_logger.py index 11cf478d3..0abf0b634 100644 --- a/examples/custom_logger.py +++ b/examples/custom_logger.py @@ -19,7 +19,7 @@ salaries_df = pd.DataFrame(salaries_data) -llm = OpenAI("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") +llm = OpenAI("Your-API-Key") agent = Agent( [employees_df, salaries_df], config={"llm": llm, "enable_cache": True}, diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 4c0758b34..02755cff0 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -285,7 +285,6 @@ def chat(self, query: str, output_type: Optional[str] = None): instance=self.__class__.__name__, output_type=output_type, ) - query_exec_tracker.add_dataframes(self._dfs) self._memory.add(query, True) diff --git a/tests/test_query_tracker.py b/tests/test_query_tracker.py index c6a149502..b784e64dc 100644 --- a/tests/test_query_tracker.py +++ b/tests/test_query_tracker.py @@ -1,14 +1,16 @@ import time from typing import Optional - from unittest.mock import Mock, patch - import pandas as pd import pytest from pandasai.helpers.query_exec_tracker import QueryExecTracker from pandasai.llm.fake import FakeLLM from pandasai.smart_dataframe import SmartDataframe +from unittest import TestCase + + +assert_almost_equal = TestCase().assertAlmostEqual class TestQueryExecTracker: @@ -76,16 +78,6 @@ def tracker(self): output_type="json", ) - # Define a custom assert_almost_equal function - def assert_almost_equal(self, first, second, places=None, msg=None, delta=None): - if delta is not None: - if abs(float(first) - float(second)) > delta: - raise AssertionError(msg or f"{first} != {second} within {delta}") - else: - assert round(abs(float(second) - float(first)), places) == 0, ( - msg or f"{first} != {second} within {places} places" - ) - def test_add_dataframes( self, smart_dataframe: SmartDataframe, tracker: QueryExecTracker ): @@ -110,37 +102,6 @@ def test_add_step(self, tracker: QueryExecTracker): assert len(tracker._steps) == 1 assert tracker._steps[0] == step - def test_execute_func_success(self, tracker: QueryExecTracker): - tracker._steps = [] - - # Create a mock function - mock_return_value = Mock() - mock_return_value.to_string = Mock() - mock_return_value.to_string.return_value = "Mock Result" - - mock_func = Mock() - mock_func.return_value = mock_return_value - - # Execute the mock function using execute_func - result = tracker.execute_func(mock_func, tag="_get_prompt") - - # Check if the result is as expected - assert result.to_string() == "Mock Result" - # Check if the step was added correctly - assert len(tracker._steps) == 1 - step = tracker._steps[0] - assert step["type"] == "Generate Prompt" - assert step["success"] is True - - def test_execute_func_failure(self, tracker: QueryExecTracker): - # Create a mock function that raises an exception - def mock_function(*args, **kwargs): - raise Exception("Mock Exception") - - # Execute the mock function using execute_func and expect an exception - with pytest.raises(Exception): - tracker.execute_func(mock_function, tag="custom_tag") - def test_format_response_dataframe( self, tracker: QueryExecTracker, sample_df: pd.DataFrame ): @@ -198,7 +159,127 @@ def mock_function(*args, **kwargs): # Get the execution time execution_time = tracker.get_execution_time() - print("Type", execution_time) - # Check if the execution time is approximately 1 second - self.assert_almost_equal(execution_time, 1.0, delta=0.1) + assert_almost_equal(execution_time, 1.0, delta=0.3) + + def test_execute_func_success(self, tracker: QueryExecTracker): + tracker._steps = [] + + # Create a mock function + mock_return_value = Mock() + mock_return_value.to_string = Mock() + mock_return_value.to_string.return_value = "Mock Result" + + mock_func = Mock() + mock_func.return_value = mock_return_value + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="_get_prompt") + + # Check if the result is as expected + assert result.to_string() == "Mock Result" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert step["type"] == "Generate Prompt" + assert step["success"] is True + + def test_execute_func_failure(self, tracker: QueryExecTracker): + # Create a mock function that raises an exception + def mock_function(*args, **kwargs): + raise Exception("Mock Exception") + + with pytest.raises(Exception): + tracker.execute_func(mock_function, tag="custom_tag") + + def test_execute_func_cache_hit(self, tracker: QueryExecTracker): + tracker._steps = [] + + mock_func = Mock() + mock_func.return_value = "code" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="cache_hit") + + # Check if the result is as expected + assert result == "code" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert "code_generated" in step + assert step["type"] == "Cache Hit" + assert step["success"] is True + + def test_execute_func_generate_code(self, tracker: QueryExecTracker): + tracker._steps = [] + + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="generate_code") + + # Check if the result is as expected + assert result == "code" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert "code_generated" in step + assert step["type"] == "Generate Code" + assert step["success"] is True + + def test_execute_func_re_rerun_code(self, tracker: QueryExecTracker): + tracker._steps = [] + + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="_retry_run_code") + + # Check if the result is as expected + assert result == "code" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert "code_generated" in step + assert step["type"] == "Retry Code Generation" + assert step["success"] is True + + def test_execute_func_execute_code_success( + self, sample_df: pd.DataFrame, tracker: QueryExecTracker + ): + tracker._steps = [] + + mock_func = Mock() + mock_func.return_value = {"type": "dataframe", "value": sample_df} + + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func, tag="execute_code") + + # Check if the result is as expected + assert result["type"] == "dataframe" + # Check if the step was added correctly + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert "result" in step + assert step["type"] == "Code Execution" + assert step["success"] is True + + def test_execute_func_execute_code_fail( + self, sample_df: pd.DataFrame, tracker: QueryExecTracker + ): + tracker._steps = [] + + def mock_function(*args, **kwargs): + raise Exception("Mock Exception") + + with pytest.raises(Exception): + tracker.execute_func(mock_function, tag="execute_code") + + assert len(tracker._steps) == 1 + step = tracker._steps[0] + assert step["type"] == "Code Execution" + assert step["success"] is False From c957373b88fe1140bf7c4491e5d5a3d234657f85 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 10 Oct 2023 17:45:09 +0500 Subject: [PATCH 05/19] Prettify the logged json for current logger --- pandasai/helpers/query_exec_tracker.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index a22c11466..cb4e13e98 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -1,3 +1,4 @@ +import json import time from typing import Any, List, TypedDict @@ -148,7 +149,19 @@ def _format_response(self, result: ResponseType) -> ResponseType: def get_summary(self) -> dict: """ - Returns the summary to steps involved in execution of track + Returns the formatted summary + Returns: + dict: summary json + """ + summary = self.get_summary_dict() + try: + return json.dumps(summary, indent=4) + except Exception: + return summary + + def get_summary_dict(self) -> dict: + """ + Returns the summary in json to steps involved in execution of track Returns: dict: summary json """ From 56985cc12781c539f8a6c565f3d990712d3f14e2 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 10 Oct 2023 21:14:16 +0500 Subject: [PATCH 06/19] fix: remove tag from the PR --- pandasai/helpers/query_exec_tracker.py | 11 ++++++----- tests/test_query_tracker.py | 16 +++++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index cb4e13e98..1047c5dca 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -70,18 +70,19 @@ def execute_func(self, function, *args, **kwargs) -> Any: """ start_time = time.time() - func_name = kwargs["tag"] if "tag" in kwargs else function.__name__ + # Get the tag from kwargs if provided, or use the function name as the default + tag = kwargs.pop("tag", function.__name__) try: result = function(*args, **kwargs) execution_time = time.time() - start_time - if func_name not in exec_steps: + if tag not in exec_steps: return result - step_data = self._generate_exec_step(func_name, result) + step_data = self._generate_exec_step(tag, result) - step_data["type"] = exec_steps[func_name] + step_data["type"] = exec_steps[tag] step_data["success"] = True step_data["execution_time"] = execution_time @@ -93,7 +94,7 @@ def execute_func(self, function, *args, **kwargs) -> Any: execution_time = time.time() - start_time self._steps.append( { - "type": exec_steps[func_name], + "type": exec_steps[tag], "success": False, "execution_time": execution_time, } diff --git a/tests/test_query_tracker.py b/tests/test_query_tracker.py index b784e64dc..56c10c74b 100644 --- a/tests/test_query_tracker.py +++ b/tests/test_query_tracker.py @@ -172,6 +172,7 @@ def test_execute_func_success(self, tracker: QueryExecTracker): mock_func = Mock() mock_func.return_value = mock_return_value + mock_func.__name__ = "_get_prompt" # Execute the mock function using execute_func result = tracker.execute_func(mock_func, tag="_get_prompt") @@ -197,6 +198,7 @@ def test_execute_func_cache_hit(self, tracker: QueryExecTracker): mock_func = Mock() mock_func.return_value = "code" + mock_func.__name__ = "get" # Execute the mock function using execute_func result = tracker.execute_func(mock_func, tag="cache_hit") @@ -216,6 +218,7 @@ def test_execute_func_generate_code(self, tracker: QueryExecTracker): # Create a mock function mock_func = Mock() mock_func.return_value = "code" + mock_func.__name__ = "generate_code" # Execute the mock function using execute_func result = tracker.execute_func(mock_func, tag="generate_code") @@ -235,9 +238,10 @@ def test_execute_func_re_rerun_code(self, tracker: QueryExecTracker): # Create a mock function mock_func = Mock() mock_func.return_value = "code" + mock_func.__name__ = "_retry_run_code" # Execute the mock function using execute_func - result = tracker.execute_func(mock_func, tag="_retry_run_code") + result = tracker.execute_func(mock_func) # Check if the result is as expected assert result == "code" @@ -255,9 +259,10 @@ def test_execute_func_execute_code_success( mock_func = Mock() mock_func.return_value = {"type": "dataframe", "value": sample_df} + mock_func.__name__ = "execute_code" # Execute the mock function using execute_func - result = tracker.execute_func(mock_func, tag="execute_code") + result = tracker.execute_func(mock_func) # Check if the result is as expected assert result["type"] == "dataframe" @@ -273,11 +278,12 @@ def test_execute_func_execute_code_fail( ): tracker._steps = [] - def mock_function(*args, **kwargs): - raise Exception("Mock Exception") + mock_func = Mock() + mock_func.side_effect = Exception("Mock Exception") + mock_func.__name__ = "execute_code" with pytest.raises(Exception): - tracker.execute_func(mock_function, tag="execute_code") + tracker.execute_func(mock_func) assert len(tracker._steps) == 1 step = tracker._steps[0] From f32f2a17f385122e8fc07a3bfa17b1098321b550 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 10 Oct 2023 21:42:02 +0500 Subject: [PATCH 07/19] Remove json to string conversion from summary function --- pandasai/helpers/query_exec_tracker.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index 1047c5dca..e8bc8ee3c 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -1,4 +1,3 @@ -import json import time from typing import Any, List, TypedDict @@ -149,18 +148,6 @@ def _format_response(self, result: ResponseType) -> ResponseType: return result def get_summary(self) -> dict: - """ - Returns the formatted summary - Returns: - dict: summary json - """ - summary = self.get_summary_dict() - try: - return json.dumps(summary, indent=4) - except Exception: - return summary - - def get_summary_dict(self) -> dict: """ Returns the summary in json to steps involved in execution of track Returns: From 5fcf2d9be9fd818f04786bb86c6dbe5acee7ef9d Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Wed, 11 Oct 2023 23:25:23 +0500 Subject: [PATCH 08/19] chore: add parsing in exec cycle and update APILogger --- examples/custom_logger.py | 10 ++++++--- pandasai/custom_loggers/api_logger.py | 31 +++++++++++++------------- pandasai/helpers/query_exec_tracker.py | 16 ++++++++++++- pandasai/smart_datalake/__init__.py | 7 +++++- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/examples/custom_logger.py b/examples/custom_logger.py index 0abf0b634..2ee4539cd 100644 --- a/examples/custom_logger.py +++ b/examples/custom_logger.py @@ -19,14 +19,18 @@ salaries_df = pd.DataFrame(salaries_data) -llm = OpenAI("Your-API-Key") +llm = OpenAI("OPEN-API-KEY") agent = Agent( [employees_df, salaries_df], config={"llm": llm, "enable_cache": True}, memory_size=10, - logger=APILogger("SERVER-URL", "USER-ID", "API-KEY"), + logger=APILogger( + "SERVER-URL", + "API-KEY", + ), ) # Chat with the agent -response = agent.chat("Who gets paid the most?") +response = agent.chat("Plot salary against department?") + print(response) diff --git a/pandasai/custom_loggers/api_logger.py b/pandasai/custom_loggers/api_logger.py index 3570c1838..69f81cba8 100644 --- a/pandasai/custom_loggers/api_logger.py +++ b/pandasai/custom_loggers/api_logger.py @@ -1,4 +1,5 @@ import logging +from typing import Union import requests from pandasai.helpers.logger import Logger @@ -6,23 +7,23 @@ class APILogger(Logger): _api_key: str = None _server_url: str = None - _user_id: str = None - def __init__(self, server_url: str, user_id: str, api_key: str): + def __init__(self, server_url: str, api_key: str): self._api_key = api_key self._server_url = server_url - self._user_id = user_id - def log(self, message: str, level: int = logging.INFO): + def log(self, message: Union[str, dict], level: int = logging.INFO): try: - log_data = { - # TODO - Remove user id from the API - "user_id": self._user_id, - "api_key_id": self._api_key, - "json_log": message, - } - response = requests.post(f"{self._server_url}/api/log/add", json=log_data) - if response.status_code != 200: - raise Exception(response.text) - except Exception: - pass + if isinstance(message, dict): + log_data = { + "api_key_id": self._api_key, + "json_log": message, + } + response = requests.post( + f"{self._server_url}/api/log/add", json=log_data + ) + if response.status_code != 200: + raise Exception(response.text) + + except Exception as e: + print(f"Exception in APILogger: {e}") diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index e8bc8ee3c..c5a267707 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -13,6 +13,7 @@ class ResponseType(TypedDict): "generate_code": "Generate Code", "execute_code": "Code Execution", "_retry_run_code": "Retry Code Generation", + "parse": "Parse Output", } @@ -22,6 +23,7 @@ class QueryExecTracker: _response: ResponseType = {} _steps: List = [] _start_time = None + _success: bool = False def __init__( self, @@ -31,6 +33,7 @@ def __init__( output_type: str, ) -> None: self._start_time = time.time() + self._success = False self._query_info = { "conversation_id": str(conversation_id), "query": query, @@ -123,7 +126,9 @@ def _generate_exec_step(self, func_name: str, result: Any) -> dict: } elif func_name == "execute_code": self._response = self._format_response(result) - return {"result": result} + return {"result": self._response} + else: + return {} def _format_response(self, result: ResponseType) -> ResponseType: """ @@ -160,7 +165,16 @@ def get_summary(self) -> dict: "steps": self._steps, "response": self._response, "execution_time": execution_time, + "success": self._success, } def get_execution_time(self) -> float: return time.time() - self._start_time + + @property + def success(self): + return self._success + + @success.setter + def success(self, value): + self._success = value diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 02755cff0..6c4bb8797 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -395,6 +395,7 @@ def chat(self, query: str, output_type: Optional[str] = None): except Exception as exception: self.last_error = str(exception) + query_exec_tracker.success = False self.logger.log(query_exec_tracker.get_summary()) return ( @@ -407,9 +408,13 @@ def chat(self, query: str, output_type: Optional[str] = None): self._add_result_to_memory(result) + result = query_exec_tracker.execute_func(self._response_parser.parse, result) + + query_exec_tracker.success = True + self.logger.log(query_exec_tracker.get_summary()) - return self._response_parser.parse(result) + return result def _add_result_to_memory(self, result: dict): """ From 1797f622d8797219ee8afbfa921f8be8c2ac215b Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Thu, 12 Oct 2023 00:08:51 +0500 Subject: [PATCH 09/19] move authorization to header --- pandasai/custom_loggers/api_logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pandasai/custom_loggers/api_logger.py b/pandasai/custom_loggers/api_logger.py index 69f81cba8..8670dcf11 100644 --- a/pandasai/custom_loggers/api_logger.py +++ b/pandasai/custom_loggers/api_logger.py @@ -16,11 +16,11 @@ def log(self, message: Union[str, dict], level: int = logging.INFO): try: if isinstance(message, dict): log_data = { - "api_key_id": self._api_key, "json_log": message, } + headers = {"Authorization": f"Bearer {self._api_key}"} response = requests.post( - f"{self._server_url}/api/log/add", json=log_data + f"{self._server_url}/api/log/add", json=log_data, headers=headers ) if response.status_code != 200: raise Exception(response.text) From abd6b2518d8a55c087b3588087894ab4c59031c1 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Thu, 12 Oct 2023 15:54:48 +0500 Subject: [PATCH 10/19] refactor: publish query sumary to server from query tracker --- examples/custom_logger.py | 36 ---------- examples/using_pandasai_log_server.py | 59 +++++++++++++++++ pandasai/custom_loggers/api_logger.py | 29 -------- pandasai/helpers/query_exec_tracker.py | 43 +++++++++++- pandasai/schemas/df_config.py | 8 ++- pandasai/smart_datalake/__init__.py | 6 +- tests/test_query_tracker.py | 92 +++++++++++++++++++++++++- 7 files changed, 203 insertions(+), 70 deletions(-) delete mode 100644 examples/custom_logger.py create mode 100644 examples/using_pandasai_log_server.py delete mode 100644 pandasai/custom_loggers/api_logger.py diff --git a/examples/custom_logger.py b/examples/custom_logger.py deleted file mode 100644 index 2ee4539cd..000000000 --- a/examples/custom_logger.py +++ /dev/null @@ -1,36 +0,0 @@ -import pandas as pd -from pandasai import Agent -from pandasai.custom_loggers.api_logger import APILogger - -from pandasai.llm.openai import OpenAI - -employees_data = { - "EmployeeID": [1, 2, 3, 4, 5], - "Name": ["John", "Emma", "Liam", "Olivia", "William"], - "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], -} - -salaries_data = { - "EmployeeID": [1, 2, 3, 4, 5], - "Salary": [5000, 6000, 4500, 7000, 5500], -} - -employees_df = pd.DataFrame(employees_data) -salaries_df = pd.DataFrame(salaries_data) - - -llm = OpenAI("OPEN-API-KEY") -agent = Agent( - [employees_df, salaries_df], - config={"llm": llm, "enable_cache": True}, - memory_size=10, - logger=APILogger( - "SERVER-URL", - "API-KEY", - ), -) - -# Chat with the agent -response = agent.chat("Plot salary against department?") - -print(response) diff --git a/examples/using_pandasai_log_server.py b/examples/using_pandasai_log_server.py new file mode 100644 index 000000000..04e0857e0 --- /dev/null +++ b/examples/using_pandasai_log_server.py @@ -0,0 +1,59 @@ +import os +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Example 1: Using Environment Variables +os.environ["LOGGING_SERVER_URL"] = "SERVER_URL" +os.environ["LOGGING_SERVER_API_KEY"] = "YOUR_API_KEY" + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent( + [employees_df, salaries_df], + config={ + "llm": llm, + "enable_cache": True, + }, + memory_size=10, +) + +# Chat with the agent +response = agent.chat("Plot salary against department?") +print(response) + + +# Example 2: Using Config +llm = OpenAI("YOUR_API_KEY") +agent = Agent( + [employees_df, salaries_df], + config={ + "llm": llm, + "enable_cache": True, + "log_server": { + "server_url": "SERVER_URL", + "api_key": "YOUR_API_KEY", + }, + }, + memory_size=10, +) + +# Chat with the agent +response = agent.chat("Plot salary against department?") + +print(response) diff --git a/pandasai/custom_loggers/api_logger.py b/pandasai/custom_loggers/api_logger.py deleted file mode 100644 index 8670dcf11..000000000 --- a/pandasai/custom_loggers/api_logger.py +++ /dev/null @@ -1,29 +0,0 @@ -import logging -from typing import Union -import requests -from pandasai.helpers.logger import Logger - - -class APILogger(Logger): - _api_key: str = None - _server_url: str = None - - def __init__(self, server_url: str, api_key: str): - self._api_key = api_key - self._server_url = server_url - - def log(self, message: Union[str, dict], level: int = logging.INFO): - try: - if isinstance(message, dict): - log_data = { - "json_log": message, - } - headers = {"Authorization": f"Bearer {self._api_key}"} - response = requests.post( - f"{self._server_url}/api/log/add", json=log_data, headers=headers - ) - if response.status_code != 200: - raise Exception(response.text) - - except Exception as e: - print(f"Exception in APILogger: {e}") diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index c5a267707..fa002d15f 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -1,5 +1,8 @@ +import os import time -from typing import Any, List, TypedDict +from typing import Any, List, TypedDict, Union + +import requests class ResponseType(TypedDict): @@ -24,6 +27,7 @@ class QueryExecTracker: _steps: List = [] _start_time = None _success: bool = False + _server_config: dict = None def __init__( self, @@ -31,9 +35,11 @@ def __init__( query: str, instance: str, output_type: str, + server_config: Union[dict, None] = None, ) -> None: self._start_time = time.time() self._success = False + self._server_config = server_config self._query_info = { "conversation_id": str(conversation_id), "query": query, @@ -171,6 +177,41 @@ def get_summary(self) -> dict: def get_execution_time(self) -> float: return time.time() - self._start_time + def publish(self) -> None: + """ + Publish Query Summary to remote logging server + """ + api_key = None + server_url = None + + if self._server_config is None: + server_url = os.environ.get("LOGGING_SERVER_URL") + api_key = os.environ.get("LOGGING_SERVER_API_KEY") + else: + server_url = self._server_config.get( + "server_url", os.environ.get("LOGGING_SERVER_URL") + ) + api_key = self._server_config.get( + "api_key", os.environ.get("LOGGING_SERVER_API_KEY") + ) + + if api_key is None or server_url is None: + return + + try: + log_data = { + "json_log": self.get_summary(), + } + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.post( + f"{server_url}/api/log/add", json=log_data, headers=headers + ) + if response.status_code != 200: + raise Exception(response.text) + + except Exception as e: + print(f"Exception in APILogger: {e}") + @property def success(self): return self._success diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index 176167e24..04054ef6b 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, validator, Field -from typing import Optional, List, Any, Dict, Type +from typing import Optional, List, Any, Dict, Type, TypedDict from pandasai.responses import ResponseParser from ..middlewares.base import Middleware @@ -8,6 +8,11 @@ from ..exceptions import LLMNotFoundError +class LogServerConfig(TypedDict): + server_url: str + api_key: str + + class Config(BaseModel): save_logs: bool = True verbose: bool = False @@ -26,6 +31,7 @@ class Config(BaseModel): lazy_load_connector: bool = True response_parser: Type[ResponseParser] = None llm: Any = None + log_server: LogServerConfig = None class Config: arbitrary_types_allowed = True diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 6c4bb8797..2f7a95d13 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -284,7 +284,9 @@ def chat(self, query: str, output_type: Optional[str] = None): query=query, instance=self.__class__.__name__, output_type=output_type, + server_config=self._config.log_server, ) + query_exec_tracker.add_dataframes(self._dfs) self._memory.add(query, True) @@ -396,7 +398,7 @@ def chat(self, query: str, output_type: Optional[str] = None): except Exception as exception: self.last_error = str(exception) query_exec_tracker.success = False - self.logger.log(query_exec_tracker.get_summary()) + query_exec_tracker.publish() return ( "Unfortunately, I was not able to answer your question, " @@ -412,7 +414,7 @@ def chat(self, query: str, output_type: Optional[str] = None): query_exec_tracker.success = True - self.logger.log(query_exec_tracker.get_summary()) + query_exec_tracker.publish() return result diff --git a/tests/test_query_tracker.py b/tests/test_query_tracker.py index 56c10c74b..dff244f20 100644 --- a/tests/test_query_tracker.py +++ b/tests/test_query_tracker.py @@ -1,6 +1,7 @@ +import os import time from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pandas as pd import pytest @@ -289,3 +290,92 @@ def test_execute_func_execute_code_fail( step = tracker._steps[0] assert step["type"] == "Code Execution" assert step["success"] is False + + def test_publish_method_with_server_key(self, tracker: QueryExecTracker): + # Define a mock summary function + def mock_get_summary(): + return "Test summary data" + + # Mock the server_config + tracker._server_config = { + "server_url": "http://custom-server", + "api_key": "custom-api-key", + } + + # Set the get_summary method to your mock + tracker.get_summary = mock_get_summary + + # Mock the requests.post method + mock_response = MagicMock() + mock_response.status_code = 200 + type(mock_response).text = "Response text" + # Mock the requests.post method + with patch("requests.post", return_value=mock_response) as mock_post: + # Call the publish method + result = tracker.publish() + + # Check that requests.post was called with the expected parameters + mock_post.assert_called_with( + "http://custom-server/api/log/add", + json={"json_log": "Test summary data"}, + headers={"Authorization": "Bearer custom-api-key"}, + ) + + # Check the result + assert result is None # The function should return None + + def test_publish_method_with_no_config(self, tracker: QueryExecTracker): + # Define a mock summary function + def mock_get_summary(): + return "Test summary data" + + tracker._server_config = None + + # Set the get_summary method to your mock + tracker.get_summary = mock_get_summary + + # Mock the requests.post method + mock_response = MagicMock() + mock_response.status_code = 200 + type(mock_response).text = "Response text" + # Mock the requests.post method + with patch("requests.post", return_value=mock_response) as mock_post: + # Call the publish method + result = tracker.publish() + + # Check that requests.post was called with the expected parameters + mock_post.assert_not_called() + + # Check the result + assert result is None # The function should return None + + def test_publish_method_with_os_env(self, tracker: QueryExecTracker): + # Define a mock summary function + def mock_get_summary(): + return "Test summary data" + + # Define a mock environment for testing + os.environ["LOGGING_SERVER_URL"] = "http://test-server" + os.environ["LOGGING_SERVER_API_KEY"] = "test-api-key" + + # Set the get_summary method to your mock + tracker.get_summary = mock_get_summary + + # Mock the requests.post method + mock_response = MagicMock() + mock_response.status_code = 200 + type(mock_response).text = "Response text" + # Mock the requests.post method + with patch("requests.post", return_value=mock_response) as mock_post: + # Call the publish method + result = tracker.publish() + + # Check that requests.post was called with the expected parameters + mock_post.assert_called_with( + "http://test-server/api/log/add", + json={"json_log": "Test summary data"}, + headers={"Authorization": "Bearer test-api-key"}, + ) + + # Check the result + assert result is None # The function should return None From 4be8b6e43e12f20c802d580180a7a1ce54c86b0d Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Wed, 18 Oct 2023 12:28:44 +0200 Subject: [PATCH 11/19] feat: add advanced reasoning feature --- pandasai/agent/__init__.py | 20 ++++++ .../prompt_templates/advanced_reasoning.tmpl | 2 + .../check_if_relevant_to_conversation.tmpl | 2 +- .../generate_python_code.tmpl | 9 +-- pandasai/exceptions.py | 10 +++ pandasai/llm/base.py | 57 +++++++++++++++- pandasai/prompts/base.py | 25 ++++++- pandasai/prompts/generate_python_code.py | 30 ++++++--- pandasai/schemas/df_config.py | 1 + pandasai/smart_dataframe/__init__.py | 8 +++ pandasai/smart_datalake/__init__.py | 39 ++++++++++- .../test_generate_python_code_prompt.py | 65 ++++++++++++++++++- tests/test_smartdataframe.py | 14 ++-- tests/test_smartdatalake.py | 19 ++++++ 14 files changed, 272 insertions(+), 29 deletions(-) create mode 100644 pandasai/assets/prompt_templates/advanced_reasoning.tmpl diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 46d88c736..7e41372bb 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -170,3 +170,23 @@ def rephrase_query(self, query: str): "because of the following error:\n" f"\n{exception}\n" ) + + @property + def last_code_generated(self): + return self._lake.last_code_generated + + @property + def last_code_executed(self): + return self._lake.last_code_executed + + @property + def last_prompt(self): + return self._lake.last_prompt + + @property + def last_reasoning(self): + return self._lake.last_reasoning + + @property + def last_answer(self): + return self._lake.last_answer diff --git a/pandasai/assets/prompt_templates/advanced_reasoning.tmpl b/pandasai/assets/prompt_templates/advanced_reasoning.tmpl new file mode 100644 index 000000000..eff38ce34 --- /dev/null +++ b/pandasai/assets/prompt_templates/advanced_reasoning.tmpl @@ -0,0 +1,2 @@ +- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags; +- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later); \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl b/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl index 48f86a72b..058867102 100644 --- a/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl +++ b/pandasai/assets/prompt_templates/check_if_relevant_to_conversation.tmpl @@ -6,4 +6,4 @@ {query} -Is the query related to the conversation? Answer only "true" or "false" (lowercase). \ No newline at end of file +Is the query somehow related to the previous conversation? Answer only "true" or "false" (lowercase). \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 93b52b3b5..594cdf579 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -6,11 +6,12 @@ You are provided with the following pandas DataFrames: {conversation} -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python {current_code} ``` -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. - -Return the updated code: \ No newline at end of file +Take a deep breath and reason step-by-step. Act as a senior data analyst. +Based on the last message in the conversation: +{advanced_reasoning} +- return the updated analyze_data function wrapped within ```python ``` \ No newline at end of file diff --git a/pandasai/exceptions.py b/pandasai/exceptions.py index fec70f874..190dc34e8 100644 --- a/pandasai/exceptions.py +++ b/pandasai/exceptions.py @@ -126,3 +126,13 @@ def __init__(self, template_path, prompt_name="Unknown"): f"Unable to find a file with template at '{template_path}' " f"for '{prompt_name}' prompt." ) + + +class AdvancedReasoningDisabledError(Exception): + """ + Raised when one tries to have access to the answer or reasoning without + having use_advanced_reasoning_framework enabled. + + Args: + Exception (Exception): AdvancedReasoningDisabledError + """ diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py index 651edd34b..faa916b97 100644 --- a/pandasai/llm/base.py +++ b/pandasai/llm/base.py @@ -120,6 +120,53 @@ def _extract_code(self, response: str, separator: str = "```") -> str: return code + def _extract_tag_text(self, response: str, tag: str) -> str: + """ + Extracts the text between two tags in the response. + + Args: + response (str): Response + tag (str): Tag name + + Returns: + (str or None): Extracted text from the response + """ + + match = re.search( + f"(<{tag}>)(.*)()", + response, + re.DOTALL | re.MULTILINE, + ) + if match: + return match.group(2) + return None + + def _extract_reasoning(self, response: str) -> str: + """ + Extracts the reasoning from the response (wrapped in tags). + + Args: + response (str): Response + + Returns: + (str or None): Extracted reasoning from the response + """ + + return self._extract_tag_text(response, "reasoning") + + def _extract_answer(self, response: str) -> str: + """ + Extracts the answer from the response (wrapped in tags). + + Args: + response (str): Response + + Returns: + (str or None): Extracted answer from the response + """ + + return self._extract_tag_text(response, "answer") + @abstractmethod def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ @@ -135,7 +182,7 @@ def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: """ raise MethodNotImplementedError("Call method has not been implemented") - def generate_code(self, instruction: AbstractPrompt) -> str: + def generate_code(self, instruction: AbstractPrompt) -> [str, str, str]: """ Generate the code based on the instruction and the given prompt. @@ -146,8 +193,12 @@ def generate_code(self, instruction: AbstractPrompt) -> str: str: A string of Python code. """ - code = self.call(instruction, suffix="") - return self._extract_code(code) + response = self.call(instruction, suffix="") + return [ + self._extract_code(response), + self._extract_reasoning(response), + self._extract_answer(response), + ] class BaseOpenAI(LLM, ABC): diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index 430ec37eb..7a692a18e 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -11,6 +11,7 @@ class AbstractPrompt(ABC): """ _args: dict = None + _config: dict = None def __init__(self, **kwargs): """ @@ -27,6 +28,9 @@ def __init__(self, **kwargs): def setup(self, **kwargs) -> None: pass + def on_prompt_generation(self) -> None: + pass + def _generate_dataframes(self, dfs): """ Generate the dataframes metadata @@ -58,6 +62,17 @@ def _generate_dataframes(self, dfs): def template(self) -> str: ... + def set_config(self, config): + self._config = config + + def get_config(self, key=None): + if self._config is None: + return None + if key is None: + return self._config + if hasattr(self._config, key): + return getattr(self._config, key) + def set_var(self, var, value): if self._args is None: self._args = {} @@ -72,10 +87,18 @@ def set_vars(self, vars): self._args.update(vars) def to_string(self): + self.on_prompt_generation() + prompt_args = {} for key, value in self._args.items(): if isinstance(value, AbstractPrompt): - value.set_vars({k: v for k, v in self._args.items() if k != key}) + value.set_vars( + { + k: v + for k, v in self._args.items() + if k != key and not isinstance(v, AbstractPrompt) + } + ) prompt_args[key] = value.to_string() else: prompt_args[key] = value diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 0c6ca22eb..3c623069f 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -8,12 +8,13 @@ {conversation} -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. {current_code} -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. +Take a deep breath and reason step-by-step. Act as a senior data analyst. +Based on the last message in the conversation: -Return the updated code:""" # noqa: E501 +- return the updated analyze_data function wrapped within ```python ```""" # noqa: E501 from .file_based_prompt import FileBasedPrompt @@ -25,18 +26,18 @@ class CurrentCodePrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/current_code.tmpl" +class AdvancedReasoningPrompt(FileBasedPrompt): + """The current code""" + + _path_to_template = "assets/prompt_templates/advanced_reasoning.tmpl" + + class GeneratePythonCodePrompt(FileBasedPrompt): """Prompt to generate Python code""" _path_to_template = "assets/prompt_templates/generate_python_code.tmpl" def setup(self, **kwargs) -> None: - default_import = "import pandas as pd" - engine_df_name = "pd.DataFrame" - - self.set_var("default_import", default_import) - self.set_var("engine_df_name", engine_df_name) - if "custom_instructions" in kwargs: self._set_instructions(kwargs["custom_instructions"]) else: @@ -52,6 +53,17 @@ def setup(self, **kwargs) -> None: else: self.set_var("current_code", CurrentCodePrompt()) + def on_prompt_generation(self) -> None: + default_import = "import pandas as pd" + engine_df_name = "pd.DataFrame" + + self.set_var("default_import", default_import) + self.set_var("engine_df_name", engine_df_name) + if self.get_config("use_advanced_reasoning_framework"): + self.set_var("advanced_reasoning", AdvancedReasoningPrompt()) + else: + self.set_var("advanced_reasoning", "") + def _set_instructions(self, instructions: str): lines = instructions.split("\n") indented_lines = [" " + line for line in lines[1:]] diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index 176167e24..65428ea8a 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -14,6 +14,7 @@ class Config(BaseModel): enforce_privacy: bool = False enable_cache: bool = True use_error_correction_framework: bool = True + use_advanced_reasoning_framework: bool = False custom_prompts: Dict = Field(default_factory=dict) custom_instructions: Optional[str] = None open_charts: bool = True diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 647846818..6828404c1 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -679,6 +679,14 @@ def sample_head(self): data = StringIO(self._sample_head) return pd.read_csv(data) + @property + def last_reasoning(self): + return self.lake.last_reasoning + + @property + def last_answer(self): + return self.lake.last_answer + @sample_head.setter def sample_head(self, sample_head: pd.DataFrame): self._sample_head = sample_head.to_csv(index=False) diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 02e29bce9..16d2b9678 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -42,6 +42,7 @@ from ..middlewares.base import Middleware from ..helpers.df_info import DataFrameType from ..helpers.path import find_project_root +from ..exceptions import AdvancedReasoningDisabledError class SmartDatalake: @@ -56,6 +57,8 @@ class SmartDatalake: _memory: Memory _last_code_generated: str = None + _last_reasoning: str = None + _last_answer: str = None _last_result: str = None _last_error: str = None @@ -228,6 +231,7 @@ def _get_prompt( prompt = custom_prompt if custom_prompt else default_prompt() # set default values for the prompt + prompt.set_config(self._config) if "dfs" not in default_values: prompt.set_var("dfs", self._dfs) if "conversation" not in default_values: @@ -318,7 +322,12 @@ def chat(self, query: str, output_type: Optional[str] = None): default_values=default_values, ) - code = self._llm.generate_code(generate_python_code_instruction) + [code, reasoning, answer] = self._llm.generate_code( + generate_python_code_instruction + ) + + self.last_reasoning = reasoning + self.last_answer = answer if self._config.enable_cache and self._cache: self._cache.set(self._get_cache_key(), code) @@ -428,7 +437,9 @@ def _retry_run_code(self, code: str, e: Exception): default_values=default_values, ) - code = self._llm.generate_code(error_correcting_instruction) + [code, _reasoning, _answer] = self._llm.generate_code( + error_correcting_instruction + ) if self._config.callback is not None: self._config.callback.on_code(code) return code @@ -595,6 +606,30 @@ def last_code_generated(self, last_code_generated: str): def last_code_executed(self): return self._code_manager.last_code_executed + @property + def last_reasoning(self): + if not self._config.use_advanced_reasoning_framework: + raise AdvancedReasoningDisabledError( + "You need to enable the advanced reasoning framework" + ) + return self._last_reasoning + + @last_reasoning.setter + def last_reasoning(self, last_reasoning: str): + self._last_reasoning = last_reasoning + + @property + def last_answer(self): + if not self._config.use_advanced_reasoning_framework: + raise AdvancedReasoningDisabledError( + "You need to enable the advanced reasoning framework" + ) + return self._last_answer + + @last_answer.setter + def last_answer(self, last_answer: str): + self._last_answer = last_answer + @property def last_result(self): return self._last_result diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index 5140afb53..b056b487c 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -65,7 +65,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint): Question -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required import pandas as pd @@ -81,9 +81,68 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: """ ``` -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. +Take a deep breath and reason step-by-step. Act as a senior data analyst. +Based on the last message in the conversation: -Return the updated code:''' # noqa E501 +- return the updated analyze_data function wrapped within ```python ```''' # noqa E501 + actual_prompt_content = prompt.to_string() + if sys.platform.startswith("win"): + actual_prompt_content = actual_prompt_content.replace("\r\n", "\n") + assert actual_prompt_content == expected_prompt_content + + def test_advanced_reasoning_prompt(self): + """ + Test a prompt with advanced reasoning framework + """ + + llm = FakeLLM("plt.show()") + dfs = [ + SmartDataframe( + pd.DataFrame({"a": [1], "b": [4]}), + config={"llm": llm, "use_advanced_reasoning_framework": True}, + ) + ] + prompt = GeneratePythonCodePrompt() + prompt.set_config(dfs[0]._lake.config) + prompt.set_var("dfs", dfs) + prompt.set_var("conversation", "Question") + prompt.set_var("save_charts_path", "") + prompt.set_var("output_type_hint", "") + + expected_prompt_content = f'''You are provided with the following pandas DataFrames: + + +Dataframe dfs[0], with 1 rows and 2 columns. +This is the metadata of the dataframe dfs[0]: +a,b +1,4 + + + +Question + + +This is the initial python function. Do not change the params. Given the context, use the right dataframes. +```python +# TODO import all the dependencies required +import pandas as pd + +def analyze_data(dfs: list[pd.DataFrame]) -> dict: + """ + Analyze the data, using the provided dataframes (`dfs`). + 1. Prepare: Preprocessing and cleaning data if necessary + 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + At the end, return a dictionary of: + + """ +``` + +Take a deep breath and reason step-by-step. Act as a senior data analyst. +Based on the last message in the conversation: +- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags; +- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later); +- return the updated analyze_data function wrapped within ```python ```''' # noqa E501 actual_prompt_content = prompt.to_string() if sys.platform.startswith("win"): actual_prompt_content = actual_prompt_content.replace("\r\n", "\n") diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 60ac31319..c9cb30dc9 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -215,7 +215,7 @@ def test_run_with_privacy_enforcement(self, llm): User: How many countries are in the dataframe? -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required import pandas as pd @@ -240,9 +240,10 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: \"\"\" ``` -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. +Take a deep breath and reason step-by-step. Act as a senior data analyst. +Based on the last message in the conversation: -Return the updated code:""" # noqa: E501 +- return the updated analyze_data function wrapped within ```python ```""" # noqa: E501 df.chat("How many countries are in the dataframe?") last_prompt = df.last_prompt if sys.platform.startswith("win"): @@ -275,7 +276,7 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint): User: How many countries are in the dataframe? -This is the initial python function. Do not change the params. +This is the initial python function. Do not change the params. Given the context, use the right dataframes. ```python # TODO import all the dependencies required import pandas as pd @@ -291,9 +292,10 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: """ ``` -Use the provided dataframes (`dfs`) to update the python code within the `analyze_data` function. +Take a deep breath and reason step-by-step. Act as a senior data analyst. +Based on the last message in the conversation: -Return the updated code:''' # noqa: E501 +- return the updated analyze_data function wrapped within ```python ```''' # noqa: E501 df.chat("How many countries are in the dataframe?", output_type=output_type) last_prompt = df.last_prompt diff --git a/tests/test_smartdatalake.py b/tests/test_smartdatalake.py index 2b02c431b..1417d5247 100644 --- a/tests/test_smartdatalake.py +++ b/tests/test_smartdatalake.py @@ -188,3 +188,22 @@ def test_initialize(self, mock_makedirs, smart_datalake: SmartDatalake): cache_dir = os.path.join(os.getcwd(), "cache") mock_makedirs.assert_any_call(cache_dir, mode=0o777, exist_ok=True) + + def test_last_answer_and_reasoning(self, smart_datalake: SmartDatalake): + llm = FakeLLM( + """ + Custom reasoning + Custom answer + ```python +def analyze_data(dfs): + return { 'type': 'text', 'value': "Hello World" } +```""" + ) + smart_datalake._llm = llm + smart_datalake.config.use_advanced_reasoning_framework = True + assert smart_datalake.last_answer is None + assert smart_datalake.last_reasoning is None + + smart_datalake.chat("How many countries are in the dataframe?") + assert smart_datalake.last_answer == "Custom answer" + assert smart_datalake.last_reasoning == "Custom reasoning" From 8421b2c9da120647daa80fa20bc1b54c1f33b1fd Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Wed, 18 Oct 2023 20:50:54 +0200 Subject: [PATCH 12/19] feat: implement advanced reasoning --- .../prompt_templates/advanced_reasoning.tmpl | 5 +++-- .../prompt_templates/generate_python_code.tmpl | 4 ++-- .../assets/prompt_templates/simple_reasoning.tmpl | 1 + pandasai/llm/base.py | 9 ++++++++- pandasai/prompts/generate_python_code.py | 14 ++++++++++---- tests/llms/test_base_llm.py | 15 +++++++++++++++ tests/prompts/test_generate_python_code_prompt.py | 15 ++++++++------- tests/test_smartdataframe.py | 8 ++++---- 8 files changed, 51 insertions(+), 20 deletions(-) create mode 100644 pandasai/assets/prompt_templates/simple_reasoning.tmpl diff --git a/pandasai/assets/prompt_templates/advanced_reasoning.tmpl b/pandasai/assets/prompt_templates/advanced_reasoning.tmpl index eff38ce34..7d132a2e0 100644 --- a/pandasai/assets/prompt_templates/advanced_reasoning.tmpl +++ b/pandasai/assets/prompt_templates/advanced_reasoning.tmpl @@ -1,2 +1,3 @@ -- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags; -- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later); \ No newline at end of file +- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags. +- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later). +- return the updated analyze_data function wrapped within ```python ``` \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 594cdf579..566ee388e 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -12,6 +12,6 @@ This is the initial python function. Do not change the params. Given the context ``` Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: -{advanced_reasoning} -- return the updated analyze_data function wrapped within ```python ``` \ No newline at end of file +{reasoning} \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/simple_reasoning.tmpl b/pandasai/assets/prompt_templates/simple_reasoning.tmpl new file mode 100644 index 000000000..c728ffb0d --- /dev/null +++ b/pandasai/assets/prompt_templates/simple_reasoning.tmpl @@ -0,0 +1 @@ +- return the updated analyze_data function wrapped within ```python ``` \ No newline at end of file diff --git a/pandasai/llm/base.py b/pandasai/llm/base.py index faa916b97..1ff64d57b 100644 --- a/pandasai/llm/base.py +++ b/pandasai/llm/base.py @@ -165,7 +165,14 @@ def _extract_answer(self, response: str) -> str: (str or None): Extracted answer from the response """ - return self._extract_tag_text(response, "answer") + sentences = [ + sentence + for sentence in response.split(". ") + if "temp_chart.png" not in sentence + ] + answer = ". ".join(sentences) + + return self._extract_tag_text(answer, "answer") @abstractmethod def call(self, instruction: AbstractPrompt, suffix: str = "") -> str: diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 3c623069f..513767d95 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -12,8 +12,8 @@ {current_code} Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: - - return the updated analyze_data function wrapped within ```python ```""" # noqa: E501 @@ -32,6 +32,12 @@ class AdvancedReasoningPrompt(FileBasedPrompt): _path_to_template = "assets/prompt_templates/advanced_reasoning.tmpl" +class SimpleReasoningPrompt(FileBasedPrompt): + """The current code""" + + _path_to_template = "assets/prompt_templates/simple_reasoning.tmpl" + + class GeneratePythonCodePrompt(FileBasedPrompt): """Prompt to generate Python code""" @@ -45,7 +51,7 @@ def setup(self, **kwargs) -> None: """Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501 +3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 ) if "current_code" in kwargs: @@ -60,9 +66,9 @@ def on_prompt_generation(self) -> None: self.set_var("default_import", default_import) self.set_var("engine_df_name", engine_df_name) if self.get_config("use_advanced_reasoning_framework"): - self.set_var("advanced_reasoning", AdvancedReasoningPrompt()) + self.set_var("reasoning", AdvancedReasoningPrompt()) else: - self.set_var("advanced_reasoning", "") + self.set_var("reasoning", SimpleReasoningPrompt()) def _set_instructions(self, instructions: str): lines = instructions.split("\n") diff --git a/tests/llms/test_base_llm.py b/tests/llms/test_base_llm.py index 8f044adc3..b6af1746f 100644 --- a/tests/llms/test_base_llm.py +++ b/tests/llms/test_base_llm.py @@ -64,3 +64,18 @@ def test_extract_code(self): """ assert LLM()._extract_code(code) == "print('Hello World')" + + def test_extract_answer(self): + llm = LLM() + response = "This is the answer." + expected_answer = "This is the answer." + assert llm._extract_answer(response) == expected_answer + + def test_extract_answer_with_temp_chart(self): + llm = LLM() + response = ( + "This is the answer. It returns a temp_chart.png. " + "But it shouldn't." + ) + expected_answer = "This is the answer. But it shouldn't." + assert llm._extract_answer(response) == expected_answer diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index b056b487c..8cc3a0aa6 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -75,15 +75,15 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) At the end, return a dictionary of: {output_type_hint} """ ``` Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: - - return the updated analyze_data function wrapped within ```python ```''' # noqa E501 actual_prompt_content = prompt.to_string() if sys.platform.startswith("win"): @@ -132,16 +132,17 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) At the end, return a dictionary of: """ ``` Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: -- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags; -- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later); +- explain your reasoning to implement the last step to the user that asked for it; it should be wrapped between tags. +- answer to the user as you would do as a data analyst; wrap it between tags; do not include the value or the chart itself (it will be calculated later). - return the updated analyze_data function wrapped within ```python ```''' # noqa E501 actual_prompt_content = prompt.to_string() if sys.platform.startswith("win"): @@ -153,7 +154,7 @@ def test_custom_instructions(self): 1. Load: Load the data from a file or database 2. Prepare: Preprocessing and cleaning data if necessary 3. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) -4. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501 +4. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 prompt = GeneratePythonCodePrompt(custom_instructions=custom_instructions) actual_instructions = prompt._args["instructions"] @@ -164,5 +165,5 @@ def test_custom_instructions(self): 1. Load: Load the data from a file or database 2. Prepare: Preprocessing and cleaning data if necessary 3. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 4. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.)""" # noqa: E501 + 4. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.)""" # noqa: E501 ) diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index c9cb30dc9..c00c9735e 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -225,7 +225,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) At the end, return a dictionary of: - type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) @@ -241,8 +241,8 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: ``` Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: - - return the updated analyze_data function wrapped within ```python ```""" # noqa: E501 df.chat("How many countries are in the dataframe?") last_prompt = df.last_prompt @@ -286,15 +286,15 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: Analyze the data, using the provided dataframes (`dfs`). 1. Prepare: Preprocessing and cleaning data if necessary 2. Process: Manipulating data for analysis (grouping, filtering, aggregating, etc.) - 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart save it to an image in temp_chart.png and do not show the chart.) + 3. Analyze: Conducting the actual analysis (if the user asks to plot a chart you must save it as an image in temp_chart.png and not show the chart.) At the end, return a dictionary of: {output_type_hint} """ ``` Take a deep breath and reason step-by-step. Act as a senior data analyst. +In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: - - return the updated analyze_data function wrapped within ```python ```''' # noqa: E501 df.chat("How many countries are in the dataframe?", output_type=output_type) From ab9d9b3cb1ca2ab64ef2d93cb9dce10ca8e6a5ba Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Wed, 18 Oct 2023 23:10:23 +0200 Subject: [PATCH 13/19] docs: add use_advanced_reasoning_framework --- docs/getting-started.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/getting-started.md b/docs/getting-started.md index 609eb70cc..9fe7a59b5 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -190,6 +190,7 @@ Settings: - `save_charts_path`: the path where to save the charts. Defaults to `exports/charts/`. You can use this setting to override the default path. - `enable_cache`: whether to enable caching. Defaults to `True`. If set to `True`, PandasAI will cache the results of the LLM to improve the response time. If set to `False`, PandasAI will always call the LLM. - `use_error_correction_framework`: whether to use the error correction framework. Defaults to `True`. If set to `True`, PandasAI will try to correct the errors in the code generated by the LLM with further calls to the LLM. If set to `False`, PandasAI will not try to correct the errors in the code generated by the LLM. +- `use_advanced_reasoning_framework`: whether to use the advanced reasoning framework. Defaults to `False`. If set to `True`, PandasAI will try to use advanced reasoning to improve the results of the LLM and provide an explanation for the results. - `max_retries`: the maximum number of retries to use when using the error correction framework. Defaults to `3`. You can use this setting to override the default number of retries. - `custom_prompts`: the custom prompts to use. Defaults to `{}`. You can use this setting to override the default custom prompts. You can find more information about custom prompts [here](custom-prompts.md). - `custom_whitelisted_dependencies`: the custom whitelisted dependencies to use. Defaults to `{}`. You can use this setting to override the default custom whitelisted dependencies. You can find more information about custom whitelisted dependencies [here](custom-whitelisted-dependencies.md). From 710ba3ab2fee64bc5d61d45be7d45d4dad308734 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Thu, 19 Oct 2023 22:15:06 +0500 Subject: [PATCH 14/19] chore(QueryExecTracker): track conversation for agent (#659) --- pandasai/agent/__init__.py | 11 +- pandasai/helpers/query_exec_tracker.py | 99 ++++++++++++------ pandasai/smart_dataframe/__init__.py | 3 + pandasai/smart_datalake/__init__.py | 75 +++++++++----- tests/test_query_tracker.py | 137 +++++++++++++++++++++++-- 5 files changed, 264 insertions(+), 61 deletions(-) diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 7e41372bb..455da534d 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -41,6 +41,10 @@ def __init__( dfs = [dfs] self._lake = SmartDatalake(dfs, config, logger, memory=Memory(memory_size)) + + # set instance type in SmartDataLake + self._lake.set_instance_type(self.__class__.__name__) + self._logger = self._lake.logger def _call_llm_with_prompt(self, prompt: AbstractPrompt): @@ -70,7 +74,8 @@ def chat(self, query: str, output_type: Optional[str] = None): Simulate a chat interaction with the assistant on Dataframe. """ try: - self.check_if_related_to_conversation(query) + is_related = self.check_if_related_to_conversation(query) + self._lake.is_related_query(is_related) result = self._lake.chat(query, output_type=output_type) return result except Exception as exception: @@ -80,7 +85,7 @@ def chat(self, query: str, output_type: Optional[str] = None): f"\n{exception}\n" ) - def check_if_related_to_conversation(self, query: str): + def check_if_related_to_conversation(self, query: str) -> bool: """ Check if the query is related to the previous conversation """ @@ -105,6 +110,8 @@ def check_if_related_to_conversation(self, query: str): if not related: self._lake.clear_memory() + return related + def clarification_questions(self, query: str) -> List[str]: """ Generate clarification questions based on the data diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index fa002d15f..0bce59129 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -3,6 +3,7 @@ from typing import Any, List, TypedDict, Union import requests +from collections import defaultdict class ResponseType(TypedDict): @@ -21,32 +22,63 @@ class ResponseType(TypedDict): class QueryExecTracker: - _query_info: str = {} - _dataframes: List = [] - _response: ResponseType = {} - _steps: List = [] - _start_time = None - _success: bool = False - _server_config: dict = None + _query_info: dict + _dataframes: List + _response: ResponseType + _steps: List + _func_exec_count: dict + _success: bool + _server_config: dict def __init__( self, - conversation_id: str, - query: str, - instance: str, - output_type: str, server_config: Union[dict, None] = None, ) -> None: - self._start_time = time.time() self._success = False + self._start_time = None self._server_config = server_config + self._query_info = {} + self._is_related_query = True + + def set_related_query(self, flag: bool): + """ + Set Related Query Parameter whether new query is related to the conversation + or not + Args: + flag (bool): boolean to set true if related else false + """ + self._is_related_query = flag + + def add_query_info( + self, conversation_id: str, instance: str, query: str, output_type: str + ): + """ + Adds query information for new track + Args: + conversation_id (str): conversation id + instance (str): instance like Agent or SmartDataframe + query (str): chat query given by user + output_type (str): output type expected by user + """ self._query_info = { "conversation_id": str(conversation_id), - "query": query, "instance": instance, + "query": query, "output_type": output_type, + "is_related_query": self._is_related_query, } + def start_new_track(self): + """ + Resets tracking variables to start new track + """ + self._start_time = time.time() + self._dataframes: List = [] + self._response: ResponseType = {} + self._steps: List = [] + self._query_info = {} + self._func_exec_count: dict = defaultdict(int) + def add_dataframes(self, dfs: List) -> None: """ Add used dataframes for the query to query exec tracker @@ -90,7 +122,6 @@ def execute_func(self, function, *args, **kwargs) -> Any: step_data = self._generate_exec_step(tag, result) - step_data["type"] = exec_steps[tag] step_data["success"] = True step_data["execution_time"] = execution_time @@ -119,22 +150,29 @@ def _generate_exec_step(self, func_name: str, result: Any) -> dict: Returns: dict: dictionary with information about the function execution """ - if ( - func_name == "cache_hit" - or func_name == "generate_code" - or func_name == "_retry_run_code" - ): - return {"code_generated": result} + + step = {"type": exec_steps[func_name]} + + if func_name == "cache_hit" or func_name == "generate_code": + step["code_generated"] = result + elif func_name == "_retry_run_code": + self._func_exec_count["_retry_run_code"] += 1 + + step[ + "type" + ] = f"{exec_steps[func_name]} ({self._func_exec_count['_retry_run_code']})" + + step["code_generated"] = result + elif func_name == "_get_prompt": - return { - "prompt_class": result.__class__.__name__, - "generated_prompt": result.to_string(), - } + step["prompt_class"] = result.__class__.__name__ + step["generated_prompt"] = result.to_string() + elif func_name == "execute_code": self._response = self._format_response(result) - return {"result": self._response} - else: - return {} + step["result"] = self._response + + return step def _format_response(self, result: ResponseType) -> ResponseType: """ @@ -164,6 +202,9 @@ def get_summary(self) -> dict: Returns: dict: summary json """ + if self._start_time is None: + raise RuntimeError("[QueryExecTracker]: Tracking not started") + execution_time = time.time() - self._start_time return { "query_info": self._query_info, @@ -213,9 +254,9 @@ def publish(self) -> None: print(f"Exception in APILogger: {e}") @property - def success(self): + def success(self) -> bool: return self._success @success.setter - def success(self, value): + def success(self, value: bool): self._success = value diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 6828404c1..4a024d3bc 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -302,6 +302,9 @@ def __init__( self._table_description = description self._lake = SmartDatalake([self], config, logger) + # set instance type in SmartDataLake + self._lake.set_instance_type(self.__class__.__name__) + # If no name is provided, use the fallback name provided the connector if self._table_name is None and self.connector: self._table_name = self.connector.fallback_name diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index d432a028f..672dbed7c 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -53,8 +53,11 @@ class SmartDatalake: _logger: Logger _start_time: float _last_prompt_id: uuid.UUID + _conversation_id: uuid.UUID _code_manager: CodeManager _memory: Memory + _instance: str + _query_exec_tracker: QueryExecTracker _last_code_generated: str = None _last_reasoning: str = None @@ -113,6 +116,20 @@ def __init__( else: self._response_parser = ResponseParser(context) + self._conversation_id = uuid.uuid4() + + self._instance = self.__class__.__name__ + + self._query_exec_tracker = QueryExecTracker( + server_config=self._config.log_server, + ) + + def set_instance_type(self, type: str): + self._instance = type + + def is_related_query(self, flag: bool): + self._query_exec_tracker.set_related_query(flag) + def initialize(self): """Initialize the SmartDatalake""" @@ -278,20 +295,18 @@ def chat(self, query: str, output_type: Optional[str] = None): ValueError: If the query is empty """ + self._query_exec_tracker.start_new_track() + self.logger.log(f"Question: {query}") self.logger.log(f"Running PandasAI with {self._llm.type} LLM...") self._assign_prompt_id() - query_exec_tracker = QueryExecTracker( - conversation_id=self._last_prompt_id, - query=query, - instance=self.__class__.__name__, - output_type=output_type, - server_config=self._config.log_server, + self._query_exec_tracker.add_query_info( + self._conversation_id, self._instance, query, output_type ) - query_exec_tracker.add_dataframes(self._dfs) + self._query_exec_tracker.add_dataframes(self._dfs) self._memory.add(query, True) @@ -304,7 +319,7 @@ def chat(self, query: str, output_type: Optional[str] = None): and self._cache.get(self._get_cache_key()) ): self.logger.log("Using cached response") - code = query_exec_tracker.execute_func( + code = self._query_exec_tracker.execute_func( self._cache.get, self._get_cache_key(), tag="cache_hit" ) @@ -322,14 +337,16 @@ def chat(self, query: str, output_type: Optional[str] = None): ): default_values["current_code"] = self._last_code_generated - generate_python_code_instruction = query_exec_tracker.execute_func( - self._get_prompt, - "generate_python_code", - default_prompt=GeneratePythonCodePrompt, - default_values=default_values, + generate_python_code_instruction = ( + self._query_exec_tracker.execute_func( + self._get_prompt, + "generate_python_code", + default_prompt=GeneratePythonCodePrompt, + default_values=default_values, + ) ) - [code, reasoning, answer] = query_exec_tracker.execute_func( + [code, reasoning, answer] = self._query_exec_tracker.execute_func( self._llm.generate_code, generate_python_code_instruction ) @@ -357,11 +374,12 @@ def chat(self, query: str, output_type: Optional[str] = None): while retry_count < self._config.max_retries: try: # Execute the code - result = query_exec_tracker.execute_func( + result = self._query_exec_tracker.execute_func( self._code_manager.execute_code, code=code_to_run, prompt_id=self._last_prompt_id, ) + break except Exception as e: if ( @@ -379,7 +397,7 @@ def chat(self, query: str, output_type: Optional[str] = None): ) traceback_error = traceback.format_exc() - code_to_run = query_exec_tracker.execute_func( + code_to_run = self._query_exec_tracker.execute_func( self._retry_run_code, code, traceback_error ) @@ -390,7 +408,7 @@ def chat(self, query: str, output_type: Optional[str] = None): self.logger.log( "\n".join(validation_logs), level=logging.WARNING ) - query_exec_tracker.add_step( + self._query_exec_tracker.add_step( { "type": "Validating Output", "success": False, @@ -398,7 +416,7 @@ def chat(self, query: str, output_type: Optional[str] = None): } ) else: - query_exec_tracker.add_step( + self._query_exec_tracker.add_step( { "type": "Validating Output", "success": True, @@ -411,8 +429,8 @@ def chat(self, query: str, output_type: Optional[str] = None): except Exception as exception: self.last_error = str(exception) - query_exec_tracker.success = False - query_exec_tracker.publish() + self._query_exec_tracker.success = False + self._query_exec_tracker.publish() return ( "Unfortunately, I was not able to answer your question, " @@ -420,15 +438,19 @@ def chat(self, query: str, output_type: Optional[str] = None): f"\n{exception}\n" ) - self.logger.log(f"Executed in: {query_exec_tracker.get_execution_time()}s") + self.logger.log( + f"Executed in: {self._query_exec_tracker.get_execution_time()}s" + ) self._add_result_to_memory(result) - result = query_exec_tracker.execute_func(self._response_parser.parse, result) + result = self._query_exec_tracker.execute_func( + self._response_parser.parse, result + ) - query_exec_tracker.success = True + self._query_exec_tracker.success = True - query_exec_tracker.publish() + self._query_exec_tracker.publish() return result @@ -484,6 +506,7 @@ def clear_memory(self): Clears the memory """ self._memory.clear() + self._conversation_id = uuid.uuid4() @property def engine(self): @@ -688,3 +711,7 @@ def dfs(self): @property def memory(self): return self._memory + + @property + def instance(self): + return self._instance diff --git a/tests/test_query_tracker.py b/tests/test_query_tracker.py index dff244f20..95f862d3b 100644 --- a/tests/test_query_tracker.py +++ b/tests/test_query_tracker.py @@ -72,12 +72,15 @@ def smart_datalake(self, smart_dataframe: SmartDataframe): @pytest.fixture def tracker(self): - return QueryExecTracker( + tracker = QueryExecTracker() + tracker.start_new_track() + tracker.add_query_info( conversation_id="123", - query="which country has the highest GDP?", instance="SmartDatalake", + query="which country has the highest GDP?", output_type="json", ) + return tracker def test_add_dataframes( self, smart_dataframe: SmartDataframe, tracker: QueryExecTracker @@ -131,22 +134,61 @@ def test_format_response_other_type(self, tracker: QueryExecTracker): assert formatted_response["type"] == "other_type" assert formatted_response["value"] == "SomeValue" - def test_get_summary(self, tracker: QueryExecTracker): + def test_get_summary(self): # Execute a mock function to generate some steps and response def mock_function(*args, **kwargs): return "Mock Result" - tracker.execute_func(mock_function, tag="custom_tag") + tracker = QueryExecTracker() + + tracker.start_new_track() + + tracker.add_query_info( + conversation_id="123", + instance="SmartDatalake", + query="which country has the highest GDP?", + output_type="json", + ) # Get the summary summary = tracker.get_summary() + tracker.execute_func(mock_function, tag="custom_tag") + # Check if the summary contains the expected keys assert "query_info" in summary assert "dataframes" in summary assert "steps" in summary assert "response" in summary assert "execution_time" in summary + assert "is_related_query" in summary["query_info"] + + def test_related_query_in_summary(self): + # Execute a mock function to generate some steps and response + def mock_function(*args, **kwargs): + return "Mock Result" + + tracker = QueryExecTracker() + + tracker.set_related_query(False) + + tracker.start_new_track() + + tracker.add_query_info( + conversation_id="123", + instance="SmartDatalake", + query="which country has the highest GDP?", + output_type="json", + ) + + # Get the summary + summary = tracker.get_summary() + + tracker.execute_func(mock_function, tag="custom_tag") + + # Check if the summary contains the expected keys + assert "is_related_query" in summary["query_info"] + assert not summary["query_info"]["is_related_query"] def test_get_execution_time(self, tracker: QueryExecTracker): def mock_function(*args, **kwargs): @@ -244,15 +286,24 @@ def test_execute_func_re_rerun_code(self, tracker: QueryExecTracker): # Execute the mock function using execute_func result = tracker.execute_func(mock_func) + # Execute the mock function using execute_func + result = tracker.execute_func(mock_func) + # Check if the result is as expected assert result == "code" # Check if the step was added correctly - assert len(tracker._steps) == 1 + assert len(tracker._steps) == 2 step = tracker._steps[0] assert "code_generated" in step - assert step["type"] == "Retry Code Generation" + assert step["type"] == "Retry Code Generation (1)" assert step["success"] is True + # Check second step as well + step2 = tracker._steps[1] + assert "code_generated" in step2 + assert step2["type"] == "Retry Code Generation (2)" + assert step2["success"] is True + def test_execute_func_execute_code_success( self, sample_df: pd.DataFrame, tracker: QueryExecTracker ): @@ -379,3 +430,77 @@ def mock_get_summary(): # Check the result assert result is None # The function should return None + + def test_multiple_instance_of_tracker(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + mock_func.__name__ = "generate_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="generate_code") + + tracker2 = QueryExecTracker() + tracker2.start_new_track() + tracker2.add_query_info( + conversation_id="12345", + instance="SmartDatalake", + query="which country has the highest GDP?", + output_type="json", + ) + + assert len(tracker._steps) == 1 + assert len(tracker2._steps) == 0 + + # Execute code with tracker 2 + tracker2.execute_func(mock_func, tag="generate_code") + assert len(tracker._steps) == 1 + assert len(tracker2._steps) == 1 + + # Create a mock function + mock_func2 = Mock() + mock_func2.return_value = "code" + mock_func2.__name__ = "_retry_run_code" + tracker2.execute_func(mock_func2, tag="_retry_run_code") + assert len(tracker._steps) == 1 + assert len(tracker2._steps) == 2 + + assert ( + tracker._query_info["conversation_id"] + != tracker2._query_info["conversation_id"] + ) + + def test_conversation_id_in_different_tracks(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = "code" + mock_func.__name__ = "generate_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="generate_code") + + summary = tracker.get_summary() + + tracker.start_new_track() + + tracker.add_query_info( + conversation_id="123", + instance="SmartDatalake", + query="Plot the GDP's?", + output_type="json", + ) + + # Create a mock function + mock_func2 = Mock() + mock_func2.return_value = "code" + mock_func2.__name__ = "_retry_run_code" + + tracker.execute_func(mock_func2, tag="_retry_run_code") + + summary2 = tracker.get_summary() + + assert ( + summary["query_info"]["conversation_id"] + == summary2["query_info"]["conversation_id"] + ) + assert len(tracker._steps) == 1 From 60a8fd9d89a0d97b4e1ae6bf09ddaa5de748c3b3 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Fri, 20 Oct 2023 19:07:43 +0500 Subject: [PATCH 15/19] chore(logs): add reasoning and answers in the logs (#665) --- pandasai/helpers/query_exec_tracker.py | 13 +++++++--- pandasai/smart_datalake/__init__.py | 17 +++++++------ tests/test_query_tracker.py | 35 ++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 10 deletions(-) diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index 0bce59129..f4fcd323b 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -153,16 +153,23 @@ def _generate_exec_step(self, func_name: str, result: Any) -> dict: step = {"type": exec_steps[func_name]} - if func_name == "cache_hit" or func_name == "generate_code": + if func_name == "cache_hit": step["code_generated"] = result + + elif func_name == "generate_code": + step["code_generated"] = result[0] + step["reasoning"] = result[1] + step["answer"] = result[2] + elif func_name == "_retry_run_code": self._func_exec_count["_retry_run_code"] += 1 step[ "type" ] = f"{exec_steps[func_name]} ({self._func_exec_count['_retry_run_code']})" - - step["code_generated"] = result + step["code_generated"] = result[0] + step["reasoning"] = result[1] + step["answer"] = result[2] elif func_name == "_get_prompt": step["prompt_class"] = result.__class__.__name__ diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 672dbed7c..447249469 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -397,7 +397,11 @@ def chat(self, query: str, output_type: Optional[str] = None): ) traceback_error = traceback.format_exc() - code_to_run = self._query_exec_tracker.execute_func( + [ + code_to_run, + reasoning, + answer, + ] = self._query_exec_tracker.execute_func( self._retry_run_code, code, traceback_error ) @@ -469,7 +473,7 @@ def _add_result_to_memory(self, result: dict): elif result["type"] == "dataframe" or result["type"] == "plot": self._memory.add("Ok here it is", False) - def _retry_run_code(self, code: str, e: Exception): + def _retry_run_code(self, code: str, e: Exception) -> List: """ A method to retry the code execution with error correction framework. @@ -494,12 +498,11 @@ def _retry_run_code(self, code: str, e: Exception): default_values=default_values, ) - [code, _reasoning, _answer] = self._llm.generate_code( - error_correcting_instruction - ) + result = self._llm.generate_code(error_correcting_instruction) if self._config.callback is not None: - self._config.callback.on_code(code) - return code + self._config.callback.on_code(result[0]) + + return result def clear_memory(self): """ diff --git a/tests/test_query_tracker.py b/tests/test_query_tracker.py index 95f862d3b..342ee4bda 100644 --- a/tests/test_query_tracker.py +++ b/tests/test_query_tracker.py @@ -504,3 +504,38 @@ def test_conversation_id_in_different_tracks(self, tracker: QueryExecTracker): == summary2["query_info"]["conversation_id"] ) assert len(tracker._steps) == 1 + + def test_reasoning_answer_in_code_section(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = ["code", "reason", "answer"] + mock_func.__name__ = "generate_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="generate_code") + + summary = tracker.get_summary() + + step = summary["steps"][0] + + assert "reasoning" in step + assert "answer" in step + assert step["reasoning"] == "reason" + assert step["answer"] == "answer" + + def test_reasoning_answer_in_rerun_code(self, tracker: QueryExecTracker): + # Create a mock function + mock_func = Mock() + mock_func.return_value = ["code", "reason", "answer"] + mock_func.__name__ = "_retry_run_code" + + # Execute the mock function using execute_func + tracker.execute_func(mock_func, tag="_retry_run_code") + + summary = tracker.get_summary() + + step = summary["steps"][0] + assert "reasoning" in step + assert "answer" in step + assert step["reasoning"] == "reason" + assert step["answer"] == "answer" From ba2c604606b7e33158bd1207058f6c121c8f165f Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Sat, 21 Oct 2023 01:46:25 +0500 Subject: [PATCH 16/19] feat(logging): store plot image as base64 in database (#666) --- pandasai/helpers/query_exec_tracker.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pandasai/helpers/query_exec_tracker.py b/pandasai/helpers/query_exec_tracker.py index f4fcd323b..76585498b 100644 --- a/pandasai/helpers/query_exec_tracker.py +++ b/pandasai/helpers/query_exec_tracker.py @@ -1,3 +1,4 @@ +import base64 import os import time from typing import Any, List, TypedDict, Union @@ -199,6 +200,19 @@ def _format_response(self, result: ResponseType) -> ResponseType: "rows": result["value"].values.tolist(), }, } + return formatted_result + elif result["type"] == "plot": + with open(result["value"], "rb") as image_file: + image_data = image_file.read() + # Encode the image data to Base64 + base64_image = ( + f"data:image/png;base64,{base64.b64encode(image_data).decode()}" + ) + formatted_result = { + "type": result["type"], + "value": base64_image, + } + return formatted_result else: return result From e711c89a2c62b2ce0be18d382d9ce2339203d01e Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Sat, 21 Oct 2023 12:43:27 +0500 Subject: [PATCH 17/19] feat(skills): add skills to the pandas-ai library (#653) --- docs/examples.md | 54 +++ docs/skills.md | 113 ++++++ examples/skills_example.py | 47 +++ mkdocs.yml | 1 + pandasai/__init__.py | 10 +- pandasai/agent/__init__.py | 8 + .../generate_python_code.tmpl | 2 +- pandasai/helpers/code_manager.py | 64 +++- pandasai/helpers/skills_manager.py | 103 ++++++ pandasai/skills/__init__.py | 32 ++ pandasai/smart_dataframe/__init__.py | 7 + pandasai/smart_datalake/__init__.py | 26 +- .../test_generate_python_code_prompt.py | 2 + tests/skills/test_skills.py | 347 ++++++++++++++++++ tests/test_codemanager.py | 70 ++-- 15 files changed, 846 insertions(+), 40 deletions(-) create mode 100644 docs/skills.md create mode 100644 examples/skills_example.py create mode 100644 pandasai/helpers/skills_manager.py create mode 100644 pandasai/skills/__init__.py create mode 100644 tests/skills/test_skills.py diff --git a/docs/examples.md b/docs/examples.md index a3c07abcf..424afaf23 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -260,3 +260,57 @@ for question in questions: response = agent.explain() print(response) ``` + +## Add Skills to the Agent + +You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations. + +``` +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +@skill( + name="Display employee salary", + description="Plots the employee salaries against names", + usage="Displays the plot having name on x axis and salaries on y axis", +) +def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) + +``` diff --git a/docs/skills.md b/docs/skills.md new file mode 100644 index 000000000..ae6488219 --- /dev/null +++ b/docs/skills.md @@ -0,0 +1,113 @@ +# Skills + +You can add customs functions for the agent to use, allowing the agent to expand its capabilities. These custom functions can be seamlessly integrated with the agent's skills, enabling a wide range of user-defined operations. + +## Example Usage + +```python + +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Function doc string to give more context to the model for use this skill +@skill +def plot_salaries(name: list[str], salaries: list[int]): + """ + Displays the bar chart having name on x axis and salaries on y axis + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + # plot bars + import matplotlib.pyplot as plt + + plt.bar(name, salaries) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") + + +``` + +## Add Streamlit Skill + +```python +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill +import streamlit as st + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + +# Function doc string to give more context to the model for use this skill +@skill +def plot_salaries(name: list[str], salary: list[int]): + """ + Displays the bar chart having name on x axis and salaries on y axis using streamlit + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + import matplotlib.pyplot as plt + + plt.bar(name, salary) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + fig = plt.gcf() + st.pyplot(fig) + + +llm = OpenAI("YOUR_API_KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries_using_streamlit) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) +``` diff --git a/examples/skills_example.py b/examples/skills_example.py new file mode 100644 index 000000000..e1df24d99 --- /dev/null +++ b/examples/skills_example.py @@ -0,0 +1,47 @@ +import pandas as pd +from pandasai import Agent + +from pandasai.llm.openai import OpenAI +from pandasai.skills import skill + +employees_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Name": ["John", "Emma", "Liam", "Olivia", "William"], + "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], +} + +salaries_data = { + "EmployeeID": [1, 2, 3, 4, 5], + "Salary": [5000, 6000, 4500, 7000, 5500], +} + +employees_df = pd.DataFrame(employees_data) +salaries_df = pd.DataFrame(salaries_data) + + +# Add function docstring to give more context to model +@skill +def plot_salaries(name: list[str], salary: list[int]) -> str: + """ + Displays the bar chart having name on x axis and salaries on y axis using streamlit + Args: + name (list[str]): Employee name + salaries (list[int]): Salaries + """ + import matplotlib.pyplot as plt + + plt.bar(name, salary) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + + +llm = OpenAI("YOUR-API-KEY") +agent = Agent([employees_df, salaries_df], config={"llm": llm}, memory_size=10) + +agent.add_skills(plot_salaries) + +# Chat with the agent +response = agent.chat("Plot the employee salaries against names") +print(response) diff --git a/mkdocs.yml b/mkdocs.yml index 86d71cd87..6808e430d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,6 +26,7 @@ nav: - callbacks.md - custom-instructions.md - custom-prompts.md + - skills.md - custom-whitelisted-dependencies.md - Examples: - examples.md diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 5e1a08ea3..2c71d8116 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -45,6 +45,7 @@ from .schemas.df_config import Config from .helpers.cache import Cache from .agent import Agent +from .skills import skill __version__ = importlib.metadata.version(__package__ or __name__) @@ -257,4 +258,11 @@ def clear_cache(filename: str = None): cache.clear() -__all__ = ["PandasAI", "SmartDataframe", "SmartDatalake", "Agent", "clear_cache"] +__all__ = [ + "PandasAI", + "SmartDataframe", + "SmartDatalake", + "Agent", + "clear_cache", + "skill", +] diff --git a/pandasai/agent/__init__.py b/pandasai/agent/__init__.py index 455da534d..52034bac0 100644 --- a/pandasai/agent/__init__.py +++ b/pandasai/agent/__init__.py @@ -1,5 +1,7 @@ import json from typing import Union, List, Optional + +from pandasai.skills import skill from ..helpers.df_info import DataFrameType from ..helpers.logger import Logger from ..helpers.memory import Memory @@ -47,6 +49,12 @@ def __init__( self._logger = self._lake.logger + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self._lake.add_skills(*skills) + def _call_llm_with_prompt(self, prompt: AbstractPrompt): """ Call LLM with prompt using error handling to retry based on config diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 566ee388e..b0c337e44 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -10,7 +10,7 @@ This is the initial python function. Do not change the params. Given the context ```python {current_code} ``` - +{skills} Take a deep breath and reason step-by-step. Act as a senior data analyst. In the answer, you must never write the "technical" names of the tables. Based on the last message in the conversation: diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index 6c5e65eee..d015d6dac 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -6,6 +6,8 @@ import astor import pandas as pd +from pandasai.helpers.skills_manager import SkillsManager + from .node_visitors import AssignmentVisitor, CallVisitor from .save_chart import add_save_chart from .optional import import_dependency @@ -23,6 +25,29 @@ import traceback +class CodeExecutionContext: + _prompt_id: uuid.UUID = None + _skills_manager: SkillsManager = None + + def __init__(self, prompt_id: uuid.UUID, skills_manager: SkillsManager): + """ + Additional Context for code execution + Args: + prompt_id (uuid.UUID): prompt unique id + skill (List): list[functions] of skills added + """ + self._skills_manager = skills_manager + self._prompt_id = prompt_id + + @property + def prompt_id(self): + return self._prompt_id + + @property + def skills_manager(self): + return self._skills_manager + + class CodeManager: _dfs: List _middlewares: List[Middleware] = [ChartsMiddleware()] @@ -180,11 +205,7 @@ def _required_dfs(self, code: str) -> List[str]: required_dfs.append(None) return required_dfs - def execute_code( - self, - code: str, - prompt_id: uuid.UUID, - ) -> Any: + def execute_code(self, code: str, context: CodeExecutionContext) -> Any: """ Execute the python code generated by LLMs to answer the question about the input dataframe. Run the code in the current context and return the @@ -192,7 +213,8 @@ def execute_code( Args: code (str): Python code to execute. - prompt_id (uuid.UUID): UUID of the request. + context (CodeExecutionContext): Code Execution Context + with prompt id and skills. Returns: Any: The result of the code execution. The type of the result depends @@ -209,12 +231,15 @@ def execute_code( code = add_save_chart( code, logger=self._logger, - file_name=str(prompt_id), + file_name=str(context.prompt_id), save_charts_path=self._config.save_charts_path, ) + # Reset used skills + context.skills_manager.used_skills = [] + # Get the code to run removing unsafe imports and df overwrites - code_to_run = self._clean_code(code) + code_to_run = self._clean_code(code, context) self.last_code_executed = code_to_run self._logger.log( f""" @@ -228,6 +253,13 @@ def execute_code( # if the code does not need them dfs = self._required_dfs(code_to_run) environment: dict = self._get_environment() + + # Add Skills in the env + if len(context.skills_manager.used_skills) > 0: + for skill_func_name in context.skills_manager.used_skills: + skill = context.skills_manager.get_skill_by_func_name(skill_func_name) + environment[skill_func_name] = skill + environment["dfs"] = self._get_samples(dfs) caught_error = self._execute_catching_errors(code_to_run, environment) @@ -293,7 +325,6 @@ def _get_environment(self) -> dict: Returns (dict): A dictionary of environment variables """ - return { "pd": pd, **{ @@ -377,7 +408,7 @@ def _sanitize_analyze_data(self, analyze_data_node: ast.stmt) -> ast.stmt: analyze_data_node.body = sanitized_analyze_data return analyze_data_node - def _clean_code(self, code: str) -> str: + def _clean_code(self, code: str, context: CodeExecutionContext) -> str: """ A method to clean the code to prevent malicious code execution. @@ -400,11 +431,24 @@ def _clean_code(self, code: str) -> str: if isinstance(node, (ast.Import, ast.ImportFrom)): self._check_imports(node) continue + if isinstance(node, ast.FunctionDef) and node.name == "analyze_data": analyze_data_node = node sanitized_analyze_data = self._sanitize_analyze_data(analyze_data_node) + + # Walk inside the function def for used skills + if len(context.skills_manager.skills) > 0: + for node in ast.walk(analyze_data_node): + # Checks for function to get skill name + if isinstance(node, ast.Call) and isinstance( + node.func, ast.Name + ): + function_name = node.func.id + context.skills_manager.add_used_skill(function_name) + new_body.append(sanitized_analyze_data) continue + new_body.append(node) new_tree = ast.Module(body=new_body) diff --git a/pandasai/helpers/skills_manager.py b/pandasai/helpers/skills_manager.py new file mode 100644 index 000000000..c63c5ff6f --- /dev/null +++ b/pandasai/helpers/skills_manager.py @@ -0,0 +1,103 @@ +from typing import List + +# from pandasai.skills import skill + + +class SkillsManager: + """ + Manages Custom added Skills and tracks used skills for the query + """ + + _skills: List + _used_skills: List[str] + + def __init__(self) -> None: + self._skills = [] + self._used_skills = [] + + def add_skills(self, *skills): + """ + Add skills to the list of skills. If a skill with the same name + already exists, raise an error. + + Args: + *skills: Variable number of skill objects to add. + """ + for skill in skills: + if any( + existing_skill.name == skill.name for existing_skill in self._skills + ): + raise ValueError(f"Skill with name '{skill.name}' already exists.") + + self._skills.extend(skills) + + def skill_exists(self, name: str): + """ + Check if a skill with the given name exists in the list of skills. + + Args: + name (str): The name of the skill to check. + + Returns: + bool: True if a skill with the given name exists, False otherwise. + """ + return any(skill.name == name for skill in self._skills) + + def get_skill_by_func_name(self, name: str): + """ + Get a skill by its name. + + Args: + name (str): The name of the skill to retrieve. + + Returns: + Skill or None: The skill with the given name, or None if not found. + """ + for skill in self._skills: + if skill.name == name: + return skill + + return None + + def add_used_skill(self, skill: str): + if self.skill_exists(skill): + self._used_skills.append(skill) + + def __str__(self) -> str: + """ + Present all skills + Returns: + str: _description_ + """ + skills_repr = "" + for skill in self._skills: + skills_repr = skills_repr + skill.print + + return skills_repr + + def prompt_display(self) -> str: + """ + Displays skills for prompt + """ + if len(self._skills) == 0: + return + + return ( + """ +You can also use the following functions, if relevant: + +""" + + self.__str__() + ) + + @property + def used_skills(self): + return self._used_skills + + @used_skills.setter + def used_skills(self, value): + self._used_skills = value + + @property + def skills(self): + return self._skills diff --git a/pandasai/skills/__init__.py b/pandasai/skills/__init__.py new file mode 100644 index 000000000..cc82c5b2f --- /dev/null +++ b/pandasai/skills/__init__.py @@ -0,0 +1,32 @@ +import inspect + + +def skill(skill_function): + def wrapped_function(*args, **kwargs): + return skill_function(*args, **kwargs) + + wrapped_function.name = skill_function.__name__ + wrapped_function.func_def = ( + """def pandasai.skills.{funcion_name}{signature}""".format( + funcion_name=wrapped_function.name, + signature=str(inspect.signature(skill_function)), + ) + ) + + doc_string = skill_function.__doc__ + + wrapped_function.print = ( + """ + +{signature} +{doc_string} + +""" + ).format( + signature=wrapped_function.func_def, + doc_string=""" \"\"\"{0}\n \"\"\"""".format(doc_string) + if doc_string is not None + else "", + ) + + return wrapped_function diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 4a024d3bc..28724c1e4 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -26,6 +26,7 @@ import pydantic from pandasai.helpers.df_validator import DfValidator +from pandasai.skills import skill from ..smart_datalake import SmartDatalake from ..schemas.df_config import Config @@ -322,6 +323,12 @@ def add_middlewares(self, *middlewares: Optional[Middleware]): """ self.lake.add_middlewares(*middlewares) + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self.lake.add_skills(*skills) + def chat(self, query: str, output_type: Optional[str] = None): """ Run a query on the dataframe. diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 447249469..34285b5a1 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -21,6 +21,9 @@ import logging import os import traceback +from pandasai.helpers.skills_manager import SkillsManager + +from pandasai.skills import skill from pandasai.helpers.query_exec_tracker import QueryExecTracker @@ -38,7 +41,7 @@ from ..prompts.correct_error_prompt import CorrectErrorPrompt from ..prompts.generate_python_code import GeneratePythonCodePrompt from typing import Union, List, Any, Type, Optional -from ..helpers.code_manager import CodeManager +from ..helpers.code_manager import CodeExecutionContext, CodeManager from ..middlewares.base import Middleware from ..helpers.df_info import DataFrameType from ..helpers.path import find_project_root @@ -51,11 +54,11 @@ class SmartDatalake: _llm: LLM _cache: Cache = None _logger: Logger - _start_time: float _last_prompt_id: uuid.UUID _conversation_id: uuid.UUID _code_manager: CodeManager _memory: Memory + _skills: SkillsManager _instance: str _query_exec_tracker: QueryExecTracker @@ -104,6 +107,8 @@ def __init__( logger=self.logger, ) + self._skills = SkillsManager() + if cache: self._cache = cache elif self._config.enable_cache: @@ -210,6 +215,12 @@ def add_middlewares(self, *middlewares: Optional[Middleware]): """ self._code_manager.add_middlewares(*middlewares) + def add_skills(self, *skills: List[skill]): + """ + Add Skills to PandasAI + """ + self._skills.add_skills(*skills) + def _assign_prompt_id(self): """Assign a prompt ID""" @@ -248,6 +259,11 @@ def _get_prompt( prompt.set_var("dfs", self._dfs) if "conversation" not in default_values: prompt.set_var("conversation", self._memory.get_conversation()) + + # Adds the skills to prompt if exist else display nothing + skills_prompt = self._skills.prompt_display() + prompt.set_var("skills", skills_prompt if skills_prompt is not None else "") + for key, value in default_values.items(): prompt.set_var(key, value) @@ -374,10 +390,10 @@ def chat(self, query: str, output_type: Optional[str] = None): while retry_count < self._config.max_retries: try: # Execute the code - result = self._query_exec_tracker.execute_func( - self._code_manager.execute_code, + context = CodeExecutionContext(self._last_prompt_id, self._skills) + result = self._code_manager.execute_code( code=code_to_run, - prompt_id=self._last_prompt_id, + context=context, ) break diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index 8cc3a0aa6..05ddb764e 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -51,6 +51,7 @@ def test_str_with_args(self, save_charts_path, output_type_hint): prompt.set_var("conversation", "Question") prompt.set_var("save_charts_path", save_charts_path) prompt.set_var("output_type_hint", output_type_hint) + prompt.set_var("skills", "") expected_prompt_content = f'''You are provided with the following pandas DataFrames: @@ -108,6 +109,7 @@ def test_advanced_reasoning_prompt(self): prompt.set_var("conversation", "Question") prompt.set_var("save_charts_path", "") prompt.set_var("output_type_hint", "") + prompt.set_var("skills", "") expected_prompt_content = f'''You are provided with the following pandas DataFrames: diff --git a/tests/skills/test_skills.py b/tests/skills/test_skills.py new file mode 100644 index 000000000..ed979951c --- /dev/null +++ b/tests/skills/test_skills.py @@ -0,0 +1,347 @@ +from typing import Optional +from unittest.mock import MagicMock, Mock, patch +import uuid +import pandas as pd + +import pytest +from pandasai.agent import Agent +from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager + +from pandasai.helpers.skills_manager import SkillsManager +from pandasai.llm.fake import FakeLLM +from pandasai.skills import skill +from pandasai.smart_dataframe import SmartDataframe + + +class TestSkills: + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": False}) + + @pytest.fixture + def code_manager(self, smart_dataframe: SmartDataframe): + return smart_dataframe.lake._code_manager + + @pytest.fixture + def exec_context(self) -> MagicMock: + context = MagicMock(spec=CodeExecutionContext) + return context + + @pytest.fixture + def agent(self, llm, sample_df): + return Agent(sample_df, config={"llm": llm, "enable_cache": False}) + + def test_add_skills(self): + skills_manager = SkillsManager() + skill1 = Mock(name="SkillA", print="SkillA Print") + skill2 = Mock(name="SkillB", print="SkillB Print") + skills_manager.add_skills(skill1, skill2) + + # Ensure that skills are added + assert skill1 in skills_manager.skills + assert skill2 in skills_manager.skills + + # Test that adding a skill with the same name raises an error + try: + skills_manager.add_skills(skill1) + except ValueError as e: + assert str(e) == f"Skill with name '{skill1.name}' already exists." + else: + assert False, "Expected ValueError" + + def test_skill_exists(self): + skills_manager = SkillsManager() + skill1 = MagicMock() + skill2 = MagicMock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + assert skills_manager.skill_exists("SkillA") + assert skills_manager.skill_exists("SkillB") + + # Test that a non-existing skill is not found + assert not skills_manager.skill_exists("SkillC") + + def test_get_skill_by_func_name(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test that you can retrieve a skill by its function name + retrieved_skill = skills_manager.get_skill_by_func_name("SkillA") + assert retrieved_skill == skill1 + + # Test that a non-existing skill returns None + retrieved_skill = skills_manager.get_skill_by_func_name("SkillC") + assert retrieved_skill is None + + def test_add_used_skill(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test adding used skills + skills_manager.add_used_skill("SkillA") + skills_manager.add_used_skill("SkillB") + + # Ensure used skills are added to the used_skills list + assert "SkillA" in skills_manager.used_skills + assert "SkillB" in skills_manager.used_skills + + def test_prompt_display(self): + skills_manager = SkillsManager() + skill1 = Mock() + skill2 = Mock() + skill1.name = "SkillA" + skill2.name = "SkillB" + skill1.print = "SkillA" + skill2.print = "SkillB" + skills_manager.add_skills(skill1, skill2) + + # Test prompt_display method when skills exist + prompt = skills_manager.prompt_display() + assert "You can also use the following functions" in prompt + + # Test prompt_display method when no skills exist + skills_manager._skills = [] + prompt = skills_manager.prompt_display() + assert prompt is None + + @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)") + def test_skill_decorator(self, mock_inspect_signature): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + # Test the wrapped functions + assert skill_a() == "SkillA Result" + assert skill_b() == "SkillB Result" + + # Test the additional attributes added by the decorator + assert skill_a.name == "skill_a" + assert skill_b.name == "skill_b" + + assert skill_a.func_def == "def pandasai.skills.skill_a(a, b, c)" + assert skill_b.func_def == "def pandasai.skills.skill_b(a, b, c)" + + assert ( + skill_a.print + == """\n\ndef pandasai.skills.skill_a(a, b, c)\n\n\n""" # noqa: E501 + ) + assert ( + skill_b.print + == """\n\ndef pandasai.skills.skill_b(a, b, c)\n\n\n""" # noqa: E501 + ) + + @patch("pandasai.skills.inspect.signature", return_value="(a, b, c)") + def test_skill_decorator_test_codc(self, llm): + df = pd.DataFrame({"country": []}) + df = SmartDataframe(df, config={"llm": llm, "enable_cache": False}) + + # Define skills using the decorator + @skill + def plot_salaries(*args, **kwargs): + """ + Test skill A + Args: + arg(str) + """ + return "SkillA Result" + + function_def = """ + Test skill A + Args: + arg(str) +""" # noqa: E501 + + assert function_def in plot_salaries.print + + def test_add_skills_with_agent(self, agent: Agent): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + agent.add_skills(skill_a) + assert len(agent._lake._skills.skills) == 1 + + agent._lake._skills._skills = [] + agent.add_skills(skill_a, skill_b) + assert len(agent._lake._skills.skills) == 2 + + def test_add_skills_with_smartDataframe(self, smart_dataframe: SmartDataframe): + # Define skills using the decorator + @skill + def skill_a(*args, **kwargs): + return "SkillA Result" + + @skill + def skill_b(*args, **kwargs): + return "SkillB Result" + + smart_dataframe.add_skills(skill_a) + assert len(smart_dataframe._lake._skills.skills) == 1 + + smart_dataframe._lake._skills._skills = [] + smart_dataframe.add_skills(skill_a, skill_b) + assert len(smart_dataframe._lake._skills.skills) == 2 + + def test_run_prompt(self, llm): + df = pd.DataFrame({"country": []}) + df = SmartDataframe(df, config={"llm": llm, "enable_cache": False}) + + function_def = """ + +def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str + + +""" # noqa: E501 + + @skill + def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + df.add_skills(plot_salaries) + + df.chat("How many countries are in the dataframe?") + last_prompt = df.last_prompt + assert function_def in last_prompt + + def test_run_prompt_agent(self, agent): + function_def = """ + +def pandasai.skills.plot_salaries(merged_df: pandas.core.frame.DataFrame) -> str + + +""" # noqa: E501 + + @skill + def plot_salaries(merged_df: pd.DataFrame) -> str: + import matplotlib.pyplot as plt + + plt.bar(merged_df["Name"], merged_df["Salary"]) + plt.xlabel("Employee Name") + plt.ylabel("Salary") + plt.title("Employee Salaries") + plt.xticks(rotation=45) + plt.savefig("temp_chart.png") + plt.close() + + agent.add_skills(plot_salaries) + + agent.chat("How many countries are in the dataframe?") + last_prompt = agent._lake.last_prompt + + assert function_def in last_prompt + + def test_run_prompt_without_skills(self, agent): + agent.chat("How many countries are in the dataframe?") + + last_prompt = agent._lake.last_prompt + + assert "" not in last_prompt + assert "" not in last_prompt + assert ( + "You can also use the following functions, if relevant:" not in last_prompt + ) + + def test_code_exec_with_skills_no_use( + self, code_manager: CodeManager, exec_context: MagicMock + ): + code = """def analyze_data(dfs): + return {'type': 'number', 'value': 1 + 1}""" + skill1 = MagicMock() + skill1.name = "SkillA" + exec_context._skills_manager._skills = [skill1] + code_manager.execute_code(code, exec_context) + assert len(exec_context._skills_manager.used_skills) == 0 + + def test_code_exec_with_skills(self, code_manager: CodeManager): + code = """def analyze_data(dfs): + plot_salaries() + return {'type': 'number', 'value': 1 + 1}""" + + @skill + def plot_salaries() -> str: + return "plot_salaries" + + code_manager._middlewares = [] + + sm = SkillsManager() + sm.add_skills(plot_salaries) + exec_context = CodeExecutionContext(uuid.uuid4(), sm) + code_manager.execute_code(code, exec_context) + + assert len(exec_context._skills_manager.used_skills) == 1 + assert exec_context._skills_manager.used_skills[0] == "plot_salaries" diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index b5c39933a..0df494429 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -1,7 +1,6 @@ """Unit tests for the CodeManager class""" -import uuid from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pandas as pd import pytest @@ -11,7 +10,7 @@ from pandasai.smart_dataframe import SmartDataframe -from pandasai.helpers.code_manager import CodeManager +from pandasai.helpers.code_manager import CodeExecutionContext, CodeManager class TestCodeManager: @@ -72,23 +71,33 @@ def smart_dataframe(self, llm, sample_df): def code_manager(self, smart_dataframe: SmartDataframe): return smart_dataframe.lake._code_manager - def test_run_code_for_calculations(self, code_manager: CodeManager): + @pytest.fixture + def exec_context(self) -> MagicMock: + context = MagicMock(spec=CodeExecutionContext) + return context + + def test_run_code_for_calculations( + self, code_manager: CodeManager, exec_context: MagicMock + ): code = """def analyze_data(dfs): return {'type': 'number', 'value': 1 + 1}""" - - assert code_manager.execute_code(code, uuid.uuid4())["value"] == 2 + assert code_manager.execute_code(code, exec_context)["value"] == 2 assert code_manager.last_code_executed == code - def test_run_code_invalid_code(self, code_manager: CodeManager): + def test_run_code_invalid_code( + self, code_manager: CodeManager, exec_context: MagicMock + ): with pytest.raises(Exception): # noinspection PyStatementEffect - code_manager.execute_code("1+ ", uuid.uuid4())["value"] + code_manager.execute_code("1+ ", exec_context)["value"] - def test_clean_code_remove_builtins(self, code_manager: CodeManager): + def test_clean_code_remove_builtins( + self, code_manager: CodeManager, exec_context: MagicMock + ): builtins_code = """import set def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" - assert code_manager.execute_code(builtins_code, uuid.uuid4())["value"] == { + assert code_manager.execute_code(builtins_code, exec_context)["value"] == { 1, 2, 3, @@ -99,44 +108,57 @@ def analyze_data(dfs): return {'type': 'number', 'value': set([1, 2, 3])}""" ) - def test_clean_code_removes_jailbreak_code(self, code_manager: CodeManager): + def test_clean_code_removes_jailbreak_code( + self, code_manager: CodeManager, exec_context: MagicMock + ): malicious_code = """def analyze_data(dfs): __builtins__['str'].__class__.__mro__[-1].__subclasses__()[140].__init__.__globals__['system']('ls') print('hello world')""" assert ( - code_manager._clean_code(malicious_code) + code_manager._clean_code(malicious_code, exec_context) == """def analyze_data(dfs): print('hello world')""" ) - def test_clean_code_remove_environment_defaults(self, code_manager: CodeManager): + def test_clean_code_remove_environment_defaults( + self, code_manager: CodeManager, exec_context: MagicMock + ): pandas_code = """import pandas as pd print('hello world') """ - assert code_manager._clean_code(pandas_code) == "print('hello world')" + assert ( + code_manager._clean_code(pandas_code, exec_context) + == "print('hello world')" + ) - def test_clean_code_whitelist_import(self, code_manager: CodeManager): + def test_clean_code_whitelist_import( + self, code_manager: CodeManager, exec_context: MagicMock + ): """Test that an installed whitelisted library is added to the environment.""" safe_code = """ import numpy as np np.array() """ - assert code_manager._clean_code(safe_code) == "np.array()" + assert code_manager._clean_code(safe_code, exec_context) == "np.array()" - def test_clean_code_raise_bad_import_error(self, code_manager: CodeManager): + def test_clean_code_raise_bad_import_error( + self, code_manager: CodeManager, exec_context: MagicMock + ): malicious_code = """ import os print(os.listdir()) """ with pytest.raises(BadImportError): - code_manager.execute_code(malicious_code, uuid.uuid4()) + code_manager.execute_code(malicious_code, exec_context) - def test_remove_dfs_overwrites(self, code_manager: CodeManager): + def test_remove_dfs_overwrites( + self, code_manager: CodeManager, exec_context: MagicMock + ): hallucinated_code = """def analyze_data(dfs): dfs = [pd.DataFrame([1,2,3])] print(dfs)""" assert ( - code_manager._clean_code(hallucinated_code) + code_manager._clean_code(hallucinated_code, exec_context) == """def analyze_data(dfs): print(dfs)""" ) @@ -157,7 +179,9 @@ def test_exception_handling( ) assert smart_dataframe.last_error == "No code found in the answer." - def test_custom_whitelisted_dependencies(self, code_manager: CodeManager, llm): + def test_custom_whitelisted_dependencies( + self, code_manager: CodeManager, llm, exec_context: MagicMock + ): code = """ import my_custom_library def analyze_data(dfs: list): @@ -166,11 +190,11 @@ def analyze_data(dfs: list): llm._output = code with pytest.raises(BadImportError): - code_manager._clean_code(code) + code_manager._clean_code(code, exec_context) code_manager._config.custom_whitelisted_dependencies = ["my_custom_library"] assert ( - code_manager._clean_code(code) + code_manager._clean_code(code, exec_context) == """def analyze_data(dfs: list): my_custom_library.do_something()""" ) From 5569df04e97334e1f7ff5b1cda9b06c737674792 Mon Sep 17 00:00:00 2001 From: Tanmay patil <77950208+Tanmaypatil123@users.noreply.github.com> Date: Wed, 25 Oct 2023 04:00:31 +0530 Subject: [PATCH 18/19] feat: add support for sqlite connectors (#680) * Added support for sqlite connectors * Removed linting error * Added tests for sqlite connector * chore: style the code --------- Co-authored-by: Gabriele Venturi --- docs/connectors.md | 22 +++++++++ examples/from_sql.py | 15 +++++- pandasai/connectors/__init__.py | 2 + pandasai/connectors/base.py | 9 ++++ pandasai/connectors/sql.py | 86 ++++++++++++++++++++++++++++++++- tests/connectors/test_sqlite.py | 85 ++++++++++++++++++++++++++++++++ 6 files changed, 216 insertions(+), 3 deletions(-) create mode 100644 tests/connectors/test_sqlite.py diff --git a/docs/connectors.md b/docs/connectors.md index 1850b9a7c..58d11a970 100644 --- a/docs/connectors.md +++ b/docs/connectors.md @@ -89,6 +89,28 @@ df = SmartDataframe(mysql_connector) df.chat('What is the total amount of loans in the last year?') ``` +### Sqlite connector + +Similarly to the PostgreSQL and MySQL connectors, the Sqlite connector allows you to connect to a local Sqlite database file. It is designed to be easy to use, even if you are not familiar with Sqlite or with PandasAI. + +To use the Sqlite connector, you only need to import it into your Python code and pass it to a `SmartDataframe` or `SmartDatalake` object: + +```python +from pandasai.connectors import SqliteConnector + +connector = SqliteConnector(config={ + "database" : "PATH_TO_DB", + "table" : "actor", + "where" :[ + ["first_name","=","PENELOPE"] + ] +}) + +df = SmartDataframe(connector) +df.chat('How many records are there ?') +``` + + ### Generic SQL connector The generic SQL connector allows you to connect to any SQL database that is supported by SQLAlchemy. diff --git a/examples/from_sql.py b/examples/from_sql.py index c8d435559..626df3ff3 100644 --- a/examples/from_sql.py +++ b/examples/from_sql.py @@ -2,7 +2,7 @@ from pandasai import SmartDatalake from pandasai.llm import OpenAI -from pandasai.connectors import MySQLConnector, PostgreSQLConnector +from pandasai.connectors import MySQLConnector, PostgreSQLConnector, SqliteConnector # With a MySQL database loan_connector = MySQLConnector( @@ -38,8 +38,19 @@ } ) +# With a Sqlite databse + +invoice_connector = SqliteConnector( + config={ + "database": "local_path_to_db", + "table": "invoices", + "where": [["status", "=", "pending"]], + } +) llm = OpenAI() -df = SmartDatalake([loan_connector, payment_connector], config={"llm": llm}) +df = SmartDatalake( + [loan_connector, payment_connector, invoice_connector], config={"llm": llm} +) response = df.chat("How many people from the United states?") print(response) # Output: 247 loans have been paid off by men. diff --git a/pandasai/connectors/__init__.py b/pandasai/connectors/__init__.py index fb80c8628..a484835ff 100644 --- a/pandasai/connectors/__init__.py +++ b/pandasai/connectors/__init__.py @@ -10,6 +10,7 @@ from .databricks import DatabricksConnector from .yahoo_finance import YahooFinanceConnector from .airtable import AirtableConnector +from .sql import SqliteConnector __all__ = [ "BaseConnector", @@ -20,4 +21,5 @@ "SnowFlakeConnector", "DatabricksConnector", "AirtableConnector", + "SqliteConnector", ] diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py index 4de547486..3193616bf 100644 --- a/pandasai/connectors/base.py +++ b/pandasai/connectors/base.py @@ -39,6 +39,15 @@ class SQLBaseConnectorConfig(BaseConnectorConfig): dialect: Optional[str] = None +class SqliteConnectorConfig(SQLBaseConnectorConfig): + """ + Connector configurations for sqlite db. + """ + + table: str + database: str + + class YahooFinanceConnectorConfig(BaseConnectorConfig): """ Connector configuration for Yahoo Finance. diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 215da08f4..8222bed21 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -5,7 +5,7 @@ import re import os import pandas as pd -from .base import BaseConnector, SQLConnectorConfig +from .base import BaseConnector, SQLConnectorConfig, SqliteConnectorConfig from .base import BaseConnectorConfig from sqlalchemy import create_engine, text, select, asc from sqlalchemy.engine import Connection @@ -364,6 +364,90 @@ def fallback_name(self): return self._config.table +class SqliteConnector(SQLConnector): + """ + Sqlite connector are used to connect to Sqlite databases. + """ + + def __init__(self, config: Union[SqliteConnectorConfig, dict]): + """ + Intialize the Sqlite connector with the given configuration. + + Args: + config (ConnectorConfig) : The configuration for the MySQL connector. + """ + config["dialect"] = "sqlite" + if isinstance(config, dict): + sqlite_env_vars = {"database": "SQLITE_DB_PATH", "table": "TABLENAME"} + config = self._populate_config_from_env(config, sqlite_env_vars) + + super().__init__(config) + + def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): + """ + Loads passed Configuration to object + + Args: + config (BaseConnectorConfig): Construct config in structure + + Returns: + config: BaseConenctorConfig + """ + return SqliteConnectorConfig(**config) + + def _init_connection(self, config: SqliteConnectorConfig): + """ + Initialize Database Connection + + Args: + config (SQLConnectorConfig): Configurations to load database + + """ + self._engine = create_engine(f"{config.dialect}:///{config.database}") + self._connection = self._engine.connect() + + def __del__(self): + """ + Close the connection to the SQL database. + """ + self._connection.close() + + @cache + def head(self): + """ + Return the head of the data source that the connector is connected to. + This information is passed to the LLM to provide the schema of the data source. + + Returns: + DataFrame: The head of the data source. + """ + + if self.logger: + self.logger.log( + f"Getting head of {self._config.table} " + f"using dialect {self._config.dialect}" + ) + + # Run a SQL query to get all the columns names and 5 random rows + query = self._build_query(limit=5, order="RANDOM()") + + # Return the head of the data source + return pd.read_sql(query, self._connection) + + def __repr__(self): + """ + Return the string representation of the SQL connector. + + Returns: + str: The string representation of the SQL connector. + """ + return ( + f"<{self.__class__.__name__} dialect={self._config.dialect} " + f"database={self._config.database} " + f"table={self._config.table}>" + ) + + class MySQLConnector(SQLConnector): """ MySQL connectors are used to connect to MySQL databases. diff --git a/tests/connectors/test_sqlite.py b/tests/connectors/test_sqlite.py new file mode 100644 index 000000000..f55d48011 --- /dev/null +++ b/tests/connectors/test_sqlite.py @@ -0,0 +1,85 @@ +import unittest +import pandas as pd +from unittest.mock import Mock,patch +from pandasai.connectors.base import SqliteConnectorConfig +from pandasai.connectors import SqliteConnector + +class TestSqliteConnector(unittest.TestCase): + @patch("pandasai.connectors.sql.create_engine",autospec=True) + def setUp(self,mock_create_engine) -> None: + self.mock_engine = Mock() + self.mock_connection = Mock() + self.mock_engine.connect.return_value = self.mock_connection + mock_create_engine.return_value = self.mock_engine + + self.config = SqliteConnectorConfig( + dialect="sqlite", + database="path_todb.db", + table="yourtable" + ).dict() + + self.connector = SqliteConnector(self.config) + + @patch("pandasai.connectors.SqliteConnector._load_connector_config") + @patch("pandasai.connectors.SqliteConnector._init_connection") + def test_constructor_and_properties( + self, mock_load_connector_config, mock_init_connection + ): + # Test constructor and properties + self.assertEqual(self.connector._config, self.config) + self.assertEqual(self.connector._engine, self.mock_engine) + self.assertEqual(self.connector._connection, self.mock_connection) + self.assertEqual(self.connector._cache_interval, 600) + SqliteConnector(self.config) + mock_load_connector_config.assert_called() + mock_init_connection.assert_called() + + def test_repr_method(self): + # Test __repr__ method + expected_repr = ( + "" + ) + self.assertEqual(repr(self.connector), expected_repr) + + @patch("pandasai.connectors.sql.pd.read_sql", autospec=True) + def test_head_method(self, mock_read_sql): + expected_data = pd.DataFrame({"Column1": [1, 2, 3], "Column2": [4, 5, 6]}) + mock_read_sql.return_value = expected_data + head_data = self.connector.head() + pd.testing.assert_frame_equal(head_data, expected_data) + + def test_rows_count_property(self): + # Test rows_count property + self.connector._rows_count = None + self.mock_connection.execute.return_value.fetchone.return_value = ( + 50, + ) # Sample rows count + rows_count = self.connector.rows_count + self.assertEqual(rows_count, 50) + + def test_columns_count_property(self): + # Test columns_count property + self.connector._columns_count = None + mock_df = Mock() + mock_df.columns = ["Column1", "Column2"] + self.connector.head = Mock(return_value=mock_df) + columns_count = self.connector.columns_count + self.assertEqual(columns_count, 2) + + def test_column_hash_property(self): + # Test column_hash property + mock_df = Mock() + mock_df.columns = ["Column1", "Column2"] + self.connector.head = Mock(return_value=mock_df) + column_hash = self.connector.column_hash + self.assertIsNotNone(column_hash) + self.assertEqual( + column_hash, + "0d045cff164deef81e24b0ed165b7c9c2789789f013902115316cde9d214fe63", + ) + + def test_fallback_name_property(self): + # Test fallback_name property + fallback_name = self.connector.fallback_name + self.assertEqual(fallback_name, "yourtable") \ No newline at end of file From 70c6fca99bcbe19e85896ce1fab75d18c56afede Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Wed, 25 Oct 2023 00:32:06 +0200 Subject: [PATCH 19/19] fix: remove hallucination of hardcoded / non-factual information when returning strings --- pandasai/helpers/output_types/_output_types.py | 2 +- tests/test_smartdataframe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pandasai/helpers/output_types/_output_types.py b/pandasai/helpers/output_types/_output_types.py index 7aa78889b..18b9eff42 100644 --- a/pandasai/helpers/output_types/_output_types.py +++ b/pandasai/helpers/output_types/_output_types.py @@ -140,7 +140,7 @@ def template_hint(self): return """- type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) Examples: - { "type": "string", "value": "The highest salary is $9,000." } + { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index c00c9735e..fada5ed04 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -230,7 +230,7 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: - type (possible values "string", "number", "dataframe", "plot") - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) Examples: - { "type": "string", "value": "The highest salary is $9,000." } + { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or