Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

支持RangeSpecifiedFieldSelector使用指定字段的值域进行数据选择 #432

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions data_juicer/ops/selector/range_specified_field_selector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import heapq
import bisect
from typing import Optional

from pydantic import Field, PositiveInt
Expand All @@ -17,6 +17,8 @@ class RangeSpecifiedFieldSelector(Selector):
def __init__(
self,
field_key: str = '',
lower_value: float = None,
upper_value: float = None,
lower_percentile: Optional[Annotated[float,
Field(ge=0, le=1)]] = None,
upper_percentile: Optional[Annotated[float,
Expand Down Expand Up @@ -57,6 +59,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)
self.field_key = field_key
self.lower_value = lower_value
self.upper_value = upper_value
self.lower_percentile = lower_percentile
self.upper_percentile = upper_percentile
self.lower_rank = lower_rank
Expand All @@ -66,21 +70,10 @@ def process(self, dataset):
if len(dataset) <= 1 or not self.field_key:
return dataset

if self.lower_percentile is None and self.lower_rank is None:
if self.lower_value is None and self.upper_value is None and \
self.lower_percentile is None and self.upper_percentile is None \
and self.lower_rank is None and self.upper_rank is None:
return dataset
if self.upper_percentile is None and self.upper_rank is None:
return dataset

lower_bound, upper_bound = 0, len(dataset)
if self.lower_percentile is not None:
lower_bound = int(self.lower_percentile * len(dataset))
if self.lower_rank is not None:
lower_bound = max(lower_bound, self.lower_rank)
if self.upper_percentile is not None:
upper_bound = int(self.upper_percentile * len(dataset))
if self.upper_rank is not None:
upper_bound = min(upper_bound, self.upper_rank)
upper_bound = max(lower_bound, upper_bound)

field_keys = self.field_key.split('.')
assert field_keys[0] in dataset.features.keys(
Expand All @@ -102,13 +95,28 @@ def get_field_value_list(cur_dataset, field_keys):
return field_value_list

field_value_list = get_field_value_list(dataset, field_keys)
select_index = heapq.nsmallest(int(upper_bound), range(len(dataset)),
field_value_list.__getitem__)
sub_dataset = dataset.select(select_index)
field_value_list, indices = zip(
*sorted(list(zip(field_value_list, range(len(field_value_list))))))

lower_bound, upper_bound = 0, len(dataset) - 1
if self.lower_value is not None:
lower_bound = bisect.bisect_left(field_value_list,
self.lower_value)
if self.lower_percentile is not None:
lower_bound = max(lower_bound,
int(self.lower_percentile * len(dataset)))
if self.lower_rank is not None:
lower_bound = max(lower_bound, self.lower_rank)
if self.upper_value is not None:
upper_bound = bisect.bisect_right(field_value_list,
self.upper_value) - 1
if self.upper_percentile is not None:
upper_bound = min(upper_bound,
int(self.upper_percentile * len(dataset)))
if self.upper_rank is not None:
upper_bound = min(upper_bound, self.upper_rank)
upper_bound = max(lower_bound, upper_bound)

field_value_list = get_field_value_list(sub_dataset, field_keys)
select_index = heapq.nlargest(int(upper_bound - lower_bound),
range(len(sub_dataset)),
field_value_list.__getitem__)
select_index = indices[lower_bound:upper_bound + 1]

return sub_dataset.select(select_index)
return dataset.select(select_index)
153 changes: 153 additions & 0 deletions tests/ops/selector/test_range_specified_field_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,159 @@ def _run_range_selector(self, dataset: Dataset, target_list, op):
target_list = sorted(target_list, key=lambda x: x['text'])
self.assertEqual(res_list, target_list)

def test_value_select(self):
ds_list = [{
'text': 'Today is Sun',
'count': 101,
'meta': {
'suffix': '.pdf',
'key1': {
'key2': {
'count': 34
},
'count': 5
}
}
}, {
'text': 'a v s e c s f e f g a a a ',
'count': 16,
'meta': {
'suffix': '.docx',
'key1': {
'key2': {
'count': 243
},
'count': 63
}
}
}, {
'text': '中文也是一个字算一个长度',
'count': 162,
'meta': {
'suffix': '.txt',
'key1': {
'key2': {
'count': None
},
'count': 23
}
}
}, {
'text': ',。、„”“«»1」「《》´∶:?!',
'count': None,
'meta': {
'suffix': '.html',
'key1': {
'key2': {
'count': 18
},
'count': 48
}
}
}, {
'text': '他的英文名字叫Harry Potter',
'count': 88,
'meta': {
'suffix': '.pdf',
'key1': {
'key2': {
'count': 551
},
'count': 78
}
}
}, {
'text': '这是一个测试',
'count': None,
'meta': {
'suffix': '.py',
'key1': {
'key2': {
'count': 89
},
'count': 3
}
}
}, {
'text': '我出生于2023年12月15日',
'count': None,
'meta': {
'suffix': '.java',
'key1': {
'key2': {
'count': 354.32
},
'count': 67
}
}
}, {
'text': 'emoji表情测试下😊,😸31231\n',
'count': 2,
'meta': {
'suffix': '.html',
'key1': {
'key2': {
'count': 354.32
},
'count': 32
}
}
}, {
'text': 'a=1\nb\nc=1+2+3+5\nd=6',
'count': 178,
'meta': {
'suffix': '.pdf',
'key1': {
'key2': {
'count': 33
},
'count': 33
}
}
}, {
'text': '使用片段分词器对每个页面进行分词,使用语言',
'count': 666,
'meta': {
'suffix': '.xml',
'key1': {
'key2': {
'count': 18
},
'count': 48
}
}
}]
tgt_list = [{
'text': 'a v s e c s f e f g a a a ',
'count': 16,
'meta': {
'suffix': '.docx',
'key1': {
'key2': {
'count': 243
},
'count': 63
}
}
}, {
'text': '我出生于2023年12月15日',
'count': None,
'meta': {
'suffix': '.java',
'key1': {
'key2': {
'count': 354.32
},
'count': 67
}
}
}]
dataset = Dataset.from_list(ds_list)
op = RangeSpecifiedFieldSelector(field_key='meta.key1.count',
lower_value=63,
upper_value=67)
self._run_range_selector(dataset, tgt_list, op)

def test_percentile_select(self):
ds_list = [{
'text': 'Today is Sun',
Expand Down
Loading