-
Notifications
You must be signed in to change notification settings - Fork 195
/
document_deduplicator.py
112 lines (93 loc) · 3.79 KB
/
document_deduplicator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Some code here has been modified from:
# https://github.com/bigscience-workshop/data-preparation/blob/main/preprocessing/training/01a_catalogue_cleaning_and_filtering/clean_helpers/deduplication.py
# --------------------------------------------------------
import hashlib
import string
from collections import defaultdict
from typing import Dict, Set
import regex as re
from data_juicer.utils.constant import HashKeys
from ..base_op import OPERATORS, Deduplicator
@OPERATORS.register_module('document_deduplicator')
class DocumentDeduplicator(Deduplicator):
"""
Deduplicator to deduplicate samples at document-level using exact matching.
Using md5 hash to deduplicate samples.
"""
def __init__(self,
lowercase: bool = False,
ignore_non_character: bool = False,
*args,
**kwargs):
"""
Initialization method.
:param lowercase: Whether to convert sample text to lower case
:param ignore_non_character: Whether to ignore non-alphabet
characters, including whitespaces, digits, and punctuations
:param args: extra args
:param kwargs: extra args.
"""
super().__init__(*args, **kwargs)
self.lowercase = lowercase
self.remove_non_character_regex = re.compile(
f'\s+|\d+|[{re.escape(string.punctuation)}]' # noqa: W605
) if ignore_non_character else None
def compute_hash(self, sample):
"""
Compute md5 hash values for the sample.
:param sample: input sample
:return: sample with md5 hash value.
"""
# check if it's computed already
if HashKeys.hash in sample:
return sample
text = sample[self.text_key]
if self.lowercase:
text = text.lower()
if self.remove_non_character_regex:
text = self.remove_non_character_regex.sub('', text)
def _get_hash(txt):
return hashlib.md5(txt.strip().encode('utf-8')).hexdigest()
sample[HashKeys.hash] = _get_hash(text)
return sample
def process(self, dataset, show_num=0):
"""
For doc-level, dataset --> dataset.
:param dataset: input dataset
:param show_num: number of traced samples used when tracer is
open.
:return: deduplicated dataset and the sampled duplicate pairs.
"""
# no need to deduplicate because too few samples
if len(dataset) <= 1:
return dataset, {}
dup_hashes = None
if show_num > 0:
# sample duplicate pairs
hash2ids: Dict[int, Set[int]] = defaultdict(set)
for sid, hash_val in enumerate(dataset[HashKeys.hash]):
hash2ids[hash_val].add(sid)
dup_samples = sorted(list(hash2ids.items()),
key=lambda x: len(x[1]),
reverse=True)
dup_hashes = set([
item[0] for item in dup_samples if len(item[1]) > 1
][:show_num])
def _filter_dup_helper(sample, hashes):
hash = sample[HashKeys.hash]
if show_num > 0 and hash in dup_hashes \
and len(dup_pairs[hash]) < 2:
# tracer is open and not enough duplicate sample pairs
dup_pairs[hash].append(sample)
if hash in hashes:
return False
else:
hashes.add(hash)
return True
hashes = set()
dup_pairs = {hash_v: [] for hash_v in dup_hashes} if dup_hashes else {}
dataset = dataset.filter(
_filter_dup_helper,
fn_kwargs=dict(hashes=hashes),
load_from_cache_file=False if show_num > 0 else True) # num_proc=1
return dataset, dup_pairs