Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang committed Apr 24, 2024
1 parent b970ba5 commit 838bf22
Show file tree
Hide file tree
Showing 11 changed files with 740 additions and 561 deletions.
6 changes: 5 additions & 1 deletion egs/multi_zh-hans/ASR/whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
from multi_dataset import MultiDataset
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from zhconv import convert

from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
Expand Down Expand Up @@ -297,6 +297,7 @@ def decode_one_batch(
print(hyps)
return {"beam-search": hyps}


def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
Expand All @@ -314,6 +315,7 @@ def decode_dataset(
Returns:
Return a dict, whose key may be "beam-search".
"""

def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
Expand All @@ -323,6 +325,7 @@ def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
return text
elif normalize == "m2met":
import re

text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
Expand All @@ -348,6 +351,7 @@ def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
text = text.replace("、", "")
text = text.replace("?", "")
return text

results = []

num_cuts = 0
Expand Down
2 changes: 1 addition & 1 deletion egs/multi_zh-hans/ASR/whisper/multi_dataset.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,4 @@ def test_cuts(self) -> Dict[str, CutSet]:
# "kespeech-asr_dev_phase2": kespeech_dev_phase2_cuts,
# "wenetspeech-net_test": wenetspeech_test_net_cuts,
# "wenetspeech_dev": wenetspeech_dev_cuts,
}
}
3 changes: 2 additions & 1 deletion egs/multi_zh-hans/ASR/whisper/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from whisper_decoder_forward_monkey_patch import replace_whisper_decoder_forward
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward

from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
Expand Down Expand Up @@ -458,6 +458,7 @@ def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
return text
elif normalize == "m2met":
import re

text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
import whisper
from torch import Tensor
from torch import nn
from typing import Dict, Iterable, Optional
from whisper.model import ResidualAttentionBlock, LayerNorm
import numpy as np
from torch import Tensor, nn
from whisper.model import LayerNorm, ResidualAttentionBlock


def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
Expand All @@ -19,10 +20,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = (
x
+ self.positional_embedding[offset : offset + x.shape[1]]
)
x = x + self.positional_embedding[offset : offset + x.shape[1]]
x = x.to(xa.dtype)

# for block in self.blocks:
Expand All @@ -39,6 +37,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):

return logits


def replace_whisper_decoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.
Expand Down
2 changes: 1 addition & 1 deletion egs/speechio/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
| 10 | **whisper-large-ft-v0** | **6.34%** | 2023.03 |
| 11 | baidu_pro_api_zh | 7.29% | 2023.12 |

Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)
Note: Above API results are from [SPEECHIO](https://github.com/SpeechColab/Leaderboard). All results used the default [normalize method.](https://github.com/SpeechColab/Leaderboard/blob/master/utils/benchmark.sh#L67)

<details><summary> Detail all models </summary><p>

Expand Down
4 changes: 3 additions & 1 deletion egs/speechio/ASR/local/normalize_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from typing import Dict, List, Optional, Tuple

import kaldialign
from speechio_norm import TextNorm

from icefall.utils import store_transcripts, write_error_stats
from speechio_norm import TextNorm


def get_parser():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -140,6 +141,7 @@ def get_filenames(
results.append(whisper_filename)
return results


def main():
parser = get_parser()
args = parser.parse_args()
Expand Down
Loading

0 comments on commit 838bf22

Please sign in to comment.