Skip to content

Commit

Permalink
Merge pull request #194 from bab2min/dev/issue192
Browse files Browse the repository at this point in the history
자잘한 버그 수정
  • Loading branch information
bab2min authored Oct 16, 2024
2 parents c1da90c + f5bfccb commit 0d72a27
Show file tree
Hide file tree
Showing 10 changed files with 236 additions and 65 deletions.
7 changes: 6 additions & 1 deletion include/kiwi/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,19 @@ namespace kiwi
HiddenMember<RaggedVector<uint32_t>, sizeof(Vector<size_t>) * 2> sents;
std::shared_ptr<lm::KnLangModelBase> knlm;
std::unique_ptr<utils::ThreadPool> workers;
std::shared_ptr<KiwiBuilder> dummyBuilder;
std::discrete_distribution<> dropout;
std::mt19937_64 rng;
Vector<ThreadLocal> locals;
Vector<size_t> shuffledIdx;
Vector<int32_t> tokenToVocab, vocabToToken;
Vector<uint8_t> windowTokenValidness;
Deque<OptionalFuture<size_t>> futures;
const Vector<MorphemeRaw>* morphemes = nullptr;
const Vector<FormRaw>* forms = nullptr;
size_t knlmVocabSize = 0;
size_t batchSize = 0;
size_t causalContextSize = 0;
size_t windowSize = 0;
size_t totalTokens = 0;
size_t passedSents = 0;
Expand All @@ -68,7 +71,7 @@ namespace kiwi
size_t _next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode, float& restLmOut, uint32_t& restLmCntOut);

public:
HSDataset(size_t _batchSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
HSDataset(size_t _batchSize = 0, size_t _causalContextSize = 0, size_t _windowSize = 0, size_t _workers = 0, double _dropoutProb = 0);
~HSDataset();
HSDataset(const HSDataset&) = delete;
HSDataset(HSDataset&&) /*noexcept*/;
Expand All @@ -80,7 +83,9 @@ namespace kiwi
size_t numTokens() const;

size_t getBatchSize() const { return batchSize; }
size_t getCausalContextSize() const { return causalContextSize; }
size_t getWindowSize() const { return windowSize; }
const Vector<uint8_t>& getWindowTokenValidness() const { return windowTokenValidness; }

void seed(size_t newSeed);
void reset();
Expand Down
10 changes: 7 additions & 3 deletions include/kiwi/Kiwi.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,8 @@ namespace kiwi

using MorphemeMap = UnorderedMap<std::tuple<KString, uint8_t, POSTag>, std::pair<size_t, size_t>>;

void initMorphemes();

template<class Fn>
MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter);

Expand Down Expand Up @@ -612,8 +614,7 @@ namespace kiwi
std::vector<std::string> corpora;
size_t minMorphCnt = 10;
size_t lmOrder = 4;
size_t lmMinCnt = 1;
size_t lmLastOrderMinCnt = 2;
std::vector<size_t> lmMinCnts = { 1 };
size_t numWorkers = 1;
size_t sbgSize = 1000000;
bool useLmTagHistory = true;
Expand Down Expand Up @@ -801,11 +802,14 @@ namespace kiwi
using TokenFilter = std::function<bool(const std::u16string&, POSTag)>;

HSDataset makeHSDataset(const std::vector<std::string>& inputPathes,
size_t batchSize, size_t windowSize, size_t numWorkers,
size_t batchSize, size_t causalContextSize, size_t windowSize, size_t numWorkers,
double dropoutProb = 0,
const TokenFilter& tokenFilter = {},
const TokenFilter& windowFilter = {},
double splitRatio = 0,
bool separateDefaultMorpheme = false,
const std::string& morphemeDefPath = {},
size_t morphemeDefMinCnt = 0,
HSDataset* splitDataset = nullptr
) const;
};
Expand Down
2 changes: 1 addition & 1 deletion include/kiwi/TagUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace kiwi
inline bool isSuffix(POSTag tag)
{
tag = clearIrregular(tag);
return POSTag::xsn <= tag && tag <= POSTag::xsa;
return POSTag::xsn <= tag && tag <= POSTag::xsm;
}

