Skip to content

Commit

Permalink
update speechio whisper ft results
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang committed Apr 24, 2024
1 parent df36f93 commit b970ba5
Show file tree
Hide file tree
Showing 13 changed files with 1,629 additions and 213 deletions.
43 changes: 43 additions & 0 deletions egs/multi_zh-hans/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,48 @@
## Results

### Multi Chinese datasets (without datatang 200h) finetuning results on Whisper-large-v2
#### Whisper
[./whisper](./whisper)

Character Error Rates (CERs) listed below are produced by the checkpoint of the second epoch using greedy search.

| Datasets | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech |
|--------------------------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------|
| Split | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net |
| Greedy Search | 23.45 | 25.42 | 0.78 | 0.83 | 2.75 | 2.93 | 17.11 | 2.68 | 2.33 | 4.97 | 2.02 | 6.34 | 5.06 | 8.38 | 6.94 |

Command for training is:
```bash
pip install -r whisper/requirements.txt

# We updated the label of wenetspeech to remove OCR deletion errors, see https://github.com/wenet-e2e/WenetSpeech/discussions/54

torchrun --nproc-per-node 8 ./whisper/train.py \
--max-duration 200 \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json
```

Command for decoding using fine-tuned models:
```bash
git lfs install
git clone https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper
ln -s icefall_asr_multi-hans-zh_whisper/v1/epoch-2-avg4.pt whisper/exp_large_v2/epoch-999.pt

python3 ./whisper/decode.py \
--exp-dir whisper/exp_large_v2 \
--model-name large-v2 \
--epoch 999 --avg 1 \
--beam-size 10 --max-duration 50
```

Fine-tuned models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/yuekai/icefall_asr_multi-hans-zh_whisper>


### Multi Chinese datasets char-based training results (Non-streaming) on zipformer model

This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall.
Expand Down
13 changes: 2 additions & 11 deletions egs/multi_zh-hans/ASR/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare WenetSpeech"
if [ -e ../../wenetspeech/ASR/data/fbank/.preprocess_complete ]; then
cd data/fbank
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_DEV_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_L_fixed.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) .
ln -svf $(realpath ../../../../wenetspeech/ASR/data/fbank/cuts_TEST_NET.jsonl.gz) .

Expand Down Expand Up @@ -299,15 +299,6 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Compute KeSpeech fbank for test/dev"
./local/compute_fbank_kespeech_dev_test.py

if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz ]; then
pieces=$(find data/fbank/kespeech/train_phase1_split_${num_splits} -name "kespeech-asr_cuts_train_phase1.*.jsonl.gz")
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
fi
if [ ! -f data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz ]; then
pieces=$(find data/fbank/kespeech/train_phase2_split_${num_splits} -name "kespeech-asr_cuts_train_phase2.*.jsonl.gz")
lhotse combine $pieces data/fbank/kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
fi

touch data/fbank/.kespeech.done
fi
fi
Expand Down
48 changes: 46 additions & 2 deletions egs/multi_zh-hans/ASR/whisper/decode.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from zhconv import convert

from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
Expand Down Expand Up @@ -214,7 +215,7 @@ def get_parser():
"--model-name",
type=str,
default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "base", "tiny"],
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
help="""The model name to use.
""",
)
Expand All @@ -226,6 +227,13 @@ def get_parser():
help="replace whisper encoder forward method to remove input length restriction",
)

parser.add_argument(
"--use-distill-whisper",
type=str2bool,
default=False,
help="Whether to use architecture of distill whisper.",
)

return parser


Expand Down Expand Up @@ -289,7 +297,6 @@ def decode_one_batch(
print(hyps)
return {"beam-search": hyps}


def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
Expand All @@ -307,6 +314,40 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("A", "A")
text = text.replace("a", "A")
text = text.replace("b", "B")
text = text.replace("c", "C")
text = text.replace("k", "K")
text = text.replace("t", "T")
text = text.replace(",", "")
text = text.replace("丶", "")
text = text.replace("。", "")
text = text.replace("、", "")
text = text.replace("?", "")
return text
results = []

num_cuts = 0
Expand All @@ -331,6 +372,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text)
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))

Expand Down Expand Up @@ -430,6 +472,8 @@ def main():

