diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 9cef1fe89..d1a77b581 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -225,9 +225,7 @@ def process(self, monitor_dir) return dataset - def map(self, *args, **kargs): - """Override the map func, which is called by most common operations, - such that the processed samples can be accessed by nested manner.""" + def update_args(self, args, kargs, is_filter=False): if args: args = list(args) # the first positional para is function @@ -253,15 +251,17 @@ def map(self, *args, **kargs): # batched is required for fault-tolerant or batched OP if callable(getattr( called_func.__self__, - 'is_batched_op')) and called_func.__self__.is_batched_op( - ) or not getattr(called_func.__self__, 'turbo', False): + 'is_batched_op')) and called_func.__self__.is_batched_op(): kargs['batched'] = True kargs['batch_size'] = kargs.pop('batch_size', 1) + elif not getattr(called_func.__self__, 'turbo', False): + kargs['batched'] = True + kargs['batch_size'] = 1 else: kargs['batched'] = False - # rank is required for cuda model loading - if callable( + # rank is required for cuda model loading for map + if not is_filter and callable( getattr(called_func.__self__, 'use_cuda')) and called_func.__self__.use_cuda(): kargs['with_rank'] = True @@ -270,6 +270,14 @@ def map(self, *args, **kargs): new_fingerprint = generate_fingerprint(self, *args, **kargs) kargs['new_fingerprint'] = new_fingerprint + return args, kargs + + def map(self, *args, **kargs): + """Override the map func, which is called by most common operations, + such that the processed samples can be accessed by nested manner.""" + + args, kargs = self.update_args(args, kargs) + if cache_utils.CACHE_COMPRESS: decompress(self, kargs['new_fingerprint'], kargs['num_proc'] if 'num_proc' in kargs else 1) @@ -288,38 +296,7 @@ def map(self, *args, **kargs): def filter(self, *args, **kargs): """Override the filter func, which is called by most common operations, such that the processed samples can be accessed by nested manner.""" - if args: - args = list(args) - # the first positional para is function - if args[0] is None: - args[0] = lambda x: nested_obj_factory(x) - else: - args[0] = wrap_func_with_nested_access(args[0]) - called_func = args[0] - else: - if 'function' not in kargs or kargs['function'] is None: - kargs['function'] = lambda x: nested_obj_factory(x) - else: - kargs['function'] = wrap_func_with_nested_access( - kargs['function']) - called_func = kargs['function'] - - # For wrapped function, try to get its unwrapped (bound) method - while not inspect.ismethod(called_func) and hasattr( - called_func, '__wrapped__'): - called_func = called_func.__wrapped__ - - # Batched is always required for fault tolerance - if inspect.ismethod(called_func): - if callable(getattr( - called_func.__self__, - 'is_batched_op')) and called_func.__self__.is_batched_op(): - kargs['batched'] = True - kargs['batch_size'] = kargs.pop('batch_size', 1) - - if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None: - new_fingerprint = generate_fingerprint(self, *args, **kargs) - kargs['new_fingerprint'] = new_fingerprint + args, kargs = self.update_args(args, kargs, is_filter=True) # For filter, it involves a map and a filter operations, so the final # cache files includes two sets with different fingerprint (before and diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 13f3b61ae..831d94c12 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -70,7 +70,7 @@ def wrapper(samples, *args, **kwargs): return wrapper -def catch_map_single_exception(method): +def catch_map_single_exception(method, return_sample=True): """ For single-map sample-level fault tolerance. The input sample is expected batch_size = 1. @@ -92,8 +92,11 @@ def wrapper(sample, *args, **kwargs): if is_batched(sample): try: sample = convert_dict_list_to_list_dict(sample)[0] - res_sample = method(sample, *args, **kwargs) - return convert_list_dict_to_dict_list([res_sample]) + res = method(sample, *args, **kwargs) + if return_sample: + return convert_list_dict_to_dict_list([res]) + else: + return [res] except Exception as e: from loguru import logger logger.error( @@ -315,7 +318,8 @@ def __init__(self, *args, **kwargs): else: self.compute_stats = catch_map_single_exception( self.compute_stats_single) - self.process = catch_map_single_exception(self.process_single) + self.process = catch_map_single_exception(self.process_single, + return_sample=False) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index 7d37959fe..df76b1358 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -2,7 +2,7 @@ datasets>=2.19.0 fsspec==2023.5.0 pandas numpy -av +av==13.1.0 soundfile librosa>=0.10 loguru