From a6192ee03e138567309d61bf4ef1c65c5e0c743e Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Fri, 13 Dec 2024 12:53:02 +0100 Subject: [PATCH 1/2] fix(load): path fixed --- pandasai/__init__.py | 4 ++++ pandasai/chat/code_execution/environment.py | 4 ++-- pandasai/data_loader/loader.py | 25 ++++++++++++++------- pandasai/data_loader/query_builder.py | 18 ++++++++++----- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 31ecf943c..068a15787 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -86,6 +86,10 @@ def load(dataset_path: str, virtualized=False) -> DataFrame: Returns: DataFrame: A new PandasAI DataFrame instance with loaded data. """ + path_parts = dataset_path.split("/") + if len(path_parts) != 2: + raise ValueError("Path must be in format 'organization/dataset'") + global _dataset_loader dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path) if not os.path.exists(dataset_full_path): diff --git a/pandasai/chat/code_execution/environment.py b/pandasai/chat/code_execution/environment.py index 412c255ec..327c9ca36 100644 --- a/pandasai/chat/code_execution/environment.py +++ b/pandasai/chat/code_execution/environment.py @@ -6,7 +6,7 @@ import importlib import sys import warnings -from typing import List +from typing import List, Union from pandas.util.version import Version @@ -92,7 +92,7 @@ def import_dependency( name: str, extra: str = "", errors: str = "raise", - min_version: str | None = None, + min_version: Union[str, None] = None, ): """ Import an optional dependency. diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index 9612e3c17..d1724bea8 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -47,6 +47,13 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame: ) else: # Initialize new dataset loader for virtualization + source_type = self.schema["source"]["type"] + + if source_type in ["csv", "parquet"]: + raise ValueError( + "Virtualization is not supported for CSV and Parquet files." + ) + data_loader = self.copy() table_name = self.schema["source"].get("table", None) or self.schema["name"] table_description = self.schema.get("description", None) @@ -58,10 +65,11 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame: path=dataset_path, ) + def _get_abs_dataset_path(self): + return os.path.join(find_project_root(), "datasets", self.dataset_path) + def _load_schema(self): - schema_path = os.path.join( - find_project_root(), "datasets", self.dataset_path, "schema.yaml" - ) + schema_path = os.path.join(self._get_abs_dataset_path(), "schema.yaml") if not os.path.exists(schema_path): raise FileNotFoundError(f"Schema file not found: {schema_path}") @@ -79,13 +87,13 @@ def _validate_source_type(self): def _get_cache_file_path(self) -> str: if "path" in self.schema["destination"]: return os.path.join( - "datasets", self.dataset_path, self.schema["destination"]["path"] + self._get_abs_dataset_path(), self.schema["destination"]["path"] ) file_extension = ( "parquet" if self.schema["destination"]["format"] == "parquet" else "csv" ) - return os.path.join("datasets", self.dataset_path, f"data.{file_extension}") + return os.path.join(self._get_abs_dataset_path(), f"data.{file_extension}") def _is_cache_valid(self, cache_file: str) -> bool: if not os.path.exists(cache_file): @@ -154,10 +162,11 @@ def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame: def _load_from_source(self) -> pd.DataFrame: source_type = self.schema["source"]["type"] if source_type in ["csv", "parquet"]: - filpath = os.path.join( - "datasets", self.dataset_path, self.schema["source"]["path"] + filepath = os.path.join( + self._get_abs_dataset_path(), + self.schema["source"]["path"], ) - return self._read_csv_or_parquet(filpath, source_type) + return self._read_csv_or_parquet(filepath, source_type) query_builder = QueryBuilder(self.schema) query = query_builder.build_query() diff --git a/pandasai/data_loader/query_builder.py b/pandasai/data_loader/query_builder.py index 5fc548951..317bd7aeb 100644 --- a/pandasai/data_loader/query_builder.py +++ b/pandasai/data_loader/query_builder.py @@ -21,6 +21,16 @@ def _get_columns(self) -> str: else: return "*" + def _get_table_name(self): + table_name = self.schema["source"].get("table", None) or self.schema["name"] + + if not table_name: + raise ValueError("Table name not found in schema!") + + table_name = table_name.lower() + + return table_name + def _add_order_by(self) -> str: if "order_by" not in self.schema: return "" @@ -40,16 +50,14 @@ def get_head_query(self, n=5): source = self.schema.get("source", {}) source_type = source.get("type") - table_name = self.schema["source"]["table"] + table_name = self._get_table_name() columns = self._get_columns() - order_by = "RAND()" - if source_type in {"sqlite", "postgres"}: - order_by = "RANDOM()" + order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()" return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}" def get_row_count(self): - table_name = self.schema["source"]["table"] + table_name = self._get_table_name() return f"SELECT COUNT(*) FROM {table_name}" From 8c2cb3807bf83b302575358edd3c39467f30f26a Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Fri, 13 Dec 2024 13:52:35 +0100 Subject: [PATCH 2/2] Update error message Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> --- pandasai/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 068a15787..94432ccfe 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -88,7 +88,7 @@ def load(dataset_path: str, virtualized=False) -> DataFrame: """ path_parts = dataset_path.split("/") if len(path_parts) != 2: - raise ValueError("Path must be in format 'organization/dataset'") + raise ValueError("The path must be in the format 'organization/dataset'.") global _dataset_loader dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path)