if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
if params.use_distill_whisper:
replace_whisper_decoder_forward()
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:
Expand Down
93 changes: 22 additions & 71 deletions egs/multi_zh-hans/ASR/whisper/multi_dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, fbank_dir: str):
- thchs_30_cuts_train.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase1.jsonl.gz
- kespeech/kespeech-asr_cuts_train_phase2.jsonl.gz
- wenetspeech/cuts_L.jsonl.gz
- wenetspeech/cuts_L_fixed.jsonl.gz
"""
self.fbank_dir = Path(fbank_dir)

Expand Down Expand Up @@ -105,7 +105,7 @@ def train_cuts(self) -> CutSet:
# WeNetSpeech
logging.info("Loading WeNetSpeech in lazy mode")
wenetspeech_L_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_L.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_L_fixed.jsonl.gz"
)

# KeSpeech
Expand All @@ -124,10 +124,10 @@ def train_cuts(self) -> CutSet:
aishell_4_L_cuts,
aishell_4_M_cuts,
aishell_4_S_cuts,
alimeeting_cuts,
stcmds_cuts,
primewords_cuts,
magicdata_cuts,
alimeeting_cuts,
wenetspeech_L_cuts,
kespeech_1_cuts,
kespeech_2_cuts,
Expand All @@ -138,10 +138,10 @@ def train_cuts(self) -> CutSet:
len(aishell_4_L_cuts),
len(aishell_4_M_cuts),
len(aishell_4_S_cuts),
len(alimeeting_cuts),
len(stcmds_cuts),
len(primewords_cuts),
len(magicdata_cuts),
len(alimeeting_cuts),
len(wenetspeech_L_cuts),
len(kespeech_1_cuts),
len(kespeech_2_cuts),
Expand All @@ -151,55 +151,13 @@ def train_cuts(self) -> CutSet:
def dev_cuts(self) -> CutSet:
logging.info("About to get multidataset dev cuts")

# AISHELL
logging.info("Loading Aishell DEV set in lazy mode")
aishell_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell_cuts_dev.jsonl.gz"
)

# AISHELL-2
logging.info("Loading Aishell-2 DEV set in lazy mode")
aishell2_dev_cuts = load_manifest_lazy(
self.fbank_dir / "aishell2_cuts_dev.jsonl.gz"
)

# Ali-Meeting
logging.info("Loading Ali-Meeting DEV set in lazy mode")
alimeeting_dev_cuts = load_manifest_lazy(
self.fbank_dir / "alimeeting-far_cuts_eval.jsonl.gz"
)

# MagicData
logging.info("Loading MagicData DEV set in lazy mode")
magicdata_dev_cuts = load_manifest_lazy(
self.fbank_dir / "magicdata_cuts_dev.jsonl.gz"
)

# KeSpeech
logging.info("Loading KeSpeech DEV set in lazy mode")
kespeech_dev_phase1_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase1.jsonl.gz"
)
kespeech_dev_phase2_cuts = load_manifest_lazy(
self.fbank_dir / "kespeech" / "kespeech-asr_cuts_dev_phase2.jsonl.gz"
)

# WeNetSpeech
logging.info("Loading WeNetSpeech DEV set in lazy mode")
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)

return wenetspeech_dev_cuts
# return [
# aishell_dev_cuts,
# aishell2_dev_cuts,
# alimeeting_dev_cuts,
# magicdata_dev_cuts,
# kespeech_dev_phase1_cuts,
# kespeech_dev_phase2_cuts,
# wenetspeech_dev_cuts,
# ]

def test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
Expand Down Expand Up @@ -267,30 +225,23 @@ def test_cuts(self) -> Dict[str, CutSet]:
self.fbank_dir / "wenetspeech" / "cuts_TEST_NET.jsonl.gz"
)
wenetspeech_dev_cuts = load_manifest_lazy(
self.fbank_dir / "wenetspeech" / "cuts_DEV.jsonl.gz"
self.fbank_dir / "wenetspeech" / "cuts_DEV_fixed.jsonl.gz"
)

return {
"aishell-2_test": aishell2_test_cuts,
"aishell-4": aishell4_test_cuts,
"magicdata_test": magicdata_test_cuts,
"kespeech-asr_test": kespeech_test_cuts,
}

# return {
# "alimeeting_test": alimeeting_test_cuts,
# "alimeeting_eval": alimeeting_eval_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "aishell-4": aishell4_test_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
# }
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
# "aishell_test": aishell_test_cuts,
# "aishell_dev": aishell_dev_cuts,
# "ali-meeting_test": alimeeting_test_cuts,
# "ali-meeting_eval": alimeeting_eval_cuts,
# "aishell-4_test": aishell4_test_cuts,
# "aishell-2_test": aishell2_test_cuts,
# "aishell-2_dev": aishell2_dev_cuts,
# "magicdata_test": magicdata_test_cuts,
# "magicdata_dev": magicdata_dev_cuts,
# "kespeech-asr_test": kespeech_test_cuts,
# "kespeech-asr_dev_phase1": kespeech_dev_phase1_cuts,
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
}
Loading

0 comments on commit b970ba5

Please sign in to comment.