Skip to content

Commit

Permalink
refactor: type for prompt_id (#586) (#587)
Browse files Browse the repository at this point in the history
* (refactor): update type hint for `prompt_id`
* (refactor): update prompt id to be passed in tests, make it be uuid
  object, not empty string
* (fix): update type hint for `_last_prompt_id`
* (docs): update docstring for `execute_code()`
  • Loading branch information
nautics889 authored Sep 22, 2023
1 parent 325fed6 commit d5e6e38
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
13 changes: 6 additions & 7 deletions pandasai/helpers/code_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import ast
import uuid
from collections import defaultdict

import astor
Expand All @@ -15,7 +16,7 @@
WHITELISTED_LIBRARIES,
)
from ..middlewares.charts import ChartsMiddleware
from typing import Union, List, Optional, Generator
from typing import Union, List, Optional, Generator, Any
from ..helpers.logger import Logger
from ..schemas.df_config import Config
import logging
Expand Down Expand Up @@ -182,21 +183,19 @@ def _required_dfs(self, code: str) -> List[str]:
def execute_code(
self,
code: str,
prompt_id: str,
) -> str:
prompt_id: uuid.UUID,
) -> Any:
"""
Execute the python code generated by LLMs to answer the question
about the input dataframe. Run the code in the current context and return the
result.
Args:
code (str): Python code to execute.
data_frame (pd.DataFrame): Full Pandas DataFrame.
use_error_correction_framework (bool): Turn on Error Correction mechanism.
Default to True.
prompt_id (uuid.UUID): UUID of the request.
Returns:
str: The result of the code execution. The type of the result depends
Any: The result of the code execution. The type of the result depends
on the generated code.
"""
Expand Down
2 changes: 1 addition & 1 deletion pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SmartDatalake:
_cache: Cache = None
_logger: Logger
_start_time: float
_last_prompt_id: uuid
_last_prompt_id: uuid.UUID
_code_manager: CodeManager
_memory: Memory

Expand Down
14 changes: 10 additions & 4 deletions tests/test_codemanager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Unit tests for the CodeManager class"""
import uuid
from typing import Optional
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -75,18 +76,23 @@ def test_run_code_for_calculations(self, code_manager: CodeManager):
code = """def analyze_data(dfs):
return {'type': 'number', 'value': 1 + 1}"""

assert code_manager.execute_code(code, "")["value"] == 2
assert code_manager.execute_code(code, uuid.uuid4())["value"] == 2
assert code_manager.last_code_executed == code

def test_run_code_invalid_code(self, code_manager: CodeManager):
with pytest.raises(Exception):
code_manager.execute_code("1+ ", "")
# noinspection PyStatementEffect
code_manager.execute_code("1+ ", uuid.uuid4())["value"]

def test_clean_code_remove_builtins(self, code_manager: CodeManager):
builtins_code = """import set
def analyze_data(dfs):
return {'type': 'number', 'value': set([1, 2, 3])}"""
assert code_manager.execute_code(builtins_code, "")["value"] == {1, 2, 3}
assert code_manager.execute_code(builtins_code, uuid.uuid4())["value"] == {
1,
2,
3,
}
assert (
code_manager.last_code_executed
== """def analyze_data(dfs):
Expand Down Expand Up @@ -123,7 +129,7 @@ def test_clean_code_raise_bad_import_error(self, code_manager: CodeManager):
print(os.listdir())
"""
with pytest.raises(BadImportError):
code_manager.execute_code(malicious_code, "")
code_manager.execute_code(malicious_code, uuid.uuid4())

def test_remove_dfs_overwrites(self, code_manager: CodeManager):
hallucinated_code = """def analyze_data(dfs):
Expand Down

0 comments on commit d5e6e38

Please sign in to comment.