Skip to content

Commit

Permalink
add otc related scripts using phone instead of bpe
Browse files Browse the repository at this point in the history
  • Loading branch information
DongjiGao committed Apr 21, 2024
1 parent 3f62460 commit fa13951
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 24 deletions.
20 changes: 6 additions & 14 deletions egs/librispeech/WSASR/conformer_ctc2/decode_phone.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from typing import Dict, List, Optional, Tuple

import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
Expand All @@ -41,7 +40,6 @@
from icefall.decode import get_lattice, one_best_decoding
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
get_texts,
Expand Down Expand Up @@ -94,7 +92,7 @@ def get_parser():
parser.add_argument(
"--avg",
type=int,
default=1,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
Expand Down Expand Up @@ -195,7 +193,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
HLG: k2.Fsa,
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
Expand Down Expand Up @@ -239,10 +237,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
Expand Down Expand Up @@ -271,7 +266,6 @@ def decode_one_batch(
1,
).to(torch.int32)

assert HLG is not None
decoding_graph = HLG

lattice = get_lattice(
Expand Down Expand Up @@ -303,7 +297,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
HLG: k2.Fsa,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
Expand Down Expand Up @@ -452,7 +446,7 @@ def main():

lexicon = Lexicon(params.lang_dir)
# remove otc_token from decoding units
max_token_id = len(lexicon.tokens) - 1
max_token_id = len(lexicon.tokens) - 1
num_classes = max_token_id + 1 # +1 for the blank

device = torch.device("cpu")
Expand All @@ -463,9 +457,7 @@ def main():

params.num_classes = num_classes

HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = HLG.to(device)
assert HLG.requires_grad is False

Expand Down
10 changes: 0 additions & 10 deletions egs/librispeech/WSASR/conformer_ctc2/train_phone.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,15 +899,6 @@ def run(rank, world_size, args):
if torch.cuda.is_available():
device = torch.device("cuda", rank)

if params.show_alignment:
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
params.HLG = HLG


lexicon = Lexicon(params.lang_dir)
graph_compiler = OtcPhoneTrainingGraphCompiler(
lexicon,
Expand Down Expand Up @@ -1118,7 +1109,6 @@ def main():
args.exp_dir = Path(args.exp_dir)
args.otc_token = f"{args.otc_token}"


world_size = args.world_size
assert world_size >= 1
if world_size > 1:
Expand Down

0 comments on commit fa13951

Please sign in to comment.