Skip to content

Commit

Permalink
update data validator logic and add/fix test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
cyruszhang committed Dec 25, 2024
1 parent bbc303d commit 4b6065f
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 89 deletions.
17 changes: 6 additions & 11 deletions data_juicer/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from data_juicer.analysis import ColumnWiseAnalysis, OverallAnalysis
from data_juicer.config import init_configs
from data_juicer.format import load_formatter
from data_juicer.core.data.dataset_builder import DatasetBuilder
from data_juicer.ops import Filter, load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils import cache_utils
Expand Down Expand Up @@ -42,14 +42,9 @@ def __init__(self, cfg: Optional[Namespace] = None):
f'[{self.cfg.cache_compress}]')
cache_utils.CACHE_COMPRESS = self.cfg.cache_compress

# setup formatter
# setup dataset builder
logger.info('Setting up data formatter...')
self.formatter = load_formatter(
dataset_path=self.cfg.dataset_path,
generated_dataset_config=self.cfg.generated_dataset_config,
text_keys=self.cfg.text_keys,
suffixes=self.cfg.suffixes,
add_suffix=self.cfg.add_suffix)
self.dataset_builder = DatasetBuilder(self.cfg)

# prepare exporter and check export path suffix
# NOTICE: no need to export dataset texts for analyzer
Expand Down Expand Up @@ -84,9 +79,9 @@ def run(self,
"""
# 1. format data
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
if load_data_np is not None:
self.dataset_builder.set_dataset_path(load_data_np)
dataset = self.dataset_builder.load_dataset()

# extract processes
logger.info('Preparing process operators...')
Expand Down
61 changes: 61 additions & 0 deletions data_juicer/core/data/config_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Dict


class ConfigValidationError(Exception):
"""Custom exception for validation errors"""
pass


class ConfigValidator:
"""Mixin class for configuration validation"""

# Define validation rules for each strategy type
CONFIG_VALIDATION_RULES = {
'required_fields': [], # Fields that must be present
'optional_fields': [], # Fields that are optional
'field_types': {}, # Expected types for fields
'custom_validators': {} # Custom validation functions
}

def validate_config(self, ds_config: Dict) -> None:
"""
Validate the configuration dictionary.
Args:
ds_config: Configuration dictionary to validate
Raises:
ValidationError: If validation fails
"""
# Check required fields
missing_fields = [
field for field in self.CONFIG_VALIDATION_RULES['required_fields']
if field not in ds_config
]
if missing_fields:
raise ConfigValidationError(
f"Missing required fields: {', '.join(missing_fields)}")

# Optional fields
# no need for any special checks

# Check field types
for field, expected_type in self.CONFIG_VALIDATION_RULES[
'field_types'].items():
if field in ds_config:
value = ds_config[field]
if not isinstance(value, expected_type):
raise ConfigValidationError(
f"Field '{field}' must be of "
"type '{expected_type.__name__}', "
f"got '{type(value).__name__}'")

# Run custom validators
for field, validator in self.CONFIG_VALIDATION_RULES[
'custom_validators'].items():
if field in ds_config:
try:
validator(ds_config[field])
except Exception as e:
raise ConfigValidationError(
f"Validation failed for field '{field}': {str(e)}")
54 changes: 51 additions & 3 deletions data_juicer/core/data/data_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
class DataValidator(ABC):
"""Base class for data validation"""

def __init__(self, config: Dict):
self.config = config

@abstractmethod
def validate(self, dataset) -> None:
"""
Expand Down Expand Up @@ -52,7 +55,8 @@ class ConversationDataValidator(DataValidator):
"""Validator for conversation data"""

def __init__(self, config: Dict):
self.config = config
super().__init__(config)

# Validation rules specific to conversation data
self.required_columns = ['text']
self.min_turns = config.get('min_turns', 2)
Expand Down Expand Up @@ -81,7 +85,8 @@ class CodeDataValidator(DataValidator):
"""Validator for code data"""

def __init__(self, config: Dict):
self.config = config
super().__init__(config)

self.required_columns = ['code', 'language']
self.supported_languages = config.get('supported_languages', [])

Expand All @@ -102,10 +107,14 @@ def __init__(self, config: Dict):
config: Dict containing:
- required_fields: List of field names that must exist
- field_types: Optional map of field names to expected types
- allow_missing: Optional float for max ratio missing allowed
"""
self.config = config
super().__init__(config)

self.required_fields = config['required_fields']
self.field_types = config.get('field_types', {})
# Default no missing allowed
self.allow_missing = config.get('allow_missing', 0.0)

def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None:
"""
Expand All @@ -130,3 +139,42 @@ def validate(self, dataset: Union[NestedDataset, RayDataset]) -> None:
if missing_fields:
raise DataValidationError(
f'Dataset missing required fields: {missing_fields}')

# Check field types and missing values
for field in self.required_fields:
# Get expected type if specified
expected_type = self.field_types.get(field)

# Sample data for validation
# For large datasets, we check a sample for performance
MAX_SAMPLE_SIZE = 1000
if isinstance(dataset, NestedDataset):
sample_size = min(MAX_SAMPLE_SIZE, len(dataset))
sample = dataset.select(range(sample_size))
values = sample[field]
elif isinstance(dataset, RayDataset): # RayDataset
sample_size = min(MAX_SAMPLE_SIZE, dataset.data.count())
sample = dataset.data.take(sample_size)
values = [row[field] for row in sample]
else:
raise NotImplementedError

# Check for missing values
missing_count = sum(1 for v in values if v is None)
missing_ratio = missing_count / len(values)
if missing_ratio > self.allow_missing:
raise DataValidationError(
f"Field '{field}' has {missing_ratio:.1%} missing values, "
f'exceeding allowed {self.allow_missing:.1%}')

# Check types if specified
if expected_type:
invalid_types = [
type(v) for v in values
if v is not None and not isinstance(v, expected_type)
]
if invalid_types:
raise DataValidationError(
f"Field '{field}' contains values of incorrect type. "
f'Expected {expected_type.__name__}, '
f'got {set(t.__name__ for t in invalid_types)}')
23 changes: 21 additions & 2 deletions data_juicer/core/data/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class DatasetBuilder(object):

def __init__(self, cfg):
self.cfg = cfg

# defaults to use dataset_path
if cfg.dataset_path is not None:
ds_configs = rewrite_cli_datapath(cfg.dataset_path)
Expand All @@ -22,6 +23,7 @@ def __init__(self, cfg):
raise ValueError(
'Unable to initialize dataset; should have one of '
'dataset_path or dataset in configurations')

# dataset config could be a list or a single entry; retrofit
if not isinstance(ds_configs, list):
ds_configs = [ds_configs]
Expand All @@ -35,6 +37,7 @@ def __init__(self, cfg):
DataLoadStrategyRegistry.get_strategy_class(
executor_type, data_type, data_source)(ds_config))

