Skip to content

Commit

Permalink
chore(direct_sql): Add Sanity Check in code generation (Sinaptik-AI#864)
Browse files Browse the repository at this point in the history
* chore(direct_sql): Add Sanity Check in code generation

* update comment

* update comments

* chore: correct typo in the comment

---------

Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
ArslanSaleem and gventuri authored Jan 11, 2024
1 parent b3254f6 commit 9a06b02
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
13 changes: 13 additions & 0 deletions pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,13 @@ def find_function_calls(self, node: ast.AST, context: CodeExecutionContext):
for child_node in ast.iter_child_nodes(node):
self.find_function_calls(child_node, context)

def check_direct_sql_func_def_exists(self, node: ast.AST):
return (
self._validate_direct_sql(self._dfs)
and isinstance(node, ast.FunctionDef)
and node.name == "execute_sql_query"
)

def _clean_code(self, code: str, context: CodeExecutionContext) -> str:
"""
A method to clean the code to prevent malicious code execution.
Expand Down Expand Up @@ -350,7 +357,13 @@ def _clean_code(self, code: str, context: CodeExecutionContext) -> str:
):
continue

# if generated code contain execute_sql_query def remove it
# function already defined
if self.check_direct_sql_func_def_exists(node):
continue

self.find_function_calls(node, context)

new_body.append(node)

new_tree = ast.Module(body=new_body)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_codemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,52 @@ def test_validate_true_direct_sql_with_two_different_connector(
config={"llm": FakeLLM(output="")},
)
code_manager._validate_direct_sql([df1, df2])

def test_clean_code_direct_sql_code(
self, pgsql_connector: PostgreSQLConnector, exec_context: MagicMock
):
"""Test that the direct SQL function definition is removed when 'direct_sql' is True"""
df = SmartDataframe(
pgsql_connector,
config={"llm": FakeLLM(output=""), "direct_sql": True},
)
code_manager = df.lake._code_manager
safe_code = """
import numpy as np
def execute_sql_query(sql_query: str) -> pd.DataFrame:
# code to connect to the database and execute the query
# ...
# return the result as a dataframe
return pd.DataFrame()
np.array()
"""
assert code_manager._clean_code(safe_code, exec_context) == "np.array()"

def test_clean_code_direct_sql_code_false(
self, pgsql_connector: PostgreSQLConnector, exec_context: MagicMock
):
"""Test that the direct SQL function definition is removed when 'direct_sql' is False"""
df = SmartDataframe(
pgsql_connector,
config={"llm": FakeLLM(output=""), "direct_sql": False},
)
code_manager = df.lake._code_manager

safe_code = """
import numpy as np
def execute_sql_query(sql_query: str) -> pd.DataFrame:
# code to connect to the database and execute the query
# ...
# return the result as a dataframe
return pd.DataFrame()
np.array()
"""
print(code_manager._clean_code(safe_code, exec_context))
assert (
code_manager._clean_code(safe_code, exec_context)
== """def execute_sql_query(sql_query: str) ->pd.DataFrame:
return pd.DataFrame()
np.array()"""
)

0 comments on commit 9a06b02

Please sign in to comment.