Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Kiwi Builder #188

Merged
merged 7 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions include/kiwi/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace kiwi

static constexpr int32_t nonVocab = -1;

HiddenMember<RaggedVector<uint16_t>, sizeof(Vector<size_t>) * 2> sents;
HiddenMember<RaggedVector<uint32_t>, sizeof(Vector<size_t>) * 2> sents;
std::shared_ptr<lm::KnLangModelBase> knlm;
std::unique_ptr<utils::ThreadPool> workers;
std::discrete_distribution<> dropout;
Expand All @@ -55,6 +55,7 @@ namespace kiwi
Deque<OptionalFuture<size_t>> futures;
const Vector<MorphemeRaw>* morphemes = nullptr;
const Vector<FormRaw>* forms = nullptr;
size_t knlmVocabSize = 0;
size_t batchSize = 0;
size_t windowSize = 0;
size_t totalTokens = 0;
Expand Down Expand Up @@ -87,12 +88,13 @@ namespace kiwi
size_t next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut);

size_t vocabSize() const { return vocabToToken.size(); }
size_t getKnlmVocabSize() const;
size_t ngramNodeSize() const;
const MorphemeRaw& vocabInfo(uint32_t vocab) const;
std::u16string vocabForm(uint32_t vocab) const;
std::vector<size_t> estimVocabFrequency() const;

Range<Vector<uint16_t>::const_iterator> getSent(size_t idx) const;
std::vector<uint16_t> getAugmentedSent(size_t idx);
Range<Vector<uint32_t>::const_iterator> getSent(size_t idx) const;
std::vector<uint32_t> getAugmentedSent(size_t idx);
};
}
11 changes: 9 additions & 2 deletions include/kiwi/Kiwi.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,14 +546,19 @@ namespace kiwi
FormRaw& addForm(const KString& form);
size_t addForm(Vector<FormRaw>& newForms, UnorderedMap<KString, size_t>& newFormMap, KString form) const;

using MorphemeMap = UnorderedMap<std::pair<KString, POSTag>, std::pair<size_t, size_t>>;
using MorphemeMap = UnorderedMap<std::tuple<KString, uint8_t, POSTag>, std::pair<size_t, size_t>>;

template<class Fn>
MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter);

MorphemeMap restoreMorphemeMap() const;
MorphemeMap restoreMorphemeMap(bool separateDefaultMorpheme = false) const;

template<class VocabTy>
void _addCorpusTo(RaggedVector<VocabTy>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector<VocabTy>* splitOut) const;

void addCorpusTo(RaggedVector<uint8_t>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector<uint8_t>* splitOut = nullptr) const;
void addCorpusTo(RaggedVector<uint16_t>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector<uint16_t>* splitOut = nullptr) const;
void addCorpusTo(RaggedVector<uint32_t>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector<uint32_t>* splitOut = nullptr) const;
void updateForms();
void updateMorphemes();

Expand Down Expand Up @@ -610,6 +615,7 @@ namespace kiwi
size_t lmMinCnt = 1;
size_t lmLastOrderMinCnt = 2;
size_t numWorkers = 1;
size_t sbgSize = 1000000;
bool useLmTagHistory = true;
bool quantizeLm = true;
bool compressLm = true;
Expand Down Expand Up @@ -799,6 +805,7 @@ namespace kiwi
double dropoutProb = 0,
const TokenFilter& tokenFilter = {},
double splitRatio = 0,
bool separateDefaultMorpheme = false,
HSDataset* splitDataset = nullptr
) const;
};
Expand Down
11 changes: 8 additions & 3 deletions src/Dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ std::u16string HSDataset::vocabForm(uint32_t vocab) const
return joinHangul((*forms)[(*morphemes)[vocabToToken[vocab]].kform].form);
}

size_t HSDataset::getKnlmVocabSize() const
{
return knlmVocabSize;
}

std::vector<size_t> kiwi::HSDataset::estimVocabFrequency() const
{
std::vector<size_t> ret(vocabSize()), augs(getDefaultMorphemeId(POSTag::max));
Expand All @@ -279,7 +284,7 @@ std::vector<size_t> kiwi::HSDataset::estimVocabFrequency() const
return ret;
}

Range<Vector<uint16_t>::const_iterator> HSDataset::getSent(size_t idx) const
Range<Vector<uint32_t>::const_iterator> HSDataset::getSent(size_t idx) const
{
return sents.get()[idx];
}
Expand All @@ -289,9 +294,9 @@ void HSDataset::seed(size_t newSeed)
rng.seed(newSeed);
}

std::vector<uint16_t> HSDataset::getAugmentedSent(size_t idx)
std::vector<uint32_t> HSDataset::getAugmentedSent(size_t idx)
{
std::vector<uint16_t> ret;
std::vector<uint32_t> ret;
auto sent = sents.get()[idx];
ret.emplace_back(*sent.begin());
for (auto p = sent.begin() + 1; p != sent.end() - 1; ++p)
Expand Down
Loading
Loading