Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dataloader): add support in dataloader for csv and parquet #1451

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
cache_file = self._get_cache_file_path()

if self._is_cache_valid(cache_file):
return self._read_cache(cache_file)
cache_format = self.schema["destination"]["format"]
return self._read_csv_or_parquet(cache_file, cache_format)

df = self._load_from_source()
df = self._apply_transformations(df)
self._cache_data(df, cache_file)

table_name = self.schema["source"]["table"]
table_name = self.schema["source"].get("table", None) or self.schema["name"]
table_description = self.schema.get("description", None)

return DataFrame(
Expand All @@ -47,7 +48,7 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
else:
# Initialize new dataset loader for virtualization
data_loader = self.copy()
table_name = self.schema["source"]["table"]
table_name = self.schema["source"].get("table", None) or self.schema["name"]
table_description = self.schema.get("description", None)
return VirtualDataFrame(
schema=self.schema,
Expand All @@ -69,7 +70,10 @@ def _load_schema(self):

def _validate_source_type(self):
source_type = self.schema["source"]["type"]
if source_type not in SUPPORTED_SOURCES:
if source_type not in SUPPORTED_SOURCES and source_type not in [
"csv",
"parquet",
]:
raise ValueError(f"Unsupported database type: {source_type}")

def _get_cache_file_path(self) -> str:
Expand All @@ -88,36 +92,13 @@ def _is_cache_valid(self, cache_file: str) -> bool:
return False

file_mtime = datetime.fromtimestamp(os.path.getmtime(cache_file))
update_frequency = self.schema["update_frequency"]
update_frequency = self.schema.get("update_frequency", None)

if update_frequency == "weekly":
if update_frequency and update_frequency == "weekly":
return file_mtime > datetime.now() - timedelta(weeks=1)

return False

def _read_cache(self, cache_file: str) -> DataFrame:
cache_format = self.schema["destination"]["format"]
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
if cache_format == "parquet":
return DataFrame(
pd.read_parquet(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
elif cache_format == "csv":
return DataFrame(
pd.read_csv(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

def _get_loader_function(self, source_type: str):
"""
Get the loader function for a specified data source type.
Expand Down Expand Up @@ -148,7 +129,36 @@ def _get_loader_function(self, source_type: str):
f"Please install the {SUPPORTED_SOURCES[source_type]} library."
) from e

def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame:
table_name = self.schema["source"].get("table") or self.schema.get("name", None)
table_description = self.schema.get("description", None)
if format == "parquet":
return DataFrame(
pd.read_parquet(file_path),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
elif format == "csv":
return DataFrame(
pd.read_csv(file_path),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
else:
raise ValueError(f"Unsupported file format: {format}")

def _load_from_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
if source_type in ["csv", "parquet"]:
filpath = os.path.join(
"datasets", self.dataset_path, self.schema["source"]["path"]
)
return self._read_csv_or_parquet(filpath, source_type)

query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()
return self.execute_query(query)
Expand Down
9 changes: 6 additions & 3 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,18 @@ def _create_yml_template(self, name, description, columns: List[dict]):
table_name (str): Name of the table or dataset.
"""
# Metadata template
metadata = {
return {
"name": name,
"description": description,
"columns": columns,
"source": {"type": "parquet", "path": "data.parquet"},
"destination": {
"type": "local",
"format": "parquet",
"path": "data.parquet",
},
}

return metadata

def save(
self, path: str, name: str, description: str = None, columns: List[dict] = []
):
Expand Down
Loading