diff --git a/data_juicer/utils/unittest_utils.py b/data_juicer/utils/unittest_utils.py index 5806fa60b..d545dbac9 100644 --- a/data_juicer/utils/unittest_utils.py +++ b/data_juicer/utils/unittest_utils.py @@ -51,30 +51,31 @@ def tearDownClass(cls, hf_model_name=None) -> None: shutil.rmtree(transformers.TRANSFORMERS_CACHE) @classmethod - def generate_dataset(cls, data, type='hf'): + def generate_dataset(cls, data, type='standalone'): """Generate dataset for a specific executor. Args: type (str, optional): `hf` or `ray`. Defaults to "hf". """ - if type == 'hf': + if type.startswith('standalone'): return Dataset.from_list(data) - elif type == 'ray': + elif type.startswith('ray'): return rd.from_items(data) + else: + raise ValueError("Unsupported type") @classmethod - def run_single_op(cls, dataset, op, type='hf'): + def run_single_op(cls, dataset, op, type='standalone'): """Run operator in the specific executor.""" - if type == 'hf': + if type.startswith('standalone'): if isinstance(op, Filter) and Fields.stats not in dataset.features: - # TODO: - # this is a temp solution, - # only add stats when calling filter op dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) dataset = dataset.map(op.compute_stats) dataset = dataset.filter(op.process) dataset = dataset.select_columns(column_names=['text']) return dataset.to_list() - elif type == 'ray': - pass + elif type.startswith('ray'): + raise ValueError("Unsupported type") + else: + raise ValueError("Unsupported type") \ No newline at end of file diff --git a/tests/ops/filter/test_alphanumeric_filter.py b/tests/ops/filter/test_alphanumeric_filter.py index 1ef5ebe2a..c0a64ffa7 100644 --- a/tests/ops/filter/test_alphanumeric_filter.py +++ b/tests/ops/filter/test_alphanumeric_filter.py @@ -9,7 +9,7 @@ class AlphanumericFilterTest(DataJuicerTestCaseBase): - @TEST_TAG("single") + @TEST_TAG("standalone") def test_case(self): ds_list = [{ @@ -40,10 +40,10 @@ def test_case(self): }] dataset = DataJuicerTestCaseBase.generate_dataset(ds_list) op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9) - result = DataJuicerTestCaseBase.run_single_op(dataset, op) + result = DataJuicerTestCaseBase.run_single_op(dataset, op, AlphanumericFilterTest.current_tag,) self.assertEqual(result, tgt_list) - @TEST_TAG("single") + @TEST_TAG("standalone") def test_token_case(self): ds_list = [{ diff --git a/tests/run.py b/tests/run.py index 24402fa55..a2fbeeb1c 100644 --- a/tests/run.py +++ b/tests/run.py @@ -22,7 +22,6 @@ parser.add_argument('--tag', choices=["standalone", "standalone-gpu", "ray", "ray-gpu"], default="standalone", help="the tag of tests being run") -parser.add_argument('--list_tests', action='store_true', help='list all tests') parser.add_argument('--pattern', default='test_*.py', help='test file pattern') parser.add_argument('--test_dir', default='tests', @@ -36,7 +35,8 @@ def __init__(self, tag=None): self.tag = tag def loadTestsFromTestCase(self, testCaseClass): - + # set tag to testcase class + setattr(testCaseClass, 'current_tag', self.tag) test_names = self.getTestCaseNames(testCaseClass) loaded_suite = self.suiteClass() for test_name in test_names: @@ -46,7 +46,7 @@ def loadTestsFromTestCase(self, testCaseClass): loaded_suite.addTest(test_case) return loaded_suite -def gather_test_cases(test_dir, pattern, list_tests, tag): +def gather_test_cases(test_dir, pattern, tag): test_to_run = unittest.TestSuite() test_loader = TaggedTestLoader(tag) discover = test_loader.discover(test_dir, pattern=pattern, top_level_dir=None) @@ -57,8 +57,8 @@ def gather_test_cases(test_dir, pattern, list_tests, tag): for test_case in test_suite: if type(test_case) in SKIPPED_TESTS.modules.values(): continue - if list_tests: - logger.info(f'Add test case [{str(test_case)}]') + logger.info(f'Add test case [{test_case._testMethodName}]' + f' from {test_case.__class__.__name__}') test_to_run.addTest(test_case) return test_to_run @@ -66,11 +66,10 @@ def gather_test_cases(test_dir, pattern, list_tests, tag): def main(): runner = unittest.TextTestRunner() test_suite = gather_test_cases(os.path.abspath(args.test_dir), - args.pattern, args.list_tests, args.tag) - if not args.list_tests: - res = runner.run(test_suite) - if not res.wasSuccessful(): - exit(1) + args.pattern, args.tag) + res = runner.run(test_suite) + if not res.wasSuccessful(): + exit(1) if __name__ == '__main__':