Skip to content

Commit

Permalink
refactor(pandasai): make pandasai v3 work for dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Nov 18, 2024
1 parent afa899d commit 5ad125d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 177 deletions.
5 changes: 5 additions & 0 deletions pandasai/dataframe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base import DataFrame

__all__ = [
"DataFrame",
]
2 changes: 1 addition & 1 deletion pandasai/dataframe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
self.schema = None
self.dataset_path = None

def load(self, dataset_path: str) -> DataFrame:
def load(self, dataset_path: str, lazy=False) -> DataFrame:
self.dataset_path = dataset_path
self._load_schema()
self._validate_source_type()
Expand Down
15 changes: 11 additions & 4 deletions pandasai/pipelines/chat/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,26 @@ def _get_originals(self, dfs):
"""
original_dfs = []
for df in dfs:
# TODO - Check why this None check is there
if df is None:
original_dfs.append(None)
continue

df.execute()

original_dfs.append(df.pandas_df)
original_dfs.append(df.head())

return original_dfs

def _extract_fix_dataframe_redeclarations(
self, node: ast.AST, code_lines: list[str]
) -> ast.AST:
"""
Checks if dataframe reclaration in the code like pd.DataFrame({...})
Args:
node (ast.AST): Code Node
code_lines (list[str]): List of code str line by line
Returns:
ast.AST: Updated Ast Node fixing redeclaration
"""
if isinstance(node, ast.Assign):
target_names, is_slice, target = self._get_target_names(node.targets)

Expand Down
114 changes: 9 additions & 105 deletions pandasai/pipelines/chat/code_execution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ast
import logging
import traceback
from collections import defaultdict
from typing import Any, Callable, Generator, List, Union

from pandasai.exceptions import InvalidLLMOutputType, InvalidOutputValueMismatch
Expand All @@ -10,13 +9,13 @@

from ...exceptions import NoResultFoundError
from ...helpers.logger import Logger
from ...helpers.node_visitors import AssignmentVisitor, CallVisitor
from ...helpers.optional import get_environment
from ...helpers.output_validator import OutputValidator
from ...schemas.df_config import Config
from ..base_logic_unit import BaseLogicUnit
from ..pipeline_context import PipelineContext
from .code_cleaning import CodeExecutionContext
import pandas as pd


class CodeExecution(BaseLogicUnit):
Expand Down Expand Up @@ -205,116 +204,21 @@ def _get_originals(self, dfs):
list: List of dfs
"""
original_dfs = []
for index, df in enumerate(dfs):
for df in dfs:
# TODO - Check why this None check is there
if df is None:
original_dfs.append(None)
continue

extracted_filters = self._extract_filters(self._current_code_executed)
filters = extracted_filters.get(f"dfs[{index}]", [])
df.set_additional_filters(filters)

df.execute()
# df.load_connector(partial=len(filters) > 0)

original_dfs.append(df.pandas_df)
if isinstance(df, pd.DataFrame):
original_dfs.append(df)
else:
# Execute to fetch only if not dataframe
df.execute()
original_dfs.append(df.pandas_df)

return original_dfs

def _extract_filters(self, code) -> dict[str, list]:
"""
Extract filters to be applied to the dataframe from passed code.
Args:
code (str): A snippet of code to be parsed.
Returns:
dict: The dictionary containing all filters parsed from
the passed code. The dictionary has the following structure:
{
"<df_number>": [
("<left_operand>", "<operator>", "<right_operand>")
]
}
Raises:
SyntaxError: If the code is unable to be parsed by `ast.parse()`.
Exception: If any exception is raised during working with nodes
of the code tree.
"""
try:
parsed_tree = ast.parse(code)
except SyntaxError:
self.logger.log(
"Invalid code passed for extracting filters", level=logging.ERROR
)
self.logger.log(f"{traceback.format_exc()}", level=logging.DEBUG)
raise

try:
filters = self._extract_comparisons(parsed_tree)
except Exception:
self.logger.log(
"Unable to extract filters for passed code", level=logging.ERROR
)
self.logger.log(f"Error: {traceback.format_exc()}", level=logging.DEBUG)
return {}

return filters

