Skip to content

Commit

Permalink
🚀 Use vocabulary config for characters and subwords
Browse files Browse the repository at this point in the history
  • Loading branch information
nglehuy committed Oct 16, 2020
1 parent fc1fd9f commit 6729e44
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 77 deletions.
11 changes: 5 additions & 6 deletions examples/conformer/test_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
parser.add_argument("--cpu", default=False, action="store_true",
help="Whether to only use cpu")

parser.add_argument("--subwords_prefix", type=str, default=None,
help="Prefix of file that stores generated subwords")
parser.add_argument("--subwords", type=str, default=None,
help="Path to file that stores generated subwords")

parser.add_argument("--output_name", type=str, default="test",
help="Result filename name prefix")
Expand All @@ -65,12 +65,11 @@
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])

if args.subwords_prefix and os.path.exists(f"{args.subwords_prefix}.subwords"):
if args.subwords and os.path.exists(args.subwords):
print("Loading subwords ...")
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"],
args.subwords_prefix)
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
else:
raise ValueError("subwords_prefix must be set")
raise ValueError("subwords must be set")

tf.random.set_seed(0)
assert args.saved
Expand Down
11 changes: 5 additions & 6 deletions examples/conformer/tflite_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
parser.add_argument("--saved", type=str, default=None,
help="Path to saved model")

parser.add_argument("--subwords_prefix", type=str, default=None,
help="Prefix of file that stores generated subwords")
parser.add_argument("--subwords", type=str, default=None,
help="Path to file that stores generated subwords")

parser.add_argument("output", type=str, default=None,
help="TFLite file path to be exported")
Expand All @@ -49,12 +49,11 @@
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])

if args.subwords_prefix and os.path.exists(f"{args.subwords_prefix}.subwords"):
if args.subwords and os.path.exists(args.subwords):
print("Loading subwords ...")
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"],
args.subwords_prefix)
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
else:
raise ValueError("subwords_prefix must be set")
raise ValueError("subwords must be set")

# build model
conformer = Conformer(
Expand Down
131 changes: 131 additions & 0 deletions examples/conformer/train_ga_conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import math
import argparse
from tensorflow_asr.utils import setup_environment, setup_strategy

setup_environment()
import tensorflow as tf

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.keras.backend.clear_session()

parser = argparse.ArgumentParser(prog="Conformer Training")

parser.add_argument("--config", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")

parser.add_argument("--max_ckpts", type=int, default=10,
help="Max number of checkpoints to keep")

parser.add_argument("--tfrecords", default=False, action="store_true",
help="Whether to use tfrecords")

parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replica")

parser.add_argument("--ebs", type=int, default=None,
help="Evaluation batch size per replica")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")

parser.add_argument("--mxp", default=False, action="store_true",
help="Enable mixed precision")

parser.add_argument("--cache", default=False, action="store_true",
help="Enable caching for dataset")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})

strategy = setup_strategy(args.devices)

from tensorflow_asr.configs.user_config import UserConfig
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
text_featurizer = CharFeaturizer(config["decoder_config"])

if args.tfrecords:
train_dataset = ASRTFRecordDataset(
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config["learning_config"]["augmentations"],
stage="train", cache=args.cache, shuffle=True
)
eval_dataset = ASRTFRecordDataset(
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
)
else:
train_dataset = ASRSliceDataset(
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config["learning_config"]["augmentations"],
stage="train", cache=args.cache, shuffle=True
)
eval_dataset = ASRSliceDataset(
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
)

conformer_trainer = TransducerTrainerGA(
config=config["learning_config"]["running_config"],
text_featurizer=text_featurizer, strategy=strategy
)

with conformer_trainer.strategy.scope():
# build model
conformer = Conformer(
**config["model_config"],
vocabulary_size=text_featurizer.num_classes
)
conformer._build(speech_featurizer.shape)
conformer.summary(line_length=120)

optimizer_config = config["learning_config"]["optimizer_config"]
optimizer = tf.keras.optimizers.Adam(
TransformerSchedule(
d_model=config["model_config"]["dmodel"],
warmup_steps=optimizer_config["warmup_steps"],
max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"]))
),
beta_1=optimizer_config["beta1"],
beta_2=optimizer_config["beta2"],
epsilon=optimizer_config["epsilon"]
)

conformer_trainer.compile(model=conformer, optimizer=optimizer,
max_to_keep=args.max_ckpts)

conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
147 changes: 147 additions & 0 deletions examples/conformer/train_ga_subword_conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import math
import argparse
from tensorflow_asr.utils import setup_environment, setup_strategy

setup_environment()
import tensorflow as tf

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

tf.keras.backend.clear_session()

parser = argparse.ArgumentParser(prog="Conformer Training")