# initialize data validators
self.validators = []
if hasattr(cfg, 'validators'):
for validator_config in cfg.validators:
Expand All @@ -45,18 +48,34 @@ def __init__(self, cfg):
self.validators.append(validator_cls(validator_config))

def load_dataset(self) -> Union[NestedDataset, RayDataset]:
# load dataset with its load strategy
# do data validation
_datasets = []
for f in self.load_strategies:
# load dataset with its load strategy
_dataset = f.load_data(self.cfg)

# do data validation
for validator in self.validators:
validator.validate(_dataset)
_datasets.append(_dataset)

# handle data mixture
return _datasets[0]

@classmethod
def load_dataset_by_generated_config(cls, generated_dataset_config):
"""
load dataset by generated config
"""
assert isinstance(generated_dataset_config,
dict) and 'type' in generated_dataset_config
args = generated_dataset_config.copy()

# TODO finish the auto local dataset part
obj_name = args.pop('type')
from data_juicer.format.formatter import FORMATTERS
dataset = FORMATTERS.modules[obj_name](**args).load_dataset()
return dataset


def rewrite_cli_datapath(dataset_path) -> List:
"""
Expand Down
61 changes: 1 addition & 60 deletions data_juicer/core/data/load_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, Optional, Type, Union

from data_juicer.core.data import DJDataset, RayDataset
from data_juicer.core.data.config_validator import ConfigValidator
from data_juicer.download.downloader import validate_snapshot_format
from data_juicer.utils.lazy_loader import LazyLoader

Expand Down Expand Up @@ -40,66 +41,6 @@ def matches(self, other: 'StrategyKey') -> bool:
and fnmatch.fnmatch(other.data_source, self.data_source))


class ConfigValidationError(Exception):
"""Custom exception for validation errors"""
pass


class ConfigValidator:
"""Mixin class for configuration validation"""

# Define validation rules for each strategy type
CONFIG_VALIDATION_RULES = {
'required_fields': [], # Fields that must be present
'optional_fields': [], # Fields that are optional
'field_types': {}, # Expected types for fields
'custom_validators': {} # Custom validation functions
}

def validate_config(self, ds_config: Dict) -> None:
"""
Validate the configuration dictionary.
Args:
ds_config: Configuration dictionary to validate
Raises:
ValidationError: If validation fails
"""
# Check required fields
missing_fields = [
field for field in self.CONFIG_VALIDATION_RULES['required_fields']
if field not in ds_config
]
if missing_fields:
raise ConfigValidationError(
f"Missing required fields: {', '.join(missing_fields)}")

# Optional fields
# no need for any special checks

# Check field types
for field, expected_type in self.CONFIG_VALIDATION_RULES[
'field_types'].items():
if field in ds_config:
value = ds_config[field]
if not isinstance(value, expected_type):
raise ConfigValidationError(
f"Field '{field}' must be of "
"type '{expected_type.__name__}', "
f"got '{type(value).__name__}'")

# Run custom validators
for field, validator in self.CONFIG_VALIDATION_RULES[
'custom_validators'].items():
if field in ds_config:
try:
validator(ds_config[field])
except Exception as e:
raise ConfigValidationError(
f"Validation failed for field '{field}': {str(e)}")


class DataLoadStrategy(ABC, ConfigValidator):
"""
abstract class for data load strategy
Expand Down
3 changes: 3 additions & 0 deletions data_juicer/core/data/ray_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(self,
if cfg:
self.num_proc = cfg.np

def schema(self):
return self.data.schema()

def process(self,
operators,
*,
Expand Down
7 changes: 3 additions & 4 deletions data_juicer/format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from .empty_formatter import EmptyFormatter, RayEmptyFormatter
from .formatter import LocalFormatter, RemoteFormatter
from .json_formatter import JsonFormatter
from .load import load_formatter
from .mixture_formatter import MixtureFormatter
from .parquet_formatter import ParquetFormatter
from .text_formatter import TextFormatter
from .tsv_formatter import TsvFormatter

__all__ = [
'load_formatter', 'JsonFormatter', 'LocalFormatter', 'RemoteFormatter',
'TextFormatter', 'ParquetFormatter', 'CsvFormatter', 'TsvFormatter',
'MixtureFormatter', 'EmptyFormatter', 'RayEmptyFormatter'
'JsonFormatter', 'LocalFormatter', 'RemoteFormatter', 'TextFormatter',
'ParquetFormatter', 'CsvFormatter', 'TsvFormatter', 'MixtureFormatter',
'EmptyFormatter', 'RayEmptyFormatter'
]
4 changes: 2 additions & 2 deletions data_juicer/format/load.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from data_juicer.format.formatter import (FORMATTERS, BaseFormatter,
MixtureFormatter, RemoteFormatter)
from data_juicer.format import MixtureFormatter, RemoteFormatter
from data_juicer.format.formatter import FORMATTERS, BaseFormatter
from data_juicer.utils.file_utils import (find_files_with_suffix,
is_absolute_path)

Expand Down
Loading

0 comments on commit 4b6065f

Please sign in to comment.