From 3f9816ab70e109d6f44386debc65d6c7d7d25599 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Wed, 20 Nov 2024 14:45:41 +0100 Subject: [PATCH] feat(VirtualDataFrame): virtual dataframe to load data on demand and enable direct_sql (#1434) * refactor(pandasai): make pandasai v3 work for dataframe * fix(sql): load and work with dataframe * fix: handle invalid data source type * feat(VirtualDataframe): lazy load data from the schema and fetch on demand --- pandasai/__init__.py | 6 +- pandasai/agent/base.py | 49 +- pandasai/{dataframe => data_loader}/loader.py | 103 +++- pandasai/data_loader/query_builder.py | 55 ++ pandasai/data_loader/schema_validator.py | 9 + pandasai/dataframe/virtual_dataframe.py | 41 ++ pandasai/ee/LICENSE | 36 -- .../advanced_security_agent/__init__.py | 32 -- .../pipeline/advanced_security_pipeline.py | 28 - .../advanced_security_prompt_generation.py | 42 -- .../pipeline/llm_call.py | 64 --- .../prompts/advanced_security_agent_prompt.py | 39 -- .../advanced_security_agent_prompt.tmpl | 20 - pandasai/ee/agents/judge_agent/__init__.py | 30 - .../judge_agent/pipeline/judge_pipeline.py | 29 - .../pipeline/judge_prompt_generation.py | 50 -- .../agents/judge_agent/pipeline/llm_call.py | 64 --- .../judge_agent/prompts/judge_agent_prompt.py | 39 -- .../prompts/templates/judge_agent_prompt.tmpl | 11 - pandasai/ee/agents/semantic_agent/__init__.py | 215 ------- .../pipeline/Semantic_prompt_generation.py | 46 -- .../semantic_agent/pipeline/code_generator.py | 231 -------- .../error_correction_pipeline.py | 67 --- .../fix_semantic_json_pipeline.py | 41 -- .../fix_semantic_schema_prompt.py | 61 -- .../semantic_agent/pipeline/llm_call.py | 59 -- .../pipeline/semantic_chat_pipeline.py | 118 ---- .../pipeline/semantic_result_parsing.py | 23 - .../pipeline/validate_pipeline_input.py | 69 --- .../prompts/fix_semantic_json.py | 39 -- .../prompts/generate_df_schema.py | 60 -- .../prompts/semantic_agent_prompt.py | 39 -- .../templates/fix_semantic_json_prompt.tmpl | 13 - .../prompts/templates/generate_df_schema.tmpl | 153 ----- .../templates/semantic_agent_prompt.tmpl | 6 - .../prompts/templates/shared/dataframe.tmpl | 1 - .../templates/shared/vectordb_docs.tmpl | 8 - pandasai/ee/connectors/relations.py | 25 - pandasai/ee/helpers/json_helper.py | 14 - pandasai/ee/helpers/query_builder.py | 533 ------------------ pandasai/helpers/dataframe_serializer.py | 2 +- pandasai/pipelines/chat/code_cleaning.py | 49 +- pandasai/pipelines/chat/code_execution.py | 55 +- .../pipelines/chat/validate_pipeline_input.py | 31 +- tests/unit_tests/dataframe/test_loader.py | 2 +- .../dataframe/test_query_builder.py | 2 +- tests/unit_tests/ee/helpers/schema.py | 88 --- .../test_semantic_agent_query_builder.py | 230 -------- .../smart_datalake/test_code_cleaning.py | 7 +- tests/unit_tests/pipelines/test_pipeline.py | 14 - 50 files changed, 268 insertions(+), 2780 deletions(-) rename pandasai/{dataframe => data_loader}/loader.py (58%) create mode 100644 pandasai/data_loader/query_builder.py create mode 100644 pandasai/data_loader/schema_validator.py create mode 100644 pandasai/dataframe/virtual_dataframe.py delete mode 100644 pandasai/ee/LICENSE delete mode 100644 pandasai/ee/agents/advanced_security_agent/__init__.py delete mode 100644 pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_pipeline.py delete mode 100644 pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_prompt_generation.py delete mode 100644 pandasai/ee/agents/advanced_security_agent/pipeline/llm_call.py delete mode 100644 pandasai/ee/agents/advanced_security_agent/prompts/advanced_security_agent_prompt.py delete mode 100644 pandasai/ee/agents/advanced_security_agent/prompts/templates/advanced_security_agent_prompt.tmpl delete mode 100644 pandasai/ee/agents/judge_agent/__init__.py delete mode 100644 pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py delete mode 100644 pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py delete mode 100644 pandasai/ee/agents/judge_agent/pipeline/llm_call.py delete mode 100644 pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py delete mode 100644 pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl delete mode 100644 pandasai/ee/agents/semantic_agent/__init__.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/Semantic_prompt_generation.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/code_generator.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_schema_prompt.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/llm_call.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/semantic_result_parsing.py delete mode 100644 pandasai/ee/agents/semantic_agent/pipeline/validate_pipeline_input.py delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/fix_semantic_json.py delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/semantic_agent_prompt.py delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/templates/fix_semantic_json_prompt.tmpl delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/templates/shared/dataframe.tmpl delete mode 100644 pandasai/ee/agents/semantic_agent/prompts/templates/shared/vectordb_docs.tmpl delete mode 100644 pandasai/ee/connectors/relations.py delete mode 100644 pandasai/ee/helpers/json_helper.py delete mode 100644 pandasai/ee/helpers/query_builder.py delete mode 100644 tests/unit_tests/ee/helpers/schema.py delete mode 100644 tests/unit_tests/ee/helpers/test_semantic_agent_query_builder.py diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 69a65c5b5..e71352fc6 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -7,7 +7,7 @@ from .agent import Agent from .helpers.cache import Cache from .dataframe.base import DataFrame -from .dataframe.loader import DatasetLoader +from .data_loader.loader import DatasetLoader # Global variable to store the current agent _current_agent = None @@ -61,7 +61,7 @@ def follow_up(query: str): _dataset_loader = DatasetLoader() -def load(dataset_path: str) -> DataFrame: +def load(dataset_path: str, virtualized=False) -> DataFrame: """ Load data based on the provided dataset path. @@ -72,7 +72,7 @@ def load(dataset_path: str) -> DataFrame: DataFrame: A new PandasAI DataFrame instance with loaded data. """ global _dataset_loader - return _dataset_loader.load(dataset_path) + return _dataset_loader.load(dataset_path, virtualized) __all__ = [ diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index b1cc667f8..7094fc371 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -5,6 +5,8 @@ import pandas as pd from pandasai.agent.base_security import BaseSecurity + +from pandasai.data_loader.schema_validator import is_schema_source_same from pandasai.llm.bamboo_llm import BambooLLM from pandasai.pipelines.chat.chat_pipeline_input import ChatPipelineInput from pandasai.pipelines.chat.code_execution_pipeline_input import ( @@ -62,17 +64,13 @@ def __init__( self.dfs = dfs if isinstance(dfs, list) else [dfs] - # Validate SQL connectors - sql_connectors = [ - df - for df in self.dfs - if hasattr(df, "type") and df.type in ["sql", "postgresql"] - ] - if len(sql_connectors) > 1: - raise InvalidConfigError("Cannot use multiple SQL connectors") - # Instantiate the context self.config = self.get_config(config) + + # Validate df input with configurations + self.validate_input() + + # Initialize the context self.context = PipelineContext( dfs=self.dfs, config=self.config, @@ -106,6 +104,39 @@ def __init__( self.pipeline = None self.security = security + def validate_input(self): + from pandasai.dataframe.virtual_dataframe import VirtualDataFrame + + # Check if all DataFrames are VirtualDataFrame, and set direct_sql accordingly + all_virtual = all(isinstance(df, VirtualDataFrame) for df in self.dfs) + if all_virtual: + self.config.direct_sql = True + + # Validate the configurations based on direct_sql flag all have same source + if self.config.direct_sql and all_virtual: + base_schema_source = self.dfs[0].schema + for df in self.dfs[1:]: + # Ensure all DataFrames have the same source in direct_sql mode + + if not is_schema_source_same(base_schema_source, df.schema): + raise InvalidConfigError( + "Direct SQL requires all connectors to be of the same type, " + "belong to the same datasource, and have the same credentials." + ) + else: + # If not using direct_sql, ensure all DataFrames have the same source + if any(isinstance(df, VirtualDataFrame) for df in self.dfs): + base_schema_source = self.dfs[0].schema + for df in self.dfs[1:]: + if not is_schema_source_same(base_schema_source, df.schema): + raise InvalidConfigError( + "All DataFrames must belong to the same source." + ) + self.config.direct_sql = True + else: + # Means all are none virtual + self.config.direct_sql = False + def configure(self): # Add project root path if save_charts_path is default if ( diff --git a/pandasai/dataframe/loader.py b/pandasai/data_loader/loader.py similarity index 58% rename from pandasai/dataframe/loader.py rename to pandasai/data_loader/loader.py index 4ffd1887a..8fc204c4c 100644 --- a/pandasai/dataframe/loader.py +++ b/pandasai/data_loader/loader.py @@ -1,12 +1,14 @@ +import copy import os import yaml import pandas as pd from datetime import datetime, timedelta import hashlib +from pandasai.dataframe.base import DataFrame +from pandasai.dataframe.virtual_dataframe import VirtualDataFrame from pandasai.exceptions import InvalidDataSourceType from pandasai.helpers.path import find_project_root -from .base import DataFrame import importlib from typing import Any from .query_builder import QueryBuilder @@ -18,27 +20,35 @@ def __init__(self): self.schema = None self.dataset_path = None - def load(self, dataset_path: str, lazy=False) -> DataFrame: + def load(self, dataset_path: str, virtualized=False) -> DataFrame: self.dataset_path = dataset_path self._load_schema() self._validate_source_type() + if not virtualized: + cache_file = self._get_cache_file_path() - cache_file = self._get_cache_file_path() + if self._is_cache_valid(cache_file): + return self._read_cache(cache_file) - if self._is_cache_valid(cache_file): - return self._read_cache(cache_file) + df = self._load_from_source() + df = self._apply_transformations(df) + self._cache_data(df, cache_file) - df = self._load_from_source() - df = self._apply_transformations(df) - self._cache_data(df, cache_file) + table_name = self.schema["source"]["table"] - return DataFrame(df, schema=self.schema) + return DataFrame(df, schema=self.schema, name=table_name) + else: + # Initialize new dataset loader for virtualization + data_loader = self.copy() + table_name = self.schema["source"]["table"] + return VirtualDataFrame( + schema=self.schema, data_loader=data_loader, name=table_name + ) def _load_schema(self): schema_path = os.path.join( find_project_root(), "datasets", self.dataset_path, "schema.yaml" ) - print(schema_path) if not os.path.exists(schema_path): raise FileNotFoundError(f"Schema file not found: {schema_path}") @@ -82,32 +92,67 @@ def _read_cache(self, cache_file: str) -> DataFrame: else: raise ValueError(f"Unsupported cache format: {cache_format}") - def _load_from_source(self) -> pd.DataFrame: - source_type = self.schema["source"]["type"] - connection_info = self.schema["source"].get("connection", {}) - query_builder = QueryBuilder(self.schema) - query = query_builder.build_query() - + def _get_loader_function(self, source_type: str): + """ + Get the loader function for a specified data source type. + """ try: module_name = SUPPORTED_SOURCES[source_type] module = importlib.import_module(module_name) - if source_type in [ + if source_type not in { "mysql", "postgres", "cockroach", "sqlite", "cockroachdb", - ]: - load_function = getattr(module, f"load_from_{source_type}") - return load_function(connection_info, query) - else: - raise InvalidDataSourceType("Invalid data source type") + }: + raise InvalidDataSourceType( + f"Unsupported data source type: {source_type}" + ) + + return getattr(module, f"load_from_{source_type}") + + except KeyError: + raise InvalidDataSourceType(f"Unsupported data source type: {source_type}") except ImportError as e: raise ImportError( f"{source_type.capitalize()} connector not found. " - f"Please install the {module_name} library." + f"Please install the {SUPPORTED_SOURCES[source_type]} library." + ) from e + + def _load_from_source(self) -> pd.DataFrame: + query_builder = QueryBuilder(self.schema) + query = query_builder.build_query() + return self.execute_query(query) + + def load_head(self) -> pd.DataFrame: + query_builder = QueryBuilder(self.schema) + query = query_builder.get_head_query() + return self.execute_query(query) + + def get_row_count(self) -> int: + query_builder = QueryBuilder(self.schema) + query = query_builder.get_row_count() + result = self.execute_query(query) + return result.iloc[0, 0] + + def execute_query(self, query: str) -> pd.DataFrame: + source = self.schema.get("source", {}) + source_type = source.get("type") + connection_info = source.get("connection", {}) + + if not source_type: + raise ValueError("Source type is missing in the schema.") + + load_function = self._get_loader_function(source_type) + + try: + return load_function(connection_info, query) + except Exception as e: + raise RuntimeError( + f"Failed to execute query for source type '{source_type}' with query: {query}" ) from e def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame: @@ -140,3 +185,15 @@ def _cache_data(self, df: pd.DataFrame, cache_file: str): df.to_csv(cache_file, index=False) else: raise ValueError(f"Unsupported cache format: {cache_format}") + + def copy(self) -> "DatasetLoader": + """ + Create a new independent copy of the current DatasetLoader instance. + + Returns: + DatasetLoader: A new instance with the same state. + """ + new_loader = DatasetLoader() + new_loader.schema = copy.deepcopy(self.schema) + new_loader.dataset_path = self.dataset_path + return new_loader diff --git a/pandasai/data_loader/query_builder.py b/pandasai/data_loader/query_builder.py new file mode 100644 index 000000000..5fc548951 --- /dev/null +++ b/pandasai/data_loader/query_builder.py @@ -0,0 +1,55 @@ +from typing import Dict, Any, List, Union + + +class QueryBuilder: + def __init__(self, schema: Dict[str, Any]): + self.schema = schema + + def build_query(self) -> str: + columns = self._get_columns() + table_name = self.schema["source"]["table"] + query = f"SELECT {columns} FROM {table_name}" + + query += self._add_order_by() + query += self._add_limit() + + return query + + def _get_columns(self) -> str: + if "columns" in self.schema: + return ", ".join([col["name"] for col in self.schema["columns"]]) + else: + return "*" + + def _add_order_by(self) -> str: + if "order_by" not in self.schema: + return "" + + order_by = self.schema["order_by"] + order_by_clause = self._format_order_by(order_by) + return f" ORDER BY {order_by_clause}" + + def _format_order_by(self, order_by: Union[List[str], str]) -> str: + return ", ".join(order_by) if isinstance(order_by, list) else order_by + + def _add_limit(self, n=None) -> str: + limit = n if n else (self.schema["limit"] if "limit" in self.schema else "") + return f" LIMIT {self.schema['limit']}" if limit else "" + + def get_head_query(self, n=5): + source = self.schema.get("source", {}) + source_type = source.get("type") + + table_name = self.schema["source"]["table"] + + columns = self._get_columns() + + order_by = "RAND()" + if source_type in {"sqlite", "postgres"}: + order_by = "RANDOM()" + + return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}" + + def get_row_count(self): + table_name = self.schema["source"]["table"] + return f"SELECT COUNT(*) FROM {table_name}" diff --git a/pandasai/data_loader/schema_validator.py b/pandasai/data_loader/schema_validator.py new file mode 100644 index 000000000..9cb3ac2f9 --- /dev/null +++ b/pandasai/data_loader/schema_validator.py @@ -0,0 +1,9 @@ +import json + + +def is_schema_source_same(schema1: dict, schema2: dict) -> bool: + return schema1.get("source").get("type") == schema2.get("source").get( + "type" + ) and json.dumps( + schema1.get("source").get("connection"), sort_keys=True + ) == json.dumps(schema2.get("source").get("connection"), sort_keys=True) diff --git a/pandasai/dataframe/virtual_dataframe.py b/pandasai/dataframe/virtual_dataframe.py new file mode 100644 index 000000000..84b40df4d --- /dev/null +++ b/pandasai/dataframe/virtual_dataframe.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, ClassVar +import pandas as pd +from pandasai.dataframe.base import DataFrame + +if TYPE_CHECKING: + from pandasai.data_loader.loader import DatasetLoader + + +class VirtualDataFrame(DataFrame): + _metadata: ClassVar[list] = [ + "_loader", + "head", + "_head", + "name", + "description", + "schema", + "config", + "_agent", + "_column_hash", + ] + + def __init__(self, *args, **kwargs): + self._loader: DatasetLoader = kwargs.pop("data_loader", None) + if not self._loader: + raise Exception("Data loader is required for virtualization!") + self._head = None + super().__init__(self.get_head(), *args, **kwargs) + + def head(self): + if self._head is None: + self._head = self._loader.load_head() + + return self._head + + @property + def rows_count(self) -> int: + return self._loader.get_row_count() + + def execute_sql_query(self, query: str) -> pd.DataFrame: + return self._loader.execute_query(query) diff --git a/pandasai/ee/LICENSE b/pandasai/ee/LICENSE deleted file mode 100644 index 86060d530..000000000 --- a/pandasai/ee/LICENSE +++ /dev/null @@ -1,36 +0,0 @@ -The PandasAI Enterprise license (the “Enterprise License”) -Copyright (c) 2024 Sinaptik GmbH - -With regard to the PandasAI Software: - -This software and associated documentation files (the "Software") may only be -used in production, if you (and any entity that you represent) have agreed to, -and are in compliance with, the PandasAI Subscription Terms of Service, available -at https://pandas-ai.com/terms (the “Enterprise Terms”), or other -agreement governing the use of the Software, as agreed by you and PandasAI, -and otherwise have a valid PandasAI Enterprise license for the -correct number of user seats. Subject to the foregoing sentence, you are free to -modify this Software and publish patches to the Software. You agree that PandasAI -and/or its licensors (as applicable) retain all right, title and interest in and -to all such modifications and/or patches, and all such modifications and/or -patches may only be used, copied, modified, displayed, distributed, or otherwise -exploited with a valid PandasAI Enterprise license for the correct -number of user seats. Notwithstanding the foregoing, you may copy and modify -the Software for development and testing purposes, without requiring a -subscription. You agree that PandasAI and/or its licensors (as applicable) retain -all right, title and interest in and to all such modifications. You are not -granted any other rights beyond what is expressly stated herein. Subject to the -foregoing, it is forbidden to copy, merge, publish, distribute, sublicense, -and/or sell the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -For all third party components incorporated into the PandasAI Software, those -components are licensed under the original license provided by the owner of the -applicable component. diff --git a/pandasai/ee/agents/advanced_security_agent/__init__.py b/pandasai/ee/agents/advanced_security_agent/__init__.py deleted file mode 100644 index 165ce57ac..000000000 --- a/pandasai/ee/agents/advanced_security_agent/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Optional, Union - -from pandasai.agent.base_security import BaseSecurity -from pandasai.config import load_config_from_json -from pandasai.ee.agents.advanced_security_agent.pipeline.advanced_security_pipeline import ( - AdvancedSecurityPipeline, -) -from pandasai.pipelines.abstract_pipeline import AbstractPipeline -from pandasai.pipelines.pipeline_context import PipelineContext -from pandasai.schemas.df_config import Config - - -class AdvancedSecurityAgent(BaseSecurity): - def __init__( - self, - config: Optional[Union[Config, dict]] = None, - pipeline: AbstractPipeline = None, - ) -> None: - context = None - - if isinstance(config, dict): - config = Config(**load_config_from_json(config)) - elif config is None: - config = Config() - - context = PipelineContext(None, config) - - pipeline = pipeline or AdvancedSecurityPipeline(context=context) - super().__init__(pipeline) - - def evaluate(self, query: str) -> bool: - return self.pipeline.run(query) diff --git a/pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_pipeline.py b/pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_pipeline.py deleted file mode 100644 index 70e97772b..000000000 --- a/pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_pipeline.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Optional - -from pandasai.ee.agents.advanced_security_agent.pipeline.advanced_security_prompt_generation import ( - AdvancedSecurityPromptGeneration, -) -from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall -from pandasai.helpers.logger import Logger -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class AdvancedSecurityPipeline: - def __init__( - self, - context: Optional[PipelineContext] = None, - logger: Optional[Logger] = None, - ): - self.pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - AdvancedSecurityPromptGeneration(), - LLMCall(), - ], - ) - - def run(self, input: str): - return self.pipeline.run(input) diff --git a/pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_prompt_generation.py b/pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_prompt_generation.py deleted file mode 100644 index cb22bee34..000000000 --- a/pandasai/ee/agents/advanced_security_agent/pipeline/advanced_security_prompt_generation.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Any - -from pandasai.ee.agents.advanced_security_agent.prompts.advanced_security_agent_prompt import ( - AdvancedSecurityAgentPrompt, -) -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - - -class AdvancedSecurityPromptGeneration(BaseLogicUnit): - """ - Code Prompt Generation Stage - """ - - pass - - def execute(self, input_query: str, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Last logic unit output - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: LogicUnitOutput(prompt) - """ - self.context = kwargs.get("context") - self.logger: Logger = kwargs.get("logger") - - prompt = AdvancedSecurityAgentPrompt(query=input_query, context=self.context) - self.logger.log(f"Using prompt: {prompt}") - - return LogicUnitOutput( - prompt, - True, - "Prompt Generated Successfully", - {"content_type": "prompt", "value": prompt.to_string()}, - ) diff --git a/pandasai/ee/agents/advanced_security_agent/pipeline/llm_call.py b/pandasai/ee/agents/advanced_security_agent/pipeline/llm_call.py deleted file mode 100644 index 47758b263..000000000 --- a/pandasai/ee/agents/advanced_security_agent/pipeline/llm_call.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Any - -from pandasai.exceptions import InvalidOutputValueMismatch -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class LLMCall(BaseLogicUnit): - """ - LLM Code Generation Stage - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - logger: Logger = kwargs.get("logger") - - retry_count = 0 - while retry_count <= pipeline_context.config.max_retries: - response = pipeline_context.config.llm.call(input, pipeline_context) - - logger.log( - f"""LLM response: - {response} - """ - ) - try: - result = False - if "" in response: - result = True - elif "" in response: - result = False - else: - raise InvalidOutputValueMismatch("Invalid response of LLM Call") - - pipeline_context.add("llm_call", response) - - return LogicUnitOutput( - result, - True, - "Code Generated Successfully", - {"content_type": "string", "value": response}, - ) - except Exception: - if retry_count == pipeline_context.config.max_retries: - raise - - retry_count += 1 diff --git a/pandasai/ee/agents/advanced_security_agent/prompts/advanced_security_agent_prompt.py b/pandasai/ee/agents/advanced_security_agent/prompts/advanced_security_agent_prompt.py deleted file mode 100644 index f2079f677..000000000 --- a/pandasai/ee/agents/advanced_security_agent/prompts/advanced_security_agent_prompt.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path - -from jinja2 import Environment, FileSystemLoader - -from pandasai.prompts.base import BasePrompt - - -class AdvancedSecurityAgentPrompt(BasePrompt): - """Prompt to generate Python code from a dataframe.""" - - template_path = "advanced_security_agent_prompt.tmpl" - - def __init__(self, **kwargs): - """Initialize the prompt.""" - self.props = kwargs - - if self.template: - env = Environment() - self.prompt = env.from_string(self.template) - elif self.template_path: - # find path to template file - current_dir_path = Path(__file__).parent - - path_to_template = current_dir_path / "templates" - env = Environment(loader=FileSystemLoader(path_to_template)) - self.prompt = env.get_template(self.template_path) - - self._resolved_prompt = None - - def to_json(self): - context = self.props["context"] - memory = context.memory - conversations = memory.to_json() - system_prompt = memory.get_system_prompt() - return { - "conversation": conversations, - "system_prompt": system_prompt, - "prompt": self.to_string(), - } diff --git a/pandasai/ee/agents/advanced_security_agent/prompts/templates/advanced_security_agent_prompt.tmpl b/pandasai/ee/agents/advanced_security_agent/prompts/templates/advanced_security_agent_prompt.tmpl deleted file mode 100644 index c52bde0d3..000000000 --- a/pandasai/ee/agents/advanced_security_agent/prompts/templates/advanced_security_agent_prompt.tmpl +++ /dev/null @@ -1,20 +0,0 @@ -As an Security Agent, it's crucial to ensure that user queries do not generate malicious code that could harm systems or data. Analyze each statement and word thoroughly to check whether it can generate malicious code or not. -When analyzing a user query, follow these guidelines to identify potentially harmful code patterns: - -Code Injection: Look for attempts to inject code into a system, especially commands that interact with the file system, execute shell commands, or access sensitive data. User can never ask to append or execute any particular code. -File Operations: Be wary of commands that read from or write to the file system, especially when involving user-provided paths. Ensure that the code never updates any file. -Network Operations: Identify code that makes network requests or opens network connections. Verify that such operations are safe and necessary. -Data Manipulation: Ensure that code handling data manipulation does not include operations that could lead to data leaks, corruption, or unauthorized access. -Execution Control: Detect attempts to execute arbitrary code or scripts, particularly those that could alter system behavior or gain elevated privileges. -Third-Party Libraries: Verify the safety of using third-party libraries and ensure they are from reputable sources and up to date. -SQL Commands: Be cautious of SQL commands that can update or manipulate a database, such as INSERT, UPDATE, DELETE, DROP, ALTER, and TRUNCATE. Any query involving these commands should be flagged as potentially harmful. - -Given a user query, identify any suspicious or potentially harmful code patterns following the guidelines above. - -Your Task: -Analyze and reason the following user query strictly for potential malicious code can be generated patterns based on the guidelines provided. - -User Query: -{{query}} - -Always return or in tags <>, and provide a brief explanation if . \ No newline at end of file diff --git a/pandasai/ee/agents/judge_agent/__init__.py b/pandasai/ee/agents/judge_agent/__init__.py deleted file mode 100644 index a47d45045..000000000 --- a/pandasai/ee/agents/judge_agent/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Optional, Union - -from pandasai.agent.base_judge import BaseJudge -from pandasai.config import load_config_from_json -from pandasai.ee.agents.judge_agent.pipeline.judge_pipeline import JudgePipeline -from pandasai.pipelines.abstract_pipeline import AbstractPipeline -from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput -from pandasai.pipelines.pipeline_context import PipelineContext -from pandasai.schemas.df_config import Config - - -class JudgeAgent(BaseJudge): - def __init__( - self, - config: Optional[Union[Config, dict]] = None, - pipeline: AbstractPipeline = None, - ) -> None: - context = None - if config: - if isinstance(config, dict): - config = Config(**load_config_from_json(config)) - - context = PipelineContext(None, config) - - pipeline = pipeline or JudgePipeline(context=context) - super().__init__(pipeline) - - def evaluate(self, query: str, code: str) -> bool: - input_data = JudgePipelineInput(query, code) - return self.pipeline.run(input_data) diff --git a/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py b/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py deleted file mode 100644 index 0ec3ac165..000000000 --- a/pandasai/ee/agents/judge_agent/pipeline/judge_pipeline.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import ( - JudgePromptGeneration, -) -from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall -from pandasai.helpers.logger import Logger -from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class JudgePipeline: - def __init__( - self, - context: Optional[PipelineContext] = None, - logger: Optional[Logger] = None, - ): - self.pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - JudgePromptGeneration(), - LLMCall(), - ], - ) - - def run(self, input: JudgePipelineInput): - return self.pipeline.run(input) diff --git a/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py b/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py deleted file mode 100644 index a8ab9b565..000000000 --- a/pandasai/ee/agents/judge_agent/pipeline/judge_prompt_generation.py +++ /dev/null @@ -1,50 +0,0 @@ -import datetime -from typing import Any - -from pandasai.ee.agents.judge_agent.prompts.judge_agent_prompt import JudgeAgentPrompt -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput -from pandasai.pipelines.logic_unit_output import LogicUnitOutput - - -class JudgePromptGeneration(BaseLogicUnit): - """ - Code Prompt Generation Stage - """ - - pass - - def execute(self, input_data: JudgePipelineInput, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Last logic unit output - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: LogicUnitOutput(prompt) - """ - self.context = kwargs.get("context") - self.logger: Logger = kwargs.get("logger") - - now = datetime.datetime.now() - human_readable_datetime = now.strftime("%A, %B %d, %Y %I:%M %p") - - prompt = JudgeAgentPrompt( - query=input_data.query, - code=input_data.code, - context=self.context, - date=human_readable_datetime, - ) - self.logger.log(f"Using prompt: {prompt}") - - return LogicUnitOutput( - prompt, - True, - "Prompt Generated Successfully", - {"content_type": "prompt", "value": prompt.to_string()}, - ) diff --git a/pandasai/ee/agents/judge_agent/pipeline/llm_call.py b/pandasai/ee/agents/judge_agent/pipeline/llm_call.py deleted file mode 100644 index 47758b263..000000000 --- a/pandasai/ee/agents/judge_agent/pipeline/llm_call.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Any - -from pandasai.exceptions import InvalidOutputValueMismatch -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class LLMCall(BaseLogicUnit): - """ - LLM Code Generation Stage - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - logger: Logger = kwargs.get("logger") - - retry_count = 0 - while retry_count <= pipeline_context.config.max_retries: - response = pipeline_context.config.llm.call(input, pipeline_context) - - logger.log( - f"""LLM response: - {response} - """ - ) - try: - result = False - if "" in response: - result = True - elif "" in response: - result = False - else: - raise InvalidOutputValueMismatch("Invalid response of LLM Call") - - pipeline_context.add("llm_call", response) - - return LogicUnitOutput( - result, - True, - "Code Generated Successfully", - {"content_type": "string", "value": response}, - ) - except Exception: - if retry_count == pipeline_context.config.max_retries: - raise - - retry_count += 1 diff --git a/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py b/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py deleted file mode 100644 index 91616aaf8..000000000 --- a/pandasai/ee/agents/judge_agent/prompts/judge_agent_prompt.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path - -from jinja2 import Environment, FileSystemLoader - -from pandasai.prompts.base import BasePrompt - - -class JudgeAgentPrompt(BasePrompt): - """Prompt to generate Python code from a dataframe.""" - - template_path = "judge_agent_prompt.tmpl" - - def __init__(self, **kwargs): - """Initialize the prompt.""" - self.props = kwargs - - if self.template: - env = Environment() - self.prompt = env.from_string(self.template) - elif self.template_path: - # find path to template file - current_dir_path = Path(__file__).parent - - path_to_template = current_dir_path / "templates" - env = Environment(loader=FileSystemLoader(path_to_template)) - self.prompt = env.get_template(self.template_path) - - self._resolved_prompt = None - - def to_json(self): - context = self.props["context"] - memory = context.memory - conversations = memory.to_json() - system_prompt = memory.get_system_prompt() - return { - "conversation": conversations, - "system_prompt": system_prompt, - "prompt": self.to_string(), - } diff --git a/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl b/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl deleted file mode 100644 index 315f7057d..000000000 --- a/pandasai/ee/agents/judge_agent/prompts/templates/judge_agent_prompt.tmpl +++ /dev/null @@ -1,11 +0,0 @@ -Today is {{date}} -### QUERY -{{query}} -### GENERATED CODE -{{code}} - -Reason step by step and at the end answer: -1. Explain what the code does -2. Explain what the user query asks for -3. Strictly compare the query with the code that is generated -Always return or if exactly meets the requirements diff --git a/pandasai/ee/agents/semantic_agent/__init__.py b/pandasai/ee/agents/semantic_agent/__init__.py deleted file mode 100644 index 61c3e1efc..000000000 --- a/pandasai/ee/agents/semantic_agent/__init__.py +++ /dev/null @@ -1,215 +0,0 @@ -import json -from typing import List, Optional, Type, Union - -import pandas as pd - -from pandasai.agent.base import BaseAgent -from pandasai.agent.base_judge import BaseJudge -from pandasai.connectors.pandas import PandasConnector -from pandasai.constants import PANDASBI_SETUP_MESSAGE -from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator -from pandasai.ee.agents.semantic_agent.pipeline.semantic_chat_pipeline import ( - SemanticChatPipeline, -) -from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import ( - GenerateDFSchemaPrompt, -) -from pandasai.ee.helpers.json_helper import extract_json_from_json_str -from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson -from pandasai.helpers.cache import Cache -from pandasai.helpers.memory import Memory -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext -from pandasai.schemas.df_config import Config -from pandasai.vectorstores.vectorstore import VectorStore - - -class SemanticAgent(BaseAgent): - """ - Answer Semantic queries - """ - - def __init__( - self, - dfs: Union[pd.DataFrame, List[pd.DataFrame]], - config: Optional[Union[Config, dict]] = None, - schema: Optional[List[dict]] = None, - memory_size: Optional[int] = 10, - pipeline: Optional[Type[GenerateChatPipeline]] = None, - vectorstore: Optional[VectorStore] = None, - description: str = None, - judge: BaseJudge = None, - ): - super().__init__(dfs, config, memory_size, vectorstore, description) - - self._validate_config() - - self._schema_cache = Cache("schema") - self._schema = schema or [] - - if not self._schema: - self._create_schema() - - if self._schema: - self._sort_dfs_according_to_schema() - self.init_duckdb_instance() - - # semantic agent works only with direct sql true - self.config.direct_sql = True - - self.context = PipelineContext( - dfs=self.dfs, - config=self.config, - memory=Memory(memory_size, agent_info=description), - vectorstore=self._vectorstore, - initial_values={"df_schema": self._schema}, - ) - - self.pipeline = ( - pipeline( - self.context, - self.logger, - judge=judge, - on_prompt_generation=self._callbacks.on_prompt_generation, - on_code_generation=self._callbacks.on_code_generation, - before_code_execution=self._callbacks.before_code_execution, - on_result=self._callbacks.on_result, - ) - if pipeline - else SemanticChatPipeline( - self.context, - self.logger, - judge=judge, - on_prompt_generation=self._callbacks.on_prompt_generation, - on_code_generation=self._callbacks.on_code_generation, - before_code_execution=self._callbacks.before_code_execution, - on_result=self._callbacks.on_result, - ) - ) - - def validate_and_convert_json(self, jsons): - json_strs = [] - - try: - for json_data in jsons: - if isinstance(json_data, str): - json.loads(json_data) - json_strs.append(json_data) - elif isinstance(json_data, dict): - json_strs.append(json.dumps(json_data)) - - except Exception as e: - raise InvalidTrainJson("Error validating JSON string") from e - - return json_strs - - def train( - self, - queries: Optional[List[str]] = None, - jsons: Optional[List[Union[dict, str]]] = None, - docs: Optional[List[str]] = None, - ) -> None: - json_strs = self.validate_and_convert_json(jsons) if jsons else None - - super().train(queries=queries, codes=json_strs, docs=docs) - - def query(self, query): - query_pipeline = Pipeline( - context=self.context, - logger=self.logger, - steps=[ - CodeGenerator(), - ], - ) - code = query_pipeline.run(query) - - self.execute_code(code) - - def init_duckdb_instance(self): - for index, tables in enumerate(self._schema): - if isinstance(self.dfs[index], PandasConnector): - self._sync_pandas_dataframe_schema(self.dfs[index], tables) - self.dfs[index].enable_sql_query(tables["table"]) - - def _sync_pandas_dataframe_schema(self, df: PandasConnector, schema: dict): - for dimension in schema["dimensions"]: - if dimension["type"] in ["date", "datetime", "timestamp"]: - column = dimension["sql"] - df.pandas_df[column] = pd.to_datetime(df.pandas_df[column]) - - def _sort_dfs_according_to_schema(self): - if not self._schema: - return - - schema_dict = { - table["table"]: [dim["sql"] for dim in table["dimensions"]] - for table in self._schema - } - sorted_dfs = [] - - for table in self._schema: - matched = False - for df in self.dfs: - df_columns = df.head().columns - if all(column in df_columns for column in schema_dict[table["table"]]): - sorted_dfs.append(df) - matched = True - - if not matched: - raise InvalidSchemaJson( - f"Some sql column of table {table['table']} doesn't match with any dataframe" - ) - - self.dfs = sorted_dfs - - def _create_schema(self): - """ - Generate schema on the initialization of Agent class - """ - if self._schema: - self.logger.log(f"using user provided schema: {self._schema}") - return - - key = self._get_schema_cache_key() - if self.config.enable_cache: - value = self._schema_cache.get(key) - if value is not None: - self._schema = json.loads(value) - self.logger.log(f"using schema: {self._schema}") - return - - prompt = GenerateDFSchemaPrompt(context=self.context) - - result = self.call_llm_with_prompt(prompt) - self.logger.log( - f"""Initializing Schema: {result} - """ - ) - schema_str = result.replace("# SAMPLE SCHEMA", "") - schema_data = extract_json_from_json_str(schema_str) - if isinstance(schema_data, dict): - schema_data = [schema_data] - - self._schema = schema_data or [] - # save schema in the cache - if self.config.enable_cache and self._schema: - self._schema_cache.set(key, json.dumps(self._schema)) - - self.logger.log(f"using schema: {self._schema}") - - def _validate_config(self): - if not isinstance(self.config.llm, BambooLLM) and not isinstance( - self.config.llm, FakeLLM - ): - raise InvalidConfigError(PANDASBI_SETUP_MESSAGE) - - def _get_schema_cache_key(self): - """ - Get the cache key for the schema - """ - return "schema_" + "_".join( - [str(df.head().columns.tolist()) for df in self.dfs] - ) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/Semantic_prompt_generation.py b/pandasai/ee/agents/semantic_agent/pipeline/Semantic_prompt_generation.py deleted file mode 100644 index 23a10c91b..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/Semantic_prompt_generation.py +++ /dev/null @@ -1,46 +0,0 @@ -import json -from typing import Any - -from pandasai.ee.agents.semantic_agent.prompts.semantic_agent_prompt import ( - SemanticAgentPrompt, -) -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class SemanticPromptGeneration(BaseLogicUnit): - """ - Code Prompt Generation Stage - """ - - pass - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Last logic unit output - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: LogicUnitOutput(prompt) - """ - self.context: PipelineContext = kwargs.get("context") - self.logger: Logger = kwargs.get("logger") - - prompt = SemanticAgentPrompt( - context=self.context, schema=json.dumps(self.context.get("df_schema")) - ) - self.logger.log(f"Using prompt: {prompt}") - - return LogicUnitOutput( - prompt, - True, - "Prompt Generated Successfully", - {"content_type": "prompt", "value": prompt.to_string()}, - ) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py b/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py deleted file mode 100644 index 0b01e82fd..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/code_generator.py +++ /dev/null @@ -1,231 +0,0 @@ -import traceback -from typing import Any, Callable - -from pandasai.ee.helpers.query_builder import QueryBuilder -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class CodeGenerator(BaseLogicUnit): - """ - LLM Code Generation Stage - """ - - def __init__( - self, - on_code_generation: Callable[[str, Exception], None] = None, - on_failure=None, - **kwargs, - ): - super().__init__(**kwargs) - self.on_code_generation = on_code_generation - self.on_failure = on_failure - - def execute(self, input_data: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input_data: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - logger: Logger = kwargs.get("logger") - schema = pipeline_context.get("df_schema") - query_builder = QueryBuilder(schema) - - retry_count = 0 - while retry_count <= pipeline_context.config.max_retries: - try: - sql_query = query_builder.generate_sql(input_data) - - response_type = self._get_type(input_data) - - gen_code = self._generate_code(response_type, input_data) - - code = f""" -{"import matplotlib.pyplot as plt" if response_type == "plot" else ""} -import pandas as pd - -sql_query="{sql_query}" -data = execute_sql_query(sql_query) - -{gen_code} -""" - - logger.log(f"""Code Generated: {code}""") - - # Implement error handling pipeline here... - - return LogicUnitOutput( - code, - True, - "Code Generated Successfully", - {"content_type": "string", "value": code}, - ) - except Exception: - if ( - retry_count == pipeline_context.config.max_retries - or not self.on_failure - ): - raise - - traceback_errors = traceback.format_exc() - - input_data = self.on_failure(input_data, traceback_errors) - - retry_count += 1 - - def _get_type(self, input: dict) -> bool: - return ( - "plot" - if input["type"] - in ["bar", "line", "histogram", "pie", "scatter", "boxplot"] - else input["type"] - ) - - def _generate_code(self, type, query): - if type == "number": - code = self._generate_code_for_number(query) - return f""" -{code} -result = {{"type": "number","value": total_value}} -""" - elif type == "dataframe": - return """ -result = {"type": "dataframe","value": data} -""" - else: - code = self.generate_matplotlib_code(query) - code += """ - -result = {"type": "plot","value": "charts.png"}""" - return code - - def _generate_code_for_number(self, query: dict) -> str: - value = None - if len(query["measures"]) > 0: - value = query["measures"][0].split(".")[1] - else: - value = query["dimensions"][0].split(".")[1] - - return f'total_value = data["{value}"].sum()\n' - - def generate_matplotlib_code(self, query: dict) -> str: - chart_type = query["type"] - x_label = query.get("options", {}).get("xLabel", None) - y_label = query.get("options", {}).get("yLabel", None) - title = query["options"].get("title", None) - legend_display = {"display": True} - legend_position = "best" - if "legend" in query["options"]: - legend_display = query["options"]["legend"].get("display", None) - legend_position = query["options"]["legend"].get("position", None) - legend_position = ( - legend_position - in [ - "best", - "upper right", - "upper left", - "lower left", - "lower right", - "right", - "center left", - "center right", - "lower center", - "upper center", - "center", - ] - or "best" - ) - - code = "" - - code_generators = { - "bar": self._generate_bar_code, - "line": self._generate_line_code, - "pie": self._generate_pie_code, - "scatter": self._generate_scatter_code, - "hist": self._generate_hist_code, - "histogram": self._generate_hist_code, - "box": self._generate_box_code, - "boxplot": self._generate_box_code, - } - - code_generator = code_generators.get(chart_type, lambda query: "") - code += code_generator(query) - - if x_label: - code += f"plt.xlabel('''{x_label}''')\n" - if y_label: - code += f"plt.ylabel('''{y_label}''')\n" - if title: - code += f"plt.title('''{title}''')\n" - - if legend_display: - code += f"plt.legend(loc='{legend_position}')\n" - - code += """ - -plt.savefig("charts.png")""" - - return code - - def _generate_bar_code(self, query): - x_key = self._get_dimensions_key(query) - plots = "" - for measure in query["measures"]: - if isinstance(measure, str): - field_name = measure.split(".")[1] - label = field_name - else: - field_name = measure["id"].split(".")[1] - label = measure["label"] - - plots += ( - f"""plt.bar(data["{x_key}"], data["{field_name}"], label="{label}")\n""" - ) - - return plots - - def _generate_pie_code(self, query): - dimension = query["dimensions"][0].split(".")[1] - measure = query["measures"][0].split(".")[1] - return f"""plt.pie(data["{measure}"], labels=data["{dimension}"], autopct='%1.1f%%')\n""" - - def _generate_line_code(self, query): - x_key = self._get_dimensions_key(query) - plots = "" - for measure in query["measures"]: - field_name = measure.split(".")[1] - plots += f"""plt.plot(data["{x_key}"], data["{field_name}"])\n""" - - return plots - - def _generate_scatter_code(self, query): - x_key = query["dimensions"][0].split(".")[1] - y_key = query["dimensions"][1].split(".")[1] - return f"plt.scatter(data['{x_key}'], data['{y_key}'])\n" - - def _generate_hist_code(self, query): - y_key = query["measures"][0].split(".")[1] - return f"plt.hist(data['{y_key}'])\n" - - def _generate_box_code(self, query): - y_key = query["measures"][0].split(".")[1] - return f"plt.boxplot(data['{y_key}'])\n" - - def _get_dimensions_key(self, query): - if "dimensions" in query and len(query["dimensions"]) > 0: - return query["dimensions"][0].split(".")[1] - - time_dimension = query["timeDimensions"][0] - dimension = time_dimension["dimension"].split(".")[1] - return f"{dimension}_by_{time_dimension['granularity']}" diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py deleted file mode 100644 index 5aeee2ff7..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/error_correction_pipeline.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Optional - -from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator -from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( - FixSemanticJsonPipeline, -) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall -from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( - SemanticPromptGeneration, -) -from pandasai.helpers.logger import Logger -from pandasai.pipelines.chat.code_cleaning import CodeCleaning -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class ErrorCorrectionPipeline: - """ - Error Correction Pipeline to regenerate prompt and code - """ - - _context: PipelineContext - _logger: Logger - - def __init__( - self, - context: Optional[PipelineContext] = None, - logger: Optional[Logger] = None, - on_prompt_generation=None, - on_code_generation=None, - ): - self.pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - SemanticPromptGeneration( - on_execution=on_prompt_generation, - ), - LLMCall(), - CodeGenerator( - on_execution=on_code_generation, - on_failure=self.on_wrong_semantic_json, - ), - CodeCleaning(), - ], - ) - - self.fix_semantic_json_pipeline = FixSemanticJsonPipeline( - context=context, - logger=logger, - on_code_generation=on_code_generation, - on_prompt_generation=on_prompt_generation, - ) - - self._context = context - self._logger = logger - - def run(self, input: ErrorCorrectionPipelineInput): - self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") - return self.pipeline.run(input) - - def on_wrong_semantic_json(self, code, errors): - correction_input = ErrorCorrectionPipelineInput(code, errors) - return self.fix_semantic_json_pipeline.run(correction_input) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py deleted file mode 100644 index df074c8c7..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_json_pipeline.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Optional - -from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_schema_prompt import ( - FixSemanticSchemaPrompt, -) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall -from pandasai.helpers.logger import Logger -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class FixSemanticJsonPipeline: - """ - Error Correction Pipeline to regenerate prompt and code - """ - - _context: PipelineContext - _logger: Logger - - def __init__( - self, - context: Optional[PipelineContext] = None, - logger: Optional[Logger] = None, - on_prompt_generation=None, - on_code_generation=None, - ): - self.pipeline = Pipeline( - context=context, - logger=logger, - steps=[FixSemanticSchemaPrompt(), LLMCall()], - ) - - self._context = context - self._logger = logger - - def run(self, input: ErrorCorrectionPipelineInput): - self._logger.log(f"Executing Pipeline: {self.__class__.__name__}") - return self.pipeline.run(input) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_schema_prompt.py b/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_schema_prompt.py deleted file mode 100644 index e7e5425d9..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/error_correction_pipeline/fix_semantic_schema_prompt.py +++ /dev/null @@ -1,61 +0,0 @@ -import json -from typing import Any, Callable - -from pandasai.ee.agents.semantic_agent.prompts.fix_semantic_json import ( - FixSemanticJsonPrompt, -) -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class FixSemanticSchemaPrompt(BaseLogicUnit): - on_prompt_generation: Callable[[str], None] - - def __init__( - self, - on_prompt_generation=None, - skip_if=None, - on_execution=None, - before_execution=None, - ): - self.on_prompt_generation = on_prompt_generation - super().__init__(skip_if, on_execution, before_execution) - - def execute(self, input: ErrorCorrectionPipelineInput, **kwargs) -> Any: - """ - A method to retry the code execution with error correction framework. - - Args: - code (str): A python code - context (PipelineContext) : Pipeline Context - logger (Logger) : Logger - e (Exception): An exception - dataframes - - Returns (str): A python code - """ - self.context: PipelineContext = kwargs.get("context") - self.logger: Logger = kwargs.get("logger") - - prompt = FixSemanticJsonPrompt( - context=self.context, - generated_json=input.code, - error=input.exception, - schema=json.dumps(self.context.get("df_schema")), - ) - self.logger.log(f"Using prompt: {prompt}") - - return LogicUnitOutput( - prompt, - True, - "Prompt Generated Successfully", - { - "content_type": "prompt", - "value": prompt.to_string(), - }, - ) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py b/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py deleted file mode 100644 index af1bd2e18..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/llm_call.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Any - -from pandasai.ee.helpers.json_helper import extract_json_from_json_str -from pandasai.helpers.logger import Logger -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class LLMCall(BaseLogicUnit): - """ - LLM Code Generation Stage - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method will return output according to - Implementation. - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - pipeline_context: PipelineContext = kwargs.get("context") - logger: Logger = kwargs.get("logger") - - retry_count = 0 - while retry_count <= pipeline_context.config.max_retries: - response = pipeline_context.config.llm.call(input, pipeline_context) - - logger.log( - f"""LLM response: - {response} - """ - ) - try: - # Validate is valid Json - response_json = extract_json_from_json_str(response) - - pipeline_context.add("llm_call", response) - - return LogicUnitOutput( - response_json, - True, - "Code Generated Successfully", - {"content_type": "string", "value": response_json}, - ) - except Exception: - if retry_count == pipeline_context.config.max_retries: - raise - - retry_count += 1 diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py deleted file mode 100644 index 1f77c926a..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/semantic_chat_pipeline.py +++ /dev/null @@ -1,118 +0,0 @@ -from typing import Optional - -from pandasai.agent.base_judge import BaseJudge -from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator -from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.error_correction_pipeline import ( - ErrorCorrectionPipeline, -) -from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import ( - FixSemanticJsonPipeline, -) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall -from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( - SemanticPromptGeneration, -) -from pandasai.ee.agents.semantic_agent.pipeline.semantic_result_parsing import ( - SemanticResultParser, -) -from pandasai.ee.agents.semantic_agent.pipeline.validate_pipeline_input import ( - ValidatePipelineInput, -) -from pandasai.helpers.logger import Logger -from pandasai.pipelines.chat.cache_lookup import CacheLookup -from pandasai.pipelines.chat.code_cleaning import CodeCleaning -from pandasai.pipelines.chat.code_execution import CodeExecution -from pandasai.pipelines.chat.error_correction_pipeline.error_correction_pipeline_input import ( - ErrorCorrectionPipelineInput, -) -from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline -from pandasai.pipelines.chat.result_validation import ResultValidation -from pandasai.pipelines.pipeline import Pipeline -from pandasai.pipelines.pipeline_context import PipelineContext - - -class SemanticChatPipeline(GenerateChatPipeline): - code_generation_pipeline = Pipeline - code_execution_pipeline = Pipeline - context: PipelineContext - _logger: Logger - last_error: str - - def __init__( - self, - context: Optional[PipelineContext] = None, - logger: Optional[Logger] = None, - judge: BaseJudge = None, - on_prompt_generation=None, - on_code_generation=None, - before_code_execution=None, - on_result=None, - ): - super().__init__( - context, - logger, - judge=judge, - on_prompt_generation=on_prompt_generation, - on_code_generation=on_code_generation, - before_code_execution=before_code_execution, - on_result=on_result, - ) - - self.code_generation_pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - ValidatePipelineInput(), - CacheLookup(), - SemanticPromptGeneration( - skip_if=self.is_cached, - on_execution=on_prompt_generation, - ), - LLMCall(), - CodeGenerator( - on_execution=on_code_generation, - on_failure=self.on_wrong_semantic_json, - ), - CodeCleaning( - skip_if=self.no_code, - on_retry=self.on_code_retry, - ), - ], - ) - - self.code_execution_pipeline = Pipeline( - context=context, - logger=logger, - steps=[ - CodeExecution( - before_execution=before_code_execution, - on_retry=self.on_code_retry, - ), - ResultValidation(), - SemanticResultParser( - before_execution=on_result, - ), - ], - ) - - self.code_exec_error_pipeline = ErrorCorrectionPipeline( - context=context, - logger=logger, - on_code_generation=on_code_generation, - on_prompt_generation=on_prompt_generation, - ) - - self.fix_semantic_json_pipeline = FixSemanticJsonPipeline( - context=context, - logger=logger, - on_code_generation=on_code_generation, - on_prompt_generation=on_prompt_generation, - ) - - self.context = context - self._logger = logger - self.last_error = None - - def on_wrong_semantic_json(self, code, errors): - correction_input = ErrorCorrectionPipelineInput(code, errors) - return self.fix_semantic_json_pipeline.run(correction_input) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/semantic_result_parsing.py b/pandasai/ee/agents/semantic_agent/pipeline/semantic_result_parsing.py deleted file mode 100644 index f897e81f2..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/semantic_result_parsing.py +++ /dev/null @@ -1,23 +0,0 @@ -from pandasai.pipelines.chat.result_parsing import ResultParsing -from pandasai.pipelines.pipeline_context import PipelineContext - - -class SemanticResultParser(ResultParsing): - """ - Semantic Agent Result Parsing Stage - """ - - pass - - def _add_result_to_memory(self, result: dict, context: PipelineContext): - """ - Add the result to the memory. - - Args: - result (dict): The result to add to the memory - context (PipelineContext) : Pipeline Context - """ - if result is None: - return - - context.memory.add(context.get("llm_call"), False) diff --git a/pandasai/ee/agents/semantic_agent/pipeline/validate_pipeline_input.py b/pandasai/ee/agents/semantic_agent/pipeline/validate_pipeline_input.py deleted file mode 100644 index 1c3f9ac64..000000000 --- a/pandasai/ee/agents/semantic_agent/pipeline/validate_pipeline_input.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Any, List - -from pandasai.connectors.base import BaseConnector -from pandasai.connectors.pandas import PandasConnector -from pandasai.constants import PANDASBI_SETUP_MESSAGE -from pandasai.exceptions import InvalidConfigError -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class ValidatePipelineInput(BaseLogicUnit): - """ - Validates pipeline input - """ - - pass - - def _validate_direct_sql(self, dfs: List[BaseConnector]) -> bool: - """ - Raises error if they don't belong to SQL connectors or have different credentials - Args: - dfs (List[BaseConnector]): list of BaseConnectors - - Raises: - InvalidConfigError: Raise Error in case of config is set but criteria is not met - """ - - if self.context.config.direct_sql: - if all( - ( - hasattr(df, "is_sql_connector") - and df.is_sql_connector - and df.equals(dfs[0]) - ) - for df in dfs - ) or all( - (isinstance(df, PandasConnector) and df.sql_enabled) for df in dfs - ): - return True - else: - raise InvalidConfigError( - "Direct SQL requires all connectors to be SQL connectors and they must belong to the same datasource " - "and have the same credentials" - ) - return False - - def execute(self, input: Any, **kwargs) -> Any: - """ - This method validates pipeline context and configs - - :param input: Your input data. - :param kwargs: A dictionary of keyword arguments. - - 'logger' (any): The logger for logging. - - 'config' (Config): Global configurations for the test - - 'context' (any): The execution context. - - :return: The result of the execution. - """ - self.context: PipelineContext = kwargs.get("context") - if not isinstance(self.context.config.llm, BambooLLM): - raise InvalidConfigError( - f"""Semantic Agent works only with BambooLLM follow instructions for setup:\n {PANDASBI_SETUP_MESSAGE}""" - ) - - self._validate_direct_sql(self.context.dfs) - - return LogicUnitOutput(input, True, "Input Validation Successful") diff --git a/pandasai/ee/agents/semantic_agent/prompts/fix_semantic_json.py b/pandasai/ee/agents/semantic_agent/prompts/fix_semantic_json.py deleted file mode 100644 index b027eb7f1..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/fix_semantic_json.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path - -from jinja2 import Environment, FileSystemLoader - -from pandasai.prompts.base import BasePrompt - - -class FixSemanticJsonPrompt(BasePrompt): - """Prompt to generate Python code from a dataframe.""" - - template_path = "fix_semantic_json_prompt.tmpl" - - def __init__(self, **kwargs): - """Initialize the prompt.""" - self.props = kwargs - - if self.template: - env = Environment() - self.prompt = env.from_string(self.template) - elif self.template_path: - # find path to template file - current_dir_path = Path(__file__).parent - - path_to_template = current_dir_path / "templates" - env = Environment(loader=FileSystemLoader(path_to_template)) - self.prompt = env.get_template(self.template_path) - - self._resolved_prompt = None - - def to_json(self): - context = self.props["context"] - memory = context.memory - conversations = memory.to_json() - system_prompt = memory.get_system_prompt() - return { - "conversation": conversations, - "system_prompt": system_prompt, - "prompt": self.to_string(), - } diff --git a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py b/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py deleted file mode 100644 index 28390f8b7..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/generate_df_schema.py +++ /dev/null @@ -1,60 +0,0 @@ -import json -from pathlib import Path - -from jinja2 import Environment, FileSystemLoader - -from pandasai.ee.helpers.json_helper import extract_json_from_json_str -from pandasai.prompts.base import BasePrompt - - -class GenerateDFSchemaPrompt(BasePrompt): - """Prompt to generate Python code with SQL from a dataframe.""" - - template_path = "generate_df_schema.tmpl" - - def __init__(self, **kwargs): - """Initialize the prompt.""" - self.props = kwargs - - if self.template: - env = Environment() - self.prompt = env.from_string(self.template) - elif self.template_path: - # find path to template file - current_dir_path = Path(__file__).parent - - path_to_template = current_dir_path / "templates" - env = Environment(loader=FileSystemLoader(path_to_template)) - self.prompt = env.get_template(self.template_path) - - self._resolved_prompt = None - - def validate(self, output: str) -> bool: - try: - json_data = extract_json_from_json_str( - output.replace("# SAMPLE SCHEMA", "") - ) - context = self.props["context"] - if isinstance(json_data, dict): - json_data = [json_data] - if isinstance(json_data, list): - for record in json_data: - if not all(key in record for key in ("name", "table")): - return False - - return len(context.dfs) == len(json_data) - - except json.JSONDecodeError: - pass - return False - - def to_json(self): - context = self.props["context"] - memory = context.memory - conversations = memory.to_json() - system_prompt = memory.get_system_prompt() - return { - "conversation": conversations, - "system_prompt": system_prompt, - "prompt": self.to_string(), - } diff --git a/pandasai/ee/agents/semantic_agent/prompts/semantic_agent_prompt.py b/pandasai/ee/agents/semantic_agent/prompts/semantic_agent_prompt.py deleted file mode 100644 index 1e0f7eff0..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/semantic_agent_prompt.py +++ /dev/null @@ -1,39 +0,0 @@ -from pathlib import Path - -from jinja2 import Environment, FileSystemLoader - -from pandasai.prompts.base import BasePrompt - - -class SemanticAgentPrompt(BasePrompt): - """Prompt to generate Python code from a dataframe.""" - - template_path = "semantic_agent_prompt.tmpl" - - def __init__(self, **kwargs): - """Initialize the prompt.""" - self.props = kwargs - - if self.template: - env = Environment() - self.prompt = env.from_string(self.template) - elif self.template_path: - # find path to template file - current_dir_path = Path(__file__).parent - - path_to_template = current_dir_path / "templates" - env = Environment(loader=FileSystemLoader(path_to_template)) - self.prompt = env.get_template(self.template_path) - - self._resolved_prompt = None - - def to_json(self): - context = self.props["context"] - memory = context.memory - conversations = memory.to_json() - system_prompt = memory.get_system_prompt() - return { - "conversation": conversations, - "system_prompt": system_prompt, - "prompt": self.to_string(), - } diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/fix_semantic_json_prompt.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/fix_semantic_json_prompt.tmpl deleted file mode 100644 index b973c53df..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/fix_semantic_json_prompt.tmpl +++ /dev/null @@ -1,13 +0,0 @@ -=== SemanticAgent === -The user asked the following question: -{{context.memory.get_conversation()}} -# SCHEMA -{{schema}} - -You generated this Json: -{{generated_json}} - -It fails with the following error: -{{error}} - -Understand the error in json return the fixed json \ No newline at end of file diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl deleted file mode 100644 index edec51e2d..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/generate_df_schema.tmpl +++ /dev/null @@ -1,153 +0,0 @@ -# SAMPLE SCHEMA -[ - { - "name": "Contracts", - "table": "contracts", - "measures": [ - { - "name": "contract_count", - "type": "count", - "sql": "store_id" - }, - { - "name": "contract_duration", - "type": "number", - "sql": "${contract_end_date} - ${contract_start_date}" - }, - { - "name": "contract_avg_duration", - "type": "avg", - "sql": "${contract_duration}" - } - ], - "dimensions": [ - { - "name": "contract_code", - "type": "string", - "sql": "contract_code", - "samples": ["C12345", "C67890"] - }, - { - "name": "store_id", - "type": "string", - "sql": "store_id", - "samples": ["S12345", "S67890"] - }, - { - "name": "tenant_code", - "type": "string", - "sql": "tenant_code", - "samples": ["T12345", "T67890"] - }, - { - "name": "tenant_name", - "type": "string", - "sql": "tenant_name", - "samples": ["Tenant A", "Tenant B"] - }, - { - "name": "store_brand", - "type": "string", - "sql": "store_brand", - "samples": ["Brand X", "Brand Y"] - }, - { - "name": "branch_segment_1", - "type": "string", - "sql": "branch_segment_1", - "samples": ["Segment 1", "Segment 2"] - }, - { - "name": "branch_segment_2", - "type": "string", - "sql": "branch_segment_2", - "samples": ["Segment A", "Segment B"] - }, - { - "name": "contract_start_date", - "type": "date", - "sql": "contract_start_date", - "samples": ["2023-01-01", "2023-02-01"] - }, - { - "name": "contract_end_date", - "type": "date", - "sql": "contract_end_date", - "samples": ["2024-01-01", "2024-02-01"] - } - ], - "joins": [ - { - "name": "Fee", - "join_type": "left", - "sql": "${Contracts.contract_code} = ${Fees.contract_id}" - } - ] - }, - { - "name": "Fees", - "table": "fees", - "measures": [ - { - "name": "total_taxable", - "type": "sum", - "sql": "imponibile_tot" - }, - { - "name": "total_revenue", - "type": "sum", - "sql": "totale_tot" - } - ], - "dimensions": [ - { - "name": "contract_id", - "type": "string", - "sql": "contract_id", - "samples": ["C12345", "C67890"] - }, - { - "name": "code", - "type": "string", - "sql": "code", - "samples": ["F12345", "F67890"] - }, - { - "name": "station", - "type": "string", - "sql": "station", - "samples": ["Station X", "Station Y"] - }, - { - "name": "tenant_id", - "type": "string", - "sql": "tenant_id", - "samples": ["T12345", "T67890"] - }, - { - "name": "day", - "type": "date", - "sql": "day", - "samples": ["2023-01-01", "2023-02-01"] - }, - { - "name": "store_id", - "type": "string", - "sql": "store_id", - "samples": ["S12345", "S67890"] - } - ], - "joins": [ - { - "name": "Contracts", - "join_type": "right", - "sql": "${Fees.contract_id} = ${Contracts.contract_code}" - } - ] - } -] - -# DATABASE -{% for df in context.dfs %}{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %} - -Take a deep breath and reason step by step. Create one json schema for these tables, similar to the sample provided. Also create joins, if any. diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl deleted file mode 100644 index 06fa68338..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/semantic_agent_prompt.tmpl +++ /dev/null @@ -1,6 +0,0 @@ -=== SemanticAgent === -{% include 'shared/vectordb_docs.tmpl' with context %} -# SCHEMA -{{schema}} - -{{ context.memory.get_last_message() }} \ No newline at end of file diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/shared/dataframe.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/shared/dataframe.tmpl deleted file mode 100644 index 1e9f9785e..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/shared/dataframe.tmpl +++ /dev/null @@ -1 +0,0 @@ -{{ df.to_string(index=index-1, serializer=context.config.dataframe_serializer, enforce_privacy=context.config.enforce_privacy) }} \ No newline at end of file diff --git a/pandasai/ee/agents/semantic_agent/prompts/templates/shared/vectordb_docs.tmpl b/pandasai/ee/agents/semantic_agent/prompts/templates/shared/vectordb_docs.tmpl deleted file mode 100644 index 0fe6be43a..000000000 --- a/pandasai/ee/agents/semantic_agent/prompts/templates/shared/vectordb_docs.tmpl +++ /dev/null @@ -1,8 +0,0 @@ -{% if context.vectorstore %}{% set documents = context.vectorstore.get_relevant_qa_documents(context.memory.get_last_message()) %} -{% if documents|length > 0%}You can utilize these examples as a reference for generating json.{% endif %} -{% for document in documents %} -{{ document}}{% endfor %}{% endif %} -{% if context.vectorstore %}{% set documents = context.vectorstore.get_relevant_docs_documents(context.memory.get_last_message()) %} -{% if documents|length > 0%}Here are additional documents for reference. Feel free to use them to answer.{% endif %} -{% for document in documents %}{{ document}} -{% endfor %}{% endif %} \ No newline at end of file diff --git a/pandasai/ee/connectors/relations.py b/pandasai/ee/connectors/relations.py deleted file mode 100644 index 8c91cc9b8..000000000 --- a/pandasai/ee/connectors/relations.py +++ /dev/null @@ -1,25 +0,0 @@ -from abc import abstractmethod - - -class AbstractRelation: - @abstractmethod - def to_string(self): - raise NotImplementedError - - -class PrimaryKey(AbstractRelation): - def __init__(self, name): - self.name = name - - def to_string(self): - return f"PRIMARY KEY ({self.name})" - - -class ForeignKey(AbstractRelation): - def __init__(self, field, foreign_table, foreign_table_field): - self.field = field - self.foreign_table_field = foreign_table_field - self.foreign_table = foreign_table - - def to_string(self): - return f"FOREIGN KEY ({self.field}) REFERENCES {self.foreign_table}({self.foreign_table_field})" diff --git a/pandasai/ee/helpers/json_helper.py b/pandasai/ee/helpers/json_helper.py deleted file mode 100644 index a7ca0bce2..000000000 --- a/pandasai/ee/helpers/json_helper.py +++ /dev/null @@ -1,14 +0,0 @@ -import json - - -def extract_json_from_json_str(json_str): - start_index = json_str.find("```json") - - end_index = json_str.find("```", start_index) - - if start_index == -1: - return json.loads(json_str) - - json_data = json_str[(start_index + len("```json")) : end_index].strip() - - return json.loads(json_data) diff --git a/pandasai/ee/helpers/query_builder.py b/pandasai/ee/helpers/query_builder.py deleted file mode 100644 index 2a613fa43..000000000 --- a/pandasai/ee/helpers/query_builder.py +++ /dev/null @@ -1,533 +0,0 @@ -import re - -from pandasai.exceptions import InvalidSchemaJson - -MISSING_TABLE_NAME_MESSAGE = "All measures, dimensions, timeDimensions, order and filters must have the format Table_Name.Dimension or Table_Name.Measure" -TABLE_NOT_FOUND_MESSAGE = "Table {0} Doesn't exist" - - -class QueryBuilder: - """ - Creates query from json structure - """ - - def __init__(self, schema): - self.schema = schema - self.supported_aggregations = {"sum", "count", "avg", "min", "max"} - self.supported_granularities = { - "year", - "month", - "day", - "hour", - "minute", - "second", - } - self.supported_date_ranges = { - "last week", - "last month", - "this month", - "this week", - "today", - "this year", - "last year", - } - - def generate_sql(self, query): - self._validate_query(query) - measures = query.get("measures", []) - dimensions = query.get("dimensions", []) - time_dimensions = query.get("timeDimensions", []) - filters = query.get("filters", []) - - columns = self._generate_columns(dimensions, time_dimensions, measures) - - referenced_tables = self._get_referenced_tables( - dimensions, time_dimensions, measures, filters - ) - main_table_entry = self._get_main_table_entry(measures, dimensions) - - if not main_table_entry: - raise ValueError("Table not found in schema.") - - sql = self._build_select_clause(columns) - sql += self._build_from_clause(main_table_entry) - sql += self._build_joins_clause(main_table_entry, referenced_tables) - sql += self._build_where_clause(filters, time_dimensions) - sql += self._build_group_by_clause(dimensions, time_dimensions) - sql += self._build_having_clause(filters) - sql += self._build_order_clause(query) - sql += self._build_limit_clause(query) - - return sql - - def _validate_table(self, value: str): - value_splitted = value.split(".") - if len(value_splitted) == 1: - raise InvalidSchemaJson(MISSING_TABLE_NAME_MESSAGE) - - table = self.find_table(value_splitted[0]) - if not table: - raise InvalidSchemaJson(TABLE_NOT_FOUND_MESSAGE.format(value_splitted[0])) - - def _validate_query(self, query: dict): - for measure in query.get("measures", []): - self._validate_table(measure) - - for dimension in query.get("dimensions", []): - self._validate_table(dimension) - - for dimension in query.get("timeDimensions", []): - self._validate_table(dimension["dimension"]) - - for order in query.get("order", []): - self._validate_table(order["id"]) - - for filter in query.get("filters", []): - self._validate_table(filter["member"]) - - def _generate_columns(self, dimensions, time_dimensions, measures): - all_dimensions = list(dict.fromkeys(dimensions)) - # + [td["dimension"] for td in time_dimensions] - columns = [] - - for dim in all_dimensions: - table = self.find_table(dim.split(".")[0])["table"] - dimension_info = self.find_dimension(dim) - sql_expr = dimension_info.get("sql") - name = dimension_info["name"] - if sql_expr: - columns.append(f"`{table}`.`{sql_expr}` AS {name}") - else: - columns.append(f"{name}") - - for measure in measures: - table = self.find_table(measure.split(".")[0])["table"] - measure_info = self.find_measure(measure) - if measure_info["type"] not in self.supported_aggregations: - raise ValueError( - f"Unsupported aggregation type '{measure_info['type']}' for measure '{measure_info['name']}'. Supported types are: {', '.join(self.supported_aggregations)}" - ) - sql_expr = measure_info.get("sql") or measure_info["name"] - columns.append( - f"{measure_info['type'].upper()}(`{table}`.`{sql_expr}`) AS {measure_info['name']}" - ) - - for time_dimension in time_dimensions: - columns.append(self._generate_time_dimension_column(time_dimension)) - - return list(dict.fromkeys(columns)) # preserve order and return unique columns - - def _validate_and_fix_mapped_measure(self, value): - value_splitted = value.split(".") - if len(value_splitted) == 1: - table_name = self._find_table_name_in_measure_if_not_exists( - value_splitted[0] - ) - if table_name is None: - raise ValueError( - "Measure must have table expected format is TableName.measure" - ) - return f"{table_name}.{value_splitted[0]}" - return value - - def _validate_and_fix_mapped_dimension(self, value): - value_splitted = value.split(".") - if len(value_splitted) == 1: - table_name = self._find_table_name_in_dimension_if_not_exists( - value_splitted[0] - ) - if table_name is None: - raise ValueError( - "Measure must have table expected format is TableName.measure" - ) - return f"{table_name}.{value_splitted[0]}" - return value - - def _validate_and_fix_mapped_order(self, value): - value_splitted = value.split(".") - if len(value_splitted) == 1: - table_name = self._find_table_name_in_orders_if_not_exists( - value_splitted[0] - ) - if table_name is None: - raise ValueError( - "Measure must have table expected format is TableName.measure" - ) - return f"{table_name}.{value_splitted[0]}" - return value - - def _find_table_name_in_filter_if_not_exists(self, filter_name: str): - """ - Find and add table name if not exists in Measure - """ - for table in self.schema: - for dimension in table["dimensions"]: - if dimension["name"] == filter_name: - return table["name"] - - return None - - def _find_table_name_in_measure_if_not_exists(self, measure_name: str): - """ - Find and add table name if not exists in Measure - """ - for table in self.schema: - for measure in table["measures"]: - if measure["name"] == measure_name: - return table["name"] - - return None - - def _find_table_name_in_dimension_if_not_exists(self, dimension_name: str): - """ - Find and add table name if not exists in Measure - """ - for table in self.schema: - for dimension in table["dimensions"]: - if dimension["name"] == dimension_name: - return table["name"] - - return None - - def _find_table_name_in_orders_if_not_exists(self, dimension_name: str): - """ - Find and add table name if not exists in Measure - """ - for table in self.schema: - for dimension in table["dimensions"]: - if dimension["name"] == dimension_name: - return table["name"] - - for measure in table["measures"]: - if measure["name"] == dimension_name: - return table["name"] - - return None - - def _generate_time_dimension_column(self, time_dimension): - dimension = time_dimension["dimension"] - granularity = ( - time_dimension["granularity"] if "granularity" in time_dimension else "day" - ) - - if granularity not in self.supported_granularities: - raise ValueError( - f"Unsupported granularity '{granularity}'. Supported granularities are: {', '.join(self.supported_granularities)}" - ) - - table = self.find_table(dimension.split(".")[0])["table"] - dimension_info = self.find_dimension(dimension) - sql_expr = f"`{table}`.`{dimension_info['sql']}`" - - granularity_sql = { - "year": f"YEAR({sql_expr})", - "month": f"DATE_FORMAT({sql_expr}, '%Y-%m')", - "day": f"DATE_FORMAT({sql_expr}, '%Y-%m-%d')", - "hour": f"HOUR({sql_expr})", - "minute": f"MINUTE({sql_expr})", - "second": f"SECOND({sql_expr})", - } - - if granularity not in granularity_sql: - raise ValueError(f"Unhandled granularity: {granularity}") - - return f"{granularity_sql[granularity]} AS {dimension_info['name']}_by_{granularity}" - - def _get_referenced_tables(self, dimensions, time_dimensions, measures, filters): - return ( - {measure.split(".")[0] for measure in measures} - | {dim.split(".")[0] for dim in dimensions} - | {td["dimension"].split(".")[0] for td in time_dimensions} - | {filter["member"].split(".")[0] for filter in filters} - ) - - def _get_main_table_entry(self, measures, dimensions): - main_table = ( - measures[0].split(".")[0] if measures else dimensions[0].split(".")[0] - ) - return next( - (table for table in self.schema if table["name"] == main_table), None - ) - - def _build_select_clause(self, columns): - return "SELECT " + ", ".join(columns) - - def _build_from_clause(self, main_table_entry): - return f" FROM `{main_table_entry['table']}`" - - def _build_joins_clause(self, main_table_entry, referenced_tables): - sql = "" - main_table = main_table_entry["name"] - - for table_name in referenced_tables: - if table_name != main_table: - table_entry = next( - (table for table in self.schema if table["name"] == table_name), - None, - ) - if not table_entry: - raise ValueError(f"Table '{table_name}' not found in schema.") - if "joins" in table_entry and ( - join := next( - ( - j - for j in table_entry["joins"] - if j["name"] in {main_table, table_name} - ), - None, - ) - ): - join_condition = self.resolve_template_literals(join["sql"]) - sql += f" {join['join_type'].upper()} JOIN `{table_entry['table']}` ON {join_condition}" - - return sql - - def _build_where_clause(self, filters, time_dimensions): - filter_statements = [ - self.process_filter(filter) - for filter in filters - if self.find_dimension(filter["member"]).get("name") is not None - ] - time_dimension_filters = [ - self.resolve_date_range(td) for td in time_dimensions if "dateRange" in td - ] - filter_statements.extend(time_dimension_filters) - - return f" WHERE {' AND '.join(filter_statements)}" if filter_statements else "" - - def _build_group_by_clause(self, dimensions, time_dimensions): - if not (time_dimensions or dimensions): - return "" - - group_by_dimensions = [ - self.find_dimension(dim)["name"] for dim in dimensions - ] + [ - f"{self.find_dimension(td['dimension'])['name']}_by_{td.get('granularity', 'day')}" - for td in time_dimensions - ] - - return " GROUP BY " + ", ".join(group_by_dimensions) - - def _build_having_clause(self, filters): - filter_statements = [ - self.process_filter(filter) - for filter in filters - if self.find_measure(filter["member"]).get("name") is not None - ] - - return f" HAVING {' AND '.join(filter_statements)}" if filter_statements else "" - - def _build_order_clause(self, query): - if "order" not in query or len(query["order"]) == 0: - return "" - - order_clauses = [] - for order in query["order"]: - name = None - if measure := self.find_measure(order["id"]): - name = measure["name"] - - if ( - name is None - and "timeDimensions" in query - and len(query["timeDimensions"]) > 0 - ): - for time_dimension in query["timeDimensions"]: - if ( - dimension - := f"{self.find_dimension(order['id'])['name']}_by_{time_dimension['granularity']}" - ): - name = dimension - - if name is None and "dimensions" in query and len(query["dimensions"]) > 0: - if dimension := self.find_dimension(order["id"]): - name = dimension["name"] - - if name is None: - name = ( - self.find_measure(order["id"]) or self.find_dimension(order["id"]) - )["name"] - - order_clauses.append(f"{name} {order['direction']}") - - return f" ORDER BY {', '.join(order_clauses)}" - - def _build_limit_clause(self, query): - return f" LIMIT {query['limit']}" if "limit" in query else "" - - def resolve_date_range(self, time_dimension): - dimension = time_dimension["dimension"] - date_range = time_dimension["dateRange"] - table_name = dimension.split(".")[0] - dimension_info = self.find_dimension(dimension) - table = self.find_table(table_name) - - if not table or not dimension_info: - raise ValueError(f"Dimension '{dimension}' not found in schema.") - - table_column = f"`{table['table']}`.`{dimension_info['sql']}`" - - if isinstance(date_range, list) and len(date_range) == 2: - start_date, end_date = date_range - return f"{table_column} BETWEEN '{start_date}' AND '{end_date}'" - else: - if isinstance(date_range, list) and len(date_range) == 1: - date_range = date_range[0] - - if date_range not in self.supported_date_ranges: - raise ValueError(f"Unsupported date range: {date_range}") - - if date_range == "last week": - return f"{table_column} >= CURRENT_DATE - INTERVAL '1 week' AND {table_column} < CURRENT_DATE" - elif date_range == "last month": - return f"{table_column} >= CURRENT_DATE - INTERVAL '1 month' AND {table_column} < CURRENT_DATE" - elif date_range == "this month": - return f"{table_column} >= DATE_TRUNC('month', CURRENT_DATE) AND {table_column} < DATE_TRUNC('month', CURRENT_DATE) + INTERVAL '1 month'" - elif date_range == "this week": - return f"{table_column} >= DATE_TRUNC('week', CURRENT_DATE) AND {table_column} < DATE_TRUNC('week', CURRENT_DATE) + INTERVAL '1 week'" - elif date_range == "today": - return f"{table_column} >= DATE_TRUNC('day', CURRENT_DATE) AND {table_column} < DATE_TRUNC('day', CURRENT_DATE) + INTERVAL '1 day'" - elif date_range == "this year": - return f"{table_column} >= DATE_TRUNC('year', CURRENT_DATE) AND {table_column} < DATE_TRUNC('year', CURRENT_DATE) + INTERVAL '1 year'" - elif date_range == "last year": - return f"{table_column} >= DATE_TRUNC('year', CURRENT_DATE - INTERVAL '1 year') AND {table_column} < DATE_TRUNC('year', CURRENT_DATE)" - - def process_filter(self, filter): - required_keys = ["member", "operator", "values"] - - # Check if any required key is missing or if "values" is empty - if any(key not in filter for key in required_keys) or ( - not filter.get("values") - and filter.get("operator", None) not in ["set", "notSet"] - ): - raise ValueError(f"Invalid filter: {filter}") - - table_name = filter["member"].split(".")[0] - dimension = self.find_dimension(filter["member"]) - measure = self.find_measure(filter["member"]) - - if dimension: - table_column = f"`{self.find_table(table_name)['table']}`.`{dimension.get('sql', dimension['name'])}`" - elif measure: - table_column = f"{measure['type'].upper()}(`{self.find_table(table_name)['table']}`.`{measure.get('sql', measure['name'])}`)" - else: - raise ValueError(f"Member '{filter['member']}' not found in schema.") - - operator = filter["operator"] - values = filter["values"] - - single_value_operators = { - "equals": "=", - "notEquals": "!=", - "contains": "LIKE", - "notContains": "NOT LIKE", - "startsWith": "LIKE", - "endsWith": "LIKE", - "gt": ">", - "gte": ">=", - "lt": "<", - "lte": "<=", - "beforeDate": "<", - "afterDate": ">", - "in": "IN", - } - - multi_value_operators = {"equals": "IN", "notEquals": "NOT IN"} - - return self._build_query_condition( - operator, - table_column, - values, - single_value_operators, - multi_value_operators, - ) - - def _build_query_condition( - self, - operator, - table_column, - values, - single_value_operators, - multi_value_operators, - ): - if operator in single_value_operators: - if operator in ["equals", "notEquals", "in"]: - if len(values) == 1: - operator_str = "=" if operator == "equals" else "!=" - return f"{table_column} {operator_str} '{values[0]}'" - else: - operator_str = "IN" if operator in ["equals", "in"] else "NOT IN" - formatted_values = "', '".join(values) - return f"{table_column} {operator_str} ('{formatted_values}')" - - elif operator in ["contains", "notContains", "startsWith", "endsWith"]: - pattern = { - "contains": f"%{values[0]}%", - "notContains": f"%{values[0]}%", - "startsWith": f"{values[0]}%", - "endsWith": f"%{values[0]}", - }[operator] - return f"{table_column} {single_value_operators[operator]} '{pattern}'" - - else: - value = f"'{values[0]}'" if isinstance(values[0], str) else values[0] - return f"{table_column} {single_value_operators[operator]} {value}" - - elif operator in multi_value_operators: - formatted_values = "', '".join(values) - return f"{table_column} {multi_value_operators[operator]} ('{formatted_values}')" - - elif operator == "set": - return f"{table_column} IS NOT NULL" - - elif operator == "notSet": - return f"{table_column} IS NULL" - - elif operator in ["inDateRange", "notInDateRange"]: - if len(values) != 2: - raise ValueError(f"Invalid number of values for '{operator}' operator.") - range_operator = "BETWEEN" if operator == "inDateRange" else "NOT BETWEEN" - return f"{table_column} {range_operator} '{values[0]}' AND '{values[1]}'" - - else: - raise ValueError(f"Unsupported operator: {operator}") - - def resolve_template_literals(self, template): - def replace_column(match): - table, column = match.group(1).split(".") - new_table = self.find_table(table) - if not new_table: - raise ValueError(f"Table '{table}' not found in schema.") - new_column = next( - (dim for dim in new_table["dimensions"] if dim["name"] == column), None - ) - if not new_column: - raise ValueError(f"Column '{column}' not found in schema.") - return f"`{new_table['table']}`.`{new_column['sql']}`" - - return re.sub(r"\$\{([^}]+)\}", replace_column, template) - - def find_table(self, table_name): - return next((table for table in self.schema if table["name"] == table_name), {}) - - def find_dimension(self, dimension): - table_name, dim_name = dimension.split(".") - table = self.find_table(table_name) - dim = next( - (dim for dim in table.get("dimensions", []) if dim.get("name") == dim_name), - {}, - ) - return dim - - def find_measure(self, measure): - table_name, measure_name = measure.split(".") - table = self.find_table(table_name) - meas = next( - ( - meas - for meas in table.get("measures", []) - if meas.get("name") == measure_name - ), - {}, - ) - return meas diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py index 67522a324..33525fd29 100644 --- a/pandasai/helpers/dataframe_serializer.py +++ b/pandasai/helpers/dataframe_serializer.py @@ -55,7 +55,7 @@ def convert_df_to_csv(self, df: pd.DataFrame, extras: dict) -> str: dataframe_info += ">" # Add dataframe details - dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.to_csv(index=False)}" + dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.head().to_csv(index=False)}" # Close the dataframe tag dataframe_info += "\n" diff --git a/pandasai/pipelines/chat/code_cleaning.py b/pandasai/pipelines/chat/code_cleaning.py index 57ad4b757..6d7f0c89c 100644 --- a/pandasai/pipelines/chat/code_cleaning.py +++ b/pandasai/pipelines/chat/code_cleaning.py @@ -1,10 +1,9 @@ -from __future__ import annotations import ast import copy import re import traceback import uuid -from typing import TYPE_CHECKING, Any, List, Union +from typing import Any, List, Union import astor from pandasai.helpers.optional import get_environment @@ -24,9 +23,6 @@ from ..logic_unit_output import LogicUnitOutput from ..pipeline_context import PipelineContext -if TYPE_CHECKING: - from pandasai.dataframe.base import DataFrame - class CodeExecutionContext: def __init__( @@ -235,45 +231,11 @@ def find_function_calls(self, node: ast.AST): def check_direct_sql_func_def_exists(self, node: ast.AST): return ( - self._validate_direct_sql(self._dfs) + self._config.direct_sql and isinstance(node, ast.FunctionDef) and node.name == "execute_sql_query" ) - def _validate_direct_sql(self, dfs: List[DataFrame]) -> bool: - """ - Raises error if they don't belong sqlconnector or have different credentials - Args: - dfs (List[DataFrame]): list of DataFrames - - Raises: - InvalidConfigError: Raise Error in case of config is set but criteria is not met - """ - - return self._config.direct_sql - # if self._config.direct_sql: - # return True - # else: - # return - # TODO - while working on direct sql - # if all( - # ( - # hasattr(df, "is_sql_connector") - # and df.is_sql_connector - # and df.equals(dfs[0]) - # ) - # for df in dfs - # ) or all( - # (isinstance(df, PandasConnector) and df.sql_enabled) for df in dfs - # ): - # return True - # else: - # raise InvalidConfigError( - # "Direct SQL requires all connectors to be SQL connectors and they must belong to the same datasource " - # "and have the same credentials" - # ) - # return False - def _replace_table_names( self, sql_query: str, table_names: list, allowed_table_names: list ): @@ -303,9 +265,10 @@ def _clean_sql_query(self, sql_query: str) -> str: """ sql_query = sql_query.rstrip(";") table_names = extract_table_names(sql_query) - allowed_table_names = {df.name: df.cs_table_name for df in self._dfs} | { - f'"{df.name}"': df.cs_table_name for df in self._dfs + allowed_table_names = {df.name: df.name for df in self._dfs} | { + f'"{df.name}"': df.name for df in self._dfs } + print(allowed_table_names) return self._replace_table_names(sql_query, table_names, allowed_table_names) def _validate_and_make_table_name_case_sensitive(self, node: ast.Assign): @@ -499,7 +462,7 @@ def _clean_code(self, code: str, context: CodeExecutionContext) -> str: # if generated code contain execute_sql_query usage if ( - self._validate_direct_sql(self._dfs) + self._config.direct_sql and "execute_sql_query" in self._function_call_visitor.function_calls ): execute_sql_query_used = True diff --git a/pandasai/pipelines/chat/code_execution.py b/pandasai/pipelines/chat/code_execution.py index a408137c9..f62cb9fa6 100644 --- a/pandasai/pipelines/chat/code_execution.py +++ b/pandasai/pipelines/chat/code_execution.py @@ -15,7 +15,6 @@ from ..base_logic_unit import BaseLogicUnit from ..pipeline_context import PipelineContext from .code_cleaning import CodeExecutionContext -import pandas as pd class CodeExecution(BaseLogicUnit): @@ -151,12 +150,12 @@ def execute_code(self, code: str, context: CodeExecutionContext) -> Any: # if the code does not need them dfs = self._required_dfs(code) environment: dict = get_environment(self._additional_dependencies) - environment["dfs"] = self._get_originals(dfs) + environment["dfs"] = dfs if len(environment["dfs"]) == 1: environment["df"] = environment["dfs"][0] if self._config.direct_sql: - environment["execute_sql_query"] = self._dfs[0].execute_direct_sql_query + environment["execute_sql_query"] = self._dfs[0].execute_sql_query # Execute the code exec(code, environment) @@ -193,31 +192,31 @@ def _required_dfs(self, code: str) -> List[str]: required_dfs.append(None) return required_dfs or self._dfs - def _get_originals(self, dfs): - """ - Get original dfs - - Args: - dfs (list): List of dfs - - Returns: - list: List of dfs - """ - original_dfs = [] - for df in dfs: - # TODO - Check why this None check is there - if df is None: - original_dfs.append(None) - continue - - if isinstance(df, pd.DataFrame): - original_dfs.append(df) - else: - # Execute to fetch only if not dataframe - df.execute() - original_dfs.append(df.pandas_df) - - return original_dfs + # def _get_originals(self, dfs): + # """ + # Get original dfs + + # Args: + # dfs (list): List of dfs + + # Returns: + # list: List of dfs + # """ + # original_dfs = [] + # for df in dfs: + # # TODO - Check why this None check is there + # if df is None: + # original_dfs.append(None) + # continue + + # if isinstance(df, pd.DataFrame): + # original_dfs.append(df) + # else: + # # Execute to fetch only if not dataframe + # df.execute() + # original_dfs.append(df.pandas_df) + + # return original_dfs def _retry_run_code( self, diff --git a/pandasai/pipelines/chat/validate_pipeline_input.py b/pandasai/pipelines/chat/validate_pipeline_input.py index 2868d62b6..bb197fba0 100644 --- a/pandasai/pipelines/chat/validate_pipeline_input.py +++ b/pandasai/pipelines/chat/validate_pipeline_input.py @@ -1,14 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, List -from pandasai.exceptions import InvalidConfigError +from typing import Any from pandasai.pipelines.logic_unit_output import LogicUnitOutput from ..base_logic_unit import BaseLogicUnit from ..pipeline_context import PipelineContext -if TYPE_CHECKING: - from pandasai.dataframe.base import DataFrame - class ValidatePipelineInput(BaseLogicUnit): """ @@ -17,30 +13,6 @@ class ValidatePipelineInput(BaseLogicUnit): pass - def _validate_direct_sql(self, dfs: List[DataFrame]) -> bool: - """ - Validates that all connectors are SQL connectors and belong to the same datasource - when direct_sql is True. - """ - - if not self.context.config.direct_sql: - return False - - if not all(hasattr(df, "is_sql_connector") for df in dfs): - raise InvalidConfigError( - "Direct SQL requires all connectors to be SQLConnectors" - ) - - if len(dfs) > 1: - first_connector = dfs[0] - if not all(connector.equals(first_connector) for connector in dfs[1:]): - raise InvalidConfigError( - "Direct SQL requires all connectors to belong to the same datasource " - "and have the same credentials" - ) - - return True - def execute(self, input: Any, **kwargs) -> Any: """ This method validates pipeline context and configs @@ -54,5 +26,4 @@ def execute(self, input: Any, **kwargs) -> Any: :return: The result of the execution. """ self.context: PipelineContext = kwargs.get("context") - self._validate_direct_sql(self.context.dfs) return LogicUnitOutput(input, True, "Input Validation Successful") diff --git a/tests/unit_tests/dataframe/test_loader.py b/tests/unit_tests/dataframe/test_loader.py index 94dd320eb..68800eeb8 100644 --- a/tests/unit_tests/dataframe/test_loader.py +++ b/tests/unit_tests/dataframe/test_loader.py @@ -2,7 +2,7 @@ from unittest.mock import patch, mock_open import pandas as pd from pandasai.dataframe.base import DataFrame -from pandasai.dataframe.loader import DatasetLoader +from pandasai.data_loader.loader import DatasetLoader from datetime import datetime, timedelta diff --git a/tests/unit_tests/dataframe/test_query_builder.py b/tests/unit_tests/dataframe/test_query_builder.py index 1db13df75..431d4a52a 100644 --- a/tests/unit_tests/dataframe/test_query_builder.py +++ b/tests/unit_tests/dataframe/test_query_builder.py @@ -1,5 +1,5 @@ import pytest -from pandasai.dataframe.query_builder import QueryBuilder +from pandasai.data_loader.query_builder import QueryBuilder class TestQueryBuilder: diff --git a/tests/unit_tests/ee/helpers/schema.py b/tests/unit_tests/ee/helpers/schema.py deleted file mode 100644 index f82ac7365..000000000 --- a/tests/unit_tests/ee/helpers/schema.py +++ /dev/null @@ -1,88 +0,0 @@ -VIZ_QUERY_SCHEMA = [ - { - "name": "Orders", - "table": "orders", - "measures": [ - {"name": "order_count", "type": "count"}, - {"name": "total_freight", "type": "sum", "sql": "freight"}, - ], - "dimensions": [ - {"name": "order_id", "type": "int", "sql": "order_id"}, - {"name": "customer_id", "type": "string", "sql": "customer_id"}, - {"name": "employee_id", "type": "int", "sql": "employee_id"}, - {"name": "order_date", "type": "date", "sql": "order_date"}, - {"name": "required_date", "type": "date", "sql": "required_date"}, - {"name": "shipped_date", "type": "date", "sql": "shipped_date"}, - {"name": "ship_via", "type": "int", "sql": "ship_via"}, - {"name": "ship_name", "type": "string", "sql": "ship_name"}, - {"name": "ship_address", "type": "string", "sql": "ship_address"}, - {"name": "ship_city", "type": "string", "sql": "ship_city"}, - {"name": "ship_region", "type": "string", "sql": "ship_region"}, - {"name": "ship_postal_code", "type": "string", "sql": "ship_postal_code"}, - {"name": "ship_country", "type": "string", "sql": "ship_country"}, - ], - "joins": [], - } -] - -VIZ_QUERY_SCHEMA_STR = '[{"name":"Orders","table":"orders","measures":[{"name":"order_count","type":"count"},{"name":"total_freight","type":"sum","sql":"freight"}],"dimensions":[{"name":"order_id","type":"int","sql":"order_id"},{"name":"customer_id","type":"string","sql":"customer_id"},{"name":"employee_id","type":"int","sql":"employee_id"},{"name":"order_date","type":"date","sql":"order_date"},{"name":"required_date","type":"date","sql":"required_date"},{"name":"shipped_date","type":"date","sql":"shipped_date"},{"name":"ship_via","type":"int","sql":"ship_via"},{"name":"ship_name","type":"string","sql":"ship_name"},{"name":"ship_address","type":"string","sql":"ship_address"},{"name":"ship_city","type":"string","sql":"ship_city"},{"name":"ship_region","type":"string","sql":"ship_region"},{"name":"ship_postal_code","type":"string","sql":"ship_postal_code"},{"name":"ship_country","type":"string","sql":"ship_country"}],"joins":[]}]' -VIZ_QUERY_SCHEMA_OBJ = '{"name":"Orders","table":"orders","measures":[{"name":"order_count","type":"count"},{"name":"total_freight","type":"sum","sql":"freight"}],"dimensions":[{"name":"order_id","type":"int","sql":"order_id"},{"name":"customer_id","type":"string","sql":"customer_id"},{"name":"employee_id","type":"int","sql":"employee_id"},{"name":"order_date","type":"date","sql":"order_date"},{"name":"required_date","type":"date","sql":"required_date"},{"name":"shipped_date","type":"date","sql":"shipped_date"},{"name":"ship_via","type":"int","sql":"ship_via"},{"name":"ship_name","type":"string","sql":"ship_name"},{"name":"ship_address","type":"string","sql":"ship_address"},{"name":"ship_city","type":"string","sql":"ship_city"},{"name":"ship_region","type":"string","sql":"ship_region"},{"name":"ship_postal_code","type":"string","sql":"ship_postal_code"},{"name":"ship_country","type":"string","sql":"ship_country"}],"joins":[]}' - - -STARS_SCHEMA = [ - { - "name": "Users", - "table": "users", - "measures": [{"name": "user_count", "type": "count", "sql": "login"}], - "dimensions": [ - {"name": "login", "type": "string", "sql": "login"}, - {"name": "starred_at", "type": "datetime", "sql": "starredAt"}, - {"name": "profile_url", "type": "string", "sql": "profileUrl"}, - {"name": "location", "type": "string", "sql": "location"}, - {"name": "company", "type": "string", "sql": "company"}, - ], - } -] - - -MULTI_JOIN_SCHEMA = [ - { - "name": "Sales", - "table": "sales", - "measures": [ - {"name": "total_revenue", "type": "sum", "sql": "revenue"}, - {"name": "total_sales", "type": "count", "sql": "id"}, - ], - "dimensions": [ - {"name": "product", "type": "string", "sql": "product"}, - {"name": "region", "type": "string", "sql": "region"}, - {"name": "sales_date", "type": "date", "sql": "sales_date"}, - {"name": "id", "type": "string", "sql": "id"}, - ], - "joins": [ - { - "name": "Engagement", - "join_type": "left", - "sql": "${Sales.id} = ${Engagement.id}", - } - ], - }, - { - "name": "Engagement", - "table": "engagement", - "measures": [{"name": "total_duration", "type": "sum", "sql": "duration"}], - "dimensions": [ - {"name": "id", "type": "string", "sql": "id"}, - {"name": "user_id", "type": "string", "sql": "user_id"}, - {"name": "activity_type", "type": "string", "sql": "activity_type"}, - {"name": "engagement_date", "type": "date", "sql": "engagement_date"}, - ], - "joins": [ - { - "name": "Sales", - "join_type": "right", - "sql": "${Engagement.id} = ${Sales.id}", - } - ], - }, -] diff --git a/tests/unit_tests/ee/helpers/test_semantic_agent_query_builder.py b/tests/unit_tests/ee/helpers/test_semantic_agent_query_builder.py deleted file mode 100644 index 70f4aa2e7..000000000 --- a/tests/unit_tests/ee/helpers/test_semantic_agent_query_builder.py +++ /dev/null @@ -1,230 +0,0 @@ -import unittest - -from pandasai.ee.helpers.query_builder import QueryBuilder -from tests.unit_tests.ee.helpers.schema import MULTI_JOIN_SCHEMA, VIZ_QUERY_SCHEMA - - -class TestSemanticAgentQueryBuilder(unittest.TestCase): - def test_constructor(self): - query_builder = QueryBuilder(VIZ_QUERY_SCHEMA) - assert query_builder.schema == VIZ_QUERY_SCHEMA - assert query_builder.supported_aggregations == { - "sum", - "count", - "avg", - "min", - "max", - } - assert query_builder.supported_granularities == { - "year", - "month", - "day", - "hour", - "minute", - "second", - } - assert query_builder.supported_date_ranges == { - "last week", - "last month", - "this month", - "this week", - "today", - "this year", - "last year", - } - - def test_sql_with_json(self): - query_builder = QueryBuilder(VIZ_QUERY_SCHEMA) - - json_str = { - "type": "bar", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.order_count"], - "timeDimensions": [], - "options": { - "xLabel": "Country", - "yLabel": "Number of Orders", - "title": "Orders Count by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [], - "order": [{"id": "Orders.order_count", "direction": "asc"}], - } - sql_query = query_builder.generate_sql(json_str) - assert sql_query in [ - "SELECT COUNT(`orders`.`order_count`) AS order_count, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country ORDER BY order_count asc", - "SELECT `orders`.`ship_country` AS ship_country, COUNT(`orders`.`order_count`) AS order_count FROM `orders` GROUP BY ship_country ORDER BY order_count asc", - ] - - def test_sql_with_filters_in_json(self): - query_builder = QueryBuilder(VIZ_QUERY_SCHEMA) - - json_str = { - "type": "bar", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.total_freight"], - "timeDimensions": [], - "options": { - "xLabel": "Country", - "yLabel": "Total Freight", - "title": "Total Freight by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [ - {"member": "Orders.total_freight", "operator": "gt", "values": [0]} - ], - "order": [{"id": "Orders.total_freight", "direction": "asc"}], - } - sql_query = query_builder.generate_sql(json_str) - assert sql_query in [ - "SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) > 0 ORDER BY total_freight asc", - "SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) > 0 ORDER BY total_freight asc", - ] - - def test_sql_with_filters_on_dimension(self): - query_builder = QueryBuilder(VIZ_QUERY_SCHEMA) - - json_str = { - "type": "bar", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.total_freight"], - "timeDimensions": [], - "options": { - "xLabel": "Country", - "yLabel": "Total Freight", - "title": "Total Freight by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [ - { - "member": "Orders.ship_country", - "operator": "equals", - "values": ["abc"], - } - ], - "order": [{"id": "Orders.total_freight", "direction": "asc"}], - } - sql_query = query_builder.generate_sql(json_str) - assert sql_query in [ - "SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` WHERE `orders`.`ship_country` = 'abc' GROUP BY ship_country ORDER BY total_freight asc", - "SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` WHERE `orders`.`ship_country` = 'abc' GROUP BY ship_country ORDER BY total_freight asc", - ] - - def test_sql_with_filters_without_order(self): - query_builder = QueryBuilder(VIZ_QUERY_SCHEMA) - - json_str = { - "type": "bar", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.total_freight"], - "timeDimensions": [], - "options": { - "xLabel": "Country", - "yLabel": "Total Freight", - "title": "Total Freight by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [ - { - "member": "Orders.ship_country", - "operator": "equals", - "values": ["abc"], - } - ], - } - sql_query = query_builder.generate_sql(json_str) - assert sql_query in [ - "SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` WHERE `orders`.`ship_country` = 'abc' GROUP BY ship_country", - "SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` WHERE `orders`.`ship_country` = 'abc' GROUP BY ship_country", - ] - - def test_sql_with_filters_with_notset_filter(self): - query_builder = QueryBuilder(VIZ_QUERY_SCHEMA) - - json_str = { - "type": "bar", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.total_freight"], - "timeDimensions": [], - "options": { - "xLabel": "Country", - "yLabel": "Total Freight", - "title": "Total Freight by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [ - {"member": "Orders.total_freight", "operator": "notSet", "values": []} - ], - "order": [{"id": "Orders.total_freight", "direction": "asc"}], - } - sql_query = query_builder.generate_sql(json_str) - assert sql_query in [ - "SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NULL ORDER BY total_freight asc", - "SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NULL ORDER BY total_freight asc", - ] - - def test_sql_with_filters_with_set_filter(self): - query_builder = QueryBuilder(VIZ_QUERY_SCHEMA) - - json_str = { - "type": "bar", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.total_freight"], - "timeDimensions": [], - "options": { - "xLabel": "Country", - "yLabel": "Total Freight", - "title": "Total Freight by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [ - { - "member": "Orders.total_freight", - "operator": "set", - "values": [], - } - ], - "order": [{"id": "Orders.total_freight", "direction": "asc"}], - } - sql_query = query_builder.generate_sql(json_str) - assert sql_query in [ - "SELECT SUM(`orders`.`freight`) AS total_freight, `orders`.`ship_country` AS ship_country FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc", - "SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country HAVING SUM(`orders`.`freight`) IS NOT NULL ORDER BY total_freight asc", - ] - - def test_sql_with_filters_with_join(self): - query_builder = QueryBuilder(MULTI_JOIN_SCHEMA) - - json_str = { - "type": "bar", - "dimensions": ["Engagement.activity_type"], - "measures": ["Sales.total_revenue"], - "timeDimensions": [], - "options": { - "xLabel": "Activity Type", - "yLabel": "Total Revenue", - "title": "Total Revenue Generated from Users who Logged in Before Purchase", - "legend": {"display": True, "position": "top"}, - }, - "joins": [ - { - "name": "Engagement", - "join_type": "right", - "sql": "${Sales.id} = ${Engagement.id}", - } - ], - "filters": [ - { - "member": "Engagement.engagement_date", - "operator": "beforeDate", - "values": ["${Sales.sales_date}"], - } - ], - "order": [{"id": "Sales.total_revenue", "direction": "asc"}], - } - sql_query = query_builder.generate_sql(json_str) - - assert ( - sql_query - == "SELECT `engagement`.`activity_type` AS activity_type, SUM(`sales`.`revenue`) AS total_revenue FROM `sales` RIGHT JOIN `engagement` ON `engagement`.`id` = `sales`.`id` WHERE `engagement`.`engagement_date` < '${Sales.sales_date}' GROUP BY activity_type ORDER BY total_revenue asc" - ) diff --git a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py index dfebfd2dc..d33f94fd5 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py @@ -5,7 +5,6 @@ from typing import Optional from unittest.mock import MagicMock, patch -import pandas as pd import pytest from pandasai import Agent @@ -43,7 +42,7 @@ def llm(self, output: Optional[str] = None): @pytest.fixture def sample_df(self): - return pd.DataFrame( + return DataFrame( { "country": [ "United States", @@ -124,12 +123,12 @@ def exec_context(self) -> MagicMock: return CodeExecutionContext(uuid.uuid4()) @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) + @patch("extensions.connectors.sql.pandasai_sql", autospec=True) def sql_connector(self, create_engine): return DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) + @patch("extensions.connectors.sql.pandasai_sql", autospec=True) def pgsql_connector(self, create_engine): return DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) diff --git a/tests/unit_tests/pipelines/test_pipeline.py b/tests/unit_tests/pipelines/test_pipeline.py index d8a0fd488..536e72750 100644 --- a/tests/unit_tests/pipelines/test_pipeline.py +++ b/tests/unit_tests/pipelines/test_pipeline.py @@ -5,11 +5,9 @@ import pytest from pandasai.dataframe.base import DataFrame -from pandasai.ee.agents.judge_agent import JudgeAgent from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.pipelines.pipeline import Pipeline from pandasai.pipelines.pipeline_context import PipelineContext from pandasai.schemas.df_config import Config @@ -163,15 +161,3 @@ def execute(self, data, logger, config, context): result = pipeline_2.run(5) assert result == 8 - - def test_pipeline_constructor_with_judge(self, context): - judge_agent = JudgeAgent() - pipeline = GenerateChatPipeline(context=context, judge=judge_agent) - assert pipeline.judge == judge_agent - assert isinstance(pipeline.context, PipelineContext) - - def test_pipeline_constructor_with_no_judge(self, context): - judge_agent = JudgeAgent() - pipeline = GenerateChatPipeline(context=context, judge=judge_agent) - assert pipeline.judge == judge_agent - assert isinstance(pipeline.context, PipelineContext)