Skip to content

Commit

Permalink
test done
Browse files Browse the repository at this point in the history
  • Loading branch information
BeachWang committed Dec 31, 2024
1 parent 8e01f7e commit af9e14d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 34 deletions.
11 changes: 6 additions & 5 deletions data_juicer/ops/mapper/video_extract_frames_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@

from pydantic import PositiveInt

from data_juicer.utils.constant import Fields
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.file_utils import dict_to_hash
from data_juicer.utils.mm_utils import (
SpecialTokens, close_video, extract_key_frames,
extract_key_frames_by_seconds, extract_video_frames_uniformly,
extract_video_frames_uniformly_by_seconds, load_data_with_context,
load_video)

from ..base_op import OPERATORS, Mapper
from ..base_op import OPERATORS, TAGGING_OPS, Mapper
from ..op_fusion import LOADED_VIDEOS

OP_NAME = 'video_extract_frames_mapper'


@TAGGING_OPS.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_VIDEOS.register_module(OP_NAME)
class VideoExtractFramesMapper(Mapper):
Expand All @@ -41,7 +42,7 @@ def __init__(
frame_num: PositiveInt = 3,
duration: float = 0,
frame_dir: str = None,
frame_key=Fields.video_frames,
frame_key=MetaKeys.video_frames,
*args,
**kwargs,
):
Expand Down Expand Up @@ -103,7 +104,7 @@ def _get_default_frame_dir(self, original_filepath):

def process_single(self, sample, context=False):
# check if it's generated already
if self.frame_key in sample:
if self.frame_key in sample[Fields.meta]:
return sample

# there is no videos in this sample
Expand Down Expand Up @@ -168,6 +169,6 @@ def process_single(self, sample, context=False):
for vid_key in videos:
close_video(videos[vid_key])

sample[self.frame_key] = json.dumps(video_to_frame_dir)
sample[Fields.meta][self.frame_key] = json.dumps(video_to_frame_dir)

return sample
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class MetaKeys(object):
video_frame_tags = 'video_frame_tags'
# # video-audio tags
video_audio_tags = 'video_audio_tags'
# # video frames
video_frames = 'video_frames'
# # image tags
image_tags = 'image_tags'

Expand Down
15 changes: 9 additions & 6 deletions tests/ops/grouper/test_naive_reverse_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.grouper.naive_reverse_grouper import NaiveReverseGrouper
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase
from data_juicer.utils.constant import Fields


class NaiveReverseGrouperTest(DataJuicerTestCaseBase):
Expand Down Expand Up @@ -87,7 +88,7 @@ def test_rm_unbatched_keys1(self):
"Sur la plateforme MT4, plusieurs manières d'accéder à \n"
'ces fonctionnalités sont conçues simultanément.'
],
'batch_size': 2,
Fields.agg: {'batch_size': 2},
}
]

Expand All @@ -114,11 +115,13 @@ def test_rm_unbatched_keys2(self):
'query':[
'Can I help you?'
],
'reponse':[
'No',
'Yes'
],
'batch_size': 1,
Fields.agg: {
'reponse':[
'No',
'Yes'
],
'batch_size': 1,
}
}
]

Expand Down
54 changes: 31 additions & 23 deletions tests/ops/mapper/test_video_extract_frames_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.video_extract_frames_mapper import \
VideoExtractFramesMapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.constant import Fields, MetaKeys
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase

Expand Down Expand Up @@ -70,18 +70,20 @@ def test_duration(self):
vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir)

tgt_list = copy.deepcopy(ds_list)
tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})})
tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})})
tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})})
tgt_list[0].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}})
tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}})
tgt_list[2].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}})

op = VideoExtractFramesMapper(
frame_sampling_method='uniform',
frame_num=frame_num,
duration=0,
frame_dir=frame_dir)
frame_dir=frame_dir,
batch_size=2,
num_proc=1)

dataset = Dataset.from_list(ds_list)
dataset = dataset.map(op.process, batch_size=2, num_proc=1)
dataset = op.run(dataset)
res_list = dataset.to_list()
self.assertEqual(res_list, tgt_list)
self.assertListEqual(
Expand Down Expand Up @@ -114,18 +116,20 @@ def test_uniform_sampling(self):
vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir)

tgt_list = copy.deepcopy(ds_list)
tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})})
tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})})
tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})})
tgt_list[0].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}})
tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}})
tgt_list[2].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}})

op = VideoExtractFramesMapper(
frame_sampling_method='uniform',
frame_num=frame_num,
duration=10,
frame_dir=frame_dir)
frame_dir=frame_dir,
batch_size=2,
num_proc=1)

dataset = Dataset.from_list(ds_list)
dataset = dataset.map(op.process, batch_size=2, num_proc=1)
dataset = op.run(dataset)
res_list = dataset.to_list()
self.assertEqual(res_list, tgt_list)
self.assertListEqual(
Expand Down Expand Up @@ -158,22 +162,24 @@ def test_all_keyframes_sampling(self):
vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir)

tgt_list = copy.deepcopy(ds_list)
tgt_list[0].update({Fields.video_frames:
json.dumps({self.vid1_path: vid1_frame_dir})})
tgt_list[1].update({Fields.video_frames: json.dumps({
tgt_list[0].update({Fields.meta: {MetaKeys.video_frames:
json.dumps({self.vid1_path: vid1_frame_dir})}})
tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({
self.vid2_path: vid2_frame_dir,
self.vid3_path: vid3_frame_dir
})})
tgt_list[2].update({Fields.video_frames:
json.dumps({self.vid3_path: vid3_frame_dir})})
})}})
tgt_list[2].update({Fields.meta: {MetaKeys.video_frames:
json.dumps({self.vid3_path: vid3_frame_dir})}})

op = VideoExtractFramesMapper(
frame_sampling_method='all_keyframes',
frame_dir=frame_dir,
duration=5)
duration=5,
batch_size=2,
num_proc=2)

dataset = Dataset.from_list(ds_list)
dataset = dataset.map(op.process, batch_size=2, num_proc=2)
dataset = op.run(dataset)
res_list = dataset.to_list()
self.assertEqual(res_list, tgt_list)
self.assertListEqual(
Expand Down Expand Up @@ -205,19 +211,21 @@ def test_default_frame_dir(self):
frame_sampling_method='uniform',
frame_num=frame_num,
duration=5,
batch_size=2,
num_proc=1
)

vid1_frame_dir = op._get_default_frame_dir(self.vid1_path)
vid2_frame_dir = op._get_default_frame_dir(self.vid2_path)
vid3_frame_dir = op._get_default_frame_dir(self.vid3_path)

tgt_list = copy.deepcopy(ds_list)
tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})})
tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})})
tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})})
tgt_list[0].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}})
tgt_list[1].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}})
tgt_list[2].update({Fields.meta: {MetaKeys.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}})

dataset = Dataset.from_list(ds_list)
dataset = dataset.map(op.process, batch_size=2, num_proc=1)
dataset = op.run(dataset)
res_list = dataset.to_list()

frame_dir_prefix = self._get_default_frame_dir_prefix()
Expand Down

0 comments on commit af9e14d

Please sign in to comment.