From e4dee39c09a36214328e20b12ad45679fdc2dee6 Mon Sep 17 00:00:00 2001 From: zhong-al <74470739+zhong-al@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:55:12 -0500 Subject: [PATCH] Add hub checkpoint check --- tests/test_miniscene2behavior.py | 42 +++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/tests/test_miniscene2behavior.py b/tests/test_miniscene2behavior.py index 93b5cd5..ba57960 100644 --- a/tests/test_miniscene2behavior.py +++ b/tests/test_miniscene2behavior.py @@ -29,6 +29,12 @@ class TestMiniscene2Behavior(unittest.TestCase): def setUpClass(cls): cls.local_checkpoint = "checkpoint_epoch_00075.pyth" + # run tracks_extractor + sys.argv = ["tracks_extractor.py", + "--video", "tests/detection_example/DJI_0068.mp4", + "--annotation", "tests/detection_example/DJI_0068.xml"] + tracks_extractor.main() + @classmethod def download_model(cls): if not os.path.exists(cls.local_checkpoint): @@ -39,7 +45,7 @@ def download_model(cls): with open(f"{cls.local_checkpoint}.zip", "wb") as f: f.write(r.content) - # unzip model checkpoint + # Unzip model checkpoint with zipfile.ZipFile(f"{cls.local_checkpoint}.zip", "r") as zip_ref: zip_ref.extractall(".") @@ -54,7 +60,8 @@ def tearDownClass(cls): def setUp(self): self.tool = "miniscene2behavior.py" self.hub = "imageomics/x3d-kabr-kinetics" - self.checkpoint = "checkpoint_epoch_00075.pyth.zip" + self.checkpoint = "checkpoint_epoch_00075.pyth" + self.checkpoint_archive = "checkpoint_epoch_00075.pyth.zip" self.miniscene = "mini-scenes/tests|detection_example|DJI_0068" self.video = "DJI_0068" self.config = "config.yml" @@ -65,21 +72,38 @@ def tearDown(self): # TODO: delete outputs del_file(self.output) - def test_run(self): - # run tracks_extractor - sys.argv = ["tracks_extractor.py", - "--video", "tests/detection_example/DJI_0068.mp4", - "--annotation", "tests/detection_example/DJI_0068.xml"] - tracks_extractor.main() + def test_hub_checkpoint_archive(self): + # annotate mini-scenes + sys.argv = [self.tool, + "--hub", self.hub, + "--checkpoint", self.checkpoint_archive, + "--miniscene", self.miniscene, + "--video", self.video] + run() + @patch("kabr_tools.miniscene2behavior.create_model") + @patch("kabr_tools.miniscene2behavior.annotate_miniscene") + def test_hub_checkpoint(self, annotate_miniscene, create_model): # annotate mini-scenes sys.argv = [self.tool, "--hub", self.hub, - "--checkpoint", self.checkpoint, + "--checkpoint", self.checkpoint_archive, "--miniscene", self.miniscene, "--video", self.video] + + # patch create_model + create_model.return_value = (Mock(), Mock()) run() + # check arguments to create_model + config_path = create_model.call_args[0][0] + checkpoint_path = create_model.call_args[0][1] + download_folder = f"{checkpoint_path.rsplit("/", 1)[0]}/" + self.assertEqual(self.checkpoint, + checkpoint_path.replace(download_folder, "")) + self.assertEqual(self.config, + config_path.replace(download_folder, "")) + @patch('kabr_tools.utils.slowfast.utils.process_cv2_inputs') @patch('kabr_tools.utils.slowfast.utils.cv2.VideoCapture') def test_matching_tracks(self, video_capture, process_cv2_inputs):