Skip to content

Commit

Permalink
Merge pull request #2176 from coqui-ai/wav2vec2-decoder
Browse files Browse the repository at this point in the history
Wav2vec2 decoder
  • Loading branch information
reuben authored Apr 1, 2022
2 parents c9e73ee + e36e731 commit bb75afb
Show file tree
Hide file tree
Showing 29 changed files with 1,020 additions and 514 deletions.
24 changes: 13 additions & 11 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,17 @@ jobs:
fetch-depth: 1
- uses: actions/setup-python@v2
with:
python-version: 3.6
python-version: "3.7"
- uses: actions/download-artifact@v2
with:
name: "coqui_stt_ctcdecoder-Linux-3.6.whl"
name: "coqui_stt_ctcdecoder-Linux-3.7.whl"
- run: |
python --version
pip --version
- run: |
pip install --upgrade pip setuptools wheel
- run: |
pip install coqui_stt_ctcdecoder-*-cp36-cp36m-*_x86_64.whl
pip install coqui_stt_ctcdecoder-*-cp37-cp37m-*_x86_64.whl
DS_NODECODER=y pip install --upgrade .
- run: |
# Easier to rename to that we can exercize the LDC93S1 importer code to
Expand Down Expand Up @@ -540,7 +540,7 @@ jobs:
if: ${{ github.event_name == 'pull_request' }}
strategy:
matrix:
python-version: ["3.6", "3.7"]
python-version: ["3.7"]
samplerate: ["8000", "16000"]
env:
CI_TMP_DIR: ${{ github.workspace }}/tmp/
Expand Down Expand Up @@ -700,7 +700,7 @@ jobs:
- run: |
python -m pip install --upgrade pip setuptools wheel jupyter
- run: |
python -m pip install coqui_stt_ctcdecoder-*-cp37-cp37m-*_x86_64.whl
python -m pip install coqui_stt_ctcdecoder*.whl
DS_NODECODER=y python -m pip install --upgrade .
- name: Run python notebooks
run: |
Expand All @@ -713,7 +713,7 @@ jobs:
strategy:
matrix:
samplerate: ["8000", "16000"]
pyver: [3.6, 3.7]
pyver: ["3.7"]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down Expand Up @@ -779,7 +779,7 @@ jobs:
strategy:
matrix:
samplerate: ["8000", "16000"]
pyver: [3.6, 3.7]
pyver: ["3.7"]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down Expand Up @@ -830,7 +830,7 @@ jobs:
strategy:
matrix:
samplerate: ["8000", "16000"]
pyver: [3.6, 3.7]
pyver: ["3.7"]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down Expand Up @@ -874,7 +874,7 @@ jobs:
strategy:
matrix:
samplerate: ["8000", "16000"]
pyver: [3.6, 3.7]
pyver: ["3.7"]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down Expand Up @@ -914,7 +914,7 @@ jobs:
strategy:
matrix:
samplerate: ["8000", "16000"]
pyver: [3.6, 3.7]
pyver: ["3.7"]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down Expand Up @@ -950,7 +950,7 @@ jobs:
strategy:
matrix:
samplerate: ["8000", "16000"]
pyver: [3.6, 3.7]
pyver: ["3.7"]
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
Expand Down Expand Up @@ -1193,6 +1193,8 @@ jobs:
docker-publish:
name: "Build and publish Docker training image to GHCR"
runs-on: ubuntu-20.04
needs: [upload-nc-release-assets]
if: always()
steps:
- uses: actions/checkout@v2
with:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: 'https://github.com/psf/black'
rev: "22.1.0"
rev: "22.3.0"
hooks:
- id: black
language_version: python3
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.build
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ RUN bazel build \
//native_client:libstt.so

# Copy built libs to /STT/native_client
RUN cp bazel-bin/native_client/libstt.so /STT/native_client/
RUN cp bazel-bin/native_client/libstt.so bazel-bin/native_client/libkenlm.so /STT/native_client/

# Build client.cc and install Python client and decoder bindings
ENV TFDIR /STT/tensorflow
Expand Down
Binary file modified data/smoke_test/LDC93S1.wav
Binary file not shown.
50 changes: 10 additions & 40 deletions lm_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
from __future__ import absolute_import, division, print_function

import sys

from coqui_stt_training.train import early_training_checks
from coqui_stt_training.util.config import (
Config,
initialize_globals_from_cli,
log_error,
)

from coqui_stt_training.util import lm_optimize as lm_opt


def main():
initialize_globals_from_cli()
early_training_checks()

if not Config.scorer_path:
log_error(
"Missing --scorer_path: can't optimize scorer alpha and beta "
"parameters without a scorer!"
)
sys.exit(1)

if not Config.test_files:
log_error(
"You need to specify what files to use for evaluation via "
"the --test_files flag."
)
sys.exit(1)

results = lm_opt.compute_lm_optimization()
if __name__ == "__main__":
print(
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
results.get("lm_alpha"),
results.get("lm_beta"),
results.get("wer"),
)
"Using the top level lm_optimizer.py script is deprecated and will be removed "
"in a future release. Instead use: python -m coqui_stt_training.util.lm_optimize"
)
try:
from coqui_stt_training.util import lm_optimize
except ImportError:
print("Training package is not installed. See training documentation.")
raise


