Skip to content

Commit

Permalink
Consider escaped characters as single characters in BPE (#322)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Mar 1, 2023
1 parent 128b900 commit e8cf860
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
12 changes: 12 additions & 0 deletions bindings/python/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,18 @@ def test_bpe_learner_no_pairs(tmpdir):
tokenizer = learner.learn(model_path)


def test_bpe_learner_escaped_character(tmpdir):
text = "คุณอาจจะทำอย ่ างนั ้ นไปซักพัก จนคุณเริ ่ มจะรู ้ สึกถึงมันจริงๆ"

tokenizer = pyonmttok.Tokenizer("aggressive", joiner_annotate=True)
learner = pyonmttok.BPELearner(tokenizer=tokenizer, symbols=5, min_frequency=1)
learner.ingest(text)
tokenizer = learner.learn(str(tmpdir.join("bpe.model")))

tokens = tokenizer(text)
assert "■%0020่" in tokens


@pytest.mark.parametrize("keep_vocab", [False, True])
def test_sp_learner(tmpdir, keep_vocab):
learner = pyonmttok.SentencePieceLearner(
Expand Down
2 changes: 2 additions & 0 deletions include/onmt/Tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ namespace onmt
static const std::string spacer_marker;
static const std::string ph_marker_open;
static const std::string ph_marker_close;
static const std::string escaped_character_prefix;
static const size_t escaped_character_width;

Tokenizer(Options options,
const std::shared_ptr<const SubwordEncoder>& subword_encoder = nullptr);
Expand Down
16 changes: 15 additions & 1 deletion src/BPE.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,23 @@ namespace onmt
std::vector<std::string> pieces;
pieces.reserve(chars.size());

static const auto escaped_character_prefix = (
unicode::utf8_to_cp(Tokenizer::escaped_character_prefix.c_str()));
size_t escaped_character_length = 0;

for (const auto& c : chars)
{
if (c.char_type == unicode::CharType::Mark)
if (escaped_character_length > 0)
{
pieces.back().append(c.data, c.length);
escaped_character_length--;
}
else if (c.value == escaped_character_prefix)
{
pieces.emplace_back(c.data, c.length);
escaped_character_length = Tokenizer::escaped_character_width;
}
else if (c.char_type == unicode::CharType::Mark)
{
if (pieces.empty())
pieces.emplace_back(c.data, c.length);
Expand Down
23 changes: 13 additions & 10 deletions src/Tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ namespace onmt
const std::string Tokenizer::spacer_marker = "";
const std::string Tokenizer::ph_marker_open = "";
const std::string Tokenizer::ph_marker_close = "";
const std::string Tokenizer::escaped_character_prefix = "";
const size_t Tokenizer::escaped_character_width = 4;
static const unicode::code_point_t ph_marker_open_cp = 0xFF5F;
static const unicode::code_point_t ph_marker_close_cp = 0xFF60;
static const std::string protected_character = "";
static const std::vector<std::pair<unicode::code_point_t, std::string>> substitutes = {
{0x2581 /**/, "_"},
{0xFFED /**/, ""},
Expand All @@ -32,7 +33,6 @@ namespace onmt

static const int placeholder_alphabet = -2;
static const int number_alphabet = -3;
static const int hex_value_width = 4;

Tokenizer::Mode Tokenizer::str_to_mode(const std::string& mode)
{
Expand Down Expand Up @@ -290,21 +290,23 @@ namespace onmt

static inline void unescape_characters(std::string& str)
{
const auto& prefix = Tokenizer::escaped_character_prefix;
const auto& width = Tokenizer::escaped_character_width;

for (size_t offset = 0;;)
{
const size_t index = str.find(protected_character, offset);
if (index == std::string::npos
|| index + protected_character.size() + hex_value_width > str.size())
const size_t index = str.find(prefix, offset);
if (index == std::string::npos || index + prefix.size() + width > str.size())
break;

const std::string code = str.substr(index + protected_character.size(), hex_value_width);
const std::string code = str.substr(index + prefix.size(), width);
const int v = hex_to_int(code);
const std::string c = unicode::cp_to_utf8(v);
if (c.empty() || !c[0] || int_to_hex(v, hex_value_width) != code)
offset = index + protected_character.size();
if (c.empty() || !c[0] || int_to_hex(v, width) != code)
offset = index + prefix.size();
else
{
str.replace(index, protected_character.size() + hex_value_width, c);
str.replace(index, prefix.size() + width, c);
offset = index + 1;
}
}
Expand Down Expand Up @@ -635,7 +637,8 @@ namespace onmt
if (_no_substitution)
append(character);
else
append(protected_character + int_to_hex(character.value, hex_value_width));
append(Tokenizer::escaped_character_prefix
+ int_to_hex(character.value, Tokenizer::escaped_character_width));
}

void flush_feature()
Expand Down

0 comments on commit e8cf860

Please sign in to comment.