From 9dfcea4e866c090e8558aeae630c1ee0823500ba Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Fri, 6 Dec 2024 18:18:55 +0100 Subject: [PATCH] fix(dataloader): add support in dataloader for csv and parquet --- pandasai/data_loader/loader.py | 68 +++++++++++++++++++--------------- pandasai/dataframe/base.py | 9 +++-- 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index d7a133320..9612e3c17 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -28,13 +28,14 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame: cache_file = self._get_cache_file_path() if self._is_cache_valid(cache_file): - return self._read_cache(cache_file) + cache_format = self.schema["destination"]["format"] + return self._read_csv_or_parquet(cache_file, cache_format) df = self._load_from_source() df = self._apply_transformations(df) self._cache_data(df, cache_file) - table_name = self.schema["source"]["table"] + table_name = self.schema["source"].get("table", None) or self.schema["name"] table_description = self.schema.get("description", None) return DataFrame( @@ -47,7 +48,7 @@ def load(self, dataset_path: str, virtualized=False) -> DataFrame: else: # Initialize new dataset loader for virtualization data_loader = self.copy() - table_name = self.schema["source"]["table"] + table_name = self.schema["source"].get("table", None) or self.schema["name"] table_description = self.schema.get("description", None) return VirtualDataFrame( schema=self.schema, @@ -69,7 +70,10 @@ def _load_schema(self): def _validate_source_type(self): source_type = self.schema["source"]["type"] - if source_type not in SUPPORTED_SOURCES: + if source_type not in SUPPORTED_SOURCES and source_type not in [ + "csv", + "parquet", + ]: raise ValueError(f"Unsupported database type: {source_type}") def _get_cache_file_path(self) -> str: @@ -88,36 +92,13 @@ def _is_cache_valid(self, cache_file: str) -> bool: return False file_mtime = datetime.fromtimestamp(os.path.getmtime(cache_file)) - update_frequency = self.schema["update_frequency"] + update_frequency = self.schema.get("update_frequency", None) - if update_frequency == "weekly": + if update_frequency and update_frequency == "weekly": return file_mtime > datetime.now() - timedelta(weeks=1) return False - def _read_cache(self, cache_file: str) -> DataFrame: - cache_format = self.schema["destination"]["format"] - table_name = self.schema["source"]["table"] - table_description = self.schema.get("description", None) - if cache_format == "parquet": - return DataFrame( - pd.read_parquet(cache_file), - schema=self.schema, - path=self.dataset_path, - name=table_name, - description=table_description, - ) - elif cache_format == "csv": - return DataFrame( - pd.read_csv(cache_file), - schema=self.schema, - path=self.dataset_path, - name=table_name, - description=table_description, - ) - else: - raise ValueError(f"Unsupported cache format: {cache_format}") - def _get_loader_function(self, source_type: str): """ Get the loader function for a specified data source type. @@ -148,7 +129,36 @@ def _get_loader_function(self, source_type: str): f"Please install the {SUPPORTED_SOURCES[source_type]} library." ) from e + def _read_csv_or_parquet(self, file_path: str, format: str) -> DataFrame: + table_name = self.schema["source"].get("table") or self.schema.get("name", None) + table_description = self.schema.get("description", None) + if format == "parquet": + return DataFrame( + pd.read_parquet(file_path), + schema=self.schema, + path=self.dataset_path, + name=table_name, + description=table_description, + ) + elif format == "csv": + return DataFrame( + pd.read_csv(file_path), + schema=self.schema, + path=self.dataset_path, + name=table_name, + description=table_description, + ) + else: + raise ValueError(f"Unsupported file format: {format}") + def _load_from_source(self) -> pd.DataFrame: + source_type = self.schema["source"]["type"] + if source_type in ["csv", "parquet"]: + filpath = os.path.join( + "datasets", self.dataset_path, self.schema["source"]["path"] + ) + return self._read_csv_or_parquet(filpath, source_type) + query_builder = QueryBuilder(self.schema) query = query_builder.build_query() return self.execute_query(query) diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index e5a40957d..12ce616e3 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -166,15 +166,18 @@ def _create_yml_template(self, name, description, columns: List[dict]): table_name (str): Name of the table or dataset. """ # Metadata template - metadata = { + return { "name": name, "description": description, "columns": columns, "source": {"type": "parquet", "path": "data.parquet"}, + "destination": { + "type": "local", + "format": "parquet", + "path": "data.parquet", + }, } - return metadata - def save( self, path: str, name: str, description: str = None, columns: List[dict] = [] ):