Skip to content

Commit

Permalink
change unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Apr 25, 2024
1 parent 74c3465 commit 16ce4a8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
34 changes: 34 additions & 0 deletions data_juicer/utils/unittest_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import shutil
import unittest
from datasets import Dataset
import ray.data as rd

from data_juicer.ops import Filter
from data_juicer.utils.registry import Registry
from data_juicer.utils.constant import Fields


SKIPPED_TESTS = Registry('SkippedTests')

Expand Down Expand Up @@ -32,3 +37,32 @@ def tearDownClass(cls, hf_model_name=None) -> None:
if os.path.exists(transformers.TRANSFORMERS_CACHE):
print('CLEAN all TRANSFORMERS_CACHE')
shutil.rmtree(transformers.TRANSFORMERS_CACHE)

@classmethod
def generate_dataset(cls, data, type="hf"):
"""Generate dataset for a specific executor.
Args:
type (str, optional): `hf` or `ray`. Defaults to "hf".
"""
if type == "hf":
return Dataset.from_list(data)
elif type == "ray":
return rd.from_items(data)

@classmethod
def run_single_op(cls, dataset, op, type="hf"):
"""Run operator in the specific executor."""
if type == "hf":
if isinstance(op, Filter) and Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats)
dataset = dataset.filter(op.process)
dataset = dataset.select_columns(column_names=['text'])
return dataset.to_list()
elif type == "ray":
pass
10 changes: 6 additions & 4 deletions tests/ops/filter/test_alphanumeric_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def test_case(self):
}, {
'text': 'emoji表情测试下😊,😸31231\n'
}]
dataset = Dataset.from_list(ds_list)
dataset = DataJuicerTestCaseBase.generate_dataset(ds_list)
op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9)
self._run_alphanumeric_filter(dataset, tgt_list, op)
result = DataJuicerTestCaseBase.run_single_op(dataset, op)
self.assertEqual(result, tgt_list)

def test_token_case(self):

Expand All @@ -76,9 +77,10 @@ def test_token_case(self):
}, {
'text': 'Do you need a cup of coffee?'
}]
dataset = Dataset.from_list(ds_list)
dataset = DataJuicerTestCaseBase.generate_dataset(ds_list)
op = AlphanumericFilter(tokenization=True, min_ratio=1.5)
self._run_alphanumeric_filter(dataset, tgt_list, op)
result = DataJuicerTestCaseBase.run_single_op(dataset, op)
self.assertEqual(result, tgt_list)


if __name__ == '__main__':
Expand Down

0 comments on commit 16ce4a8

Please sign in to comment.