From 0ba6459acac1243ddb3dfd6d89e349ecd172e298 Mon Sep 17 00:00:00 2001 From: Haibin <1400012807@pku.edu.cn> Date: Fri, 20 Dec 2024 10:51:29 +0800 Subject: [PATCH] tags specified field --- data_juicer/ops/selector/__init__.py | 4 +- .../selector/tags_specified_field_selector.py | 54 ++++++++++++++++ .../selector/test_tags_specified_selector.py | 63 +++++++++++++++++++ 3 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 data_juicer/ops/selector/tags_specified_field_selector.py create mode 100644 tests/ops/selector/test_tags_specified_selector.py diff --git a/data_juicer/ops/selector/__init__.py b/data_juicer/ops/selector/__init__.py index 22df12987..0339a2c5b 100644 --- a/data_juicer/ops/selector/__init__.py +++ b/data_juicer/ops/selector/__init__.py @@ -1,9 +1,11 @@ from .frequency_specified_field_selector import FrequencySpecifiedFieldSelector from .random_selector import RandomSelector from .range_specified_field_selector import RangeSpecifiedFieldSelector +from .tags_specified_field_selector import TagsSpecifiedFieldSelector from .topk_specified_field_selector import TopkSpecifiedFieldSelector __all__ = [ 'FrequencySpecifiedFieldSelector', 'RandomSelector', - 'RangeSpecifiedFieldSelector', 'TopkSpecifiedFieldSelector' + 'RangeSpecifiedFieldSelector', 'TagsSpecifiedFieldSelector', + 'TopkSpecifiedFieldSelector' ] diff --git a/data_juicer/ops/selector/tags_specified_field_selector.py b/data_juicer/ops/selector/tags_specified_field_selector.py new file mode 100644 index 000000000..6fb32251a --- /dev/null +++ b/data_juicer/ops/selector/tags_specified_field_selector.py @@ -0,0 +1,54 @@ +import numbers +from typing import List + +from ..base_op import OPERATORS, Selector + + +@OPERATORS.register_module('tags_specified_field_selector') +class TagsSpecifiedFieldSelector(Selector): + """Selector to select samples based on the tags of specified + field.""" + + def __init__(self, + field_key: str = '', + target_tags: List[str] = None, + *args, + **kwargs): + """ + Initialization method. + + :param field_key: Selector based on the specified value + corresponding to the target key. The target key + corresponding to multi-level field information need to be + separated by '.'. + :param target_tags: Target tags to be select. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.field_key = field_key + self.target_tags = set(target_tags) + + def process(self, dataset): + if len(dataset) <= 1 or not self.field_key: + return dataset + + field_keys = self.field_key.split('.') + assert field_keys[0] in dataset.features.keys( + ), "'{}' not in {}".format(field_keys[0], dataset.features.keys()) + + selected_index = [] + for i, item in enumerate(dataset[field_keys[0]]): + field_value = item + for key in field_keys[1:]: + assert key in field_value.keys(), "'{}' not in {}".format( + key, field_value.keys()) + field_value = field_value[key] + assert field_value is None or isinstance( + field_value, str) or isinstance( + field_value, numbers.Number + ), 'The {} item is not String, Numbers or NoneType'.format(i) + if field_value in self.target_tags: + selected_index.append(i) + + return dataset.select(selected_index) diff --git a/tests/ops/selector/test_tags_specified_selector.py b/tests/ops/selector/test_tags_specified_selector.py new file mode 100644 index 000000000..87c232a2b --- /dev/null +++ b/tests/ops/selector/test_tags_specified_selector.py @@ -0,0 +1,63 @@ +import unittest + +from data_juicer.core.data import NestedDataset as Dataset + +from data_juicer.ops.selector.tags_specified_field_selector import \ + TagsSpecifiedFieldSelector +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class TagsSpecifiedFieldSelectorTest(DataJuicerTestCaseBase): + + def _run_tag_selector(self, dataset: Dataset, target_list, op): + dataset = op.process(dataset) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_tag_select(self): + ds_list = [{ + 'text': 'a', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'b', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'c', + 'meta': { + 'sentiment': 'sad', + } + }, { + 'text': 'd', + 'meta': { + 'sentiment': 'angry', + } + }] + tgt_list = [{ + 'text': 'a', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'b', + 'meta': { + 'sentiment': 'happy', + } + }, { + 'text': 'c', + 'meta': { + 'sentiment': 'sad', + } + }] + dataset = Dataset.from_list(ds_list) + op = TagsSpecifiedFieldSelector( + field_key='meta.sentiment', + target_tags=['happy', 'sad']) + self._run_tag_selector(dataset, tgt_list, op) + + +if __name__ == '__main__': + unittest.main()