Skip to content

Commit

Permalink
fix batch bug (#504)
Browse files Browse the repository at this point in the history
* fix batch bug

* fix filter batch

* not rank for filter

* limit pyav version
  • Loading branch information
BeachWang authored Dec 5, 2024
1 parent 4ab426e commit 5a4b1a1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 44 deletions.
55 changes: 16 additions & 39 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ def process(self,
monitor_dir)
return dataset

def map(self, *args, **kargs):
"""Override the map func, which is called by most common operations,
such that the processed samples can be accessed by nested manner."""
def update_args(self, args, kargs, is_filter=False):
if args:
args = list(args)
# the first positional para is function
Expand All @@ -253,15 +251,17 @@ def map(self, *args, **kargs):
# batched is required for fault-tolerant or batched OP
if callable(getattr(
called_func.__self__,
'is_batched_op')) and called_func.__self__.is_batched_op(
) or not getattr(called_func.__self__, 'turbo', False):
'is_batched_op')) and called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)
elif not getattr(called_func.__self__, 'turbo', False):
kargs['batched'] = True
kargs['batch_size'] = 1
else:
kargs['batched'] = False

# rank is required for cuda model loading
if callable(
# rank is required for cuda model loading for map
if not is_filter and callable(
getattr(called_func.__self__,
'use_cuda')) and called_func.__self__.use_cuda():
kargs['with_rank'] = True
Expand All @@ -270,6 +270,14 @@ def map(self, *args, **kargs):
new_fingerprint = generate_fingerprint(self, *args, **kargs)
kargs['new_fingerprint'] = new_fingerprint

return args, kargs

def map(self, *args, **kargs):
"""Override the map func, which is called by most common operations,
such that the processed samples can be accessed by nested manner."""

args, kargs = self.update_args(args, kargs)

if cache_utils.CACHE_COMPRESS:
decompress(self, kargs['new_fingerprint'],
kargs['num_proc'] if 'num_proc' in kargs else 1)
Expand All @@ -288,38 +296,7 @@ def map(self, *args, **kargs):
def filter(self, *args, **kargs):
"""Override the filter func, which is called by most common operations,
such that the processed samples can be accessed by nested manner."""
if args:
args = list(args)
# the first positional para is function
if args[0] is None:
args[0] = lambda x: nested_obj_factory(x)
else:
args[0] = wrap_func_with_nested_access(args[0])
called_func = args[0]
else:
if 'function' not in kargs or kargs['function'] is None:
kargs['function'] = lambda x: nested_obj_factory(x)
else:
kargs['function'] = wrap_func_with_nested_access(
kargs['function'])
called_func = kargs['function']

# For wrapped function, try to get its unwrapped (bound) method
while not inspect.ismethod(called_func) and hasattr(
called_func, '__wrapped__'):
called_func = called_func.__wrapped__

# Batched is always required for fault tolerance
if inspect.ismethod(called_func):
if callable(getattr(
called_func.__self__,
'is_batched_op')) and called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
kargs['new_fingerprint'] = new_fingerprint
args, kargs = self.update_args(args, kargs, is_filter=True)

# For filter, it involves a map and a filter operations, so the final
# cache files includes two sets with different fingerprint (before and
Expand Down
12 changes: 8 additions & 4 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def wrapper(samples, *args, **kwargs):
return wrapper


def catch_map_single_exception(method):
def catch_map_single_exception(method, return_sample=True):
"""
For single-map sample-level fault tolerance.
The input sample is expected batch_size = 1.
Expand All @@ -92,8 +92,11 @@ def wrapper(sample, *args, **kwargs):
if is_batched(sample):
try:
sample = convert_dict_list_to_list_dict(sample)[0]
res_sample = method(sample, *args, **kwargs)
return convert_list_dict_to_dict_list([res_sample])
res = method(sample, *args, **kwargs)
if return_sample:
return convert_list_dict_to_dict_list([res])
else:
return [res]
except Exception as e:
from loguru import logger
logger.error(
Expand Down Expand Up @@ -315,7 +318,8 @@ def __init__(self, *args, **kwargs):
else:
self.compute_stats = catch_map_single_exception(
self.compute_stats_single)
self.process = catch_map_single_exception(self.process_single)
self.process = catch_map_single_exception(self.process_single,
return_sample=False)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion environments/minimal_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ datasets>=2.19.0
fsspec==2023.5.0
pandas
numpy
av
av==13.1.0
soundfile
librosa>=0.10
loguru
Expand Down

0 comments on commit 5a4b1a1

Please sign in to comment.