Skip to content

Commit

Permalink
fix: prevent sql injections
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Sep 7, 2023
1 parent 9955000 commit 8951ebe
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 57 deletions.
70 changes: 37 additions & 33 deletions pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
SQL connectors are used to connect to SQL databases in different dialects.
"""

import re
import os
import pandas as pd
from .base import BaseConnector, ConnectorConfig
from sqlalchemy import create_engine, sql
from sqlalchemy import create_engine, sql, text, select, asc
from functools import cached_property, cache
import hashlib
from ..helpers.path import find_project_root
Expand Down Expand Up @@ -70,45 +71,48 @@ def __repr__(self):
f"table={self._config.table}>"
)

def _build_query(self, limit: int = None, order: str = None):
"""
Build the SQL query that will be executed.
def _validate_column_name(self, column_name):
regex = r"^[a-zA-Z0-9_]+$"
if not re.match(regex, column_name):
raise ValueError("Invalid column name: {}".format(column_name))

Args:
limit (int, optional): The number of rows to return. Defaults to None.
def _build_query(self, limit=None, order=None):
base_query = select("*").select_from(text(self._config.table))
valid_operators = ["=", ">", "<", ">=", "<=", "LIKE", "!=", "IN", "NOT IN"]

Returns:
str: The SQL query that will be executed.
"""
if self._config.where or self._additional_filters:
# conditions is the list of wher + additional filters
conditions = []
if self._config.where:
conditions += self._config.where
if self._additional_filters:
conditions += self._additional_filters

# Run a SQL query to get all the columns names and 5 random rows
query = f"SELECT * FROM {self._config.table}"
if (
self._config.where
or self._additional_filters is not None
and len(self._additional_filters) > 0
):
query += " WHERE "
query_params = {}
condition_strings = []

for i, condition in enumerate(conditions):
if len(condition) == 3:
column_name, operator, value = condition
if operator in valid_operators:
self._validate_column_name(column_name)

condition_strings.append(f"{column_name} {operator} :value_{i}")
query_params[f"value_{i}"] = value

if condition_strings:
where_clause = " AND ".join(condition_strings)
base_query = base_query.where(
text(where_clause).bindparams(**query_params)
)

conditions = []
if self._config.where is not None:
for condition in self._config.where:
conditions.append(f"{condition[0]} {condition[1]} '{condition[2]}'")
if (
self._additional_filters is not None
and len(self._additional_filters) > 0
):
for condition in self._additional_filters:
conditions.append(f"{condition[0]} {condition[1]} '{condition[2]}'")

query += " AND ".join(conditions)
if order:
query += f" ORDER BY {order}"
base_query = base_query.order_by(asc(text(order)))

if limit:
query += f" LIMIT {limit}"
base_query = base_query.limit(limit)

# Return the query
return sql.text(query)
return base_query

@cache
def head(self):
Expand Down
4 changes: 2 additions & 2 deletions pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CodeManager:
_logger: Logger = None
_additional_dependencies: List[dict] = []
_ast_comparatos_map: dict = {
ast.Eq: "==",
ast.Eq: "=",
ast.NotEq: "!=",
ast.Lt: "<",
ast.LtE: "<=",
Expand Down Expand Up @@ -589,7 +589,7 @@ def _extract_comparisons(self, tree: ast.Module) -> dict[str, list]:
}
"""
comparisons = defaultdict(list)
current_df = "dfs0"
current_df = "dfs[0]"

visitor = AssignmentVisitor()
visitor.visit(tree)
Expand Down
9 changes: 5 additions & 4 deletions tests/connectors/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ def test_repr_method(self):
def test_build_query_method(self):
# Test _build_query method
query = self.connector._build_query(limit=5, order="RAND()")
expected_query = (
"SELECT * FROM your_table WHERE column_name = 'value' "
"ORDER BY RAND() LIMIT 5"
)
expected_query = """SELECT *
FROM your_table
WHERE column_name = :value_0 ORDER BY RAND() ASC
LIMIT :param_1"""

self.assertEqual(str(query), expected_query)

@patch("pandasai.connectors.sql.pd.read_sql", autospec=True)
Expand Down
36 changes: 18 additions & 18 deletions tests/test_codemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
assert isinstance(filters["dfs[0]"], list)
assert len(filters["dfs[0]"]) == 2

assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "==", "male")
assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "=", "male")

def test_extract_filters_polars_multiple_df(self, code_manager: CodeManager):
code = """
Expand Down Expand Up @@ -355,14 +355,14 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
assert len(filters["dfs[0]"]) == 2
assert len(filters["dfs[1]"]) == 2

assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "==", "male")
assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "=", "male")

assert filters["dfs[1]"][0] == ("loan_status", "==", "PENDING")
assert filters["dfs[1]"][1] == ("Gender", "==", "male")
assert filters["dfs[1]"][0] == ("loan_status", "=", "PENDING")
assert filters["dfs[1]"][1] == ("Gender", "=", "male")

assert filters["dfs[2]"][0] == ("loan_status", "==", "PAIDOFF")
assert filters["dfs[2]"][1] == ("Gender", "==", "female")
assert filters["dfs[2]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[2]"][1] == ("Gender", "=", "female")

@pytest.mark.parametrize("df_name", ["df", "foobar"])
def test_extract_filters_col_index(self, df_name, code_manager: CodeManager):
Expand All @@ -386,8 +386,8 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
assert isinstance(filters["dfs[0]"], list)
assert len(filters["dfs[0]"]) == 2

assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "==", "male")
assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "=", "male")

def test_extract_filters_col_index_non_default_name(
self, code_manager: CodeManager
Expand All @@ -413,8 +413,8 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
assert isinstance(filters["dfs[0]"], list)
assert len(filters["dfs[0]"]) == 2

assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "==", "male")
assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "=", "male")

def test_extract_filters_col_index_multiple_df(self, code_manager: CodeManager):
code = """
Expand Down Expand Up @@ -455,11 +455,11 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
assert len(filters["dfs[0]"]) == 2
assert len(filters["dfs[1]"]) == 2

assert filters["dfs[0]"][0] == ("loan_status", "==", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "==", "male")
assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "=", "male")

assert filters["dfs[1]"][0] == ("loan_status", "==", "PENDING")
assert filters["dfs[1]"][1] == ("Gender", "==", "male")
assert filters["dfs[1]"][0] == ("loan_status", "=", "PENDING")
assert filters["dfs[1]"][1] == ("Gender", "=", "male")

assert filters["dfs[2]"][0] == ("loan_status", "==", "PAIDOFF")
assert filters["dfs[2]"][1] == ("Gender", "==", "female")
assert filters["dfs[2]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[2]"][1] == ("Gender", "=", "female")

0 comments on commit 8951ebe

Please sign in to comment.