diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 31ecf943c..94432ccfe 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -86,6 +86,10 @@ def load(dataset_path: str, virtualized=False) -> DataFrame: Returns: DataFrame: A new PandasAI DataFrame instance with loaded data. """ + path_parts = dataset_path.split("/") + if len(path_parts) != 2: + raise ValueError("The path must be in the format 'organization/dataset'.") + global _dataset_loader dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path) if not os.path.exists(dataset_full_path): diff --git a/pandasai/chat/code_execution/environment.py b/pandasai/chat/code_execution/environment.py index 412c255ec..327c9ca36 100644 --- a/pandasai/chat/code_execution/environment.py +++ b/pandasai/chat/code_execution/environment.py @@ -6,7 +6,7 @@ import importlib import sys import warnings -from typing import List +from typing import List, Union from pandas.util.version import Version @@ -92,7 +92,7 @@ def import_dependency( name: str, extra: str = "", errors: str = "raise", - min_version: str | None = None, + min_version: Union[str, None] = None, ): """ Import an optional dependency. diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index 9612e3c17..d1724bea8 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -47,6 +47,13 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame: ) else: # Initialize new dataset loader for virtualization + source_type = self.schema["source"]["type"] + + if source_type in ["csv", "parquet"]: + raise ValueError( + "Virtualization is not supported for CSV and Parquet files." + ) + data_loader = self.copy() table_name = self.schema["source"].get("table", None) or self.schema["name"] table_description = self.schema.get("description", None) @@ -58,10 +65,11 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame: path=dataset_path, ) + def _get_abs_dataset_path(self): + return os.path.join(find_project_root(), "datasets", self.dataset_path) + def _load_schema(self): - schema_path = os.path.join( - find_project_root(), "datasets", self.dataset_path, "schema.yaml" - ) + schema_path = os.path.join(self._get_abs_dataset_path(), "schema.yaml") if not os.path.exists(schema_path): raise FileNotFoundError(f"Schema file not found: {schema_path}") @@ -79,13 +87,13 @@ def _validate_source_type(self): def _get_cache_file_path(self) -> str: if "path" in self.schema["destination"]: return os.path.join( - "datasets", self.dataset_path, self.schema["destination"]["path"] + self._get_abs_dataset_path(), self.schema["destination"]["path"] ) file_extension = ( "parquet" if self.schema["destination"]["format"] == "parquet" else "csv" ) - return os.path.join("datasets", self.dataset_path, f"data.{file_extension}") + return os.path.join(self._get_abs_dataset_path(), f"data.{file_extension}") def _is_cache_valid(self, cache_file: str) -> bool: if not os.path.exists(cache_file): @@ -154,10 +162,11 @@ def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame: def _load_from_source(self) -> pd.DataFrame: source_type = self.schema["source"]["type"] if source_type in ["csv", "parquet"]: - filpath = os.path.join( - "datasets", self.dataset_path, self.schema["source"]["path"] + filepath = os.path.join( + self._get_abs_dataset_path(), + self.schema["source"]["path"], ) - return self._read_csv_or_parquet(filpath, source_type) + return self._read_csv_or_parquet(filepath, source_type) query_builder = QueryBuilder(self.schema) query = query_builder.build_query() diff --git a/pandasai/data_loader/query_builder.py b/pandasai/data_loader/query_builder.py index 5fc548951..317bd7aeb 100644 --- a/pandasai/data_loader/query_builder.py +++ b/pandasai/data_loader/query_builder.py @@ -21,6 +21,16 @@ def _get_columns(self) -> str: else: return "*" + def _get_table_name(self): + table_name = self.schema["source"].get("table", None) or self.schema["name"] + + if not table_name: + raise ValueError("Table name not found in schema!") + + table_name = table_name.lower() + + return table_name + def _add_order_by(self) -> str: if "order_by" not in self.schema: return "" @@ -40,16 +50,14 @@ def get_head_query(self, n=5): source = self.schema.get("source", {}) source_type = source.get("type") - table_name = self.schema["source"]["table"] + table_name = self._get_table_name() columns = self._get_columns() - order_by = "RAND()" - if source_type in {"sqlite", "postgres"}: - order_by = "RANDOM()" + order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()" return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}" def get_row_count(self): - table_name = self.schema["source"]["table"] + table_name = self._get_table_name() return f"SELECT COUNT(*) FROM {table_name}"