Skip to content

Commit

Permalink
Add PrefixCount to clustering feature
Browse files Browse the repository at this point in the history
  • Loading branch information
bab2min committed Aug 30, 2024
1 parent d950587 commit 90b121f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 7 deletions.
24 changes: 22 additions & 2 deletions include/kiwi/SubstringExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,42 @@ namespace kiwi
UnorderedMap<uint32_t, uint32_t> token2id;
Vector<uint32_t> id2Token;
Vector<uint16_t> buf;
Vector<size_t> tokenClusters;
Vector<size_t> tokenCnts;
std::shared_ptr<void> threadPool;

template<class It>
void _addArray(It first, It last);

Vector<std::pair<uint32_t, float>> computeClusterScore() const;

public:
PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers);
PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers,
const std::vector<std::vector<size_t>>& clusters = {}
);
void addArray(const uint16_t* first, const uint16_t* last);
void addArray(const uint32_t* first, const uint32_t* last);
void addArray(const uint64_t* first, const uint64_t* last);
utils::FrozenTrie<uint32_t, uint32_t> count() const;
std::unique_ptr<lm::KnLangModelBase> buildLM(
size_t lastMinCf,
const std::vector<size_t>& minCfByOrder,
size_t bosTokenId,
size_t eosTokenId,
size_t unkTokenId,
ArchType archType = ArchType::none
) const;
};

class ClusterData
{
const std::pair<uint32_t, float>* clusterScores = nullptr;
size_t clusterSize = 0;
public:
ClusterData();
ClusterData(const void* _ptr, size_t _size);

size_t size() const;
size_t cluster(size_t i) const;
float score(size_t i) const;
};
}
121 changes: 116 additions & 5 deletions src/SubstringExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,62 @@ namespace kiwi
template<class IntTy>
using PrefixTrieNode = utils::TrieNodeEx<IntTy, uint32_t, utils::ConstAccess<map<IntTy, int32_t>>>;

PrefixCounter::PrefixCounter(size_t _prefixSize, size_t _minCf, size_t _numWorkers)
PrefixCounter::PrefixCounter(
size_t _prefixSize,
size_t _minCf,
size_t _numWorkers,
const std::vector<std::vector<size_t>>& clusters
)
: prefixSize(_prefixSize), minCf(_minCf), id2Token(2), buf(1)
{
if (_numWorkers == 0) _numWorkers = min(thread::hardware_concurrency(), 8u);
if (_numWorkers == (size_t)-1) _numWorkers = min(thread::hardware_concurrency(), 8u);
if (_numWorkers > 1)
{
threadPool = make_unique<mp::ThreadPool>(_numWorkers);
}

if (clusters.empty()) return;

unordered_set<size_t> alreadyAllocated;
for (auto cs : clusters)
{
if (cs.empty()) continue;
sort(cs.begin(), cs.end());
const auto cid = cs[0];
for (auto c : cs)
{
if (alreadyAllocated.find(c) != alreadyAllocated.end())
{
throw runtime_error("Duplicated cluster id");
}
alreadyAllocated.insert(c);

if (c >= tokenClusters.size())
{
const auto e = c + 1;
tokenClusters.resize(e, -1);
tokenCnts.resize(e);
}
tokenClusters[c] = cid;
}
}
}

template<class It>
void PrefixCounter::_addArray(It first, It last)
{
for (; first != last; ++first)
{
const auto token = *first;
auto token = *first;
if (token < tokenClusters.size())
{
if (tokenClusters[token] != (size_t)-1)
{
tokenCnts[token]++;
token = tokenClusters[token];
}
}

auto it = token2id.find(token);
if (it == token2id.end())
{
Expand Down Expand Up @@ -295,18 +335,89 @@ namespace kiwi
return utils::freezeTrie(move(trie), ArchType::balanced);
}

Vector<pair<uint32_t, float>> PrefixCounter::computeClusterScore() const
{
UnorderedMap<size_t, size_t> clusterCnts;
for (size_t i = 0; i < tokenClusters.size(); ++i)
{
if (tokenClusters[i] != (size_t)-1)
{
clusterCnts[tokenClusters[i]] += tokenCnts[i];
}
}

Vector<pair<uint32_t, float>> ret;
ret.reserve(tokenClusters.size());
for (size_t i = 0; i < tokenClusters.size(); ++i)
{
if (tokenClusters[i] == (size_t)-1)
{
ret.emplace_back(-1, 0);
}
else
{
ret.emplace_back(tokenClusters[i], (float)log((double)tokenCnts[i] / clusterCnts[tokenClusters[i]]));
}
}
return ret;
}

unique_ptr<lm::KnLangModelBase> PrefixCounter::buildLM(
size_t lastMinCf,
const std::vector<size_t>& minCfByOrder,
size_t bosTokenId,
size_t eosTokenId,
size_t unkTokenId,
ArchType archType) const
{
Vector<char> extraBuf;
if (!tokenClusters.empty())
{
auto clusterScore = computeClusterScore();
extraBuf.resize(clusterScore.size() * sizeof(uint64_t) + sizeof(uint64_t) * 2);
memcpy(extraBuf.data(), "UNIGRAM\0", sizeof(uint64_t));
uint64_t size = clusterScore.size();
memcpy(extraBuf.data() + sizeof(uint64_t), &size, sizeof(uint64_t));
memcpy(extraBuf.data() + sizeof(uint64_t) * 2, clusterScore.data(), clusterScore.size() * sizeof(uint64_t));
}

utils::MemoryOwner mem;
{
auto trie = count();
mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCf, lastMinCf, unkTokenId, bosTokenId, eosTokenId, 1e-5f, 0, false);
mem = lm::KnLangModelBase::build(move(trie), prefixSize, minCfByOrder, unkTokenId, bosTokenId, eosTokenId,
1e-5f, 0, false, nullptr, (const Vector<int>*)nullptr,
extraBuf.data(), extraBuf.size());
}
return lm::KnLangModelBase::create(move(mem), archType);
}

ClusterData::ClusterData() = default;

ClusterData::ClusterData(const void* _ptr, size_t _size)
{
if (!_ptr || !_size) return;
if (_size < sizeof(uint64_t) * 2) throw runtime_error("Invalid cluster data");
auto ptr = (const uint64_t*)_ptr;
if (memcmp(ptr, "UNIGRAM\0", sizeof(uint64_t)) != 0) throw runtime_error("Invalid cluster data");
const auto size = ptr[1];
if (_size < sizeof(uint64_t) * 2 + size * sizeof(uint64_t)) throw runtime_error("Invalid cluster data");
clusterScores = (const pair<uint32_t, float>*)(ptr + 2);
clusterSize = size;
}

size_t ClusterData::size() const
{
return clusterSize;
}

size_t ClusterData::cluster(size_t i) const
{
if (i >= clusterSize || clusterScores[i].first == (uint32_t)-1) return i;
return clusterScores[i].first;
}

float ClusterData::score(size_t i) const
{
if (i >= clusterSize || clusterScores[i].first == (uint32_t)-1) return 0;
return clusterScores[i].second;
}
}

0 comments on commit 90b121f

Please sign in to comment.