Skip to content

Commit

Permalink
Merge pull request #188 from bab2min/dev/kiwi_builder
Browse files Browse the repository at this point in the history
Improve Kiwi Builder
  • Loading branch information
bab2min authored Oct 3, 2024
2 parents c2552a4 + 981497a commit 93f3188
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 60 deletions.
8 changes: 5 additions & 3 deletions include/kiwi/Dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace kiwi

static constexpr int32_t nonVocab = -1;

HiddenMember<RaggedVector<uint16_t>, sizeof(Vector<size_t>) * 2> sents;
HiddenMember<RaggedVector<uint32_t>, sizeof(Vector<size_t>) * 2> sents;
std::shared_ptr<lm::KnLangModelBase> knlm;
std::unique_ptr<utils::ThreadPool> workers;
std::discrete_distribution<> dropout;
Expand All @@ -55,6 +55,7 @@ namespace kiwi
Deque<OptionalFuture<size_t>> futures;
const Vector<MorphemeRaw>* morphemes = nullptr;
const Vector<FormRaw>* forms = nullptr;
size_t knlmVocabSize = 0;
size_t batchSize = 0;
size_t windowSize = 0;
size_t totalTokens = 0;
Expand Down Expand Up @@ -87,12 +88,13 @@ namespace kiwi
size_t next(int64_t* in, int64_t* out, float* lmLProbs, int64_t* outNgramNode, float& restLmOut, uint32_t& restLmCntOut);

size_t vocabSize() const { return vocabToToken.size(); }
size_t getKnlmVocabSize() const;
size_t ngramNodeSize() const;
const MorphemeRaw& vocabInfo(uint32_t vocab) const;
std::u16string vocabForm(uint32_t vocab) const;
std::vector<size_t> estimVocabFrequency() const;

Range<Vector<uint16_t>::const_iterator> getSent(size_t idx) const;
std::vector<uint16_t> getAugmentedSent(size_t idx);
Range<Vector<uint32_t>::const_iterator> getSent(size_t idx) const;
std::vector<uint32_t> getAugmentedSent(size_t idx);
};
}
11 changes: 9 additions & 2 deletions include/kiwi/Kiwi.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,14 +546,19 @@ namespace kiwi
FormRaw& addForm(const KString& form);
size_t addForm(Vector<FormRaw>& newForms, UnorderedMap<KString, size_t>& newFormMap, KString form) const;

using MorphemeMap = UnorderedMap<std::pair<KString, POSTag>, std::pair<size_t, size_t>>;
using MorphemeMap = UnorderedMap<std::tuple<KString, uint8_t, POSTag>, std::pair<size_t, size_t>>;

template<class Fn>
MorphemeMap loadMorphemesFromTxt(std::istream& is, Fn&& filter);

MorphemeMap restoreMorphemeMap() const;
MorphemeMap restoreMorphemeMap(bool separateDefaultMorpheme = false) const;

template<class VocabTy>
void _addCorpusTo(RaggedVector<VocabTy>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio, RaggedVector<VocabTy>* splitOut) const;

void addCorpusTo(RaggedVector<uint8_t>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector<uint8_t>* splitOut = nullptr) const;
void addCorpusTo(RaggedVector<uint16_t>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector<uint16_t>* splitOut = nullptr) const;
void addCorpusTo(RaggedVector<uint32_t>& out, std::istream& is, MorphemeMap& morphMap, double splitRatio = 0, RaggedVector<uint32_t>* splitOut = nullptr) const;
void updateForms();
void updateMorphemes();

Expand Down Expand Up @@ -610,6 +615,7 @@ namespace kiwi
size_t lmMinCnt = 1;
size_t lmLastOrderMinCnt = 2;
size_t numWorkers = 1;
size_t sbgSize = 1000000;
bool useLmTagHistory = true;
bool quantizeLm = true;
bool compressLm = true;
Expand Down Expand Up @@ -799,6 +805,7 @@ namespace kiwi
double dropoutProb = 0,
const TokenFilter& tokenFilter = {},
double splitRatio = 0,
bool separateDefaultMorpheme = false,
HSDataset* splitDataset = nullptr
) const;
};
Expand Down
11 changes: 8 additions & 3 deletions src/Dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ std::u16string HSDataset::vocabForm(uint32_t vocab) const
return joinHangul((*forms)[(*morphemes)[vocabToToken[vocab]].kform].form);
}

size_t HSDataset::getKnlmVocabSize() const
{
return knlmVocabSize;
}

std::vector<size_t> kiwi::HSDataset::estimVocabFrequency() const
{
std::vector<size_t> ret(vocabSize()), augs(getDefaultMorphemeId(POSTag::max));
Expand All @@ -279,7 +284,7 @@ std::vector<size_t> kiwi::HSDataset::estimVocabFrequency() const
return ret;
}

Range<Vector<uint16_t>::const_iterator> HSDataset::getSent(size_t idx) const
Range<Vector<uint32_t>::const_iterator> HSDataset::getSent(size_t idx) const
{
return sents.get()[idx];
}
Expand All @@ -289,9 +294,9 @@ void HSDataset::seed(size_t newSeed)
rng.seed(newSeed);
}

std::vector<uint16_t> HSDataset::getAugmentedSent(size_t idx)
std::vector<uint32_t> HSDataset::getAugmentedSent(size_t idx)
{
std::vector<uint16_t> ret;
std::vector<uint32_t> ret;
auto sent = sents.get()[idx];
ret.emplace_back(*sent.begin());
for (auto p = sent.begin() + 1; p != sent.end() - 1; ++p)
Expand Down
Loading

0 comments on commit 93f3188

Please sign in to comment.