Skip to content

Commit

Permalink
Add minhash deduplicator based on RAY. (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyushuo authored Dec 31, 2024
1 parent 9466c73 commit 1fe821f
Show file tree
Hide file tree
Showing 12 changed files with 661 additions and 27 deletions.
18 changes: 18 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,24 @@ process:
redis_port: 6380 # the port of redis instance, please note that the default port of redis is 6379 which is the same as default port for ray, so we need to modify the default redis config to use it in other port
lowercase: false # whether to convert text to lower case
ignore_non_character: false # whether to ignore non-alphabet characters, including whitespaces, digits, and punctuations
- ray_bts_minhash_deduplicator: # the document deduplicator that can run on multi-nodes using minhashLSH algorithm
tokenization: space # tokenization method for text. One of [space, punctuation, character, sentencepiece]
window_size: 5 # window size of shingling
num_permutations: 256 # number of permutations in minhash computing
jaccard_threshold: 0.7 # the min jaccard similarity threshold in near-duplicate detection. When the jaccard similarity of two sample texts is >= this threshold, they are regarded as similar samples and this op will only keep one of them after deduplication
num_bands: null # number of bands in LSH. Default it's None, and it will be determined by an optimal params computation algorithm by minimize the weighted sum of probs of False Positives and False Negatives
num_rows_per_band: null # number of rows in each band in LSH. Default it's None, and it will be determined by an optimal params computation algorithm
lowercase: true # whether to convert text to lower case
ignore_pattern: null # whether to ignore sub-strings with specific pattern when computing simhash.
tokenizer_model: null # path for the sentencepiece model, used for sentencepiece tokenization.
union_find_parallel_num: 'auto' # number of parallel workers for union-find algorithm. Default it's 'auto', and it will be determined by half of the number of CPUs.
union_threshold: 256 # threshold for minhash values group to perform union-find algorightm.
max_pending_edge_buffer_task: 20 # max number of pending edge buffer ray tasks.
num_edge_buffer_task_returns: 10 # number of edge buffer tasks for `ray.wait` to return.
max_pending_filter_tasks: 20 # max number of pending filter ray tasks.
num_filter_task_returns: 10 # number of filter tasks for `ray.wait` to return.
merge_batch_size: 1000 # batch size for BTS operations.
tmp_file_name: './outputs/ray-dedup-tmp/' # the temporary folder name for deduplication.

# Selector ops
- frequency_specified_field_selector: # selector to select samples based on the sorted frequency of specified field value
Expand Down
1 change: 1 addition & 0 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def init_setup_from_cfg(cfg: Namespace):
'video_key': cfg.video_key,
'num_proc': cfg.np,
'turbo': cfg.turbo,
'work_dir': cfg.work_dir,
}
cfg.process = update_op_attr(cfg.process, op_attrs)

Expand Down
25 changes: 14 additions & 11 deletions data_juicer/core/ray_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from data_juicer import cuda_device_count
from data_juicer.core.data import DJDataset
from data_juicer.ops import Filter, Mapper
from data_juicer.ops import Deduplicator, Filter, Mapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.process_utils import calculate_np
Expand Down Expand Up @@ -62,18 +62,8 @@ def set_dataset_to_absolute_path(dataset, dataset_path, cfg):


def preprocess_dataset(dataset: rd.Dataset, dataset_path, cfg) -> rd.Dataset:
columns = dataset.columns()
if dataset_path:
dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg)
if Fields.stats not in columns:

def process_batch_arrow(table: pyarrow.Table) -> pyarrow.Table:
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(Fields.stats, [new_column_data])
return new_talbe

dataset = dataset.map_batches(process_batch_arrow,
batch_format='pyarrow')
return dataset


Expand Down Expand Up @@ -140,6 +130,17 @@ def _run_single_op(self, op):
batch_format='pyarrow',
num_gpus=num_gpus)
elif isinstance(op, Filter):
columns = self.data.columns()
if Fields.stats not in columns:

def process_batch_arrow(table: pyarrow.Table):
new_column_data = [{} for _ in range(len(table))]
new_talbe = table.append_column(
Fields.stats, [new_column_data])
return new_talbe

