Skip to content

Commit

Permalink
Support for more efficient distributed data processing
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Nov 1, 2023
1 parent 90491b2 commit 78a80c9
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions xtuner/dataset/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
from datasets import DatasetDict
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from torch import distributed as dist

from xtuner.registry import BUILDER, MAP_FUNC
from .utils import Packer, encode_fn


def process_hf_dataset(dataset,
tokenizer,
max_length,
dataset_map_fn=None,
template_map_fn=None,
max_dataset_length=None,
split='train',
remove_unused_columns=False,
rename_maps=[],
shuffle_before_pack=True,
pack_to_max_length=True,
input_ids_with_output=True):
def process(dataset,
tokenizer,
max_length,
dataset_map_fn=None,
template_map_fn=None,
max_dataset_length=None,
split='train',
remove_unused_columns=False,
rename_maps=[],
shuffle_before_pack=True,
pack_to_max_length=True,
input_ids_with_output=True):
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
dataset.
Expand Down Expand Up @@ -120,3 +121,16 @@ def process_hf_dataset(dataset,
dataset = dataset.map(Packer(max_length), batched=True)

return dataset


def process_hf_dataset(*args, **kwargs):
if not (dist.is_available() and dist.is_initialized()):
return process(*args, **kwargs)

if dist.get_rank() == 0:
dataset = process(*args, **kwargs)
objects = [dataset]
else:
objects = [None]
dist.broadcast_object_list(objects, src=0)
return objects[0]

0 comments on commit 78a80c9

Please sign in to comment.