Skip to content

Commit

Permalink
* change the thresholds
Browse files Browse the repository at this point in the history
  • Loading branch information
HYLcool committed Dec 30, 2024
1 parent 2025a2e commit 85e757d
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions tests/ops/filter/test_video_motion_score_raft_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
from data_juicer.ops.filter.video_motion_score_raft_filter import \
VideoMotionScoreRaftFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS

# skip due to conflicts when run lazy_load in multiprocessing in librosa
# tests passed locally.
# @SKIPPED_TESTS.register_module()
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
class VideoMotionScoreRaftFilterTest(DataJuicerTestCaseBase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
Expand All @@ -26,7 +22,6 @@ def _run_helper(self, op, source_list, target_list, np=1):
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats, num_proc=np)
dataset = dataset.filter(op.process, num_proc=np)
print(dataset[Fields.stats])
dataset = dataset.select_columns(column_names=[op.video_key])
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)
Expand Down Expand Up @@ -134,7 +129,7 @@ def test_middle(self):
'videos': [self.vid3_path]
}]
tgt_list = [{'videos': [self.vid2_path]}]
op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.5)
op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.2)
self._run_helper(op, ds_list, tgt_list)

def test_any(self):
Expand All @@ -151,7 +146,7 @@ def test_any(self):
'videos': [self.vid2_path, self.vid3_path]
}]
op = VideoMotionScoreRaftFilter(min_score=3,
max_score=10.5,
max_score=10.2,
any_or_all='any')
self._run_helper(op, ds_list, tgt_list)

Expand All @@ -165,7 +160,7 @@ def test_all(self):
}]
tgt_list = []
op = VideoMotionScoreRaftFilter(min_score=3,
max_score=10.5,
max_score=10.2,
any_or_all='all')
self._run_helper(op, ds_list, tgt_list)

Expand All @@ -181,7 +176,7 @@ def test_parallel(self):
'videos': [self.vid3_path]
}]
tgt_list = [{'videos': [self.vid2_path]}]
op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.5)
op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.2)
self._run_helper(op, ds_list, tgt_list, np=2)


Expand Down

0 comments on commit 85e757d

Please sign in to comment.