From 23a1c9fba282b5ae909c16ce7e7ecc169d7bfde2 Mon Sep 17 00:00:00 2001 From: awkrail Date: Mon, 7 Oct 2024 14:56:49 +0900 Subject: [PATCH] fix evaluate.py --- README.md | 2 +- training/evaluate.py | 31 +++++++++++++++++++------------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 439d104..8dd2bc9 100755 --- a/README.md +++ b/README.md @@ -228,7 +228,7 @@ python training/train.py --model moment_detr --dataset tvsum --feature clip_slow #### Evaluation The evaluation command is: ``` -python training/evaluate.py --model MODEL --dataset DATASET --feature FEATURE --split {val,test} --model_path MODEL_PATH --eval_path EVAL_PATH +python training/evaluate.py --model MODEL --dataset DATASET --feature FEATURE --split {val,test} --model_path MODEL_PATH --eval_path EVAL_PATH [--domain DOMAIN] ``` (**Example 1**) Evaluating Moment DETR w/ CLIP+Slowfast on the QVHighlights val set: ``` diff --git a/training/evaluate.py b/training/evaluate.py index 5d58adc..8265257 100755 --- a/training/evaluate.py +++ b/training/evaluate.py @@ -392,7 +392,7 @@ def start_inference(opt, domain=None): logger.info("metrics_no_nms {}".format(pprint.pformat(metrics["brief"], indent=4))) -def check_valid_combination(dataset, feature): +def check_valid_combination(dataset, feature, domain): dataset_feature_map = { 'qvhighlight': ['resnet_glove', 'clip', 'clip_slowfast', 'clip_slowfast_pann'], 'qvhighlight_pretrain': ['resnet_glove', 'clip', 'clip_slowfast', 'clip_slowfast_pann'], @@ -403,7 +403,16 @@ def check_valid_combination(dataset, feature): 'youtube_highlight': ['clip', 'clip_slowfast'], 'clotho-moment': ['clap'], } - return feature in dataset_feature_map[dataset] + + domain_map = { + 'tvsum': ['BK', 'BT', 'DS', 'FM', 'GA', 'MS', 'PK', 'PR', 'VT', 'VU'], + 'youtube_highlight': ['dog', 'gymnastics', 'parkour', 'skating', 'skiing', 'surfing'], + } + + if dataset in domain_map: + return feature in dataset_feature_map[dataset] and domain in domain_map[dataset] + else: + return feature in dataset_feature_map[dataset] if __name__ == '__main__': @@ -421,12 +430,16 @@ def check_valid_combination(dataset, feature): parser.add_argument('--model_path', type=str, required=True, help='saved model path') parser.add_argument('--split', type=str, required=True, choices=['val', 'test'], help='val or test') parser.add_argument('--eval_path', type=str, required=True, help='evaluation data') - args = parser.parse_args() + parser.add_argument('--domain', '-dm', type=str, + choices=['BK', 'BT', 'DS', 'FM', 'GA', 'MS', 'PK', 'PR', 'VT', 'VU', + 'dog', 'gymnastics', 'parkour', 'skating', 'skiing', 'surfing'], + help='domain for highlight detection dataset (e.g., BK for TVSum, dog for YouTube Highlight).') - is_valid = check_valid_combination(args.dataset, args.feature) + args = parser.parse_args() + is_valid = check_valid_combination(args.dataset, args.feature, args.domain) if is_valid: - option_manager = BaseOptions(args.model, args.dataset, args.feature) + option_manager = BaseOptions(args.model, args.dataset, args.feature, args.domain) option_manager.parse() opt = option_manager.option os.makedirs(opt.results_dir, exist_ok=True) @@ -434,13 +447,7 @@ def check_valid_combination(dataset, feature): opt.model_path = args.model_path opt.eval_split_name = args.split opt.eval_path = args.eval_path - - if 'domains' in opt: - for domain in opt.domains: - opt.results_dir = os.path.join(opt.results_dir, domain) - start_inference(opt, domain=domain) - else: - start_inference(opt) + start_inference(opt, domain=args.domain) else: raise ValueError('The combination of dataset and feature is invalid: dataset={}, feature={}'.format(args.dataset, args.feature))