diff --git a/pandasai/dataframe/__init__.py b/pandasai/dataframe/__init__.py new file mode 100644 index 000000000..36288a4ae --- /dev/null +++ b/pandasai/dataframe/__init__.py @@ -0,0 +1,5 @@ +from .base import DataFrame + +__all__ = [ + "DataFrame", +] diff --git a/pandasai/dataframe/loader.py b/pandasai/dataframe/loader.py index 5b9619a14..f19d64e7a 100644 --- a/pandasai/dataframe/loader.py +++ b/pandasai/dataframe/loader.py @@ -15,7 +15,7 @@ def __init__(self): self.schema = None self.dataset_path = None - def load(self, dataset_path: str) -> DataFrame: + def load(self, dataset_path: str, lazy=False) -> DataFrame: self.dataset_path = dataset_path self._load_schema() self._validate_source_type() diff --git a/pandasai/pipelines/chat/code_cleaning.py b/pandasai/pipelines/chat/code_cleaning.py index e9f7774d5..355b6b16e 100644 --- a/pandasai/pipelines/chat/code_cleaning.py +++ b/pandasai/pipelines/chat/code_cleaning.py @@ -388,19 +388,26 @@ def _get_originals(self, dfs): """ original_dfs = [] for df in dfs: + # TODO - Check why this None check is there if df is None: original_dfs.append(None) continue - - df.execute() - - original_dfs.append(df.pandas_df) + original_dfs.append(df.head()) return original_dfs def _extract_fix_dataframe_redeclarations( self, node: ast.AST, code_lines: list[str] ) -> ast.AST: + """ + Checks if dataframe reclaration in the code like pd.DataFrame({...}) + Args: + node (ast.AST): Code Node + code_lines (list[str]): List of code str line by line + + Returns: + ast.AST: Updated Ast Node fixing redeclaration + """ if isinstance(node, ast.Assign): target_names, is_slice, target = self._get_target_names(node.targets) diff --git a/pandasai/pipelines/chat/code_execution.py b/pandasai/pipelines/chat/code_execution.py index 5f8a4dc3a..a408137c9 100644 --- a/pandasai/pipelines/chat/code_execution.py +++ b/pandasai/pipelines/chat/code_execution.py @@ -1,7 +1,6 @@ import ast import logging import traceback -from collections import defaultdict from typing import Any, Callable, Generator, List, Union from pandasai.exceptions import InvalidLLMOutputType, InvalidOutputValueMismatch @@ -10,13 +9,13 @@ from ...exceptions import NoResultFoundError from ...helpers.logger import Logger -from ...helpers.node_visitors import AssignmentVisitor, CallVisitor from ...helpers.optional import get_environment from ...helpers.output_validator import OutputValidator from ...schemas.df_config import Config from ..base_logic_unit import BaseLogicUnit from ..pipeline_context import PipelineContext from .code_cleaning import CodeExecutionContext +import pandas as pd class CodeExecution(BaseLogicUnit): @@ -205,116 +204,21 @@ def _get_originals(self, dfs): list: List of dfs """ original_dfs = [] - for index, df in enumerate(dfs): + for df in dfs: + # TODO - Check why this None check is there if df is None: original_dfs.append(None) continue - extracted_filters = self._extract_filters(self._current_code_executed) - filters = extracted_filters.get(f"dfs[{index}]", []) - df.set_additional_filters(filters) - - df.execute() - # df.load_connector(partial=len(filters) > 0) - - original_dfs.append(df.pandas_df) + if isinstance(df, pd.DataFrame): + original_dfs.append(df) + else: + # Execute to fetch only if not dataframe + df.execute() + original_dfs.append(df.pandas_df) return original_dfs - def _extract_filters(self, code) -> dict[str, list]: - """ - Extract filters to be applied to the dataframe from passed code. - - Args: - code (str): A snippet of code to be parsed. - - Returns: - dict: The dictionary containing all filters parsed from - the passed code. The dictionary has the following structure: - { - "": [ - ("", "", "") - ] - } - - Raises: - SyntaxError: If the code is unable to be parsed by `ast.parse()`. - Exception: If any exception is raised during working with nodes - of the code tree. - """ - try: - parsed_tree = ast.parse(code) - except SyntaxError: - self.logger.log( - "Invalid code passed for extracting filters", level=logging.ERROR - ) - self.logger.log(f"{traceback.format_exc()}", level=logging.DEBUG) - raise - - try: - filters = self._extract_comparisons(parsed_tree) - except Exception: - self.logger.log( - "Unable to extract filters for passed code", level=logging.ERROR - ) - self.logger.log(f"Error: {traceback.format_exc()}", level=logging.DEBUG) - return {} - - return filters - - def _extract_comparisons(self, tree: ast.Module) -> dict[str, list]: - """ - Process nodes from passed tree to extract filters. - - Collects all assignments in the tree. - Collects all function calls in the tree. - Walk over the tree and handle each comparison node. - For each comparison node, defined what `df` is this node related to. - Parse constants values from the comparison node. - Add to the result dict. - - Args: - tree (str): A snippet of code to be parsed. - - Returns: - dict: The `defaultdict(list)` instance containing all filters - parsed from the passed instructions tree. The dictionary has - the following structure: - { - "": [ - ("", "", "") - ] - } - """ - comparisons = defaultdict(list) - current_df = "dfs[0]" - - visitor = AssignmentVisitor() - visitor.visit(tree) - assignments = visitor.assignment_nodes - - call_visitor = CallVisitor() - call_visitor.visit(tree) - - for node in ast.walk(tree): - if isinstance(node, ast.Compare) and isinstance(node.left, ast.Subscript): - name, *slices = self._tokenize_operand(node.left) - current_df = ( - self._get_df_id_by_nearest_assignment( - node.lineno, assignments, name - ) - or current_df - ) - left_str = slices[-1] if slices else name - - for op, right in zip(node.ops, node.comparators): - op_str = self._ast_comparator_map.get(type(op), "Unknown") - name, *slices = self._tokenize_operand(right) - right_str = slices[-1] if slices else name - - comparisons[current_df].append((left_str, op_str, right_str)) - return comparisons - def _retry_run_code( self, code: str, diff --git a/tests/unit_tests/pipelines/smart_datalake/test_code_execution.py b/tests/unit_tests/pipelines/smart_datalake/test_code_execution.py index 635320363..df7e8722a 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_code_execution.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_code_execution.py @@ -324,70 +324,3 @@ def test_get_environment(self): "__build_class__": __build_class__, "__name__": "__main__", } - - @pytest.mark.parametrize("df_name", ["df", "foobar"]) - def test_extract_filters_col_index(self, df_name, code_execution): - code = f""" -{df_name} = dfs[0] -filtered_df = ( - {df_name}[ - ({df_name}['loan_status'] == 'PAIDOFF') & ({df_name}['Gender'] == 'male') - ] -) -num_loans = len(filtered_df) -result = {{'type': 'number', 'value': num_loans}} -""" - filters = code_execution._extract_filters(code) - assert isinstance(filters, dict) - assert "dfs[0]" in filters - assert isinstance(filters["dfs[0]"], list) - assert len(filters["dfs[0]"]) == 2 - - assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF") - assert filters["dfs[0]"][1] == ("Gender", "=", "male") - - def test_extract_filters_col_index_multiple_df(self, code_execution, logger): - code = """ -df = dfs[0] -filtered_paid_df_male = df[( - df['loan_status'] == 'PAIDOFF') & (df['Gender'] == 'male' -)] -num_loans_paid_off_male = len(filtered_paid_df) - -df = dfs[1] -filtered_pend_df_male = df[( - df['loan_status'] == 'PENDING') & (df['Gender'] == 'male' -)] -num_loans_pending_male = len(filtered_pend_df) - -df = dfs[2] -filtered_paid_df_female = df[( - df['loan_status'] == 'PAIDOFF') & (df['Gender'] == 'female' -)] -num_loans_paid_off_female = len(filtered_pend_df) - -value = num_loans_paid_off + num_loans_pending + num_loans_paid_off_female -result = { - 'type': 'number', - 'value': value -} -""" - code_execution.logger = logger - filters = code_execution._extract_filters(code) - print(filters) - assert isinstance(filters, dict) - assert "dfs[0]" in filters - assert "dfs[1]" in filters - assert "dfs[2]" in filters - assert isinstance(filters["dfs[0]"], list) - assert len(filters["dfs[0]"]) == 2 - assert len(filters["dfs[1]"]) == 2 - - assert filters["dfs[0]"][0] == ("loan_status", "=", "PAIDOFF") - assert filters["dfs[0]"][1] == ("Gender", "=", "male") - - assert filters["dfs[1]"][0] == ("loan_status", "=", "PENDING") - assert filters["dfs[1]"][1] == ("Gender", "=", "male") - - assert filters["dfs[2]"][0] == ("loan_status", "=", "PAIDOFF") - assert filters["dfs[2]"][1] == ("Gender", "=", "female")