if __name__ == "__main__":
main()
lm_optimize.main()
14 changes: 12 additions & 2 deletions native_client/alphabet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ Alphabet::SerializeText()
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
<< "# to use '#' in the Alphabet.\n";

for (int idx = 0; idx < entrySize(); ++idx) {
out << getEntry(idx) << "\n";
for (const std::string& label : GetLabels()) {
out << label << "\n";
}

out << "# The last (non-comment) line needs to end with a newline.\n";
Expand Down Expand Up @@ -174,6 +174,16 @@ Alphabet::GetSize() const
return entrySize();
}

std::vector<std::string>
Alphabet::GetLabels() const
{
std::vector<std::string> labels;
for (int idx = 0; idx < GetSize(); ++idx) {
labels.push_back(DecodeSingle(idx));
}
return labels;
}

bool
Alphabet::CanEncodeSingle(const std::string& input) const
{
Expand Down
39 changes: 20 additions & 19 deletions native_client/alphabet.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,42 @@ class Alphabet : public fl::lib::text::Dictionary

size_t GetSize() const;

bool IsSpace(unsigned int label) const {
return label == space_index_;
bool IsSpace(unsigned int index) const {
return index == space_index_;
}

unsigned int GetSpaceLabel() const {
return space_index_;
}

// Returns true if the single character/output class has a corresponding label
virtual std::vector<std::string> GetLabels() const;

// Returns true if the single character/output class has a corresponding index
// in the alphabet.
virtual bool CanEncodeSingle(const std::string& string) const;
virtual bool CanEncodeSingle(const std::string& label) const;

// Returns true if the entire string can be encoded into labels in this
// alphabet.
virtual bool CanEncode(const std::string& string) const;
// Returns true if the entire string can be encoded with this alphabet.
virtual bool CanEncode(const std::string& label) const;

// Decode a single label into a string.
std::string DecodeSingle(unsigned int label) const;
// Decode a single index into its label.
std::string DecodeSingle(unsigned int index) const;

// Encode a single character/output class into a label. Character must be in
// Encode a single character/output class into its index. Character must be in
// the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
unsigned int EncodeSingle(const std::string& string) const;
unsigned int EncodeSingle(const std::string& label) const;

// Decode a sequence of labels into a string.
std::string Decode(const std::vector<unsigned int>& input) const;
// Decode a sequence of indices into a string.
std::string Decode(const std::vector<unsigned int>& indices) const;

// We provide a C-style overload for accepting NumPy arrays as input, since
// the NumPy library does not have built-in typemaps for std::vector<T>.
std::string Decode(const unsigned int* input, int length) const;
std::string Decode(const unsigned int* indices, int length) const;

// Encode a sequence of character/output classes into a sequence of labels.
// Encode a sequence of character/output classes into a sequence of indices.
// Characters are assumed to always take a single Unicode codepoint.
// Characters must be in the alphabet, this method will assert that. Use
// `CanEncode` and `CanEncodeSingle` to test.
virtual std::vector<unsigned int> Encode(const std::string& input) const;
virtual std::vector<unsigned int> Encode(const std::string& labels) const;

protected:
unsigned int space_index_;
Expand All @@ -93,9 +94,9 @@ class UTF8Alphabet : public Alphabet
return 0;
}

bool CanEncodeSingle(const std::string& string) const override;
bool CanEncode(const std::string& string) const override;
std::vector<unsigned int> Encode(const std::string& input) const override;
bool CanEncodeSingle(const std::string& label) const override;
bool CanEncode(const std::string& label) const override;
std::vector<unsigned int> Encode(const std::string& label) const override;
};

#endif //ALPHABET_H
63 changes: 63 additions & 0 deletions native_client/ctcdecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,69 @@ def ctc_beam_search_decoder(
return beam_results


def ctc_beam_search_decoder_for_wav2vec2am(
probs_seq,
alphabet,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
blank_id=-1,
ignored_symbols=frozenset(),
scorer=None,
hot_words=dict(),
num_results=1,
):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over alphabet and blank.
:type probs_seq: 2-D list
:param alphabet: Alphabet
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in alphabet will be
used in beam search, default 40.
:type cutoff_top_n: int
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: dict[string, float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence.
:rtype: list
"""
beam_results = swigwrapper.ctc_beam_search_decoder_for_wav2vec2am(
probs_seq,
alphabet,
beam_size,
cutoff_prob,
cutoff_top_n,
blank_id,
ignored_symbols,
scorer,
hot_words,
num_results,
)
beam_results = [
DecodeResult(
res.confidence,
alphabet.Decode(res.tokens),
[int(t) for t in res.tokens],
[int(t) for t in res.timesteps],
)
for res in beam_results
]
return beam_results


def ctc_beam_search_decoder_batch(
probs_seq,
seq_lengths,
Expand Down
Loading

0 comments on commit bb75afb

Please sign in to comment.