diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py index d224a9b0a..17235a2b8 100644 --- a/data_juicer/core/ray_data.py +++ b/data_juicer/core/ray_data.py @@ -57,8 +57,9 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg): def preprocess_dataset(dataset: Dataset, dataset_path, cfg) -> Dataset: if dataset_path: dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) - if Fields.stats not in dataset.columns(fetch_if_missing=False): - logger.info(f'columns {dataset.columns(fetch_if_missing=False)}') + columns = dataset.columns() + if Fields.stats not in columns: + logger.info(f'columns {columns}') def process_batch_arrow(table: pa.Table) -> pa.Table: new_column_data = [{} for _ in range(len(table))]