Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(load): issue load path of dataset and virtualized=true error on csv #1474

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def load(dataset_path: str, virtualized=False) -> DataFrame:
Returns:
DataFrame: A new PandasAI DataFrame instance with loaded data.
"""
path_parts = dataset_path.split("/")
if len(path_parts) != 2:
raise ValueError("The path must be in the format 'organization/dataset'.")

global _dataset_loader
dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path)
if not os.path.exists(dataset_full_path):
Expand Down
4 changes: 2 additions & 2 deletions pandasai/chat/code_execution/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import importlib
import sys
import warnings
from typing import List
from typing import List, Union

from pandas.util.version import Version

Expand Down Expand Up @@ -92,7 +92,7 @@ def import_dependency(
name: str,
extra: str = "",
errors: str = "raise",
min_version: str | None = None,
min_version: Union[str, None] = None,
):
"""
Import an optional dependency.
Expand Down
25 changes: 17 additions & 8 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
)
else:
# Initialize new dataset loader for virtualization
source_type = self.schema["source"]["type"]

if source_type in ["csv", "parquet"]:
raise ValueError(
"Virtualization is not supported for CSV and Parquet files."
)

data_loader = self.copy()
table_name = self.schema["source"].get("table", None) or self.schema["name"]
table_description = self.schema.get("description", None)
Expand All @@ -58,10 +65,11 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
path=dataset_path,
)

def _get_abs_dataset_path(self):
return os.path.join(find_project_root(), "datasets", self.dataset_path)

def _load_schema(self):
schema_path = os.path.join(
find_project_root(), "datasets", self.dataset_path, "schema.yaml"
)
schema_path = os.path.join(self._get_abs_dataset_path(), "schema.yaml")
if not os.path.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

Expand All @@ -79,13 +87,13 @@ def _validate_source_type(self):
def _get_cache_file_path(self) -> str:
if "path" in self.schema["destination"]:
return os.path.join(
"datasets", self.dataset_path, self.schema["destination"]["path"]
self._get_abs_dataset_path(), self.schema["destination"]["path"]
)

file_extension = (
"parquet" if self.schema["destination"]["format"] == "parquet" else "csv"
)
return os.path.join("datasets", self.dataset_path, f"data.{file_extension}")
return os.path.join(self._get_abs_dataset_path(), f"data.{file_extension}")

def _is_cache_valid(self, cache_file: str) -> bool:
if not os.path.exists(cache_file):
Expand Down Expand Up @@ -154,10 +162,11 @@ def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame:
def _load_from_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
if source_type in ["csv", "parquet"]:
filpath = os.path.join(
"datasets", self.dataset_path, self.schema["source"]["path"]
filepath = os.path.join(
self._get_abs_dataset_path(),
self.schema["source"]["path"],
)
return self._read_csv_or_parquet(filpath, source_type)
return self._read_csv_or_parquet(filepath, source_type)

query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()
Expand Down
18 changes: 13 additions & 5 deletions pandasai/data_loader/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def _get_columns(self) -> str:
else:
return "*"

def _get_table_name(self):
table_name = self.schema["source"].get("table", None) or self.schema["name"]

if not table_name:
raise ValueError("Table name not found in schema!")

table_name = table_name.lower()

return table_name

def _add_order_by(self) -> str:
if "order_by" not in self.schema:
return ""
Expand All @@ -40,16 +50,14 @@ def get_head_query(self, n=5):
source = self.schema.get("source", {})
source_type = source.get("type")

table_name = self.schema["source"]["table"]
table_name = self._get_table_name()

columns = self._get_columns()

order_by = "RAND()"
if source_type in {"sqlite", "postgres"}:
order_by = "RANDOM()"
order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()"

return f"SELECT {columns} FROM {table_name} ORDER BY {order_by} LIMIT {n}"

def get_row_count(self):
table_name = self.schema["source"]["table"]
table_name = self._get_table_name()
return f"SELECT COUNT(*) FROM {table_name}"
Loading