Skip to content

Commit

Permalink
fix(dataloader): add support in dataloader for csv and parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Dec 6, 2024
1 parent 61773a8 commit 9dfcea4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 32 deletions.
68 changes: 39 additions & 29 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
cache_file = self._get_cache_file_path()

if self._is_cache_valid(cache_file):
return self._read_cache(cache_file)
cache_format = self.schema["destination"]["format"]
return self._read_csv_or_parquet(cache_file, cache_format)

df = self._load_from_source()
df = self._apply_transformations(df)
self._cache_data(df, cache_file)

table_name = self.schema["source"]["table"]
table_name = self.schema["source"].get("table", None) or self.schema["name"]
table_description = self.schema.get("description", None)

return DataFrame(
Expand All @@ -47,7 +48,7 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame:
else:
# Initialize new dataset loader for virtualization
data_loader = self.copy()
table_name = self.schema["source"]["table"]
table_name = self.schema["source"].get("table", None) or self.schema["name"]
table_description = self.schema.get("description", None)
return VirtualDataFrame(
schema=self.schema,
Expand All @@ -69,7 +70,10 @@ def _load_schema(self):

def _validate_source_type(self):
source_type = self.schema["source"]["type"]
if source_type not in SUPPORTED_SOURCES:
if source_type not in SUPPORTED_SOURCES and source_type not in [
"csv",
"parquet",
]:
raise ValueError(f"Unsupported database type: {source_type}")

def _get_cache_file_path(self) -> str:
Expand All @@ -88,36 +92,13 @@ def _is_cache_valid(self, cache_file: str) -> bool:
return False

file_mtime = datetime.fromtimestamp(os.path.getmtime(cache_file))
update_frequency = self.schema["update_frequency"]
update_frequency = self.schema.get("update_frequency", None)

if update_frequency == "weekly":
if update_frequency and update_frequency == "weekly":
return file_mtime > datetime.now() - timedelta(weeks=1)

return False

def _read_cache(self, cache_file: str) -> DataFrame:
cache_format = self.schema["destination"]["format"]
table_name = self.schema["source"]["table"]
table_description = self.schema.get("description", None)
if cache_format == "parquet":
return DataFrame(
pd.read_parquet(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
elif cache_format == "csv":
return DataFrame(
pd.read_csv(cache_file),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
else:
raise ValueError(f"Unsupported cache format: {cache_format}")

def _get_loader_function(self, source_type: str):
"""
Get the loader function for a specified data source type.
Expand Down Expand Up @@ -148,7 +129,36 @@ def _get_loader_function(self, source_type: str):
f"Please install the {SUPPORTED_SOURCES[source_type]} library."
) from e

def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame:
table_name = self.schema["source"].get("table") or self.schema.get("name", None)
table_description = self.schema.get("description", None)
if format == "parquet":
return DataFrame(
pd.read_parquet(file_path),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
elif format == "csv":
return DataFrame(
pd.read_csv(file_path),
schema=self.schema,
path=self.dataset_path,
name=table_name,
description=table_description,
)
else:
raise ValueError(f"Unsupported file format: {format}")

def _load_from_source(self) -> pd.DataFrame:
source_type = self.schema["source"]["type"]
if source_type in ["csv", "parquet"]:
filpath = os.path.join(
"datasets", self.dataset_path, self.schema["source"]["path"]
)
return self._read_csv_or_parquet(filpath, source_type)

query_builder = QueryBuilder(self.schema)
query = query_builder.build_query()
return self.execute_query(query)
Expand Down
9 changes: 6 additions & 3 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,18 @@ def _create_yml_template(self, name, description, columns: List[dict]):
table_name (str): Name of the table or dataset.
"""
# Metadata template
metadata = {
return {
"name": name,
"description": description,
"columns": columns,
"source": {"type": "parquet", "path": "data.parquet"},
"destination": {
"type": "local",
"format": "parquet",
"path": "data.parquet",
},
}

return metadata

def save(
self, path: str, name: str, description: str = None, columns: List[dict] = []
):
Expand Down

0 comments on commit 9dfcea4

Please sign in to comment.