From a4a85c466351353651e9e29be8d77232fc1459f7 Mon Sep 17 00:00:00 2001 From: bab2min Date: Mon, 30 Oct 2023 22:35:06 +0900 Subject: [PATCH] improved multi sentence path searching --- src/Kiwi.cpp | 27 ++++++++++++++++----------- src/PathEvaluator.hpp | 39 ++++++++++++++++++++++++++++++--------- test/test_cpp.cpp | 2 +- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index 9686e516..eee69a8e 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -613,9 +613,10 @@ namespace kiwi if (ret.empty()) { - ret.resize(pathes.size()); - spStatesByRet.resize(pathes.size()); - parentMap.resize(pathes.size()); + const size_t n = pathes.size(); + ret.resize(n); + spStatesByRet.resize(n); + parentMap.resize(n); iota(parentMap.begin(), parentMap.end(), 0); } else @@ -661,10 +662,14 @@ namespace kiwi } } + UnorderedMap spStateCnt; size_t validTarget = 0; for (size_t i = 0; i < ret.size(); ++i) { - if (parentMap[i] < pathes.size()) + auto& r = pathes[parentMap[i]]; + auto& rarr = ret[validTarget].first; + + if (parentMap[i] < pathes.size() && spStateCnt[r.curState] < topN) { if (validTarget != i) ret[validTarget] = move(ret[i]); } @@ -673,8 +678,6 @@ namespace kiwi continue; } - auto& r = pathes[parentMap[i]]; - auto& rarr = ret[validTarget].first; const KString* prevMorph = nullptr; for (auto& s : r.path) { @@ -726,6 +729,7 @@ namespace kiwi rarr.erase(joinAffixTokens(rarr.begin(), rarr.end(), matchOptions), rarr.end()); ret[validTarget].second += r.score; spStatesByRet[validTarget] = r.curState; + spStateCnt[r.curState]++; validTarget++; } ret.erase(ret.begin() + validTarget, ret.end()); @@ -977,6 +981,12 @@ namespace kiwi ); insertPathIntoResults(ret, spStatesByRet, res, topN, matchOptions, integrateAllomorph, positionTable, wordPositions, pretokenizedGroup, nodeInWhichPretokenized); } + + sort(ret.begin(), ret.end(), [](const TokenResult& a, const TokenResult& b) + { + return a.second > b.second; + }); + if (ret.size() > topN) ret.erase(ret.begin() + topN, ret.end()); auto newlines = allNewLinePositions(str); for (auto& r : ret) @@ -985,11 +995,6 @@ namespace kiwi fillSentLineInfo(r.first, newlines); } - sort(ret.begin(), ret.end(), [](const TokenResult& a, const TokenResult& b) - { - return a.second > b.second; - }); - if (ret.empty()) ret.emplace_back(); return ret; } diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index 8a2561cb..4a5d895d 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -1030,9 +1030,13 @@ namespace kiwi auto& cand = cache.back(); sort(cand.begin(), cand.end(), [](const WordLL& a, const WordLL& b) - { - return a.accScore > b.accScore; - } + { + if (a.rootId < b.rootId) return true; + if (a.rootId > b.rootId) return false; + if (a.spState < b.spState) return true; + if (a.spState > b.spState) return false; + return a.accScore > b.accScore; + } ); #ifdef DEBUG_PRINT @@ -1052,14 +1056,31 @@ namespace kiwi utils::ContainerSearcher> csearcher{ cache }; Vector ret; - for (size_t i = 0; i < min(topN, cand.size()); ++i) + size_t startIdx = 0; + pair prevRootIdAndSpState; + if (!cand.empty()) prevRootIdAndSpState = make_pair(cand[0].rootId, (uint8_t)cand[0].spState); + for (size_t i = 0; i < cand.size(); ++i) { - auto tokens = generateTokenList( - &cand[i], csearcher, graph, ownFormList, kw->typoCostWeight, - kw->morphemes.data(), langVocabSize - ); - ret.emplace_back(move(tokens.first), cand[i].accScore, tokens.second->spState, cand[i].spState); + auto curRootIdAndSpState = make_pair(cand[i].rootId, (uint8_t)cand[i].spState); + if (prevRootIdAndSpState != curRootIdAndSpState) + { + startIdx = i; + prevRootIdAndSpState = curRootIdAndSpState; + } + + if (i - startIdx < topN) + { + auto tokens = generateTokenList( + &cand[i], csearcher, graph, ownFormList, kw->typoCostWeight, + kw->morphemes.data(), langVocabSize + ); + ret.emplace_back(move(tokens.first), cand[i].accScore, tokens.second->spState, cand[i].spState); + } } + sort(ret.begin(), ret.end(), [](const ChunkResult& a, const ChunkResult& b) + { + return a.score > b.score; + }); return ret; } } diff --git a/test/test_cpp.cpp b/test/test_cpp.cpp index 0859739d..51275fd9 100644 --- a/test/test_cpp.cpp +++ b/test/test_cpp.cpp @@ -417,7 +417,7 @@ TEST(KiwiCpp, SentenceBoundaryWithOrderedBullet) u"1) 편당 요금을 지불한다. 2) 편당 요금을 지불한다. 3) 편당 요금을 지불한다.", u"가. 편당 요금을 지불한다 나. 편당 요금을 지불한다 다. 편당 요금을 지불한다", u"가) 편당 요금을 지불한다 나) 편당 요금을 지불한다 다) 편당 요금을 지불한다", - //u"1) 편당 요금을 지불한다 2) 편당 요금을 지불한다 3) 편당 요금을 지불한다", + u"1) 편당 요금을 지불한다 2) 편당 요금을 지불한다 3) 편당 요금을 지불한다", u"가. 편당 요금을 지불 나. 편당 요금을 지불 다. 편당 요금을 지불", u"가) 편당 요금을 지불 나) 편당 요금을 지불 다) 편당 요금을 지불", u"1) 편당 요금을 지불 2) 편당 요금을 지불 3) 편당 요금을 지불",