From dd82686a0fce98a806f3b4142777c8908f435af4 Mon Sep 17 00:00:00 2001 From: jinzr Date: Wed, 4 Sep 2024 22:16:41 +0800 Subject: [PATCH 01/33] init commit --- .../ASR/local/compute_fbank_libritts.py | 160 ++ egs/libritts/ASR/local/compute_fbank_musan.py | 1 + .../ASR/local/compute_spectrogram_libritts.py | 107 ++ .../ASR/local/display_manifest_statistics.py | 341 ++++ egs/libritts/ASR/local/validate_manifest.py | 71 + egs/libritts/ASR/prepare.sh | 108 ++ egs/libritts/ASR/shared | 1 + egs/libritts/ASR/zipformer/.gitignore | 1 + egs/libritts/ASR/zipformer/asr_datamodule.py | 459 +++++ .../ASR/zipformer/attention_decoder.py | 1 + egs/libritts/ASR/zipformer/beam_search.py | 1 + egs/libritts/ASR/zipformer/ctc_decode.py | 991 +++++++++++ egs/libritts/ASR/zipformer/decode.py | 1085 ++++++++++++ egs/libritts/ASR/zipformer/decode_stream.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + egs/libritts/ASR/zipformer/export-onnx-ctc.py | 1 + .../zipformer/export-onnx-streaming-ctc.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/libritts/ASR/zipformer/export-onnx.py | 1 + egs/libritts/ASR/zipformer/export.py | 1 + .../ASR/zipformer/generate_averaged_model.py | 1 + egs/libritts/ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/libritts/ASR/zipformer/joiner.py | 1 + egs/libritts/ASR/zipformer/label_smoothing.py | 1 + egs/libritts/ASR/zipformer/model.py | 1 + egs/libritts/ASR/zipformer/my_profile.py | 1 + egs/libritts/ASR/zipformer/onnx_check.py | 1 + egs/libritts/ASR/zipformer/onnx_decode.py | 324 ++++ .../onnx_pretrained-streaming-ctc.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + egs/libritts/ASR/zipformer/onnx_pretrained.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_H.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 + .../onnx_pretrained_ctc_HLG_streaming.py | 1 + egs/libritts/ASR/zipformer/optim.py | 1 + egs/libritts/ASR/zipformer/pretrained.py | 1 + egs/libritts/ASR/zipformer/pretrained_ctc.py | 1 + egs/libritts/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 904 ++++++++++ egs/libritts/ASR/zipformer/subsampling.py | 1 + egs/libritts/ASR/zipformer/train.py | 1511 +++++++++++++++++ egs/libritts/ASR/zipformer/zipformer.py | 1 + egs/libritts/CODEC/encodec/binary.py | 161 ++ .../CODEC/encodec/codec_datamodule.py | 271 +++ egs/libritts/CODEC/encodec/discriminators.py | 117 ++ egs/libritts/CODEC/encodec/encodec.py | 261 +++ egs/libritts/CODEC/encodec/loss.py | 298 ++++ .../CODEC/encodec/models/discriminators.py | 229 +++ egs/libritts/CODEC/encodec/models/utils.py | 12 + .../CODEC/encodec/modules/__init__.py | 20 + egs/libritts/CODEC/encodec/modules/conv.py | 334 ++++ egs/libritts/CODEC/encodec/modules/lstm.py | 27 + egs/libritts/CODEC/encodec/modules/norm.py | 28 + egs/libritts/CODEC/encodec/modules/seanet.py | 368 ++++ .../CODEC/encodec/modules/transformer.py | 141 ++ .../CODEC/encodec/quantization/__init__.py | 7 + egs/libritts/CODEC/encodec/quantization/ac.py | 311 ++++ .../CODEC/encodec/quantization/core_vq.py | 377 ++++ .../CODEC/encodec/quantization/distrib.py | 126 ++ egs/libritts/CODEC/encodec/quantization/vq.py | 121 ++ egs/libritts/CODEC/encodec/train.py | 902 ++++++++++ egs/libritts/CODEC/encodec/utils.py | 1 + 68 files changed, 10210 insertions(+) create mode 100755 egs/libritts/ASR/local/compute_fbank_libritts.py create mode 120000 egs/libritts/ASR/local/compute_fbank_musan.py create mode 100755 egs/libritts/ASR/local/compute_spectrogram_libritts.py create mode 100755 egs/libritts/ASR/local/display_manifest_statistics.py create mode 100755 egs/libritts/ASR/local/validate_manifest.py create mode 100755 egs/libritts/ASR/prepare.sh create mode 120000 egs/libritts/ASR/shared create mode 100644 egs/libritts/ASR/zipformer/.gitignore create mode 100644 egs/libritts/ASR/zipformer/asr_datamodule.py create mode 120000 egs/libritts/ASR/zipformer/attention_decoder.py create mode 120000 egs/libritts/ASR/zipformer/beam_search.py create mode 100755 egs/libritts/ASR/zipformer/ctc_decode.py create mode 100755 egs/libritts/ASR/zipformer/decode.py create mode 120000 egs/libritts/ASR/zipformer/decode_stream.py create mode 120000 egs/libritts/ASR/zipformer/encoder_interface.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx-ctc.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx.py create mode 120000 egs/libritts/ASR/zipformer/export.py create mode 120000 egs/libritts/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/libritts/ASR/zipformer/jit_pretrained.py create mode 120000 egs/libritts/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/libritts/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/libritts/ASR/zipformer/joiner.py create mode 120000 egs/libritts/ASR/zipformer/label_smoothing.py create mode 120000 egs/libritts/ASR/zipformer/model.py create mode 120000 egs/libritts/ASR/zipformer/my_profile.py create mode 120000 egs/libritts/ASR/zipformer/onnx_check.py create mode 100755 egs/libritts/ASR/zipformer/onnx_decode.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py create mode 120000 egs/libritts/ASR/zipformer/optim.py create mode 120000 egs/libritts/ASR/zipformer/pretrained.py create mode 120000 egs/libritts/ASR/zipformer/pretrained_ctc.py create mode 120000 egs/libritts/ASR/zipformer/scaling.py create mode 120000 egs/libritts/ASR/zipformer/scaling_converter.py create mode 120000 egs/libritts/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/libritts/ASR/zipformer/streaming_decode.py create mode 120000 egs/libritts/ASR/zipformer/subsampling.py create mode 100755 egs/libritts/ASR/zipformer/train.py create mode 120000 egs/libritts/ASR/zipformer/zipformer.py create mode 100644 egs/libritts/CODEC/encodec/binary.py create mode 100644 egs/libritts/CODEC/encodec/codec_datamodule.py create mode 100644 egs/libritts/CODEC/encodec/discriminators.py create mode 100644 egs/libritts/CODEC/encodec/encodec.py create mode 100644 egs/libritts/CODEC/encodec/loss.py create mode 100644 egs/libritts/CODEC/encodec/models/discriminators.py create mode 100644 egs/libritts/CODEC/encodec/models/utils.py create mode 100644 egs/libritts/CODEC/encodec/modules/__init__.py create mode 100644 egs/libritts/CODEC/encodec/modules/conv.py create mode 100644 egs/libritts/CODEC/encodec/modules/lstm.py create mode 100644 egs/libritts/CODEC/encodec/modules/norm.py create mode 100644 egs/libritts/CODEC/encodec/modules/seanet.py create mode 100644 egs/libritts/CODEC/encodec/modules/transformer.py create mode 100644 egs/libritts/CODEC/encodec/quantization/__init__.py create mode 100644 egs/libritts/CODEC/encodec/quantization/ac.py create mode 100644 egs/libritts/CODEC/encodec/quantization/core_vq.py create mode 100644 egs/libritts/CODEC/encodec/quantization/distrib.py create mode 100644 egs/libritts/CODEC/encodec/quantization/vq.py create mode 100644 egs/libritts/CODEC/encodec/train.py create mode 120000 egs/libritts/CODEC/encodec/utils.py diff --git a/egs/libritts/ASR/local/compute_fbank_libritts.py b/egs/libritts/ASR/local/compute_fbank_libritts.py new file mode 100755 index 0000000000..5e78af18b1 --- /dev/null +++ b/egs/libritts/ASR/local/compute_fbank_libritts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao,) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriTTS dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""", + ) + + return parser.parse_args() + + +def compute_fbank_libritts( + dataset: Optional[str] = None, + sampling_rate: int = 24000, + perturb_speed: Optional[bool] = True, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(32, os.cpu_count()) + + num_mel_bins = 80 + + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "libritts" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if sampling_rate != 24000: + logging.info(f"Resampling audio to {sampling_rate}") + cut_set = cut_set.resample(sampling_rate) + if "train" in partition: + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + compute_fbank_libritts( + dataset=args.dataset, + sampling_rate=args.sampling_rate, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/libritts/ASR/local/compute_fbank_musan.py b/egs/libritts/ASR/local/compute_fbank_musan.py new file mode 120000 index 0000000000..5833f2484e --- /dev/null +++ b/egs/libritts/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/compute_spectrogram_libritts.py b/egs/libritts/ASR/local/compute_spectrogram_libritts.py new file mode 100755 index 0000000000..181353fdd6 --- /dev/null +++ b/egs/libritts/ASR/local/compute_spectrogram_libritts.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao,) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the VCTK dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/spectrogram. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + LilcomChunkyWriter, + Spectrogram, + SpectrogramConfig, + load_manifest, +) +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_spectrogram_libritts(): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(32, os.cpu_count()) + + sampling_rate = 24000 + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "libritts" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet + ).resample(sampling_rate=sampling_rate) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_libritts() diff --git a/egs/libritts/ASR/local/display_manifest_statistics.py b/egs/libritts/ASR/local/display_manifest_statistics.py new file mode 100755 index 0000000000..ddd022c96f --- /dev/null +++ b/egs/libritts/ASR/local/display_manifest_statistics.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + paths = [ + "./data/fbank/libritts_cuts_train-clean-100.jsonl.gz", + "./data/fbank/libritts_cuts_train-clean-360.jsonl.gz", + "./data/fbank/libritts_cuts_train-other-500.jsonl.gz", + "./data/fbank/libritts_cuts_dev-clean.jsonl.gz", + "./data/fbank/libritts_cuts_dev-other.jsonl.gz", + "./data/fbank/libritts_cuts_test-clean.jsonl.gz", + "./data/fbank/libritts_cuts_test-other.jsonl.gz", + ] + for path in paths: + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +./data/fbank/libritts_cuts_train-clean-100.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 33236 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 53:47:18 _ +________________________________________ +_ mean _ 5.8 _ +________________________________________ +_ std _ 4.6 _ +________________________________________ +_ min _ 0.2 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.5 _ +________________________________________ +_ 75% _ 7.9 _ +________________________________________ +_ 99% _ 21.4 _ +________________________________________ +_ 99.5% _ 23.7 _ +________________________________________ +_ 99.9% _ 27.8 _ +________________________________________ +_ max _ 33.2 _ +________________________________________ +_ Recordings available: _ 33236 _ +________________________________________ +_ Features available: _ 33236 _ +________________________________________ +_ Supervisions available: _ 33236 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_train-clean-360.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 116500 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 191:17:42 _ +_________________________________________ +_ mean _ 5.9 _ +_________________________________________ +_ std _ 4.6 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.4 _ +_________________________________________ +_ 50% _ 4.6 _ +_________________________________________ +_ 75% _ 8.1 _ +_________________________________________ +_ 99% _ 21.3 _ +_________________________________________ +_ 99.5% _ 23.4 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 40.4 _ +_________________________________________ +_ Recordings available: _ 116500 _ +_________________________________________ +_ Features available: _ 116500 _ +_________________________________________ +_ Supervisions available: _ 116500 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/fbank/libritts_cuts_train-other-500.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 205043 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 310:04:36 _ +_________________________________________ +_ mean _ 5.4 _ +_________________________________________ +_ std _ 4.4 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.3 _ +_________________________________________ +_ 50% _ 4.2 _ +_________________________________________ +_ 75% _ 7.3 _ +_________________________________________ +_ 99% _ 20.6 _ +_________________________________________ +_ 99.5% _ 22.8 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 43.9 _ +_________________________________________ +_ Recordings available: _ 205043 _ +_________________________________________ +_ Features available: _ 205043 _ +_________________________________________ +_ Supervisions available: _ 205043 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/fbank/libritts_cuts_dev-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5736 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:58:13 _ +________________________________________ +_ mean _ 5.6 _ +________________________________________ +_ std _ 4.3 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.4 _ +________________________________________ +_ 75% _ 7.8 _ +________________________________________ +_ 99% _ 19.9 _ +________________________________________ +_ 99.5% _ 21.9 _ +________________________________________ +_ 99.9% _ 26.3 _ +________________________________________ +_ max _ 30.1 _ +________________________________________ +_ Recordings available: _ 5736 _ +________________________________________ +_ Features available: _ 5736 _ +________________________________________ +_ Supervisions available: _ 5736 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_dev-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4613 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:25:52 _ +________________________________________ +_ mean _ 5.0 _ +________________________________________ +_ std _ 4.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.2 _ +________________________________________ +_ 50% _ 3.8 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 19.7 _ +________________________________________ +_ 99.5% _ 24.5 _ +________________________________________ +_ 99.9% _ 31.0 _ +________________________________________ +_ max _ 32.6 _ +________________________________________ +_ Recordings available: _ 4613 _ +________________________________________ +_ Features available: _ 4613 _ +________________________________________ +_ Supervisions available: _ 4613 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_test-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4837 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:34:09 _ +________________________________________ +_ mean _ 6.4 _ +________________________________________ +_ std _ 5.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.8 _ +________________________________________ +_ 75% _ 8.9 _ +________________________________________ +_ 99% _ 22.6 _ +________________________________________ +_ 99.5% _ 24.4 _ +________________________________________ +_ 99.9% _ 29.6 _ +________________________________________ +_ max _ 36.7 _ +________________________________________ +_ Recordings available: _ 4837 _ +________________________________________ +_ Features available: _ 4837 _ +________________________________________ +_ Supervisions available: _ 4837 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_test-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5120 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:41:31 _ +________________________________________ +_ mean _ 4.7 _ +________________________________________ +_ std _ 3.8 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 3.6 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 17.8 _ +________________________________________ +_ 99.5% _ 20.4 _ +________________________________________ +_ 99.9% _ 23.8 _ +________________________________________ +_ max _ 27.3 _ +________________________________________ +_ Recordings available: _ 5120 _ +________________________________________ +_ Features available: _ 5120 _ +________________________________________ +_ Supervisions available: _ 5120 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ +""" diff --git a/egs/libritts/ASR/local/validate_manifest.py b/egs/libritts/ASR/local/validate_manifest.py new file mode 100755 index 0000000000..abd4da88af --- /dev/null +++ b/egs/libritts/ASR/local/validate_manifest.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao,) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest +from lhotse.dataset.speech_recognition import validate_for_asr + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest(manifest) + assert isinstance(cut_set, CutSet) + + validate_for_asr(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh new file mode 100755 index 0000000000..77c3c38422 --- /dev/null +++ b/egs/libritts/ASR/prepare.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 +sampling_rate=24000 +perturb_speed=true + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriTTS, + # you can create a symlink + # + # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS + # + if [ ! -d $dl_dir/LibriTTS ]; then + lhotse download libritts $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/musan + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriTTS manifest" + # We assume that you have downloaded the LibriTTS corpus + # to $dl_dir/LibriTTS + mkdir -p data/manifests + if [ ! -e data/manifests/.libritts.done ]; then + lhotse prepare libritts $dl_dir/LibriTTS data/manifests + touch data/manifests/.libritts.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + if [ ! -f data/manifests/.musan_manifests.done ]; then + log "It may take 6 minutes" + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan_manifests.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute Fbank for LibriTTS" + mkdir -p data/fbank + if [ ! -e data/fbank/.libritts.done ]; then + ./local/compute_fbank_libritts.py \ + --sampling-rate $sampling_rate \ + --perturb-speed $perturb_speed + touch data/fbank/.libritts.done + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 and + # train-other-500 together to form the training set. + if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + fi + + if [ ! -e data/fbank/.libritts-validated.done ]; then + log "Validating data/fbank for LibriTTS" + ./local/validate_manifest.py \ + data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + touch data/fbank/.libritts-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + if [ ! -f data/fbank/.msuan.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_musan.py + touch data/fbank/.msuan.done + fi +fi \ No newline at end of file diff --git a/egs/libritts/ASR/shared b/egs/libritts/ASR/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/libritts/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/.gitignore b/egs/libritts/ASR/zipformer/.gitignore new file mode 100644 index 0000000000..e47ac15828 --- /dev/null +++ b/egs/libritts/ASR/zipformer/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/libritts/ASR/zipformer/asr_datamodule.py b/egs/libritts/ASR/zipformer/asr_datamodule.py new file mode 100644 index 0000000000..8d2b9eaddf --- /dev/null +++ b/egs/libritts/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,459 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriTTSAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. libritts test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""When enabled, use 960h LibriTTS. + Otherwise, use the 100h subset.""", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/libritts/ASR/zipformer/attention_decoder.py b/egs/libritts/ASR/zipformer/attention_decoder.py new file mode 120000 index 0000000000..384e1b95ea --- /dev/null +++ b/egs/libritts/ASR/zipformer/attention_decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/beam_search.py b/egs/libritts/ASR/zipformer/beam_search.py new file mode 120000 index 0000000000..e24eca39f2 --- /dev/null +++ b/egs/libritts/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py new file mode 100755 index 0000000000..c31b1362ac --- /dev/null +++ b/egs/libritts/ASR/zipformer/ctc_decode.py @@ -0,0 +1,991 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(3) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(4) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(5) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(6) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring + +(7) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(8) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + ctc_greedy_search, + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (4) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (5) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (6) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + if params.decoding_method == "ctc-greedy-search": + hyps = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} # note: returns words + + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no-rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} # note: returns BPE tokens + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "ctc-greedy-search", + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 + + if params.decoding_method in [ + "ctc-greedy-search", + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method in [ + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriTTSAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py new file mode 100755 index 0000000000..1249254efd --- /dev/null +++ b/egs/libritts/ASR/zipformer/decode.py @@ -0,0 +1,1085 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriTTSAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/decode_stream.py b/egs/libritts/ASR/zipformer/decode_stream.py new file mode 120000 index 0000000000..b8d8ddfc4c --- /dev/null +++ b/egs/libritts/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/encoder_interface.py b/egs/libritts/ASR/zipformer/encoder_interface.py new file mode 120000 index 0000000000..653c5b09af --- /dev/null +++ b/egs/libritts/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-ctc.py b/egs/libritts/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 0000000000..f9d7563520 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 0000000000..652346001e --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-streaming.py b/egs/libritts/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 0000000000..2962eb7847 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx.py b/egs/libritts/ASR/zipformer/export-onnx.py new file mode 120000 index 0000000000..70a15683c2 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export.py b/egs/libritts/ASR/zipformer/export.py new file mode 120000 index 0000000000..dfc1bec080 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/generate_averaged_model.py b/egs/libritts/ASR/zipformer/generate_averaged_model.py new file mode 120000 index 0000000000..5a015ee6c1 --- /dev/null +++ b/egs/libritts/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained.py b/egs/libritts/ASR/zipformer/jit_pretrained.py new file mode 120000 index 0000000000..25108391fa --- /dev/null +++ b/egs/libritts/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py b/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 0000000000..9a8da58444 --- /dev/null +++ b/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py b/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 0000000000..1962351e9a --- /dev/null +++ b/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/joiner.py b/egs/libritts/ASR/zipformer/joiner.py new file mode 120000 index 0000000000..5b8a36332e --- /dev/null +++ b/egs/libritts/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/label_smoothing.py b/egs/libritts/ASR/zipformer/label_smoothing.py new file mode 120000 index 0000000000..175c633cc7 --- /dev/null +++ b/egs/libritts/ASR/zipformer/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/model.py b/egs/libritts/ASR/zipformer/model.py new file mode 120000 index 0000000000..cd7e07d72b --- /dev/null +++ b/egs/libritts/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/my_profile.py b/egs/libritts/ASR/zipformer/my_profile.py new file mode 120000 index 0000000000..3a90b26289 --- /dev/null +++ b/egs/libritts/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_check.py b/egs/libritts/ASR/zipformer/onnx_check.py new file mode 120000 index 0000000000..f3dd420046 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_decode.py b/egs/libritts/ASR/zipformer/onnx_decode.py new file mode 100755 index 0000000000..4b1e2cc5cf --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_decode.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +2. Run this file + +./zipformer/onnx_decode.py \ + --exp-dir $repo/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + The token table. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + The token table. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = SymbolTable.from_file(args.tokens) + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriTTSAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py new file mode 120000 index 0000000000..d623a8462c --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 0000000000..cfea104c27 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained.py b/egs/libritts/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 0000000000..8f32f4ee7a --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py new file mode 120000 index 0000000000..a3183ebf66 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 120000 index 0000000000..a4fd76ac2e --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 120000 index 0000000000..f805e3761a --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 120000 index 0000000000..8343d50793 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py new file mode 120000 index 0000000000..3568e7cab7 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/optim.py b/egs/libritts/ASR/zipformer/optim.py new file mode 120000 index 0000000000..5eaa3cffd4 --- /dev/null +++ b/egs/libritts/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/pretrained.py b/egs/libritts/ASR/zipformer/pretrained.py new file mode 120000 index 0000000000..0bd71dde4d --- /dev/null +++ b/egs/libritts/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/pretrained_ctc.py b/egs/libritts/ASR/zipformer/pretrained_ctc.py new file mode 120000 index 0000000000..c2f6f6fc38 --- /dev/null +++ b/egs/libritts/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/scaling.py b/egs/libritts/ASR/zipformer/scaling.py new file mode 120000 index 0000000000..6f398f431d --- /dev/null +++ b/egs/libritts/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/scaling_converter.py b/egs/libritts/ASR/zipformer/scaling_converter.py new file mode 120000 index 0000000000..b0ecee05e1 --- /dev/null +++ b/egs/libritts/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/streaming_beam_search.py b/egs/libritts/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 0000000000..b1ed545579 --- /dev/null +++ b/egs/libritts/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/streaming_decode.py b/egs/libritts/ASR/zipformer/streaming_decode.py new file mode 100755 index 0000000000..e771bbafe9 --- /dev/null +++ b/egs/libritts/ASR/zipformer/streaming_decode.py @@ -0,0 +1,904 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import LibriTTSAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet, set_caching_enabled +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + if params.label: + params.suffix += f"-{params.label}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriTTSAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/subsampling.py b/egs/libritts/ASR/zipformer/subsampling.py new file mode 120000 index 0000000000..01ae9002c6 --- /dev/null +++ b/egs/libritts/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py new file mode 100755 index 0000000000..fef2e2ae5e --- /dev/null +++ b/egs/libritts/ASR/zipformer/train.py @@ -0,0 +1,1511 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` + - ctc loss & attention decoder loss, no transducer loss, + with `--use-transducer False --use-ctc True --use-attention-decoder True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_autocast: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriTTSAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = librispeech.train_clean_100_cuts() + # train_cuts += librispeech.train_clean_360_cuts() + # train_cuts += librispeech.train_other_500_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/zipformer.py b/egs/libritts/ASR/zipformer/zipformer.py new file mode 120000 index 0000000000..23011dda71 --- /dev/null +++ b/egs/libritts/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/libritts/CODEC/encodec/binary.py b/egs/libritts/CODEC/encodec/binary.py new file mode 100644 index 0000000000..3004831272 --- /dev/null +++ b/egs/libritts/CODEC/encodec/binary.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" + +import io +import json +import struct +from typing import IO, Any, List, Optional + +# format is `ECDC` magic code, followed by the header size as uint32. +# Then an uint8 indicates the protocol version (0.) +# The header is then provided as json and should contain all required +# informations for decoding. A raw stream of bytes is then provided +# and should be interpretable using the json header. +_encodec_header_struct = struct.Struct("!4sBI") +_ENCODEC_MAGIC = b"ECDC" + + +def write_ecdc_header(fo: IO[bytes], metadata: Any): + meta_dumped = json.dumps(metadata).encode("utf-8") + version = 0 + header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped)) + fo.write(header) + fo.write(meta_dumped) + fo.flush() + + +def _read_exactly(fo: IO[bytes], size: int) -> bytes: + buf = b"" + while len(buf) < size: + new_buf = fo.read(size) + if not new_buf: + raise EOFError( + "Impossible to read enough data from the stream, " + f"{size} bytes remaining." + ) + buf += new_buf + size -= len(new_buf) + return buf + + +def read_ecdc_header(fo: IO[bytes]): + header_bytes = _read_exactly(fo, _encodec_header_struct.size) + magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) + if magic != _ENCODEC_MAGIC: + raise ValueError("File is not in ECDC format.") + if version != 0: + raise ValueError("Version not supported.") + meta_bytes = _read_exactly(fo, meta_size) + return json.loads(meta_bytes.decode("utf-8")) + + +class BitPacker: + """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. + Note that for some bandwidth (1.5, 3), the codebook representation + will not cover an integer number of bytes. + + Args: + bits (int): number of bits per value that will be pushed. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: IO[bytes]): + self._current_value = 0 + self._current_bits = 0 + self.bits = bits + self.fo = fo + + def push(self, value: int): + """Push a new value to the stream. This will immediately + write as many uint8 as possible to the underlying file-object.""" + self._current_value += value << self._current_bits + self._current_bits += self.bits + while self._current_bits >= 8: + lower_8bits = self._current_value & 0xFF + self._current_bits -= 8 + self._current_value >>= 8 + self.fo.write(bytes([lower_8bits])) + + def flush(self): + """Flushes the remaining partial uint8, call this at the end + of the stream to encode.""" + if self._current_bits: + self.fo.write(bytes([self._current_value])) + self._current_value = 0 + self._current_bits = 0 + self.fo.flush() + + +class BitUnpacker: + """BitUnpacker does the opposite of `BitPacker`. + + Args: + bits (int): number of bits of the values to decode. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: IO[bytes]): + self.bits = bits + self.fo = fo + self._mask = (1 << bits) - 1 + self._current_value = 0 + self._current_bits = 0 + + def pull(self) -> Optional[int]: + """ + Pull a single value from the stream, potentially reading some + extra bytes from the underlying file-object. + Returns `None` when reaching the end of the stream. + """ + while self._current_bits < self.bits: + buf = self.fo.read(1) + if not buf: + return None + character = buf[0] + self._current_value += character << self._current_bits + self._current_bits += 8 + + out = self._current_value & self._mask + self._current_value >>= self.bits + self._current_bits -= self.bits + return out + + +def test(): + import torch + + torch.manual_seed(1234) + for rep in range(4): + length: int = torch.randint(10, 2_000, (1,)).item() + bits: int = torch.randint(1, 16, (1,)).item() + tokens: List[int] = torch.randint(2**bits, (length,)).tolist() + rebuilt: List[int] = [] + buf = io.BytesIO() + packer = BitPacker(bits, buf) + for token in tokens: + packer.push(token) + packer.flush() + buf.seek(0) + unpacker = BitUnpacker(bits, buf) + while True: + value = unpacker.pull() + if value is None: + break + rebuilt.append(value) + assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) + # The flushing mechanism might lead to "ghost" values at the end of the stream. + assert len(rebuilt) <= len(tokens) + 8 // bits, ( + len(rebuilt), + len(tokens), + bits, + ) + for idx, (a, b) in enumerate(zip(tokens, rebuilt)): + assert a == b, (idx, a, b) + + +if __name__ == "__main__": + test() diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py new file mode 100644 index 0000000000..996569d215 --- /dev/null +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -0,0 +1,271 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriTTSCodecDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Codec data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=False, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") diff --git a/egs/libritts/CODEC/encodec/discriminators.py b/egs/libritts/CODEC/encodec/discriminators.py new file mode 100644 index 0000000000..484f1ee431 --- /dev/null +++ b/egs/libritts/CODEC/encodec/discriminators.py @@ -0,0 +1,117 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn +from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT +from torch.nn import AvgPool1d + + +class MultiPeriodDiscriminator(nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + + def __init__( + self, + filters: int, + in_channels: int = 1, + out_channels: int = 1, + n_ffts: List[int] = [1024, 2048, 512, 256, 128], + hop_lengths: List[int] = [256, 512, 128, 64, 32], + win_lengths: List[int] = [1024, 2048, 512, 256, 128], + **kwargs + ): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList( + [ + DiscriminatorSTFT( + filters, + in_channels=in_channels, + out_channels=out_channels, + n_fft=n_ffts[i], + win_length=win_lengths[i], + hop_length=hop_lengths[i], + **kwargs + ) + for i in range(len(n_ffts)) + ] + ) + self.num_discriminators = len(self.discriminators) + + def forward(self, x: torch.Tensor): + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py new file mode 100644 index 0000000000..e7c5ad590a --- /dev/null +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -0,0 +1,261 @@ +import math +import random +from typing import List + +import numpy as np +import torch +from loss import loss_dis, loss_g +from torch import nn +from torch.cuda.amp import autocast + + +class Encodec(nn.Module): + def __init__( + self, + sample_rate: int, + target_bandwidths: List[float], + params: dict, + encoder: nn.Module, + quantizer: nn.Module, + decoder: nn.Module, + multi_scale_discriminator: nn.Module, + multi_period_discriminator: nn.Module, + multi_scale_stft_discriminator: nn.Module, + cache_generator_outputs: bool = True, + ): + super(Encodec, self).__init__() + + self.params = params + + # setup the generator + self.sample_rate = sample_rate + self.encoder = encoder + self.quantizer = quantizer + self.decoder = decoder + + self.ratios = encoder.ratios + self.hop_length = np.prod(self.ratios) + self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios)) + self.target_bandwidths = target_bandwidths + + # discriminators + self.multi_scale_discriminator = multi_scale_discriminator + self.multi_period_discriminator = multi_period_discriminator + self.multi_scale_stft_discriminator = multi_scale_stft_discriminator + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + def _forward_generator( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + global_step: int, + return_sample: bool = False, + ): + """Perform generator forward. + + Args: + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + global_step (int): Global step. + return_sample (bool): Return the generator output. + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + e = self.encoder(speech) + bw = random.choice(self.target_bandwidths) + quantized, codes, bandwidth, commit_loss = self.quantizer( + e, self.frame_rate, bw + ) + speech_hat = self.decoder(quantized) + else: + speech_hat = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = speech_hat + + # calculate discriminator outputs + y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous()) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) + + # calculate losses + with autocast(enabled=False): + loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( + commit_loss, + speech, + speech_hat, + fmap, + fmap_hat, + y, + y_hat, + global_step, + y_p, + y_p_hat, + y_s, + y_s_hat, + fmap_p, + fmap_p_hat, + fmap_s, + fmap_s_hat, + args=self.params, + ) + + stats = dict( + generator_loss=loss.item(), + generator_reconstruction_loss=rec_loss.item(), + generator_feature_loss=feat_loss.item(), + generator_adv_loss=adv_loss.item(), + generator_commit_loss=commit_loss.item(), + d_weight=d_weight.item(), + ) + + if return_sample: + stats["returned_sample"] = ( + speech_hat[0].data.cpu().numpy(), + speech[0].data.cpu().numpy(), + fmap_hat[0][0].data.cpu().numpy(), + fmap[0][0].data.cpu().numpy(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def _forward_discriminator( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + global_step: int, + ): + """ + Args: + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + global_step (int): Global step. + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + e = self.encoder(speech) + bw = random.choice(self.target_bandwidths) + quantized, codes, bandwidth, commit_loss = self.quantizer( + e, self.frame_rate, bw + ) + speech_hat = self.decoder(quantized) + else: + speech_hat = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = speech_hat + + # calculate discriminator outputs + y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) + y_hat, fmap_hat = self.multi_scale_stft_discriminator( + speech_hat.contiguous().detach() + ) + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) + # calculate losses + with autocast(enabled=False): + loss = loss_dis( + y, + y_hat, + fmap, + fmap_hat, + y_p, + y_p_hat, + fmap_p, + fmap_p_hat, + y_s, + y_s_hat, + fmap_s, + fmap_s_hat, + global_step, + args=self.params, + ) + stats = dict( + discriminator_loss=loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return loss, stats + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + global_step: int, + return_sample: bool, + forward_generator: bool, + ): + if forward_generator: + return self._forward_generator( + speech=speech, + speech_lengths=speech_lengths, + global_step=global_step, + return_sample=return_sample, + ) + else: + return self._forward_discriminator( + speech=speech, + speech_lengths=speech_lengths, + global_step=global_step, + ) + + def encode(self, x, target_bw=None, st=None): + e = self.encoder(x) + if target_bw is None: + bw = self.target_bandwidths[-1] + else: + bw = target_bw + if st is None: + st = 0 + codes = self.quantizer.encode(e, self.frame_rate, bw, st) + return codes + + def decode(self, codes): + quantized = self.quantizer.decode(codes) + o = self.decoder(quantized) + return o diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py new file mode 100644 index 0000000000..1bb78f2839 --- /dev/null +++ b/egs/libritts/CODEC/encodec/loss.py @@ -0,0 +1,298 @@ +import torch +import torch.nn.functional as F +from torchaudio.transforms import MelSpectrogram + + +def adversarial_g_loss(y_disc_gen): + """Hinge loss""" + loss = 0.0 + for i in range(len(y_disc_gen)): + stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze() + loss += stft_loss + return loss / len(y_disc_gen) + + +def feature_loss(fmap_r, fmap_gen): + loss = 0.0 + for i in range(len(fmap_r)): + for j in range(len(fmap_r[i])): + stft_loss = ( + (fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean()) + ).mean() + loss += stft_loss + return loss / (len(fmap_r) * len(fmap_r[0])) + + +def sim_loss(y_disc_r, y_disc_gen): + loss = 0.0 + for i in range(len(y_disc_r)): + loss += F.mse_loss(y_disc_r[i], y_disc_gen[i]) + return loss / len(y_disc_r) + + +# def sisnr_loss(x, s, eps=1e-8): +# """ +# calculate training loss +# input: +# x: separated signal, N x S tensor, estimate value +# s: reference signal, N x S tensor, True value +# Return: +# sisnr: N tensor +# """ +# if x.shape != s.shape: +# if x.shape[-1] > s.shape[-1]: +# x = x[:, :s.shape[-1]] +# else: +# s = s[:, :x.shape[-1]] +# def l2norm(mat, keepdim=False): +# return torch.norm(mat, dim=-1, keepdim=keepdim) +# if x.shape != s.shape: +# raise RuntimeError( +# "Dimention mismatch when calculate si-snr, {} vs {}".format( +# x.shape, s.shape)) +# x_zm = x - torch.mean(x, dim=-1, keepdim=True) +# s_zm = s - torch.mean(s, dim=-1, keepdim=True) +# t = torch.sum( +# x_zm * s_zm, dim=-1, +# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) +# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) +# return torch.sum(loss) / x.shape[0] + + +def reconstruction_loss(x, G_x, args, eps=1e-7): + # NOTE (lsx): hard-coded now + L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss + # loss_sisnr = sisnr_loss(G_x, x) # + # L += 0.01*loss_sisnr + # 2^6=64 -> 2^10=1024 + # NOTE (lsx): add 2^11 + for i in range(6, 12): + # for i in range(5, 12): # Encodec setting + s = 2**i + melspec = MelSpectrogram( + sample_rate=args.sr, + n_fft=max(s, 512), + win_length=s, + hop_length=s // 4, + n_mels=64, + wkwargs={"device": args.device}, + ).to(args.device) + S_x = melspec(x) + S_G_x = melspec(G_x) + l1_loss = (S_x - S_G_x).abs().mean() + l2_loss = ( + ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean( + dim=-2 + ) + ** 0.5 + ).mean() + + alpha = (s / 2) ** 0.5 + L += l1_loss + alpha * l2_loss + return L + + +def criterion_d( + y_disc_r, + y_disc_gen, + fmap_r_det, + fmap_gen_det, + y_df_hat_r, + y_df_hat_g, + fmap_f_r, + fmap_f_g, + y_ds_hat_r, + y_ds_hat_g, + fmap_s_r, + fmap_s_g, +): + """Hinge Loss""" + loss = 0.0 + loss1 = 0.0 + loss2 = 0.0 + loss3 = 0.0 + for i in range(len(y_disc_r)): + loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean() + for i in range(len(y_df_hat_r)): + loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean() + for i in range(len(y_ds_hat_r)): + loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean() + + loss = ( + loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r) + ) / 3.0 + + return loss + + +def criterion_g( + commit_loss, + x, + G_x, + fmap_r, + fmap_gen, + y_disc_r, + y_disc_gen, + y_df_hat_r, + y_df_hat_g, + fmap_f_r, + fmap_f_g, + y_ds_hat_r, + y_ds_hat_g, + fmap_s_r, + fmap_s_g, + args, +): + adv_g_loss = adversarial_g_loss(y_disc_gen) + feat_loss = ( + feature_loss(fmap_r, fmap_gen) + + sim_loss(y_disc_r, y_disc_gen) + + feature_loss(fmap_f_r, fmap_f_g) + + sim_loss(y_df_hat_r, y_df_hat_g) + + feature_loss(fmap_s_r, fmap_s_g) + + sim_loss(y_ds_hat_r, y_ds_hat_g) + ) / 3.0 + rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args) + total_loss = ( + args.lambda_com * commit_loss + + args.lambda_adv * adv_g_loss + + args.lambda_feat * feat_loss + + args.lambda_rec * rec_loss + ) + return total_loss, adv_g_loss, feat_loss, rec_loss + + +def adopt_weight(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +def adopt_dis_weight(weight, global_step, threshold=0, value=0.0): + # 0,3,6,9,13....这些时间步,不更新dis + if global_step % 3 == 0: + weight = value + return weight + + +def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + print("last_layer cannot be none") + assert 1 == 2 + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 1.0, 1.0).detach() + d_weight = d_weight * args.lambda_adv + return d_weight + + +def loss_g( + codebook_loss, + speech, + speech_hat, + fmap, + fmap_hat, + y, + y_hat, + global_step, + y_df, + y_df_hat, + y_ds, + y_ds_hat, + fmap_f, + fmap_f_hat, + fmap_s, + fmap_s_hat, + args=None, +): + """ + args: + codebook_loss: commit loss. + speech: ground-truth wav. + speech_hat: reconstructed wav. + fmap: real stft-D feature map. + fmap_hat: fake stft-D feature map. + y: real stft-D logits. + y_hat: fake stft-D logits. + global_step: global training step. + y_df: real MPD logits. + y_df_hat: fake MPD logits. + y_ds: real MSD logits. + y_ds_hat: fake MSD logits. + fmap_f: real MPD feature map. + fmap_f_hat: fake MPD feature map. + fmap_s: real MSD feature map. + fmap_s_hat: fake MSD feature map. + """ + rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args) + adv_g_loss = adversarial_g_loss(y_hat) + adv_mpd_loss = adversarial_g_loss(y_df_hat) + adv_msd_loss = adversarial_g_loss(y_ds_hat) + adv_loss = ( + adv_g_loss + adv_mpd_loss + adv_msd_loss + ) / 3.0 # NOTE(lsx): need to divide by 3? + feat_loss = feature_loss( + fmap, fmap_hat + ) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits? + feat_loss_mpd = feature_loss( + fmap_f, fmap_f_hat + ) # + sim_loss(y_df_hat_r, y_df_hat_g) + feat_loss_msd = feature_loss( + fmap_s, fmap_s_hat + ) # + sim_loss(y_ds_hat_r, y_ds_hat_g) + feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 + d_weight = torch.tensor(1.0) + + disc_factor = adopt_weight( + args.lambda_adv, global_step, threshold=args.discriminator_iter_start + ) + if disc_factor == 0.0: + fm_loss_wt = 0 + else: + fm_loss_wt = args.lambda_feat + + loss = ( + rec_loss + + d_weight * disc_factor * adv_loss + + fm_loss_wt * feat_loss_tot + + args.lambda_com * codebook_loss + ) + return loss, rec_loss, adv_loss, feat_loss_tot, d_weight + + +def loss_dis( + y, + y_hat, + fmap, + fmap_hat, + y_df, + y_df_hat, + fmap_f, + fmap_f_hat, + y_ds, + y_ds_hat, + fmap_s, + fmap_s_hat, + global_step, + args, +): + disc_factor = adopt_weight( + args.lambda_adv, global_step, threshold=args.discriminator_iter_start + ) + d_loss = disc_factor * criterion_d( + y, + y_hat, + fmap, + fmap_hat, + y_df, + y_df_hat, + fmap_f, + fmap_f_hat, + y_ds, + y_ds_hat, + fmap_s, + fmap_s_hat, + ) + return d_loss diff --git a/egs/libritts/CODEC/encodec/models/discriminators.py b/egs/libritts/CODEC/encodec/models/discriminators.py new file mode 100644 index 0000000000..900349b554 --- /dev/null +++ b/egs/libritts/CODEC/encodec/models/discriminators.py @@ -0,0 +1,229 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from einops import rearrange +from utils import get_2d_padding, get_padding + +from ..modules import NormConv1d, NormConv2d + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period, + kernel_size=5, + stride=3, + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super(DiscriminatorP, self).__init__() + + self.period = period + self.activation = getattr(torch.nn, activation)(**activation_params) + self.convs = nn.ModuleList( + [ + NormConv2d( + 1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0) + ), + NormConv2d( + 32, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ), + NormConv2d( + 32, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ), + NormConv2d( + 32, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ), + NormConv2d(32, 32, (kernel_size, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = NormConv2d(32, 1, (3, 1), 1, padding=(1, 0)) + + def forward(self, x): + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = self.activation(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(nn.Module): + def __init__( + self, + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super(DiscriminatorS, self).__init__() + self.activation = getattr(torch.nn, activation)(**activation_params) + self.convs = nn.ModuleList( + [ + NormConv1d(1, 32, 15, 1, padding=7), + NormConv1d(32, 32, 41, 2, groups=4, padding=20), + NormConv1d(32, 32, 41, 2, groups=16, padding=20), + NormConv1d(32, 32, 41, 4, groups=16, padding=20), + NormConv1d(32, 32, 41, 4, groups=16, padding=20), + NormConv1d(32, 32, 41, 1, groups=16, padding=20), + NormConv1d(32, 32, 5, 1, padding=2), + ] + ) + self.conv_post = NormConv1d(32, 1, 3, 1, padding=1) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = self.activation(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + + def __init__( + self, + n_filters: int, + in_channels: int = 1, + out_channels: int = 1, + n_fft: int = 1024, + hop_length: int = 256, + win_length: int = 1024, + max_filters: int = 1024, + filters_scale: int = 1, + kernel_size: Tuple[int, int] = (3, 9), + dilations: List[int] = [1, 2, 4], + stride: Tuple[int, int] = (1, 2), + normalized: bool = True, + norm: str = "weight_norm", + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = n_filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window_fn=torch.hann_window, + normalized=self.normalized, + center=False, + pad_mode=None, + power=None, + ) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d( + spec_channels, + self.filters, + kernel_size=kernel_size, + padding=get_2d_padding(kernel_size), + ) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=(dilation, 1), + padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm, + ) + ) + in_chs = out_chs + out_chs = min( + (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + ) + self.conv_post = NormConv2d( + out_chs, + self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + + def forward(self, x: torch.Tensor): + fmap = [] + # print('x ', x.shape) + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + # print('z ', z.shape) + z = torch.cat([z.real, z.imag], dim=1) + # print('cat_z ', z.shape) + z = rearrange(z, "b c w t -> b c t w") + for i, layer in enumerate(self.convs): + z = layer(z) + z = self.activation(z) + # print('z i', i, z.shape) + fmap.append(z) + z = self.conv_post(z) + # print('logit ', z.shape) + return z, fmap diff --git a/egs/libritts/CODEC/encodec/models/utils.py b/egs/libritts/CODEC/encodec/models/utils.py new file mode 100644 index 0000000000..2be73a312e --- /dev/null +++ b/egs/libritts/CODEC/encodec/models/utils.py @@ -0,0 +1,12 @@ +from typing import Tuple + + +def get_padding(kernel_size, dilation=1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) diff --git a/egs/libritts/CODEC/encodec/modules/__init__.py b/egs/libritts/CODEC/encodec/modules/__init__.py new file mode 100644 index 0000000000..e9f7584647 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Torch modules.""" +# flake8: noqa +from .conv import ( + NormConv1d, + NormConv2d, + NormConvTranspose1d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, + pad1d, + unpad1d, +) +from .lstm import SLSTM +from .seanet import SEANetDecoder, SEANetEncoder +from .transformer import StreamingTransformerEncoder diff --git a/egs/libritts/CODEC/encodec/modules/conv.py b/egs/libritts/CODEC/encodec/modules/conv.py new file mode 100644 index 0000000000..45518a3f8f --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/conv.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Convolutional layers wrappers and utilities.""" +import logging +import math +from typing import Any, Dict, Tuple + +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + +CONV_NORMALIZATIONS = frozenset( + [ + "none", + "weight_norm", + "spectral_norm", + "time_layer_norm", + "layer_norm", + "time_group_norm", + ] +) + + +def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == "weight_norm": + return weight_norm(module) + elif norm == "spectral_norm": + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == "layer_norm": + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == "time_group_norm": + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d( + x: Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d( + x: Tensor, + paddings: Tuple[int, int], + mode: str = "zero", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: Tensor, paddings: Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + pad_mode: str = "reflect", + ): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + logging.warning( + "SConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + padding_total = (kernel_size - 1) * dilation - (stride - 1) + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: Dict[str, Any] = {}, + ): + super().__init__() + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/egs/libritts/CODEC/encodec/modules/lstm.py b/egs/libritts/CODEC/encodec/modules/lstm.py new file mode 100644 index 0000000000..7d5b8af885 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/lstm.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""LSTM layers module.""" +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y diff --git a/egs/libritts/CODEC/encodec/modules/norm.py b/egs/libritts/CODEC/encodec/modules/norm.py new file mode 100644 index 0000000000..b7ab72f9ea --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Normalization modules.""" + +from typing import List, Union + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + + def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, "b ... t -> b t ...") + x = super().forward(x) + x = einops.rearrange(x, "b t ... -> b ... t") + return diff --git a/egs/libritts/CODEC/encodec/modules/seanet.py b/egs/libritts/CODEC/encodec/modules/seanet.py new file mode 100644 index 0000000000..50d6c3f13e --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/seanet.py @@ -0,0 +1,368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Encodec SEANet-based encoder and decoder implementation.""" + +from typing import Any, Dict, List, Optional + +import numpy as np +import torch.nn as nn +from modules import SLSTM, SConv1d, SConvTranspose1d + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + + def __init__( + self, + dim: int, + kernel_sizes: List[int] = [3, 1], + dilations: List[int] = [1, 1], + activation: str = "ELU", + activation_params: Dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: Dict[str, Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): + super().__init__() + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + SConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: Dict[str, Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + ): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) # 计算乘积 + + act = getattr(nn, activation) + mult = 1 + model: List[nn.Module] = [ + SConv1d( + channels, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + # Add downsampling layers + model += [ + act(**activation_params), + SConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + mult *= 2 + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + SConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: Optional[str] = None, + final_activation_params: Optional[dict] = None, + norm: str = "weight_norm", + norm_params: Dict[str, Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + trim_right_ratio: float = 1.0, + ): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: List[nn.Module] = [ + SConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + act(**activation_params), + SConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + SConv1d( + n_filters, + channels, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [final_act(**final_activation_params)] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + print("z ", z.shape) + assert 1 == 2 + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == "__main__": + test() diff --git a/egs/libritts/CODEC/encodec/modules/transformer.py b/egs/libritts/CODEC/encodec/modules/transformer.py new file mode 100644 index 0000000000..9ef2c7ac15 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/transformer.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""A streamable transformer.""" +import typing as tp +from typing import Any, List, Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +def create_sin_embedding(positions: Tensor, dim: int, max_period: float = 10000): + """Create time embedding for the given positions, target dimension `dim`.""" + # We aim for BTC format + assert dim % 2 == 0 + half_dim = dim // 2 + adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) + phase = positions / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ) + + +class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): + def forward(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore + if self.norm_first: + sa_input = self.norm1(x) + x = x + self._sa_block(sa_input, x_past, past_context) + x = x + self._ff_block(self.norm2(x)) + else: + sa_input = x + x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) + x = self.norm2(x + self._ff_block(x)) + + return x, sa_input + + # self-attention block + def _sa_block(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore + _, T, _ = x.shape + _, H, _ = x_past.shape + + queries = x + keys = torch.cat([x_past, x], dim=1) + values = keys + + queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) + keys_pos = torch.arange(T + H, device=x.device).view(1, -1) + delta = queries_pos - keys_pos + valid_access = (delta >= 0) & (delta <= past_context) + x = self.self_attn( + queries, keys, values, attn_mask=~valid_access, need_weights=False + )[0] + return self.dropout1(x) + + +class StreamingTransformerEncoder(nn.Module): + """TransformerEncoder with streaming support. + + Args: + dim (int): dimension of the data. + hidden_scale (int): intermediate dimension of FF module is this times the dimension. + num_heads (int): number of heads. + num_layers (int): number of layers. + max_period (float): maxium period of cosines in the positional embedding. + past_context (int or None): receptive field for the causal mask, infinite if None. + gelu (bool): if true uses GeLUs, otherwise use ReLUs. + norm_in (bool): normalize the input. + dropout (float): dropout probability. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + + def __init__( + self, + dim, + hidden_scale: float = 4.0, + num_heads: int = 8, + num_layers: int = 5, + max_period: float = 10000, + past_context: int = 1000, + gelu: bool = True, + norm_in: bool = True, + dropout: float = 0.0, + **kwargs + ): + super().__init__() + assert dim % num_heads == 0 + hidden_dim = int(dim * hidden_scale) + + self.max_period = max_period + self.past_context = past_context + activation: Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + else: + self.norm_in = nn.Identity() + + self.layers = nn.ModuleList() + for idx in range(num_layers): + self.layers.append( + StreamingTransformerEncoderLayer( + dim, + num_heads, + hidden_dim, + activation=activation, + batch_first=True, + dropout=dropout, + **kwargs + ) + ) + + def forward( + self, + x: Tensor, + states: Optional[List[Tensor]] = None, + offset: Union[int, Tensor] = 0, + ): + B, T, C = x.shape + if states is None: + states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] + + positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) + + new_state: List[Tensor] = [] + x = self.norm_in(x) + x = x + pos_emb + + for layer_state, layer in zip(states, self.layers): + x, new_layer_state = layer(x, layer_state, self.past_context) + new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) + new_state.append(new_layer_state[:, -self.past_context :, :]) + return x, new_state, offset + T diff --git a/egs/libritts/CODEC/encodec/quantization/__init__.py b/egs/libritts/CODEC/encodec/quantization/__init__.py new file mode 100644 index 0000000000..7364623400 --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/egs/libritts/CODEC/encodec/quantization/ac.py b/egs/libritts/CODEC/encodec/quantization/ac.py new file mode 100644 index 0000000000..660931b410 --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/ac.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Arithmetic coder.""" +import io +import math +import random +from typing import IO, Any, List, Optional + +import torch +from torch import Tensor + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf( + pdf: Tensor, + total_range_bits: int, + roundoff: float = 1e-8, + min_range: int = 2, + check: bool = True, +) -> Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2**total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + if ( + (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range + ).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: List[Any] = [] + self._dbg2: List[Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2**self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, ( + effective_low, + effective_high, + range_low, + range_high, + ) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream.""" + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + + def __init__(self, fo: IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: List[Any] = [] + self._dbg2: List[Any] = [] + self._last: Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + self.current -= b1 << self.max_bit + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: Tensor) -> Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2**self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/egs/libritts/CODEC/encodec/quantization/core_vq.py b/egs/libritts/CODEC/encodec/quantization/core_vq.py new file mode 100644 index 0000000000..66d3dcf5dd --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/core_vq.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Core vector quantization implementation.""" + +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn + +from .distrib import broadcast_tensors + + +def default(val: Any, d: Any) -> Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: Union[Callable[..., torch.Tensor], Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode( + self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None + ) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + st = st or 0 + for layer in self.layers[st:n_q]: # 设置解码的起止layer + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/egs/libritts/CODEC/encodec/quantization/distrib.py b/egs/libritts/CODEC/encodec/quantization/distrib.py new file mode 100644 index 0000000000..5b1b06d688 --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/distrib.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Torch distributed utilities.""" +from typing import Dict, Iterable, List + +import torch +from torch import distributed as dist + + +def rank(): + if dist.is_initialized(): + return dist.get_rank() + else: + return 0 + + +def world_size(): + if dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM): + if is_distributed(): + return dist.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + # src = int(rank()) # added code + handle = dist.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = dist.all_reduce( + buffer.data, op=dist.ReduceOp.SUM, async_op=True + ) + else: + handle = dist.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/egs/libritts/CODEC/encodec/quantization/vq.py b/egs/libritts/CODEC/encodec/quantization/vq.py new file mode 100644 index 0000000000..22212a7942 --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/vq.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Residual vector quantizer implementation.""" +import math +from dataclasses import dataclass, field +from typing import Optional + +import torch +from torch import Tensor, nn + +from .core_vq import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: Tensor + codes: Tensor + bandwidth: Tensor # bandwidth in kb/s used, per batch item. + penalty: Optional[Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward( + self, x: Tensor, sample_rate: int, bandwidth: Optional[float] = None + ) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (Tensor): Input tensor. + sample_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return quantized, codes, bw, torch.mean(commit_loss) + # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth( + self, sample_rate: int, bandwidth: Optional[float] = None + ) -> int: + """Return n_q based on specified target bandwidth.""" + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.0: + n_q = int(max(1, math.floor(bandwidth / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, sample_rate: int): + """Return bandwidth per quantizer for a given input sample rate.""" + return math.log2(self.bins) * sample_rate / 1000 + + def encode( + self, + x: Tensor, + sample_rate: int, + bandwidth: Optional[float] = None, + st: Optional[int] = None, + ) -> Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + st = st or 0 + codes = self.vq.encode(x, n_q=n_q, st=st) + return codes + + def decode(self, codes: Tensor) -> Tensor: + """Decode the given codes to the quantized representation.""" + quantized = self.vq.decode(codes) + return quantized diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py new file mode 100644 index 0000000000..0d08a2e24c --- /dev/null +++ b/egs/libritts/CODEC/encodec/train.py @@ -0,0 +1,902 @@ +import argparse +import itertools +import logging +import math +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from encodec import Encodec +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import nn +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from utils import MetricsTracker, plot_feature, save_checkpoint + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=500, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lr", type=float, default=3.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 24000, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_wav": 100.0, # loss scaling coefficient for waveform loss + "lambda_feat": 1.0, # loss scaling coefficient for feat loss + "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss + "lambda_com": 1000.0, # loss scaling coefficient for commitment loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + """Get the model based on the configuration.""" + + from discriminators import ( + MultiPeriodDiscriminator, + MultiScaleDiscriminator, + MultiScaleSTFTDiscriminator, + ) + from modules.seanet import SEANetDecoder, SEANetEncoder + from quantization import ResidualVectorQuantizer + + generator_params = { + "generator_n_filters": 32, + "dimension": 512, + "ratios": [2, 2, 2, 4], + "target_bandwidths": [7.5, 15], + "bins": 1024, + } + discriminator_params = { + "stft_discriminator_n_filters": 32, + } + + params.update(generator_params) + params.update(discriminator_params) + + hop_length = np.prod(params.ratios) + n_q = int( + 1000 + * params.target_bandwidths[-1] + // (math.ceil(params.sample_rate / hop_length) * 10) + ) + + encoder = SEANetEncoder( + n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios + ) + decoder = SEANetDecoder( + n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios + ) + quantizer = ResidualVectorQuantizer( + dimension=params.dimension, n_q=n_q, bins=params.bins + ) + + model = Encodec( + params=params, + sample_rate=params.sampling_rate, + target_bandwidths=params.target_bandwidths, + encoder=encoder, + quantizer=quantizer, + decoder=decoder, + multi_scale_discriminator=MultiScaleDiscriminator(), + multi_period_discriminator=MultiPeriodDiscriminator(), + multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(), + ) + return model + + +def prepare_input( + batch: dict, + device: torch.device, +): + """Parse batch data""" + audio = batch["audio"].to(device, memory_format=torch.contiguous_format) + features = batch["features"].to(device, memory_format=torch.contiguous_format) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + + return audio, audio_lens, features, features_lens + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model to be trained. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + _, + _, + ) = prepare_input(batch, device) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + speech=audio, + speech_lengths=audio_lens, + global_step=params.batch_idx_train, + return_sample=False, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + speech=audio, + speech_lengths=audio_lens, + global_step=params.batch_idx_train, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valdi_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valdi_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + _, + _, + ) = prepare_input(batch, device) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + speech=audio, + speech_lengths=audio_lens, + global_step=params.batch_idx_train, + return_sample=False, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + speech=audio, + speech_lengths=audio_lens, + global_step=params.batch_idx_train, + forward_generator=True, + return_sample=batch_idx == 0, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + speech_hat_, speech_, _, _ = stats_g["returned_sample"] + + returned_sample = (speech_hat_, speech_) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + _, + _, + ) = prepare_input(batch, device) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + speech=audio, + speech_lengths=audio_lens, + global_step=params.batch_idx_train, + return_sample=False, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + global_step=params.batch_idx_train, + return_sample=False, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + vctk = VctkTtsDataModule(args) + + train_cuts = vctk.train_cuts() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + encoder = model.encoder + decoder = model.decoder + quantizer = model.quantizer + multi_scale_discriminator = model.multi_scale_discriminator + multi_period_discriminator = model.multi_period_discriminator + multi_scale_stft_discriminator = model.multi_scale_stft_discriminator + + num_param_e = sum([p.numel() for p in encoder.parameters()]) + logging.info(f"Number of parameters in encoder: {num_param_e}") + num_param_d = sum([p.numel() for p in decoder.parameters()]) + logging.info(f"Number of parameters in decoder: {num_param_d}") + num_param_q = sum([p.numel() for p in quantizer.parameters()]) + logging.info(f"Number of parameters in quantizer: {num_param_q}") + num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) + logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") + num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) + logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") + num_param_dstft = sum( + [p.numel() for p in multi_scale_stft_discriminator.parameters()] + ) + logging.info( + f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" + ) + logging.info( + f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" + ) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + itertools.chain( + encoder.parameters(), + quantizer.parameters(), + decoder.parameters(), + ), + lr=params.lr, + betas=(0.5, 0.9), + ) + optimizer_d = torch.optim.AdamW( + itertools.chain( + multi_scale_stft_discriminator.parameters(), + multi_scale_discriminator.parameters(), + multi_period_discriminator.parameters(), + ), + lr=params.lr, + betas=(0.5, 0.9), + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + train_dl = vctk.train_dataloaders(train_cuts) + + valid_cuts = vctk.valid_cuts() + valid_dl = vctk.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + VctkTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/CODEC/encodec/utils.py b/egs/libritts/CODEC/encodec/utils.py new file mode 120000 index 0000000000..7c95867761 --- /dev/null +++ b/egs/libritts/CODEC/encodec/utils.py @@ -0,0 +1 @@ +../../../vctk/TTS/vits/utils.py \ No newline at end of file From 6e4a9ea85a50bfaf67cdf4f8e65ba73f07496991 Mon Sep 17 00:00:00 2001 From: JinZr Date: Thu, 5 Sep 2024 22:30:07 +0800 Subject: [PATCH 02/33] a little bit coarse commit --- .../ASR/local/compute_spectrogram_libritts.py | 104 +++++++++---- egs/libritts/ASR/prepare.sh | 11 +- ...scriminators.py => base_discriminators.py} | 13 +- .../CODEC/encodec/codec_datamodule.py | 77 ++++++++-- egs/libritts/CODEC/encodec/discriminators.py | 8 +- egs/libritts/CODEC/encodec/encodec.py | 24 ++- egs/libritts/CODEC/encodec/loss.py | 12 +- egs/libritts/CODEC/encodec/models/utils.py | 12 -- egs/libritts/CODEC/encodec/train.py | 144 ++++++++++++------ 9 files changed, 274 insertions(+), 131 deletions(-) rename egs/libritts/CODEC/encodec/{models/discriminators.py => base_discriminators.py} (95%) delete mode 100644 egs/libritts/CODEC/encodec/models/utils.py diff --git a/egs/libritts/ASR/local/compute_spectrogram_libritts.py b/egs/libritts/ASR/local/compute_spectrogram_libritts.py index 181353fdd6..6cdc55bc89 100755 --- a/egs/libritts/ASR/local/compute_spectrogram_libritts.py +++ b/egs/libritts/ASR/local/compute_spectrogram_libritts.py @@ -25,19 +25,16 @@ The generated fbank features are saved in data/spectrogram. """ +import argparse import logging import os from pathlib import Path +from typing import Optional import torch -from lhotse import ( - CutSet, - LilcomChunkyWriter, - Spectrogram, - SpectrogramConfig, - load_manifest, -) +from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig from lhotse.audio import RecordingSet +from lhotse.recipes.utils import read_manifests_if_cached from lhotse.supervision import SupervisionSet from icefall.utils import get_executor @@ -49,26 +46,62 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) +def get_args(): + parser = argparse.ArgumentParser() -def compute_spectrogram_libritts(): + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""", + ) + + return parser.parse_args() + + +def compute_spectrogram_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000,): src_dir = Path("data/manifests") output_dir = Path("data/spectrogram") num_jobs = min(32, os.cpu_count()) - sampling_rate = 24000 + frame_length = 1024 / sampling_rate # (in second) frame_shift = 256 / sampling_rate # (in second) use_fft_mag = True prefix = "libritts" suffix = "jsonl.gz" - partition = "all" + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None - recordings = load_manifest( - src_dir / f"{prefix}_recordings_{partition}.jsonl.gz", RecordingSet - ).resample(sampling_rate=sampling_rate) - supervisions = load_manifest( - src_dir / f"{prefix}_supervisions_{partition}.jsonl.gz", SupervisionSet + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, ) config = SpectrogramConfig( @@ -80,24 +113,29 @@ def compute_spectrogram_libritts(): extractor = Spectrogram(config) with get_executor() as ex: # Initialize the executor only once. - cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" - if (output_dir / cuts_filename).is_file(): - logging.info(f"{partition} already exists - skipping.") - return - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=recordings, supervisions=supervisions - ) - - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/{prefix}_feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - cut_set.to_file(output_dir / cuts_filename) + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if sampling_rate != 24000: + logging.info(f"Resampling audio to {sampling_rate}") + cut_set = cut_set.resample(sampling_rate) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) if __name__ == "__main__": diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh index 77c3c38422..f3a78bdb81 100755 --- a/egs/libritts/ASR/prepare.sh +++ b/egs/libritts/ASR/prepare.sh @@ -8,6 +8,7 @@ set -eou pipefail stage=0 stop_stage=100 sampling_rate=24000 +nj=32 perturb_speed=true dl_dir=$PWD/download @@ -54,7 +55,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/LibriTTS mkdir -p data/manifests if [ ! -e data/manifests/.libritts.done ]; then - lhotse prepare libritts $dl_dir/LibriTTS data/manifests + lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests touch data/manifests/.libritts.done fi fi @@ -84,10 +85,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # Here we shuffle and combine the train-clean-100, train-clean-360 and # train-other-500 together to form the training set. if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then - cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + cat <(gunzip -c ./libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c ./libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c ./libritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > ./libritts_cuts_train-all-shuf.jsonl.gz fi if [ ! -e data/fbank/.libritts-validated.done ]; then diff --git a/egs/libritts/CODEC/encodec/models/discriminators.py b/egs/libritts/CODEC/encodec/base_discriminators.py similarity index 95% rename from egs/libritts/CODEC/encodec/models/discriminators.py rename to egs/libritts/CODEC/encodec/base_discriminators.py index 900349b554..e112436e50 100644 --- a/egs/libritts/CODEC/encodec/models/discriminators.py +++ b/egs/libritts/CODEC/encodec/base_discriminators.py @@ -5,9 +5,18 @@ import torch.nn.functional as F import torchaudio from einops import rearrange -from utils import get_2d_padding, get_padding +from modules.conv import NormConv1d, NormConv2d -from ..modules import NormConv1d, NormConv2d + +def get_padding(kernel_size, dilation=1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) class DiscriminatorP(nn.Module): diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py index 996569d215..b547e8513f 100644 --- a/egs/libritts/CODEC/encodec/codec_datamodule.py +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -80,6 +80,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "augmentations, etc.", ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""When enabled, use the entire LibriTTS training set. + Otherwise, use the clean-100 subset.""", + ) group.add_argument( "--manifest-dir", type=Path, @@ -210,8 +217,8 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: validate = SpeechSynthesisDataset( return_text=False, - return_tokens=True, - return_spk_ids=True, + return_tokens=False, + return_spk_ids=False, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -236,8 +243,8 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: test = SpeechSynthesisDataset( return_text=False, - return_tokens=True, - return_spk_ids=True, + return_tokens=False, + return_spk_ids=False, feature_input_strategy=eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) @@ -256,16 +263,60 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: return test_dl @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_train.jsonl.gz") + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" + ) @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get validation cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_valid.jsonl.gz") + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" + ) @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "vctk_cuts_test.jsonl.gz") + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/libritts/CODEC/encodec/discriminators.py b/egs/libritts/CODEC/encodec/discriminators.py index 484f1ee431..471aa92443 100644 --- a/egs/libritts/CODEC/encodec/discriminators.py +++ b/egs/libritts/CODEC/encodec/discriminators.py @@ -1,8 +1,8 @@ -from typing import List, Tuple +from typing import List import torch import torch.nn as nn -from models.discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT +from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT from torch.nn import AvgPool1d @@ -81,7 +81,7 @@ class MultiScaleSTFTDiscriminator(nn.Module): def __init__( self, - filters: int, + n_filters: int, in_channels: int = 1, out_channels: int = 1, n_ffts: List[int] = [1024, 2048, 512, 256, 128], @@ -94,7 +94,7 @@ def __init__( self.discriminators = nn.ModuleList( [ DiscriminatorSTFT( - filters, + n_filters, in_channels=in_channels, out_channels=out_channels, n_fft=n_ffts[i], diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index e7c5ad590a..071dc19bae 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -12,7 +12,7 @@ class Encodec(nn.Module): def __init__( self, - sample_rate: int, + sampling_rate: int, target_bandwidths: List[float], params: dict, encoder: nn.Module, @@ -21,21 +21,21 @@ def __init__( multi_scale_discriminator: nn.Module, multi_period_discriminator: nn.Module, multi_scale_stft_discriminator: nn.Module, - cache_generator_outputs: bool = True, + cache_generator_outputs: bool = False, ): super(Encodec, self).__init__() self.params = params # setup the generator - self.sample_rate = sample_rate + self.sampling_rate = sampling_rate self.encoder = encoder self.quantizer = quantizer self.decoder = decoder self.ratios = encoder.ratios self.hop_length = np.prod(self.ratios) - self.frame_rate = math.ceil(self.sample_rate / np.prod(self.ratios)) + self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios)) self.target_bandwidths = target_bandwidths # discriminators @@ -133,10 +133,10 @@ def _forward_generator( if return_sample: stats["returned_sample"] = ( - speech_hat[0].data.cpu().numpy(), - speech[0].data.cpu().numpy(), - fmap_hat[0][0].data.cpu().numpy(), - fmap[0][0].data.cpu().numpy(), + speech_hat.cpu(), + speech.cpu(), + fmap_hat[0][0].data.cpu(), + fmap[0][0].data.cpu(), ) # reset cache @@ -259,3 +259,11 @@ def decode(self, codes): quantized = self.quantizer.decode(codes) o = self.decoder(quantized) return o + + def inference(self, x, target_bw=None, st=None): + # setup + x = x.unsqueeze(1) + + codes = self.encode(x, target_bw, st) + o = self.decode(codes) + return o diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 1bb78f2839..9ec80f5369 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -59,9 +59,9 @@ def sim_loss(y_disc_r, y_disc_gen): # return torch.sum(loss) / x.shape[0] -def reconstruction_loss(x, G_x, args, eps=1e-7): +def reconstruction_loss(x, x_hat, args, eps=1e-7): # NOTE (lsx): hard-coded now - L = args.lambda_wav * F.mse_loss(x, G_x) # wav L1 loss + L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss # loss_sisnr = sisnr_loss(G_x, x) # # L += 0.01*loss_sisnr # 2^6=64 -> 2^10=1024 @@ -70,15 +70,15 @@ def reconstruction_loss(x, G_x, args, eps=1e-7): # for i in range(5, 12): # Encodec setting s = 2**i melspec = MelSpectrogram( - sample_rate=args.sr, + sample_rate=args.sampling_rate, n_fft=max(s, 512), win_length=s, hop_length=s // 4, n_mels=64, - wkwargs={"device": args.device}, - ).to(args.device) + wkwargs={"device": x_hat.device}, + ).to(x_hat.device) S_x = melspec(x) - S_G_x = melspec(G_x) + S_G_x = melspec(x_hat) l1_loss = (S_x - S_G_x).abs().mean() l2_loss = ( ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean( diff --git a/egs/libritts/CODEC/encodec/models/utils.py b/egs/libritts/CODEC/encodec/models/utils.py deleted file mode 100644 index 2be73a312e..0000000000 --- a/egs/libritts/CODEC/encodec/models/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Tuple - - -def get_padding(kernel_size, dilation=1) -> int: - return int((kernel_size * dilation - dilation) / 2) - - -def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)): - return ( - ((kernel_size[0] - 1) * dilation[0]) // 2, - ((kernel_size[1] - 1) * dilation[1]) // 2, - ) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 0d08a2e24c..6057ba2abe 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -2,6 +2,7 @@ import itertools import logging import math +import random from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -10,6 +11,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn +from codec_datamodule import LibriTTSCodecDataModule from encodec import Encodec from lhotse.cut import Cut from lhotse.utils import fix_random_seed @@ -76,7 +78,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="vits/exp", + default="encodec/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -127,6 +129,12 @@ def get_parser(): default=False, help="Whether to use half precision training.", ) + parser.add_argument( + "--chunk-size", + type=int, + default=1, + help="The chunk size for the dataset (in second).", + ) return parser @@ -249,23 +257,32 @@ def get_model(params: AttributeDict) -> nn.Module: } discriminator_params = { "stft_discriminator_n_filters": 32, + "discriminator_iter_start": 500, + } + inference_params = { + "target_bw": 7.5, } params.update(generator_params) params.update(discriminator_params) + params.update(inference_params) hop_length = np.prod(params.ratios) n_q = int( 1000 * params.target_bandwidths[-1] - // (math.ceil(params.sample_rate / hop_length) * 10) + // (math.ceil(params.sampling_rate / hop_length) * 10) ) encoder = SEANetEncoder( - n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios + n_filters=params.generator_n_filters, + dimension=params.dimension, + ratios=params.ratios, ) decoder = SEANetDecoder( - n_filters=params.n_filters, dimension=params.dimension, ratios=params.ratios + n_filters=params.generator_n_filters, + dimension=params.dimension, + ratios=params.ratios, ) quantizer = ResidualVectorQuantizer( dimension=params.dimension, n_q=n_q, bins=params.bins @@ -273,21 +290,25 @@ def get_model(params: AttributeDict) -> nn.Module: model = Encodec( params=params, - sample_rate=params.sampling_rate, + sampling_rate=params.sampling_rate, target_bandwidths=params.target_bandwidths, encoder=encoder, quantizer=quantizer, decoder=decoder, multi_scale_discriminator=MultiScaleDiscriminator(), multi_period_discriminator=MultiPeriodDiscriminator(), - multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator(), + multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( + n_filters=params.stft_discriminator_n_filters + ), ) return model def prepare_input( + params: AttributeDict, batch: dict, device: torch.device, + is_training: bool = True, ): """Parse batch data""" audio = batch["audio"].to(device, memory_format=torch.contiguous_format) @@ -295,6 +316,18 @@ def prepare_input( audio_lens = batch["audio_lens"].to(device) features_lens = batch["features_lens"].to(device) + if is_training: + audio_dims = audio.size(-1) + start_idx = random.randint( + 0, max(0, audio_dims - params.chunk_size * params.sampling_rate) + ) + audio = audio[:, start_idx : params.sampling_rate + start_idx] + else: + # NOTE: a very coarse setup + audio = audio[ + :, params.sampling_rate : params.sampling_rate + params.sampling_rate + ] + return audio, audio_lens, features, features_lens @@ -371,13 +404,13 @@ def save_bad_model(suffix: str = ""): for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 - batch_size = len(batch["tokens"]) + batch_size = len(batch["audio"]) ( audio, audio_lens, _, _, - ) = prepare_input(batch, device) + ) = prepare_input(params, batch, device) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -476,31 +509,38 @@ def save_bad_model(suffix: str = ""): "train/grad_scale", cur_grad_scale, params.batch_idx_train ) if "returned_sample" in stats_g: - speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + # speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + speech_hat_, speech_, _, _ = stats_g["returned_sample"] + + speech_hat_i = speech_hat_[0] + speech_i = speech_[0] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) tb_writer.add_audio( - "train/speech_hat_", - speech_hat_, + f"train/speech_hat_", + speech_hat_i, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( - "train/speech_", - speech_, + f"train/speech_", + speech_i, params.batch_idx_train, params.sampling_rate, ) - tb_writer.add_image( - "train/mel_hat_", - plot_feature(mel_hat_), - params.batch_idx_train, - dataformats="HWC", - ) - tb_writer.add_image( - "train/mel_", - plot_feature(mel_), - params.batch_idx_train, - dataformats="HWC", - ) + # tb_writer.add_image( + # "train/mel_hat_", + # plot_feature(mel_hat_), + # params.batch_idx_train, + # dataformats="HWC", + # ) + # tb_writer.add_image( + # "train/mel_", + # plot_feature(mel_), + # params.batch_idx_train, + # dataformats="HWC", + # ) if ( params.batch_idx_train % params.valid_interval == 0 @@ -522,15 +562,20 @@ def save_bad_model(suffix: str = ""): valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) + speech_hat_i = speech_hat[0] + speech_i = speech[0] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) tb_writer.add_audio( "train/valdi_speech_hat", - speech_hat, + speech_hat_i, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( "train/valdi_speech", - speech, + speech_i, params.batch_idx_train, params.sampling_rate, ) @@ -559,13 +604,13 @@ def compute_validation_loss( with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): - batch_size = len(batch["tokens"]) + batch_size = len(batch["audio"]) ( audio, audio_lens, _, _, - ) = prepare_input(batch, device) + ) = prepare_input(params, batch, device, is_training=False) loss_info = MetricsTracker() loss_info["samples"] = batch_size @@ -588,7 +633,7 @@ def compute_validation_loss( speech_lengths=audio_lens, global_step=params.batch_idx_train, forward_generator=True, - return_sample=batch_idx == 0, + return_sample=False, ) assert loss_g.requires_grad is False for k, v in stats_g.items(): @@ -599,9 +644,9 @@ def compute_validation_loss( # infer for first batch: if batch_idx == 0 and rank == 0: - speech_hat_, speech_, _, _ = stats_g["returned_sample"] - - returned_sample = (speech_hat_, speech_) + inner_model = model.module if isinstance(model, DDP) else model + audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) + returned_sample = (audio_pred, audio) if world_size > 1: tot_loss.reduce(device) @@ -635,7 +680,7 @@ def scan_pessimistic_batches_for_oom( audio_lens, _, _, - ) = prepare_input(batch, device) + ) = prepare_input(params, batch, device) try: # for discriminator with autocast(enabled=params.use_fp16): @@ -706,9 +751,12 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - vctk = VctkTtsDataModule(args) + libritts = LibriTTSCodecDataModule(args) - train_cuts = vctk.train_cuts() + if params.full_libri: + train_cuts = libritts.train_all_shuf_cuts() + else: + train_cuts = libritts.train_clean_100_cuts() logging.info(params) @@ -798,19 +846,19 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - train_dl = vctk.train_dataloaders(train_cuts) + train_dl = libritts.train_dataloaders(train_cuts) - valid_cuts = vctk.valid_cuts() - valid_dl = vctk.valid_dataloaders(valid_cuts) + valid_cuts = libritts.dev_clean_cuts() + valid_dl = libritts.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer_g=optimizer_g, - optimizer_d=optimizer_d, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer_g=optimizer_g, + # optimizer_d=optimizer_d, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -883,7 +931,7 @@ def run(rank, world_size, args): def main(): parser = get_parser() - VctkTtsDataModule.add_arguments(parser) + LibriTTSCodecDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) From 2df992f98a7ab7eb56caf98cc4d3d5ba01f1f4aa Mon Sep 17 00:00:00 2001 From: JinZr Date: Thu, 5 Sep 2024 22:35:57 +0800 Subject: [PATCH 03/33] fixed a typo --- egs/libritts/CODEC/encodec/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 6057ba2abe..11c8458567 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -568,13 +568,13 @@ def save_bad_model(suffix: str = ""): speech_hat_i = speech_hat_i.squeeze(0) speech_i = speech_i.squeeze(0) tb_writer.add_audio( - "train/valdi_speech_hat", + "train/valid_speech_hat", speech_hat_i, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", + "train/valid_speech", speech_i, params.batch_idx_train, params.sampling_rate, From 91f7b1ce6f70b34f72a385702c77edb421f793ec Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 18:07:50 +0800 Subject: [PATCH 04/33] sort of fixed DDP training issue --- .../CODEC/encodec/codec_datamodule.py | 24 +++++++++++++++---- egs/libritts/CODEC/encodec/encodec.py | 15 +++++++++--- egs/libritts/CODEC/encodec/loss.py | 6 ++--- egs/libritts/CODEC/encodec/train.py | 24 +++++++++++++++---- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py index b547e8513f..e84f08e708 100644 --- a/egs/libritts/CODEC/encodec/codec_datamodule.py +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -139,7 +139,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--num-workers", type=int, - default=8, + default=2, help="The number of training dataloader workers that " "collect the batches.", ) @@ -155,6 +155,8 @@ def train_dataloaders( self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, ) -> DataLoader: """ Args: @@ -182,6 +184,8 @@ def train_dataloaders( buffer_size=self.args.num_buckets * 2000, shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -189,6 +193,8 @@ def train_dataloaders( cuts_train, max_duration=self.args.max_duration, shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, ) logging.info("About to create train dataloader") @@ -206,13 +212,18 @@ def train_dataloaders( sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, - persistent_workers=False, + persistent_workers=True, worker_init_fn=worker_init_fn, ) return train_dl - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: logging.info("About to create dev dataset") validate = SpeechSynthesisDataset( @@ -226,14 +237,17 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: cuts_valid, max_duration=self.args.max_duration, shuffle=False, + world_size=world_size, + rank=rank, ) logging.info("About to create valid dataloader") valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, - num_workers=2, - persistent_workers=False, + num_workers=1, + drop_last=False, + persistent_workers=True, ) return valid_dl diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 071dc19bae..32d80eb38d 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -74,14 +74,18 @@ def _forward_generator( if not self.cache_generator_outputs or self._cache is None: reuse_cache = False e = self.encoder(speech) - bw = random.choice(self.target_bandwidths) + index = torch.tensor( + random.randint(0, len(self.target_bandwidths) - 1), device=speech.device, + ) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(index, src=0) + bw = self.target_bandwidths[index.item()] quantized, codes, bandwidth, commit_loss = self.quantizer( e, self.frame_rate, bw ) speech_hat = self.decoder(quantized) else: speech_hat = self._cache - # store cache if self.training and self.cache_generator_outputs and not reuse_cache: self._cache = speech_hat @@ -169,7 +173,12 @@ def _forward_discriminator( if not self.cache_generator_outputs or self._cache is None: reuse_cache = False e = self.encoder(speech) - bw = random.choice(self.target_bandwidths) + index = torch.tensor( + random.randint(0, len(self.target_bandwidths) - 1), device=speech.device, + ) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(index, src=0) + bw = self.target_bandwidths[index.item()] quantized, codes, bandwidth, commit_loss = self.quantizer( e, self.frame_rate, bw ) diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 9ec80f5369..0614abf92f 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -78,10 +78,10 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7): wkwargs={"device": x_hat.device}, ).to(x_hat.device) S_x = melspec(x) - S_G_x = melspec(x_hat) - l1_loss = (S_x - S_G_x).abs().mean() + S_x_hat = melspec(x_hat) + l1_loss = (S_x - S_x_hat).abs().mean() l2_loss = ( - ((torch.log(S_x.abs() + eps) - torch.log(S_G_x.abs() + eps)) ** 2).mean( + ((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean( dim=-2 ) ** 0.5 diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 11c8458567..7dfbef2b64 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -552,13 +552,14 @@ def save_bad_model(suffix: str = ""): model=model, valid_dl=valid_dl, world_size=world_size, + rank=rank, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) - if tb_writer is not None: + if tb_writer is not None and rank == 0 and speech_hat is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) @@ -647,6 +648,8 @@ def compute_validation_loss( inner_model = model.module if isinstance(model, DDP) else model audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) returned_sample = (audio_pred, audio) + else: + returned_sample = (None, None) if world_size > 1: tot_loss.reduce(device) @@ -796,7 +799,12 @@ def run(rank, world_size, args): if world_size > 1: logging.info("Using DDP") model = nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model = DDP( + model, + device_ids=[rank], + find_unused_parameters=True, + broadcast_buffers=False, + ) optimizer_g = torch.optim.AdamW( itertools.chain( @@ -846,10 +854,18 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - train_dl = libritts.train_dataloaders(train_cuts) + train_dl = libritts.train_dataloaders( + train_cuts, + world_size=world_size, + rank=rank, + ) valid_cuts = libritts.dev_clean_cuts() - valid_dl = libritts.valid_dataloaders(valid_cuts) + valid_dl = libritts.valid_dataloaders( + valid_cuts, + world_size=world_size, + rank=rank, + ) # if not params.print_diagnostics: # scan_pessimistic_batches_for_oom( From 2e5055a847f81e4b8e16b48090a4ef9f94c1cea4 Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 18:16:18 +0800 Subject: [PATCH 05/33] minor updates --- egs/libritts/CODEC/encodec/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 7dfbef2b64..bc39da877e 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -803,7 +803,6 @@ def run(rank, world_size, args): model, device_ids=[rank], find_unused_parameters=True, - broadcast_buffers=False, ) optimizer_g = torch.optim.AdamW( From 0150961a335e877d502a4d3bb51489484222ef4d Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 21:20:45 +0800 Subject: [PATCH 06/33] minor fixes --- egs/libritts/CODEC/encodec/encodec.py | 1 - egs/libritts/CODEC/encodec/loss.py | 30 --------------------------- egs/libritts/CODEC/encodec/train.py | 7 ++++--- 3 files changed, 4 insertions(+), 34 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 32d80eb38d..bde03034f1 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -146,7 +146,6 @@ def _forward_generator( # reset cache if reuse_cache or not self.training: self._cache = None - return loss, stats def _forward_discriminator( diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 0614abf92f..96300e9d67 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -30,35 +30,6 @@ def sim_loss(y_disc_r, y_disc_gen): return loss / len(y_disc_r) -# def sisnr_loss(x, s, eps=1e-8): -# """ -# calculate training loss -# input: -# x: separated signal, N x S tensor, estimate value -# s: reference signal, N x S tensor, True value -# Return: -# sisnr: N tensor -# """ -# if x.shape != s.shape: -# if x.shape[-1] > s.shape[-1]: -# x = x[:, :s.shape[-1]] -# else: -# s = s[:, :x.shape[-1]] -# def l2norm(mat, keepdim=False): -# return torch.norm(mat, dim=-1, keepdim=keepdim) -# if x.shape != s.shape: -# raise RuntimeError( -# "Dimention mismatch when calculate si-snr, {} vs {}".format( -# x.shape, s.shape)) -# x_zm = x - torch.mean(x, dim=-1, keepdim=True) -# s_zm = s - torch.mean(s, dim=-1, keepdim=True) -# t = torch.sum( -# x_zm * s_zm, dim=-1, -# keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps) -# loss = -20. * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) -# return torch.sum(loss) / x.shape[0] - - def reconstruction_loss(x, x_hat, args, eps=1e-7): # NOTE (lsx): hard-coded now L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss @@ -169,7 +140,6 @@ def adopt_weight(weight, global_step, threshold=0, value=0.0): def adopt_dis_weight(weight, global_step, threshold=0, value=0.0): - # 0,3,6,9,13....这些时间步,不更新dis if global_step % 3 == 0: weight = value return weight diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index bc39da877e..e207b12f7d 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -559,7 +559,7 @@ def save_bad_model(suffix: str = ""): logging.info( f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) - if tb_writer is not None and rank == 0 and speech_hat is not None: + if tb_writer is not None and rank == 0: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) @@ -634,11 +634,12 @@ def compute_validation_loss( speech_lengths=audio_lens, global_step=params.batch_idx_train, forward_generator=True, - return_sample=False, + return_sample=True, ) assert loss_g.requires_grad is False for k, v in stats_g.items(): - loss_info[k] = v * batch_size + if "returned_sample" not in k: + loss_info[k] = v * batch_size # summary stats tot_loss = tot_loss + loss_info From 8da57a04496ee10d3e3630799a5def95281500a1 Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 21:21:58 +0800 Subject: [PATCH 07/33] black formatted --- egs/libritts/CODEC/encodec/encodec.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index bde03034f1..385551d06a 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -75,7 +75,8 @@ def _forward_generator( reuse_cache = False e = self.encoder(speech) index = torch.tensor( - random.randint(0, len(self.target_bandwidths) - 1), device=speech.device, + random.randint(0, len(self.target_bandwidths) - 1), + device=speech.device, ) if torch.distributed.is_initialized(): torch.distributed.broadcast(index, src=0) @@ -173,7 +174,8 @@ def _forward_discriminator( reuse_cache = False e = self.encoder(speech) index = torch.tensor( - random.randint(0, len(self.target_bandwidths) - 1), device=speech.device, + random.randint(0, len(self.target_bandwidths) - 1), + device=speech.device, ) if torch.distributed.is_initialized(): torch.distributed.broadcast(index, src=0) From 4483c6e700975746061c471ac90033f0a2c54e49 Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 21:52:59 +0800 Subject: [PATCH 08/33] tensorboard should work properly --- egs/libritts/CODEC/encodec/train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index e207b12f7d..73d6980087 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -601,7 +601,7 @@ def compute_validation_loss( # used to summary the stats over iterations tot_loss = MetricsTracker() - returned_sample = None + returned_sample = (None, None) with torch.no_grad(): for batch_idx, batch in enumerate(valid_dl): @@ -634,7 +634,7 @@ def compute_validation_loss( speech_lengths=audio_lens, global_step=params.batch_idx_train, forward_generator=True, - return_sample=True, + return_sample=False, ) assert loss_g.requires_grad is False for k, v in stats_g.items(): @@ -649,8 +649,6 @@ def compute_validation_loss( inner_model = model.module if isinstance(model, DDP) else model audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) returned_sample = (audio_pred, audio) - else: - returned_sample = (None, None) if world_size > 1: tot_loss.reduce(device) From 12c7a16a5a3eb2966469b3027c070bccf5a35805 Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Sep 2024 22:05:21 +0800 Subject: [PATCH 09/33] minor updates --- egs/libritts/CODEC/encodec/train.py | 42 +++++++++++++++++------------ 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 73d6980087..842689155a 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -58,6 +58,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--num-samples", + type=int, + default=3, + help="Number of samples to generate for tensorboard.", + ) + parser.add_argument( "--num-epochs", type=int, @@ -563,23 +570,24 @@ def save_bad_model(suffix: str = ""): valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - speech_hat_i = speech_hat[0] - speech_i = speech[0] - if speech_hat_i.dim() > 1: - speech_hat_i = speech_hat_i.squeeze(0) - speech_i = speech_i.squeeze(0) - tb_writer.add_audio( - "train/valid_speech_hat", - speech_hat_i, - params.batch_idx_train, - params.sampling_rate, - ) - tb_writer.add_audio( - "train/valid_speech", - speech_i, - params.batch_idx_train, - params.sampling_rate, - ) + for index in range(params.num_samples): # 3 + speech_hat_i = speech_hat[index] + speech_i = speech[index] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) + tb_writer.add_audio( + f"train/valid_speech_hat_{index}", + speech_hat_i, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + f"train/valid_speech_{index}", + speech_i, + params.batch_idx_train, + params.sampling_rate, + ) loss_value = tot_loss["generator_loss"] / tot_loss["samples"] params.train_loss = loss_value From c2367576743d379256f64db6a0443fac35dc5cac Mon Sep 17 00:00:00 2001 From: JinZr Date: Sat, 7 Sep 2024 23:33:52 +0800 Subject: [PATCH 10/33] * added script for inference * minor updates --- egs/libritts/CODEC/encodec/encodec.py | 8 +- egs/libritts/CODEC/encodec/infer.py | 300 ++++++++++++++++++ .../CODEC/encodec/quantization/core_vq.py | 2 +- egs/libritts/CODEC/encodec/train.py | 15 +- 4 files changed, 311 insertions(+), 14 deletions(-) create mode 100755 egs/libritts/CODEC/encodec/infer.py diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 385551d06a..4f45be9c25 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -267,13 +267,13 @@ def encode(self, x, target_bw=None, st=None): def decode(self, codes): quantized = self.quantizer.decode(codes) - o = self.decoder(quantized) - return o + x_hat = self.decoder(quantized) + return x_hat def inference(self, x, target_bw=None, st=None): # setup x = x.unsqueeze(1) codes = self.encode(x, target_bw, st) - o = self.decode(codes) - return o + x_hat = self.decode(codes) + return codes, x_hat diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py new file mode 100755 index 0000000000..dccff984d5 --- /dev/null +++ b/egs/libritts/CODEC/encodec/infer.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Dict, List + +import torch +import torch.nn.functional as F +import torchaudio +from codec_datamodule import LibriTTSCodecDataModule +from torch import nn +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="encodec/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--target-bw", + type=float, + default=7.5, + help="The target bandwidth for the generator", + ) + + return parser + + +# implementation from https://github.com/yangdongchao/AcademiCodec/blob/master/academicodec/models/encodec/test.py +def remove_encodec_weight_norm(model) -> None: + from modules import SConv1d + from modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + subset: + The name of the subset. + params: + It is returned by :func:`get_params`. + model: + The neural model. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_recon.wav"), + audio_pred[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["audio"]) + + audios = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + codes, audio_hats = model.inference( + audios.to(device), target_bw=params.target_bw + ) + audio_hats = audio_hats.squeeze(1).cpu() + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audios, + audio_hats, + audio_lens, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSCodecDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + # we need cut ids to display results of both constructed and ground-truth audio + args.return_cuts = True + libritts = LibriTTSCodecDataModule(args) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + remove_encodec_weight_norm(model) + + model.to(device) + model.eval() + + encoder = model.encoder + decoder = model.decoder + quantizer = model.quantizer + multi_scale_discriminator = model.multi_scale_discriminator + multi_period_discriminator = model.multi_period_discriminator + multi_scale_stft_discriminator = model.multi_scale_stft_discriminator + + num_param_e = sum([p.numel() for p in encoder.parameters()]) + logging.info(f"Number of parameters in encoder: {num_param_e}") + num_param_d = sum([p.numel() for p in decoder.parameters()]) + logging.info(f"Number of parameters in decoder: {num_param_d}") + num_param_q = sum([p.numel() for p in quantizer.parameters()]) + logging.info(f"Number of parameters in quantizer: {num_param_q}") + num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) + logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") + num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) + logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") + num_param_dstft = sum( + [p.numel() for p in multi_scale_stft_discriminator.parameters()] + ) + logging.info( + f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" + ) + logging.info( + f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" + ) + + test_clean_cuts = libritts.test_clean_cuts() + test_clean = libritts.test_dataloaders(test_clean_cuts) + + test_other_cuts = libritts.test_other_cuts() + test_other = libritts.test_dataloaders(test_other_cuts) + + dev_clean_cuts = libritts.dev_clean_cuts() + dev_clean = libritts.valid_dataloaders(dev_clean_cuts) + + dev_other_cuts = libritts.dev_other_cuts() + dev_other = libritts.valid_dataloaders(dev_other_cuts) + + infer_sets = { + "test-clean": test_clean, + "test-other": test_other, + "dev-clean": dev_clean, + "dev-other": dev_other, + } + + for subset, dl in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/CODEC/encodec/quantization/core_vq.py b/egs/libritts/CODEC/encodec/quantization/core_vq.py index 66d3dcf5dd..4719e20f7f 100644 --- a/egs/libritts/CODEC/encodec/quantization/core_vq.py +++ b/egs/libritts/CODEC/encodec/quantization/core_vq.py @@ -360,7 +360,7 @@ def encode( all_indices = [] n_q = n_q or len(self.layers) st = st or 0 - for layer in self.layers[st:n_q]: # 设置解码的起止layer + for layer in self.layers[st:n_q]: indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 842689155a..65aec13831 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -136,12 +136,6 @@ def get_parser(): default=False, help="Whether to use half precision training.", ) - parser.add_argument( - "--chunk-size", - type=int, - default=1, - help="The chunk size for the dataset (in second).", - ) return parser @@ -191,6 +185,7 @@ def get_params() -> AttributeDict: "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 24000, + "chunk_size": 1.0, # in seconds "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss "lambda_wav": 100.0, # loss scaling coefficient for waveform loss "lambda_feat": 1.0, # loss scaling coefficient for feat loss @@ -570,7 +565,7 @@ def save_bad_model(suffix: str = ""): valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - for index in range(params.num_samples): # 3 + for index in range(params.num_samples): # 3 speech_hat_i = speech_hat[index] speech_i = speech[index] if speech_hat_i.dim() > 1: @@ -655,8 +650,10 @@ def compute_validation_loss( # infer for first batch: if batch_idx == 0 and rank == 0: inner_model = model.module if isinstance(model, DDP) else model - audio_pred = inner_model.inference(x=audio, target_bw=params.target_bw) - returned_sample = (audio_pred, audio) + _, audio_hat = inner_model.inference( + x=audio, target_bw=params.target_bw + ) + returned_sample = (audio_hat, audio) if world_size > 1: tot_loss.reduce(device) From d45b400805c0fbd063d9a23d45efbd97e3c42252 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 8 Sep 2024 11:16:12 +0800 Subject: [PATCH 11/33] minor updates --- egs/libritts/ASR/prepare.sh | 10 +- .../local/compute_spectrogram_libritts.py | 6 +- .../local/display_manifest_statistics.py | 341 ++++++++++++++++++ egs/libritts/CODEC/local/validate_manifest.py | 1 + egs/libritts/CODEC/prepare.sh | 87 +++++ egs/libritts/CODEC/shared | 1 + 6 files changed, 439 insertions(+), 7 deletions(-) rename egs/libritts/{ASR => CODEC}/local/compute_spectrogram_libritts.py (97%) create mode 100755 egs/libritts/CODEC/local/display_manifest_statistics.py create mode 120000 egs/libritts/CODEC/local/validate_manifest.py create mode 100755 egs/libritts/CODEC/prepare.sh create mode 120000 egs/libritts/CODEC/shared diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh index f3a78bdb81..23c84e8386 100755 --- a/egs/libritts/ASR/prepare.sh +++ b/egs/libritts/ASR/prepare.sh @@ -85,10 +85,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # Here we shuffle and combine the train-clean-100, train-clean-360 and # train-other-500 together to form the training set. if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then - cat <(gunzip -c ./libritts_cuts_train-clean-100.jsonl.gz) \ - <(gunzip -c ./libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c ./libritts_cuts_train-other-500.jsonl.gz) | \ - shuf | gzip -c > ./libritts_cuts_train-all-shuf.jsonl.gz + cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz fi if [ ! -e data/fbank/.libritts-validated.done ]; then @@ -106,4 +106,4 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then ./local/compute_fbank_musan.py touch data/fbank/.msuan.done fi -fi \ No newline at end of file +fi diff --git a/egs/libritts/ASR/local/compute_spectrogram_libritts.py b/egs/libritts/CODEC/local/compute_spectrogram_libritts.py similarity index 97% rename from egs/libritts/ASR/local/compute_spectrogram_libritts.py rename to egs/libritts/CODEC/local/compute_spectrogram_libritts.py index 6cdc55bc89..8d864db92b 100755 --- a/egs/libritts/ASR/local/compute_spectrogram_libritts.py +++ b/egs/libritts/CODEC/local/compute_spectrogram_libritts.py @@ -46,6 +46,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) + def get_args(): parser = argparse.ArgumentParser() @@ -64,12 +65,13 @@ def get_args(): return parser.parse_args() -def compute_spectrogram_libritts(dataset: Optional[str] = None, sampling_rate: int = 24000,): +def compute_spectrogram_libritts( + dataset: Optional[str] = None, sampling_rate: int = 24000 +): src_dir = Path("data/manifests") output_dir = Path("data/spectrogram") num_jobs = min(32, os.cpu_count()) - frame_length = 1024 / sampling_rate # (in second) frame_shift = 256 / sampling_rate # (in second) use_fft_mag = True diff --git a/egs/libritts/CODEC/local/display_manifest_statistics.py b/egs/libritts/CODEC/local/display_manifest_statistics.py new file mode 100755 index 0000000000..ec00e04545 --- /dev/null +++ b/egs/libritts/CODEC/local/display_manifest_statistics.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + paths = [ + "./data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz", + "./data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz", + "./data/spectrogram/libritts_cuts_train-other-500.jsonl.gz", + "./data/spectrogram/libritts_cuts_dev-clean.jsonl.gz", + "./data/spectrogram/libritts_cuts_dev-other.jsonl.gz", + "./data/spectrogram/libritts_cuts_test-clean.jsonl.gz", + "./data/spectrogram/libritts_cuts_test-other.jsonl.gz", + ] + for path in paths: + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +./data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 33236 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 53:47:18 _ +________________________________________ +_ mean _ 5.8 _ +________________________________________ +_ std _ 4.6 _ +________________________________________ +_ min _ 0.2 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.5 _ +________________________________________ +_ 75% _ 7.9 _ +________________________________________ +_ 99% _ 21.4 _ +________________________________________ +_ 99.5% _ 23.7 _ +________________________________________ +_ 99.9% _ 27.8 _ +________________________________________ +_ max _ 33.2 _ +________________________________________ +_ Recordings available: _ 33236 _ +________________________________________ +_ Features available: _ 33236 _ +________________________________________ +_ Supervisions available: _ 33236 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 116500 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 191:17:42 _ +_________________________________________ +_ mean _ 5.9 _ +_________________________________________ +_ std _ 4.6 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.4 _ +_________________________________________ +_ 50% _ 4.6 _ +_________________________________________ +_ 75% _ 8.1 _ +_________________________________________ +_ 99% _ 21.3 _ +_________________________________________ +_ 99.5% _ 23.4 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 40.4 _ +_________________________________________ +_ Recordings available: _ 116500 _ +_________________________________________ +_ Features available: _ 116500 _ +_________________________________________ +_ Supervisions available: _ 116500 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/spectrogram/libritts_cuts_train-other-500.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 205043 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 310:04:36 _ +_________________________________________ +_ mean _ 5.4 _ +_________________________________________ +_ std _ 4.4 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.3 _ +_________________________________________ +_ 50% _ 4.2 _ +_________________________________________ +_ 75% _ 7.3 _ +_________________________________________ +_ 99% _ 20.6 _ +_________________________________________ +_ 99.5% _ 22.8 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 43.9 _ +_________________________________________ +_ Recordings available: _ 205043 _ +_________________________________________ +_ Features available: _ 205043 _ +_________________________________________ +_ Supervisions available: _ 205043 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/spectrogram/libritts_cuts_dev-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5736 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:58:13 _ +________________________________________ +_ mean _ 5.6 _ +________________________________________ +_ std _ 4.3 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.4 _ +________________________________________ +_ 75% _ 7.8 _ +________________________________________ +_ 99% _ 19.9 _ +________________________________________ +_ 99.5% _ 21.9 _ +________________________________________ +_ 99.9% _ 26.3 _ +________________________________________ +_ max _ 30.1 _ +________________________________________ +_ Recordings available: _ 5736 _ +________________________________________ +_ Features available: _ 5736 _ +________________________________________ +_ Supervisions available: _ 5736 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_dev-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4613 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:25:52 _ +________________________________________ +_ mean _ 5.0 _ +________________________________________ +_ std _ 4.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.2 _ +________________________________________ +_ 50% _ 3.8 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 19.7 _ +________________________________________ +_ 99.5% _ 24.5 _ +________________________________________ +_ 99.9% _ 31.0 _ +________________________________________ +_ max _ 32.6 _ +________________________________________ +_ Recordings available: _ 4613 _ +________________________________________ +_ Features available: _ 4613 _ +________________________________________ +_ Supervisions available: _ 4613 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_test-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4837 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:34:09 _ +________________________________________ +_ mean _ 6.4 _ +________________________________________ +_ std _ 5.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.8 _ +________________________________________ +_ 75% _ 8.9 _ +________________________________________ +_ 99% _ 22.6 _ +________________________________________ +_ 99.5% _ 24.4 _ +________________________________________ +_ 99.9% _ 29.6 _ +________________________________________ +_ max _ 36.7 _ +________________________________________ +_ Recordings available: _ 4837 _ +________________________________________ +_ Features available: _ 4837 _ +________________________________________ +_ Supervisions available: _ 4837 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_test-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5120 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:41:31 _ +________________________________________ +_ mean _ 4.7 _ +________________________________________ +_ std _ 3.8 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 3.6 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 17.8 _ +________________________________________ +_ 99.5% _ 20.4 _ +________________________________________ +_ 99.9% _ 23.8 _ +________________________________________ +_ max _ 27.3 _ +________________________________________ +_ Recordings available: _ 5120 _ +________________________________________ +_ Features available: _ 5120 _ +________________________________________ +_ Supervisions available: _ 5120 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ +""" diff --git a/egs/libritts/CODEC/local/validate_manifest.py b/egs/libritts/CODEC/local/validate_manifest.py new file mode 120000 index 0000000000..b4d52ebca0 --- /dev/null +++ b/egs/libritts/CODEC/local/validate_manifest.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/libritts/CODEC/prepare.sh b/egs/libritts/CODEC/prepare.sh new file mode 100755 index 0000000000..3dcb734745 --- /dev/null +++ b/egs/libritts/CODEC/prepare.sh @@ -0,0 +1,87 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 +sampling_rate=24000 +nj=32 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriTTS, + # you can create a symlink + # + # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS + # + if [ ! -d $dl_dir/LibriTTS ]; then + lhotse download libritts $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/musan + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriTTS manifest" + # We assume that you have downloaded the LibriTTS corpus + # to $dl_dir/LibriTTS + mkdir -p data/manifests + if [ ! -e data/manifests/.libritts.done ]; then + lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests + touch data/manifests/.libritts.done + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute Spectrogram for LibriTTS" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.libritts.done ]; then + ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate + touch data/spectrogram/.libritts.done + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 and + # train-other-500 together to form the training set. + if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c /data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + fi + + if [ ! -e data/spectrogram/.libritts-validated.done ]; then + log "Validating data/spectrogram for LibriTTS" + ./local/validate_manifest.py \ + data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + touch data/spectrogram/.libritts-validated.done + fi +fi + diff --git a/egs/libritts/CODEC/shared b/egs/libritts/CODEC/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/libritts/CODEC/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From c43977ea054a158a03052c9a1ca6c69285cec565 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 8 Sep 2024 11:23:27 +0800 Subject: [PATCH 12/33] black formatted --- egs/libritts/ASR/zipformer/ctc_decode.py | 1 + egs/libritts/ASR/zipformer/decode.py | 1 + egs/libritts/ASR/zipformer/onnx_decode.py | 1 + egs/libritts/ASR/zipformer/streaming_decode.py | 7 ++----- egs/libritts/ASR/zipformer/train.py | 1 + 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py index c31b1362ac..177f2e392d 100755 --- a/egs/libritts/ASR/zipformer/ctc_decode.py +++ b/egs/libritts/ASR/zipformer/ctc_decode.py @@ -4,6 +4,7 @@ # Liyong Guo, # Quandong Wang, # Zengwei Yao) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py index 1249254efd..8b033ce90f 100755 --- a/egs/libritts/ASR/zipformer/decode.py +++ b/egs/libritts/ASR/zipformer/decode.py @@ -2,6 +2,7 @@ # # Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, # Zengwei Yao) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/libritts/ASR/zipformer/onnx_decode.py b/egs/libritts/ASR/zipformer/onnx_decode.py index 4b1e2cc5cf..99a02c5cf3 100755 --- a/egs/libritts/ASR/zipformer/onnx_decode.py +++ b/egs/libritts/ASR/zipformer/onnx_decode.py @@ -3,6 +3,7 @@ # Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, # Zengwei Yao, # Xiaoyu Yang) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/libritts/ASR/zipformer/streaming_decode.py b/egs/libritts/ASR/zipformer/streaming_decode.py index e771bbafe9..4e2f1ecb9f 100755 --- a/egs/libritts/ASR/zipformer/streaming_decode.py +++ b/egs/libritts/ASR/zipformer/streaming_decode.py @@ -2,6 +2,7 @@ # Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, # Fangjun Kuang, # Zengwei Yao) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -388,11 +389,7 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index fef2e2ae5e..5eb90efbcb 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -4,6 +4,7 @@ # Mingshuang Luo, # Zengwei Yao, # Daniel Povey) +# Copyright 2024 The Chinese Univ. of HK (author: Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # From 1e65a976d03c6ec07736a841c1c1e3b912ad7daa Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 8 Sep 2024 15:37:06 +0800 Subject: [PATCH 13/33] added pesq and stoi for reconstruction performance evaluation --- egs/libritts/CODEC/encodec/infer.py | 52 ++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py index dccff984d5..c407b4a593 100755 --- a/egs/libritts/CODEC/encodec/infer.py +++ b/egs/libritts/CODEC/encodec/infer.py @@ -30,12 +30,16 @@ import logging from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Dict, List +from statistics import mean +from typing import List, Tuple +import numpy as np import torch -import torch.nn.functional as F import torchaudio from codec_datamodule import LibriTTSCodecDataModule +from pesq import pesq +from pystoi import stoi +from scipy import signal from torch import nn from train import get_model, get_params @@ -105,12 +109,25 @@ def remove_encodec_weight_norm(model) -> None: remove_weight_norm(decoder._modules[key].conv.conv) +def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float: + """Compute PESQ score between reference and generated audio.""" + DEFAULT_SAMPLING_RATE = 16000 + ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE) + deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE) + return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb") + + +def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float: + """Compute STOI score between reference and generated audio.""" + return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False) + + def infer_dataset( dl: torch.utils.data.DataLoader, subset: str, params: AttributeDict, model: nn.Module, -) -> None: +) -> Tuple[float, float]: """Decode dataset. The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. @@ -123,6 +140,9 @@ def infer_dataset( It is returned by :func:`get_params`. model: The neural model. + + Returns: + The average PESQ and STOI scores. """ # Background worker save audios to disk. @@ -150,6 +170,9 @@ def _save_worker( num_cuts = 0 log_interval = 5 + pesq_wb_scores = [] + stoi_scores = [] + try: num_batches = len(dl) except TypeError: @@ -169,6 +192,25 @@ def _save_worker( ) audio_hats = audio_hats.squeeze(1).cpu() + for cut_id, audio, audio_hat, audio_len in zip( + cut_ids, audios, audio_hats, audio_lens + ): + try: + pesq_wb = compute_pesq( + ref_wav=audio[:audio_len].numpy(), + gen_wav=audio_hat[:audio_len].numpy(), + ) + pesq_wb_scores.append(pesq_wb) + except Exception as e: + logging.error(f"Error while computing PESQ for cut {cut_id}: {e}") + + stoi_score = compute_stoi( + ref_wav=audio[:audio_len].numpy(), + gen_wav=audio_hat[:audio_len].numpy(), + sampling_rate=params.sampling_rate, + ) + stoi_scores.append(stoi_score) + futures.append( executor.submit( _save_worker, @@ -192,6 +234,7 @@ def _save_worker( # return results for f in futures: f.result() + return mean(pesq_wb_scores), mean(stoi_scores) @torch.no_grad() @@ -285,12 +328,13 @@ def main(): logging.info(f"Processing {subset} set, saving to {save_wav_dir}") - infer_dataset( + pesq_wb, stoi = infer_dataset( dl=dl, subset=subset, params=params, model=model, ) + logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}") logging.info(f"Wav files are saved to {params.save_wav_dir}") logging.info("Done!") From f9340cc5d7b8f4f5f374d053057c6d382d7ee867 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sat, 5 Oct 2024 23:11:43 +0800 Subject: [PATCH 14/33] refactored loss functions --- .../CODEC/encodec/codec_datamodule.py | 2 +- egs/libritts/CODEC/encodec/encodec.py | 156 +++++-- egs/libritts/CODEC/encodec/infer.py | 2 +- egs/libritts/CODEC/encodec/loss.py | 426 +++++++++++++----- egs/libritts/CODEC/encodec/train.py | 220 ++++++++- 5 files changed, 620 insertions(+), 186 deletions(-) diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py index e84f08e708..e77a255e56 100644 --- a/egs/libritts/CODEC/encodec/codec_datamodule.py +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -139,7 +139,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--num-workers", type=int, - default=2, + default=8, help="The number of training dataloader workers that " "collect the batches.", ) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 4f45be9c25..4701423922 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -4,7 +4,13 @@ import numpy as np import torch -from loss import loss_dis, loss_g +from loss import ( + DiscriminatorAdversarialLoss, + FeatureMatchLoss, + GeneratorAdversarialLoss, + MelSpectrogramReconstructionLoss, + WavReconstructionLoss, +) from torch import nn from torch.cuda.amp import autocast @@ -47,11 +53,23 @@ def __init__( self.cache_generator_outputs = cache_generator_outputs self._cache = None + # construct loss functions + self.generator_adversarial_loss = GeneratorAdversarialLoss( + average_by_discriminators=True, loss_type="hinge" + ) + self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( + average_by_discriminators=True, loss_type="hinge" + ) + self.feature_match_loss = FeatureMatchLoss(average_by_layers=False) + self.wav_reconstruction_loss = WavReconstructionLoss() + self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( + sampling_rate=self.sampling_rate + ) + def _forward_generator( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - global_step: int, return_sample: bool = False, ): """Perform generator forward. @@ -59,7 +77,6 @@ def _forward_generator( Args: speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). - global_step (int): Global step. return_sample (bool): Return the generator output. Returns: @@ -107,33 +124,56 @@ def _forward_generator( # calculate losses with autocast(enabled=False): - loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( - commit_loss, - speech, - speech_hat, - fmap, - fmap_hat, - y, - y_hat, - global_step, - y_p, - y_p_hat, - y_s, - y_s_hat, - fmap_p, - fmap_p_hat, - fmap_s, - fmap_s_hat, - args=self.params, + gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) + gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) + gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) + + feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat) + feature_period_loss = self.feature_match_loss( + feats=fmap_p, feats_hat=fmap_p_hat + ) + feature_scale_loss = self.feature_match_loss( + feats=fmap_s, feats_hat=fmap_s_hat + ) + + wav_reconstruction_loss = self.wav_reconstruction_loss( + x=speech, x_hat=speech_hat + ) + mel_reconstruction_loss = self.mel_reconstruction_loss( + x=speech, x_hat=speech_hat ) + # loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( + # commit_loss, + # speech, + # speech_hat, + # fmap, + # fmap_hat, + # y, + # y_hat, + # y_p, + # y_p_hat, + # y_s, + # y_s_hat, + # fmap_p, + # fmap_p_hat, + # fmap_s, + # fmap_s_hat, + # args=self.params, + # ) + stats = dict( - generator_loss=loss.item(), - generator_reconstruction_loss=rec_loss.item(), - generator_feature_loss=feat_loss.item(), - generator_adv_loss=adv_loss.item(), + # generator_loss=loss.item(), + generator_wav_reconstruction_loss=wav_reconstruction_loss.item(), + generator_mel_reconstruction_loss=mel_reconstruction_loss.item(), + generator_feature_stft_loss=feature_stft_loss.item(), + generator_feature_period_loss=feature_period_loss.item(), + generator_feature_scale_loss=feature_scale_loss.item(), + generator_stft_adv_loss=gen_stft_adv_loss.item(), + generator_period_adv_loss=gen_period_adv_loss.item(), + generator_scale_adv_loss=gen_scale_adv_loss.item(), generator_commit_loss=commit_loss.item(), - d_weight=d_weight.item(), + # d_weight=d_weight.item(), ) if return_sample: @@ -147,19 +187,28 @@ def _forward_generator( # reset cache if reuse_cache or not self.training: self._cache = None - return loss, stats + return ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats, + ) def _forward_discriminator( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - global_step: int, ): """ Args: speech (Tensor): Speech waveform tensor (B, T_wav). speech_lengths (Tensor): Speech length tensor (B,). - global_step (int): Global step. Returns: * loss (Tensor): Loss scalar tensor. @@ -206,37 +255,46 @@ def _forward_discriminator( ) # calculate losses with autocast(enabled=False): - loss = loss_dis( - y, - y_hat, - fmap, - fmap_hat, - y_p, - y_p_hat, - fmap_p, - fmap_p_hat, - y_s, - y_s_hat, - fmap_s, - fmap_s_hat, - global_step, - args=self.params, - ) + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat) + ( + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + ) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat) + ( + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + ) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat) + stats = dict( - discriminator_loss=loss.item(), + discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(), + discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(), + discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(), + discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(), + discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(), + discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(), ) # reset cache if reuse_cache or not self.training: self._cache = None - return loss, stats + return ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats, + ) def forward( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - global_step: int, return_sample: bool, forward_generator: bool, ): @@ -244,14 +302,12 @@ def forward( return self._forward_generator( speech=speech, speech_lengths=speech_lengths, - global_step=global_step, return_sample=return_sample, ) else: return self._forward_discriminator( speech=speech, speech_lengths=speech_lengths, - global_step=global_step, ) def encode(self, x, target_bw=None, st=None): diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py index c407b4a593..e5d69fa600 100755 --- a/egs/libritts/CODEC/encodec/infer.py +++ b/egs/libritts/CODEC/encodec/infer.py @@ -71,7 +71,7 @@ def get_parser(): parser.add_argument( "--target-bw", type=float, - default=7.5, + default=24, help="The target bandwidth for the generator", ) diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 96300e9d67..7e9bf5590d 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -1,8 +1,310 @@ +from typing import List, Tuple, Union + import torch import torch.nn.functional as F +from lhotse.features.kaldi import Wav2LogFilterBank from torchaudio.transforms import MelSpectrogram +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "hinge", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + if isinstance(outputs, (tuple, list)): + adv_loss = 0.0 + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + adv_loss = self.criterion(outputs) + + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return F.relu(1 - x).mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "hinge", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + if isinstance(outputs, (tuple, list)): + real_loss = 0.0 + fake_loss = 0.0 + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + real_loss = self.real_criterion(outputs) + fake_loss = self.fake_criterion(outputs_hat) + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.relu(x.new_ones(x.size()) - x).mean() + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.relu(x.new_ones(x.size()) + x).mean() + + +class FeatureMatchLoss(torch.nn.Module): + """Feature matching loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = False, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramReconstructionLoss(torch.nn.Module): + """Mel Spec Reconstruction loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + n_mels: int = 64, + use_fft_mag: bool = True, + return_mel: bool = False, + ): + super().__init__() + self.wav_to_specs = [] + for i in range(5, 12): + s = 2**i + # self.wav_to_specs.append( + # Wav2LogFilterBank( + # sampling_rate=sampling_rate, + # frame_length=s, + # frame_shift=s // 4, + # use_fft_mag=use_fft_mag, + # num_filters=n_mels, + # ) + # ) + self.wav_to_specs.append( + MelSpectrogram( + sample_rate=sampling_rate, + n_fft=max(s, 512), + win_length=s, + hop_length=s // 4, + n_mels=n_mels, + ) + ) + self.return_mel = return_mel + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Calculate Mel-spectrogram loss. + + Args: + x_hat (Tensor): Generated waveform tensor (B, 1, T). + x (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_loss = 0.0 + + for i, wav_to_spec in enumerate(self.wav_to_specs): + s = 2 ** (i + 5) + wav_to_spec.to(x.device) + + mel_hat = wav_to_spec(x_hat.squeeze(1)) + mel = wav_to_spec(x.squeeze(1)) + + alpha = (s / 2) ** 0.5 + mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel) + + # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) + # mel = self.wav_to_spec(x.squeeze(1)) + # mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel) + + if self.return_mel: + return mel_loss, (mel_hat, mel) + + return mel_loss + + +class WavReconstructionLoss(torch.nn.Module): + """Wav Reconstruction loss.""" + + def __init__(self): + super().__init__() + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + ) -> torch.Tensor: + """Calculate wav loss. + + Args: + x_hat (Tensor): Generated waveform tensor (B, 1, T). + x (Tensor): Groundtruth waveform tensor (B, 1, T). + + Returns: + Tensor: Wav loss value. + + """ + wav_loss = F.mse_loss(x, x_hat) + + return wav_loss + + def adversarial_g_loss(y_disc_gen): """Hinge loss""" loss = 0.0 @@ -63,88 +365,12 @@ def reconstruction_loss(x, x_hat, args, eps=1e-7): return L -def criterion_d( - y_disc_r, - y_disc_gen, - fmap_r_det, - fmap_gen_det, - y_df_hat_r, - y_df_hat_g, - fmap_f_r, - fmap_f_g, - y_ds_hat_r, - y_ds_hat_g, - fmap_s_r, - fmap_s_g, -): - """Hinge Loss""" - loss = 0.0 - loss1 = 0.0 - loss2 = 0.0 - loss3 = 0.0 - for i in range(len(y_disc_r)): - loss1 += F.relu(1 - y_disc_r[i]).mean() + F.relu(1 + y_disc_gen[i]).mean() - for i in range(len(y_df_hat_r)): - loss2 += F.relu(1 - y_df_hat_r[i]).mean() + F.relu(1 + y_df_hat_g[i]).mean() - for i in range(len(y_ds_hat_r)): - loss3 += F.relu(1 - y_ds_hat_r[i]).mean() + F.relu(1 + y_ds_hat_g[i]).mean() - - loss = ( - loss1 / len(y_disc_gen) + loss2 / len(y_df_hat_r) + loss3 / len(y_ds_hat_r) - ) / 3.0 - - return loss - - -def criterion_g( - commit_loss, - x, - G_x, - fmap_r, - fmap_gen, - y_disc_r, - y_disc_gen, - y_df_hat_r, - y_df_hat_g, - fmap_f_r, - fmap_f_g, - y_ds_hat_r, - y_ds_hat_g, - fmap_s_r, - fmap_s_g, - args, -): - adv_g_loss = adversarial_g_loss(y_disc_gen) - feat_loss = ( - feature_loss(fmap_r, fmap_gen) - + sim_loss(y_disc_r, y_disc_gen) - + feature_loss(fmap_f_r, fmap_f_g) - + sim_loss(y_df_hat_r, y_df_hat_g) - + feature_loss(fmap_s_r, fmap_s_g) - + sim_loss(y_ds_hat_r, y_ds_hat_g) - ) / 3.0 - rec_loss = reconstruction_loss(x.contiguous(), G_x.contiguous(), args) - total_loss = ( - args.lambda_com * commit_loss - + args.lambda_adv * adv_g_loss - + args.lambda_feat * feat_loss - + args.lambda_rec * rec_loss - ) - return total_loss, adv_g_loss, feat_loss, rec_loss - - def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight -def adopt_dis_weight(weight, global_step, threshold=0, value=0.0): - if global_step % 3 == 0: - weight = value - return weight - - def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): if last_layer is not None: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] @@ -166,7 +392,6 @@ def loss_g( fmap_hat, y, y_hat, - global_step, y_df, y_df_hat, y_ds, @@ -215,9 +440,10 @@ def loss_g( feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 d_weight = torch.tensor(1.0) - disc_factor = adopt_weight( - args.lambda_adv, global_step, threshold=args.discriminator_iter_start - ) + # disc_factor = adopt_weight( + # args.lambda_adv, global_step, threshold=args.discriminator_iter_start + # ) + disc_factor = 1 if disc_factor == 0.0: fm_loss_wt = 0 else: @@ -232,37 +458,9 @@ def loss_g( return loss, rec_loss, adv_loss, feat_loss_tot, d_weight -def loss_dis( - y, - y_hat, - fmap, - fmap_hat, - y_df, - y_df_hat, - fmap_f, - fmap_f_hat, - y_ds, - y_ds_hat, - fmap_s, - fmap_s_hat, - global_step, - args, -): - disc_factor = adopt_weight( - args.lambda_adv, global_step, threshold=args.discriminator_iter_start - ) - d_loss = disc_factor * criterion_d( - y, - y_hat, - fmap, - fmap_hat, - y_df, - y_df_hat, - fmap_f, - fmap_f_hat, - y_ds, - y_ds_hat, - fmap_s, - fmap_s_hat, - ) - return d_loss +if __name__ == "__main__": + la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False) + aa = [torch.rand(192, 192) for _ in range(3)] + bb = [torch.rand(192, 192) for _ in range(3)] + print(la(bb, aa)) + print(feature_loss(aa, bb)) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 65aec13831..206a72a760 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -15,12 +15,13 @@ from encodec import Encodec from lhotse.cut import Cut from lhotse.utils import fix_random_seed +from loss import adopt_weight from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter -from utils import MetricsTracker, plot_feature, save_checkpoint +from utils import MetricsTracker, save_checkpoint from icefall import diagnostics from icefall.checkpoint import load_checkpoint @@ -250,11 +251,26 @@ def get_model(params: AttributeDict) -> nn.Module: from modules.seanet import SEANetDecoder, SEANetEncoder from quantization import ResidualVectorQuantizer + # generator_params = { + # "generator_n_filters": 32, + # "dimension": 512, + # "ratios": [2, 2, 2, 4], + # "target_bandwidths": [7.5, 15], + # "bins": 1024, + # } + # discriminator_params = { + # "stft_discriminator_n_filters": 32, + # "discriminator_iter_start": 500, + # } + # inference_params = { + # "target_bw": 7.5, + # } + generator_params = { "generator_n_filters": 32, "dimension": 512, - "ratios": [2, 2, 2, 4], - "target_bandwidths": [7.5, 15], + "ratios": [8, 5, 4, 2], + "target_bandwidths": [1.5, 3, 6, 12, 24], "bins": 1024, } discriminator_params = { @@ -262,7 +278,7 @@ def get_model(params: AttributeDict) -> nn.Module: "discriminator_iter_start": 500, } inference_params = { - "target_bw": 7.5, + "target_bw": 12, } params.update(generator_params) @@ -419,36 +435,93 @@ def save_bad_model(suffix: str = ""): try: with autocast(enabled=params.use_fp16): + d_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) # forward discriminator - loss_d, stats_d = model( + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, return_sample=False, forward_generator=False, ) + disc_loss = ( + ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) + * d_weight + / 3 + ) for k, v in stats_d.items(): loss_info[k] = v * batch_size # update discriminator optimizer_d.zero_grad() - scaler.scale(loss_d).backward() + scaler.scale(disc_loss).backward() scaler.step(optimizer_d) with autocast(enabled=params.use_fp16): + g_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) # forward generator - loss_g, stats_g = model( + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, forward_generator=True, return_sample=params.batch_idx_train % params.log_interval == 0, ) + gen_adv_loss = ( + (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) + * g_weight + / 3 + ) + feature_loss = ( + feature_stft_loss + feature_period_loss + feature_scale_loss + ) / 3 + reconstruction_loss = ( + params.lambda_wav * wav_reconstruction_loss + + mel_reconstruction_loss + ) + gen_loss = ( + gen_adv_loss + + params.lambda_rec * reconstruction_loss + + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_com * commit_loss + ) for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size # update generator optimizer_g.zero_grad() - scaler.scale(loss_g).backward() + scaler.scale(gen_loss).backward() scaler.step(optimizer_g) scaler.update() @@ -619,27 +692,84 @@ def compute_validation_loss( loss_info = MetricsTracker() loss_info["samples"] = batch_size + d_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) + # forward discriminator - loss_d, stats_d = model( + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, return_sample=False, forward_generator=False, ) - assert loss_d.requires_grad is False + disc_loss = ( + ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) + * d_weight + / 3 + ) + assert disc_loss.requires_grad is False for k, v in stats_d.items(): loss_info[k] = v * batch_size + g_weight = adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) # forward generator - loss_g, stats_g = model( + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, forward_generator=True, return_sample=False, ) - assert loss_g.requires_grad is False + gen_adv_loss = ( + (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) + * g_weight + / 3 + ) + feature_loss = ( + feature_stft_loss + feature_period_loss + feature_scale_loss + ) / 3 + reconstruction_loss = ( + params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss + ) + gen_loss = ( + gen_adv_loss + + params.lambda_rec * reconstruction_loss + + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_com * commit_loss + ) + assert gen_loss.requires_grad is False for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size @@ -691,24 +821,74 @@ def scan_pessimistic_batches_for_oom( try: # for discriminator with autocast(enabled=params.use_fp16): - loss_d, stats_d = model( + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( speech=audio, speech_lengths=audio_lens, - global_step=params.batch_idx_train, return_sample=False, forward_generator=False, ) + loss_d = ( + ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) + * adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) + / 3 + ) optimizer_d.zero_grad() loss_d.backward() # for generator with autocast(enabled=params.use_fp16): - loss_g, stats_g = model( + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( speech=audio, speech_lengths=audio_lens, forward_generator=True, - global_step=params.batch_idx_train, return_sample=False, ) + loss_g = ( + (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) + * adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, + ) + / 3 + + params.lambda_rec + * ( + params.lambda_wav * wav_reconstruction_loss + + mel_reconstruction_loss + ) + + params.lambda_feat + * (feature_stft_loss + feature_period_loss + feature_scale_loss) + + params.lambda_com * commit_loss + ) optimizer_g.zero_grad() loss_g.backward() except Exception as e: From e788bb4853c7455020a829b1476426aeb189bc11 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 6 Oct 2024 13:38:05 +0800 Subject: [PATCH 15/33] making MSD and MPD optional --- egs/libritts/CODEC/encodec/encodec.py | 100 +++++++++++++++--------- egs/libritts/CODEC/encodec/train.py | 108 ++++++++++++-------------- 2 files changed, 115 insertions(+), 93 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 4701423922..a2e540dcd1 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -1,6 +1,6 @@ import math import random -from typing import List +from typing import List, Optional import numpy as np import torch @@ -25,8 +25,8 @@ def __init__( quantizer: nn.Module, decoder: nn.Module, multi_scale_discriminator: nn.Module, - multi_period_discriminator: nn.Module, - multi_scale_stft_discriminator: nn.Module, + multi_period_discriminator: Optional[nn.Module] = None, + multi_scale_stft_discriminator: Optional[nn.Module] = None, cache_generator_outputs: bool = False, ): super(Encodec, self).__init__() @@ -113,28 +113,42 @@ def _forward_generator( with torch.no_grad(): # do not store discriminator gradient in generator turn y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) - y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( - speech.contiguous(), - speech_hat.contiguous(), - ) - y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( - speech.contiguous(), - speech_hat.contiguous(), - ) + + gen_period_adv_loss = torch.tensor(0.0) + feature_period_loss = torch.tensor(0.0) + if self.multi_period_discriminator is not None: + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) + + gen_scale_adv_loss = torch.tensor(0.0) + feature_scale_loss = torch.tensor(0.0) + if self.multi_scale_discriminator is not None: + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) # calculate losses with autocast(enabled=False): gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) - gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) - gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) + + if self.multi_period_discriminator is not None: + gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) + if self.multi_scale_discriminator is not None: + gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat) - feature_period_loss = self.feature_match_loss( - feats=fmap_p, feats_hat=fmap_p_hat - ) - feature_scale_loss = self.feature_match_loss( - feats=fmap_s, feats_hat=fmap_s_hat - ) + + if self.multi_period_discriminator is not None: + feature_period_loss = self.feature_match_loss( + feats=fmap_p, feats_hat=fmap_p_hat + ) + if self.multi_scale_discriminator is not None: + feature_scale_loss = self.feature_match_loss( + feats=fmap_s, feats_hat=fmap_s_hat + ) wav_reconstruction_loss = self.wav_reconstruction_loss( x=speech, x_hat=speech_hat @@ -245,28 +259,44 @@ def _forward_discriminator( y_hat, fmap_hat = self.multi_scale_stft_discriminator( speech_hat.contiguous().detach() ) - y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( - speech.contiguous(), - speech_hat.contiguous().detach(), - ) - y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( - speech.contiguous(), - speech_hat.contiguous().detach(), - ) + + disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor( + 0.0 + ), torch.tensor(0.0) + if self.multi_period_discriminator is not None: + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) + + disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor( + 0.0 + ), torch.tensor(0.0) + if self.multi_scale_discriminator is not None: + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) # calculate losses with autocast(enabled=False): ( disc_stft_real_adv_loss, disc_stft_fake_adv_loss, ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat) - ( - disc_period_real_adv_loss, - disc_period_fake_adv_loss, - ) = self.discriminator_adversarial_loss(outputs=y_p, outputs_hat=y_p_hat) - ( - disc_scale_real_adv_loss, - disc_scale_fake_adv_loss, - ) = self.discriminator_adversarial_loss(outputs=y_s, outputs_hat=y_s_hat) + if self.multi_period_discriminator is not None: + ( + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + ) = self.discriminator_adversarial_loss( + outputs=y_p, outputs_hat=y_p_hat + ) + if self.multi_scale_discriminator is not None: + ( + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + ) = self.discriminator_adversarial_loss( + outputs=y_s, outputs_hat=y_s_hat + ) stats = dict( discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(), diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 206a72a760..0adffb6587 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -313,8 +313,8 @@ def get_model(params: AttributeDict) -> nn.Module: encoder=encoder, quantizer=quantizer, decoder=decoder, - multi_scale_discriminator=MultiScaleDiscriminator(), - multi_period_discriminator=MultiPeriodDiscriminator(), + multi_scale_discriminator=None, + multi_period_discriminator=None, multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( n_filters=params.stft_discriminator_n_filters ), @@ -456,17 +456,13 @@ def save_bad_model(suffix: str = ""): forward_generator=False, ) disc_loss = ( - ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) - * d_weight - / 3 - ) + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * d_weight for k, v in stats_d.items(): loss_info[k] = v * batch_size # update discriminator @@ -499,13 +495,11 @@ def save_bad_model(suffix: str = ""): return_sample=params.batch_idx_train % params.log_interval == 0, ) gen_adv_loss = ( - (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) - * g_weight - / 3 - ) + gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss + ) * g_weight feature_loss = ( feature_stft_loss + feature_period_loss + feature_scale_loss - ) / 3 + ) reconstruction_loss = ( params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss @@ -714,17 +708,13 @@ def compute_validation_loss( forward_generator=False, ) disc_loss = ( - ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) - * d_weight - / 3 - ) + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * d_weight assert disc_loss.requires_grad is False for k, v in stats_d.items(): loss_info[k] = v * batch_size @@ -753,13 +743,9 @@ def compute_validation_loss( return_sample=False, ) gen_adv_loss = ( - (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) - * g_weight - / 3 - ) - feature_loss = ( - feature_stft_loss + feature_period_loss + feature_scale_loss - ) / 3 + gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss + ) * g_weight + feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss reconstruction_loss = ( params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss ) @@ -836,20 +822,16 @@ def scan_pessimistic_batches_for_oom( forward_generator=False, ) loss_d = ( - ( - disc_stft_real_adv_loss - + disc_stft_fake_adv_loss - + disc_period_real_adv_loss - + disc_period_fake_adv_loss - + disc_scale_real_adv_loss - + disc_scale_fake_adv_loss - ) - * adopt_weight( - params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, - ) - / 3 + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * adopt_weight( + params.lambda_adv, + params.batch_idx_train, + threshold=params.discriminator_iter_start, ) optimizer_d.zero_grad() loss_d.backward() @@ -879,7 +861,6 @@ def scan_pessimistic_batches_for_oom( params.batch_idx_train, threshold=params.discriminator_iter_start, ) - / 3 + params.lambda_rec * ( params.lambda_wav * wav_reconstruction_loss @@ -962,9 +943,17 @@ def run(rank, world_size, args): logging.info(f"Number of parameters in decoder: {num_param_d}") num_param_q = sum([p.numel() for p in quantizer.parameters()]) logging.info(f"Number of parameters in quantizer: {num_param_q}") - num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) + num_param_ds = ( + sum([p.numel() for p in multi_scale_discriminator.parameters()]) + if multi_scale_discriminator is not None + else 0 + ) logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") - num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) + num_param_dp = ( + sum([p.numel() for p in multi_period_discriminator.parameters()]) + if multi_period_discriminator is not None + else 0 + ) logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") num_param_dstft = sum( [p.numel() for p in multi_scale_stft_discriminator.parameters()] @@ -998,12 +987,15 @@ def run(rank, world_size, args): lr=params.lr, betas=(0.5, 0.9), ) + discriminator_params = [ + multi_scale_stft_discriminator.parameters(), + ] + if multi_scale_discriminator is not None: + discriminator_params.append(multi_scale_discriminator.parameters()) + if multi_period_discriminator is not None: + discriminator_params.append(multi_period_discriminator.parameters()) optimizer_d = torch.optim.AdamW( - itertools.chain( - multi_scale_stft_discriminator.parameters(), - multi_scale_discriminator.parameters(), - multi_period_discriminator.parameters(), - ), + itertools.chain(*discriminator_params), lr=params.lr, betas=(0.5, 0.9), ) From d83ce89fcac3f108cdd0b8955ac042b91514b38c Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 6 Oct 2024 15:55:49 +0800 Subject: [PATCH 16/33] fixed loss normalization & scaling factors --- egs/libritts/CODEC/encodec/encodec.py | 4 ++-- egs/libritts/CODEC/encodec/loss.py | 27 +++++++++++++++------------ egs/libritts/CODEC/encodec/train.py | 22 +++++++++++----------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index a2e540dcd1..725ce5d019 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -6,7 +6,7 @@ import torch from loss import ( DiscriminatorAdversarialLoss, - FeatureMatchLoss, + FeatureLoss, GeneratorAdversarialLoss, MelSpectrogramReconstructionLoss, WavReconstructionLoss, @@ -60,7 +60,7 @@ def __init__( self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( average_by_discriminators=True, loss_type="hinge" ) - self.feature_match_loss = FeatureMatchLoss(average_by_layers=False) + self.feature_match_loss = FeatureLoss(average_by_layers=False) self.wav_reconstruction_loss = WavReconstructionLoss() self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( sampling_rate=self.sampling_rate diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 7e9bf5590d..f4188a3134 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -57,7 +57,7 @@ def forward( else: adv_loss = self.criterion(outputs) - return adv_loss + return adv_loss / len(outputs) def _mse_loss(self, x): return F.mse_loss(x, x.new_ones(x.size())) @@ -129,7 +129,7 @@ def forward( real_loss = self.real_criterion(outputs) fake_loss = self.fake_criterion(outputs_hat) - return real_loss, fake_loss + return real_loss / len(outputs), fake_loss / len(outputs) def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_ones(x.size())) @@ -144,14 +144,14 @@ def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return F.relu(x.new_ones(x.size()) + x).mean() -class FeatureMatchLoss(torch.nn.Module): - """Feature matching loss module.""" +class FeatureLoss(torch.nn.Module): + """Feature loss module.""" def __init__( self, average_by_layers: bool = True, average_by_discriminators: bool = True, - include_final_outputs: bool = False, + include_final_outputs: bool = True, ): """Initialize FeatureMatchLoss module. @@ -195,14 +195,16 @@ def forward( feats_hat_ = feats_hat_[:-1] feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): - feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) + feat_match_loss_ += ( + (feat_hat_ - feat_).abs() / (feat_.abs().mean()) + ).mean() if self.average_by_layers: feat_match_loss_ /= j + 1 feat_match_loss += feat_match_loss_ if self.average_by_discriminators: feat_match_loss /= i + 1 - return feat_match_loss + return feat_match_loss / (len(feats) * len(feats[0])) class MelSpectrogramReconstructionLoss(torch.nn.Module): @@ -231,7 +233,7 @@ def __init__( self.wav_to_specs.append( MelSpectrogram( sample_rate=sampling_rate, - n_fft=max(s, 512), + n_fft=s, win_length=s, hop_length=s // 4, n_mels=n_mels, @@ -266,8 +268,9 @@ def forward( mel_hat = wav_to_spec(x_hat.squeeze(1)) mel = wav_to_spec(x.squeeze(1)) - alpha = (s / 2) ** 0.5 - mel_loss += F.l1_loss(mel_hat, mel) + alpha * F.mse_loss(mel_hat, mel) + mel_loss += F.l1_loss( + mel_hat, mel, reduce=True, reduction="mean" + ) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean") # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) # mel = self.wav_to_spec(x.squeeze(1)) @@ -300,7 +303,7 @@ def forward( Tensor: Wav loss value. """ - wav_loss = F.mse_loss(x, x_hat) + wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean") return wav_loss @@ -459,7 +462,7 @@ def loss_g( if __name__ == "__main__": - la = FeatureMatchLoss(average_by_layers=False, average_by_discriminators=False) + la = FeatureLoss(average_by_layers=False, average_by_discriminators=False) aa = [torch.rand(192, 192) for _ in range(3)] bb = [torch.rand(192, 192) for _ in range(3)] print(la(bb, aa)) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 0adffb6587..5b21c81dd2 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -187,11 +187,11 @@ def get_params() -> AttributeDict: "env_info": get_env_info(), "sampling_rate": 24000, "chunk_size": 1.0, # in seconds - "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss - "lambda_wav": 100.0, # loss scaling coefficient for waveform loss - "lambda_feat": 1.0, # loss scaling coefficient for feat loss + "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss + "lambda_wav": 0.1, # loss scaling coefficient for waveform loss + "lambda_feat": 3.0, # loss scaling coefficient for feat loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 1000.0, # loss scaling coefficient for commitment loss + "lambda_com": 1.0, # loss scaling coefficient for commitment loss } ) @@ -502,11 +502,11 @@ def save_bad_model(suffix: str = ""): ) reconstruction_loss = ( params.lambda_wav * wav_reconstruction_loss - + mel_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss ) gen_loss = ( gen_adv_loss - + params.lambda_rec * reconstruction_loss + + reconstruction_loss + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + params.lambda_com * commit_loss ) @@ -747,11 +747,12 @@ def compute_validation_loss( ) * g_weight feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss reconstruction_loss = ( - params.lambda_wav * wav_reconstruction_loss + mel_reconstruction_loss + params.lambda_wav * wav_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss ) gen_loss = ( gen_adv_loss - + params.lambda_rec * reconstruction_loss + + reconstruction_loss + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + params.lambda_com * commit_loss ) @@ -861,10 +862,9 @@ def scan_pessimistic_batches_for_oom( params.batch_idx_train, threshold=params.discriminator_iter_start, ) - + params.lambda_rec - * ( + + ( params.lambda_wav * wav_reconstruction_loss - + mel_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss ) + params.lambda_feat * (feature_stft_loss + feature_period_loss + feature_scale_loss) From 58f656282424c568bcb6c543fa6e81c75a6303c3 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 6 Oct 2024 19:07:07 +0800 Subject: [PATCH 17/33] added scheduler w/ warmup --- egs/libritts/CODEC/encodec/encodec.py | 2 +- egs/libritts/CODEC/encodec/loss.py | 43 +++++--- egs/libritts/CODEC/encodec/scheduler.py | 141 ++++++++++++++++++++++++ egs/libritts/CODEC/encodec/train.py | 61 ++++++---- 4 files changed, 209 insertions(+), 38 deletions(-) create mode 100644 egs/libritts/CODEC/encodec/scheduler.py diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 725ce5d019..aa0373bfab 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -60,7 +60,7 @@ def __init__( self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( average_by_discriminators=True, loss_type="hinge" ) - self.feature_match_loss = FeatureLoss(average_by_layers=False) + self.feature_match_loss = FeatureLoss() self.wav_reconstruction_loss = WavReconstructionLoss() self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( sampling_rate=self.sampling_rate diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index f4188a3134..a4e0ec06de 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -45,8 +45,8 @@ def forward( Tensor: Generator adversarial loss value. """ + adv_loss = 0.0 if isinstance(outputs, (tuple, list)): - adv_loss = 0.0 for i, outputs_ in enumerate(outputs): if isinstance(outputs_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps @@ -55,9 +55,10 @@ def forward( if self.average_by_discriminators: adv_loss /= i + 1 else: - adv_loss = self.criterion(outputs) - - return adv_loss / len(outputs) + for i, outputs_ in enumerate(outputs): + adv_loss += self.criterion(outputs_) + adv_loss /= i + 1 + return adv_loss def _mse_loss(self, x): return F.mse_loss(x, x.new_ones(x.size())) @@ -112,9 +113,9 @@ def forward( Tensor: Discriminator fake loss value. """ + real_loss = 0.0 + fake_loss = 0.0 if isinstance(outputs, (tuple, list)): - real_loss = 0.0 - fake_loss = 0.0 for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): if isinstance(outputs_hat_, (tuple, list)): # NOTE(kan-bayashi): case including feature maps @@ -126,10 +127,13 @@ def forward( fake_loss /= i + 1 real_loss /= i + 1 else: - real_loss = self.real_criterion(outputs) - fake_loss = self.fake_criterion(outputs_hat) + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + fake_loss /= i + 1 + real_loss /= i + 1 - return real_loss / len(outputs), fake_loss / len(outputs) + return real_loss, fake_loss def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_ones(x.size())) @@ -204,7 +208,7 @@ def forward( if self.average_by_discriminators: feat_match_loss /= i + 1 - return feat_match_loss / (len(feats) * len(feats[0])) + return feat_match_loss class MelSpectrogramReconstructionLoss(torch.nn.Module): @@ -233,7 +237,7 @@ def __init__( self.wav_to_specs.append( MelSpectrogram( sample_rate=sampling_rate, - n_fft=s, + n_fft=max(s, 512), win_length=s, hop_length=s // 4, n_mels=n_mels, @@ -462,8 +466,15 @@ def loss_g( if __name__ == "__main__": - la = FeatureLoss(average_by_layers=False, average_by_discriminators=False) - aa = [torch.rand(192, 192) for _ in range(3)] - bb = [torch.rand(192, 192) for _ in range(3)] - print(la(bb, aa)) - print(feature_loss(aa, bb)) + # la = FeatureLoss(average_by_layers=True, average_by_discriminators=True) + # aa = [torch.rand(192, 192) for _ in range(3)] + # bb = [torch.rand(192, 192) for _ in range(3)] + # print(la(bb, aa)) + # print(feature_loss(aa, bb)) + la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge") + aa = torch.Tensor([0.1, 0.2, 0.3, 0.4]) + bb = torch.Tensor([0.4, 0.3, 0.2, 0.1]) + print(la(aa)) + print(adversarial_g_loss(aa)) + print(la(bb)) + print(adversarial_g_loss(bb)) diff --git a/egs/libritts/CODEC/encodec/scheduler.py b/egs/libritts/CODEC/encodec/scheduler.py new file mode 100644 index 0000000000..1a62e96f29 --- /dev/null +++ b/egs/libritts/CODEC/encodec/scheduler.py @@ -0,0 +1,141 @@ +import math +from bisect import bisect_right + +from torch.optim.lr_scheduler import _LRScheduler + + +class WarmupLrScheduler(_LRScheduler): + def __init__( + self, + optimizer, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.warmup_epoch = warmup_epoch + self.warmup_ratio = warmup_ratio + self.warmup = warmup + super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) + + def get_lr(self): + ratio = self.get_lr_ratio() + lrs = [ratio * lr for lr in self.base_lrs] + return lrs + + def get_lr_ratio(self): + if self.last_epoch < self.warmup_epoch: + ratio = self.get_warmup_ratio() + else: + ratio = self.get_main_ratio() + return ratio + + def get_main_ratio(self): + raise NotImplementedError + + def get_warmup_ratio(self): + assert self.warmup in ("linear", "exp") + alpha = self.last_epoch / self.warmup_epoch + if self.warmup == "linear": + ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha + elif self.warmup == "exp": + ratio = self.warmup_ratio ** (1.0 - alpha) + return ratio + + +class WarmupPolyLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + power, + max_epoch, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.power = power + self.max_epoch = max_epoch + super(WarmupPolyLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_epoch = self.last_epoch - self.warmup_epoch + real_max_epoch = self.max_epoch - self.warmup_epoch + alpha = real_epoch / real_max_epoch + ratio = (1 - alpha) ** self.power + return ratio + + +class WarmupExpLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + gamma, + interval=1, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.gamma = gamma + self.interval = interval + super(WarmupExpLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_epoch = self.last_epoch - self.warmup_epoch + ratio = self.gamma ** (real_epoch // self.interval) + return ratio + + +class WarmupCosineLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + max_epoch, + eta_ratio=0, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.eta_ratio = eta_ratio + self.max_epoch = max_epoch + super(WarmupCosineLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_max_epoch = self.max_epoch - self.warmup_epoch + return ( + self.eta_ratio + + (1 - self.eta_ratio) + * (1 + math.cos(math.pi * self.last_epoch / real_max_epoch)) + / 2 + ) + + +class WarmupStepLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + milestones: list, + gamma=0.1, + warmup_epoch=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.milestones = milestones + self.gamma = gamma + super(WarmupStepLrScheduler, self).__init__( + optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_epoch = self.last_epoch - self.warmup_epoch + ratio = self.gamma ** bisect_right(self.milestones, real_epoch) + return ratio diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 5b21c81dd2..0c761b8edc 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -16,6 +16,7 @@ from lhotse.cut import Cut from lhotse.utils import fix_random_seed from loss import adopt_weight +from scheduler import WarmupCosineLrScheduler from torch import nn from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP @@ -188,10 +189,10 @@ def get_params() -> AttributeDict: "sampling_rate": 24000, "chunk_size": 1.0, # in seconds "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss - "lambda_wav": 0.1, # loss scaling coefficient for waveform loss + "lambda_wav": 1.0, # loss scaling coefficient for waveform loss "lambda_feat": 3.0, # loss scaling coefficient for feat loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 1.0, # loss scaling coefficient for commitment loss + "lambda_com": 100.0, # loss scaling coefficient for commitment loss } ) @@ -260,7 +261,7 @@ def get_model(params: AttributeDict) -> nn.Module: # } # discriminator_params = { # "stft_discriminator_n_filters": 32, - # "discriminator_iter_start": 500, + # "discriminator_epoch_start": 5, # } # inference_params = { # "target_bw": 7.5, @@ -275,7 +276,10 @@ def get_model(params: AttributeDict) -> nn.Module: } discriminator_params = { "stft_discriminator_n_filters": 32, - "discriminator_iter_start": 500, + "discriminator_epoch_start": 3, + "n_ffts": [1024, 2048, 512], + "hop_lengths": [256, 512, 128], + "win_lengths": [1024, 2048, 512], } inference_params = { "target_bw": 12, @@ -316,7 +320,10 @@ def get_model(params: AttributeDict) -> nn.Module: multi_scale_discriminator=None, multi_period_discriminator=None, multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( - n_filters=params.stft_discriminator_n_filters + n_filters=params.stft_discriminator_n_filters, + n_ffts=params.n_ffts, + hop_lengths=params.hop_lengths, + win_lengths=params.win_lengths, ), ) return model @@ -437,8 +444,8 @@ def save_bad_model(suffix: str = ""): with autocast(enabled=params.use_fp16): d_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward discriminator ( @@ -473,8 +480,8 @@ def save_bad_model(suffix: str = ""): with autocast(enabled=params.use_fp16): g_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward generator ( @@ -507,7 +514,7 @@ def save_bad_model(suffix: str = ""): gen_loss = ( gen_adv_loss + reconstruction_loss - + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_feat * feature_loss + params.lambda_com * commit_loss ) for k, v in stats_g.items(): @@ -688,8 +695,8 @@ def compute_validation_loss( d_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward discriminator @@ -721,8 +728,8 @@ def compute_validation_loss( g_weight = adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_epoch_start, ) # forward generator ( @@ -753,7 +760,7 @@ def compute_validation_loss( gen_loss = ( gen_adv_loss + reconstruction_loss - + (params.lambda_feat if g_weight != 0.0 else 0.0) * feature_loss + + params.lambda_feat * feature_loss + params.lambda_com * commit_loss ) assert gen_loss.requires_grad is False @@ -831,8 +838,8 @@ def scan_pessimistic_batches_for_oom( + disc_scale_fake_adv_loss ) * adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + params.cur_epoch, + threshold=params.discriminator_train_start, ) optimizer_d.zero_grad() loss_d.backward() @@ -859,8 +866,8 @@ def scan_pessimistic_batches_for_oom( (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) * adopt_weight( params.lambda_adv, - params.batch_idx_train, - threshold=params.discriminator_iter_start, + 0, + threshold=params.discriminator_epoch_start, ) + ( params.lambda_wav * wav_reconstruction_loss @@ -1000,8 +1007,20 @@ def run(rank, world_size, args): betas=(0.5, 0.9), ) - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + scheduler_g = WarmupCosineLrScheduler( + optimizer=optimizer_g, + max_epoch=params.num_epochs, + eta_ratio=0.1, + warmup_epoch=params.discriminator_epoch_start, + warmup_ratio=1e-4, + ) + scheduler_d = WarmupCosineLrScheduler( + optimizer=optimizer_d, + max_epoch=params.num_epochs, + eta_ratio=0.1, + warmup_epoch=params.discriminator_epoch_start, + warmup_ratio=1e-4, + ) if checkpoints is not None: # load state_dict for optimizers From 01cc3076647425ab11d24cad1d8f64479f1fbcd6 Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 7 Oct 2024 01:03:26 +0800 Subject: [PATCH 18/33] fixed loss functions & scaling factors --- egs/libritts/CODEC/encodec/loss.py | 21 ++++--- egs/libritts/CODEC/encodec/scheduler.py | 75 ++++++++++++++++--------- egs/libritts/CODEC/encodec/train.py | 26 +++++---- 3 files changed, 80 insertions(+), 42 deletions(-) diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index a4e0ec06de..ae1e34bddf 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -142,10 +142,10 @@ def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: return F.mse_loss(x, x.new_zeros(x.size())) def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.relu(x.new_ones(x.size()) - x).mean() + return F.relu(torch.ones_like(x) - x).mean() def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: - return F.relu(x.new_ones(x.size()) + x).mean() + return F.relu(torch.ones_like(x) + x).mean() class FeatureLoss(torch.nn.Module): @@ -200,7 +200,7 @@ def forward( feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): feat_match_loss_ += ( - (feat_hat_ - feat_).abs() / (feat_.abs().mean()) + F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean()) ).mean() if self.average_by_layers: feat_match_loss_ /= j + 1 @@ -272,9 +272,16 @@ def forward( mel_hat = wav_to_spec(x_hat.squeeze(1)) mel = wav_to_spec(x.squeeze(1)) - mel_loss += F.l1_loss( - mel_hat, mel, reduce=True, reduction="mean" - ) + F.mse_loss(mel_hat, mel, reduce=True, reduction="mean") + mel_loss += ( + F.l1_loss(mel_hat, mel, reduce=True, reduction="mean") + + ( + ( + (torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7)) + ** 2 + ).mean(dim=-2) + ** 0.5 + ).mean() + ) # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) # mel = self.wav_to_spec(x.squeeze(1)) @@ -307,7 +314,7 @@ def forward( Tensor: Wav loss value. """ - wav_loss = F.l1_loss(x, x_hat, reduce=True, reduction="mean") + wav_loss = F.l1_loss(x, x_hat) return wav_loss diff --git a/egs/libritts/CODEC/encodec/scheduler.py b/egs/libritts/CODEC/encodec/scheduler.py index 1a62e96f29..fb6ba087d6 100644 --- a/egs/libritts/CODEC/encodec/scheduler.py +++ b/egs/libritts/CODEC/encodec/scheduler.py @@ -4,16 +4,40 @@ from torch.optim.lr_scheduler import _LRScheduler +# It will be replaced with huggingface optimization +class WarmUpLR(_LRScheduler): + """warmup_training learning rate scheduler + Args: + optimizer: optimzier(e.g. SGD) + total_iters: totoal_iters of warmup phase + """ + + def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1): + + self.total_iters = iter_per_epoch * warmup_epoch + self.iter_per_epoch = iter_per_epoch + super().__init__(optimizer, last_epoch) + + def get_lr(self): + """we will use the first m batches, and set the learning + rate to base_lr * m / total_iters + """ + return [ + base_lr * self.last_epoch / (self.total_iters + 1e-8) + for base_lr in self.base_lrs + ] + + class WarmupLrScheduler(_LRScheduler): def __init__( self, optimizer, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, ): - self.warmup_epoch = warmup_epoch + self.warmup_iter = warmup_iter self.warmup_ratio = warmup_ratio self.warmup = warmup super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) @@ -24,7 +48,7 @@ def get_lr(self): return lrs def get_lr_ratio(self): - if self.last_epoch < self.warmup_epoch: + if self.last_epoch < self.warmup_iter: ratio = self.get_warmup_ratio() else: ratio = self.get_main_ratio() @@ -35,7 +59,7 @@ def get_main_ratio(self): def get_warmup_ratio(self): assert self.warmup in ("linear", "exp") - alpha = self.last_epoch / self.warmup_epoch + alpha = self.last_epoch / self.warmup_iter if self.warmup == "linear": ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha elif self.warmup == "exp": @@ -48,22 +72,22 @@ def __init__( self, optimizer, power, - max_epoch, - warmup_epoch=500, + max_iter, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, ): self.power = power - self.max_epoch = max_epoch + self.max_iter = max_iter super(WarmupPolyLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_epoch = self.last_epoch - self.warmup_epoch - real_max_epoch = self.max_epoch - self.warmup_epoch - alpha = real_epoch / real_max_epoch + real_iter = self.last_epoch - self.warmup_iter + real_max_iter = self.max_iter - self.warmup_iter + alpha = real_iter / real_max_iter ratio = (1 - alpha) ** self.power return ratio @@ -74,7 +98,7 @@ def __init__( optimizer, gamma, interval=1, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, @@ -82,12 +106,12 @@ def __init__( self.gamma = gamma self.interval = interval super(WarmupExpLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_epoch = self.last_epoch - self.warmup_epoch - ratio = self.gamma ** (real_epoch // self.interval) + real_iter = self.last_epoch - self.warmup_iter + ratio = self.gamma ** (real_iter // self.interval) return ratio @@ -95,25 +119,26 @@ class WarmupCosineLrScheduler(WarmupLrScheduler): def __init__( self, optimizer, - max_epoch, + max_iter, eta_ratio=0, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, ): self.eta_ratio = eta_ratio - self.max_epoch = max_epoch + self.max_iter = max_iter super(WarmupCosineLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_max_epoch = self.max_epoch - self.warmup_epoch + real_iter = self.last_epoch - self.warmup_iter + real_max_iter = self.max_iter - self.warmup_iter return ( self.eta_ratio + (1 - self.eta_ratio) - * (1 + math.cos(math.pi * self.last_epoch / real_max_epoch)) + * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2 ) @@ -124,7 +149,7 @@ def __init__( optimizer, milestones: list, gamma=0.1, - warmup_epoch=500, + warmup_iter=500, warmup_ratio=5e-4, warmup="exp", last_epoch=-1, @@ -132,10 +157,10 @@ def __init__( self.milestones = milestones self.gamma = gamma super(WarmupStepLrScheduler, self).__init__( - optimizer, warmup_epoch, warmup_ratio, warmup, last_epoch + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch ) def get_main_ratio(self): - real_epoch = self.last_epoch - self.warmup_epoch - ratio = self.gamma ** bisect_right(self.milestones, real_epoch) + real_iter = self.last_epoch - self.warmup_iter + ratio = self.gamma ** bisect_right(self.milestones, real_iter) return ratio diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 0c761b8edc..088dbc5774 100644 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -187,6 +187,7 @@ def get_params() -> AttributeDict: "valid_interval": 200, "env_info": get_env_info(), "sampling_rate": 24000, + "audio_normalization": False, "chunk_size": 1.0, # in seconds "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss "lambda_wav": 1.0, # loss scaling coefficient for waveform loss @@ -276,13 +277,13 @@ def get_model(params: AttributeDict) -> nn.Module: } discriminator_params = { "stft_discriminator_n_filters": 32, - "discriminator_epoch_start": 3, + "discriminator_epoch_start": 5, "n_ffts": [1024, 2048, 512], "hop_lengths": [256, 512, 128], "win_lengths": [1024, 2048, 512], } inference_params = { - "target_bw": 12, + "target_bw": 6, } params.update(generator_params) @@ -353,6 +354,11 @@ def prepare_input( :, params.sampling_rate : params.sampling_rate + params.sampling_rate ] + if params.audio_normalization: + mean = audio.mean(dim=-1, keepdim=True) + std = audio.std(dim=-1, keepdim=True) + audio = (audio - mean) / (std + 1e-7) + return audio, audio_lens, features, features_lens @@ -532,6 +538,10 @@ def save_bad_model(suffix: str = ""): save_bad_model() raise + # step per iteration + scheduler_g.step() + scheduler_d.step() + if params.print_diagnostics and batch_idx == 5: return @@ -1009,16 +1019,16 @@ def run(rank, world_size, args): scheduler_g = WarmupCosineLrScheduler( optimizer=optimizer_g, - max_epoch=params.num_epochs, + max_iter=params.num_epochs * 1500, eta_ratio=0.1, - warmup_epoch=params.discriminator_epoch_start, + warmup_iter=params.discriminator_epoch_start * 1500, warmup_ratio=1e-4, ) scheduler_d = WarmupCosineLrScheduler( optimizer=optimizer_d, - max_epoch=params.num_epochs, + max_iter=params.num_epochs * 1500, eta_ratio=0.1, - warmup_epoch=params.discriminator_epoch_start, + warmup_iter=params.discriminator_epoch_start * 1500, warmup_ratio=1e-4, ) @@ -1128,10 +1138,6 @@ def run(rank, world_size, args): best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) - # step per epoch - scheduler_g.step() - scheduler_d.step() - logging.info("Done!") if world_size > 1: From b65eba2d1578623afe11170939e18358bfe2fbf0 Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 7 Oct 2024 09:53:09 +0800 Subject: [PATCH 19/33] fixed script for inference --- egs/libritts/CODEC/encodec/infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py index e5d69fa600..6be869534b 100755 --- a/egs/libritts/CODEC/encodec/infer.py +++ b/egs/libritts/CODEC/encodec/infer.py @@ -289,9 +289,9 @@ def main(): logging.info(f"Number of parameters in decoder: {num_param_d}") num_param_q = sum([p.numel() for p in quantizer.parameters()]) logging.info(f"Number of parameters in quantizer: {num_param_q}") - num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) + num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) if multi_scale_discriminator is not None else 0 logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") - num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) + num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) if multi_period_discriminator is not None else 0 logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") num_param_dstft = sum( [p.numel() for p in multi_scale_stft_discriminator.parameters()] From 266e8404754542e5ee6ac5c7c3b339e7c5b5c0df Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 7 Oct 2024 16:10:13 +0800 Subject: [PATCH 20/33] fixed ``+x`` permission --- egs/libritts/CODEC/encodec/train.py | 1 - 1 file changed, 1 deletion(-) mode change 100644 => 100755 egs/libritts/CODEC/encodec/train.py diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py old mode 100644 new mode 100755 index 088dbc5774..8475ab6e86 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -13,7 +13,6 @@ import torch.nn as nn from codec_datamodule import LibriTTSCodecDataModule from encodec import Encodec -from lhotse.cut import Cut from lhotse.utils import fix_random_seed from loss import adopt_weight from scheduler import WarmupCosineLrScheduler From 32a7d2222d53070cc9cd9b26727ef1a8622d17d5 Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 7 Oct 2024 20:55:53 +0800 Subject: [PATCH 21/33] minor updates to the scripts --- .../ASR/zipformer/attention_decoder.py | 26 +- .../ASR/zipformer/export-onnx-streaming.py | 12 +- egs/librispeech/ASR/zipformer/export-onnx.py | 8 +- .../ASR/local/compute_fbank_libritts.py | 2 +- .../convert_transcript_words_to_tokens.py | 1 + egs/libritts/ASR/local/download_lm.py | 1 + egs/libritts/ASR/local/norm_text.py | 1 + egs/libritts/ASR/local/prepare_lang.py | 1 + egs/libritts/ASR/local/prepare_lang_bpe.py | 1 + egs/libritts/ASR/local/prepare_lang_fst.py | 1 + .../ASR/local/prepare_lm_training_data.py | 1 + egs/libritts/ASR/local/train_bpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/libritts/ASR/prepare.sh | 87 +++++- egs/libritts/ASR/prepare_lm.sh | 264 ++++++++++++++++++ egs/libritts/ASR/zipformer/decode.py | 10 +- .../ASR/zipformer/streaming_decode.py | 6 +- egs/libritts/ASR/zipformer/train.py | 35 ++- egs/libritts/CODEC/prepare.sh | 9 - 19 files changed, 422 insertions(+), 46 deletions(-) create mode 120000 egs/libritts/ASR/local/convert_transcript_words_to_tokens.py create mode 120000 egs/libritts/ASR/local/download_lm.py create mode 120000 egs/libritts/ASR/local/norm_text.py create mode 120000 egs/libritts/ASR/local/prepare_lang.py create mode 120000 egs/libritts/ASR/local/prepare_lang_bpe.py create mode 120000 egs/libritts/ASR/local/prepare_lang_fst.py create mode 120000 egs/libritts/ASR/local/prepare_lm_training_data.py create mode 120000 egs/libritts/ASR/local/train_bpe_model.py create mode 120000 egs/libritts/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/libritts/ASR/prepare_lm.sh diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 81682e87b5..bff536f90b 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -236,7 +236,7 @@ def forward( causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) attn_mask = torch.logical_or( padding_mask.unsqueeze(1), # (batch, 1, seq_len) - torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) + torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len) ) # (batch, seq_len, seq_len) if memory is not None: @@ -367,7 +367,9 @@ def __init__( self.num_heads = num_heads self.head_dim = attention_dim // num_heads assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, num_heads, attention_dim + self.head_dim, + num_heads, + attention_dim, ) self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. @@ -437,15 +439,19 @@ def forward( if key_padding_mask is not None: assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), ) if attn_mask is not None: - assert ( - attn_mask.shape == (batch, 1, src_len) - or attn_mask.shape == (batch, tgt_len, src_len) + assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == ( + batch, + tgt_len, + src_len, ), attn_mask.shape - attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) + attn_weights = attn_weights.masked_fill( + attn_mask.unsqueeze(1), float("-inf") + ) attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -456,7 +462,11 @@ def forward( # (batch * head, tgt_len, head_dim) attn_output = torch.bmm(attn_weights, v) - assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape + assert attn_output.shape == ( + batch * num_heads, + tgt_len, + head_dim, + ), attn_output.shape attn_output = attn_output.transpose(0, 1).contiguous() attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 88c58f5818..a35eb52877 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -487,6 +487,7 @@ def build_inputs_outputs(tensors, i): add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, decoder_filename: str, @@ -754,30 +755,31 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") - if(params.fp16) : + if params.fp16: from onnxconverter_common import float16 + logging.info("Generate fp16 models") encoder = onnx.load(encoder_filename) encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16,encoder_filename_fp16) + onnx.save(encoder_fp16, encoder_filename_fp16) decoder = onnx.load(decoder_filename) decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16,decoder_filename_fp16) + onnx.save(decoder_fp16, decoder_filename_fp16) joiner = onnx.load(joiner_filename) joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16,joiner_filename_fp16) + onnx.save(joiner_fp16, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection logging.info("Generate int8 quantization models") - + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" quantize_dynamic( model_input=encoder_filename, diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index ca3cbf0d59..a56a7a3e67 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -592,23 +592,23 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") - if(params.fp16) : + if params.fp16: logging.info("Generate fp16 models") encoder = onnx.load(encoder_filename) encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16,encoder_filename_fp16) + onnx.save(encoder_fp16, encoder_filename_fp16) decoder = onnx.load(decoder_filename) decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16,decoder_filename_fp16) + onnx.save(decoder_fp16, decoder_filename_fp16) joiner = onnx.load(joiner_filename) joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16,joiner_filename_fp16) + onnx.save(joiner_fp16, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection diff --git a/egs/libritts/ASR/local/compute_fbank_libritts.py b/egs/libritts/ASR/local/compute_fbank_libritts.py index 5e78af18b1..b6e2a4c436 100755 --- a/egs/libritts/ASR/local/compute_fbank_libritts.py +++ b/egs/libritts/ASR/local/compute_fbank_libritts.py @@ -124,7 +124,7 @@ def compute_fbank_libritts( supervisions=m["supervisions"], ) if sampling_rate != 24000: - logging.info(f"Resampling audio to {sampling_rate}") + logging.info(f"Resampling audio to {sampling_rate}Hz") cut_set = cut_set.resample(sampling_rate) if "train" in partition: if perturb_speed: diff --git a/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py b/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py new file mode 120000 index 0000000000..2ce13fd69a --- /dev/null +++ b/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/download_lm.py b/egs/libritts/ASR/local/download_lm.py new file mode 120000 index 0000000000..c9668bd2dc --- /dev/null +++ b/egs/libritts/ASR/local/download_lm.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/download_lm.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/norm_text.py b/egs/libritts/ASR/local/norm_text.py new file mode 120000 index 0000000000..dea3c051f8 --- /dev/null +++ b/egs/libritts/ASR/local/norm_text.py @@ -0,0 +1 @@ +../../../libriheavy/ASR/local/norm_text.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang.py b/egs/libritts/ASR/local/prepare_lang.py new file mode 120000 index 0000000000..747f2ab398 --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang_bpe.py b/egs/libritts/ASR/local/prepare_lang_bpe.py new file mode 120000 index 0000000000..36b40e7fc2 --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang_fst.py b/egs/libritts/ASR/local/prepare_lang_fst.py new file mode 120000 index 0000000000..c5787c5340 --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lm_training_data.py b/egs/libritts/ASR/local/prepare_lm_training_data.py new file mode 120000 index 0000000000..abc00d421f --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/train_bpe_model.py b/egs/libritts/ASR/local/train_bpe_model.py new file mode 120000 index 0000000000..6fad36421e --- /dev/null +++ b/egs/libritts/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/validate_bpe_lexicon.py b/egs/libritts/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 0000000000..721bb48e7c --- /dev/null +++ b/egs/libritts/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh index 23c84e8386..4b551385f5 100755 --- a/egs/libritts/ASR/prepare.sh +++ b/egs/libritts/ASR/prepare.sh @@ -7,9 +7,15 @@ set -eou pipefail stage=0 stop_stage=100 -sampling_rate=24000 +sampling_rate=16000 nj=32 perturb_speed=true +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) dl_dir=$PWD/download @@ -27,6 +33,15 @@ log() { log "dl_dir: $dl_dir" +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download LM" # we directly use the librispeech lm here + mkdir -p $dl_dir/lm + if [ ! -e $dl_dir/lm/.done ]; then + ./local/download_lm.py --out-dir=$dl_dir/lm + touch $dl_dir/lm/.done + fi +fi + if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" @@ -107,3 +122,73 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then touch data/fbank/.msuan.done fi fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Train BPE model for normalized text" + + if [ ! -f data/texts ]; then + gunzip -c data/manifests/libritts_supervisions_train-clean-100.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py > data/texts + + gunzip -c data/manifests/libritts_supervisions_train-clean-360.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> data/texts + + gunzip -c data/manifests/libritts_supervisions_train-other-500.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> data/texts + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + cp data/texts $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then + log "No lexicon file in $dl_dir/lm, please run :" + log "prepare.sh --stage -1 --stop-stage -1" + exit -1 + fi + + if [ ! -f $lang_dir/lexicon.txt ]; then + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi diff --git a/egs/libritts/ASR/prepare_lm.sh b/egs/libritts/ASR/prepare_lm.sh new file mode 100755 index 0000000000..1c690983b8 --- /dev/null +++ b/egs/libritts/ASR/prepare_lm.sh @@ -0,0 +1,264 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +# This script generate Ngram LM / NNLM and related files that needed by decoding. + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# + +. prepare.sh --stage -1 --stop-stage 6 || exit 1 + +log "Running prepare_lm.sh" + +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare BPE based lexicon." + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare word level G" + # We assume you have installed kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare token level ngram G" + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + for ngram in 2 3 4 5; do + if [ ! -f $lang_dir/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt + fi + done + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate NNLM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate NNLM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + gunzip -c data/manifests/libritts_supervisions_dev-clean.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py > $out_dir/valid.txt + + gunzip -c data/manifests/libritts_supervisions_dev-other.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Generate NNLM test data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + gunzip -c data/manifests/libritts_supervisions_test-clean.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py > $out_dir/test.txt + + gunzip -c data/manifests/libritts_supervisions_test-other.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> $out_dir/test.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Sort NNLM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py index 8b033ce90f..15267b0cb2 100755 --- a/egs/libritts/ASR/zipformer/decode.py +++ b/egs/libritts/ASR/zipformer/decode.py @@ -1041,13 +1041,13 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriTTSAsrDataModule(args) + libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts() + test_other_cuts = libritts.test_other_cuts() - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] diff --git a/egs/libritts/ASR/zipformer/streaming_decode.py b/egs/libritts/ASR/zipformer/streaming_decode.py index 4e2f1ecb9f..3ecc5c94f1 100755 --- a/egs/libritts/ASR/zipformer/streaming_decode.py +++ b/egs/libritts/ASR/zipformer/streaming_decode.py @@ -864,10 +864,10 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriTTSAsrDataModule(args) + libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts() + test_other_cuts = libritts.test_other_cuts() test_sets = ["test-clean", "test-other"] test_cuts = [test_clean_cuts, test_other_cuts] diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index 5eb90efbcb..98bbafc4a3 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -603,6 +603,15 @@ def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + def get_encoder_embed(params: AttributeDict) -> nn.Module: # encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, (T - 7) // 2, encoder_dims). @@ -1284,21 +1293,26 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriTTSAsrDataModule(args) + libritts = LibriTTSAsrDataModule(args) if params.full_libri: - train_cuts = librispeech.train_all_shuf_cuts() + train_cuts = libritts.train_all_shuf_cuts() # previously we used the following code to load all training cuts, # strictly speaking, shuffled training cuts should be used instead, # but we leave the code here to demonstrate that there is an option # like this to combine multiple cutsets - # train_cuts = librispeech.train_clean_100_cuts() - # train_cuts += librispeech.train_clean_360_cuts() - # train_cuts += librispeech.train_other_500_cuts() + # train_cuts = libritts.train_clean_100_cuts() + # train_cuts += libritts.train_clean_360_cuts() + # train_cuts += libritts.train_other_500_cuts() else: - train_cuts = librispeech.train_clean_100_cuts() + train_cuts = libritts.train_clean_100_cuts() + + def normalize_text(c: Cut): + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1338,6 +1352,7 @@ def remove_short_and_long_utt(c: Cut): return True train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.map(normalize_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1346,13 +1361,13 @@ def remove_short_and_long_utt(c: Cut): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = libritts.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_cuts = libritts.dev_clean_cuts() + valid_cuts += libritts.dev_other_cuts() + valid_dl = libritts.valid_dataloaders(valid_cuts) if not params.print_diagnostics: scan_pessimistic_batches_for_oom( diff --git a/egs/libritts/CODEC/prepare.sh b/egs/libritts/CODEC/prepare.sh index 3dcb734745..6a471c3adc 100755 --- a/egs/libritts/CODEC/prepare.sh +++ b/egs/libritts/CODEC/prepare.sh @@ -37,15 +37,6 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ ! -d $dl_dir/LibriTTS ]; then lhotse download libritts $dl_dir fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/musan - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then From f0744877a6b0642239c861fca8b2f4dbda00a272 Mon Sep 17 00:00:00 2001 From: JinZr Date: Mon, 7 Oct 2024 23:32:03 +0800 Subject: [PATCH 22/33] minor updates --- egs/libritts/ASR/local/compile_hlg.py | 1 + egs/libritts/ASR/local/compile_lg.py | 1 + egs/libritts/ASR/prepare.sh | 10 +++++----- egs/libritts/ASR/zipformer/decoder.py | 1 + egs/libritts/ASR/zipformer/train.py | 2 +- 5 files changed, 9 insertions(+), 6 deletions(-) create mode 120000 egs/libritts/ASR/local/compile_hlg.py create mode 120000 egs/libritts/ASR/local/compile_lg.py create mode 120000 egs/libritts/ASR/zipformer/decoder.py diff --git a/egs/libritts/ASR/local/compile_hlg.py b/egs/libritts/ASR/local/compile_hlg.py new file mode 120000 index 0000000000..471aa7fb40 --- /dev/null +++ b/egs/libritts/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/compile_lg.py b/egs/libritts/ASR/local/compile_lg.py new file mode 120000 index 0000000000..462d6d3fb9 --- /dev/null +++ b/egs/libritts/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh index 4b551385f5..9d9ce8f870 100755 --- a/egs/libritts/ASR/prepare.sh +++ b/egs/libritts/ASR/prepare.sh @@ -126,25 +126,25 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Train BPE model for normalized text" - if [ ! -f data/texts ]; then + if [ ! -f data/text ]; then gunzip -c data/manifests/libritts_supervisions_train-clean-100.jsonl.gz \ | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py > data/texts + | ./local/norm_text.py > data/text gunzip -c data/manifests/libritts_supervisions_train-clean-360.jsonl.gz \ | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py >> data/texts + | ./local/norm_text.py >> data/text gunzip -c data/manifests/libritts_supervisions_train-other-500.jsonl.gz \ | jq ".text" | sed 's/"//g' \ - | ./local/norm_text.py >> data/texts + | ./local/norm_text.py >> data/text fi for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} mkdir -p $lang_dir - cp data/texts $lang_dir/text + cp data/text $lang_dir/text if [ ! -f $lang_dir/bpe.model ]; then ./local/train_bpe_model.py \ diff --git a/egs/libritts/ASR/zipformer/decoder.py b/egs/libritts/ASR/zipformer/decoder.py new file mode 120000 index 0000000000..5a8018680d --- /dev/null +++ b/egs/libritts/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index 98bbafc4a3..0fa32d7f64 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -1351,8 +1351,8 @@ def remove_short_and_long_utt(c: Cut): return True - train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.map(normalize_text) + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint From 156af46a6e0f6f266003a51b758cfa03c0fd3418 Mon Sep 17 00:00:00 2001 From: JinZr Date: Tue, 8 Oct 2024 00:02:16 +0800 Subject: [PATCH 23/33] applied text norm to valid & test cuts --- egs/libritts/ASR/zipformer/decode.py | 6 +++--- egs/libritts/ASR/zipformer/train.py | 28 ++++++++++++++-------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py index 15267b0cb2..759d9d50a1 100755 --- a/egs/libritts/ASR/zipformer/decode.py +++ b/egs/libritts/ASR/zipformer/decode.py @@ -123,7 +123,7 @@ modified_beam_search_LODR, ) from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text from icefall import ContextGraph, LmScorer, NgramLm from icefall.checkpoint import ( @@ -1043,8 +1043,8 @@ def main(): args.return_cuts = True libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = libritts.test_clean_cuts() - test_other_cuts = libritts.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) test_clean_dl = libritts.test_dataloaders(test_clean_cuts) test_other_dl = libritts.test_dataloaders(test_other_cuts) diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index 0fa32d7f64..5485eaf0ab 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -603,13 +603,18 @@ def _to_int_tuple(s: str): return tuple(map(int, s.split(","))) -def remove_punc_to_upper(text: str) -> str: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - s_list = [x.upper() if x in tokens else " " for x in text] - s = " ".join("".join(s_list).split()).strip() - return s +def normalize_text(c: Cut): + def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c def get_encoder_embed(params: AttributeDict) -> nn.Module: @@ -1309,11 +1314,6 @@ def run(rank, world_size, args): else: train_cuts = libritts.train_clean_100_cuts() - def normalize_text(c: Cut): - text = remove_punc_to_upper(c.supervisions[0].text) - c.supervisions[0].text = text - return c - def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -1365,8 +1365,8 @@ def remove_short_and_long_utt(c: Cut): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = libritts.dev_clean_cuts() - valid_cuts += libritts.dev_other_cuts() + valid_cuts = libritts.dev_clean_cuts().map(normalize_text) + valid_cuts += libritts.dev_other_cuts().map(normalize_text) valid_dl = libritts.valid_dataloaders(valid_cuts) if not params.print_diagnostics: From 43267e3e29dcc68c89ec8c5d1fdd192832216fda Mon Sep 17 00:00:00 2001 From: JinZr Date: Tue, 8 Oct 2024 13:12:12 +0800 Subject: [PATCH 24/33] black formatted --- egs/libritts/CODEC/encodec/infer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py index 6be869534b..2abe2493ba 100755 --- a/egs/libritts/CODEC/encodec/infer.py +++ b/egs/libritts/CODEC/encodec/infer.py @@ -289,9 +289,17 @@ def main(): logging.info(f"Number of parameters in decoder: {num_param_d}") num_param_q = sum([p.numel() for p in quantizer.parameters()]) logging.info(f"Number of parameters in quantizer: {num_param_q}") - num_param_ds = sum([p.numel() for p in multi_scale_discriminator.parameters()]) if multi_scale_discriminator is not None else 0 + num_param_ds = ( + sum([p.numel() for p in multi_scale_discriminator.parameters()]) + if multi_scale_discriminator is not None + else 0 + ) logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") - num_param_dp = sum([p.numel() for p in multi_period_discriminator.parameters()]) if multi_period_discriminator is not None else 0 + num_param_dp = ( + sum([p.numel() for p in multi_period_discriminator.parameters()]) + if multi_period_discriminator is not None + else 0 + ) logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") num_param_dstft = sum( [p.numel() for p in multi_scale_stft_discriminator.parameters()] From 2356621059184b53fc10c338f8a46a17d923fdd4 Mon Sep 17 00:00:00 2001 From: JinZr Date: Wed, 9 Oct 2024 14:04:21 +0800 Subject: [PATCH 25/33] minor updates --- egs/libritts/CODEC/encodec/encodec.py | 31 +---- egs/libritts/CODEC/encodec/loss.py | 168 -------------------------- egs/libritts/CODEC/encodec/train.py | 25 ++-- 3 files changed, 19 insertions(+), 205 deletions(-) diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index aa0373bfab..e1b646d725 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -157,27 +157,7 @@ def _forward_generator( x=speech, x_hat=speech_hat ) - # loss, rec_loss, adv_loss, feat_loss, d_weight = loss_g( - # commit_loss, - # speech, - # speech_hat, - # fmap, - # fmap_hat, - # y, - # y_hat, - # y_p, - # y_p_hat, - # y_s, - # y_s_hat, - # fmap_p, - # fmap_p_hat, - # fmap_s, - # fmap_s_hat, - # args=self.params, - # ) - stats = dict( - # generator_loss=loss.item(), generator_wav_reconstruction_loss=wav_reconstruction_loss.item(), generator_mel_reconstruction_loss=mel_reconstruction_loss.item(), generator_feature_stft_loss=feature_stft_loss.item(), @@ -187,7 +167,6 @@ def _forward_generator( generator_period_adv_loss=gen_period_adv_loss.item(), generator_scale_adv_loss=gen_scale_adv_loss.item(), generator_commit_loss=commit_loss.item(), - # d_weight=d_weight.item(), ) if return_sample: @@ -260,18 +239,16 @@ def _forward_discriminator( speech_hat.contiguous().detach() ) - disc_period_real_adv_loss, disc_period_fake_adv_loss = torch.tensor( - 0.0 - ), torch.tensor(0.0) + disc_period_real_adv_loss = torch.tensor(0.0) + disc_period_fake_adv_loss = torch.tensor(0.0) if self.multi_period_discriminator is not None: y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( speech.contiguous(), speech_hat.contiguous().detach(), ) - disc_scale_real_adv_loss, disc_scale_fake_adv_loss = torch.tensor( - 0.0 - ), torch.tensor(0.0) + disc_scale_real_adv_loss = torch.tensor(0.0) + disc_scale_fake_adv_loss = torch.tensor(0.0) if self.multi_scale_discriminator is not None: y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( speech.contiguous(), diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index ae1e34bddf..8ec80bb9c9 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -317,171 +317,3 @@ def forward( wav_loss = F.l1_loss(x, x_hat) return wav_loss - - -def adversarial_g_loss(y_disc_gen): - """Hinge loss""" - loss = 0.0 - for i in range(len(y_disc_gen)): - stft_loss = F.relu(1 - y_disc_gen[i]).mean().squeeze() - loss += stft_loss - return loss / len(y_disc_gen) - - -def feature_loss(fmap_r, fmap_gen): - loss = 0.0 - for i in range(len(fmap_r)): - for j in range(len(fmap_r[i])): - stft_loss = ( - (fmap_r[i][j] - fmap_gen[i][j]).abs() / (fmap_r[i][j].abs().mean()) - ).mean() - loss += stft_loss - return loss / (len(fmap_r) * len(fmap_r[0])) - - -def sim_loss(y_disc_r, y_disc_gen): - loss = 0.0 - for i in range(len(y_disc_r)): - loss += F.mse_loss(y_disc_r[i], y_disc_gen[i]) - return loss / len(y_disc_r) - - -def reconstruction_loss(x, x_hat, args, eps=1e-7): - # NOTE (lsx): hard-coded now - L = args.lambda_wav * F.mse_loss(x, x_hat) # wav L1 loss - # loss_sisnr = sisnr_loss(G_x, x) # - # L += 0.01*loss_sisnr - # 2^6=64 -> 2^10=1024 - # NOTE (lsx): add 2^11 - for i in range(6, 12): - # for i in range(5, 12): # Encodec setting - s = 2**i - melspec = MelSpectrogram( - sample_rate=args.sampling_rate, - n_fft=max(s, 512), - win_length=s, - hop_length=s // 4, - n_mels=64, - wkwargs={"device": x_hat.device}, - ).to(x_hat.device) - S_x = melspec(x) - S_x_hat = melspec(x_hat) - l1_loss = (S_x - S_x_hat).abs().mean() - l2_loss = ( - ((torch.log(S_x.abs() + eps) - torch.log(S_x_hat.abs() + eps)) ** 2).mean( - dim=-2 - ) - ** 0.5 - ).mean() - - alpha = (s / 2) ** 0.5 - L += l1_loss + alpha * l2_loss - return L - - -def adopt_weight(weight, global_step, threshold=0, value=0.0): - if global_step < threshold: - weight = value - return weight - - -def calculate_adaptive_weight(nll_loss, g_loss, last_layer, args): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - print("last_layer cannot be none") - assert 1 == 2 - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 1.0, 1.0).detach() - d_weight = d_weight * args.lambda_adv - return d_weight - - -def loss_g( - codebook_loss, - speech, - speech_hat, - fmap, - fmap_hat, - y, - y_hat, - y_df, - y_df_hat, - y_ds, - y_ds_hat, - fmap_f, - fmap_f_hat, - fmap_s, - fmap_s_hat, - args=None, -): - """ - args: - codebook_loss: commit loss. - speech: ground-truth wav. - speech_hat: reconstructed wav. - fmap: real stft-D feature map. - fmap_hat: fake stft-D feature map. - y: real stft-D logits. - y_hat: fake stft-D logits. - global_step: global training step. - y_df: real MPD logits. - y_df_hat: fake MPD logits. - y_ds: real MSD logits. - y_ds_hat: fake MSD logits. - fmap_f: real MPD feature map. - fmap_f_hat: fake MPD feature map. - fmap_s: real MSD feature map. - fmap_s_hat: fake MSD feature map. - """ - rec_loss = reconstruction_loss(speech.contiguous(), speech_hat.contiguous(), args) - adv_g_loss = adversarial_g_loss(y_hat) - adv_mpd_loss = adversarial_g_loss(y_df_hat) - adv_msd_loss = adversarial_g_loss(y_ds_hat) - adv_loss = ( - adv_g_loss + adv_mpd_loss + adv_msd_loss - ) / 3.0 # NOTE(lsx): need to divide by 3? - feat_loss = feature_loss( - fmap, fmap_hat - ) # + sim_loss(y_disc_r, y_disc_gen) # NOTE(lsx): need logits? - feat_loss_mpd = feature_loss( - fmap_f, fmap_f_hat - ) # + sim_loss(y_df_hat_r, y_df_hat_g) - feat_loss_msd = feature_loss( - fmap_s, fmap_s_hat - ) # + sim_loss(y_ds_hat_r, y_ds_hat_g) - feat_loss_tot = (feat_loss + feat_loss_mpd + feat_loss_msd) / 3.0 - d_weight = torch.tensor(1.0) - - # disc_factor = adopt_weight( - # args.lambda_adv, global_step, threshold=args.discriminator_iter_start - # ) - disc_factor = 1 - if disc_factor == 0.0: - fm_loss_wt = 0 - else: - fm_loss_wt = args.lambda_feat - - loss = ( - rec_loss - + d_weight * disc_factor * adv_loss - + fm_loss_wt * feat_loss_tot - + args.lambda_com * codebook_loss - ) - return loss, rec_loss, adv_loss, feat_loss_tot, d_weight - - -if __name__ == "__main__": - # la = FeatureLoss(average_by_layers=True, average_by_discriminators=True) - # aa = [torch.rand(192, 192) for _ in range(3)] - # bb = [torch.rand(192, 192) for _ in range(3)] - # print(la(bb, aa)) - # print(feature_loss(aa, bb)) - la = GeneratorAdversarialLoss(average_by_discriminators=True, loss_type="hinge") - aa = torch.Tensor([0.1, 0.2, 0.3, 0.4]) - bb = torch.Tensor([0.4, 0.3, 0.2, 0.1]) - print(la(aa)) - print(adversarial_g_loss(aa)) - print(la(bb)) - print(adversarial_g_loss(bb)) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 8475ab6e86..11f352911f 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -14,7 +14,6 @@ from codec_datamodule import LibriTTSCodecDataModule from encodec import Encodec from lhotse.utils import fix_random_seed -from loss import adopt_weight from scheduler import WarmupCosineLrScheduler from torch import nn from torch.cuda.amp import GradScaler, autocast @@ -189,10 +188,10 @@ def get_params() -> AttributeDict: "audio_normalization": False, "chunk_size": 1.0, # in seconds "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss - "lambda_wav": 1.0, # loss scaling coefficient for waveform loss - "lambda_feat": 3.0, # loss scaling coefficient for feat loss + "lambda_wav": 0.1, # loss scaling coefficient for waveform loss + "lambda_feat": 4.0, # loss scaling coefficient for feat loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 100.0, # loss scaling coefficient for commitment loss + "lambda_com": 1.0, # loss scaling coefficient for commitment loss } ) @@ -361,6 +360,12 @@ def prepare_input( return audio, audio_lens, features, features_lens +def train_discriminator(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], @@ -447,7 +452,7 @@ def save_bad_model(suffix: str = ""): try: with autocast(enabled=params.use_fp16): - d_weight = adopt_weight( + d_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -483,7 +488,7 @@ def save_bad_model(suffix: str = ""): scaler.step(optimizer_d) with autocast(enabled=params.use_fp16): - g_weight = adopt_weight( + g_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -702,7 +707,7 @@ def compute_validation_loss( loss_info = MetricsTracker() loss_info["samples"] = batch_size - d_weight = adopt_weight( + d_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -735,7 +740,7 @@ def compute_validation_loss( for k, v in stats_d.items(): loss_info[k] = v * batch_size - g_weight = adopt_weight( + g_weight = train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_epoch_start, @@ -845,7 +850,7 @@ def scan_pessimistic_batches_for_oom( + disc_period_fake_adv_loss + disc_scale_real_adv_loss + disc_scale_fake_adv_loss - ) * adopt_weight( + ) * train_discriminator( params.lambda_adv, params.cur_epoch, threshold=params.discriminator_train_start, @@ -873,7 +878,7 @@ def scan_pessimistic_batches_for_oom( ) loss_g = ( (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) - * adopt_weight( + * train_discriminator( params.lambda_adv, 0, threshold=params.discriminator_epoch_start, From df87a0fe2c018f036c20ecb1d0aef3354fdc5365 Mon Sep 17 00:00:00 2001 From: JinZr Date: Wed, 9 Oct 2024 14:12:41 +0800 Subject: [PATCH 26/33] updated train.py --- egs/libritts/CODEC/encodec/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 11f352911f..934d480f59 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -527,6 +527,7 @@ def save_bad_model(suffix: str = ""): + params.lambda_feat * feature_loss + params.lambda_com * commit_loss ) + loss_info["generator_loss"] = gen_loss for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size @@ -737,6 +738,7 @@ def compute_validation_loss( + disc_scale_fake_adv_loss ) * d_weight assert disc_loss.requires_grad is False + loss_info["discriminator_loss"] = disc_loss for k, v in stats_d.items(): loss_info[k] = v * batch_size @@ -778,6 +780,7 @@ def compute_validation_loss( + params.lambda_com * commit_loss ) assert gen_loss.requires_grad is False + loss_info["generator_loss"] = gen_loss for k, v in stats_g.items(): if "returned_sample" not in k: loss_info[k] = v * batch_size From 5492a6a5e2fcfb3c7973696164d19d41c0b72364 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sat, 12 Oct 2024 15:33:38 +0800 Subject: [PATCH 27/33] comments updated --- egs/libritts/CODEC/encodec/loss.py | 22 ++++++++++++---------- egs/libritts/CODEC/encodec/train.py | 9 +++------ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 8ec80bb9c9..9cf1d42d28 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -1,8 +1,19 @@ +# Modified from egs/ljspeech/TTS/vits/loss.py by: Zengrui JIN (Tsinghua University) +# original implementation is from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encodec-related loss modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + from typing import List, Tuple, Union import torch import torch.nn.functional as F -from lhotse.features.kaldi import Wav2LogFilterBank from torchaudio.transforms import MelSpectrogram @@ -225,15 +236,6 @@ def __init__( self.wav_to_specs = [] for i in range(5, 12): s = 2**i - # self.wav_to_specs.append( - # Wav2LogFilterBank( - # sampling_rate=sampling_rate, - # frame_length=s, - # frame_shift=s // 4, - # use_fft_mag=use_fft_mag, - # num_filters=n_mels, - # ) - # ) self.wav_to_specs.append( MelSpectrogram( sample_rate=sampling_rate, diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 934d480f59..49e9743105 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -186,12 +186,11 @@ def get_params() -> AttributeDict: "env_info": get_env_info(), "sampling_rate": 24000, "audio_normalization": False, - "chunk_size": 1.0, # in seconds "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss "lambda_wav": 0.1, # loss scaling coefficient for waveform loss "lambda_feat": 4.0, # loss scaling coefficient for feat loss "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss - "lambda_com": 1.0, # loss scaling coefficient for commitment loss + "lambda_com": 1000.0, # loss scaling coefficient for commitment loss } ) @@ -342,12 +341,10 @@ def prepare_input( if is_training: audio_dims = audio.size(-1) - start_idx = random.randint( - 0, max(0, audio_dims - params.chunk_size * params.sampling_rate) - ) + start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate)) audio = audio[:, start_idx : params.sampling_rate + start_idx] else: - # NOTE: a very coarse setup + # NOTE(zengrui): a very coarse setup audio = audio[ :, params.sampling_rate : params.sampling_rate + params.sampling_rate ] From cd96f635c3b405b5ae6dcee5fb243e5fc5648a00 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sat, 12 Oct 2024 16:08:07 +0800 Subject: [PATCH 28/33] added text norm for other decoding scripts --- egs/libritts/ASR/zipformer/ctc_decode.py | 12 ++++++------ egs/libritts/ASR/zipformer/onnx_decode.py | 11 ++++++----- egs/libritts/ASR/zipformer/streaming_decode.py | 6 +++--- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py index 177f2e392d..d77aa59626 100755 --- a/egs/libritts/ASR/zipformer/ctc_decode.py +++ b/egs/libritts/ASR/zipformer/ctc_decode.py @@ -122,7 +122,7 @@ import torch.nn as nn from asr_datamodule import LibriTTSAsrDataModule from lhotse import set_caching_enabled -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text from icefall.checkpoint import ( average_checkpoints, @@ -949,13 +949,13 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriTTSAsrDataModule(args) + libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] diff --git a/egs/libritts/ASR/zipformer/onnx_decode.py b/egs/libritts/ASR/zipformer/onnx_decode.py index 99a02c5cf3..6f09cc8f7b 100755 --- a/egs/libritts/ASR/zipformer/onnx_decode.py +++ b/egs/libritts/ASR/zipformer/onnx_decode.py @@ -80,6 +80,7 @@ from asr_datamodule import LibriTTSAsrDataModule from k2 import SymbolTable from onnx_pretrained import OnnxModel, greedy_search +from train import normalize_text from icefall.utils import setup_logger, store_transcripts, write_error_stats @@ -290,13 +291,13 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriTTSAsrDataModule(args) + libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] diff --git a/egs/libritts/ASR/zipformer/streaming_decode.py b/egs/libritts/ASR/zipformer/streaming_decode.py index 3ecc5c94f1..b210187886 100755 --- a/egs/libritts/ASR/zipformer/streaming_decode.py +++ b/egs/libritts/ASR/zipformer/streaming_decode.py @@ -52,7 +52,7 @@ ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params +from train import add_model_arguments, get_model, get_params, normalize_text from icefall.checkpoint import ( average_checkpoints, @@ -866,8 +866,8 @@ def main(): libritts = LibriTTSAsrDataModule(args) - test_clean_cuts = libritts.test_clean_cuts() - test_other_cuts = libritts.test_other_cuts() + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) test_sets = ["test-clean", "test-other"] test_cuts = [test_clean_cuts, test_other_cuts] From 74a738f22c0eb3ff3cbab6886f6fc4773a523f88 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sat, 12 Oct 2024 16:12:22 +0800 Subject: [PATCH 29/33] comments updated --- egs/libritts/CODEC/encodec/infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py index 2abe2493ba..3c6ea15f9a 100755 --- a/egs/libritts/CODEC/encodec/infer.py +++ b/egs/libritts/CODEC/encodec/infer.py @@ -19,9 +19,9 @@ This script performs model inference on test set. Usage: -./vits/infer.py \ - --epoch 1000 \ - --exp-dir ./vits/exp \ +./codec/infer.py \ + --epoch 300 \ + --exp-dir ./codec/exp \ --max-duration 500 """ From 7eee6b9e9d4d711c0558d18916af494bdbb17e69 Mon Sep 17 00:00:00 2001 From: JinZr Date: Sun, 13 Oct 2024 02:40:04 +0800 Subject: [PATCH 30/33] updated default param --- egs/libritts/CODEC/encodec/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 49e9743105..c2dac0f15e 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -120,7 +120,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=20, + default=1, help="""Save checkpoint after processing this number of epochs" periodically. We save checkpoint to exp-dir/ whenever params.cur_epoch % save_every_n == 0. The checkpoint filename From 283157268a964d46bbb41c9ac18eff2b83d65424 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 20 Oct 2024 21:38:20 +0800 Subject: [PATCH 31/33] added README.md and RESULTS.md --- egs/libritts/ASR/README.md | 26 +++++++++ egs/libritts/ASR/RESULTS.md | 58 ++++++++++++++++++++ egs/libritts/ASR/zipformer/asr_datamodule.py | 2 +- 3 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 egs/libritts/ASR/README.md create mode 100644 egs/libritts/ASR/RESULTS.md diff --git a/egs/libritts/ASR/README.md b/egs/libritts/ASR/README.md new file mode 100644 index 0000000000..138f4ae80a --- /dev/null +++ b/egs/libritts/ASR/README.md @@ -0,0 +1,26 @@ +# Introduction + +LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. +The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. +The main differences from the LibriSpeech corpus are listed below: +1. The audio files are at 24kHz sampling rate. +2. The speech is split at sentence breaks. +3. Both original and normalized texts are included. +4. Contextual information (e.g., neighbouring sentences) can be extracted. +5. Utterances with significant background noise are excluded. +For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. + + +This recipe includes some different ASR models trained with [LibriTTS](https://openslr.org/60/). + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +| | Encoder | Decoder | +|---------------------------------------|---------------------|--------------------| +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | + +The decoder is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/libritts/ASR/RESULTS.md b/egs/libritts/ASR/RESULTS.md new file mode 100644 index 0000000000..574f81eb62 --- /dev/null +++ b/egs/libritts/ASR/RESULTS.md @@ -0,0 +1,58 @@ +# Results + +## zipformer (zipformer + pruned stateless transducer) + +See for more details. + +[zipformer](./zipformer) + +### Non-streaming + +#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 2.83 | 5.91 | --epoch 30 --avg 5 | +| modified_beam_search | 2.80 | 5.87 | --epoch 30 --avg 5 | +| fast_beam_search | 2.87 | 5.86 | --epoch 30 --avg 5 | +| greedy_search | 2.76 | 5.68 | --epoch 40 --avg 16| +| modified_beam_search | 2.74 | 5.66 | --epoch 40 --avg 16| +| fast_beam_search | 2.75 | 5.67 | --epoch 40 --avg 16| +| greedy_search | 2.74 | 5.67 | --epoch 50 --avg 30| +| modified_beam_search | 2.73 | 5.58 | --epoch 50 --avg 30| +| fast_beam_search | 2.78 | 5.61 | --epoch 50 --avg 30| + + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 0 \ + --full-libri 1 \ + --max-duration 3600 +``` +This was used on 2 Nvidia A800 GPUs, you'll need to adjust the `CUDA_VISIBLE_DEVICES`, `--world-size` and `--max-duration` according to your hardware. + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode.py \ + --epoch 50 \ + --avg 30 \ + --use-averaged-model 1 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method $m +done +``` diff --git a/egs/libritts/ASR/zipformer/asr_datamodule.py b/egs/libritts/ASR/zipformer/asr_datamodule.py index 8d2b9eaddf..dab8343032 100644 --- a/egs/libritts/ASR/zipformer/asr_datamodule.py +++ b/egs/libritts/ASR/zipformer/asr_datamodule.py @@ -86,7 +86,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "--full-libri", type=str2bool, default=True, - help="""When enabled, use 960h LibriTTS. + help="""When enabled, use the entire LibriTTS training set. Otherwise, use the 100h subset.""", ) From d7522302875eed66fc991c79061ddf773a9c1b6a Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 21 Oct 2024 10:55:27 +0800 Subject: [PATCH 32/33] added licensing info --- .../CODEC/encodec/base_discriminators.py | 23 +++++++++++++++---- egs/libritts/CODEC/encodec/binary.py | 4 ++-- egs/libritts/CODEC/encodec/discriminators.py | 6 +++++ egs/libritts/CODEC/encodec/encodec.py | 17 ++++++++++++++ egs/libritts/CODEC/encodec/loss.py | 2 +- .../CODEC/encodec/modules/__init__.py | 2 +- egs/libritts/CODEC/encodec/modules/conv.py | 2 +- egs/libritts/CODEC/encodec/modules/lstm.py | 2 +- egs/libritts/CODEC/encodec/modules/norm.py | 2 +- egs/libritts/CODEC/encodec/modules/seanet.py | 6 ++--- .../CODEC/encodec/modules/transformer.py | 2 +- .../CODEC/encodec/quantization/__init__.py | 2 +- egs/libritts/CODEC/encodec/quantization/ac.py | 18 +++++++-------- .../CODEC/encodec/quantization/core_vq.py | 2 +- .../CODEC/encodec/quantization/distrib.py | 2 +- egs/libritts/CODEC/encodec/quantization/vq.py | 2 +- egs/libritts/CODEC/encodec/scheduler.py | 5 ++++ egs/libritts/CODEC/encodec/train.py | 18 +++++++++++++++ 18 files changed, 88 insertions(+), 29 deletions(-) diff --git a/egs/libritts/CODEC/encodec/base_discriminators.py b/egs/libritts/CODEC/encodec/base_discriminators.py index e112436e50..7bc035554a 100644 --- a/egs/libritts/CODEC/encodec/base_discriminators.py +++ b/egs/libritts/CODEC/encodec/base_discriminators.py @@ -1,3 +1,21 @@ +#!/usr/bin/env python3 +# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from typing import List, Tuple import torch @@ -222,17 +240,12 @@ def __init__( def forward(self, x: torch.Tensor): fmap = [] - # print('x ', x.shape) z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] - # print('z ', z.shape) z = torch.cat([z.real, z.imag], dim=1) - # print('cat_z ', z.shape) z = rearrange(z, "b c w t -> b c t w") for i, layer in enumerate(self.convs): z = layer(z) z = self.activation(z) - # print('z i', i, z.shape) fmap.append(z) z = self.conv_post(z) - # print('logit ', z.shape) return z, fmap diff --git a/egs/libritts/CODEC/encodec/binary.py b/egs/libritts/CODEC/encodec/binary.py index 3004831272..0f552fb99f 100644 --- a/egs/libritts/CODEC/encodec/binary.py +++ b/egs/libritts/CODEC/encodec/binary.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" import io @@ -132,7 +132,7 @@ def test(): for rep in range(4): length: int = torch.randint(10, 2_000, (1,)).item() bits: int = torch.randint(1, 16, (1,)).item() - tokens: List[int] = torch.randint(2**bits, (length,)).tolist() + tokens: List[int] = torch.randint(2 ** bits, (length,)).tolist() rebuilt: List[int] = [] buf = io.BytesIO() packer = BitPacker(bits, buf) diff --git a/egs/libritts/CODEC/encodec/discriminators.py b/egs/libritts/CODEC/encodec/discriminators.py index 471aa92443..e6b7f09290 100644 --- a/egs/libritts/CODEC/encodec/discriminators.py +++ b/egs/libritts/CODEC/encodec/discriminators.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + from typing import List import torch diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index e1b646d725..f21d494b62 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -1,3 +1,20 @@ +#!/usr/bin/env python3 +# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import random from typing import List, Optional diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 9cf1d42d28..8675841b27 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -235,7 +235,7 @@ def __init__( super().__init__() self.wav_to_specs = [] for i in range(5, 12): - s = 2**i + s = 2 ** i self.wav_to_specs.append( MelSpectrogram( sample_rate=sampling_rate, diff --git a/egs/libritts/CODEC/encodec/modules/__init__.py b/egs/libritts/CODEC/encodec/modules/__init__.py index e9f7584647..b903a28b0e 100644 --- a/egs/libritts/CODEC/encodec/modules/__init__.py +++ b/egs/libritts/CODEC/encodec/modules/__init__.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Torch modules.""" # flake8: noqa from .conv import ( diff --git a/egs/libritts/CODEC/encodec/modules/conv.py b/egs/libritts/CODEC/encodec/modules/conv.py index 45518a3f8f..a70a5c67fe 100644 --- a/egs/libritts/CODEC/encodec/modules/conv.py +++ b/egs/libritts/CODEC/encodec/modules/conv.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Convolutional layers wrappers and utilities.""" import logging import math diff --git a/egs/libritts/CODEC/encodec/modules/lstm.py b/egs/libritts/CODEC/encodec/modules/lstm.py index 7d5b8af885..5307552c01 100644 --- a/egs/libritts/CODEC/encodec/modules/lstm.py +++ b/egs/libritts/CODEC/encodec/modules/lstm.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """LSTM layers module.""" from torch import nn diff --git a/egs/libritts/CODEC/encodec/modules/norm.py b/egs/libritts/CODEC/encodec/modules/norm.py index b7ab72f9ea..3002b3a265 100644 --- a/egs/libritts/CODEC/encodec/modules/norm.py +++ b/egs/libritts/CODEC/encodec/modules/norm.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Normalization modules.""" from typing import List, Union diff --git a/egs/libritts/CODEC/encodec/modules/seanet.py b/egs/libritts/CODEC/encodec/modules/seanet.py index 50d6c3f13e..38f2f8728c 100644 --- a/egs/libritts/CODEC/encodec/modules/seanet.py +++ b/egs/libritts/CODEC/encodec/modules/seanet.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Encodec SEANet-based encoder and decoder implementation.""" from typing import Any, Dict, List, Optional @@ -161,7 +161,7 @@ def __init__( SEANetResnetBlock( mult * n_filters, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base**j, 1], + dilations=[dilation_base ** j, 1], norm=norm, norm_params=norm_params, activation=activation, @@ -311,7 +311,7 @@ def __init__( SEANetResnetBlock( mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base**j, 1], + dilations=[dilation_base ** j, 1], activation=activation, activation_params=activation_params, norm=norm, diff --git a/egs/libritts/CODEC/encodec/modules/transformer.py b/egs/libritts/CODEC/encodec/modules/transformer.py index 9ef2c7ac15..1768d88f99 100644 --- a/egs/libritts/CODEC/encodec/modules/transformer.py +++ b/egs/libritts/CODEC/encodec/modules/transformer.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """A streamable transformer.""" import typing as tp from typing import Any, List, Optional, Union diff --git a/egs/libritts/CODEC/encodec/quantization/__init__.py b/egs/libritts/CODEC/encodec/quantization/__init__.py index 7364623400..82d744f5fb 100644 --- a/egs/libritts/CODEC/encodec/quantization/__init__.py +++ b/egs/libritts/CODEC/encodec/quantization/__init__.py @@ -2,6 +2,6 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE # flake8: noqa from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/egs/libritts/CODEC/encodec/quantization/ac.py b/egs/libritts/CODEC/encodec/quantization/ac.py index 660931b410..99b62d14ba 100644 --- a/egs/libritts/CODEC/encodec/quantization/ac.py +++ b/egs/libritts/CODEC/encodec/quantization/ac.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Arithmetic coder.""" import io import math @@ -41,7 +41,7 @@ def build_stable_quantized_cdf( if roundoff: pdf = (pdf / roundoff).floor() * roundoff # interpolate with uniform distribution to achieve desired minimum probability. - total_range = 2**total_range_bits + total_range = 2 ** total_range_bits cardinality = len(pdf) alpha = min_range * cardinality / total_range assert alpha <= 1, "you must reduce min_range" @@ -51,7 +51,7 @@ def build_stable_quantized_cdf( if min_range < 2: raise ValueError("min_range must be at least 2.") if check: - assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] if ( (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range ).any() or quantized_cdf[0] < min_range: @@ -142,7 +142,7 @@ def push(self, symbol: int, quantized_cdf: Tensor): quantized_cdf (Tensor): use `build_stable_quantized_cdf` to build this from your pdf estimate. """ - while self.delta < 2**self.total_range_bits: + while self.delta < 2 ** self.total_range_bits: self.low *= 2 self.high = self.high * 2 + 1 self.max_bit += 1 @@ -150,10 +150,10 @@ def push(self, symbol: int, quantized_cdf: Tensor): range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() range_high = quantized_cdf[symbol].item() - 1 effective_low = int( - math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))) ) effective_high = int( - math.floor(range_high * (self.delta / (2**self.total_range_bits))) + math.floor(range_high * (self.delta / (2 ** self.total_range_bits))) ) assert self.low <= self.high self.high = self.low + effective_high @@ -238,7 +238,7 @@ def pull(self, quantized_cdf: Tensor) -> Optional[int]: to build this from your pdf estimate. This must be **exatly** the same cdf as the one used at encoding time. """ - while self.delta < 2**self.total_range_bits: + while self.delta < 2 ** self.total_range_bits: bit = self.unpacker.pull() if bit is None: return None @@ -255,10 +255,10 @@ def bin_search(low_idx: int, high_idx: int): range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 range_high = quantized_cdf[mid].item() - 1 effective_low = int( - math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))) ) effective_high = int( - math.floor(range_high * (self.delta / (2**self.total_range_bits))) + math.floor(range_high * (self.delta / (2 ** self.total_range_bits))) ) low = effective_low + self.low high = effective_high + self.low diff --git a/egs/libritts/CODEC/encodec/quantization/core_vq.py b/egs/libritts/CODEC/encodec/quantization/core_vq.py index 4719e20f7f..0b342f2b0d 100644 --- a/egs/libritts/CODEC/encodec/quantization/core_vq.py +++ b/egs/libritts/CODEC/encodec/quantization/core_vq.py @@ -76,7 +76,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): for _ in range(num_iters): diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") - dists = -(diffs**2).sum(dim=-1) + dists = -(diffs ** 2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) diff --git a/egs/libritts/CODEC/encodec/quantization/distrib.py b/egs/libritts/CODEC/encodec/quantization/distrib.py index 5b1b06d688..41ac7525fe 100644 --- a/egs/libritts/CODEC/encodec/quantization/distrib.py +++ b/egs/libritts/CODEC/encodec/quantization/distrib.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Torch distributed utilities.""" from typing import Dict, Iterable, List diff --git a/egs/libritts/CODEC/encodec/quantization/vq.py b/egs/libritts/CODEC/encodec/quantization/vq.py index 22212a7942..8e59887a6c 100644 --- a/egs/libritts/CODEC/encodec/quantization/vq.py +++ b/egs/libritts/CODEC/encodec/quantization/vq.py @@ -2,7 +2,7 @@ # All rights reserved. # # This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE """Residual vector quantizer implementation.""" import math from dataclasses import dataclass, field diff --git a/egs/libritts/CODEC/encodec/scheduler.py b/egs/libritts/CODEC/encodec/scheduler.py index fb6ba087d6..00ef9882a5 100644 --- a/egs/libritts/CODEC/encodec/scheduler.py +++ b/egs/libritts/CODEC/encodec/scheduler.py @@ -1,3 +1,8 @@ +# original implementation is from https://github.com/ZhikangNiu/encodec-pytorch/blob/main/scheduler.py + +# Copyright 2024 Zhi-Kang Niu +# MIT License + import math from bisect import bisect_right diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index c2dac0f15e..bf231c5b66 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -1,3 +1,21 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Author: Zengwei Yao) +# 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import itertools import logging From 01a003a67566ed4f1f871c5b48d49e2382a3f328 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 21 Oct 2024 11:14:59 +0800 Subject: [PATCH 33/33] black formatted --- egs/libritts/CODEC/encodec/binary.py | 2 +- egs/libritts/CODEC/encodec/loss.py | 2 +- egs/libritts/CODEC/encodec/modules/seanet.py | 4 ++-- egs/libritts/CODEC/encodec/quantization/ac.py | 16 ++++++++-------- .../CODEC/encodec/quantization/core_vq.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/egs/libritts/CODEC/encodec/binary.py b/egs/libritts/CODEC/encodec/binary.py index 0f552fb99f..003bcfaf59 100644 --- a/egs/libritts/CODEC/encodec/binary.py +++ b/egs/libritts/CODEC/encodec/binary.py @@ -132,7 +132,7 @@ def test(): for rep in range(4): length: int = torch.randint(10, 2_000, (1,)).item() bits: int = torch.randint(1, 16, (1,)).item() - tokens: List[int] = torch.randint(2 ** bits, (length,)).tolist() + tokens: List[int] = torch.randint(2**bits, (length,)).tolist() rebuilt: List[int] = [] buf = io.BytesIO() packer = BitPacker(bits, buf) diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py index 8675841b27..9cf1d42d28 100644 --- a/egs/libritts/CODEC/encodec/loss.py +++ b/egs/libritts/CODEC/encodec/loss.py @@ -235,7 +235,7 @@ def __init__( super().__init__() self.wav_to_specs = [] for i in range(5, 12): - s = 2 ** i + s = 2**i self.wav_to_specs.append( MelSpectrogram( sample_rate=sampling_rate, diff --git a/egs/libritts/CODEC/encodec/modules/seanet.py b/egs/libritts/CODEC/encodec/modules/seanet.py index 38f2f8728c..76999b2984 100644 --- a/egs/libritts/CODEC/encodec/modules/seanet.py +++ b/egs/libritts/CODEC/encodec/modules/seanet.py @@ -161,7 +161,7 @@ def __init__( SEANetResnetBlock( mult * n_filters, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], + dilations=[dilation_base**j, 1], norm=norm, norm_params=norm_params, activation=activation, @@ -311,7 +311,7 @@ def __init__( SEANetResnetBlock( mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], + dilations=[dilation_base**j, 1], activation=activation, activation_params=activation_params, norm=norm, diff --git a/egs/libritts/CODEC/encodec/quantization/ac.py b/egs/libritts/CODEC/encodec/quantization/ac.py index 99b62d14ba..8d8a770caf 100644 --- a/egs/libritts/CODEC/encodec/quantization/ac.py +++ b/egs/libritts/CODEC/encodec/quantization/ac.py @@ -41,7 +41,7 @@ def build_stable_quantized_cdf( if roundoff: pdf = (pdf / roundoff).floor() * roundoff # interpolate with uniform distribution to achieve desired minimum probability. - total_range = 2 ** total_range_bits + total_range = 2**total_range_bits cardinality = len(pdf) alpha = min_range * cardinality / total_range assert alpha <= 1, "you must reduce min_range" @@ -51,7 +51,7 @@ def build_stable_quantized_cdf( if min_range < 2: raise ValueError("min_range must be at least 2.") if check: - assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1] + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] if ( (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range ).any() or quantized_cdf[0] < min_range: @@ -142,7 +142,7 @@ def push(self, symbol: int, quantized_cdf: Tensor): quantized_cdf (Tensor): use `build_stable_quantized_cdf` to build this from your pdf estimate. """ - while self.delta < 2 ** self.total_range_bits: + while self.delta < 2**self.total_range_bits: self.low *= 2 self.high = self.high * 2 + 1 self.max_bit += 1 @@ -150,10 +150,10 @@ def push(self, symbol: int, quantized_cdf: Tensor): range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() range_high = quantized_cdf[symbol].item() - 1 effective_low = int( - math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))) + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) ) effective_high = int( - math.floor(range_high * (self.delta / (2 ** self.total_range_bits))) + math.floor(range_high * (self.delta / (2**self.total_range_bits))) ) assert self.low <= self.high self.high = self.low + effective_high @@ -238,7 +238,7 @@ def pull(self, quantized_cdf: Tensor) -> Optional[int]: to build this from your pdf estimate. This must be **exatly** the same cdf as the one used at encoding time. """ - while self.delta < 2 ** self.total_range_bits: + while self.delta < 2**self.total_range_bits: bit = self.unpacker.pull() if bit is None: return None @@ -255,10 +255,10 @@ def bin_search(low_idx: int, high_idx: int): range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 range_high = quantized_cdf[mid].item() - 1 effective_low = int( - math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))) + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) ) effective_high = int( - math.floor(range_high * (self.delta / (2 ** self.total_range_bits))) + math.floor(range_high * (self.delta / (2**self.total_range_bits))) ) low = effective_low + self.low high = effective_high + self.low diff --git a/egs/libritts/CODEC/encodec/quantization/core_vq.py b/egs/libritts/CODEC/encodec/quantization/core_vq.py index 0b342f2b0d..4719e20f7f 100644 --- a/egs/libritts/CODEC/encodec/quantization/core_vq.py +++ b/egs/libritts/CODEC/encodec/quantization/core_vq.py @@ -76,7 +76,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): for _ in range(num_iters): diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") - dists = -(diffs ** 2).sum(dim=-1) + dists = -(diffs**2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters)