From 237c67bfd2ada36f8720cb388fe77cd76eaa2ec7 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 19 Nov 2024 12:03:47 +0100 Subject: [PATCH] fix(sql): load and work with dataframe --- .../connectors/sql/pandasai_sql/__init__.py | 1 - extensions/connectors/sql/pandasai_sql/sql.py | 657 ------------------ pandasai/__init__.py | 4 - pandasai/agent/agent.py | 12 +- pandasai/connectors/__init__.py | 7 - pandasai/connectors/base.py | 315 --------- pandasai/connectors/pandas.py | 204 ------ pandasai/dataframe/loader.py | 7 +- pandasai/dataframe/query_builder.py | 5 +- pandasai/helpers/dataframe_serializer.py | 33 +- pandasai/pipelines/chat/code_cleaning.py | 55 +- .../pipelines/chat/validate_pipeline_input.py | 10 +- pandasai/pipelines/pipeline.py | 10 +- pandasai/smart_dataframe/__init__.py | 230 ------ pandasai/smart_datalake/__init__.py | 182 ----- tests/unit_tests/agent/test_base_agent.py | 5 +- tests/unit_tests/connectors/__init__.py | 0 tests/unit_tests/connectors/test_base.py | 93 --- tests/unit_tests/connectors/test_pandas.py | 75 -- .../ee/judge_agent/test_judge_agent.py | 229 ------ .../ee/judge_agent/test_judge_llm_call.py | 179 ----- .../ee/judge_agent/test_judge_prompt_gen.py | 178 ----- .../ee/security_agent/test_security_agent.py | 229 ------ .../security_agent/test_security_llm_call.py | 179 ----- .../test_security_prompt_gen.py | 179 ----- .../test__semantic_code_generator.py | 510 -------------- .../ee/semantic_agent/test_semantic_agent.py | 162 ----- .../semantic_agent/test_semantic_llm_call.py | 208 ------ .../test_semantic_semantic_prompt_gen.py | 163 ----- .../test_semantic_validate_pipeline_input.py | 221 ------ .../helpers/test_dataframe_serializer.py | 30 +- .../smart_datalake/test_code_cleaning.py | 79 +-- .../smart_datalake/test_code_generator.py | 4 +- .../test_error_prompt_generation.py | 4 +- .../smart_datalake/test_prompt_generation.py | 9 +- .../smart_datalake/test_result_parsing.py | 4 +- .../smart_datalake/test_result_validation.py | 4 +- .../test_validate_pipeline_input.py | 123 ++-- tests/unit_tests/pipelines/test_pipeline.py | 8 +- .../prompts/test_correct_error_prompt.py | 10 +- .../test_generate_python_code_prompt.py | 52 +- 41 files changed, 233 insertions(+), 4436 deletions(-) delete mode 100644 extensions/connectors/sql/pandasai_sql/sql.py delete mode 100644 pandasai/connectors/__init__.py delete mode 100644 pandasai/connectors/base.py delete mode 100644 pandasai/connectors/pandas.py delete mode 100644 pandasai/smart_dataframe/__init__.py delete mode 100644 pandasai/smart_datalake/__init__.py delete mode 100644 tests/unit_tests/connectors/__init__.py delete mode 100644 tests/unit_tests/connectors/test_base.py delete mode 100644 tests/unit_tests/connectors/test_pandas.py delete mode 100644 tests/unit_tests/ee/judge_agent/test_judge_agent.py delete mode 100644 tests/unit_tests/ee/judge_agent/test_judge_llm_call.py delete mode 100644 tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py delete mode 100644 tests/unit_tests/ee/security_agent/test_security_agent.py delete mode 100644 tests/unit_tests/ee/security_agent/test_security_llm_call.py delete mode 100644 tests/unit_tests/ee/security_agent/test_security_prompt_gen.py delete mode 100644 tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py delete mode 100644 tests/unit_tests/ee/semantic_agent/test_semantic_agent.py delete mode 100644 tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py delete mode 100644 tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py delete mode 100644 tests/unit_tests/ee/semantic_agent/test_semantic_validate_pipeline_input.py diff --git a/extensions/connectors/sql/pandasai_sql/__init__.py b/extensions/connectors/sql/pandasai_sql/__init__.py index 0ef962765..cc9f68f56 100644 --- a/extensions/connectors/sql/pandasai_sql/__init__.py +++ b/extensions/connectors/sql/pandasai_sql/__init__.py @@ -1,4 +1,3 @@ -from .sql import SQLConnector, SqliteConnector, SQLConnectorConfig import pandas as pd diff --git a/extensions/connectors/sql/pandasai_sql/sql.py b/extensions/connectors/sql/pandasai_sql/sql.py deleted file mode 100644 index 42c8e4935..000000000 --- a/extensions/connectors/sql/pandasai_sql/sql.py +++ /dev/null @@ -1,657 +0,0 @@ -""" -SQL connectors are used to connect to SQL databases in different dialects. -""" - -import hashlib -import os -import re -import time -from functools import cache, cached_property -from typing import Optional, Union - -import sqlglot -from sqlalchemy import asc, create_engine, select, text -from sqlalchemy.engine import Connection - -import pandas as pd -from pandasai.exceptions import MaliciousQueryError -from pandasai.helpers.path import find_project_root - -from pandasai.constants import DEFAULT_FILE_PERMISSIONS -from pandasai.connectors.base import BaseConnector, BaseConnectorConfig - - -class SQLBaseConnectorConfig(BaseConnectorConfig): - """ - Base Connector configuration. - """ - - driver: Optional[str] = None - dialect: Optional[str] = None - - -class SqliteConnectorConfig(SQLBaseConnectorConfig): - """ - Connector configurations for sqlite db. - """ - - table: str - database: str - - -class SQLConnectorConfig(SQLBaseConnectorConfig): - """ - Connector configuration. - """ - - host: str - port: int - username: str - password: str - - -class SQLConnector(BaseConnector): - """ - SQL connectors are used to connect to SQL databases in different dialects. - """ - - is_sql_connector = True - _engine = None - _connection: Connection = None - _rows_count: int = None - _columns_count: int = None - _cache_interval: int = 600 # 10 minutes - - def __init__( - self, - config: Union[BaseConnectorConfig, dict], - cache_interval: int = 600, - **kwargs, - ): - """ - Initialize the SQL connector with the given configuration. - - Args: - config (ConnectorConfig): The configuration for the SQL connector. - """ - config = self._load_connector_config(config) - super().__init__(config, **kwargs) - - if config.dialect is None: - raise Exception("SQL dialect must be specified") - - self._init_connection(config) - - self._cache_interval = cache_interval - - # Table to equal to table name for sql connectors - self.name = self.fallback_name - - def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): - """ - Loads passed Configuration to object - - Args: - config (BaseConnectorConfig): Construct config in structure - - Returns: - config: BaseConenctorConfig - """ - return SQLConnectorConfig(**config) - - def _init_connection(self, config: SQLConnectorConfig): - """ - Initialize Database Connection - - Args: - config (SQLConnectorConfig): Configurations to load database - - """ - - if config.driver: - self._engine = create_engine( - f"{config.dialect}+{config.driver}://{config.username}:{config.password}" - f"@{config.host}:{str(config.port)}/{config.database}", - connect_args=config.connect_args, - ) - else: - self._engine = create_engine( - f"{config.dialect}://{config.username}:{config.password}@{config.host}" - f":{str(config.port)}/{config.database}", - connect_args=config.connect_args, - ) - - self._connection = self._engine.connect() - - def __del__(self): - """ - Close the connection to the SQL database. - """ - if self._connection: - self._connection.close() - - def __repr__(self): - """ - Return the string representation of the SQL connector. - - Returns: - str: The string representation of the SQL connector. - """ - return ( - f"<{self.__class__.__name__} dialect={self.config.dialect} " - f"driver={self.config.driver} host={self.config.host} " - f"port={str(self.config.port)} database={self.config.database} " - f"table={self.config.table}>" - ) - - def _validate_column_name(self, column_name): - regex = r"^[a-zA-Z0-9_]+$" - if not re.match(regex, column_name): - raise ValueError(f"Invalid column name: {column_name}") - - def _build_query(self, limit=None, order=None): - base_query = select("*").select_from(text(self.cs_table_name)) - if self.config.where or self._additional_filters: - # conditions is the list of where + additional filters - conditions = [] - if self.config.where: - conditions += self.config.where - if self._additional_filters: - conditions += self._additional_filters - - query_params = {} - condition_strings = [] - - valid_operators = ["=", ">", "<", ">=", "<=", "LIKE", "!=", "IN", "NOT IN"] - - for i, condition in enumerate(conditions): - if len(condition) == 3: - column_name, operator, value = condition - if operator in valid_operators: - self._validate_column_name(column_name) - - condition_strings.append(f"{column_name} {operator} :value_{i}") - query_params[f"value_{i}"] = value - - if condition_strings: - where_clause = " AND ".join(condition_strings) - base_query = base_query.where( - text(where_clause).bindparams(**query_params) - ) - - if order: - base_query = base_query.order_by(asc(text(order))) - - if limit: - base_query = base_query.limit(limit) - - return base_query - - @cache - def head(self, n: int = 5) -> pd.DataFrame: - """ - Return the head of the data source that the connector is connected to. - This information is passed to the LLM to provide the schema of the data source. - - Returns: - DataFrame: The head of the data source. - """ - - if self.logger: - self.logger.log( - f"Getting head of {self.config.table} " - f"using dialect {self.config.dialect}" - ) - - # Run a SQL query to get all the columns names and 5 random rows - query = self._build_query(limit=n, order="RAND()") - - # Return the head of the data source - return pd.read_sql(query, self._connection) - - def _get_cache_path(self, include_additional_filters: bool = False): - """ - Return the path of the cache file. - - Args: - include_additional_filters (bool, optional): Whether to include the - additional filters in when calling `_get_column_hash()`. - Defaults to False. - - Returns: - str: The path of the cache file. - """ - try: - cache_dir = os.path.join((find_project_root()), "cache") - except ValueError: - cache_dir = os.path.join(os.getcwd(), "cache") - - os.makedirs(cache_dir, mode=DEFAULT_FILE_PERMISSIONS, exist_ok=True) - - filename = ( - self._get_column_hash(include_additional_filters=include_additional_filters) - + ".parquet" - ) - path = os.path.join(cache_dir, filename) - - return path - - def _cached(self, include_additional_filters: bool = False) -> Union[str, bool]: - """ - Return the cached data if it exists and is not older than the cache interval. - - Args: - include_additional_filters (bool, optional): Whether to include the - additional filters in when calling `_get_column_hash()`. - Defaults to False. - - Returns: - DataFrame|bool: The name of the file containing cached data if it exists - and is not older than the cache interval, False otherwise. - """ - filename = self._get_cache_path( - include_additional_filters=include_additional_filters - ) - if not os.path.exists(filename): - return False - - # If the file is older than 1 day, delete it - if os.path.getmtime(filename) < time.time() - self._cache_interval: - if self.logger: - self.logger.log(f"Deleting expired cached data from {filename}") - os.remove(filename) - return False - - if self.logger: - self.logger.log(f"Loading cached data from {filename}") - - return filename - - def _save_cache(self, df): - """ - Save the given DataFrame to the cache. - - Args: - df (DataFrame): The DataFrame to save to the cache. - """ - - filename = self._get_cache_path( - include_additional_filters=self._additional_filters is not None - and len(self._additional_filters) > 0 - ) - - df.to_csv(filename, index=False) - - def execute(self): - """ - Execute the SQL query and return the result. - - Returns: - DataFrame: The result of the SQL query. - """ - - if cached := self._cached() or self._cached(include_additional_filters=True): - return pd.read_csv(cached) - - if self.logger: - self.logger.log( - f"Loading the table {self.config.table} " - f"using dialect {self.config.dialect}" - ) - - # Run a SQL query to get all the results - query = self._build_query() - - # Get the result of the query - result = pd.read_sql(query, self._connection) - - # Save the result to the cache - self._save_cache(result) - - # Return the result - return result - - @cached_property - def rows_count(self): - """ - Return the number of rows in the SQL table. - - Returns: - int: The number of rows in the SQL table. - """ - - if self._rows_count is not None: - return self._rows_count - - if self.logger: - self.logger.log( - "Getting the number of rows in the table " - f"{self.config.table} using dialect " - f"{self.config.dialect}" - ) - - # Run a SQL query to get the number of rows - query = select(text("COUNT(*)")).select_from(text(self.cs_table_name)) - - # Return the number of rows - self._rows_count = self._connection.execute(query).fetchone()[0] - return self._rows_count - - @cached_property - def columns_count(self): - """ - Return the number of columns in the SQL table. - - Returns: - int: The number of columns in the SQL table. - """ - - if self._columns_count is not None: - return self._columns_count - - if self.logger: - self.logger.log( - "Getting the number of columns in the table " - f"{self.config.table} using dialect " - f"{self.config.dialect}" - ) - - self._columns_count = len(self.head().columns) - return self._columns_count - - def _get_column_hash(self, include_additional_filters: bool = False): - """ - Return the hash of the SQL table columns. - - Args: - include_additional_filters (bool, optional): Whether to include the - additional filters in the hash. Defaults to False. - - Returns: - str: The hash of the SQL table columns. - """ - - # Return the hash of the columns and the where clause - columns_str = "".join(self.head().columns) - if ( - self.config.where - or include_additional_filters - and self._additional_filters is not None - ): - columns_str += "WHERE" - if self.config.where: - # where clause is a list of lists - for condition in self.config.where: - columns_str += f"{condition[0]} {condition[1]} {condition[2]}" - if include_additional_filters and self._additional_filters: - for condition in self._additional_filters: - columns_str += f"{condition[0]} {condition[1]} {condition[2]}" - - hash_object = hashlib.sha256(columns_str.encode()) - return hash_object.hexdigest() - - @cached_property - def column_hash(self): - """ - Return the hash of the SQL table columns. - - Returns: - str: The hash of the SQL table columns. - """ - return self._get_column_hash() - - @property - def fallback_name(self): - return self.config.table - - @property - def pandas_df(self): - return self.execute() - - def equals(self, other): - if isinstance(other, self.__class__): - return ( - self.config.dialect, - self.config.driver, - self.config.host, - self.config.port, - ) == ( - other.config.dialect, - other.config.driver, - other.config.host, - other.config.port, - ) - return False - - def _is_sql_query_safe(self, query: str): - infected_keywords = [ - r"\bINSERT\b", - r"\bUPDATE\b", - r"\bDELETE\b", - r"\bDROP\b", - r"\bEXEC\b", - r"\bALTER\b", - r"\bCREATE\b", - ] - - return not any( - re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords - ) - - def execute_direct_sql_query(self, sql_query): - if not self._is_sql_query_safe(sql_query): - raise MaliciousQueryError("Malicious query is generated in code") - - return pd.read_sql(sql_query, self._connection) - - @property - def cs_table_name(self): - return self.config.table - - @property - def type(self): - return self.config.dialect - - -class SqliteConnector(SQLConnector): - """ - Sqlite connector are used to connect to Sqlite databases. - """ - - def __init__( - self, - config: Union[SqliteConnectorConfig, dict], - **kwargs, - ): - """ - Initialize the Sqlite connector with the given configuration. - - Args: - config (ConnectorConfig) : The configuration for the MySQL connector. - """ - config["dialect"] = "sqlite" - if isinstance(config, dict): - sqlite_env_vars = {"database": "SQLITE_DB_PATH", "table": "TABLENAME"} - config = self._populate_config_from_env(config, sqlite_env_vars) - - super().__init__(config, **kwargs) - - def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): - """ - Loads passed Configuration to object - - Args: - config (BaseConnectorConfig): Construct config in structure - - Returns: - config: BaseConenctorConfig - """ - return SqliteConnectorConfig(**config) - - def _init_connection(self, config: SqliteConnectorConfig): - """ - Initialize Database Connection - - Args: - config (SQLConnectorConfig): Configurations to load database - - """ - self._engine = create_engine(f"{config.dialect}:///{config.database}") - self._connection = self._engine.connect() - - def __del__(self): - """ - Close the connection to the SQL database. - """ - if self._connection: - self._connection.close() - - @cache - def head(self, n: int = 5) -> pd.DataFrame: - """ - Return the head of the data source that the connector is connected to. - This information is passed to the LLM to provide the schema of the data source. - - Returns: - DataFrame: The head of the data source. - """ - - if self.logger: - self.logger.log( - f"Getting head of {self.config.table} " - f"using dialect {self.config.dialect}" - ) - - # Run a SQL query to get all the columns names and 5 random rows - query = self._build_query(limit=n, order="RANDOM()") - - # Return the head of the data source - return pd.read_sql(query, self._connection) - - @property - def cs_table_name(self): - return f'"{self.config.table}"' - - def __repr__(self): - """ - Return the string representation of the SQL connector. - - Returns: - str: The string representation of the SQL connector. - """ - return ( - f"<{self.__class__.__name__} dialect={self.config.dialect} " - f"database={self.config.database} " - f"table={self.config.table}>" - ) - - def equals(self, other): - if isinstance(other, self.__class__): - print(self.config.database) - print(other.config.database) - return ( - self.config.dialect, - self.config.driver, - self.config.database, - ) == ( - other.config.dialect, - other.config.driver, - other.config.database, - ) - return False - - -class MySQLConnector(SQLConnector): - """ - MySQL connectors are used to connect to MySQL databases. - """ - - def __init__( - self, - config: Union[SQLConnectorConfig, dict], - **kwargs, - ): - """ - Initialize the MySQL connector with the given configuration. - - Args: - config (ConnectorConfig): The configuration for the MySQL connector. - """ - config["dialect"] = "mysql" - config["driver"] = "pymysql" - - if isinstance(config, dict): - mysql_env_vars = { - "host": "MYSQL_HOST", - "port": "MYSQL_PORT", - "database": "MYSQL_DATABASE", - "username": "MYSQL_USERNAME", - "password": "MYSQL_PASSWORD", - } - config = self._populate_config_from_env(config, mysql_env_vars) - - super().__init__(config, **kwargs) - - -class PostgreSQLConnector(SQLConnector): - """ - PostgreSQL connectors are used to connect to PostgreSQL databases. - """ - - def __init__( - self, - config: Union[SQLConnectorConfig, dict], - **kwargs, - ): - """ - Initialize the PostgreSQL connector with the given configuration. - - Args: - config (ConnectorConfig): The configuration for the PostgreSQL connector. - """ - if "dialect" not in config: - config["dialect"] = "postgresql" - - config["driver"] = "psycopg2" - - if isinstance(config, dict): - postgresql_env_vars = { - "host": "POSTGRESQL_HOST", - "port": "POSTGRESQL_PORT", - "database": "POSTGRESQL_DATABASE", - "username": "POSTGRESQL_USERNAME", - "password": "POSTGRESQL_PASSWORD", - } - config = self._populate_config_from_env(config, postgresql_env_vars) - - super().__init__(config, **kwargs) - - @cache - def head(self, n: int = 5) -> pd.DataFrame: - """ - Return the head of the data source that the connector is connected to. - This information is passed to the LLM to provide the schema of the data source. - - Returns: - DataFrame: The head of the data source. - """ - - if self.logger: - self.logger.log( - f"Getting head of {self.config.table} " - f"using dialect {self.config.dialect}" - ) - - # Run a SQL query to get all the columns names and 5 random rows - query = self._build_query(limit=n, order="RANDOM()") - - # Return the head of the data source - return pd.read_sql(query, self._connection) - - @property - def cs_table_name(self): - return f'"{self.config.table}"' - - def execute_direct_sql_query(self, sql_query): - sql_query = sqlglot.transpile(sql_query, read="mysql", write="postgres")[0] - return super().execute_direct_sql_query(sql_query) diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 4d673a30b..69a65c5b5 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -4,8 +4,6 @@ """ from typing import List -from pandasai.smart_dataframe import SmartDataframe -from pandasai.smart_datalake import SmartDatalake from .agent import Agent from .helpers.cache import Cache from .dataframe.base import DataFrame @@ -81,8 +79,6 @@ def load(dataset_path: str) -> DataFrame: "Agent", "clear_cache", "pandas", - "SmartDataframe", - "SmartDatalake", "DataFrame", "chat", "follow_up", diff --git a/pandasai/agent/agent.py b/pandasai/agent/agent.py index 280c139fe..8be3bb428 100644 --- a/pandasai/agent/agent.py +++ b/pandasai/agent/agent.py @@ -1,22 +1,22 @@ -from typing import List, Optional, Type, Union +from __future__ import annotations +from typing import TYPE_CHECKING, List, Optional, Type, Union -import pandas as pd from pandasai.agent.base import BaseAgent from pandasai.agent.base_judge import BaseJudge from pandasai.agent.base_security import BaseSecurity -from pandasai.connectors.base import BaseConnector from pandasai.pipelines.chat.generate_chat_pipeline import GenerateChatPipeline from pandasai.schemas.df_config import Config from pandasai.vectorstores.vectorstore import VectorStore +if TYPE_CHECKING: + from pandasai.dataframe import DataFrame + class Agent(BaseAgent): def __init__( self, - dfs: Union[ - pd.DataFrame, BaseConnector, List[Union[pd.DataFrame, BaseConnector]] - ], + dfs: Union[DataFrame, List[DataFrame]], config: Optional[Union[Config, dict]] = None, memory_size: Optional[int] = 10, pipeline: Optional[Type[GenerateChatPipeline]] = None, diff --git a/pandasai/connectors/__init__.py b/pandasai/connectors/__init__.py deleted file mode 100644 index 3ad976db5..000000000 --- a/pandasai/connectors/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base import BaseConnector -from .pandas import PandasConnector - -__all__ = [ - "BaseConnector", - "PandasConnector", -] diff --git a/pandasai/connectors/base.py b/pandasai/connectors/base.py deleted file mode 100644 index 1196b25fd..000000000 --- a/pandasai/connectors/base.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Base connector class to be extended by all connectors. -""" - -import json -import os -from abc import ABC, abstractmethod -from functools import cache -from typing import TYPE_CHECKING, List, Optional, Union - -import pandas as pd -from pandasai.helpers.dataframe_serializer import ( - DataframeSerializer, - DataframeSerializerType, -) -from pydantic import BaseModel - -from ..helpers.logger import Logger - -if TYPE_CHECKING: - from pandasai.ee.connectors.relations import AbstractRelation - - -class BaseConnectorConfig(BaseModel): - """ - Base Connector configuration. - """ - - database: str - table: str - where: list[list[str]] = [] - connect_args: Optional[dict] = {} - - -class BaseConnector(ABC): - """ - Base connector class to be extended by all connectors. - """ - - _logger: Logger = None - _additional_filters: list[list[str]] = None - - def __init__( - self, - config: Union[BaseConnectorConfig, dict], - name: str = None, - description: str = None, - custom_head: pd.DataFrame = None, - field_descriptions: dict = None, - connector_relations: List["AbstractRelation"] = None, - ): - """ - Initialize the connector with the given configuration. - - Args: - config (dict): The configuration for the connector. - """ - if isinstance(config, dict): - config = self._load_connector_config(config) - - self.config = config - self.name = name - self.description = description - self.custom_head = custom_head - self.field_descriptions = field_descriptions - self.connector_relations = connector_relations - - def _load_connector_config(self, config: Union[BaseConnectorConfig, dict]): - """Loads passed Configuration to object - - Args: - config (BaseConnectorConfig): Construct config in structure - - Returns: - config: BaseConnectorConfig - """ - pass - - def _populate_config_from_env(self, config: dict, envs_mapping: dict): - """ - Populate the configuration dictionary with values from environment variables - if not exists in the config. - - Args: - config (dict): The configuration dictionary to be populated. - envs_mapping (dict): The dictionary representing a map of config's keys - and according names of the environment variables. - - Returns: - dict: The populated configuration dictionary. - """ - - for key, env_var in envs_mapping.items(): - if key not in config and os.getenv(env_var): - config[key] = os.getenv(env_var) - - return config - - def _init_connection(self, config: BaseConnectorConfig): - """ - make connection to database - """ - pass - - @abstractmethod - def head(self, n: int = 3) -> pd.DataFrame: - """ - Return the head of the data source that the connector is connected to. - This information is passed to the LLM to provide the schema of the - data source. - """ - pass - - @abstractmethod - def execute(self) -> pd.DataFrame: - """ - Execute the given query on the data source that the connector is - connected to. - """ - pass - - def set_additional_filters(self, filters: dict): - """ - Add additional filters to the connector. - - Args: - filters (dict): The additional filters to add to the connector. - """ - self._additional_filters = filters or [] - - @property - def rows_count(self): - """ - Return the number of rows in the data source that the connector is - connected to. - """ - raise NotImplementedError - - @property - def columns_count(self): - """ - Return the number of columns in the data source that the connector is - connected to. - """ - raise NotImplementedError - - @property - def column_hash(self): - """ - Return the hash code that is unique to the columns of the data source - that the connector is connected to. - """ - raise NotImplementedError - - @property - def path(self): - """ - Return the path of the data source that the connector is connected to. - """ - # JDBC string - path = f"{self.__class__.__name__}://{self.config.host}:" - if hasattr(self.config, "port"): - path += str(self.config.port) - path += f"/{self.config.database}/{self.config.table}" - return path - - @property - def logger(self): - """ - Return the logger for the connector. - """ - return self._logger - - @logger.setter - def logger(self, logger: Logger): - """ - Set the logger for the connector. - - Args: - logger (Logger): The logger for the connector. - """ - self._logger = logger - - @property - def fallback_name(self): - """ - Return the name of the table that the connector is connected to. - """ - raise NotImplementedError - - @property - def pandas_df(self): - """ - Returns the pandas dataframe - """ - raise NotImplementedError - - @property - def type(self): - return "pd.DataFrame" - - def equals(self, other): - return self.__dict__ == other.__dict__ - - @cache - def get_head(self, n: int = 3) -> pd.DataFrame: - """ - Return the head of the data source that the connector is connected to. - This information is passed to the LLM to provide the schema of the - data source. - - Args: - n (int, optional): The number of rows to return. Defaults to 5. - - Returns: - pd.DataFrame: The head of the data source that the connector is - connected to. - """ - return self.custom_head if self.custom_head is not None else self.head(n) - - def head_with_truncate_columns(self, max_size=25) -> pd.DataFrame: - """ - Truncate the columns of the dataframe to a maximum of 20 characters. - - Args: - df (pd.DataFrame): The dataframe to truncate the columns of. - - Returns: - pd.DataFrame: The dataframe with truncated columns. - """ - df_trunc = self.get_head().copy() - - for col in df_trunc.columns: - if df_trunc[col].dtype == "object": - first_val = df_trunc[col].iloc[0] - if isinstance(first_val, str) and len(first_val) > max_size: - df_trunc[col] = f"{df_trunc[col].str.slice(0, max_size - 3)}..." - - return df_trunc - - @cache - def get_schema(self) -> pd.DataFrame: - """ - A sample of the dataframe. - - Returns: - pd.DataFrame: A sample of the dataframe. - """ - if self.get_head() is None: - return None - - if len(self.get_head()) > 0: - return self.head_with_truncate_columns() - - return self.get_head() - - @cache - def to_csv(self) -> str: - """ - A proxy-call to the dataframe's `.to_csv()`. - - Returns: - str: The dataframe as a CSV string. - """ - return self.get_head().to_csv(index=False) - - @cache - def to_string( - self, - index: int = 0, - is_direct_sql: bool = False, - serializer: DataframeSerializerType = None, - enforce_privacy: bool = False, - ) -> str: - """ - Convert dataframe to string - Returns: - str: dataframe string - """ - # If field descriptions are added always use YML. Other formats don't support field descriptions yet - if self.field_descriptions or self.connector_relations: - serializer = DataframeSerializerType.YML - - return DataframeSerializer().serialize( - self, - extras={ - "index": index, - "type": "pd.DataFrame", - "is_direct_sql": is_direct_sql, - "enforce_privacy": enforce_privacy, - }, - type_=serializer, - ) - - @cache - def to_json(self): - df_head = self.get_head() - - return { - "name": self.name, - "description": self.description, - "head": json.loads(df_head.to_json(orient="records", date_format="iso")), - } - - def serialize_dataframe( - self, - index: int, - is_direct_sql: bool, - serializer_type: DataframeSerializerType, - enforce_privacy: bool, - ) -> str: - """ - Serialize DataFrame to string representation. - """ - return self.to_string(index, is_direct_sql, serializer_type, enforce_privacy) diff --git a/pandasai/connectors/pandas.py b/pandasai/connectors/pandas.py deleted file mode 100644 index ce46ede6a..000000000 --- a/pandasai/connectors/pandas.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Pandas connector class to handle csv, parquet, xlsx files and pandas dataframes. -""" - -import hashlib -from functools import cache, cached_property -from typing import Union - -try: - import duckdb -except ImportError: - duckdb = None -import sqlglot -from pydantic import BaseModel - -import pandas as pd -from pandasai.exceptions import PandasConnectorTableNotFound - -from ..helpers.data_sampler import DataSampler -from ..helpers.file_importer import FileImporter -from ..helpers.logger import Logger -from .base import BaseConnector - - -class PandasConnectorConfig(BaseModel): - """ - Pandas Connector configuration. - """ - - original_df: Union[pd.DataFrame, pd.Series, str, list, dict] - - class Config: - arbitrary_types_allowed = True - - -class PandasConnector(BaseConnector): - """ - Pandas connector class to handle csv, parquet, xlsx files and pandas dataframes. - """ - - pandas_df = pd.DataFrame - _logger: Logger = None - _additional_filters: list[list[str]] = None - - def __init__( - self, - config: Union[PandasConnectorConfig, dict], - **kwargs, - ): - """ - Initialize the Pandas connector with the given configuration. - - Args: - config (PandasConnectorConfig): The configuration for the Pandas connector. - """ - super().__init__(config, **kwargs) - - self._load_df(self.config.original_df) - self.sql_enabled = False - - def _load_df(self, df: Union[pd.DataFrame, pd.Series, str, list, dict]): - """ - Load the dataframe from a file or pandas dataframe. - - Args: - df (Union[pd.DataFrame, pd.Series, str, list, dict]): The dataframe to load. - """ - if isinstance(df, pd.Series): - self.pandas_df = df.to_frame() - elif isinstance(df, pd.DataFrame): - self.pandas_df = df - elif isinstance(df, (list, dict)): - try: - self.pandas_df = pd.DataFrame(df) - except Exception as e: - raise ValueError( - "Invalid input data. We cannot convert it to a dataframe." - ) from e - elif isinstance(df, str): - self.pandas_df = FileImporter.import_from_file(df) - else: - raise ValueError("Invalid input data. We cannot convert it to a dataframe.") - - def _load_connector_config( - self, config: Union[PandasConnectorConfig, dict] - ) -> PandasConnectorConfig: - """ - Loads passed Configuration to object - - Args: - config (PandasConnectorConfig): Construct config in structure - - Returns: - config: PandasConnectorConfig - """ - return PandasConnectorConfig(**config) - - @cache - def head(self, n: int = 5) -> pd.DataFrame: - """ - Return the head of the data source that the connector is connected to. - This information is passed to the LLM to provide the schema of the - data source. - """ - sampler = DataSampler(self.pandas_df) - return sampler.sample(n) - - @cache - def execute(self) -> pd.DataFrame: - """ - Execute the given query on the data source that the connector is - connected to. - """ - return self.pandas_df - - @cached_property - def rows_count(self): - """ - Return the number of rows in the data source that the connector is - connected to. - """ - return len(self.pandas_df) - - @cached_property - def columns_count(self): - """ - Return the number of columns in the data source that the connector is - connected to. - """ - return len(self.pandas_df.columns) - - @property - def column_hash(self): - """ - Return the hash code that is unique to the columns of the data source - that the connector is connected to. - """ - columns_str = "".join(self.pandas_df.columns) - hash_object = hashlib.sha256(columns_str.encode()) - return hash_object.hexdigest() - - @cached_property - def path(self): - """ - Return the path of the data source that the connector is connected to. - """ - pass - - @property - def fallback_name(self): - """ - Return the name of the table that the connector is connected to. - """ - pass - - @property - def type(self): - return "pd.DataFrame" - - def equals(self, other: BaseConnector): - """ - Return whether the data source that the connector is connected to is - equal to the other data source. - """ - return self._original_df.equals(other._original_df) - - def enable_sql_query(self, table_name=None): - if duckdb is None: - raise ImportError( - "DuckDB is not installed. Please install it to use SQL queries." - ) - - if not table_name and not self.name: - raise PandasConnectorTableNotFound("Table name not found!") - - table = table_name or self.name - - # Check if the table already exists in DuckDB - existing_tables = duckdb.query("SHOW TABLES").fetchall() - - # If the table already exists, drop it - if table in [t[0] for t in existing_tables]: - duckdb.query(f"DROP TABLE {table}") - - duckdb_relation = duckdb.from_df(self.pandas_df) - duckdb_relation.create(table) - self.sql_enabled = True - self.name = table - - def execute_direct_sql_query(self, sql_query): - if duckdb is None: - raise ImportError( - "DuckDB is not installed. Please install it to use SQL queries." - ) - - if not self.sql_enabled: - self.enable_sql_query() - - sql_query = sqlglot.transpile(sql_query, read="mysql", write="duckdb")[0] - return duckdb.query(sql_query).df() - - @property - def cs_table_name(self): - return self.name diff --git a/pandasai/dataframe/loader.py b/pandasai/dataframe/loader.py index f19d64e7a..e47ea7580 100644 --- a/pandasai/dataframe/loader.py +++ b/pandasai/dataframe/loader.py @@ -3,6 +3,8 @@ import pandas as pd from datetime import datetime, timedelta import hashlib + +from pandasai.helpers.path import find_project_root from .base import DataFrame import importlib from typing import Any @@ -32,7 +34,10 @@ def load(self, dataset_path: str, lazy=False) -> DataFrame: return DataFrame(df, schema=self.schema) def _load_schema(self): - schema_path = os.path.join("datasets", self.dataset_path, "schema.yaml") + schema_path = os.path.join( + find_project_root(), "datasets", self.dataset_path, "schema.yaml" + ) + print(schema_path) if not os.path.exists(schema_path): raise FileNotFoundError(f"Schema file not found: {schema_path}") diff --git a/pandasai/dataframe/query_builder.py b/pandasai/dataframe/query_builder.py index 1cb072f04..8bc8c1e50 100644 --- a/pandasai/dataframe/query_builder.py +++ b/pandasai/dataframe/query_builder.py @@ -16,7 +16,10 @@ def build_query(self) -> str: return query def _get_columns(self) -> str: - return ", ".join([col["name"] for col in self.schema["columns"]]) + if "columns" in self.schema: + return ", ".join([col["name"] for col in self.schema["columns"]]) + else: + return "*" def _add_order_by(self) -> str: if "order_by" not in self.schema: diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py index dbc10f516..67522a324 100644 --- a/pandasai/helpers/dataframe_serializer.py +++ b/pandasai/helpers/dataframe_serializer.py @@ -55,7 +55,7 @@ def convert_df_to_csv(self, df: pd.DataFrame, extras: dict) -> str: dataframe_info += ">" # Add dataframe details - dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.to_csv()}" + dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.to_csv(index=False)}" # Close the dataframe tag dataframe_info += "\n" @@ -96,7 +96,7 @@ def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict: # Create a dictionary representing the data structure df_info = { "name": df.name, - "description": df.description, + "description": None, "type": ( df.type if "is_direct_sql" in extras and extras["is_direct_sql"] @@ -122,20 +122,21 @@ def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict: col_info["samples"] = df_head[col_name].head().tolist() # Add column description if available - if df.field_descriptions and isinstance(df.field_descriptions, dict): - if col_description := df.field_descriptions.get(col_name, None): - col_info["description"] = col_description - - if df.connector_relations: - for relation in df.connector_relations: - from pandasai.ee.connectors.relations import ForeignKey, PrimaryKey - - if ( - isinstance(relation, PrimaryKey) and relation.name == col_name - ) or ( - isinstance(relation, ForeignKey) and relation.field == col_name - ): - col_info["constraints"] = relation.to_string() + # TODO - Fix or remove this later! + # if df.field_descriptions and isinstance(df.field_descriptions, dict): + # if col_description := df.field_descriptions.get(col_name, None): + # col_info["description"] = col_description + + # if df.connector_relations: + # for relation in df.connector_relations: + # from pandasai.ee.connectors.relations import ForeignKey, PrimaryKey + + # if ( + # isinstance(relation, PrimaryKey) and relation.name == col_name + # ) or ( + # isinstance(relation, ForeignKey) and relation.field == col_name + # ): + # col_info["constraints"] = relation.to_string() data["schema"]["fields"].append(col_info) diff --git a/pandasai/pipelines/chat/code_cleaning.py b/pandasai/pipelines/chat/code_cleaning.py index 355b6b16e..57ad4b757 100644 --- a/pandasai/pipelines/chat/code_cleaning.py +++ b/pandasai/pipelines/chat/code_cleaning.py @@ -1,23 +1,20 @@ +from __future__ import annotations import ast import copy import re import traceback import uuid -from typing import Any, List, Union +from typing import TYPE_CHECKING, Any, List, Union import astor - -from pandasai.connectors.pandas import PandasConnector from pandasai.helpers.optional import get_environment from pandasai.helpers.path import find_project_root from pandasai.helpers.sql import extract_table_names -from ...connectors import BaseConnector from ...constants import WHITELISTED_BUILTINS, WHITELISTED_LIBRARIES from ...exceptions import ( BadImportError, ExecuteSQLQueryNotUsed, - InvalidConfigError, MaliciousQueryError, ) from ...helpers.logger import Logger @@ -27,6 +24,9 @@ from ..logic_unit_output import LogicUnitOutput from ..pipeline_context import PipelineContext +if TYPE_CHECKING: + from pandasai.dataframe.base import DataFrame + class CodeExecutionContext: def __init__( @@ -240,34 +240,39 @@ def check_direct_sql_func_def_exists(self, node: ast.AST): and node.name == "execute_sql_query" ) - def _validate_direct_sql(self, dfs: List[BaseConnector]) -> bool: + def _validate_direct_sql(self, dfs: List[DataFrame]) -> bool: """ Raises error if they don't belong sqlconnector or have different credentials Args: - dfs (List[BaseConnector]): list of BaseConnectors + dfs (List[DataFrame]): list of DataFrames Raises: InvalidConfigError: Raise Error in case of config is set but criteria is not met """ - if self._config.direct_sql: - if all( - ( - hasattr(df, "is_sql_connector") - and df.is_sql_connector - and df.equals(dfs[0]) - ) - for df in dfs - ) or all( - (isinstance(df, PandasConnector) and df.sql_enabled) for df in dfs - ): - return True - else: - raise InvalidConfigError( - "Direct SQL requires all connectors to be SQL connectors and they must belong to the same datasource " - "and have the same credentials" - ) - return False + return self._config.direct_sql + # if self._config.direct_sql: + # return True + # else: + # return + # TODO - while working on direct sql + # if all( + # ( + # hasattr(df, "is_sql_connector") + # and df.is_sql_connector + # and df.equals(dfs[0]) + # ) + # for df in dfs + # ) or all( + # (isinstance(df, PandasConnector) and df.sql_enabled) for df in dfs + # ): + # return True + # else: + # raise InvalidConfigError( + # "Direct SQL requires all connectors to be SQL connectors and they must belong to the same datasource " + # "and have the same credentials" + # ) + # return False def _replace_table_names( self, sql_query: str, table_names: list, allowed_table_names: list diff --git a/pandasai/pipelines/chat/validate_pipeline_input.py b/pandasai/pipelines/chat/validate_pipeline_input.py index b640c489b..2868d62b6 100644 --- a/pandasai/pipelines/chat/validate_pipeline_input.py +++ b/pandasai/pipelines/chat/validate_pipeline_input.py @@ -1,12 +1,14 @@ -from typing import Any, List - +from __future__ import annotations +from typing import TYPE_CHECKING, Any, List from pandasai.exceptions import InvalidConfigError from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from ...connectors import BaseConnector from ..base_logic_unit import BaseLogicUnit from ..pipeline_context import PipelineContext +if TYPE_CHECKING: + from pandasai.dataframe.base import DataFrame + class ValidatePipelineInput(BaseLogicUnit): """ @@ -15,7 +17,7 @@ class ValidatePipelineInput(BaseLogicUnit): pass - def _validate_direct_sql(self, dfs: List[BaseConnector]) -> bool: + def _validate_direct_sql(self, dfs: List[DataFrame]) -> bool: """ Validates that all connectors are SQL connectors and belong to the same datasource when direct_sql is True. diff --git a/pandasai/pipelines/pipeline.py b/pandasai/pipelines/pipeline.py index d843d232e..c23ca76fb 100644 --- a/pandasai/pipelines/pipeline.py +++ b/pandasai/pipelines/pipeline.py @@ -1,17 +1,21 @@ +from __future__ import annotations import logging -from typing import Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union from pandasai.config import load_config_from_json + from pandasai.exceptions import PipelineConcatenationError, UnSupportedLogicUnit from pandasai.helpers.logger import Logger from pandasai.pipelines.base_logic_unit import BaseLogicUnit from pandasai.pipelines.logic_unit_output import LogicUnitOutput from pandasai.pipelines.pipeline_context import PipelineContext -from ..connectors import BaseConnector from ..schemas.df_config import Config from .abstract_pipeline import AbstractPipeline +if TYPE_CHECKING: + from pandasai.dataframe.base import DataFrame + class Pipeline(AbstractPipeline): """ @@ -24,7 +28,7 @@ class Pipeline(AbstractPipeline): def __init__( self, - context: Union[List[BaseConnector], PipelineContext], + context: Union[List[DataFrame], PipelineContext], config: Optional[Union[Config, dict]] = None, steps: Optional[List] = None, logger: Optional[Logger] = None, diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py deleted file mode 100644 index 81474c345..000000000 --- a/pandasai/smart_dataframe/__init__.py +++ /dev/null @@ -1,230 +0,0 @@ -import uuid -from functools import cached_property -from io import StringIO -from typing import Any, List, Optional, Union - -import pandas as pd -from pandasai.agent import Agent -from pandasai.connectors.pandas import PandasConnector - -from ..connectors.base import BaseConnector -from ..helpers.logger import Logger -from ..schemas.df_config import Config - - -class SmartDataframe: - _table_name: str - _table_description: str - _custom_head: str = None - _original_import: any - - def __init__( - self, - df: Union[pd.DataFrame, BaseConnector], - name: str = None, - description: str = None, - custom_head: pd.DataFrame = None, - config: Config = None, - ): - print("\n" + "*" * 80) - print("\033[1;33mDEPRECATION WARNING:\033[0m") - print("SmartDataframe will be deprecated soon. Use df.chat() instead.") - print("*" * 80 + "\n") - - self._original_import = df - - self._agent = Agent([df], config=config) - - self.dataframe = self._agent.context.dfs[0] - - self._table_description = description - self._table_name = name - - if custom_head is not None: - self._custom_head = custom_head.to_csv(index=False) - - def load_dfs(self, df, name: str, description: str, custom_head: pd.DataFrame): - if isinstance(df, (pd.DataFrame, pd.Series, list, dict, str)): - df = PandasConnector( - {"original_df": df}, - name=name, - description=description, - custom_head=custom_head, - ) - else: - raise ValueError("Invalid input data. We cannot convert it to a dataframe.") - return df - - def chat(self, query: str, output_type: Optional[str] = None): - """ - Run a query on the dataframe. - - Args: - query (str): Query to run on the dataframe - output_type (Optional[str]): Add a hint for LLM of which - type should be returned by `analyze_data()` in generated - code. Possible values: "number", "dataframe", "plot", "string": - * number - specifies that user expects to get a number - as a response object - * dataframe - specifies that user expects to get - pandas dataframe as a response object - * plot - specifies that user expects LLM to build - a plot - * string - specifies that user expects to get text - as a response object - - Raises: - ValueError: If the query is empty - """ - return self._agent.chat(query, output_type) - - @cached_property - def head_df(self): - """ - Get the head of the dataframe as a dataframe. - - Returns: - pd.DataFrame: Pandas dataframe - """ - return self.dataframe.get_head() - - @cached_property - def head_csv(self): - """ - Get the head of the dataframe as a CSV string. - - Returns: - str: CSV string - """ - df_head = self.dataframe.get_head() - return df_head.to_csv(index=False) - - @property - def last_prompt(self): - return self._agent.last_prompt - - @property - def last_prompt_id(self) -> uuid.UUID: - return self._agent.last_prompt_id - - @property - def last_code_generated(self): - return self._agent.last_code_executed - - @property - def last_code_executed(self): - return self._agent.last_code_executed - - def original_import(self): - return self._original_import - - @property - def logger(self): - return self._agent.logger - - @logger.setter - def logger(self, logger: Logger): - self._agent.logger = logger - - @property - def logs(self): - return self._agent.context.config.logs - - @property - def verbose(self): - return self._agent.context.config.verbose - - @verbose.setter - def verbose(self, verbose: bool): - self._agent.context.config.verbose = verbose - - @property - def save_logs(self): - return self._agent.context.config.save_logs - - @save_logs.setter - def save_logs(self, save_logs: bool): - self._agent.context.config.save_logs = save_logs - - @property - def enforce_privacy(self): - return self._agent.context.config.enforce_privacy - - @enforce_privacy.setter - def enforce_privacy(self, enforce_privacy: bool): - self._agent.context.config.enforce_privacy = enforce_privacy - - @property - def enable_cache(self): - return self._agent.context.config.enable_cache - - @enable_cache.setter - def enable_cache(self, enable_cache: bool): - self._agent.context.config.enable_cache = enable_cache - - @property - def save_charts(self): - return self._agent.context.config.save_charts - - @save_charts.setter - def save_charts(self, save_charts: bool): - self._agent.context.config.save_charts = save_charts - - @property - def save_charts_path(self): - return self._agent.context.config.save_charts_path - - @save_charts_path.setter - def save_charts_path(self, save_charts_path: str): - self._agent.context.config.save_charts_path = save_charts_path - - @property - def table_name(self): - return self._table_name - - @property - def table_description(self): - return self._table_description - - @property - def custom_head(self): - data = StringIO(self._custom_head) - return pd.read_csv(data) - - def __len__(self): - return len(self.dataframe) - - def __eq__(self, other): - return self.dataframe.equals(other.dataframe) - - def __getattr__(self, name): - if name in self.dataframe.__dir__(): - return getattr(self.dataframe, name) - else: - return self.__getattribute__(name) - - def __getitem__(self, key): - return self.dataframe.__getitem__(key) - - def __setitem__(self, key, value): - return self.dataframe.__setitem__(key, value) - - -def load_smartdataframes( - dfs: List[Union[pd.DataFrame, Any]], config: Config -) -> List[SmartDataframe]: - """ - Load all the dataframes to be used in the smart datalake. - - Args: - dfs (List[Union[pd.DataFrame, Any]]): List of dataframes to be used - """ - - smart_dfs = [] - for df in dfs: - if not isinstance(df, SmartDataframe): - smart_dfs.append(SmartDataframe(df, config=config)) - else: - smart_dfs.append(df) - - return smart_dfs diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py deleted file mode 100644 index 728a43236..000000000 --- a/pandasai/smart_datalake/__init__.py +++ /dev/null @@ -1,182 +0,0 @@ -import uuid -import pandas as pd -from typing import List, Optional, Union - -from pandasai.agent import Agent - -from ..helpers.cache import Cache -from ..schemas.df_config import Config -from ..connectors.base import BaseConnector - - -class SmartDatalake: - def __init__( - self, - dfs: List[Union[pd.DataFrame, BaseConnector]], - config: Optional[Union[Config, dict]] = None, - ): - print("\n" + "*" * 80) - print("\033[1;33mDEPRECATION WARNING:\033[0m") - print("SmartDatalake will be deprecated soon. Use pai.chat() instead.") - print("*" * 80 + "\n") - - self._agent = Agent(dfs, config=config) - - def chat(self, query: str, output_type: Optional[str] = None): - """ - Run a query on the dataframe. - - Args: - query (str): Query to run on the dataframe - output_type (Optional[str]): Add a hint for LLM which - type should be returned by `analyze_data()` in generated - code. Possible values: "number", "dataframe", "plot", "string": - * number - specifies that user expects to get a number - as a response object - * dataframe - specifies that user expects to get - pandas dataframe as a response object - * plot - specifies that user expects LLM to build - a plot - * string - specifies that user expects to get text - as a response object - If none `output_type` is specified, the type can be any - of the above or "text". - - Raises: - ValueError: If the query is empty - """ - return self._agent.chat(query, output_type) - - def clear_memory(self): - """ - Clears the memory - """ - self._agent.clear_memory() - - @property - def last_prompt(self): - return self._agent.last_prompt - - @property - def last_prompt_id(self) -> uuid.UUID: - """Return the id of the last prompt that was run.""" - if self._agent.last_prompt_id is None: - raise ValueError("Pandas AI has not been run yet.") - return self._agent.last_prompt_id - - @property - def logs(self): - return self._agent.logger.logs - - @property - def logger(self): - return self._agent.logger - - @logger.setter - def logger(self, logger): - self._agent.logger = logger - - @property - def config(self): - return self._agent.context.config - - @property - def cache(self): - return self._agent.context.cache - - @property - def verbose(self): - return self._agent.context.config.verbose - - @verbose.setter - def verbose(self, verbose: bool): - self._agent.context.config.verbose = verbose - self._agent.logger.verbose = verbose - - @property - def save_logs(self): - return self._agent.context.config.save_logs - - @save_logs.setter - def save_logs(self, save_logs: bool): - self._agent.context.config.save_logs = save_logs - self._agent.logger.save_logs = save_logs - - @property - def enforce_privacy(self): - return self._agent.context.config.enforce_privacy - - @enforce_privacy.setter - def enforce_privacy(self, enforce_privacy: bool): - self._agent.context.config.enforce_privacy = enforce_privacy - - @property - def enable_cache(self): - return self._agent.context.config.enable_cache - - @enable_cache.setter - def enable_cache(self, enable_cache: bool): - self._agent.context.config.enable_cache = enable_cache - if enable_cache: - if self.cache is None: - self._cache = Cache() - else: - self._cache = None - - @property - def use_error_correction_framework(self): - return self._agent.context.config.use_error_correction_framework - - @use_error_correction_framework.setter - def use_error_correction_framework(self, use_error_correction_framework: bool): - self._agent.context.config.use_error_correction_framework = ( - use_error_correction_framework - ) - - @property - def custom_prompts(self): - return self._agent.context.config.custom_prompts - - @custom_prompts.setter - def custom_prompts(self, custom_prompts: dict): - self._agent.context.config.custom_prompts = custom_prompts - - @property - def save_charts(self): - return self._agent.context.config.save_charts - - @save_charts.setter - def save_charts(self, save_charts: bool): - self._agent.context.config.save_charts = save_charts - - @property - def save_charts_path(self): - return self._agent.context.config.save_charts_path - - @save_charts_path.setter - def save_charts_path(self, save_charts_path: str): - self._agent.context.config.save_charts_path = save_charts_path - - @property - def last_code_generated(self): - return self._agent.last_code_generated - - @property - def last_code_executed(self): - return self._agent.last_code_executed - - @property - def last_result(self): - return self._agent.last_result - - @property - def last_error(self): - return self._agent.last_error - - @property - def dfs(self): - return self._agent.context.dfs - - @property - def memory(self): - return self._agent.context.memory diff --git a/tests/unit_tests/agent/test_base_agent.py b/tests/unit_tests/agent/test_base_agent.py index 239b4dca7..95b0d9c34 100644 --- a/tests/unit_tests/agent/test_base_agent.py +++ b/tests/unit_tests/agent/test_base_agent.py @@ -1,10 +1,9 @@ +from pandasai.dataframe.base import DataFrame from pandasai.llm.fake import FakeLLM import pytest -import pandas as pd from unittest.mock import Mock, patch, MagicMock from pandasai.agent.base import BaseAgent from pandasai.pipelines.chat.chat_pipeline_input import ChatPipelineInput -from pandasai.connectors import PandasConnector class TestBaseAgent: @@ -17,7 +16,7 @@ def mock_bamboo_llm(self): @pytest.fixture def mock_agent(self): # Create a mock DataFrame - mock_df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + mock_df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) fake_llm = FakeLLM() agent = BaseAgent([mock_df], config={"llm": fake_llm}) agent.pipeline = MagicMock() diff --git a/tests/unit_tests/connectors/__init__.py b/tests/unit_tests/connectors/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/unit_tests/connectors/test_base.py b/tests/unit_tests/connectors/test_base.py deleted file mode 100644 index 7f58ac909..000000000 --- a/tests/unit_tests/connectors/test_base.py +++ /dev/null @@ -1,93 +0,0 @@ -import pytest - -from pandasai.connectors import BaseConnector -from pandasai.connectors.base import BaseConnectorConfig -from pandasai.helpers import Logger - - -class MockConfig: - def __init__(self, host, port, database, table): - self.host = host - self.port = port - self.database = database - self.table = table - - -# Mock subclass of BaseConnector for testing -class MockConnector(BaseConnector): - def _load_connector_config(self, config: BaseConnectorConfig): - pass - - def _init_connection(self, config: BaseConnectorConfig): - pass - - def head(self, n: int = 5): - pass - - def execute(self): - pass - - @property - def rows_count(self): - return 100 - - @property - def columns_count(self): - return 5 - - @property - def column_hash(self): - return "some_hash_value" - - @property - def fallback_name(self): - return "fallback_table_name" - - -# Mock Logger class for testing -class MockLogger(Logger): - def __init__(self): - pass - - -# Create a fixture for the configuration -@pytest.fixture -def mock_config(): - return MockConfig("localhost", 5432, "test_db", "test_table") - - -# Create a fixture for the connector with the configuration -@pytest.fixture -def mock_connector(mock_config): - return MockConnector(mock_config) - - -def test_base_connector_initialization(mock_config, mock_connector): - assert mock_connector.config == mock_config - - -def test_base_connector_path_property(mock_connector): - expected_path = "MockConnector://localhost:5432/test_db/test_table" - assert mock_connector.path == expected_path - - -def test_base_connector_logger_property(mock_connector): - logger = MockLogger() - mock_connector.logger = logger - assert mock_connector.logger == logger - - -def test_base_connector_rows_count_property(mock_connector): - assert mock_connector.rows_count == 100 - - -def test_base_connector_columns_count_property(mock_connector): - assert mock_connector.columns_count == 5 - - -def test_base_connector_column_hash_property(mock_connector): - assert mock_connector.column_hash == "some_hash_value" - - -def test_base_connector_fallback_name_property(mock_connector): - assert mock_connector.fallback_name == "fallback_table_name" diff --git a/tests/unit_tests/connectors/test_pandas.py b/tests/unit_tests/connectors/test_pandas.py deleted file mode 100644 index 849f67a7b..000000000 --- a/tests/unit_tests/connectors/test_pandas.py +++ /dev/null @@ -1,75 +0,0 @@ -import pandas as pd -import pytest - -from pandasai.connectors import PandasConnector - - -class TestPandasConnector: - def test_load_dataframe_from_list(self): - input_data = [ - {"column1": 1, "column2": 4}, - {"column1": 2, "column2": 5}, - {"column1": 3, "column2": 6}, - ] - connector = PandasConnector({"original_df": input_data}) - assert isinstance(connector.execute(), pd.DataFrame) - - def test_load_dataframe_from_dict(self): - input_data = {"column1": [1, 2, 3], "column2": [4, 5, 6]} - connector = PandasConnector({"original_df": input_data}) - assert isinstance(connector.execute(), pd.DataFrame) - - def test_load_dataframe_from_pandas_dataframe(self): - input_data = pd.DataFrame({"column1": [1, 2, 3], "column2": [4, 5, 6]}) - connector = PandasConnector({"original_df": input_data}) - assert isinstance(connector.execute(), pd.DataFrame) - - def test_import_pandas_series(self): - input_data = pd.Series([1, 2, 3]) - connector = PandasConnector({"original_df": input_data}) - assert isinstance(connector.execute(), pd.DataFrame) - - def test_to_json(self): - input_data = pd.DataFrame( - { - "EmployeeID": [1, 2, 3, 4, 5], - "Name": ["John", "Emma", "Liam", "Olivia", "William"], - "Department": ["HR", "Sales", "IT", "Marketing", "Finance"], - } - ) - connector = PandasConnector({"original_df": input_data}) - data = connector.to_json() - - assert isinstance(data, dict) - assert "name" in data - assert "description" in data - assert "head" in data - assert isinstance(data["head"], list) - - def test_type_name_property(self): - input_data = [ - {"column1": 1, "column2": 4}, - {"column1": 2, "column2": 5}, - {"column1": 3, "column2": 6}, - ] - connector = PandasConnector({"original_df": input_data}) - assert connector.type == "pd.DataFrame" - - def test_cs_table_name(self): - input_data = [ - {"column1": 1, "column2": 4}, - {"column1": 2, "column2": 5}, - {"column1": 3, "column2": 6}, - ] - connector = PandasConnector({"original_df": input_data}, name="test_name") - assert connector.cs_table_name == "test_name" - - def test_enable_sql_query(self): - input_data = [ - {"column1": 1, "column2": 4}, - {"column1": 2, "column2": 5}, - {"column1": 3, "column2": 6}, - ] - connector = PandasConnector({"original_df": input_data}) - with pytest.raises(Exception): - connector.enable_sql_query() diff --git a/tests/unit_tests/ee/judge_agent/test_judge_agent.py b/tests/unit_tests/ee/judge_agent/test_judge_agent.py deleted file mode 100644 index a59e610fa..000000000 --- a/tests/unit_tests/ee/judge_agent/test_judge_agent.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import Optional -from unittest.mock import MagicMock, patch - -import pandas as pd -import pytest - -from pandasai.agent import Agent -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.judge_agent import JudgeAgent -from pandasai.helpers.dataframe_serializer import DataframeSerializerType -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from tests.unit_tests.ee.helpers.schema import ( - VIZ_QUERY_SCHEMA_STR, -) - - -class MockBambooLLM(BambooLLM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.call = MagicMock(return_value=VIZ_QUERY_SCHEMA_STR) - - -class TestJudgeAgent: - "Unit tests for Agent class" - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "order_id": [ - 10248, - 10249, - 10250, - 10251, - 10252, - 10253, - 10254, - 10255, - 10256, - 10257, - ], - "customer_id": [ - "VINET", - "TOMSP", - "HANAR", - "VICTE", - "SUPRD", - "HANAR", - "CHOPS", - "RICSU", - "WELLI", - "HILAA", - ], - "employee_id": [5, 6, 4, 3, 4, 3, 4, 7, 3, 4], - "order_date": pd.to_datetime( - [ - "1996-07-04", - "1996-07-05", - "1996-07-08", - "1996-07-08", - "1996-07-09", - "1996-07-10", - "1996-07-11", - "1996-07-12", - "1996-07-15", - "1996-07-16", - ] - ), - "required_date": pd.to_datetime( - [ - "1996-08-01", - "1996-08-16", - "1996-08-05", - "1996-08-05", - "1996-08-06", - "1996-08-07", - "1996-08-08", - "1996-08-09", - "1996-08-12", - "1996-08-13", - ] - ), - "shipped_date": pd.to_datetime( - [ - "1996-07-16", - "1996-07-10", - "1996-07-12", - "1996-07-15", - "1996-07-11", - "1996-07-16", - "1996-07-23", - "1996-07-26", - "1996-07-17", - "1996-07-22", - ] - ), - "ship_via": [3, 1, 2, 1, 2, 2, 2, 3, 2, 1], - "ship_name": [ - "Vins et alcools Chevalier", - "Toms Spezialitäten", - "Hanari Carnes", - "Victuailles en stock", - "Suprêmes délices", - "Hanari Carnes", - "Chop-suey Chinese", - "Richter Supermarkt", - "Wellington Importadora", - "HILARION-Abastos", - ], - "ship_address": [ - "59 rue de l'Abbaye", - "Luisenstr. 48", - "Rua do Paço, 67", - "2, rue du Commerce", - "Boulevard Tirou, 255", - "Rua do Paço, 67", - "Hauptstr. 31", - "Starenweg 5", - "Rua do Mercado, 12", - "Carrera 22 con Ave. Carlos Soublette #8-35", - ], - "ship_city": [ - "Reims", - "Münster", - "Rio de Janeiro", - "Lyon", - "Charleroi", - "Rio de Janeiro", - "Bern", - "Genève", - "Resende", - "San Cristóbal", - ], - "ship_region": [ - "CJ", - None, - "RJ", - "RH", - None, - "RJ", - None, - None, - "SP", - "Táchira", - ], - "ship_postal_code": [ - "51100", - "44087", - "05454-876", - "69004", - "B-6000", - "05454-876", - "3012", - "1204", - "08737-363", - "5022", - ], - "ship_country": [ - "France", - "Germany", - "Brazil", - "France", - "Belgium", - "Brazil", - "Switzerland", - "Switzerland", - "Brazil", - "Venezuela", - ], - } - ) - - @pytest.fixture - def llm(self, output: Optional[str] = None) -> FakeLLM: - return FakeLLM(output=output) - - @pytest.fixture - def config(self, llm: FakeLLM) -> dict: - return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV} - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def agent(self) -> Agent: - return JudgeAgent() - - def test_contruct_with_pipeline(self, sample_df): - JudgeAgent(pipeline=MagicMock()) diff --git a/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py b/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py deleted file mode 100644 index 01055354b..000000000 --- a/tests/unit_tests/ee/judge_agent/test_judge_llm_call.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Optional -from unittest.mock import MagicMock, patch - -import pandas as pd -import pytest - -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.judge_agent.pipeline.llm_call import LLMCall -from pandasai.exceptions import InvalidOutputValueMismatch -from pandasai.helpers.logger import Logger -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA_STR - - -class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return VIZ_QUERY_SCHEMA_STR - - -class TestJudgeLLMCall: - "Unit test for Validate Pipeline Input" - - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - @pytest.fixture - def logger(self): - return Logger(True, False) - - def test_init(self, context, config): - # Test the initialization of the CodeGenerator - code_generator = LLMCall() - assert isinstance(code_generator, LLMCall) - - def test_llm_call(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="") - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert isinstance(result, LogicUnitOutput) - assert result.output is True - - def test_llm_call_no(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="") - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert isinstance(result, LogicUnitOutput) - assert result.output is False - - def test_llm_call_(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="") - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert isinstance(result, LogicUnitOutput) - assert result.output is False - - def test_llm_call_with_no_tags(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="yes") - - context = PipelineContext([sample_df], config) - - with pytest.raises(InvalidOutputValueMismatch): - input_validator.execute(input="test", context=context, logger=logger) diff --git a/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py b/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py deleted file mode 100644 index 74cbe78cb..000000000 --- a/tests/unit_tests/ee/judge_agent/test_judge_prompt_gen.py +++ /dev/null @@ -1,178 +0,0 @@ -import re -from typing import Optional -from unittest.mock import patch - -import pandas as pd -import pytest - -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.judge_agent.pipeline.judge_prompt_generation import ( - JudgePromptGeneration, -) -from pandasai.helpers.logger import Logger -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput -from pandasai.pipelines.pipeline_context import PipelineContext -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA, VIZ_QUERY_SCHEMA_STR - - -class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return VIZ_QUERY_SCHEMA_STR - - -class TestJudgePromptGeneration: - "Unit test for Validate Pipeline Input" - - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - @pytest.fixture - def logger(self): - return Logger(True, False) - - def test_init(self, context, config): - # Test the initialization of the CodeGenerator - code_generator = JudgePromptGeneration() - assert isinstance(code_generator, JudgePromptGeneration) - - def test_validate_input_semantic_prompt(self, sample_df, context, logger): - semantic_prompter = JudgePromptGeneration() - - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df], config) - - context.memory.add("hello word!", True) - - context.add("df_schema", VIZ_QUERY_SCHEMA) - - input_data = JudgePipelineInput( - query="What is test?", code="print('Code Data')" - ) - - response = semantic_prompter.execute( - input_data=input_data, context=context, logger=logger - ) - - match = re.search( - r"Today is ([A-Za-z]+, [A-Za-z]+ \d{1,2}, \d{4} \d{2}:\d{2} [APM]{2})", - response.output.to_string(), - ) - datetime_str = match.group(1) - - assert ( - response.output.to_string() - == f"""Today is {datetime_str} -### QUERY -What is test? -### GENERATED CODE -print('Code Data') - -Reason step by step and at the end answer: -1. Explain what the code does -2. Explain what the user query asks for -3. Strictly compare the query with the code that is generated -Always return or if exactly meets the requirements""" - ) diff --git a/tests/unit_tests/ee/security_agent/test_security_agent.py b/tests/unit_tests/ee/security_agent/test_security_agent.py deleted file mode 100644 index 22dffbc00..000000000 --- a/tests/unit_tests/ee/security_agent/test_security_agent.py +++ /dev/null @@ -1,229 +0,0 @@ -from typing import Optional -from unittest.mock import MagicMock, patch - -import pandas as pd -import pytest - -from pandasai.agent import Agent -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.advanced_security_agent import AdvancedSecurityAgent -from pandasai.helpers.dataframe_serializer import DataframeSerializerType -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from tests.unit_tests.ee.helpers.schema import ( - VIZ_QUERY_SCHEMA_STR, -) - - -class MockBambooLLM(BambooLLM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.call = MagicMock(return_value=VIZ_QUERY_SCHEMA_STR) - - -class TestSecurityAgent: - "Unit tests for Agent class" - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "order_id": [ - 10248, - 10249, - 10250, - 10251, - 10252, - 10253, - 10254, - 10255, - 10256, - 10257, - ], - "customer_id": [ - "VINET", - "TOMSP", - "HANAR", - "VICTE", - "SUPRD", - "HANAR", - "CHOPS", - "RICSU", - "WELLI", - "HILAA", - ], - "employee_id": [5, 6, 4, 3, 4, 3, 4, 7, 3, 4], - "order_date": pd.to_datetime( - [ - "1996-07-04", - "1996-07-05", - "1996-07-08", - "1996-07-08", - "1996-07-09", - "1996-07-10", - "1996-07-11", - "1996-07-12", - "1996-07-15", - "1996-07-16", - ] - ), - "required_date": pd.to_datetime( - [ - "1996-08-01", - "1996-08-16", - "1996-08-05", - "1996-08-05", - "1996-08-06", - "1996-08-07", - "1996-08-08", - "1996-08-09", - "1996-08-12", - "1996-08-13", - ] - ), - "shipped_date": pd.to_datetime( - [ - "1996-07-16", - "1996-07-10", - "1996-07-12", - "1996-07-15", - "1996-07-11", - "1996-07-16", - "1996-07-23", - "1996-07-26", - "1996-07-17", - "1996-07-22", - ] - ), - "ship_via": [3, 1, 2, 1, 2, 2, 2, 3, 2, 1], - "ship_name": [ - "Vins et alcools Chevalier", - "Toms Spezialitäten", - "Hanari Carnes", - "Victuailles en stock", - "Suprêmes délices", - "Hanari Carnes", - "Chop-suey Chinese", - "Richter Supermarkt", - "Wellington Importadora", - "HILARION-Abastos", - ], - "ship_address": [ - "59 rue de l'Abbaye", - "Luisenstr. 48", - "Rua do Paço, 67", - "2, rue du Commerce", - "Boulevard Tirou, 255", - "Rua do Paço, 67", - "Hauptstr. 31", - "Starenweg 5", - "Rua do Mercado, 12", - "Carrera 22 con Ave. Carlos Soublette #8-35", - ], - "ship_city": [ - "Reims", - "Münster", - "Rio de Janeiro", - "Lyon", - "Charleroi", - "Rio de Janeiro", - "Bern", - "Genève", - "Resende", - "San Cristóbal", - ], - "ship_region": [ - "CJ", - None, - "RJ", - "RH", - None, - "RJ", - None, - None, - "SP", - "Táchira", - ], - "ship_postal_code": [ - "51100", - "44087", - "05454-876", - "69004", - "B-6000", - "05454-876", - "3012", - "1204", - "08737-363", - "5022", - ], - "ship_country": [ - "France", - "Germany", - "Brazil", - "France", - "Belgium", - "Brazil", - "Switzerland", - "Switzerland", - "Brazil", - "Venezuela", - ], - } - ) - - @pytest.fixture - def llm(self, output: Optional[str] = None) -> FakeLLM: - return FakeLLM(output=output) - - @pytest.fixture - def config(self, llm: FakeLLM) -> dict: - return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV} - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def agent(self) -> Agent: - return AdvancedSecurityAgent() - - def test_contruct_with_pipeline(self, sample_df): - AdvancedSecurityAgent(pipeline=MagicMock()) diff --git a/tests/unit_tests/ee/security_agent/test_security_llm_call.py b/tests/unit_tests/ee/security_agent/test_security_llm_call.py deleted file mode 100644 index ddb28ed3f..000000000 --- a/tests/unit_tests/ee/security_agent/test_security_llm_call.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Optional -from unittest.mock import MagicMock, patch - -import pandas as pd -import pytest - -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.advanced_security_agent.pipeline.llm_call import LLMCall -from pandasai.exceptions import InvalidOutputValueMismatch -from pandasai.helpers.logger import Logger -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA_STR - - -class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return VIZ_QUERY_SCHEMA_STR - - -class TestSecurityLLMCall: - "Unit test for Validate Pipeline Input" - - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - @pytest.fixture - def logger(self): - return Logger(True, False) - - def test_init(self, context, config): - # Test the initialization of the CodeGenerator - code_generator = LLMCall() - assert isinstance(code_generator, LLMCall) - - def test_llm_call(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="") - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert isinstance(result, LogicUnitOutput) - assert result.output is True - - def test_llm_call_no(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="") - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert isinstance(result, LogicUnitOutput) - assert result.output is False - - def test_llm_call_(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="") - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert isinstance(result, LogicUnitOutput) - assert result.output is False - - def test_llm_call_with_no_tags(self, sample_df, context, logger, config): - input_validator = LLMCall() - - config["llm"].call = MagicMock(return_value="yes") - - context = PipelineContext([sample_df], config) - - with pytest.raises(InvalidOutputValueMismatch): - input_validator.execute(input="test", context=context, logger=logger) diff --git a/tests/unit_tests/ee/security_agent/test_security_prompt_gen.py b/tests/unit_tests/ee/security_agent/test_security_prompt_gen.py deleted file mode 100644 index 10d25f651..000000000 --- a/tests/unit_tests/ee/security_agent/test_security_prompt_gen.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import Optional -from unittest.mock import patch - -import pandas as pd -import pytest - -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.advanced_security_agent.pipeline.advanced_security_prompt_generation import ( - AdvancedSecurityPromptGeneration, -) -from pandasai.helpers.logger import Logger -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.judge.judge_pipeline_input import JudgePipelineInput -from pandasai.pipelines.pipeline_context import PipelineContext -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA, VIZ_QUERY_SCHEMA_STR - - -class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return VIZ_QUERY_SCHEMA_STR - - -class TestSecurityPromptGeneration: - "Unit test for Validate Pipeline Input" - - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - @pytest.fixture - def logger(self): - return Logger(True, False) - - def test_init(self, context, config): - # Test the initialization of the CodeGenerator - code_generator = AdvancedSecurityPromptGeneration() - assert isinstance(code_generator, AdvancedSecurityPromptGeneration) - - def test_validate_input_semantic_prompt(self, sample_df, context, logger): - semantic_prompter = AdvancedSecurityPromptGeneration() - - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df], config) - - context.memory.add("hello word!", True) - - context.add("df_schema", VIZ_QUERY_SCHEMA) - - input_data = JudgePipelineInput( - query="What is test?", code="print('Code Data')" - ) - - response = semantic_prompter.execute(input_data, context=context, logger=logger) - - print(response.output.to_string()) - assert ( - response.output.to_string() - == """As an Security Agent, it's crucial to ensure that user queries do not generate malicious code that could harm systems or data. Analyze each statement and word thoroughly to check whether it can generate malicious code or not. -When analyzing a user query, follow these guidelines to identify potentially harmful code patterns: - -Code Injection: Look for attempts to inject code into a system, especially commands that interact with the file system, execute shell commands, or access sensitive data. User can never ask to append or execute any particular code. -File Operations: Be wary of commands that read from or write to the file system, especially when involving user-provided paths. Ensure that the code never updates any file. -Network Operations: Identify code that makes network requests or opens network connections. Verify that such operations are safe and necessary. -Data Manipulation: Ensure that code handling data manipulation does not include operations that could lead to data leaks, corruption, or unauthorized access. -Execution Control: Detect attempts to execute arbitrary code or scripts, particularly those that could alter system behavior or gain elevated privileges. -Third-Party Libraries: Verify the safety of using third-party libraries and ensure they are from reputable sources and up to date. -SQL Commands: Be cautious of SQL commands that can update or manipulate a database, such as INSERT, UPDATE, DELETE, DROP, ALTER, and TRUNCATE. Any query involving these commands should be flagged as potentially harmful. - -Given a user query, identify any suspicious or potentially harmful code patterns following the guidelines above. - -Your Task: -Analyze and reason the following user query strictly for potential malicious code can be generated patterns based on the guidelines provided. - -User Query: -JudgePipelineInput(query='What is test?', code="print('Code Data')") - -Always return or in tags <>, and provide a brief explanation if .""" - ) diff --git a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py b/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py deleted file mode 100644 index fb1f261b6..000000000 --- a/tests/unit_tests/ee/semantic_agent/test__semantic_code_generator.py +++ /dev/null @@ -1,510 +0,0 @@ -from typing import Optional - -import pandas as pd -import pytest - -from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator -from pandasai.helpers.logger import Logger -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext -from pandasai.schemas.df_config import Config -from tests.unit_tests.ee.helpers.schema import STARS_SCHEMA, VIZ_QUERY_SCHEMA - - -class TestSemanticCodeGenerator: - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - def logger(self): - return Logger() - - @pytest.fixture - def config_with_direct_sql(self): - return Config( - llm=FakeLLM(output=""), - enable_cache=False, - direct_sql=True, - ) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - def test_generate_matplolib_par_code( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", VIZ_QUERY_SCHEMA) - json_str = { - "type": "bar", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.order_count"], - "timeDimensions": [], - "options": { - "xLabel": "Country", - "yLabel": "Number of Orders", - "title": "Orders Count by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [], - "order": [{"id": "Orders.order_count", "direction": "asc"}], - } - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT `orders`.`ship_country` AS ship_country, COUNT(`orders`.`order_count`) AS order_count FROM `orders` GROUP BY ship_country ORDER BY order_count asc" -data = execute_sql_query(sql_query) - -plt.bar(data["ship_country"], data["order_count"], label="order_count") -plt.xlabel('''Country''') -plt.ylabel('''Number of Orders''') -plt.title('''Orders Count by Country''') -plt.legend(loc='best') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_matplolib_pie_chart_code( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", VIZ_QUERY_SCHEMA) - json_str = { - "type": "pie", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.order_count"], - "timeDimensions": [], - "options": { - "title": "Orders Count by Country", - "legend": {"display": True, "position": "top"}, - }, - "filters": [], - } - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT `orders`.`ship_country` AS ship_country, COUNT(`orders`.`order_count`) AS order_count FROM `orders` GROUP BY ship_country" -data = execute_sql_query(sql_query) - -plt.pie(data["order_count"], labels=data["ship_country"], autopct='%1.1f%%') -plt.title('''Orders Count by Country''') -plt.legend(loc='best') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_matplolib_line_chart_code( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", VIZ_QUERY_SCHEMA) - json_str = { - "type": "line", - "dimensions": ["Orders.order_date"], - "measures": ["Orders.order_count"], - "timeDimensions": [], - "options": { - "xLabel": "Order Date", - "yLabel": "Number of Orders", - "title": "Orders Over Time", - "legend": {"display": True, "position": "top"}, - }, - "filters": [], - "order": [], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT `orders`.`order_date` AS order_date, COUNT(`orders`.`order_count`) AS order_count FROM `orders` GROUP BY order_date" -data = execute_sql_query(sql_query) - -plt.plot(data["order_date"], data["order_count"]) -plt.xlabel('''Order Date''') -plt.ylabel('''Number of Orders''') -plt.title('''Orders Over Time''') -plt.legend(loc='best') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_matplolib_scatter_chart_code( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", VIZ_QUERY_SCHEMA) - json_str = { - "type": "scatter", - "dimensions": ["Orders.order_date", "Orders.ship_via"], - "measures": [], - "timeDimensions": [], - "options": {"title": "Total Freight by Order Date"}, - "filters": [], - "order": [], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT `orders`.`order_date` AS order_date, `orders`.`ship_via` AS ship_via FROM `orders` GROUP BY order_date, ship_via" -data = execute_sql_query(sql_query) - -plt.scatter(data['order_date'], data['ship_via']) -plt.title('''Total Freight by Order Date''') -plt.legend(loc='best') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_matplolib_histogram_chart_code( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", VIZ_QUERY_SCHEMA) - json_str = { - "type": "histogram", - "dimensions": [], - "measures": ["Orders.total_freight"], - "timeDimensions": [], - "options": { - "xLabel": "Total Freight", - "yLabel": "Frequency", - "title": "Distribution of Total Freight", - "legend": {"display": False}, - "bins": 30, - }, - "filters": [], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT SUM(`orders`.`freight`) AS total_freight FROM `orders`" -data = execute_sql_query(sql_query) - -plt.hist(data['total_freight']) -plt.xlabel('''Total Freight''') -plt.ylabel('''Frequency''') -plt.title('''Distribution of Total Freight''') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_matplolib_boxplot_chart_code( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", VIZ_QUERY_SCHEMA) - json_str = { - "type": "boxplot", - "dimensions": ["Orders.ship_country"], - "measures": ["Orders.total_freight"], - "timeDimensions": [], - "options": { - "xLabel": "Shipping Country", - "yLabel": "Total Freight", - "title": "Distribution of Total Freight by Shipping Country", - "legend": {"display": False}, - }, - "filters": [], - "order": [], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT `orders`.`ship_country` AS ship_country, SUM(`orders`.`freight`) AS total_freight FROM `orders` GROUP BY ship_country" -data = execute_sql_query(sql_query) - -plt.boxplot(data['total_freight']) -plt.xlabel('''Shipping Country''') -plt.ylabel('''Total Freight''') -plt.title('''Distribution of Total Freight by Shipping Country''') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_matplolib_number_type( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", VIZ_QUERY_SCHEMA) - json_str = { - "type": "number", - "measures": ["Orders.order_count"], - "timeDimensions": [], - "options": {"title": "Total Orders Count"}, - "filters": [], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - print(logic_unit.output) - assert ( - logic_unit.output - == """ - -import pandas as pd - -sql_query="SELECT COUNT(`orders`.`order_count`) AS order_count FROM `orders`" -data = execute_sql_query(sql_query) - - -total_value = data["order_count"].sum() - -result = {"type": "number","value": total_value} - -""" - ) - - def test_generate_timedimension_query( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", STARS_SCHEMA) - json_str = { - "type": "line", - "measures": ["Users.user_count"], - "timeDimensions": [ - { - "dimension": "Users.starred_at", - "dateRange": ["2022-01-01", "2023-03-31"], - "granularity": "month", - } - ], - "options": { - "xLabel": "Month", - "yLabel": "Number of Stars", - "title": "Stars Count per Month", - "legend": {"display": True, "position": "bottom"}, - }, - "filters": [], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - print(logic_unit.output) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2022-01-01' AND '2023-03-31' GROUP BY starred_at_by_month" -data = execute_sql_query(sql_query) - -plt.plot(data["starred_at_by_month"], data["user_count"]) -plt.xlabel('''Month''') -plt.ylabel('''Number of Stars''') -plt.title('''Stars Count per Month''') -plt.legend(loc='best') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_timedimension_for_year( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", STARS_SCHEMA) - json_str = { - "type": "line", - "measures": ["Users.user_count"], - "timeDimensions": [ - { - "dimension": "Users.starred_at", - "dateRange": ["this year"], - "granularity": "month", - } - ], - "options": { - "xLabel": "Time Period", - "yLabel": "Stars Count", - "title": "Stars Count Per Month This Year", - "legend": {"display": True, "position": "bottom"}, - }, - "filters": [], - "order": [{"id": "Users.starred_at", "direction": "asc"}], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - print(logic_unit.output) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` >= DATE_TRUNC('year', CURRENT_DATE) AND `users`.`starredAt` < DATE_TRUNC('year', CURRENT_DATE) + INTERVAL '1 year' GROUP BY starred_at_by_month ORDER BY starred_at_by_month asc" -data = execute_sql_query(sql_query) - -plt.plot(data["starred_at_by_month"], data["user_count"]) -plt.xlabel('''Time Period''') -plt.ylabel('''Stars Count''') -plt.title('''Stars Count Per Month This Year''') -plt.legend(loc='best') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) - - def test_generate_timedimension_histogram_for_year( - self, context: PipelineContext, logger: Logger - ): - code_gen = CodeGenerator() - context.add("df_schema", STARS_SCHEMA) - json_str = { - "type": "histogram", - "dimensions": ["Users.starred_at"], - "measures": ["Users.user_count"], - "timeDimensions": [ - { - "dimension": "Users.starred_at", - "dateRange": ["2023-01-01", "2023-12-31"], - "granularity": "month", - } - ], - "options": { - "xLabel": "Starred Month", - "yLabel": "Number of Users", - "title": "Distribution of Stars per Month in 2023", - "legend": {"display": False}, - }, - "filters": [], - "order": [{"id": "Users.starred_at", "direction": "asc"}], - } - - logic_unit = code_gen.execute(json_str, context=context, logger=logger) - assert isinstance(logic_unit, LogicUnitOutput) - assert ( - logic_unit.output - == """ -import matplotlib.pyplot as plt -import pandas as pd - -sql_query="SELECT `users`.`starredAt` AS starred_at, COUNT(`users`.`login`) AS user_count, DATE_FORMAT(`users`.`starredAt`, '%Y-%m') AS starred_at_by_month FROM `users` WHERE `users`.`starredAt` BETWEEN '2023-01-01' AND '2023-12-31' GROUP BY starred_at, starred_at_by_month ORDER BY starred_at_by_month asc" -data = execute_sql_query(sql_query) - -plt.hist(data['user_count']) -plt.xlabel('''Starred Month''') -plt.ylabel('''Number of Users''') -plt.title('''Distribution of Stars per Month in 2023''') - - -plt.savefig("charts.png") - -result = {"type": "plot","value": "charts.png"} -""" - ) diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py b/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py deleted file mode 100644 index c30868949..000000000 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_agent.py +++ /dev/null @@ -1,162 +0,0 @@ -from unittest.mock import MagicMock, patch, PropertyMock - -import pandasai as pai -import pandas as pd -import pytest -import os - -from pandasai.agent import Agent -from pandasai.agent.base import BaseAgent -from pandasai.ee.agents.semantic_agent import SemanticAgent -from pandasai.exceptions import InvalidTrainJson -from pandasai.llm.fake import FakeLLM -from tests.unit_tests.ee.helpers.schema import ( - VIZ_QUERY_SCHEMA, - VIZ_QUERY_SCHEMA_STR, -) -from pandasai.dataframe.base import DataFrame - - -class TestSemanticAgent: - "Unit tests for Agent class" - - @pytest.fixture - def sample_df(self): - df = pai.DataFrame( - { - "order_id": [10248, 10249, 10250], - "customer_id": ["VINET", "TOMSP", "HANAR"], - "employee_id": [5, 6, 4], - "order_date": pd.to_datetime( - ["1996-07-04", "1996-07-05", "1996-07-08"] - ), - "required_date": pd.to_datetime( - ["1996-08-01", "1996-08-16", "1996-08-05"] - ), - "shipped_date": pd.to_datetime( - ["1996-07-16", "1996-07-10", "1996-07-12"] - ), - "ship_via": [3, 1, 2], - "freight": [32.38, 11.61, 65.83], - "ship_name": [ - "Vins et alcools Chevalier", - "Toms Spezialitäten", - "Hanari Carnes", - ], - "ship_address": [ - "59 rue de l'Abbaye", - "Luisenstr. 48", - "Rua do Paço, 67", - ], - "ship_city": ["Reims", "Münster", "Rio de Janeiro"], - "ship_region": ["CJ", None, "RJ"], - "ship_postal_code": ["51100", "44087", "05454-876"], - "ship_country": ["France", "Germany", "Brazil"], - } - ) - return DataFrame(df) - - @pytest.fixture - def llm(self) -> FakeLLM: - return FakeLLM(output=VIZ_QUERY_SCHEMA_STR) - - @pytest.fixture - def agent(self, sample_df: DataFrame, llm: FakeLLM) -> Agent: - with patch.dict(os.environ, {"PANDASAI_API_KEY": "test_key"}), patch( - "pandasai.ee.agents.semantic_agent.SemanticAgent._create_schema" - ) as mock_create_schema: - mock_create_schema.return_value = None - return SemanticAgent( - sample_df, config={"llm": llm}, vectorstore=MagicMock() - ) - - def test_base_agent_construct(self, sample_df, llm): - BaseAgent(sample_df, {"llm": llm}, vectorstore=MagicMock()) - - def test_base_agent_log_id_register_agent(self, sample_df, llm): - with patch.dict(os.environ, {"PANDASAI_API_KEY": "test_key"}), patch( - "pandasai.ee.agents.semantic_agent.SemanticAgent._create_schema" - ) as mock_create_schema, patch("uuid.uuid4") as mock_uuid: - mock_create_schema.return_value = None - mock_uuid.return_value = "test-uuid" - agent = SemanticAgent( - sample_df, {"llm": llm, "enable_cache": False}, vectorstore=MagicMock() - ) - agent.context.config.__dict__["log_id"] = "test-uuid" - assert agent.context.config.__dict__["log_id"] == "test-uuid" - - def test_constructor_with_no_bamboo(self, sample_df): - non_bamboo_llm = FakeLLM(output=VIZ_QUERY_SCHEMA_STR, type="fake") - with pytest.raises(Exception): - SemanticAgent( - sample_df, - {"llm": non_bamboo_llm, "enable_cache": False}, - vectorstore=MagicMock(), - ) - - def test_constructor(self, sample_df, llm): - with patch.dict(os.environ, {"PANDASAI_API_KEY": "test_key"}), patch( - "pandasai.ee.agents.semantic_agent.SemanticAgent._create_schema" - ) as mock_create_schema: - mock_create_schema.return_value = None - agent = SemanticAgent( - sample_df, {"llm": llm, "enable_cache": False}, vectorstore=MagicMock() - ) - assert agent.context.config.llm == llm - - def test_last_error(self, sample_df, llm): - with patch.dict(os.environ, {"PANDASAI_API_KEY": "test_key"}), patch( - "pandasai.ee.agents.semantic_agent.SemanticAgent._create_schema" - ) as mock_create_schema, patch.object( - BaseAgent, "last_error", new_callable=PropertyMock - ) as mock_last_error: - mock_create_schema.return_value = None - mock_last_error.return_value = None - agent = SemanticAgent(sample_df, {"llm": llm}, vectorstore=MagicMock()) - assert agent.last_error is None - - @patch("pandasai.helpers.cache.Cache.get") - def test_cache_of_schema(self, mock_cache_get, sample_df, llm): - mock_cache_get.return_value = VIZ_QUERY_SCHEMA_STR - - agent = SemanticAgent(sample_df, {"llm": llm}, vectorstore=MagicMock()) - - assert not llm.called - assert agent._schema == VIZ_QUERY_SCHEMA - - def test_train_method_with_qa(self, agent): - queries = ["query1"] - jsons = ['{"name": "test"}'] - agent.train(queries=queries, jsons=jsons) - - agent._vectorstore.add_docs.assert_not_called() - agent._vectorstore.add_question_answer.assert_called_once_with(queries, jsons) - - def test_train_method_with_docs(self, agent): - docs = ["doc1"] - agent.train(docs=docs) - - agent._vectorstore.add_question_answer.assert_not_called() - agent._vectorstore.add_docs.assert_called_once() - agent._vectorstore.add_docs.assert_called_once_with(docs) - - def test_train_method_with_docs_and_qa(self, agent): - docs = ["doc1"] - queries = ["query1"] - jsons = ['{"name": "test"}'] - agent.train(queries, jsons, docs=docs) - - agent._vectorstore.add_question_answer.assert_called_once() - agent._vectorstore.add_question_answer.assert_called_once_with(queries, jsons) - agent._vectorstore.add_docs.assert_called_once() - agent._vectorstore.add_docs.assert_called_once_with(docs) - - def test_train_method_with_queries_but_no_code(self, agent): - queries = ["query1", "query2"] - with pytest.raises(ValueError): - agent.train(queries) - - def test_train_method_with_code_but_no_queries(self, agent): - jsons = ["code1", "code2"] - with pytest.raises(InvalidTrainJson): - agent.train(jsons=jsons) diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py b/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py deleted file mode 100644 index 3707f61c4..000000000 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_llm_call.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import Optional -from unittest.mock import patch - -import pandas as pd -import pytest - -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall -from pandasai.helpers.logger import Logger -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.pipeline_context import PipelineContext -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA_STR - - -class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return VIZ_QUERY_SCHEMA_STR - - -class TestSemanticLLMCall: - "Unit test for Validate Pipeline Input" - - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - @pytest.fixture - def logger(self): - return Logger(True, False) - - def test_init(self, context, config): - # Test the initialization of the CodeGenerator - code_generator = LLMCall() - assert isinstance(code_generator, LLMCall) - - def test_validate_input_llm_call(self, sample_df, context, logger): - input_validator = LLMCall() - - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df], config) - - input_validator.execute(input="test", context=context, logger=logger) - - def test_validate_input_with_direct_sql_false_and_non_connector( - self, sample_df, logger - ): - input_validator = LLMCall() - - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert result.output == [ - { - "name": "Orders", - "table": "orders", - "measures": [ - {"name": "order_count", "type": "count"}, - {"name": "total_freight", "type": "sum", "sql": "freight"}, - ], - "dimensions": [ - {"name": "order_id", "type": "int", "sql": "order_id"}, - {"name": "customer_id", "type": "string", "sql": "customer_id"}, - {"name": "employee_id", "type": "int", "sql": "employee_id"}, - {"name": "order_date", "type": "date", "sql": "order_date"}, - {"name": "required_date", "type": "date", "sql": "required_date"}, - {"name": "shipped_date", "type": "date", "sql": "shipped_date"}, - {"name": "ship_via", "type": "int", "sql": "ship_via"}, - {"name": "ship_name", "type": "string", "sql": "ship_name"}, - {"name": "ship_address", "type": "string", "sql": "ship_address"}, - {"name": "ship_city", "type": "string", "sql": "ship_city"}, - {"name": "ship_region", "type": "string", "sql": "ship_region"}, - { - "name": "ship_postal_code", - "type": "string", - "sql": "ship_postal_code", - }, - {"name": "ship_country", "type": "string", "sql": "ship_country"}, - ], - "joins": [], - } - ] - - def test_validate_input_llm_call_raise_exception(self, sample_df, context, logger): - input_validator = LLMCall() - - class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return "Hello World!" - - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df], config) - - with pytest.raises(Exception): - input_validator.execute(input="test", context=context, logger=logger) diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py b/tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py deleted file mode 100644 index 594141834..000000000 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_semantic_prompt_gen.py +++ /dev/null @@ -1,163 +0,0 @@ -from typing import Optional -from unittest.mock import patch - -import pandas as pd -import pytest - -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import ( - SemanticPromptGeneration, -) -from pandasai.helpers.logger import Logger -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.pipeline_context import PipelineContext -from tests.unit_tests.ee.helpers.schema import VIZ_QUERY_SCHEMA, VIZ_QUERY_SCHEMA_STR - - -class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return VIZ_QUERY_SCHEMA_STR - - -class TestSemanticPromptGeneration: - "Unit test for Validate Pipeline Input" - - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - @pytest.fixture - def logger(self): - return Logger(True, False) - - def test_init(self, context, config): - # Test the initialization of the CodeGenerator - code_generator = SemanticPromptGeneration() - assert isinstance(code_generator, SemanticPromptGeneration) - - def test_validate_input_semantic_prompt(self, sample_df, context, logger): - semantic_prompter = SemanticPromptGeneration() - - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df], config) - - context.memory.add("hello word!", True) - - context.add("df_schema", VIZ_QUERY_SCHEMA) - - response = semantic_prompter.execute( - input="test", context=context, logger=logger - ) - - assert ( - response.output.to_string() - == """=== SemanticAgent === - - -# SCHEMA -[{"name": "Orders", "table": "orders", "measures": [{"name": "order_count", "type": "count"}, {"name": "total_freight", "type": "sum", "sql": "freight"}], "dimensions": [{"name": "order_id", "type": "int", "sql": "order_id"}, {"name": "customer_id", "type": "string", "sql": "customer_id"}, {"name": "employee_id", "type": "int", "sql": "employee_id"}, {"name": "order_date", "type": "date", "sql": "order_date"}, {"name": "required_date", "type": "date", "sql": "required_date"}, {"name": "shipped_date", "type": "date", "sql": "shipped_date"}, {"name": "ship_via", "type": "int", "sql": "ship_via"}, {"name": "ship_name", "type": "string", "sql": "ship_name"}, {"name": "ship_address", "type": "string", "sql": "ship_address"}, {"name": "ship_city", "type": "string", "sql": "ship_city"}, {"name": "ship_region", "type": "string", "sql": "ship_region"}, {"name": "ship_postal_code", "type": "string", "sql": "ship_postal_code"}, {"name": "ship_country", "type": "string", "sql": "ship_country"}], "joins": []}] - -### QUERY - hello word!""" - ) diff --git a/tests/unit_tests/ee/semantic_agent/test_semantic_validate_pipeline_input.py b/tests/unit_tests/ee/semantic_agent/test_semantic_validate_pipeline_input.py deleted file mode 100644 index 9206e3a17..000000000 --- a/tests/unit_tests/ee/semantic_agent/test_semantic_validate_pipeline_input.py +++ /dev/null @@ -1,221 +0,0 @@ -from typing import Optional -from unittest.mock import patch - -import pandas as pd -import pytest - -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) -from pandasai.ee.agents.semantic_agent.pipeline.validate_pipeline_input import ( - ValidatePipelineInput, -) -from pandasai.exceptions import InvalidConfigError -from pandasai.helpers.logger import Logger -from pandasai.llm.bamboo_llm import BambooLLM -from pandasai.llm.fake import FakeLLM -from pandasai.pipelines.logic_unit_output import LogicUnitOutput -from pandasai.pipelines.pipeline_context import PipelineContext - - -class MockBambooLLM(BambooLLM): - def __init__(self): - pass - - def call(self, *args, **kwargs): - return "Mock llm" - - -class TestSemanticValidatePipelineInput: - "Unit test for Validate Pipeline Input" - - @pytest.fixture - def llm(self, output: Optional[str] = None): - return FakeLLM(output=output) - - @pytest.fixture - def sample_df(self): - return pd.DataFrame( - { - "country": [ - "United States", - "United Kingdom", - "France", - "Germany", - "Italy", - "Spain", - "Canada", - "Australia", - "Japan", - "China", - ], - "gdp": [ - 19294482071552, - 2891615567872, - 2411255037952, - 3435817336832, - 1745433788416, - 1181205135360, - 1607402389504, - 1490967855104, - 4380756541440, - 14631844184064, - ], - "happiness_index": [ - 6.94, - 7.16, - 6.66, - 7.07, - 6.38, - 6.4, - 7.23, - 7.22, - 5.87, - 5.12, - ], - } - ) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) - - @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) - - @pytest.fixture - def config(self, llm): - return {"llm": llm, "enable_cache": True} - - @pytest.fixture - def context(self, sample_df, config): - return PipelineContext([sample_df], config) - - @pytest.fixture - def logger(self): - return Logger(True, False) - - def test_init(self, context, config): - # Test the initialization of the CodeGenerator - code_generator = ValidatePipelineInput() - assert isinstance(code_generator, ValidatePipelineInput) - - def test_validate_input_without_bamboo_llm(self, context, logger): - input_validator = ValidatePipelineInput() - - with pytest.raises(InvalidConfigError): - input_validator.execute(input="test", context=context, logger=logger) - - def test_validate_input_with_direct_sql_false_and_non_connector( - self, sample_df, logger - ): - input_validator = ValidatePipelineInput() - - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df], config) - - result = input_validator.execute(input="test", context=context, logger=logger) - - assert result.output == "test" - - def test_validate_input_with_direct_sql_true_and_non_connector( - self, sample_df, llm, logger - ): - input_validator = ValidatePipelineInput() - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": True} - - context = PipelineContext([sample_df], config) - with pytest.raises(InvalidConfigError): - input_validator.execute(input="test", context=context, logger=logger) - - def test_validate_input_with_direct_sql_false_and_connector( - self, sample_df, llm, logger, sql_connector - ): - input_validator = ValidatePipelineInput() - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": False} - - context = PipelineContext([sample_df, sql_connector], config) - result = input_validator.execute(input="test", context=context, logger=logger) - assert isinstance(result, LogicUnitOutput) - assert result.output == "test" - - def test_validate_input_with_direct_sql_true_and_connector( - self, sample_df, llm, logger, sql_connector - ): - input_validator = ValidatePipelineInput() - llm = MockBambooLLM() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": True} - - context = PipelineContext([sql_connector], config) - result = input_validator.execute(input="test", context=context, logger=logger) - assert isinstance(result, LogicUnitOutput) - assert result.output == "test" - - def test_validate_input_with_direct_sql_true_and_connector_pandasdf( - self, sample_df, llm, logger, sql_connector - ): - input_validator = ValidatePipelineInput() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": True} - - context = PipelineContext([sample_df, sql_connector], config) - with pytest.raises(InvalidConfigError): - input_validator.execute(input="test", context=context, logger=logger) - - def test_validate_input_with_direct_sql_true_and_different_type_connector( - self, pgsql_connector, llm, logger, sql_connector - ): - input_validator = ValidatePipelineInput() - - # context for true config - config = {"llm": llm, "enable_cache": True, "direct_sql": True} - - context = PipelineContext([pgsql_connector, sql_connector], config) - with pytest.raises(InvalidConfigError): - input_validator.execute(input="test", context=context, logger=logger) diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py index 3cf64b7df..5407b5a89 100644 --- a/tests/unit_tests/helpers/test_dataframe_serializer.py +++ b/tests/unit_tests/helpers/test_dataframe_serializer.py @@ -1,8 +1,6 @@ import unittest -import pandas as pd - -from pandasai.connectors import PandasConnector +from pandasai.dataframe.base import DataFrame from pandasai.helpers.dataframe_serializer import ( DataframeSerializer, DataframeSerializerType, @@ -16,15 +14,27 @@ def setUp(self): def test_convert_df_to_yml(self): # Test convert df to yml data = {"name": ["en_name", "中文_名称"]} - connector = PandasConnector( - {"original_df": pd.DataFrame(data)}, - name="en_table_name", - description="中文_描述", - field_descriptions={k: k for k in data}, - ) + connector = DataFrame(data, name="en_table_name", description="中文_描述") result = self.serializer.serialize( connector, type_=DataframeSerializerType.YML, extras={"index": 0, "type": "pd.Dataframe"}, ) - self.assertIn("中文_描述", result) + print(result) + self.assertIn( + """dfs[0]: + name: en_table_name + description: null + type: pd.Dataframe + rows: 2 + columns: 1 + schema: + fields: + - name: name + type: object + samples: + - en_name + - 中文_名称 +""", + result, + ) diff --git a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py index 8302f13d4..dfebfd2dc 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py @@ -9,12 +9,7 @@ import pytest from pandasai import Agent -from pandasai.connectors.pandas import PandasConnector -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) +from pandasai.dataframe.base import DataFrame from pandasai.exceptions import ( BadImportError, InvalidConfigError, @@ -114,7 +109,7 @@ def agent(self, llm, sample_df): return Agent([sample_df], config={"llm": llm, "enable_cache": False}) @pytest.fixture - def agent_with_connector(self, llm, pgsql_connector: PostgreSQLConnector): + def agent_with_connector(self, llm, pgsql_connector: DataFrame): return Agent( [pgsql_connector], config={"llm": llm, "enable_cache": False, "direct_sql": True}, @@ -131,40 +126,12 @@ def exec_context(self) -> MagicMock: @pytest.fixture @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) + return DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) @pytest.fixture @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config, name="your_table") + return DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) def test_run_code_for_calculations( self, @@ -513,7 +480,7 @@ def test_check_is_query_using_relevant_table_multiple_tables_one_unknown( def test_clean_code_using_correct_sql_table( self, - pgsql_connector: PostgreSQLConnector, + pgsql_connector: DataFrame, context: PipelineContext, logger: Logger, ): @@ -536,7 +503,7 @@ def test_clean_code_using_correct_sql_table( def test_clean_code_with_no_execute_sql_query_usage_script( self, - pgsql_connector: PostgreSQLConnector, + pgsql_connector: DataFrame, context: PipelineContext, logger: Logger, ): @@ -554,7 +521,7 @@ def test_clean_code_with_no_execute_sql_query_usage_script( def test_clean_code_using_incorrect_sql_table( self, - pgsql_connector: PostgreSQLConnector, + pgsql_connector: DataFrame, context: PipelineContext, logger, ): @@ -571,7 +538,7 @@ def test_clean_code_using_incorrect_sql_table( def test_clean_code_using_multi_incorrect_sql_table( self, - pgsql_connector: PostgreSQLConnector, + pgsql_connector: DataFrame, context: PipelineContext, logger: Logger, ): @@ -585,11 +552,8 @@ def test_clean_code_using_multi_incorrect_sql_table( assert str(excinfo.value) == ("Query uses unauthorized table: table1.") - @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext): - df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - mock_head.return_value = df - pandas_connector = PandasConnector({"original_df": df}) + pandas_connector = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] @@ -608,13 +572,10 @@ def test_fix_dataframe_redeclarations(self, mock_head, context: PipelineContext) assert isinstance(output, ast.Assign) - @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_multiline_redeclarations( self, mock_head, context: PipelineContext ): - df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - mock_head.return_value = df - pandas_connector = PandasConnector({"original_df": df}) + pandas_connector = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] @@ -642,11 +603,8 @@ def test_fix_dataframe_multiline_redeclarations( assert isinstance(outputs[1], ast.Assign) assert outputs[2] is None - @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_no_redeclarations(self, mock_head, context: PipelineContext): - df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - mock_head.return_value = df - pandas_connector = PandasConnector({"original_df": df}) + pandas_connector = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] @@ -665,13 +623,10 @@ def test_fix_dataframe_no_redeclarations(self, mock_head, context: PipelineConte assert output is None - @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_redeclarations_with_subscript( self, mock_head, context: PipelineContext ): - df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - mock_head.return_value = df - pandas_connector = PandasConnector({"original_df": df}) + pandas_connector = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] @@ -690,7 +645,6 @@ def test_fix_dataframe_redeclarations_with_subscript( assert isinstance(output, ast.Assign) - @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_redeclarations_with_subscript_and_data_variable( self, mock_head, context: PipelineContext ): @@ -698,9 +652,7 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable( "country": ["China", "United States", "Japan", "Germany", "United Kingdom"], "sales": [8000, 6000, 4000, 3500, 3000], } - df = pd.DataFrame(data) - mock_head.return_value = df - pandas_connector = PandasConnector({"original_df": df}) + pandas_connector = DataFrame(data) code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] @@ -724,7 +676,6 @@ def test_fix_dataframe_redeclarations_with_subscript_and_data_variable( assert isinstance(output, ast.Assign) - @patch("pandasai.connectors.pandas.PandasConnector.head") def test_fix_dataframe_redeclarations_and_data_variable( self, mock_head, context: PipelineContext ): @@ -732,9 +683,7 @@ def test_fix_dataframe_redeclarations_and_data_variable( "country": ["China", "United States", "Japan", "Germany", "United Kingdom"], "sales": [8000, 6000, 4000, 3500, 3000], } - df = pd.DataFrame(data) - mock_head.return_value = df - pandas_connector = PandasConnector({"original_df": df}) + pandas_connector = DataFrame(data) code_cleaning = CodeCleaning() code_cleaning._dfs = [pandas_connector] diff --git a/tests/unit_tests/pipelines/smart_datalake/test_code_generator.py b/tests/unit_tests/pipelines/smart_datalake/test_code_generator.py index 627d458f1..3415f386d 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_code_generator.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_code_generator.py @@ -1,9 +1,9 @@ from typing import Optional from unittest.mock import Mock, patch -import pandas as pd import pytest +from pandasai.dataframe.base import DataFrame from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM from pandasai.pipelines.chat.code_generator import CodeGenerator @@ -20,7 +20,7 @@ def llm(self, output: Optional[str] = None): @pytest.fixture def sample_df(self): - return pd.DataFrame( + return DataFrame( { "country": [ "United States", diff --git a/tests/unit_tests/pipelines/smart_datalake/test_error_prompt_generation.py b/tests/unit_tests/pipelines/smart_datalake/test_error_prompt_generation.py index 1cddaca99..3f253b4a7 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_error_prompt_generation.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_error_prompt_generation.py @@ -1,9 +1,9 @@ from typing import Optional from unittest.mock import MagicMock -import pandas as pd import pytest +from pandasai.dataframe.base import DataFrame from pandasai.exceptions import InvalidLLMOutputType from pandasai.llm.fake import FakeLLM from pandasai.pipelines.chat.error_correction_pipeline.error_prompt_generation import ( @@ -25,7 +25,7 @@ def llm(self, output: Optional[str] = None): @pytest.fixture def sample_df(self): - return pd.DataFrame( + return DataFrame( { "country": [ "United States", diff --git a/tests/unit_tests/pipelines/smart_datalake/test_prompt_generation.py b/tests/unit_tests/pipelines/smart_datalake/test_prompt_generation.py index 6a41fbc66..ddc7222a3 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_prompt_generation.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_prompt_generation.py @@ -3,7 +3,7 @@ import pandas as pd import pytest -from pandasai.connectors import PandasConnector +from pandasai.dataframe.base import DataFrame from pandasai.helpers.dataframe_serializer import DataframeSerializerType from pandasai.llm.fake import FakeLLM from pandasai.pipelines.chat.prompt_generation import PromptGeneration @@ -66,7 +66,7 @@ def sample_df(self): @pytest.fixture def dataframe(self, sample_df): - return PandasConnector({"original_df": sample_df}) + return DataFrame(sample_df) @pytest.fixture def config(self, llm): @@ -118,11 +118,10 @@ def test_get_chat_prompt_enforce_privacy_true_custom_head(self, context, sample_ # Test case 1: direct_sql is True prompt_generation = PromptGeneration() context.config.enforce_privacy = True - context.config.dataframe_serializer = DataframeSerializerType.YML + context.config.dataframe_serializer = DataframeSerializerType.CSV - dataframe = PandasConnector({"original_df": sample_df}, custom_head=sample_df) + dataframe = DataFrame(sample_df) context.dfs = [dataframe] gen_prompt = prompt_generation.get_chat_prompt(context) assert isinstance(gen_prompt, GeneratePythonCodePrompt) - assert "samples" in gen_prompt.to_string() diff --git a/tests/unit_tests/pipelines/smart_datalake/test_result_parsing.py b/tests/unit_tests/pipelines/smart_datalake/test_result_parsing.py index 8190ff63b..cb0195422 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_result_parsing.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_result_parsing.py @@ -1,9 +1,9 @@ from typing import Optional from unittest.mock import Mock -import pandas as pd import pytest +from pandasai.dataframe.base import DataFrame from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM from pandasai.pipelines.chat.result_parsing import ResultParsing @@ -21,7 +21,7 @@ def llm(self, output: Optional[str] = None): @pytest.fixture def sample_df(self): - return pd.DataFrame( + return DataFrame( { "country": [ "United States", diff --git a/tests/unit_tests/pipelines/smart_datalake/test_result_validation.py b/tests/unit_tests/pipelines/smart_datalake/test_result_validation.py index 541226bad..396cad90d 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_result_validation.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_result_validation.py @@ -1,9 +1,9 @@ from typing import Optional from unittest.mock import Mock -import pandas as pd import pytest +from pandasai.dataframe.base import DataFrame from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM from pandasai.pipelines.chat.result_validation import ResultValidation @@ -21,7 +21,7 @@ def llm(self, output: Optional[str] = None): @pytest.fixture def sample_df(self): - return pd.DataFrame( + return DataFrame( { "country": [ "United States", diff --git a/tests/unit_tests/pipelines/smart_datalake/test_validate_pipeline_input.py b/tests/unit_tests/pipelines/smart_datalake/test_validate_pipeline_input.py index f3f9c334e..426e7629a 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_validate_pipeline_input.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_validate_pipeline_input.py @@ -1,14 +1,9 @@ from typing import Optional -from unittest.mock import patch import pandas as pd import pytest -from extensions.connectors.sql.pandasai_sql.sql import ( - PostgreSQLConnector, - SQLConnector, - SQLConnectorConfig, -) +from pandasai.dataframe.base import DataFrame from pandasai.exceptions import InvalidConfigError from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM @@ -70,42 +65,90 @@ def sample_df(self): ) @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def sql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="mysql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return SQLConnector(self.config) + def sql_connector(self): + return DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) @pytest.fixture - @patch("extensions.connectors.sql.pandasai_sql.sql.create_engine", autospec=True) - def pgsql_connector(self, create_engine): - # Define your ConnectorConfig instance here - self.config = SQLConnectorConfig( - dialect="pgsql", - driver="pymysql", - username="your_username", - password="your_password", - host="your_host", - port=443, - database="your_database", - table="your_table", - where=[["column_name", "=", "value"]], - ).dict() - - # Create an instance of SQLConnector - return PostgreSQLConnector(self.config) + def pgsql_connector(self): + return DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) @pytest.fixture def config(self, llm): diff --git a/tests/unit_tests/pipelines/test_pipeline.py b/tests/unit_tests/pipelines/test_pipeline.py index 3e60e77a1..d8a0fd488 100644 --- a/tests/unit_tests/pipelines/test_pipeline.py +++ b/tests/unit_tests/pipelines/test_pipeline.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from pandasai.connectors import BaseConnector, PandasConnector +from pandasai.dataframe.base import DataFrame from pandasai.ee.agents.judge_agent import JudgeAgent from pandasai.helpers.logger import Logger from pandasai.llm.fake import FakeLLM @@ -70,7 +70,7 @@ def sample_df(self): @pytest.fixture def dataframe(self, sample_df): - return PandasConnector({"original_df": sample_df}) + return DataFrame(sample_df) @pytest.fixture def config(self, llm): @@ -97,14 +97,14 @@ def test_init_with_agent(self, dataframe, config): pipeline = Pipeline([dataframe], config=config) assert isinstance(pipeline, Pipeline) assert len(pipeline._context.dfs) == 1 - assert isinstance(pipeline._context.dfs[0], BaseConnector) + assert isinstance(pipeline._context.dfs[0], DataFrame) def test_init_with_dfs(self, dataframe, config): # Test the initialization of the Pipeline pipeline = Pipeline([dataframe], config=config) assert isinstance(pipeline, Pipeline) assert len(pipeline._context.dfs) == 1 - assert isinstance(pipeline._context.dfs[0], BaseConnector) + assert isinstance(pipeline._context.dfs[0], DataFrame) def test_add_step(self, context, config): # Test the add_step method diff --git a/tests/unit_tests/prompts/test_correct_error_prompt.py b/tests/unit_tests/prompts/test_correct_error_prompt.py index d39a54404..917d4d2e1 100644 --- a/tests/unit_tests/prompts/test_correct_error_prompt.py +++ b/tests/unit_tests/prompts/test_correct_error_prompt.py @@ -2,10 +2,8 @@ import sys -import pandas as pd - from pandasai import Agent -from pandasai.connectors import PandasConnector +from pandasai.dataframe.base import DataFrame from pandasai.helpers.dataframe_serializer import DataframeSerializerType from pandasai.llm.fake import FakeLLM from pandasai.prompts import CorrectErrorPrompt @@ -19,7 +17,7 @@ def test_str_with_args(self): llm = FakeLLM() agent = Agent( - dfs=[PandasConnector({"original_df": pd.DataFrame()})], + dfs=[DataFrame()], config={"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}, ) prompt = CorrectErrorPrompt( @@ -54,7 +52,7 @@ def test_to_json(self): llm = FakeLLM() agent = Agent( - dfs=[PandasConnector({"original_df": pd.DataFrame()})], + dfs=[DataFrame()], config={"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}, ) prompt = CorrectErrorPrompt( @@ -62,7 +60,7 @@ def test_to_json(self): ) assert prompt.to_json() == { - "datasets": [{"name": None, "description": None, "head": []}], + "datasets": ["{}"], "conversation": [], "system_prompt": None, "error": { diff --git a/tests/unit_tests/prompts/test_generate_python_code_prompt.py b/tests/unit_tests/prompts/test_generate_python_code_prompt.py index adf2e8961..c01cd111c 100644 --- a/tests/unit_tests/prompts/test_generate_python_code_prompt.py +++ b/tests/unit_tests/prompts/test_generate_python_code_prompt.py @@ -4,12 +4,10 @@ import sys from unittest.mock import patch -import pandas as pd import pytest from pandasai import Agent -from pandasai.connectors import PandasConnector -from pandasai.ee.connectors.relations import PrimaryKey +from pandasai.dataframe.base import DataFrame from pandasai.helpers.dataframe_serializer import DataframeSerializerType from pandasai.llm.fake import FakeLLM from pandasai.prompts import GeneratePythonCodePrompt @@ -61,7 +59,7 @@ def test_str_with_args(self, output_type, output_type_template): llm = FakeLLM() agent = Agent( - PandasConnector({"original_df": pd.DataFrame({"a": [1], "b": [4]})}), + DataFrame({"a": [1], "b": [4]}), config={"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}, ) prompt = GeneratePythonCodePrompt( @@ -150,7 +148,7 @@ def test_str_with_train_qa(self, chromadb_mock, output_type, output_type_templat chromadb_instance.get_relevant_qa_documents.return_value = [["query1"]] llm = FakeLLM() agent = Agent( - PandasConnector({"original_df": pd.DataFrame({"a": [1], "b": [4]})}), + DataFrame({"a": [1], "b": [4]}), config={"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}, ) agent.train(["query1"], ["code1"]) @@ -245,8 +243,9 @@ def test_str_with_train_docs( chromadb_instance.get_relevant_docs_documents.return_value = [["query1"]] llm = FakeLLM() agent = Agent( - PandasConnector({"original_df": pd.DataFrame({"a": [1], "b": [4]})}), + DataFrame({"a": [1], "b": [4]}), config={"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}, + vectorstore=chromadb_instance, ) agent.train(docs=["document1"]) prompt = GeneratePythonCodePrompt( @@ -344,8 +343,9 @@ def test_str_with_train_docs_and_qa( chromadb_instance.get_relevant_qa_documents.return_value = [["query1"]] llm = FakeLLM() agent = Agent( - PandasConnector({"original_df": pd.DataFrame({"a": [1], "b": [4]})}), + DataFrame({"a": [1], "b": [4]}), config={"llm": llm}, + vectorstore=chromadb_instance, ) agent.train(queries=["query1"], codes=["code1"], docs=["document1"]) prompt = GeneratePythonCodePrompt( @@ -412,8 +412,9 @@ def test_str_geenerate_code_prompt_to_json(self, chromadb_mock): chromadb_instance.get_relevant_qa_documents.return_value = [["query1"]] llm = FakeLLM() agent = Agent( - PandasConnector({"original_df": pd.DataFrame({"a": [1], "b": [4]})}), + DataFrame({"a": [1], "b": [4]}), config={"llm": llm}, + vectorstore=chromadb_instance, ) agent.train(queries=["query1"], codes=["code1"], docs=["document1"]) prompt = GeneratePythonCodePrompt( @@ -424,9 +425,7 @@ def test_str_geenerate_code_prompt_to_json(self, chromadb_mock): prompt_json["prompt"] = prompt_json["prompt"].replace("\r\n", "\n") assert prompt_json == { - "datasets": [ - {"name": None, "description": None, "head": [{"a": 1, "b": 4}]} - ], + "datasets": ['{"a":{"0":1},"b":{"0":4}}'], "conversation": [], "system_prompt": None, "prompt": '\ndfs[0]:1x2\na,b\n1,4\n\n\n\n\nUpdate this initial code:\n```python\n# TODO: import the required dependencies\nimport pandas as pd\n\n# Write code here\n\n# Declare result var: \ntype (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }\n\n```\n\n\nYou can utilize these examples as a reference for generating code.\n\n[\'query1\']\n\nHere are additional documents for reference. Feel free to use them to answer.\n[\'documents1\']\n\n\n\nVariable `dfs: list[pd.DataFrame]` is already declared.\n\nAt the end, declare "result" variable as a dictionary of type and value.\n\n\nGenerate python code and return full updated code:', @@ -479,11 +478,9 @@ def test_str_relations(self, chromadb_mock, output_type, output_type_template): chromadb_instance.get_relevant_qa_documents.return_value = [["query1"]] llm = FakeLLM() agent = Agent( - PandasConnector( - {"original_df": pd.DataFrame({"a": [1], "b": [4]})}, - connector_relations=[PrimaryKey("a")], - ), + DataFrame({"a": [1], "b": [4]}), config={"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}, + vectorstore=chromadb_instance, ) agent.train(["query1"], ["code1"]) prompt = GeneratePythonCodePrompt( @@ -491,23 +488,11 @@ def test_str_relations(self, chromadb_mock, output_type, output_type_template): output_type=output_type, ) - expected_prompt_content = f"""dfs[0]: - name: null - description: null - type: pd.DataFrame - rows: 1 - columns: 2 - schema: - fields: - - name: a - type: int64 - samples: - - 1 - constraints: PRIMARY KEY (a) - - name: b - type: int64 - samples: - - 4 + expected_prompt_content = f""" +dfs[0]:1x2 +a,b +1,4 + @@ -539,9 +524,8 @@ def test_str_relations(self, chromadb_mock, output_type, output_type_template): Generate python code and return full updated code:""" # noqa E501 actual_prompt_content = prompt.to_string() + if sys.platform.startswith("win"): actual_prompt_content = actual_prompt_content.replace("\r\n", "\n") - print(actual_prompt_content) - assert actual_prompt_content == expected_prompt_content