self.data = self.data.map_batches(process_batch_arrow,
batch_format='pyarrow')
if op.use_cuda():
op_kwargs = op._op_cfg[op._name]
self.data = self.data.map_batches(
Expand Down Expand Up @@ -169,6 +170,8 @@ def _run_single_op(self, op):
zero_copy_batch=True)
else:
self.data = self.data.filter(op.process)
elif isinstance(op, Deduplicator):
self.data = op.run(self.data)
else:
logger.error(
'Ray executor only support Filter and Mapper OPs for now')
Expand Down
40 changes: 30 additions & 10 deletions data_juicer/core/ray_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import shutil
import time

from loguru import logger
Expand All @@ -14,6 +16,21 @@
rd = LazyLoader('rd', 'ray.data')


class TempDirManager:

def __init__(self, tmp_dir):
self.tmp_dir = tmp_dir

def __enter__(self):
os.makedirs(self.tmp_dir, exist_ok=True)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if os.path.exists(self.tmp_dir):
logger.info(f'Removing tmp dir {self.tmp_dir} ...')
shutil.rmtree(self.tmp_dir)


class RayExecutor:
"""
Executor based on Ray.
Expand Down Expand Up @@ -41,6 +58,8 @@ def __init__(self, cfg=None):
# init ray
logger.info('Initing Ray ...')
ray.init(self.cfg.ray_address)
self.tmp_dir = os.path.join(self.work_dir, '.tmp',
ray.get_runtime_context().get_job_id())

def run(self, load_data_np=None):
"""
Expand Down Expand Up @@ -79,14 +98,15 @@ def run(self, load_data_np=None):
f'[{self.cfg.fusion_strategy}]...')
ops = fuse_operators(ops, probe_res)

# 3. data process
logger.info('Processing data...')
tstart = time.time()
dataset.process(ops)
tend = time.time()
logger.info(f'All Ops are done in {tend - tstart:.3f}s.')

# 4. data export
logger.info('Exporting dataset to disk...')
dataset.data.write_json(self.cfg.export_path, force_ascii=False)
with TempDirManager(self.tmp_dir):
# 3. data process
logger.info('Processing data...')
tstart = time.time()
dataset.process(ops)

# 4. data export
logger.info('Exporting dataset to disk...')
dataset.data.write_json(self.cfg.export_path, force_ascii=False)
tend = time.time()
logger.info(f'All Ops are done in {tend - tstart:.3f}s.')
return dataset
3 changes: 3 additions & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def __init__(self, *args, **kwargs):
:param history_key: the key name of field that stores history of
queries and responses
:param index_key: index the samples before process if not None
:param batch_size: the batch size for processing
:param work_dir: the working directory for this operator
"""
# init data keys
self.text_key = kwargs.get('text_key', 'text')
Expand All @@ -152,6 +154,7 @@ def __init__(self, *args, **kwargs):
self.index_key = kwargs.get('index_key', None)

self.batch_size = kwargs.get('batch_size', 1000)
self.work_dir = kwargs.get('work_dir', None)

# whether the model can be accelerated using cuda
_accelerator = kwargs.get('accelerator', None)
Expand Down
16 changes: 12 additions & 4 deletions data_juicer/ops/deduplicator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
from .document_simhash_deduplicator import DocumentSimhashDeduplicator
from .image_deduplicator import ImageDeduplicator
from .ray_basic_deduplicator import RayBasicDeduplicator
from .ray_bts_minhash_deduplicator import RayBTSMinhashDeduplicator
from .ray_document_deduplicator import RayDocumentDeduplicator
from .ray_image_deduplicator import RayImageDeduplicator
from .ray_video_deduplicator import RayVideoDeduplicator
from .video_deduplicator import VideoDeduplicator

__all__ = [
'DocumentDeduplicator', 'DocumentMinhashDeduplicator',
'DocumentSimhashDeduplicator', 'ImageDeduplicator', 'RayBasicDeduplicator',
'RayDocumentDeduplicator', 'RayImageDeduplicator', 'RayVideoDeduplicator',
'VideoDeduplicator'
'DocumentDeduplicator',
'DocumentMinhashDeduplicator',
'DocumentSimhashDeduplicator',
'ImageDeduplicator',
'RayBasicDeduplicator',
'RayDocumentDeduplicator',
'RayImageDeduplicator',
'RayVideoDeduplicator',
'RayImageDeduplicator',
'RayBTSMinhashDeduplicator',
'VideoDeduplicator',
]
Loading

0 comments on commit 1fe821f

Please sign in to comment.