inline bool isSpecialClass(POSTag tag)
Expand Down
51 changes: 40 additions & 11 deletions src/Dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

using namespace kiwi;

HSDataset::HSDataset(size_t _batchSize, size_t _windowSize, size_t _workers, double _dropoutProb)
HSDataset::HSDataset(size_t _batchSize, size_t _causalContextSize, size_t _windowSize, size_t _workers, double _dropoutProb)
: workers{ _workers ? make_unique<utils::ThreadPool>(_workers) : nullptr },
dropout{ {1 - _dropoutProb * 3, _dropoutProb, _dropoutProb, _dropoutProb} },
locals( _workers ? workers->size() : 1),
batchSize{ _batchSize },
causalContextSize{ _causalContextSize },
windowSize{ _windowSize }
{
}
Expand Down Expand Up @@ -113,12 +114,21 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,

local.lmLProbsBuf.resize(tokens.size());
local.outNgramNodeBuf.resize(tokens.size());
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
if (knlm)
{
knlm->evaluate(tokens.begin(), tokens.end(), local.lmLProbsBuf.begin(), local.outNgramNodeBuf.begin());
}

auto& history = local.historyBuf;
history.clear();
history.resize(windowSize, -1);
history.back() = tokenToVocab[tokens[0]];
if (windowSize)
{
history.resize(windowSize, -1);
if (windowTokenValidness[tokens[0]])
{
history.back() = tokenToVocab[tokens[0]];
}
}
for (size_t i = 1; i < tokens.size(); ++i)
{
int32_t v = tokenToVocab[tokens[i]];
Expand All @@ -134,13 +144,32 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
local.restLmLProbsCntData[r] += 1;
continue;
}
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));

if (causalContextSize)
{
for (size_t j = 0; j < causalContextSize; ++j)
{
local.inData.emplace_back(i + j < causalContextSize ?
nonVocab : tokenToVocab[tokens[i + j - causalContextSize]]);
}
}
if (windowSize)
{
if (windowTokenValidness[v])
{
std::copy(history.begin(), history.end(), std::back_inserter(local.inData));
history.pop_front();
history.push_back(v);
}
else
{
local.inData.resize(local.inData.size() + windowSize, -1);
}
}

local.outData.emplace_back(v);
local.lmLProbsData.emplace_back(local.lmLProbsBuf[i]);
local.outNgramNodeData.emplace_back(local.outNgramNodeBuf[i]);

history.pop_front();
history.push_back(v);
}

size_t r = local.outData.size() / batchSize;
Expand Down Expand Up @@ -217,14 +246,14 @@ size_t HSDataset::_next(InTy in, OutTy out, LmTy lmLProbs, NgramTy outNgramNode,
auto& l = locals[localId];

size_t rest = std::min(l.outData.size(), batchSize);
std::copy(l.inData.begin(), l.inData.begin() + rest * windowSize, in);
std::copy(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize), in);
std::copy(l.outData.begin(), l.outData.begin() + rest, out);
std::copy(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest, lmLProbs);
std::copy(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest, outNgramNode);
restLmOut = l.restLmLProbsData.front();
restLmCntOut = l.restLmLProbsCntData.front();

l.inData.erase(l.inData.begin(), l.inData.begin() + rest * windowSize);
l.inData.erase(l.inData.begin(), l.inData.begin() + rest * (causalContextSize + windowSize));
l.outData.erase(l.outData.begin(), l.outData.begin() + rest);
l.lmLProbsData.erase(l.lmLProbsData.begin(), l.lmLProbsData.begin() + rest);
l.outNgramNodeData.erase(l.outNgramNodeData.begin(), l.outNgramNodeData.begin() + rest);
Expand All @@ -245,7 +274,7 @@ size_t HSDataset::next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outN

