diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index 8ee34e0df..3b386c8ce 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -2,10 +2,11 @@ SQL connectors are used to connect to SQL databases in different dialects. """ +import re import os import pandas as pd from .base import BaseConnector, ConnectorConfig -from sqlalchemy import create_engine, sql +from sqlalchemy import create_engine, sql, text, select, asc from functools import cached_property, cache import hashlib from ..helpers.path import find_project_root @@ -70,45 +71,48 @@ def __repr__(self): f"table={self._config.table}>" ) - def _build_query(self, limit: int = None, order: str = None): - """ - Build the SQL query that will be executed. + def _validate_column_name(self, column_name): + regex = r"^[a-zA-Z0-9_]+$" + if not re.match(regex, column_name): + raise ValueError("Invalid column name: {}".format(column_name)) - Args: - limit (int, optional): The number of rows to return. Defaults to None. + def _build_query(self, limit=None, order=None): + base_query = select("*").select_from(text(self._config.table)) + valid_operators = ["=", ">", "<", ">=", "<=", "LIKE", "!=", "IN", "NOT IN"] - Returns: - str: The SQL query that will be executed. - """ + if self._config.where or self._additional_filters: + # conditions is the list of wher + additional filters + conditions = [] + if self._config.where: + conditions += self._config.where + if self._additional_filters: + conditions += self._additional_filters - # Run a SQL query to get all the columns names and 5 random rows - query = f"SELECT * FROM {self._config.table}" - if ( - self._config.where - or self._additional_filters is not None - and len(self._additional_filters) > 0 - ): - query += " WHERE " + query_params = {} + condition_strings = [] + + for i, condition in enumerate(conditions): + if len(condition) == 3: + column_name, operator, value = condition + if operator in valid_operators: + self._validate_column_name(column_name) + + condition_strings.append(f"{column_name} {operator} :value_{i}") + query_params[f"value_{i}"] = value + + if condition_strings: + where_clause = " AND ".join(condition_strings) + base_query = base_query.where( + text(where_clause).bindparams(**query_params) + ) - conditions = [] - if self._config.where is not None: - for condition in self._config.where: - conditions.append(f"{condition[0]} {condition[1]} '{condition[2]}'") - if ( - self._additional_filters is not None - and len(self._additional_filters) > 0 - ): - for condition in self._additional_filters: - conditions.append(f"{condition[0]} {condition[1]} '{condition[2]}'") - - query += " AND ".join(conditions) if order: - query += f" ORDER BY {order}" + base_query = base_query.order_by(asc(text(order))) + if limit: - query += f" LIMIT {limit}" + base_query = base_query.limit(limit) - # Return the query - return sql.text(query) + return base_query @cache def head(self): diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index 80aaf6c38..310aaaca6 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -29,7 +29,7 @@ class CodeManager: _logger: Logger = None _additional_dependencies: List[dict] = [] _ast_comparatos_map: dict = { - ast.Eq: "==", + ast.Eq: "=", ast.NotEq: "!=", ast.Lt: "<", ast.LtE: "<=", @@ -589,7 +589,7 @@ def _extract_comparisons(self, tree: ast.Module) -> dict[str, list]: } """ comparisons = defaultdict(list) - current_df = "dfs0" + current_df = "dfs[0]" visitor = AssignmentVisitor() visitor.visit(tree) diff --git a/pandasai/helpers/df_info.py b/pandasai/helpers/df_info.py index 139bd609b..ac3de91f3 100644 --- a/pandasai/helpers/df_info.py +++ b/pandasai/helpers/df_info.py @@ -21,10 +21,6 @@ def df_type(df: DataFrameType) -> str: Returns: str: Type of the dataframe """ - print("*" * 100) - print(df) - print("*" * 100) - if polars_imported and isinstance(df, pl.DataFrame): return "polars" elif isinstance(df, pd.DataFrame): diff --git a/tests/connectors/test_sql.py b/tests/connectors/test_sql.py index 733be7759..f2152aa60 100644 --- a/tests/connectors/test_sql.py +++ b/tests/connectors/test_sql.py @@ -50,10 +50,11 @@ def test_repr_method(self): def test_build_query_method(self): # Test _build_query method query = self.connector._build_query(limit=5, order="RAND()") - expected_query = ( - "SELECT * FROM your_table WHERE column_name = 'value' " - "ORDER BY RAND() LIMIT 5" - ) + expected_query = """SELECT * +FROM your_table +WHERE column_name = :value_0 ORDER BY RAND() ASC + LIMIT :param_1""" + self.assertEqual(str(query), expected_query) @patch("pandasai.connectors.sql.pd.read_sql", autospec=True) diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index cfe225a02..529a0c3fb 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -313,8 +313,8 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: assert isinstance(filters["dfs[0]"], list) assert len(filters["dfs[0]"]) == 2 - assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF") - assert filters["dfs[0]"][1] == ("Gender", "==", "male") + assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF") + assert filters["dfs[0]"][1] == ("Gender", "=", "male") def test_extract_filters_polars_multiple_df(self, code_manager: CodeManager): code = """ @@ -355,14 +355,14 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: assert len(filters["dfs[0]"]) == 2 assert len(filters["dfs[1]"]) == 2 - assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF") - assert filters["dfs[0]"][1] == ("Gender", "==", "male") + assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF") + assert filters["dfs[0]"][1] == ("Gender", "=", "male") - assert filters["dfs[1]"][0] == ("loan_status", "==", "PENDING") - assert filters["dfs[1]"][1] == ("Gender", "==", "male") + assert filters["dfs[1]"][0] == ("loan_status", "=", "PENDING") + assert filters["dfs[1]"][1] == ("Gender", "=", "male") - assert filters["dfs[2]"][0] == ("loan_status", "==", "PAIDOFF") - assert filters["dfs[2]"][1] == ("Gender", "==", "female") + assert filters["dfs[2]"][0] == ("loan_status", "=", "PAIDOFF") + assert filters["dfs[2]"][1] == ("Gender", "=", "female") @pytest.mark.parametrize("df_name", ["df", "foobar"]) def test_extract_filters_col_index(self, df_name, code_manager: CodeManager): @@ -386,8 +386,8 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: assert isinstance(filters["dfs[0]"], list) assert len(filters["dfs[0]"]) == 2 - assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF") - assert filters["dfs[0]"][1] == ("Gender", "==", "male") + assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF") + assert filters["dfs[0]"][1] == ("Gender", "=", "male") def test_extract_filters_col_index_non_default_name( self, code_manager: CodeManager @@ -413,8 +413,8 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: assert isinstance(filters["dfs[0]"], list) assert len(filters["dfs[0]"]) == 2 - assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF") - assert filters["dfs[0]"][1] == ("Gender", "==", "male") + assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF") + assert filters["dfs[0]"][1] == ("Gender", "=", "male") def test_extract_filters_col_index_multiple_df(self, code_manager: CodeManager): code = """ @@ -455,11 +455,11 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict: assert len(filters["dfs[0]"]) == 2 assert len(filters["dfs[1]"]) == 2 - assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF") - assert filters["dfs[0]"][1] == ("Gender", "==", "male") + assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF") + assert filters["dfs[0]"][1] == ("Gender", "=", "male") - assert filters["dfs[1]"][0] == ("loan_status", "==", "PENDING") - assert filters["dfs[1]"][1] == ("Gender", "==", "male") + assert filters["dfs[1]"][0] == ("loan_status", "=", "PENDING") + assert filters["dfs[1]"][1] == ("Gender", "=", "male") - assert filters["dfs[2]"][0] == ("loan_status", "==", "PAIDOFF") - assert filters["dfs[2]"][1] == ("Gender", "==", "female") + assert filters["dfs[2]"][0] == ("loan_status", "=", "PAIDOFF") + assert filters["dfs[2]"][1] == ("Gender", "=", "female")