diff --git a/CMakeLists.txt b/CMakeLists.txt index 93bb0382..0b5a1738 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,7 @@ set ( CORE_SRCS src/Joiner.cpp src/Kiwi.cpp src/KiwiBuilder.cpp + src/Knlm.cpp src/KTrie.cpp src/PatternMatcher.cpp src/search.cpp diff --git a/include/kiwi/FrozenTrie.h b/include/kiwi/FrozenTrie.h index b6fbc8a4..5c4b39b4 100644 --- a/include/kiwi/FrozenTrie.h +++ b/include/kiwi/FrozenTrie.h @@ -97,6 +97,26 @@ namespace kiwi std::unique_ptr nextKeys; std::unique_ptr nextDiffs; + template + void traverse(Fn&& visitor, const Node* node, std::vector& prefix, size_t maxDepth) const + { + auto* keys = &nextKeys[node->nextOffset]; + auto* diffs = &nextDiffs[node->nextOffset]; + for (size_t i = 0; i < node->numNexts; ++i) + { + const auto* child = node + diffs[i]; + const auto val = child->val(*this); + if (!hasMatch(val)) continue; + prefix.emplace_back(keys[i]); + visitor(val, prefix); + if (prefix.size() < maxDepth) + { + traverse(visitor, child, prefix, maxDepth); + } + prefix.pop_back(); + } + } + public: FrozenTrie() = default; @@ -117,6 +137,13 @@ namespace kiwi const Value& value(size_t idx) const { return values[idx]; }; bool hasMatch(_Value v) const { return !this->isNull(v) && !this->hasSubmatch(v); } + + template + void traverse(Fn&& visitor, size_t maxDepth = -1) const + { + std::vector prefix; + traverse(std::forward(visitor), root(), prefix, maxDepth); + } }; } } diff --git a/include/kiwi/Knlm.h b/include/kiwi/Knlm.h index cea37e21..8b96de8d 100644 --- a/include/kiwi/Knlm.h +++ b/include/kiwi/Knlm.h @@ -7,7 +7,9 @@ #include #include +#include "Utils.h" #include "Mmap.h" +#include "ArchUtils.h" namespace kiwi { @@ -20,6 +22,7 @@ namespace kiwi uint64_t num_nodes, node_offset, key_offset, ll_offset, gamma_offset, qtable_offset, htx_offset; uint64_t unk_id, bos_id, eos_id, vocab_size; uint8_t order, key_size, diff_size, quantized; + uint32_t extra_buf_size; }; template @@ -43,6 +46,7 @@ namespace kiwi virtual float _progress(ptrdiff_t& node_idx, size_t next) const = 0; virtual std::vector allNextLL(ptrdiff_t node_idx) const = 0; virtual std::vector allNextLL(ptrdiff_t node_idx, std::vector& next_node_idx) const = 0; + virtual void nextTopN(ptrdiff_t node_idx, size_t top_n, uint32_t* idx_out, float* ll_out) const = 0; public: @@ -55,21 +59,28 @@ namespace kiwi virtual size_t llSize() const = 0; virtual const float* getLLBuf() const = 0; virtual const float* getGammaBuf() const = 0; + virtual const void* getExtraBuf() const = 0; static std::unique_ptr create(utils::MemoryObject&& mem, ArchType archType = ArchType::none); - template> - static utils::MemoryOwner build(const utils::ContinuousTrie& ngram_cf, - size_t order, size_t min_cf, size_t last_min_cf, + template> + static utils::MemoryOwner build(Trie&& ngram_cf, + size_t order, const std::vector& min_cf_by_order, size_t unk_id, size_t bos_id, size_t eos_id, float unigram_alpha, size_t quantize, bool compress, const std::vector>* bigram_list = nullptr, - const HistoryTx* historyTransformer = nullptr + const HistoryTx* history_transformer = nullptr, + const void* extra_buf = nullptr, + size_t extra_buf_size = 0 ); const utils::MemoryObject& getMemory() const { return base; } - //virtual float progress(ptrdiff_t& node_idx, size_t next) const = 0; + template + float progress(ptrdiff_t& node_idx, Ty next) const + { + return _progress(node_idx, next); + } template void evaluate(InTy in_first, InTy in_last, OutTy out_first) const @@ -130,6 +141,19 @@ namespace kiwi } } + template + void predictTopN(InTy in_first, InTy in_last, size_t top_n, uint32_t* idx_out, float* ll_out) const + { + ptrdiff_t node_idx = 0; + for (; in_first != in_last; ++in_first) + { + _progress(node_idx, *in_first); + nextTopN(node_idx, top_n, idx_out, ll_out); + idx_out += top_n; + ll_out += top_n; + } + } + template void fillIn(PfTy prefix_first, PfTy prefix_last, SfTy suffix_first, SfTy suffix_last, OutTy out_first, bool reduce = true) const { diff --git a/include/kiwi/Mmap.h b/include/kiwi/Mmap.h index 0401d89d..a2c6ef3d 100644 --- a/include/kiwi/Mmap.h +++ b/include/kiwi/Mmap.h @@ -303,6 +303,11 @@ namespace kiwi setp(epptr() + off, epptr()); else if (dir == std::ios_base::beg) setp(pbase() + off, epptr()); + + if (!(which & std::ios_base::in)) + { + return pptr() - pbase(); + } } return gptr() - eback(); } diff --git a/include/kiwi/SubstringExtractor.h b/include/kiwi/SubstringExtractor.h index c4ba7c90..ec71a48e 100644 --- a/include/kiwi/SubstringExtractor.h +++ b/include/kiwi/SubstringExtractor.h @@ -3,6 +3,9 @@ #include #include +#include +#include + namespace kiwi { std::vector> extractSubstrings( @@ -13,4 +16,50 @@ namespace kiwi size_t maxLength = 32, bool longestOnly = true, char16_t stopChr = 0); + + + class PrefixCounter + { + size_t prefixSize = 0, minCf = 0, numArrays = 0; + UnorderedMap token2id; + Vector id2Token; + Vector buf; + Vector tokenClusters; + Vector tokenCnts; + std::shared_ptr threadPool; + + template + void _addArray(It first, It last); + + Vector> computeClusterScore() const; + + public: + PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers, + const std::vector>& clusters = {} + ); + void addArray(const uint16_t* first, const uint16_t* last); + void addArray(const uint32_t* first, const uint32_t* last); + void addArray(const uint64_t* first, const uint64_t* last); + utils::FrozenTrie count() const; + std::unique_ptr buildLM( + const std::vector& minCfByOrder, + size_t bosTokenId, + size_t eosTokenId, + size_t unkTokenId, + ArchType archType = ArchType::none + ) const; + }; + + class ClusterData + { + const std::pair* clusterScores = nullptr; + size_t clusterSize = 0; + public: + ClusterData(); + ClusterData(const void* _ptr, size_t _size); + + size_t size() const; + size_t cluster(size_t i) const; + float score(size_t i) const; + }; } diff --git a/include/kiwi/Trie.hpp b/include/kiwi/Trie.hpp index b2f24581..2bd9a42a 100644 --- a/include/kiwi/Trie.hpp +++ b/include/kiwi/Trie.hpp @@ -132,6 +132,24 @@ namespace kiwi return; } + template + void traverse(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const + { + fn(this->val, rkeys); + + if (rkeys.size() >= maxDepth) return; + + for (auto& p : next) + { + if (ignoreNegative ? (p.second > 0) : (p.second)) + { + rkeys.emplace_back(p.first); + getNext(p.first)->traverse(fn, rkeys, maxDepth, ignoreNegative); + rkeys.pop_back(); + } + } + } + template void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const { @@ -462,6 +480,13 @@ namespace kiwi return nodes[0].fillFail(std::forward(htx), ignoreNegative); } + template + void traverse(_Fn&& fn, size_t maxDepth = -1, bool ignoreNegative = false) const + { + std::vector rkeys; + return nodes[0].traverse(std::forward<_Fn>(fn), rkeys, maxDepth, ignoreNegative); + } + template void traverseWithKeys(_Fn&& fn, std::vector<_CKey>& rkeys, size_t maxDepth = -1, bool ignoreNegative = false) const { diff --git a/include/kiwi/Utils.h b/include/kiwi/Utils.h index 870d8a9e..2cf5674d 100644 --- a/include/kiwi/Utils.h +++ b/include/kiwi/Utils.h @@ -1,7 +1,8 @@ -#pragma once +#pragma once #include #include #include +#include #include "Types.h" namespace kiwi @@ -82,6 +83,11 @@ namespace kiwi return within(chr, 0x302E, 0x3030); } + inline bool isCompatibleHangulConsonant(char16_t chr) + { + return within(chr, 0x3131, 0x314E) || within(chr, 0x3165, 0x3186); + } + struct ComparatorIgnoringSpace { static bool less(const KString& a, const KString& b, const kchar_t space = u' '); @@ -146,6 +152,38 @@ namespace kiwi return joinHangul(hangul.begin(), hangul.end()); } + inline bool isHighSurrogate(char16_t c) + { + return (c & 0xFC00) == 0xD800; + } + + inline bool isLowSurrogate(char16_t c) + { + return (c & 0xFC00) == 0xDC00; + } + + inline char32_t mergeSurrogate(char16_t h, char16_t l) + { + return (((h & 0x3FF) << 10) | (l & 0x3FF)) + 0x10000; + } + + inline std::array decomposeSurrogate(char32_t c) + { + std::array ret; + if (c < 0x10000) + { + ret[0] = c; + ret[1] = 0; + } + else + { + c -= 0x10000; + ret[0] = ((c >> 10) & 0x3FF) | 0xD800; + ret[1] = (c & 0x3FF) | 0xDC00; + } + return ret; + } + POSTag identifySpecialChr(char32_t chr); size_t getSSType(char16_t c); size_t getSBType(const std::u16string& form); diff --git a/src/FrozenTrie.hpp b/src/FrozenTrie.hpp index fe650c34..13371e3a 100644 --- a/src/FrozenTrie.hpp +++ b/src/FrozenTrie.hpp @@ -111,7 +111,7 @@ namespace kiwi for (size_t i = 0; i < trie.size(); ++i) { auto& o = trie[i]; - nodes[i].numNexts = o.next.size(); + nodes[i].numNexts = (Key)o.next.size(); values[i] = xform(o); nodes[i].nextOffset = ptr; diff --git a/src/Joiner.cpp b/src/Joiner.cpp index d6c3de03..80f7df55 100644 --- a/src/Joiner.cpp +++ b/src/Joiner.cpp @@ -352,7 +352,7 @@ namespace kiwi for (size_t i = 0; i < candidates.size(); ++i) { auto& c = candidates[i]; - auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, i)); + auto inserted = bestScoreByState.emplace(c.lmState, make_pair(c.score, (uint32_t)i)); if (!inserted.second) { if (inserted.first->second.first < c.score) diff --git a/src/Kiwi.cpp b/src/Kiwi.cpp index c01df3e1..f239dcca 100644 --- a/src/Kiwi.cpp +++ b/src/Kiwi.cpp @@ -847,6 +847,7 @@ namespace kiwi morph.tag = s.tokenization[0].tag; morph.vowel = CondVowel::none; morph.polar = CondPolarity::none; + morph.complex = 0; morph.lmMorphemeId = getDefaultMorphemeId(s.tokenization[0].tag); form.candidate[0] = &morph; } diff --git a/src/KiwiBuilder.cpp b/src/KiwiBuilder.cpp index 06225241..60dd20e4 100644 --- a/src/KiwiBuilder.cpp +++ b/src/KiwiBuilder.cpp @@ -723,10 +723,11 @@ KiwiBuilder::KiwiBuilder(const ModelBuildArgs& args) new (&pool) utils::ThreadPool{ args.numWorkers }; } auto cntNodes = utils::count(sents.begin(), sents.end(), args.lmMinCnt, 1, args.lmOrder, (args.numWorkers > 1 ? &pool : nullptr), &bigramList, args.useLmTagHistory ? &historyTx : nullptr); - cntNodes.root().getNext(lmVocabSize)->val /= 2; + std::vector minCnts(args.lmOrder, args.lmMinCnt); + minCnts.back() = args.lmLastOrderMinCnt; langMdl.knlm = lm::KnLangModelBase::create(lm::KnLangModelBase::build( cntNodes, - args.lmOrder, args.lmMinCnt, args.lmLastOrderMinCnt, + args.lmOrder, minCnts, 2, 0, 1, 1e-5, args.quantizeLm ? 8 : 0, args.compressLm, diff --git a/src/Knlm.cpp b/src/Knlm.cpp new file mode 100644 index 00000000..2ab66388 --- /dev/null +++ b/src/Knlm.cpp @@ -0,0 +1,46 @@ +#include "Knlm.hpp" + +namespace kiwi +{ + namespace lm + { + template + std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) + { + auto* ptr = reinterpret_cast(mem.get()); + auto& header = *reinterpret_cast(ptr); + switch (header.key_size) + { + case 1: + return make_unique>(std::move(mem)); + case 2: + return make_unique>(std::move(mem)); + case 4: + return make_unique>(std::move(mem)); + case 8: + return make_unique>(std::move(mem)); + default: + throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.key_size) }; + } + } + + using FnCreateOptimizedModel = decltype(&createOptimizedModel); + + struct CreateOptimizedModelGetter + { + template + struct Wrapper + { + static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; + }; + }; + + std::unique_ptr KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType) + { + static tp::Table table{ CreateOptimizedModelGetter{} }; + auto fn = table[static_cast(archType)]; + if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; + return (*fn)(std::move(mem)); + } + } +} \ No newline at end of file diff --git a/src/Knlm.hpp b/src/Knlm.hpp index 1f9f9baf..bb7c756d 100644 --- a/src/Knlm.hpp +++ b/src/Knlm.hpp @@ -4,9 +4,9 @@ #include #include #include +#include #include #include -#include #include "ArchAvailable.h" #include "search.h" #include "BitEncoder.hpp" @@ -17,6 +17,9 @@ namespace kiwi { namespace lm { + static constexpr size_t serialAlignment = 16; + + using QCode = qe::QCode<0, 2, 8, 16>; template @@ -79,6 +82,12 @@ namespace kiwi } } + inline const void* toAlignedPtr(const void* ptr, size_t alignment = serialAlignment) + { + auto addr = reinterpret_cast(ptr); + return reinterpret_cast((addr + alignment - 1) & ~(alignment - 1)); + } + template class KnLangModel : public KnLangModelBase { @@ -92,6 +101,7 @@ namespace kiwi const float* ll_data = nullptr; const float* gamma_data = nullptr; const KeyType* htx_data = nullptr; + const void* extra_buf = nullptr; Vector restored_floats; float unk_ll = 0; ptrdiff_t bos_node_idx = 0; @@ -159,8 +169,8 @@ namespace kiwi { auto* ptr = reinterpret_cast(base.get()); auto& header = getHeader(); - size_t quantized = header.quantized & 0x1F; - bool compressed = header.quantized & 0x80; + const size_t quantized = header.quantized & 0x1F; + const bool compressed = header.quantized & 0x80; Vector d_node_size; auto* node_sizes = reinterpret_cast(ptr + header.node_offset); @@ -209,12 +219,14 @@ namespace kiwi num_non_leaf_nodes, num_leaf_nodes ); + extra_buf = toAlignedPtr(gamma_table + ((size_t)1 << quantized)); } else { ll_data = reinterpret_cast(ptr + header.ll_offset); gamma_data = reinterpret_cast(ptr + header.gamma_offset); leaf_ll_data = ll_data + num_non_leaf_nodes; + extra_buf = toAlignedPtr(gamma_data + num_non_leaf_nodes); } size_t htx_vocab_size = header.vocab_size; @@ -222,6 +234,12 @@ namespace kiwi { htx_data = reinterpret_cast(ptr + header.htx_offset); htx_vocab_size = *std::max_element(htx_data, htx_data + header.vocab_size) + 1; + extra_buf = toAlignedPtr(htx_data + header.vocab_size); + } + + if (!header.extra_buf_size) + { + extra_buf = nullptr; } // restore node's data @@ -448,6 +466,11 @@ namespace kiwi return gamma_data; } + const void* getExtraBuf() const final + { + return extra_buf; + } + ptrdiff_t getLowerNode(ptrdiff_t node_idx) const final { return node_idx + node_data[node_idx].lower; @@ -627,47 +650,94 @@ namespace kiwi } return ret; } - }; - template - std::unique_ptr createOptimizedModel(utils::MemoryObject&& mem) - { - auto* ptr = reinterpret_cast(mem.get()); - auto& header = *reinterpret_cast(ptr); - switch (header.key_size) + template + void _nextTopN(ptrdiff_t node_idx, size_t top_n, KeyOut* idx_out, float* ll_out) const { - case 1: - return make_unique>(std::move(mem)); - case 2: - return make_unique>(std::move(mem)); - case 4: - return make_unique>(std::move(mem)); - case 8: - return make_unique>(std::move(mem)); - default: - throw std::runtime_error{ "Unsupported `key_size` : " + std::to_string((size_t)header.key_size) }; - } - } + thread_local Vector> buf; + buf.clear(); + auto* node = &node_data[node_idx]; + auto* keys = &key_data[node->next_offset]; + auto* values = &value_data[node->next_offset]; + for (size_t i = 0; i < node->num_nexts; ++i) + { + if (values[i] < 0) + { + buf.emplace_back(reinterpret_cast(values[i]), (KeyOut)keys[i]); + } + else + { + buf.emplace_back(ll_data[node_idx + values[i]], (KeyOut)keys[i]); + } + } + std::make_heap(buf.begin(), buf.end()); - using FnCreateOptimizedModel = decltype(&createOptimizedModel); + float acc = 0; + while (node->num_nexts < top_n && node->lower) + { + acc += gamma_data[node - &node_data[0]]; + node += node->lower; + keys = &key_data[node->next_offset]; + values = &value_data[node->next_offset]; + for (size_t i = 0; i < node->num_nexts; ++i) + { + if (values[i] < 0) + { + buf.emplace_back(acc + reinterpret_cast(values[i]), (KeyOut)keys[i]); + } + else + { + buf.emplace_back(acc + ll_data[node - &node_data[0] + values[i]], (KeyOut)keys[i]); + } + std::push_heap(buf.begin(), buf.end()); + } + } - struct CreateOptimizedModelGetter - { - template - struct Wrapper + size_t i; + if (top_n <= 16) + { + for (i = 0; i < top_n && !buf.empty();) + { + std::pop_heap(buf.begin(), buf.end()); + if (std::find(idx_out, idx_out + i, buf.back().second) == idx_out + i) + { + idx_out[i] = buf.back().second; + ll_out[i] = buf.back().first; + ++i; + } + buf.pop_back(); + } + } + else + { + thread_local std::unordered_set uniq; + uniq.clear(); + for (i = 0; i < top_n && !buf.empty();) + { + std::pop_heap(buf.begin(), buf.end()); + if (uniq.insert(buf.back().second).second) + { + idx_out[i] = buf.back().second; + ll_out[i] = buf.back().first; + ++i; + } + buf.pop_back(); + } + } + + for (; i < top_n; ++i) + { + idx_out[i] = 0; + ll_out[i] = -INFINITY; + } + } + + void nextTopN(ptrdiff_t node_idx, size_t top_n, uint32_t* idx_out, float* ll_out) const final { - static constexpr FnCreateOptimizedModel value = &createOptimizedModel(i)>; - }; + return _nextTopN(node_idx, top_n, idx_out, ll_out); + } }; - inline std::unique_ptr KnLangModelBase::create(utils::MemoryObject&& mem, ArchType archType) - { - static tp::Table table{ CreateOptimizedModelGetter{} }; - auto fn = table[static_cast(archType)]; - if (!fn) throw std::runtime_error{ std::string{"Unsupported architecture : "} + archToStr(archType) }; - return (*fn)(std::move(mem)); - } - template void quantize(const std::vector& ll_table, const std::vector& gamma_table, const std::vector& ll, const std::vector& leaf_ll, @@ -725,15 +795,33 @@ namespace kiwi return table[bits - 1](ll_table, gamma_table, ll, leaf_ll, gamma, llq, gammaq); } + inline size_t alignedOffsetInc(size_t& offset, size_t inc, size_t alignment = serialAlignment) + { + return offset = (offset + inc + alignment - 1) & ~(alignment - 1); + } + + inline std::ostream& writePadding(std::ostream& os, size_t alignment = serialAlignment) + { + const size_t pos = os.tellp(); + size_t pad = ((pos + alignment - 1) & ~(alignment - 1)) - pos; + for (size_t i = 0; i < pad; ++i) + { + os.put(0); + } + return os; + } + template utils::MemoryOwner buildCompressedModel(Header header, - size_t min_cf, size_t last_min_cf, + const std::vector& min_cf_by_order, float unigram_alpha, utils::ContinuousTrie&& compressed_ngrams, const std::vector& unigram_pats, const std::vector& unigram_cnts, const std::vector>& ngram_ncnt, - const HistoryTx* history_transformer = nullptr + const HistoryTx* history_transformer = nullptr, + const void* extra_buf = nullptr, + size_t extra_buf_size = 0 ) { header.key_size = sizeof(KeyType); @@ -743,10 +831,10 @@ namespace kiwi std::vector ll_table, gamma_table; std::ostringstream llq, gammaq, c_node_size; std::vector keys; - size_t quantized = header.quantized & 0x1F; - bool compressed = (header.quantized & 0x80) != 0; + const size_t quantized = header.quantized & 0x1F; + const bool compressed = (header.quantized & 0x80) != 0; - size_t quantize_size = (1 << (header.quantized & 0x1F)); + const size_t quantize_size = ((size_t)1 << (header.quantized & 0x1F)); for (auto& node : compressed_ngrams) { size_t i = (size_t)(&node - &compressed_ngrams[0]); @@ -767,7 +855,7 @@ namespace kiwi double y = ncnt[0] / (ncnt[0] + 2. * ncnt[1]); for (size_t j = 0; j < 3; ++j) { - discnts[i][j] = ncnt[j] ? ((j + 1) - (j + 2) * y * ncnt[j + 1] / ncnt[j]) : 0; + discnts[i][j] = ncnt[j] ? std::max((j + 1) - (j + 2) * y * ncnt[j + 1] / ncnt[j], 0.) : 0; } } if (history_transformer) @@ -775,13 +863,13 @@ namespace kiwi for (auto& e : discnts[1]) e *= 0.25; } - std::vector rkeys; + std::vector rkeys; // set gamma & unigram ll - compressed_ngrams[0].traverseWithKeys([&](const TrieNode* node, const std::vector& rkeys) + compressed_ngrams[0].traverseWithKeys([&](const TrieNode* node, const std::vector& rkeys) { if (rkeys.empty()) return; ptrdiff_t i = (ptrdiff_t)(node - &compressed_ngrams[0]); - size_t min_cnt = rkeys.size() < header.order - 1 ? min_cf : last_min_cf; + const size_t min_cnt = std::max(min_cf_by_order[std::max(std::min(rkeys.size(), min_cf_by_order.size()), (size_t)1) - 1], (size_t)1); std::array pats = { 0, }; ptrdiff_t rest = node->val; @@ -828,12 +916,12 @@ namespace kiwi // set n-gram ll for (size_t o = 2; o <= header.order; ++o) { - compressed_ngrams[0].traverseWithKeys([&](const TrieNode* node, const std::vector& rkeys) + compressed_ngrams[0].traverseWithKeys([&](const TrieNode* node, const std::vector& rkeys) { ptrdiff_t i = (ptrdiff_t)(node - &compressed_ngrams[0]); if (rkeys.size() == o) { - size_t min_cnt = o < header.order ? min_cf : last_min_cf; + const size_t min_cnt = std::max(min_cf_by_order[std::max(std::min(rkeys.size(), min_cf_by_order.size()), (size_t)1) - 1], (size_t)1); if (node->val) { double l = (node->val - min_cnt * discnts[rkeys.size() - 1][std::min(node->val / min_cnt, (size_t)3) - 1]) / (double)node->getParent()->val; @@ -917,42 +1005,44 @@ namespace kiwi size_t final_size = 0; - header.node_offset = (final_size += sizeof(Header)); + header.node_offset = alignedOffsetInc(final_size, sizeof(Header)); if (compressed) { - header.key_offset = (final_size += c_node_size.tellp()); + header.key_offset = alignedOffsetInc(final_size, c_node_size.tellp()); } else { - header.key_offset = (final_size += sizeof(KeyType) * node_sizes.size()); + header.key_offset = alignedOffsetInc(final_size, sizeof(KeyType) * node_sizes.size()); } - header.ll_offset = (final_size += sizeof(KeyType) * keys.size()); + header.ll_offset = alignedOffsetInc(final_size, sizeof(KeyType) * keys.size()); if (quantized) { - header.gamma_offset = (final_size += llq.tellp()); - header.qtable_offset = (final_size += gammaq.tellp()); - final_size += sizeof(float) * quantize_size * 2; + header.gamma_offset = alignedOffsetInc(final_size, llq.tellp()); + header.qtable_offset = alignedOffsetInc(final_size, gammaq.tellp()); + alignedOffsetInc(final_size, sizeof(float) * quantize_size * 2); } else { - header.gamma_offset = (final_size += sizeof(float) * (ll.size() + leaf_ll.size())); + header.gamma_offset = alignedOffsetInc(final_size, sizeof(float) * (ll.size() + leaf_ll.size())); header.qtable_offset = 0; - final_size += sizeof(float) * gamma.size(); + alignedOffsetInc(final_size, sizeof(float) * gamma.size()); } if (history_transformer) { header.htx_offset = final_size; - final_size += sizeof(KeyType) * header.vocab_size; + alignedOffsetInc(final_size, sizeof(KeyType) * header.vocab_size); } else { header.htx_offset = 0; } + header.extra_buf_size = extra_buf_size; - utils::MemoryOwner ret{ final_size }; + utils::MemoryOwner ret{ final_size + extra_buf_size }; utils::omstream ostr{ (char*)ret.get(), (std::ptrdiff_t)ret.size() }; ostr.write((const char*)&header, sizeof(Header)); + writePadding(ostr); if (compressed) { ostr.write((const char*)c_node_size.str().data(), c_node_size.tellp()); @@ -961,11 +1051,15 @@ namespace kiwi { ostr.write((const char*)node_sizes.data(), sizeof(KeyType) * node_sizes.size()); } + writePadding(ostr); ostr.write((const char*)keys.data(), sizeof(KeyType) * keys.size()); + writePadding(ostr); if (quantized) { ostr.write((const char*)llq.str().data(), llq.tellp()); + writePadding(ostr); ostr.write((const char*)gammaq.str().data(), gammaq.tellp()); + writePadding(ostr); ostr.write((const char*)ll_table.data(), sizeof(float) * quantize_size); ostr.write((const char*)gamma_table.data(), sizeof(float) * quantize_size); } @@ -973,8 +1067,10 @@ namespace kiwi { ostr.write((const char*)ll.data(), sizeof(float) * ll.size()); ostr.write((const char*)leaf_ll.data(), sizeof(float) * leaf_ll.size()); + writePadding(ostr); ostr.write((const char*)gamma.data(), sizeof(float) * gamma.size()); } + writePadding(ostr); if (history_transformer) { @@ -983,25 +1079,46 @@ namespace kiwi htx.resize(header.vocab_size); ostr.write((const char*)htx.data(), sizeof(KeyType) * htx.size()); } + writePadding(ostr); + + if (extra_buf_size) + { + ostr.write((const char*)extra_buf, extra_buf_size); + } return ret; } - template - utils::MemoryOwner KnLangModelBase::build(const utils::ContinuousTrie& ngram_cf, - size_t order, size_t min_cf, size_t last_min_cf, + template + struct GetNodeType; + + template + struct GetNodeType> + { + using type = TrieNode; + }; + + template + struct GetNodeType> + { + using type = utils::TrieNodeEx; + }; + + template + utils::MemoryOwner KnLangModelBase::build(Trie&& ngram_cf, + size_t order, const std::vector& min_cf_by_order, size_t unk_id, size_t bos_id, size_t eos_id, float unigram_alpha, size_t quantize, bool compress, - const std::vector>* bigram_list, const HistoryTx* history_transformer + const std::vector>* bigram_list, const HistoryTx* history_transformer, + const void* extra_buf, size_t extra_buf_size ) { + using TrieNode = typename GetNodeType::type>::type>::type; + using Key = typename TrieNode::Key; if (quantize > 16) throw std::invalid_argument{ "16+ bits quantization not supported."}; size_t max_vid = 0; utils::ContinuousTrie compressed_ngrams{ 1 }; std::vector unigram_pats, unigram_cnts; std::vector> ngram_ncnt(order); - if (min_cf == 0) min_cf = 1; - if (last_min_cf < min_cf) last_min_cf = min_cf; - if (bigram_list) { for (auto& p : *bigram_list) @@ -1022,16 +1139,16 @@ namespace kiwi } { - std::vector rkeys; + std::vector rkeys; utils::ContinuousTrie reverse_ngrams{ 1 }; - ngram_cf[0].traverseWithKeys([&](const TrieNode* node, const std::vector& rkeys) + ngram_cf.traverse([&](const uint32_t cnt, const std::vector& rkeys) { // unigram prob counting if (rkeys.size() == 1) { if (rkeys[0] >= unigram_cnts.size()) unigram_cnts.resize(rkeys[0] + 1); - unigram_cnts[rkeys[0]] += node->val; + unigram_cnts[rkeys[0]] += cnt; } if (bigram_list == nullptr && rkeys.size() == 2) @@ -1040,25 +1157,25 @@ namespace kiwi unigram_pats[rkeys[1]] += 1; } - size_t min_cnt = rkeys.size() == order ? last_min_cf : min_cf; + const size_t min_cnt = std::max(min_cf_by_order[std::max(std::min(rkeys.size(), min_cf_by_order.size()), (size_t)1) - 1], (size_t)1); - if (node->val < min_cnt) return; + if (cnt < min_cnt) return; if (!rkeys.empty()) max_vid = std::max(max_vid, (size_t)rkeys.back()); // last-gram discounting if (rkeys.size() == order) { - size_t n = node->val / last_min_cf; + size_t n = cnt / min_cnt; if (n <= 4) ngram_ncnt[order - 1][n - 1]++; } if (rkeys.size() >= 2) { - reverse_ngrams.build(rkeys.rbegin(), rkeys.rend(), 0)->val = node->val; + reverse_ngrams.build(rkeys.rbegin(), rkeys.rend(), 0)->val = cnt; } - compressed_ngrams.build(rkeys.begin(), rkeys.end(), 0)->val += node->val; - }, rkeys); + compressed_ngrams.build(rkeys.begin(), rkeys.end(), 0)->val += cnt; + }); if (history_transformer) { compressed_ngrams.fillFail([&](size_t i) { return (*history_transformer)[i]; }, true); @@ -1068,7 +1185,7 @@ namespace kiwi compressed_ngrams.fillFail(true); } - reverse_ngrams[0].traverseWithKeys([&](const TrieNode* node, const std::vector& rkeys) + reverse_ngrams.traverseWithKeys([&](const TrieNode* node, const std::vector& rkeys) { if (rkeys.size() >= 1) { @@ -1096,19 +1213,35 @@ namespace kiwi if (max_vid <= 0xFF) { - return buildCompressedModel(header, min_cf, last_min_cf, unigram_alpha, move(compressed_ngrams), unigram_pats, unigram_cnts, ngram_ncnt, history_transformer); + return buildCompressedModel(header, min_cf_by_order, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, + history_transformer, + extra_buf, extra_buf_size); } else if (max_vid <= 0xFFFF) { - return buildCompressedModel(header, min_cf, last_min_cf, unigram_alpha, move(compressed_ngrams), unigram_pats, unigram_cnts, ngram_ncnt, history_transformer); + return buildCompressedModel(header, min_cf_by_order, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, + history_transformer, + extra_buf, extra_buf_size); } else if (max_vid <= 0xFFFFFFFF) { - return buildCompressedModel(header, min_cf, last_min_cf, unigram_alpha, move(compressed_ngrams), unigram_pats, unigram_cnts, ngram_ncnt, history_transformer); + return buildCompressedModel(header, min_cf_by_order, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, + history_transformer, + extra_buf, extra_buf_size); } else { - return buildCompressedModel(header, min_cf, last_min_cf, unigram_alpha, move(compressed_ngrams), unigram_pats, unigram_cnts, ngram_ncnt, history_transformer); + return buildCompressedModel(header, min_cf_by_order, + unigram_alpha, move(compressed_ngrams), + unigram_pats, unigram_cnts, ngram_ncnt, + history_transformer, + extra_buf, extra_buf_size); } } } diff --git a/src/LmState.hpp b/src/LmState.hpp index 089cf224..c019875d 100644 --- a/src/LmState.hpp +++ b/src/LmState.hpp @@ -30,7 +30,7 @@ namespace kiwi template class KnLMState { - friend class Hash>; + friend struct Hash>; int32_t node = 0; public: static constexpr ArchType arch = _arch; @@ -57,7 +57,7 @@ namespace kiwi template class SbgState : public KnLMState<_arch, VocabTy> { - friend class Hash>; + friend struct Hash>; size_t historyPos = 0; std::array history = { {0,} }; public: diff --git a/src/PathEvaluator.hpp b/src/PathEvaluator.hpp index dadb48a4..1c0e9993 100644 --- a/src/PathEvaluator.hpp +++ b/src/PathEvaluator.hpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -585,7 +585,7 @@ namespace kiwi } else { - auto inserted = bestPathIndex.emplace(ph, make_pair(bestPathValues.size(), 1)); + auto inserted = bestPathIndex.emplace(ph, make_pair((uint32_t)bestPathValues.size(), 1)); if (inserted.second) { bestPathValues.emplace_back(curMorph, candScore, prevPath.accTypoCost + node->typoCost, &prevPath, move(cLmState), spState); diff --git a/src/StrUtils.h b/src/StrUtils.h index b6a7b2ac..1268c3f5 100644 --- a/src/StrUtils.h +++ b/src/StrUtils.h @@ -1,9 +1,8 @@ #pragma once -#include #include #include -#include #include +#include #include "string_view.hpp" namespace kiwi @@ -372,38 +371,6 @@ namespace kiwi return ret; } - inline bool isHighSurrogate(char16_t c) - { - return (c & 0xFC00) == 0xD800; - } - - inline bool isLowSurrogate(char16_t c) - { - return (c & 0xFC00) == 0xDC00; - } - - inline char32_t mergeSurrogate(char16_t h, char16_t l) - { - return (((h & 0x3FF) << 10) | (l & 0x3FF)) + 0x10000; - } - - inline std::array decomposeSurrogate(char32_t c) - { - std::array ret; - if (c < 0x10000) - { - ret[0] = c; - ret[1] = 0; - } - else - { - c -= 0x10000; - ret[0] = ((c >> 10) & 0x3FF) | 0xD800; - ret[1] = (c & 0x3FF) | 0xDC00; - } - return ret; - } - inline std::string utf16To8(nonstd::u16string_view str) { std::string ret; diff --git a/src/SubstringExtractor.cpp b/src/SubstringExtractor.cpp index bdcaffda..cc41541b 100644 --- a/src/SubstringExtractor.cpp +++ b/src/SubstringExtractor.cpp @@ -1,12 +1,30 @@ #include +#include #include +#include #include #include "StrUtils.h" +#include "FrozenTrie.hpp" +#include "Knlm.hpp" #include "sais/fm_index.hpp" + +#ifdef KIWI_USE_BTREE + +#ifdef _WIN32 +using ssize_t = ptrdiff_t; +#else +#include +#endif + +#include +#else +#endif + + using namespace std; namespace kiwi @@ -47,26 +65,33 @@ namespace kiwi vector> ret; fi.enumSuffices(minCnt, [&](const sais::FmIndex::SuffixTy& s, const sais::FmIndex::TraceTy& t) { - if (s.size() > maxLength) return false; + auto u32size = s.size(); + for (size_t i = 0; i < s.size(); ++i) + { + if (isLowSurrogate(s[i])) u32size--; + } + + if (u32size > maxLength) return false; if (find(s.begin(), s.end(), stopChr) != s.end()) { return false; } - if (isLowSurrogate(s.back()) || isHighSurrogate(s.front())) return false; + if (isHighSurrogate(s.front())) return false; + if (isLowSurrogate(s.back())) return true; if (testRepetition(s.data(), s.size())) { return false; } - const auto ssLength = s.size(); - if (ssLength < minLength) + if (u32size < minLength) { return true; } + const auto ssLength = s.size(); const auto ssCnt = t.back().second - t.back().first; if (ssCnt < minCnt) { @@ -138,4 +163,261 @@ namespace kiwi return ret; } + +#ifdef KIWI_USE_BTREE + template using map = btree::map; +#else + template using map = std::map; +#endif + + template + using PrefixTrieNode = utils::TrieNodeEx>>; + + PrefixCounter::PrefixCounter( + size_t _prefixSize, + size_t _minCf, + size_t _numWorkers, + const std::vector>& clusters + ) + : prefixSize(_prefixSize), minCf(_minCf), id2Token(2), buf(1) + { + if (_numWorkers == (size_t)-1) _numWorkers = min(thread::hardware_concurrency(), 8u); + if (_numWorkers > 1) + { + threadPool = make_unique(_numWorkers); + } + + if (clusters.empty()) return; + + unordered_set alreadyAllocated; + for (auto cs : clusters) + { + if (cs.empty()) continue; + sort(cs.begin(), cs.end()); + const auto cid = cs[0]; + for (auto c : cs) + { + if (alreadyAllocated.find(c) != alreadyAllocated.end()) + { + throw runtime_error("Duplicated cluster id"); + } + alreadyAllocated.insert(c); + + if (c >= tokenClusters.size()) + { + const auto e = c + 1; + tokenClusters.resize(e, -1); + tokenCnts.resize(e); + } + tokenClusters[c] = cid; + } + } + } + + template + void PrefixCounter::_addArray(It first, It last) + { + for (; first != last; ++first) + { + auto token = *first; + if (token < tokenClusters.size()) + { + if (tokenClusters[token] != (size_t)-1) + { + tokenCnts[token]++; + token = tokenClusters[token]; + } + } + + auto it = token2id.find(token); + if (it == token2id.end()) + { + const auto id = id2Token.size(); + it = token2id.emplace(token, id).first; + id2Token.push_back(token); + } + if (it->second < 0x4000) + { + buf.emplace_back(it->second); + } + else if (it->second < 0x10000000) + { + buf.emplace_back((it->second & 0x3FFF) | 0x4000); + buf.emplace_back((it->second >> 14) | 0x8000); + } + else + { + throw runtime_error("Too many tokens"); + } + } + buf.emplace_back(1); + numArrays += 1; + } + + + void PrefixCounter::addArray(const uint16_t* first, const uint16_t* last) + { + _addArray(first, last); + } + + void PrefixCounter::addArray(const uint32_t* first, const uint32_t* last) + { + _addArray(first, last); + } + + void PrefixCounter::addArray(const uint64_t* first, const uint64_t* last) + { + _addArray(first, last); + } + + utils::FrozenTrie PrefixCounter::count() const + { + sais::FmIndex fi{ (const char16_t*)buf.data(), buf.size(), (mp::ThreadPool*)threadPool.get()}; + utils::ContinuousTrie> trie{ 1 }; + trie.root().val = buf.size() - 1 - numArrays; + + unique_ptr mtx; + if (threadPool) + { + mtx = make_unique(); + } + + fi.enumSuffices(minCf, [&](const sais::FmIndex::SuffixTy& s, const sais::FmIndex::TraceTy& t) + { + auto u32size = s.size(); + for (size_t i = 0; i < s.size(); ++i) + { + if (s[i] & 0x8000) + { + u32size--; + } + } + + if (u32size > prefixSize) return false; + + if (find(s.begin(), s.end(), 0) != s.end() || find(s.begin(), s.end(), 1) != s.end()) + { + return false; + } + + if (s.front() & 0x4000) return false; + if (s.back() & 0x8000) return true; + + const auto suffixCnt = t.back().second - t.back().first; + if (suffixCnt < minCf) + { + return false; + } + + thread_local Vector restoredBuf; + restoredBuf.clear(); + for(auto rit = s.rbegin(); rit != s.rend(); ++rit) + { + if (*rit & 0x4000) + { + const auto merged = (rit[0] & 0x3FFF) | ((rit[1] & 0x3FFF) << 14); + restoredBuf.push_back(id2Token[merged]); + ++rit; + } + else if (*rit & 0x8000) + { + throw runtime_error("Invalid token"); + } + else + { + restoredBuf.push_back(id2Token[*rit]); + } + } + mp::OptionalLockGuard lock{ mtx.get() }; + trie.build(restoredBuf.begin(), restoredBuf.end(), suffixCnt); + return true; + }, (mp::ThreadPool*)threadPool.get()); + return utils::freezeTrie(move(trie), ArchType::balanced); + } + + Vector> PrefixCounter::computeClusterScore() const + { + UnorderedMap clusterCnts; + for (size_t i = 0; i < tokenClusters.size(); ++i) + { + if (tokenClusters[i] != (size_t)-1) + { + clusterCnts[tokenClusters[i]] += tokenCnts[i]; + } + } + + Vector> ret; + ret.reserve(tokenClusters.size()); + for (size_t i = 0; i < tokenClusters.size(); ++i) + { + if (tokenClusters[i] == (size_t)-1) + { + ret.emplace_back(-1, 0); + } + else + { + ret.emplace_back(tokenClusters[i], (float)log((double)tokenCnts[i] / clusterCnts[tokenClusters[i]])); + } + } + return ret; + } + + unique_ptr PrefixCounter::buildLM( + const std::vector& minCfByOrder, + size_t bosTokenId, + size_t eosTokenId, + size_t unkTokenId, + ArchType archType) const + { + Vector extraBuf; + if (!tokenClusters.empty()) + { + auto clusterScore = computeClusterScore(); + extraBuf.resize(clusterScore.size() * sizeof(uint64_t) + sizeof(uint64_t) * 2); + memcpy(extraBuf.data(), "UNIGRAM\0", sizeof(uint64_t)); + uint64_t size = clusterScore.size(); + memcpy(extraBuf.data() + sizeof(uint64_t), &size, sizeof(uint64_t)); + memcpy(extraBuf.data() + sizeof(uint64_t) * 2, clusterScore.data(), clusterScore.size() * sizeof(uint64_t)); + } + + utils::MemoryOwner mem; + { + auto trie = count(); + mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCfByOrder, unkTokenId, bosTokenId, eosTokenId, + 1e-5f, 0, false, nullptr, (const Vector*)nullptr, + extraBuf.data(), extraBuf.size()); + } + return lm::KnLangModelBase::create(move(mem), archType); + } + + ClusterData::ClusterData() = default; + + ClusterData::ClusterData(const void* _ptr, size_t _size) + { + if (!_ptr || !_size) return; + if (_size < sizeof(uint64_t) * 2) throw runtime_error("Invalid cluster data"); + auto ptr = (const uint64_t*)_ptr; + if (memcmp(ptr, "UNIGRAM\0", sizeof(uint64_t)) != 0) throw runtime_error("Invalid cluster data"); + const auto size = ptr[1]; + if (_size < sizeof(uint64_t) * 2 + size * sizeof(uint64_t)) throw runtime_error("Invalid cluster data"); + clusterScores = (const pair*)(ptr + 2); + clusterSize = size; + } + + size_t ClusterData::size() const + { + return clusterSize; + } + + size_t ClusterData::cluster(size_t i) const + { + if (i >= clusterSize || clusterScores[i].first == (uint32_t)-1) return i; + return clusterScores[i].first; + } + + float ClusterData::score(size_t i) const + { + if (i >= clusterSize || clusterScores[i].first == (uint32_t)-1) return 0; + return clusterScores[i].second; + } } diff --git a/src/sais/fm_index.hpp b/src/sais/fm_index.hpp index a6f6513f..5c21aca9 100644 --- a/src/sais/fm_index.hpp +++ b/src/sais/fm_index.hpp @@ -14,7 +14,7 @@ namespace sais std::unique_ptr cKeys; std::unique_ptr cValues; size_t length = 0, vocabSize = 0; - WaveletTree waveletTree; + WaveletTree waveletTree; public: FmIndex() = default; @@ -25,14 +25,14 @@ namespace sais if (length < 0x80000000) { auto ibuf = std::unique_ptr(new int32_t[length + 1]); - bwt(data, bwtData.get(), ibuf.get(), length, 0, nullptr, pool); + bwt(data, bwtData.get(), ibuf.get(), length, 0, nullptr, pool); } else { auto ibuf = std::unique_ptr(new int64_t[length + 1]); - bwt(data, bwtData.get(), ibuf.get(), length, 0, nullptr, pool); + bwt(data, bwtData.get(), ibuf.get(), length, 0, nullptr, pool); } - waveletTree = WaveletTree{ bwtData.get(), length }; + waveletTree = WaveletTree{ bwtData.get(), length }; /*freqs = std::unique_ptr(new size_t[(size_t)1 << (sizeof(ChrTy) * 8)]); std::fill(freqs.get(), freqs.get() + ((size_t)1 << (sizeof(ChrTy) * 8)), 0); @@ -114,30 +114,34 @@ namespace sais } template - size_t enumSuffices(size_t minCnt, Fn&& fn) const + size_t enumSuffices(size_t minCnt, Fn&& fn, mp::ThreadPool* tp = nullptr) const { - SuffixTy suffix; - TraceTy trace; - size_t numSuffices = 0; - for (size_t k = 0; k < vocabSize; ++k) + auto numSuffices = mp::runParallel(tp, [&](const size_t i, const size_t numWorkers, mp::Barrier*) { - auto p = std::make_pair(cValues[k], (k + 1 < vocabSize) ? cValues[k + 1] : length); - if (p.second - p.first < minCnt) continue; - suffix.push_back(cKeys[k]); - trace.emplace_back(p); - if (!fn(const_cast(suffix), const_cast(trace))) + SuffixTy suffix; + TraceTy trace; + size_t numSuffices = 0; + for (size_t k = i; k < vocabSize; k += numWorkers) { + auto p = std::make_pair(cValues[k], (k + 1 < vocabSize) ? cValues[k + 1] : length); + if (p.second - p.first < minCnt) continue; + suffix.push_back(cKeys[k]); + trace.emplace_back(p); + if (!fn(const_cast(suffix), const_cast(trace))) + { + suffix.pop_back(); + trace.pop_back(); + continue; + } + numSuffices++; + + numSuffices += enumSuffices(minCnt, suffix, trace, p.first, p.second, fn); suffix.pop_back(); trace.pop_back(); - continue; } - numSuffices++; - - numSuffices += enumSuffices(minCnt, suffix, trace, p.first, p.second, fn); - suffix.pop_back(); - trace.pop_back(); - } - return numSuffices; + return numSuffices; + }); + return std::accumulate(numSuffices.begin(), numSuffices.end(), (size_t)0); } }; } diff --git a/src/sais/mp_utils.hpp b/src/sais/mp_utils.hpp index 01bc7ff4..374f6054 100644 --- a/src/sais/mp_utils.hpp +++ b/src/sais/mp_utils.hpp @@ -406,4 +406,24 @@ namespace mp if (pool) pool->_limitedSize = prevSize; } }; + + template + class OptionalLockGuard + { + Mutex* mutex; + + public: + OptionalLockGuard(Mutex* _mutex = nullptr) : mutex{ _mutex } + { + if (mutex) mutex->lock(); + } + + ~OptionalLockGuard() + { + if (mutex) mutex->unlock(); + } + + OptionalLockGuard(const OptionalLockGuard&) = delete; + OptionalLockGuard& operator=(const OptionalLockGuard&) = delete; + }; } diff --git a/src/sais/sais.hpp b/src/sais/sais.hpp index fa38b89b..8ffeb02e 100644 --- a/src/sais/sais.hpp +++ b/src/sais/sais.hpp @@ -436,11 +436,14 @@ class SaisImpl gather_lms_suffixes_16u(T, SA, n, (fast_sint_t)n - 1 - m, omp_block_start, omp_block_size); - mp::barrier(barrier); - - if (pool && thread_state[id].state.m > 0) + if (num_threads > 1) { - SA[(fast_sint_t)n - 1 - m] = (SaTy)thread_state[id].state.last_lms_suffix; + mp::barrier(barrier); + + if (pool && thread_state[id].state.m > 0) + { + SA[(fast_sint_t)n - 1 - m] = (SaTy)thread_state[id].state.last_lms_suffix; + } } }, mp::ParallelCond{n >= 65536}); } diff --git a/src/sais/wavelet_tree.hpp b/src/sais/wavelet_tree.hpp index 843ef409..4c438834 100644 --- a/src/sais/wavelet_tree.hpp +++ b/src/sais/wavelet_tree.hpp @@ -24,7 +24,8 @@ namespace sais return ret; } - inline size_t writeLSBs(uint8_t* out, size_t bitOffset, const char16_t* data, size_t size) + template + inline size_t writeLSBs(uint8_t* out, size_t bitOffset, const ChrTy* data, size_t size) { using namespace kiwi::utils; @@ -68,7 +69,8 @@ namespace sais return oneCnt; } - inline void splitByLSB(char16_t* ones, char16_t* zeros, const char16_t* data, size_t size) + template + inline void splitByLSB(ChrTy* ones, ChrTy* zeros, const ChrTy* data, size_t size) { for (size_t i = 0; i < size; ++i) { @@ -84,11 +86,11 @@ namespace sais } } - - inline size_t writeMSBs(uint8_t* out, size_t bitOffset, const char16_t* data, size_t size) + template + inline size_t writeMSBs(uint8_t* out, size_t bitOffset, const ChrTy* data, size_t size) { using namespace kiwi::utils; - + static constexpr size_t bitSize = sizeof(ChrTy) * 8; static constexpr size_t rSize = sizeof(size_t) * 8; size_t headSize = std::min(bitOffset ? (rSize - bitOffset) : 0, size); size_t bodySize = (std::max(size, headSize) - headSize) & ~(rSize - 1); @@ -97,7 +99,7 @@ namespace sais for (size_t i = 0; i < headSize; ++i) { - auto v = (size_t)((data[i] & (1 << 15)) ? 1 : 0); + auto v = (size_t)((data[i] & (1 << (bitSize - 1))) ? 1 : 0); ((size_t*)out)[0] |= v << (i + bitOffset); oneCnt += v ? 1 : 0; } @@ -108,7 +110,7 @@ namespace sais size_t g = 0; for (size_t j = 0; j < rSize; ++j) { - auto v = (size_t)((curData[j] & (1 << 15)) ? 1 : 0); + auto v = (size_t)((curData[j] & (1 << (bitSize - 1))) ? 1 : 0); g |= v << j; } ((size_t*)out)[i + (headSize ? 1 : 0)] = g; @@ -121,7 +123,7 @@ namespace sais tail = 0; for (size_t i = 0; i < tailSize; ++i) { - auto v = (size_t)((data[headSize + bodySize + i] & (1 << 15)) ? 1 : 0); + auto v = (size_t)((data[headSize + bodySize + i] & (1 << (bitSize - 1))) ? 1 : 0); tail |= v << (i & (rSize - 1)); oneCnt += v ? 1 : 0; } @@ -129,11 +131,13 @@ namespace sais return oneCnt; } - inline void splitByMSB(char16_t* ones, char16_t* zeros, const char16_t* data, size_t size) + template + inline void splitByMSB(ChrTy* ones, ChrTy* zeros, const ChrTy* data, size_t size) { + static constexpr size_t bitSize = sizeof(ChrTy) * 8; for (size_t i = 0; i < size; ++i) { - if ((*data & (1 << 15))) + if ((*data & (1 << (bitSize - 1)))) { *ones++ = *data << 1; } @@ -171,13 +175,14 @@ namespace sais return n / 2 + ((size_t)1 << (k - h - 1)); } + template class WaveletTree { static constexpr size_t bitAlignmentSize = sizeof(size_t) * 8; static constexpr size_t superBlockSize = 64; static constexpr size_t superBlockBitSize = superBlockSize * 8; static constexpr size_t rSize = sizeof(size_t) * 8; - static constexpr size_t depth = sizeof(char16_t) * 8; + static constexpr size_t depth = sizeof(ChrTy) * 8; size_t length = 0; std::unique_ptr bits; @@ -192,7 +197,7 @@ namespace sais } template - size_t enumerate(size_t i, char16_t c, size_t l, size_t r, size_t offsetIdx, Fn&& fn) const + size_t enumerate(size_t i, ChrTy c, size_t l, size_t r, size_t offsetIdx, Fn&& fn) const { size_t ret = 0; @@ -237,15 +242,15 @@ namespace sais public: WaveletTree() = default; - WaveletTree(const char16_t* data, size_t size) + WaveletTree(const ChrTy* data, size_t size) { const size_t alignedSize = (size + bitAlignmentSize - 1) & ~(bitAlignmentSize - 1); length = size; - bits = std::unique_ptr(new uint8_t[alignedSize * sizeof(char16_t)]); + bits = std::unique_ptr(new uint8_t[alignedSize * sizeof(ChrTy)]); offsets = std::unique_ptr(new size_t[(size_t)1 << depth]); if (alignedSize / superBlockBitSize > 0) superBlocks = std::unique_ptr(new size_t[alignedSize / superBlockBitSize * depth]); - std::vector buf(size * 2); + std::vector buf(size * 2); size_t oneCnt = writeMSBs(&bits[0], 0, data, size); fillSuperBlocks(&superBlocks[0], &bits[0], alignedSize / 8); splitByMSB(&buf[0], &buf[oneCnt], data, size); @@ -274,7 +279,7 @@ namespace sais } } - size_t rank(char16_t c, size_t l) const + size_t rank(ChrTy c, size_t l) const { const size_t alignedSize = (length + bitAlignmentSize - 1) & ~(bitAlignmentSize - 1); size_t offsetIdx = 0; @@ -286,7 +291,7 @@ namespace sais const size_t offset = offsets[offsetIdx]; const size_t oneCnt = countOne(curBits, curSuperBlocks, l + offset) - countOne(curBits, curSuperBlocks, offset); - const size_t msb = (c & (1 << 15)); + const size_t msb = (c & (1 << (depth - 1))); l = msb ? oneCnt : (l - oneCnt); offsetIdx += msb ? 0 : ((size_t)1 << (depth - i - 1)); c <<= 1; diff --git a/vsproj/kiwi_shared_library.vcxproj b/vsproj/kiwi_shared_library.vcxproj index 35a21b68..aca1d9e3 100644 --- a/vsproj/kiwi_shared_library.vcxproj +++ b/vsproj/kiwi_shared_library.vcxproj @@ -119,6 +119,7 @@ +