diff --git a/pandasai/helpers/code_manager.py b/pandasai/helpers/code_manager.py index e2706231c..e5b6f8979 100644 --- a/pandasai/helpers/code_manager.py +++ b/pandasai/helpers/code_manager.py @@ -319,6 +319,13 @@ def find_function_calls(self, node: ast.AST, context: CodeExecutionContext): for child_node in ast.iter_child_nodes(node): self.find_function_calls(child_node, context) + def check_direct_sql_func_def_exists(self, node: ast.AST): + return ( + self._validate_direct_sql(self._dfs) + and isinstance(node, ast.FunctionDef) + and node.name == "execute_sql_query" + ) + def _clean_code(self, code: str, context: CodeExecutionContext) -> str: """ A method to clean the code to prevent malicious code execution. @@ -350,7 +357,13 @@ def _clean_code(self, code: str, context: CodeExecutionContext) -> str: ): continue + # if generated code contain execute_sql_query def remove it + # function already defined + if self.check_direct_sql_func_def_exists(node): + continue + self.find_function_calls(node, context) + new_body.append(node) new_tree = ast.Module(body=new_body) diff --git a/tests/test_codemanager.py b/tests/test_codemanager.py index a875aa2fb..97d325bcd 100644 --- a/tests/test_codemanager.py +++ b/tests/test_codemanager.py @@ -510,3 +510,52 @@ def test_validate_true_direct_sql_with_two_different_connector( config={"llm": FakeLLM(output="")}, ) code_manager._validate_direct_sql([df1, df2]) + + def test_clean_code_direct_sql_code( + self, pgsql_connector: PostgreSQLConnector, exec_context: MagicMock + ): + """Test that the direct SQL function definition is removed when 'direct_sql' is True""" + df = SmartDataframe( + pgsql_connector, + config={"llm": FakeLLM(output=""), "direct_sql": True}, + ) + code_manager = df.lake._code_manager + safe_code = """ +import numpy as np +def execute_sql_query(sql_query: str) -> pd.DataFrame: + # code to connect to the database and execute the query + # ... + # return the result as a dataframe + return pd.DataFrame() +np.array() +""" + assert code_manager._clean_code(safe_code, exec_context) == "np.array()" + + def test_clean_code_direct_sql_code_false( + self, pgsql_connector: PostgreSQLConnector, exec_context: MagicMock + ): + """Test that the direct SQL function definition is removed when 'direct_sql' is False""" + df = SmartDataframe( + pgsql_connector, + config={"llm": FakeLLM(output=""), "direct_sql": False}, + ) + code_manager = df.lake._code_manager + + safe_code = """ +import numpy as np +def execute_sql_query(sql_query: str) -> pd.DataFrame: + # code to connect to the database and execute the query + # ... + # return the result as a dataframe + return pd.DataFrame() +np.array() +""" + print(code_manager._clean_code(safe_code, exec_context)) + assert ( + code_manager._clean_code(safe_code, exec_context) + == """def execute_sql_query(sql_query: str) ->pd.DataFrame: + return pd.DataFrame() + + +np.array()""" + )