Skip to content

Commit

Permalink
improved multi sentence path searching
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Oct 30, 2023
1 parent 7fb06fd commit a4a85c4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 21 deletions.
27 changes: 16 additions & 11 deletions src/Kiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,10 @@ namespace kiwi

if (ret.empty())
{
ret.resize(pathes.size());
spStatesByRet.resize(pathes.size());
parentMap.resize(pathes.size());
const size_t n = pathes.size();
ret.resize(n);
spStatesByRet.resize(n);
parentMap.resize(n);
iota(parentMap.begin(), parentMap.end(), 0);
}
else
Expand Down Expand Up @@ -661,10 +662,14 @@ namespace kiwi
}
}

UnorderedMap<uint8_t, uint32_t> spStateCnt;
size_t validTarget = 0;
for (size_t i = 0; i < ret.size(); ++i)
{
if (parentMap[i] < pathes.size())
auto& r = pathes[parentMap[i]];
auto& rarr = ret[validTarget].first;

if (parentMap[i] < pathes.size() && spStateCnt[r.curState] < topN)
{
if (validTarget != i) ret[validTarget] = move(ret[i]);
}
Expand All @@ -673,8 +678,6 @@ namespace kiwi
continue;
}

auto& r = pathes[parentMap[i]];
auto& rarr = ret[validTarget].first;
const KString* prevMorph = nullptr;
for (auto& s : r.path)
{
Expand Down Expand Up @@ -726,6 +729,7 @@ namespace kiwi
rarr.erase(joinAffixTokens(rarr.begin(), rarr.end(), matchOptions), rarr.end());
ret[validTarget].second += r.score;
spStatesByRet[validTarget] = r.curState;
spStateCnt[r.curState]++;
validTarget++;
}
ret.erase(ret.begin() + validTarget, ret.end());
Expand Down Expand Up @@ -977,6 +981,12 @@ namespace kiwi
);
insertPathIntoResults(ret, spStatesByRet, res, topN, matchOptions, integrateAllomorph, positionTable, wordPositions, pretokenizedGroup, nodeInWhichPretokenized);
}

sort(ret.begin(), ret.end(), [](const TokenResult& a, const TokenResult& b)
{
return a.second > b.second;
});
if (ret.size() > topN) ret.erase(ret.begin() + topN, ret.end());

auto newlines = allNewLinePositions(str);
for (auto& r : ret)
Expand All @@ -985,11 +995,6 @@ namespace kiwi
fillSentLineInfo(r.first, newlines);
}

sort(ret.begin(), ret.end(), [](const TokenResult& a, const TokenResult& b)
{
return a.second > b.second;
});

if (ret.empty()) ret.emplace_back();
return ret;
}
Expand Down
39 changes: 30 additions & 9 deletions src/PathEvaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,9 +1030,13 @@ namespace kiwi
auto& cand = cache.back();
sort(cand.begin(), cand.end(),
[](const WordLL<LmState>& a, const WordLL<LmState>& b)
{
return a.accScore > b.accScore;
}
{
if (a.rootId < b.rootId) return true;
if (a.rootId > b.rootId) return false;
if (a.spState < b.spState) return true;
if (a.spState > b.spState) return false;
return a.accScore > b.accScore;
}
);

#ifdef DEBUG_PRINT
Expand All @@ -1052,14 +1056,31 @@ namespace kiwi

utils::ContainerSearcher<WordLL<LmState>> csearcher{ cache };
Vector<ChunkResult> ret;
for (size_t i = 0; i < min(topN, cand.size()); ++i)
size_t startIdx = 0;
pair<uint8_t, uint8_t> prevRootIdAndSpState;
if (!cand.empty()) prevRootIdAndSpState = make_pair(cand[0].rootId, (uint8_t)cand[0].spState);
for (size_t i = 0; i < cand.size(); ++i)
{
auto tokens = generateTokenList(
&cand[i], csearcher, graph, ownFormList, kw->typoCostWeight,
kw->morphemes.data(), langVocabSize
);
ret.emplace_back(move(tokens.first), cand[i].accScore, tokens.second->spState, cand[i].spState);
auto curRootIdAndSpState = make_pair(cand[i].rootId, (uint8_t)cand[i].spState);
if (prevRootIdAndSpState != curRootIdAndSpState)
{
startIdx = i;
prevRootIdAndSpState = curRootIdAndSpState;
}

if (i - startIdx < topN)
{
auto tokens = generateTokenList(
&cand[i], csearcher, graph, ownFormList, kw->typoCostWeight,
kw->morphemes.data(), langVocabSize
);
ret.emplace_back(move(tokens.first), cand[i].accScore, tokens.second->spState, cand[i].spState);
}
}
sort(ret.begin(), ret.end(), [](const ChunkResult& a, const ChunkResult& b)
{
return a.score > b.score;
});
return ret;
}
}
2 changes: 1 addition & 1 deletion test/test_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ TEST(KiwiCpp, SentenceBoundaryWithOrderedBullet)
u"1) 편당 요금을 지불한다. 2) 편당 요금을 지불한다. 3) 편당 요금을 지불한다.",
u"가. 편당 요금을 지불한다 나. 편당 요금을 지불한다 다. 편당 요금을 지불한다",
u"가) 편당 요금을 지불한다 나) 편당 요금을 지불한다 다) 편당 요금을 지불한다",
//u"1) 편당 요금을 지불한다 2) 편당 요금을 지불한다 3) 편당 요금을 지불한다",
u"1) 편당 요금을 지불한다 2) 편당 요금을 지불한다 3) 편당 요금을 지불한다",
u"가. 편당 요금을 지불 나. 편당 요금을 지불 다. 편당 요금을 지불",
u"가) 편당 요금을 지불 나) 편당 요금을 지불 다) 편당 요금을 지불",
u"1) 편당 요금을 지불 2) 편당 요금을 지불 3) 편당 요금을 지불",
Expand Down

0 comments on commit a4a85c4

Please sign in to comment.