-
Notifications
You must be signed in to change notification settings - Fork 195
/
random_selector.py
53 lines (42 loc) · 1.8 KB
/
random_selector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from typing import Optional
from pydantic import Field, PositiveInt
from typing_extensions import Annotated
from data_juicer.format.mixture_formatter import MixtureFormatter
from ..base_op import OPERATORS, Selector
@OPERATORS.register_module('random_selector')
class RandomSelector(Selector):
"""Selector to random select samples. """
def __init__(self,
select_ratio: Optional[Annotated[float,
Field(ge=0, le=1)]] = None,
select_num: PositiveInt = None,
*args,
**kwargs):
"""
Initialization method.
:param select_ratio: The ratio to select. When both
select_ratio and select_num are set, the value corresponding
to the smaller number of samples will be applied.
:param select_num: The number of samples to select. When both
select_ratio and select_num are set, the value corresponding
to the smaller number of samples will be applied.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.select_ratio = select_ratio
self.select_num = select_num
def process(self, dataset):
if len(dataset) <= 1:
return dataset
if self.select_ratio is None and self.select_num is None:
return dataset
select_num = 0
if not self.select_ratio:
select_num = self.select_num
else:
select_num = int(self.select_ratio * len(dataset))
if self.select_num and self.select_num < select_num:
select_num = self.select_num
return MixtureFormatter.random_sample(dataset,
sample_number=select_num)