Skip to content

Commit

Permalink
Add hub checkpoint check
Browse files Browse the repository at this point in the history
  • Loading branch information
zhong-al committed Nov 19, 2024
1 parent a9f2deb commit e4dee39
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions tests/test_miniscene2behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class TestMiniscene2Behavior(unittest.TestCase):
def setUpClass(cls):
cls.local_checkpoint = "checkpoint_epoch_00075.pyth"

# run tracks_extractor
sys.argv = ["tracks_extractor.py",
"--video", "tests/detection_example/DJI_0068.mp4",
"--annotation", "tests/detection_example/DJI_0068.xml"]
tracks_extractor.main()

@classmethod
def download_model(cls):
if not os.path.exists(cls.local_checkpoint):
Expand All @@ -39,7 +45,7 @@ def download_model(cls):
with open(f"{cls.local_checkpoint}.zip", "wb") as f:
f.write(r.content)

# unzip model checkpoint
# Unzip model checkpoint
with zipfile.ZipFile(f"{cls.local_checkpoint}.zip", "r") as zip_ref:
zip_ref.extractall(".")

Expand All @@ -54,7 +60,8 @@ def tearDownClass(cls):
def setUp(self):
self.tool = "miniscene2behavior.py"
self.hub = "imageomics/x3d-kabr-kinetics"
self.checkpoint = "checkpoint_epoch_00075.pyth.zip"
self.checkpoint = "checkpoint_epoch_00075.pyth"
self.checkpoint_archive = "checkpoint_epoch_00075.pyth.zip"
self.miniscene = "mini-scenes/tests|detection_example|DJI_0068"
self.video = "DJI_0068"
self.config = "config.yml"
Expand All @@ -65,21 +72,38 @@ def tearDown(self):
# TODO: delete outputs
del_file(self.output)

def test_run(self):
# run tracks_extractor
sys.argv = ["tracks_extractor.py",
"--video", "tests/detection_example/DJI_0068.mp4",
"--annotation", "tests/detection_example/DJI_0068.xml"]
tracks_extractor.main()
def test_hub_checkpoint_archive(self):
# annotate mini-scenes
sys.argv = [self.tool,
"--hub", self.hub,
"--checkpoint", self.checkpoint_archive,
"--miniscene", self.miniscene,
"--video", self.video]
run()

@patch("kabr_tools.miniscene2behavior.create_model")
@patch("kabr_tools.miniscene2behavior.annotate_miniscene")
def test_hub_checkpoint(self, annotate_miniscene, create_model):
# annotate mini-scenes
sys.argv = [self.tool,
"--hub", self.hub,
"--checkpoint", self.checkpoint,
"--checkpoint", self.checkpoint_archive,
"--miniscene", self.miniscene,
"--video", self.video]

# patch create_model
create_model.return_value = (Mock(), Mock())
run()

# check arguments to create_model
config_path = create_model.call_args[0][0]
checkpoint_path = create_model.call_args[0][1]
download_folder = f"{checkpoint_path.rsplit("/", 1)[0]}/"
self.assertEqual(self.checkpoint,
checkpoint_path.replace(download_folder, ""))
self.assertEqual(self.config,
config_path.replace(download_folder, ""))

@patch('kabr_tools.utils.slowfast.utils.process_cv2_inputs')
@patch('kabr_tools.utils.slowfast.utils.cv2.VideoCapture')
def test_matching_tracks(self, video_capture, process_cv2_inputs):
Expand Down

0 comments on commit e4dee39

Please sign in to comment.