parser.add_argument("--config", type=str, default=DEFAULT_YAML,
help="The file path of model configuration file")

parser.add_argument("--max_ckpts", type=int, default=10,
help="Max number of checkpoints to keep")

parser.add_argument("--tfrecords", default=False, action="store_true",
help="Whether to use tfrecords")

parser.add_argument("--tbs", type=int, default=None,
help="Train batch size per replica")

parser.add_argument("--ebs", type=int, default=None,
help="Evaluation batch size per replica")

parser.add_argument("--devices", type=int, nargs="*", default=[0],
help="Devices' ids to apply distributed training")

parser.add_argument("--mxp", default=False, action="store_true",
help="Enable mixed precision")

parser.add_argument("--cache", default=False, action="store_true",
help="Enable caching for dataset")

parser.add_argument("--subwords", type=str, default=None,
help="Path to file that stores generated subwords")

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
help="Transcript files for generating subwords")

args = parser.parse_args()

tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})

strategy = setup_strategy(args.devices)

from tensorflow_asr.configs.user_config import UserConfig
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
from tensorflow_asr.runners.transducer_runners import TransducerTrainerGA
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.optimizers.schedules import TransformerSchedule

config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])

if args.subwords and os.path.exists(args.subwords):
print("Loading subwords ...")
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
else:
print("Generating subwords ...")
text_featurizer = SubwordFeaturizer.build_from_corpus(
config["decoder_config"],
corpus_files=args.subwords_corpus
)
text_featurizer.subwords.save_to_file(args.subwords_prefix)

if args.tfrecords:
train_dataset = ASRTFRecordDataset(
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config["learning_config"]["augmentations"],
stage="train", cache=args.cache, shuffle=True
)
eval_dataset = ASRTFRecordDataset(
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
)
else:
train_dataset = ASRSliceDataset(
data_paths=config["learning_config"]["dataset_config"]["train_paths"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
augmentations=config["learning_config"]["augmentations"],
stage="train", cache=args.cache, shuffle=True
)
eval_dataset = ASRSliceDataset(
data_paths=config["learning_config"]["dataset_config"]["eval_paths"],
speech_featurizer=speech_featurizer,
text_featurizer=text_featurizer,
stage="eval", cache=args.cache, shuffle=True
)

conformer_trainer = TransducerTrainerGA(
config=config["learning_config"]["running_config"],
text_featurizer=text_featurizer, strategy=strategy
)

with conformer_trainer.strategy.scope():
# build model
conformer = Conformer(
**config["model_config"],
vocabulary_size=text_featurizer.num_classes
)
conformer._build(speech_featurizer.shape)
conformer.summary(line_length=120)

optimizer_config = config["learning_config"]["optimizer_config"]
optimizer = tf.keras.optimizers.Adam(
TransformerSchedule(
d_model=config["model_config"]["dmodel"],
warmup_steps=optimizer_config["warmup_steps"],
max_lr=(0.05 / math.sqrt(config["model_config"]["dmodel"]))
),
beta_1=optimizer_config["beta1"],
beta_2=optimizer_config["beta2"],
epsilon=optimizer_config["epsilon"]
)

conformer_trainer.compile(model=conformer, optimizer=optimizer,
max_to_keep=args.max_ckpts)

conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
9 changes: 4 additions & 5 deletions examples/conformer/train_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
parser.add_argument("--cache", default=False, action="store_true",
help="Enable caching for dataset")

parser.add_argument("--subwords_prefix", type=str, default=None,
help="Prefix of file that stores generated subwords")
parser.add_argument("--subwords", type=str, default=None,
help="Path to file that stores generated subwords")

parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[],
help="Transcript files for generating subwords")
Expand All @@ -73,10 +73,9 @@
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])

if args.subwords_prefix and os.path.exists(f"{args.subwords_prefix}.subwords"):
if args.subwords and os.path.exists(args.subwords):
print("Loading subwords ...")
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"],
args.subwords_prefix)
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
else:
print("Generating subwords ...")
text_featurizer = SubwordFeaturizer.build_from_corpus(
Expand Down
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,14 @@

setuptools.setup(
name="TensorFlowASR",
version="0.2.5",
version="0.2.6",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/TensorSpeech/TensorFlowASR",
packages=setuptools.find_packages(include=["tensorflow_asr*"]),
package_data={
"tensorflow_asr": ["featurizers/*.txt"]
},
install_requires=requirements,
classifiers=[
"Programming Language :: Python :: 3.6",
Expand Down
16 changes: 0 additions & 16 deletions tensorflow_asr/featurizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +0,0 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

ENGLISH = os.path.abspath(os.path.join(os.path.dirname(__file__), "english.txt"))
Loading

0 comments on commit 6729e44

Please sign in to comment.