Skip to content

Commit

Permalink
Add lm decode for the Python API.
Browse files Browse the repository at this point in the history
  • Loading branch information
kamirdin committed Oct 10, 2023
1 parent 8455057 commit 909c4af
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python-api-examples/online-decode-files.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,24 @@ def get_args():
""",
)

parser.add_argument(
"--lm",
type=str,
default=0.1,
help="""Used only when --decoding-method is modified_beam_search.
path of language model.
""",
)

parser.add_argument(
"--lm-scale",
type=float,
default=0.1,
help="""Used only when --decoding-method is modified_beam_search.
scale of language model.
""",
)

parser.add_argument(
"--provider",
type=str,
Expand Down Expand Up @@ -215,6 +233,8 @@ def main():
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
lm=args.lm,
scale=args.lm_scale,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
Expand Down
1 change: 1 addition & 0 deletions sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::arg("hotwords_score") = 0)
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("lm_config", &PyClass::lm_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
Expand Down
15 changes: 15 additions & 0 deletions sherpa-onnx/python/sherpa_onnx/online_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
OnlineRecognizer as _Recognizer,
OnlineRecognizerConfig,
OnlineStream,
OnlineLMConfig,
OnlineTransducerModelConfig,
)

Expand Down Expand Up @@ -46,6 +47,8 @@ def from_transducer(
hotwords_file: str = "",
provider: str = "cpu",
model_type: str = "",
lm:str = "",
scale:float = 0.1,
):
"""
Please refer to
Expand Down Expand Up @@ -137,10 +140,22 @@ def from_transducer(
"Please use --decoding-method=modified_beam_search when using "
f"--hotwords-file. Currently given: {decoding_method}"
)

if lm and decoding_method != "modified_beam_search":
raise ValueError(
"Please use --decoding-method=modified_beam_search when using "
f"--lm. Currently given: {decoding_method}"
)

lm_config = OnlineLMConfig(
model = lm,
scale = scale,
)

recognizer_config = OnlineRecognizerConfig(
feat_config=feat_config,
model_config=model_config,
lm_config = lm_config ,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
Expand Down

0 comments on commit 909c4af

Please sign in to comment.