From 0936c80be758725ec23adc0279f64985157f8230 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 12 Jun 2024 08:43:05 +0000 Subject: [PATCH] add distill whisper results --- egs/multi_zh-hans/ASR/RESULTS.md | 9 +++++---- egs/speechio/ASR/RESULTS.md | 11 +++++++---- egs/speechio/ASR/whisper/decode.py | 12 +++++++++++- .../whisper/whisper_decoder_forward_monkey_patch.py | 1 + 4 files changed, 24 insertions(+), 9 deletions(-) create mode 120000 egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index a7f3bc4f79..e411e80a37 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -6,10 +6,11 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search. -| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | -|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| -| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | -| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | +|Model| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | +|-|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------| +| | Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting | +|whisper-large-v2-ft |Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 | +|whisper-large-v2-ft-distill |Greedy Search | 24.91 | 26.73 | 0.91 | 0.94 | 2.71 | 2.98 | 17.65 | 2.81 | 2.47 | 5.16 | 2.10 | 6.27 | 8.34 | Command for training is: ```bash diff --git a/egs/speechio/ASR/RESULTS.md b/egs/speechio/ASR/RESULTS.md index f1273d41e7..3c556f74ed 100644 --- a/egs/speechio/ASR/RESULTS.md +++ b/egs/speechio/ASR/RESULTS.md @@ -17,12 +17,15 @@ | 7 | aispeech_api_zh | 3.62% | 2023.12 | | 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 | | 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 | -| 10 | **zipformer (70Mb)** | **6.17%** | 2023.10 | -| 11 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | -| 12 | baidu_pro_api_zh | 7.29% | 2023.12 | +| 10 | **whisper-large-ft-v1-distill** | **4.71%** | 2024.04 | +| 11 | **zipformer (70Mb)** | **6.17%** | 2023.10 | +| 12 | **whisper-large-ft-v0** | **6.34%** | 2023.03 | +| 13 | baidu_pro_api_zh | 7.29% | 2023.12 | Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67) +For **whisper-large-ft-v1-distill**, instead of actually using distillation loss for training, the model structure and parameter initialization method from the [distill-whisper](https://arxiv.org/abs/2311.00430) paper were adopted: only the first and last layers of the decoder were retained. +
Detail all models

| Model | Training Set | Note | @@ -31,7 +34,7 @@ Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leade |[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs| |[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs | |[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs| - +|[whisper-large-ft-v1-distill](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1-distill)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 6 epochs|

diff --git a/egs/speechio/ASR/whisper/decode.py b/egs/speechio/ASR/whisper/decode.py index 70f743eeec..c20f1f7149 100644 --- a/egs/speechio/ASR/whisper/decode.py +++ b/egs/speechio/ASR/whisper/decode.py @@ -58,6 +58,7 @@ from multi_dataset import MultiDataset from tn.chinese.normalizer import Normalizer from whisper.normalizers import BasicTextNormalizer +from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from zhconv import convert @@ -215,7 +216,7 @@ def get_parser(): "--model-name", type=str, default="large-v2", - choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"], + choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"], help="""The model name to use. """, ) @@ -227,6 +228,13 @@ def get_parser(): help="replace whisper encoder forward method to remove input length restriction", ) + parser.add_argument( + "--use-distill-whisper", + type=str2bool, + default=False, + help="Whether to use architecture of distill whisper.", + ) + return parser @@ -431,6 +439,8 @@ def main(): if params.remove_whisper_encoder_input_length_restriction: replace_whisper_encoder_forward() + if params.use_distill_whisper: + replace_whisper_decoder_forward() model = whisper.load_model(params.model_name, "cpu") if params.epoch > 0: if params.avg > 1: diff --git a/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py b/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py new file mode 120000 index 0000000000..167fba1eb4 --- /dev/null +++ b/egs/speechio/ASR/whisper/whisper_decoder_forward_monkey_patch.py @@ -0,0 +1 @@ +../../../multi_zh-hans/ASR/whisper/whisper_decoder_forward_monkey_patch.py \ No newline at end of file