size_t HSDataset::ngramNodeSize() const
{
return knlm->nonLeafNodeSize();
return knlm ? knlm->nonLeafNodeSize() : 0;
}

const MorphemeRaw& HSDataset::vocabInfo(uint32_t vocab) const
Expand Down
28 changes: 19 additions & 9 deletions src/KTrie.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,8 +653,8 @@ size_t kiwi::splitByTrie(
const auto scanStart = max(endPosMap[nBeginWithMultiplier].first, (uint32_t)1), scanEnd = endPosMap[nBeginWithMultiplier].second;
const bool longestMatched = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
const auto start = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos == start || specialStartPos == start);
const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
return nBeginWithMultiplier == g.endPos && (lastSpecialEndPos * posMultiplier == startPos || specialStartPos * posMultiplier == startPos);
});

// insert unknown form
Expand Down Expand Up @@ -742,7 +742,7 @@ size_t kiwi::splitByTrie(
const auto scanStart = max(endPosMap[unkFormEndPos * posMultiplier].first, (uint32_t)1), scanEnd = endPosMap[unkFormEndPos * posMultiplier].second;
const bool duplicated = scanStart < scanEnd && any_of(out.begin() + scanStart, out.begin() + scanEnd, [&](const KGraphNode& g)
{
size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
const size_t startPos = g.endPos - (g.uform.empty() ? g.form->sizeWithoutSpace() : g.uform.size()) * posMultiplier;
return startPos == lastSpecialEndPos * posMultiplier && g.endPos == unkFormEndPos * posMultiplier;
});
if (unkFormEndPos > lastSpecialEndPos && !duplicated)
Expand Down Expand Up @@ -1215,9 +1215,10 @@ size_t kiwi::splitByTrie(
return n + startOffset;
}

template<ArchType arch>
template<ArchType arch, bool typoTolerant>
const Form* kiwi::findForm(
const utils::FrozenTrie<kchar_t, const Form*>& trie,
const Form* formData,
const KString& str
)
{
Expand All @@ -1228,7 +1229,12 @@ const Form* kiwi::findForm(
if (!node) return nullptr;
}
if (trie.hasSubmatch(node->val(trie))) return nullptr;
return node->val(trie);
auto ret = node->val(trie);
if (typoTolerant)
{
ret = &reinterpret_cast<const TypoForm*>(ret)->form(formData);
}
return ret;
}

namespace kiwi
Expand Down Expand Up @@ -1266,19 +1272,23 @@ FnSplitByTrie kiwi::getSplitByTrieFn(ArchType arch, bool typoTolerant, bool cont

namespace kiwi
{
template<bool typoTolerant>
struct FindFormGetter
{
template<std::ptrdiff_t i>
struct Wrapper
{
static constexpr FnFindForm value = &findForm<static_cast<ArchType>(i)>;
static constexpr FnFindForm value = &findForm<static_cast<ArchType>(i), typoTolerant>;
};
};
}

FnFindForm kiwi::getFindFormFn(ArchType arch)
FnFindForm kiwi::getFindFormFn(ArchType arch, bool typoTolerant)
{
static tp::Table<FnFindForm, AvailableArch> table{ FindFormGetter{} };
static std::array<tp::Table<FnFindForm, AvailableArch>, 2> table{
FindFormGetter<false>{},
FindFormGetter<true>{},
};

return table[static_cast<std::ptrdiff_t>(arch)];
return table[typoTolerant ? 1 : 0][static_cast<std::ptrdiff_t>(arch)];
}
7 changes: 4 additions & 3 deletions src/KTrie.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,18 @@ namespace kiwi
const PretokenizedSpanGroup::Span* pretokenizedLast
);

template<ArchType arch>
template<ArchType arch, bool typoTolerant>
const Form* findForm(
const utils::FrozenTrie<kchar_t, const Form*>& trie,
const Form* formData,
const KString& str
);

using FnSplitByTrie = decltype(&splitByTrie<ArchType::default_>);
FnSplitByTrie getSplitByTrieFn(ArchType arch, bool typoTolerant, bool continualTypoTolerant, bool lengtheningTypoTolerant);

using FnFindForm = decltype(&findForm<ArchType::default_>);
FnFindForm getFindFormFn(ArchType arch);
using FnFindForm = decltype(&findForm<ArchType::default_, false>);
FnFindForm getFindFormFn(ArchType arch, bool typoTolerant);

struct KTrie : public utils::TrieNode<char16_t, const Form*, utils::ConstAccess<map<char16_t, int32_t>>, KTrie>
{
Expand Down
16 changes: 9 additions & 7 deletions src/Kiwi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace kiwi
typoTolerant,
continualTypoTolerant,
lengtheningTypoTolerant);
dfFindForm = (void*)getFindFormFn(selectedArch);
dfFindForm = (void*)getFindFormFn(selectedArch, typoTolerant);

static tp::Table<FnFindBestPath, AvailableArch> lmKnLM_8{ FindBestPathGetter<WrappedKnLM<uint8_t>::type>{} };
static tp::Table<FnFindBestPath, AvailableArch> lmKnLM_16{ FindBestPathGetter<WrappedKnLM<uint16_t>::type>{} };
Expand Down Expand Up @@ -802,7 +802,8 @@ namespace kiwi
const Vector<uint32_t>& positionTable,
const KString& normStr,
FnFindForm findForm,
const utils::FrozenTrie<kchar_t, const Form*>& formTrie
const utils::FrozenTrie<kchar_t, const Form*>& formTrie,
const Form* formData
)
{
if (pretokenized.empty()) return;
Expand Down Expand Up @@ -833,7 +834,7 @@ namespace kiwi
if (s.tokenization.empty())
{
auto formStr = normStr.substr(span.begin, span.end - span.begin);
span.form = findForm(formTrie, formStr); // reuse the predefined form & morpheme
span.form = findForm(formTrie, formData, formStr); // reuse the predefined form & morpheme
if (!span.form) // or use a fallback form
{
span.form = formTrie.value((size_t)POSTag::nnp);
Expand All @@ -842,7 +843,7 @@ namespace kiwi
else if (s.tokenization.size() == 1)
{
auto formStr = normalizeHangul(s.tokenization[0].form);
auto* tform = findForm(formTrie, formStr);
auto* tform = findForm(formTrie, formData, formStr);
if (tform && tform->candidate.size() == 1 &&
areTagsEqual(tform->candidate[0]->tag, s.tokenization[0].tag, !!s.tokenization[0].inferRegularity))
// reuse the predefined form & morpheme
Expand Down Expand Up @@ -908,7 +909,7 @@ namespace kiwi
{
auto& t = s.tokenization[i];
auto formStr = normalizeHangul(t.form);
auto* tform = findForm(formTrie, formStr);
auto* tform = findForm(formTrie, formData, formStr);
const Morpheme* foundMorph = nullptr;
if (tform)
{
Expand Down Expand Up @@ -999,7 +1000,8 @@ namespace kiwi
positionTable,
normalizedStr,
reinterpret_cast<FnFindForm>(dfFindForm),
formTrie
formTrie,
forms.data()
);

// 분석할 문장에 포함된 개별 문자에 대해 어절번호를 생성한다
Expand Down Expand Up @@ -1317,7 +1319,7 @@ namespace kiwi
void Kiwi::findMorpheme(vector<const Morpheme*>& ret, const u16string& s, POSTag tag) const
{
auto normalized = normalizeHangul(s);
auto form = (*reinterpret_cast<FnFindForm>(dfFindForm))(formTrie, normalized);
auto form = (*reinterpret_cast<FnFindForm>(dfFindForm))(formTrie, forms.data(), normalized);
if (!form) return;
tag = clearIrregular(tag);
for (auto c : form->candidate)
Expand Down
Loading

0 comments on commit 0d72a27

Please sign in to comment.