def _extract_comparisons(self, tree: ast.Module) -> dict[str, list]:
"""
Process nodes from passed tree to extract filters.
Collects all assignments in the tree.
Collects all function calls in the tree.
Walk over the tree and handle each comparison node.
For each comparison node, defined what `df` is this node related to.
Parse constants values from the comparison node.
Add to the result dict.
Args:
tree (str): A snippet of code to be parsed.
Returns:
dict: The `defaultdict(list)` instance containing all filters
parsed from the passed instructions tree. The dictionary has
the following structure:
{
"<df_number>": [
("<left_operand>", "<operator>", "<right_operand>")
]
}
"""
comparisons = defaultdict(list)
current_df = "dfs[0]"

visitor = AssignmentVisitor()
visitor.visit(tree)
assignments = visitor.assignment_nodes

call_visitor = CallVisitor()
call_visitor.visit(tree)

for node in ast.walk(tree):
if isinstance(node, ast.Compare) and isinstance(node.left, ast.Subscript):
name, *slices = self._tokenize_operand(node.left)
current_df = (
self._get_df_id_by_nearest_assignment(
node.lineno, assignments, name
)
or current_df
)
left_str = slices[-1] if slices else name

for op, right in zip(node.ops, node.comparators):
op_str = self._ast_comparator_map.get(type(op), "Unknown")
name, *slices = self._tokenize_operand(right)
right_str = slices[-1] if slices else name

comparisons[current_df].append((left_str, op_str, right_str))
return comparisons

def _retry_run_code(
self,
code: str,
Expand Down
67 changes: 0 additions & 67 deletions tests/unit_tests/pipelines/smart_datalake/test_code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,70 +324,3 @@ def test_get_environment(self):
"__build_class__": __build_class__,
"__name__": "__main__",
}

@pytest.mark.parametrize("df_name", ["df", "foobar"])
def test_extract_filters_col_index(self, df_name, code_execution):
code = f"""
{df_name} = dfs[0]
filtered_df = (
{df_name}[
({df_name}['loan_status'] == 'PAIDOFF') & ({df_name}['Gender'] == 'male')
]
)
num_loans = len(filtered_df)
result = {{'type': 'number', 'value': num_loans}}
"""
filters = code_execution._extract_filters(code)
assert isinstance(filters, dict)
assert "dfs[0]" in filters
assert isinstance(filters["dfs[0]"], list)
assert len(filters["dfs[0]"]) == 2

assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "=", "male")

def test_extract_filters_col_index_multiple_df(self, code_execution, logger):
code = """
df = dfs[0]
filtered_paid_df_male = df[(
df['loan_status'] == 'PAIDOFF') & (df['Gender'] == 'male'
)]
num_loans_paid_off_male = len(filtered_paid_df)
df = dfs[1]
filtered_pend_df_male = df[(
df['loan_status'] == 'PENDING') & (df['Gender'] == 'male'
)]
num_loans_pending_male = len(filtered_pend_df)
df = dfs[2]
filtered_paid_df_female = df[(
df['loan_status'] == 'PAIDOFF') & (df['Gender'] == 'female'
)]
num_loans_paid_off_female = len(filtered_pend_df)
value = num_loans_paid_off + num_loans_pending + num_loans_paid_off_female
result = {
'type': 'number',
'value': value
}
"""
code_execution.logger = logger
filters = code_execution._extract_filters(code)
print(filters)
assert isinstance(filters, dict)
assert "dfs[0]" in filters
assert "dfs[1]" in filters
assert "dfs[2]" in filters
assert isinstance(filters["dfs[0]"], list)
assert len(filters["dfs[0]"]) == 2
assert len(filters["dfs[1]"]) == 2

assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[0]"][1] == ("Gender", "=", "male")

assert filters["dfs[1]"][0] == ("loan_status", "=", "PENDING")
assert filters["dfs[1]"][1] == ("Gender", "=", "male")

assert filters["dfs[2]"][0] == ("loan_status", "=", "PAIDOFF")
assert filters["dfs[2]"][1] == ("Gender", "=", "female")

0 comments on commit 5ad125d

Please sign in to comment.