Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Chinese distill-whisper fine-tuning results #1648

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions egs/multi_zh-hans/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
| Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
|Model| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech |
|-|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|-------------------|
| | Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | test meeting |
|whisper-large-v2-ft |Greedy Search | 23.22 | 28.24 | 0.61 | 0.66 | 2.67 | 2.80 | 16.61 | 2.56 | 2.21 | 4.73 | 1.90 | 5.98 | 8.13 |
|whisper-large-v2-ft-distill |Greedy Search | 24.91 | 26.73 | 0.91 | 0.94 | 2.71 | 2.98 | 17.65 | 2.81 | 2.47 | 5.16 | 2.10 | 6.27 | 8.34 |

Command for training is:
```bash
Expand Down
11 changes: 7 additions & 4 deletions egs/speechio/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
| 7 | aispeech_api_zh | 3.62% | 2023.12 |
| 8 | **whisper-large-ft-v1** | **4.32%** | 2024.04 |
| 9 | **whisper-large-ft-v0.5** | **4.60%** | 2024.04 |
| 10 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
| 11 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 12 | baidu_pro_api_zh | 7.29% | 2023.12 |
| 10 | **whisper-large-ft-v1-distill** | **4.71%** | 2024.04 |
| 11 | **zipformer (70Mb)** | **6.17%** | 2023.10 |
| 12 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 13 | baidu_pro_api_zh | 7.29% | 2023.12 |

Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)

For **whisper-large-ft-v1-distill**, instead of actually using distillation loss for training, the model structure and parameter initialization method from the [distill-whisper](https://arxiv.org/abs/2311.00430) paper were adopted: only the first and last layers of the decoder were retained.

<details><summary> Detail all models </summary><p>

| Model | Training Set | Note |
Expand All @@ -31,7 +34,7 @@ Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leade
|[whisper-large-ft-v0](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/tree/main/exp_large_v2)| wenetspeech | greedy_search, 3 epochs|
|[whisper-large-ft-v0.5](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper/blob/main/epoch-2-avg-5.pt)| wenetspeech(updated) | [wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy_search, 2 epochs |
|[whisper-large-ft-v1](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1.1)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 3 epochs|

|[whisper-large-ft-v1-distill](https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper/tree/main/v1-distill)|wenetspeech(updated), other multi-hans-zh exclude datatang 200h|[wenetspeech update method](https://github.com/k2-fsa/icefall/blob/master/egs/wenetspeech/ASR/local/fix_manifest.py), greedy search, 6 epochs|
</details>


Expand Down
12 changes: 11 additions & 1 deletion egs/speechio/ASR/whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert

Expand Down Expand Up @@ -215,7 +216,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "base", "small", "tiny"],
help="""The model name to use.
""",
)
Expand All @@ -227,6 +228,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)

parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)

return parser


Expand Down Expand Up @@ -431,6 +439,8 @@ def main():

if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:
Expand Down
Loading