Skip to content

Commit

Permalink
fix: remove df override from the output code (Sinaptik-AI#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed May 16, 2023
1 parent e0bcd23 commit 86d9f06
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
28 changes: 26 additions & 2 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PandasAI:
{df_head}.
When asked about the data, your response should include a python code that describes the dataframe `df`.
Return the python code (do not import anything) and make sure to prefix the requested python code with {START_CODE_TAG} exactly and suffix the code with {END_CODE_TAG} exactly to get the answer to the following question:
Using the provided dataframe, df, return the python code and make sure to prefix the requested python code with {START_CODE_TAG} exactly and suffix the code with {END_CODE_TAG} exactly to get the answer to the following question:
"""
_response_instruction: str = """
Question: {question}
Expand Down Expand Up @@ -185,6 +185,30 @@ def remove_unsafe_imports(self, code: str) -> str:
new_tree = ast.Module(body=new_body)
return astor.to_source(new_tree).strip()

def remove_df_overwrites(self, code: str) -> str:
"""Remove df declarations from the code to prevent malicious code execution"""

tree = ast.parse(code)
new_body = [
node
for node in tree.body
if not (
isinstance(node, ast.Assign)
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "df"
)
]
new_tree = ast.Module(body=new_body)
return astor.to_source(new_tree).strip()

def clean_code(self, code: str) -> str:
"""Clean the code to prevent malicious code execution"""

# TODO: avoid iterating over the code twice
code = self.remove_unsafe_imports(code)
code = self.remove_df_overwrites(code)
return code

def run_code(
self,
code: str,
Expand All @@ -198,7 +222,7 @@ def run_code(
with redirect_stdout(io.StringIO()) as output:
# Execute the code
count = 0
code_to_run = self.remove_unsafe_imports(code)
code_to_run = self.clean_code(code)
while count < self._max_retries:
try:
exec(
Expand Down
12 changes: 10 additions & 2 deletions tests/test_pandasai.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_run_with_privacy_enforcement(self, pandasai, llm):
Index: [].
When asked about the data, your response should include a python code that describes the dataframe `df`.
Return the python code (do not import anything) and make sure to prefix the requested python code with <startCode> exactly and suffix the code with <endCode> exactly to get the answer to the following question:
Using the provided dataframe, df, return the python code and make sure to prefix the requested python code with <startCode> exactly and suffix the code with <endCode> exactly to get the answer to the following question:
How many countries are in the dataframe?
Code:
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_run_without_privacy_enforcement(self, pandasai):
2 France.
When asked about the data, your response should include a python code that describes the dataframe `df`.
Return the python code (do not import anything) and make sure to prefix the requested python code with <startCode> exactly and suffix the code with <endCode> exactly to get the answer to the following question:
Using the provided dataframe, df, return the python code and make sure to prefix the requested python code with <startCode> exactly and suffix the code with <endCode> exactly to get the answer to the following question:
How many countries are in the dataframe?
Code:
Expand Down Expand Up @@ -249,3 +249,11 @@ def test_remove_unsafe_imports(self, pandasai):
pandasai._llm._output = malicious_code
assert pandasai.remove_unsafe_imports(malicious_code) == "print(os.listdir())"
assert pandasai.run_code(malicious_code, pd.DataFrame()) == ""

def test_remove_df_overwrites(self, pandasai):
malicious_code = """
df = pd.DataFrame([1,2,3])
print(df)
"""
pandasai._llm._output = malicious_code
assert pandasai.remove_df_overwrites(malicious_code) == "print(df)"

0 comments on commit 86d9f06

Please sign in to comment.