-
Notifications
You must be signed in to change notification settings - Fork 444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor hotwords,support loading hotwords from file #296
Conversation
cmake/sentencepiece.cmake
Outdated
function(download_sentencepiece) | ||
include(FetchContent) | ||
|
||
set(sentencepiece_URL "https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.tar.gz") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does sentencepiece support arm64 and arm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but I have to check it. I once compiled it successfully in andriod platform.
Finally, I find it is a little hassle to encode hotwords in C++ side, because sentencepiece and onnxruntime both depend on google's protobuf, when link them statically symbols conficts occurs. One way to fix this is to change some code in sentencepiece and maintain our own branch. To make the depend libs of sherpa-onnx as less as possible, we decide to encode the hotwords outside (with a python command line tool). The command line tool is inside sherpa-onnx, you can use it like this:
|
Great work! Would love to see this integrated into iOS/Android just like in sherpa-ncnn 😄 |
@@ -0,0 +1,5539 @@ | |||
<blk> 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we download the test data files from huggingface in GitHub actions?
If we want to run the tests locally, we can either skip the tests if the test data files do not exist or download them by ourselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I put them in the repo, because I think they are small. I think github also supports lfs, can we put it there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to use git lfs
to manage them. You can upload them to github directly without git lfs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get your idea, you mean upload to github but not in our repo? then where?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I put them in the repo, because I think
If you want to use GitHub, you can create a repo for it.
The point is that we can download it from GitHub actions during the test.
tokens_type: | ||
The valid values are cjkchar, bpe, cjkchar+bpe. | ||
bpe_model: | ||
The path of the bpe model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please document that it is required only when
tokens_type
is bpe
or cjkchar+bpe
.
else: | ||
assert modeling_unit == "bpe+char", modeling_unit | ||
if "bpe" in tokens_type: | ||
assert Path(bpe_model).is_file, f"File not exists, {bpe_model}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert Path(bpe_model).is_file, f"File not exists, {bpe_model}" | |
assert Path(bpe_model).is_file(), f"File not exists, {bpe_model}" |
texts_list = [list("".join(text.split())) for text in texts] | ||
elif "bpe" == tokens_type: | ||
assert ( | ||
sp is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is sp
always not None in this if branch? If not, in which case is sp
None?
else: | ||
assert ( | ||
"cjkchar+bpe" == tokens_type | ||
), f"Supporting tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
), f"Supporting tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}" | |
), f"Supported tokens_type are cjkchar, bpe, cjkchar+bpe, given {tokens_type}" |
sherpa-onnx/csrc/utils.h
Outdated
* @param is The input stream, it contains several lines, one hotwords for each | ||
* line. For each hotword, the tokens (cjkchar or bpe) are separated | ||
* by spaces. | ||
* @symbol_table The tokens table mapping symbols to ids. All the symbols in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* @symbol_table The tokens table mapping symbols to ids. All the symbols in | |
* @param symbol_table The tokens table mapping symbols to ids. All the symbols in |
sherpa-onnx/csrc/utils.h
Outdated
* the stream should be in the symbol_table, if not this function | ||
* returns fasle. | ||
* | ||
* @hotwords The encoded ids to be written to. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* @hotwords The encoded ids to be written to. | |
* @param hotwords The encoded ids to be written to. |
sherpa-onnx/csrc/online-recognizer.h
Outdated
: feat_config(feat_config), | ||
model_config(model_config), | ||
endpoint_config(endpoint_config), | ||
enable_endpoint(enable_endpoint), | ||
decoding_method(decoding_method), | ||
max_active_paths(max_active_paths), | ||
context_score(context_score) {} | ||
hotwords_file(hotwords_file), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the same order as the one when you define them.
@@ -53,7 +61,8 @@ std::string OfflineRecognizerConfig::ToString() const { | |||
os << "lm_config=" << lm_config.ToString() << ", "; | |||
os << "decoding_method=\"" << decoding_method << "\", "; | |||
os << "max_active_paths=" << max_active_paths << ", "; | |||
os << "context_score=" << context_score << ")"; | |||
os << "hotwords_file=" << hotwords_file << ", "; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os << "hotwords_file=" << hotwords_file << ", "; | |
os << "hotwords_file=\"" << hotwords_file << "\", "; |
for (int32_t j = 0; j < word_len; ++j) { | ||
tmp.push_back(char_dist(mt)); | ||
} | ||
contexts.push_back(tmp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contexts.push_back(tmp); | |
contexts.push_back(tmp); |
contexts.push_back(tmp); | |
contexts.push_back(std::move(tmp)); |
scripts/text2token.py
Outdated
"--tokens-type", | ||
type=str, | ||
required=True, | ||
default="cjkchar", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you give it a default value, please remove required=True
.
@@ -342,6 +367,9 @@ def check_args(args): | |||
assert Path(args.decoder).is_file(), args.decoder | |||
assert Path(args.joiner).is_file(), args.joiner | |||
|
|||
if args.hotwords_file != "": | |||
assert args.decoding_method == "modified_beam_search", args.decoding_method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert args.decoding_method == "modified_beam_search", args.decoding_method | |
assert args.decoding_method == "modified_beam_search", args.decoding_method | |
assert Path(args.hotwords_file).is_file(), args.hotwords_file |
], encoded_ids | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please skip the test if the expected directory does not exist and print a message to tell users the test is skipped.
print( | ||
f"No test data found, skipping test_bpe().\n" | ||
f"You can download the test data by: \n" | ||
f"git clone [email protected]:pkufool/sherpa-test-data.git /tmp/sherpa-test-data" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use https
to download it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
This PR do some refactoring to the hotwords pipeline, mainly to support loading hotwords from file and encode hotwords in c++ side (this could be eaiser when wrap to other programing language), I think encode hotwords internally will be more user friendly, so that, users can provide hotwords in the most natural way.