Skip to content

Commit

Permalink
deactivate beam search temporarily for speed
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 28, 2024
1 parent ebc0f3b commit 360f208
Showing 1 changed file with 10 additions and 24 deletions.
34 changes: 10 additions & 24 deletions egs/librispeech/ASR/whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Fangjun Kuang,
# Wei Kang)
# 2024 Yuekai Zhang
# 2024 Xiaomi Corporation Xiaoyu Yang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand Down Expand Up @@ -145,26 +146,6 @@ def remove_punctuation(text: str or List[str]):
raise Exception(f"Not support type {type(text)}")


def to_simple(text: str or List[str]):
"""Convert traditional Chinese to simplified Chinese.
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings converted to simplified Chinese.
"""
if isinstance(text, str):
text = convert(text, "zh-cn")
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, "zh-cn")
result_text.append(t)
return result_text
else:
raise Exception(f"Not support type{type(text)}")


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
Expand Down Expand Up @@ -417,8 +398,8 @@ def main():
options = whisper.DecodingOptions(
task="transcribe",
language="en",
# without_timestamps=True,
# beam_size=params.beam_size,
without_timestamps=True,
#beam_size=params.beam_size,
)
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
Expand Down Expand Up @@ -481,12 +462,17 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

def remove_short_and_long_utt(c):
if c.duration < 1.0 or c.duration > 30.0:
return False
return True

# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)

test_clean_cuts = librispeech.test_clean_cuts().subset(first=200)
test_other_cuts = librispeech.test_other_cuts().subset(first=200)
test_clean_cuts = librispeech.test_clean_cuts().filter(remove_short_and_long_utt)
test_other_cuts = librispeech.test_other_cuts().filter(remove_short_and_long_utt)

test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
Expand Down

0 comments on commit 360f208

Please sign in to comment.