From 165f368dd0d5d585d388a90b3407dfbd94429a4f Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 12 Oct 2024 20:25:48 +0900 Subject: [PATCH 1/9] Fix #192 --- src/KTrie.cpp | 22 ++++++++++++++++------ src/KTrie.h | 7 ++++--- src/Kiwi.cpp | 16 +++++++++------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/KTrie.cpp b/src/KTrie.cpp index 5b77bbb6..025b906f 100644 --- a/src/KTrie.cpp +++ b/src/KTrie.cpp @@ -1215,9 +1215,10 @@ size_t kiwi::splitByTrie( return n + startOffset; } -template +template const Form* kiwi::findForm( const utils::FrozenTrie& trie, + const Form* formData, const KString& str ) { @@ -1228,7 +1229,12 @@ const Form* kiwi::findForm( if (!node) return nullptr; } if (trie.hasSubmatch(node->val(trie))) return nullptr; - return node->val(trie); + auto ret = node->val(trie); + if (typoTolerant) + { + ret = &reinterpret_cast(ret)->form(formData); + } + return ret; } namespace kiwi @@ -1266,19 +1272,23 @@ FnSplitByTrie kiwi::getSplitByTrieFn(ArchType arch, bool typoTolerant, bool cont namespace kiwi { + template struct FindFormGetter { template struct Wrapper { - static constexpr FnFindForm value = &findForm(i)>; + static constexpr FnFindForm value = &findForm(i), typoTolerant>; }; }; } -FnFindForm kiwi::getFindFormFn(ArchType arch) +FnFindForm kiwi::getFindFormFn(ArchType arch, bool typoTolerant) { - static tp::Table table{ FindFormGetter{} }; + static std::array, 2> table{ + FindFormGetter{}, + FindFormGetter{}, + }; - return table[static_cast(arch)]; + return table[typoTolerant ? 1 : 0][static_cast(arch)]; } diff --git a/src/KTrie.h b/src/KTrie.h index c07e0636..e6ef78ed 100644 --- a/src/KTrie.h +++ b/src/KTrie.h @@ -103,17 +103,18 @@ namespace kiwi const PretokenizedSpanGroup::Span* pretokenizedLast ); - template + template const Form* findForm( const utils::FrozenTrie& trie, + const Form* formData, const KString& str ); using FnSplitByTrie = decltype(&splitByTrie); FnSplitByTrie getSplitByTrieFn(ArchType arch, bool typoTolerant, bool continualTypoTolerant, bool lengtheningTypoTolerant); - using FnFindForm = decltype(&findForm); - FnFindForm getFindFormFn(ArchType arch); + using FnFindForm = decltype(&findForm); + FnFindForm getFindFormFn(ArchType arch, bool typoTolerant); struct KTrie : public utils::TrieNode>, KTrie> { diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index 5a155e18..295ad550 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -53,7 +53,7 @@ namespace kiwi typoTolerant, continualTypoTolerant, lengtheningTypoTolerant); - dfFindForm = (void*)getFindFormFn(selectedArch); + dfFindForm = (void*)getFindFormFn(selectedArch, typoTolerant); static tp::Table lmKnLM_8{ FindBestPathGetter::type>{} }; static tp::Table lmKnLM_16{ FindBestPathGetter::type>{} }; @@ -802,7 +802,8 @@ namespace kiwi const Vector& positionTable, const KString& normStr, FnFindForm findForm, - const utils::FrozenTrie& formTrie + const utils::FrozenTrie& formTrie, + const Form* formData ) { if (pretokenized.empty()) return; @@ -833,7 +834,7 @@ namespace kiwi if (s.tokenization.empty()) { auto formStr = normStr.substr(span.begin, span.end - span.begin); - span.form = findForm(formTrie, formStr); // reuse the predefined form & morpheme + span.form = findForm(formTrie, formData, formStr); // reuse the predefined form & morpheme if (!span.form) // or use a fallback form { span.form = formTrie.value((size_t)POSTag::nnp); @@ -842,7 +843,7 @@ namespace kiwi else if (s.tokenization.size() == 1) { auto formStr = normalizeHangul(s.tokenization[0].form); - auto* tform = findForm(formTrie, formStr); + auto* tform = findForm(formTrie, formData, formStr); if (tform && tform->candidate.size() == 1 && areTagsEqual(tform->candidate[0]->tag, s.tokenization[0].tag, !!s.tokenization[0].inferRegularity)) // reuse the predefined form & morpheme @@ -908,7 +909,7 @@ namespace kiwi { auto& t = s.tokenization[i]; auto formStr = normalizeHangul(t.form); - auto* tform = findForm(formTrie, formStr); + auto* tform = findForm(formTrie, formData, formStr); const Morpheme* foundMorph = nullptr; if (tform) { @@ -999,7 +1000,8 @@ namespace kiwi positionTable, normalizedStr, reinterpret_cast(dfFindForm), - formTrie + formTrie, + forms.data() ); // 분석할 문장에 포함된 개별 문자에 대해 어절번호를 생성한다 @@ -1317,7 +1319,7 @@ namespace kiwi void Kiwi::findMorpheme(vector& ret, const u16string& s, POSTag tag) const { auto normalized = normalizeHangul(s); - auto form = (*reinterpret_cast(dfFindForm))(formTrie, normalized); + auto form = (*reinterpret_cast(dfFindForm))(formTrie, forms.data(), normalized); if (!form) return; tag = clearIrregular(tag); for (auto c : form->candidate) From 3b29e28f2e91dbd5669b162ce15589a5b7e71f7a Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 12 Oct 2024 20:27:10 +0900 Subject: [PATCH 2/9] Prevent XSM tag to appear at left boundaries --- include/kiwi/TagUtils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/kiwi/TagUtils.h b/include/kiwi/TagUtils.h index cdc38fd0..9bd5e396 100644 --- a/include/kiwi/TagUtils.h +++ b/include/kiwi/TagUtils.h @@ -35,7 +35,7 @@ namespace kiwi inline bool isSuffix(POSTag tag) { tag = clearIrregular(tag); - return POSTag::xsn <= tag && tag <= POSTag::xsa; + return POSTag::xsn <= tag && tag <= POSTag::xsm; } inline bool isSpecialClass(POSTag tag) From 01a22a8b74da5e36b1b06f5efe5f291f8dcfceea Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 12 Oct 2024 20:29:11 +0900 Subject: [PATCH 3/9] Add new features to `HSDataset` --- include/kiwi/Dataset.h | 7 +++- include/kiwi/Kiwi.h | 7 +++- src/Dataset.cpp | 51 +++++++++++++++++++------ src/KiwiBuilder.cpp | 86 +++++++++++++++++++++++++++++++++--------- test/test_cpp.cpp | 2 +- 5 files changed, 121 insertions(+), 32 deletions(-) diff --git a/include/kiwi/Dataset.h b/include/kiwi/Dataset.h index 6ce86c38..32be4879 100644 --- a/include/kiwi/Dataset.h +++ b/include/kiwi/Dataset.h @@ -47,16 +47,19 @@ namespace kiwi HiddenMember, sizeof(Vector) * 2> sents; std::shared_ptr knlm; std::unique_ptr workers; + std::shared_ptr dummyBuilder; std::discrete_distribution<> dropout; std::mt19937_64 rng; Vector locals; Vector shuffledIdx; Vector tokenToVocab, vocabToToken; + Vector windowTokenValidness; Deque> futures; const Vector* morphemes = nullptr; const Vector* forms = nullptr; size_t knlmVocabSize = 0; size_t batchSize = 0; + size_t causalContextSize = 0; size_t windowSize = 0; size_t totalTokens = 0; size_t passedSents = 0; @@ -68,7 +71,7 @@ namespace kiwi size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut); public: - HSDataset(size_t _batchSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0); + HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0); ~HSDataset(); HSDataset(const HSDataset&) = delete; HSDataset(HSDataset&&) /*noexcept*/; @@ -80,7 +83,9 @@ namespace kiwi size_t numTokens() const; size_t getBatchSize() const { return batchSize; } + size_t getCausalContextSize() const { return causalContextSize; } size_t getWindowSize() const { return windowSize; } + const Vector& getWindowTokenValidness() const { return windowTokenValidness; } void seed(size_t newSeed); void reset(); diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 7cec3b2b..44a14d5c 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -548,6 +548,8 @@ namespace kiwi using MorphemeMap = UnorderedMap, std::pair>; + void initMorphemes(); + template MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter); @@ -801,11 +803,14 @@ namespace kiwi using TokenFilter = std::function; HSDataset makeHSDataset(const std::vector& inputPathes, - size_t batchSize, size_t windowSize, size_t numWorkers, + size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb = 0, const TokenFilter& tokenFilter = {}, + const TokenFilter& windowFilter = {}, double splitRatio = 0, bool separateDefaultMorpheme = false, + const std::string& morphemeDefPath = {}, + size_t morphemeDefMinCnt = 0, HSDataset* splitDataset = nullptr ) const; }; diff --git a/src/Dataset.cpp b/src/Dataset.cpp index bf9bd1a4..f3e47edd 100644 --- a/src/Dataset.cpp +++ b/src/Dataset.cpp @@ -3,11 +3,12 @@ using namespace kiwi; -HSDataset::HSDataset(size_t _batchSize, size_t _windowSize, size_t _workers, double _dropoutProb) +HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb) : workers{ _workers ? make_unique(_workers) : nullptr }, dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} }, locals( _workers ? workers->size() : 1), batchSize{ _batchSize }, + causalContextSize{ _causalContextSize }, windowSize{ _windowSize } { } @@ -113,12 +114,21 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, local.lmLProbsBuf.resize(tokens.size()); local.outNgramNodeBuf.resize(tokens.size()); - knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin()); + if (knlm) + { + knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin()); + } auto& history = local.historyBuf; history.clear(); - history.resize(windowSize, -1); - history.back() = tokenToVocab[tokens[0]]; + if (windowSize) + { + history.resize(windowSize, -1); + if (windowTokenValidness[tokens[0]]) + { + history.back() = tokenToVocab[tokens[0]]; + } + } for (size_t i = 1; i < tokens.size(); ++i) { int32_t v = tokenToVocab[tokens[i]]; @@ -134,13 +144,32 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, local.restLmLProbsCntData[r] += 1; continue; } - std::copy(history.begin(), history.end(), std::back_inserter(local.inData)); + + if (causalContextSize) + { + for (size_t j = 0; j < causalContextSize; ++j) + { + local.inData.emplace_back(i + j < causalContextSize ? + nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]); + } + } + if (windowSize) + { + if (windowTokenValidness[v]) + { + std::copy(history.begin(), history.end(), std::back_inserter(local.inData)); + history.pop_front(); + history.push_back(v); + } + else + { + local.inData.resize(local.inData.size() + windowSize, -1); + } + } + local.outData.emplace_back(v); local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]); local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]); - - history.pop_front(); - history.push_back(v); } size_t r = local.outData.size() / batchSize; @@ -217,14 +246,14 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, auto& l = locals[localId]; size_t rest = std::min(l.outData.size(), batchSize); - std::copy(l.inData.begin(), l.inData.begin() + rest * windowSize, in); + std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in); std::copy(l.outData.begin(), l.outData.begin() + rest, out); std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs); std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode); restLmOut = l.restLmLProbsData.front(); restLmCntOut = l.restLmLProbsCntData.front(); - l.inData.erase(l.inData.begin(), l.inData.begin() + rest * windowSize); + l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize)); l.outData.erase(l.outData.begin(), l.outData.begin() + rest); l.lmLProbsData.erase(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest); l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest); @@ -245,7 +274,7 @@ size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outN size_t HSDataset::ngramNodeSize() const { - return knlm->nonLeafNodeSize(); + return knlm ? knlm->nonLeafNodeSize() : 0; } const MorphemeRaw& HSDataset::vocabInfo(uint32_t vocab) const diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index b216fff5..10a798e3 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -784,10 +784,8 @@ KiwiBuilder::KiwiBuilder(const string& modelPath, size_t _numThreads, BuildOptio } } -KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +void KiwiBuilder::initMorphemes() { - archType = getSelectedArch(ArchType::default_); - forms.resize(defaultFormSize); morphemes.resize(defaultFormSize + 2); // additional places for , for (size_t i = 1; i < defaultTagSize; ++i) @@ -805,6 +803,18 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) morphemes[i + defaultTagSize + 1].userScore = -1.5f; } +} + +KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) +{ + if (!(args.lmMinCnts.size() == 1 || args.lmMinCnts.size() == args.lmOrder)) + { + throw invalid_argument{ "lmMinCnts should have 1 or lmOrder elements" }; + } + + archType = getSelectedArch(ArchType::default_); + initMorphemes(); + ifstream ifs; auto realMorph = loadMorphemesFromTxt(openFile(ifs, args.morphemeDef), [&](POSTag tag, float cnt) { @@ -2179,43 +2189,72 @@ vector KiwiBuilder::extractAddWords(const U16MultipleReader& reader, s } HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, - size_t batchSize, size_t windowSize, size_t numWorkers, + size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers, double dropoutProb, const TokenFilter& tokenFilter, + const TokenFilter& windowFilter, double splitRatio, bool separateDefaultMorpheme, + const string& morphemeDefPath, + size_t morphemeDefMinCnt, HSDataset* splitDataset ) const { - auto realMorph = restoreMorphemeMap(separateDefaultMorpheme); - HSDataset dataset{ batchSize, windowSize, numWorkers, dropoutProb }; + HSDataset dataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb }; auto& sents = dataset.sents.get(); - dataset.knlm = langMdl.knlm; - dataset.morphemes = &morphemes; - dataset.forms = &forms; + const KiwiBuilder* srcBuilder = this; + MorphemeMap realMorph; + size_t maxTokenId = 0; + if (morphemeDefPath.empty()) + { + realMorph = restoreMorphemeMap(separateDefaultMorpheme); + } + else + { + dataset.dummyBuilder = make_shared(); + dataset.dummyBuilder->initMorphemes(); + ifstream ifs; + realMorph = dataset.dummyBuilder->loadMorphemesFromTxt(openFile(ifs, morphemeDefPath), [&](POSTag tag, float cnt) + { + return cnt >= morphemeDefMinCnt; + }); + srcBuilder = dataset.dummyBuilder.get(); + + for (auto& p : realMorph) + { + maxTokenId = max(p.second.first + 1, maxTokenId); + } + } + + auto& knlm = srcBuilder->langMdl.knlm; + dataset.knlm = knlm; + dataset.morphemes = &srcBuilder->morphemes; + dataset.forms = &srcBuilder->forms; if (splitDataset) { - *splitDataset = HSDataset{ batchSize, windowSize, numWorkers, dropoutProb }; - splitDataset->knlm = langMdl.knlm; - splitDataset->morphemes = &morphemes; - splitDataset->forms = &forms; + *splitDataset = HSDataset{ batchSize, causalContextSize, windowSize, numWorkers, dropoutProb }; + splitDataset->dummyBuilder = dataset.dummyBuilder; + splitDataset->knlm = knlm; + splitDataset->morphemes = &srcBuilder->morphemes; + splitDataset->forms = &srcBuilder->forms; } for (auto& path : inputPathes) { ifstream ifs; - addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); + srcBuilder->addCorpusTo(sents, openFile(ifs, path), realMorph, splitRatio, splitDataset ? &splitDataset->sents.get() : nullptr); } - size_t tokenSize = sents.raw().empty() ? 0 : *std::max_element(sents.raw().begin(), sents.raw().end()) + 1; + size_t tokenSize = sents.raw().empty() ? 0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1; if (splitDataset) { auto& sents = splitDataset->sents.get(); - tokenSize = std::max(tokenSize, sents.raw().empty() ? (size_t)0 : *std::max_element(sents.raw().begin(), sents.raw().end()) + 1); + tokenSize = max(tokenSize, sents.raw().empty() ? (size_t)0 : *max_element(sents.raw().begin(), sents.raw().end()) + 1); } - const size_t knlmVocabSize = langMdl.knlm->getHeader().vocab_size; + const size_t knlmVocabSize = knlm ? knlm->getHeader().vocab_size : maxTokenId; + tokenSize = max(tokenSize, knlmVocabSize); size_t filteredKnlmVocabSize = 0; for (size_t i = 0; i < tokenSize; ++i) { @@ -2223,7 +2262,17 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, { filteredKnlmVocabSize = dataset.vocabToToken.size(); } - if (tokenFilter && !tokenFilter(joinHangul(forms[morphemes[i].kform].form), morphemes[i].tag)) + + if (windowFilter && !windowFilter(joinHangul(srcBuilder->forms[srcBuilder->morphemes[i].kform].form), srcBuilder->morphemes[i].tag)) + { + dataset.windowTokenValidness.emplace_back(0); + } + else + { + dataset.windowTokenValidness.emplace_back(1); + } + + if (tokenFilter && !tokenFilter(joinHangul(srcBuilder->forms[srcBuilder->morphemes[i].kform].form), srcBuilder->morphemes[i].tag)) { dataset.tokenToVocab.emplace_back(HSDataset::nonVocab); continue; @@ -2244,6 +2293,7 @@ HSDataset KiwiBuilder::makeHSDataset(const vector& inputPathes, if (splitDataset) { + splitDataset->windowTokenValidness = dataset.windowTokenValidness; splitDataset->tokenToVocab = dataset.tokenToVocab; splitDataset->vocabToToken = dataset.vocabToToken; splitDataset->knlmVocabSize = dataset.knlmVocabSize; diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index d9605450..602fa73b 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -414,7 +414,7 @@ TEST(KiwiCpp, HSDataset) }; HSDataset trainset, devset; - trainset = kw.makeHSDataset(data, batchSize, windowSize, 1, 0., tokenFilter, 0.1, false, &devset); + trainset = kw.makeHSDataset(data, batchSize, 0, windowSize, 1, 0., tokenFilter, {}, 0.1, false, {}, 0, &devset); for (size_t i = 0; i < 2; ++i) { { From 7b2f0f630996a00bdaa3f6068c4e8d4ecaef6a46 Mon Sep 17 00:00:00 2001 From: bab2min Date: Sat, 12 Oct 2024 20:30:42 +0900 Subject: [PATCH 4/9] Change minCnt of KnLM from `size_t` type to `vector` --- include/kiwi/Kiwi.h | 3 +-- src/KiwiBuilder.cpp | 15 ++++++++++++--- tools/model_builder.cpp | 41 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/include/kiwi/Kiwi.h b/include/kiwi/Kiwi.h index 44a14d5c..b4f846b1 100644 --- a/include/kiwi/Kiwi.h +++ b/include/kiwi/Kiwi.h @@ -614,8 +614,7 @@ namespace kiwi std::vector corpora; size_t minMorphCnt = 10; size_t lmOrder = 4; - size_t lmMinCnt = 1; - size_t lmLastOrderMinCnt = 2; + std::vector lmMinCnts = { 1 }; size_t numWorkers = 1; size_t sbgSize = 1000000; bool useLmTagHistory = true; diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 10a798e3..e6bfa7a9 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -872,7 +872,8 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) pool.~ThreadPool(); new (&pool) utils::ThreadPool{ args.numWorkers }; } - auto cntNodes = utils::count(sents.begin(), sents.end(), args.lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); + size_t lmMinCnt = *std::min(args.lmMinCnts.begin(), args.lmMinCnts.end()); + auto cntNodes = utils::count(sents.begin(), sents.end(), lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); // discount for bos node cnt if (args.useLmTagHistory) { @@ -882,8 +883,16 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) { cntNodes.root().getNext(0)->val /= 2; } - std::vector minCnts(args.lmOrder, args.lmMinCnt); - minCnts.back() = args.lmLastOrderMinCnt; + std::vector minCnts; + if (args.lmMinCnts.size() == 1) + { + minCnts.clear(); + minCnts.resize(args.lmOrder, args.lmMinCnts[0]); + } + else if (args.lmMinCnts.size() == args.lmOrder) + { + minCnts = args.lmMinCnts; + } langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build( cntNodes, args.lmOrder, minCnts, diff --git a/tools/model_builder.cpp b/tools/model_builder.cpp index 199e13ab..c9b37302 100644 --- a/tools/model_builder.cpp +++ b/tools/model_builder.cpp @@ -8,6 +8,27 @@ using namespace std; using namespace kiwi; +vector splitMultipleInts(const string& s, const char delim = ',') +{ + vector ret; + size_t p = 0, e = 0; + while (1) + { + size_t t = s.find(delim, p); + if (t == s.npos) + { + ret.emplace_back(atoi(&s[e])); + return ret; + } + else + { + ret.emplace_back(atoi(&s[e])); + p = t + 1; + e = t + 1; + } + } +} + int run(const KiwiBuilder::ModelBuildArgs& args, const string& output, bool skipBigram) { try @@ -49,7 +70,7 @@ int main(int argc, const char* argv[]) ValueArg workers{ "w", "workers", "number of workers", false, 1, "int" }; ValueArg morMinCnt{ "", "morpheme_min_cnt", "min count of morpheme", false, 10, "int" }; ValueArg lmOrder{ "", "order", "order of LM", false, 4, "int" }; - ValueArg lmMinCnt{ "", "min_cnt", "min count of LM", false, 1, "int" }; + ValueArg lmMinCnt{ "", "min_cnt", "min count of LM", false, "1", "multiple ints with comma"}; ValueArg lmLastOrderMinCnt{ "", "last_min_cnt", "min count of the last order of LM", false, 2, "int" }; ValueArg output{ "o", "output", "output model path", true, "", "string" }; ValueArg sbgSize{ "", "sbg_size", "sbg size", false, 1000000, "int" }; @@ -86,10 +107,24 @@ int main(int argc, const char* argv[]) args.useLmTagHistory = tagHistory; args.minMorphCnt = morMinCnt; args.lmOrder = lmOrder; - args.lmMinCnt = lmMinCnt; - args.lmLastOrderMinCnt = lmLastOrderMinCnt; args.numWorkers = workers; args.sbgSize = sbgSize; + + auto v = splitMultipleInts(lmMinCnt.getValue()); + + if (v.empty()) + { + args.lmMinCnts.resize(1, 1); + } + else if (v.size() == 1 || v.size() == lmOrder) + { + args.lmMinCnts = v; + } + else + { + cerr << "error: min_cnt size should be 1 or equal to order" << endl; + return -1; + } return run(args, output, skipBigram); } From 0fceff105783d66fbac3fa8817f5f132839ca0c6 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 14 Oct 2024 01:39:12 +0900 Subject: [PATCH 5/9] Fix missing combined morphemes with single consonant --- src/KiwiBuilder.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index e6bfa7a9..c27014e9 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -734,10 +734,11 @@ void KiwiBuilder::updateMorphemes() void KiwiBuilder::loadMorphBin(std::istream& is) { serializer::readMany(is, serializer::toKey("KIWI"), forms, morphemes); - size_t cnt = 0; for (auto& form : forms) { - formMap.emplace(form.form, cnt++); + const size_t idx = &form - &forms[0]; + if (idx < defaultFormSize + 27) continue; + formMap.emplace(form.form, idx); } } @@ -1363,10 +1364,13 @@ void KiwiBuilder::addCombinedMorphemes( else return newForms[id - forms.size()]; }; - auto res = combiningRule->combine(getForm(getMorph(leftId).kform).form, getForm(getMorph(rightId).kform).form, ruleId); + auto& leftForm = getForm(getMorph(leftId).kform).form; + auto& rightForm = getForm(getMorph(rightId).kform).form; + + auto res = combiningRule->combine(leftForm, rightForm, ruleId); for (auto& r : res) { - if (!r.ignoreRCond && !FeatureTestor::isMatched(&getForm(getMorph(leftId).kform).form, getMorph(rightId).vowel())) + if (!r.ignoreRCond && !FeatureTestor::isMatched(&leftForm, getMorph(rightId).vowel())) { continue; } From e25012aa5fd148dafe30ae8409dbec76d990cd15 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 14 Oct 2024 01:40:49 +0900 Subject: [PATCH 6/9] Add test cases for single consonant morphemes --- test/test_cpp.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 602fa73b..171e60c2 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -135,6 +135,16 @@ TEST(KiwiCpp, SingleResult) } } +TEST(KiwiCpp, SingleConsonantMorpheme) +{ + Kiwi& kiwi = reuseKiwiInstance(); + auto res = kiwi.analyze(u"구원의 손길을 내민 시민들", Match::allWithNormalizing).first; + EXPECT_EQ(res[4].str, u"내밀"); + + res = kiwi.analyze(u"서툰 모습을", Match::allWithNormalizing).first; + EXPECT_EQ(res[0].str, u"서툴"); +} + TEST(KiwiCpp, SplitComplex) { Kiwi& kiwi = reuseKiwiInstance(); From f854d811fd8ca16830ea7c2e4573c91323465353 Mon Sep 17 00:00:00 2001 From: bab2min Date: Wed, 16 Oct 2024 00:18:24 +0900 Subject: [PATCH 7/9] Fix wrong splitting on continualTypo mode --- src/KTrie.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/KTrie.cpp b/src/KTrie.cpp index 025b906f..b1fd8140 100644 --- a/src/KTrie.cpp +++ b/src/KTrie.cpp @@ -653,8 +653,8 @@ size_t kiwi::splitByTrie( const auto scanStart = max(endPosMap[nBeginWithMultiplier].first, (uint32_t)1), scanEnd = endPosMap[nBeginWithMultiplier].second; const bool longestMatched = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g) { - const auto start = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier; - return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos == start || specialStartPos == start); + const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier; + return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos * posMultiplier == startPos || specialStartPos * posMultiplier == startPos); }); // insert unknown form @@ -742,7 +742,7 @@ size_t kiwi::splitByTrie( const auto scanStart = max(endPosMap[unkFormEndPos * posMultiplier].first, (uint32_t)1), scanEnd = endPosMap[unkFormEndPos * posMultiplier].second; const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g) { - size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier; + const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier; return startPos == lastSpecialEndPos * posMultiplier && g.endPos == unkFormEndPos * posMultiplier; }); if (unkFormEndPos > lastSpecialEndPos && !duplicated) From e4d6f86f02780736cb4071ad5a4ed0a4983b3bf2 Mon Sep 17 00:00:00 2001 From: bab2min Date: Wed, 16 Oct 2024 00:19:45 +0900 Subject: [PATCH 8/9] Add test cases for splitting errors on continual typos --- test/test_cpp.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 171e60c2..7b94e647 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -145,6 +145,18 @@ TEST(KiwiCpp, SingleConsonantMorpheme) EXPECT_EQ(res[0].str, u"서툴"); } +TEST(KiwiCpp, SpecialTokenErrorOnContinualTypo) +{ + KiwiBuilder builder{ MODEL_PATH, 0, BuildOption::default_, }; + Kiwi typoKiwi = builder.build(DefaultTypoSet::continualTypoSet); + + auto res = typoKiwi.analyze(u"감사합니다 -친구들과", Match::allWithNormalizing).first; + EXPECT_EQ(res[0].str, u"감사"); + EXPECT_EQ(res[1].str, u"하"); + EXPECT_EQ(res[3].str, u"-"); + EXPECT_EQ(res[3].tag, POSTag::so); +} + TEST(KiwiCpp, SplitComplex) { Kiwi& kiwi = reuseKiwiInstance(); From f5bfccb4f2e3eb223a9bd805ba838bd918ef0c9b Mon Sep 17 00:00:00 2001 From: bab2min Date: Wed, 16 Oct 2024 03:15:11 +0900 Subject: [PATCH 9/9] Fix `TEST(KiwiCpp, HSDataset)` --- test/test_cpp.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 7b94e647..7c8a8b62 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -413,7 +413,7 @@ TEST(KiwiCpp, HSDataset) for (size_t w : {0, 1, 2, 4}) { //std::cout << w << std::endl; - auto dataset = kw.makeHSDataset(data, batchSize, windowSize, w, 0.); + auto dataset = kw.makeHSDataset(data, batchSize, 0, windowSize, w, 0.); for (size_t i = 0; i < 2; ++i) { size_t totalBatchCnt = 0, totalTokenCnt = 0, s;