Skip to content

Commit

Permalink
move _TemperatureSoftmaxto sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzhengxiang committed Oct 31, 2024
1 parent 69e1187 commit 17329be
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion transformers/llm/engine/include/sampler/sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace MNN {
namespace Transformer {

MNN_PUBLIC VARP _TempratureSoftmax(VARP logits, float temperature, int axis = -1);
MNN_PUBLIC MNN::Express::VARP _TempratureSoftmax(MNN::Express::VARP logits, float temperature, int axis = -1);

class Llm;

Expand Down
10 changes: 5 additions & 5 deletions transformers/llm/engine/src/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
namespace MNN{
namespace Transformer{

VARP _TempratureSoftmax(VARP logits, float temperature, int axis) {
return _Softmax(logits * _Scalar<float>(1.0f / temperature), axis);
MNN::Express::VARP _TempratureSoftmax(MNN::Express::VARP logits, float temperature, int axis) {
return MNN::Express::_Softmax(logits * MNN::Express::_Scalar<float>(1.0f / temperature), axis);
}

/* ----------Sampler's members---------- */
Expand Down Expand Up @@ -43,7 +43,7 @@ int Sampler::randomSelect(MNN::Express::VARP probs) {
}

int Sampler::reSoftmaxSelect(struct SubsetLogits subset, float temperature) {
int token_index_id = randomSelect(MNN::Express::_TempratureSoftmax(subset.logits, temperature));
int token_index_id = randomSelect(_TempratureSoftmax(subset.logits, temperature));
return ((subset.is_subset) ? subset.index[token_index_id] : token_index_id);
}

Expand Down Expand Up @@ -257,7 +257,7 @@ struct SubsetLogits LocalSampler::topK(struct SubsetLogits superset) {
}

int LocalSampler::packSoftmax(MNN::Express::VARP logits, std::vector<IndexScore>& index_scores, float temperature) {
auto prob_varp = MNN::Express::_TempratureSoftmax(logits, temperature);
auto prob_varp = _TempratureSoftmax(logits, temperature);
auto probs = (float*)(prob_varp->readMap<float>());
auto size = prob_varp->getInfo()->size;
index_scores.resize(size);
Expand Down Expand Up @@ -358,7 +358,7 @@ struct SubsetLogits LocalSampler::tfs(struct SubsetLogits superset) {

struct SubsetLogits LocalSampler::typical(struct SubsetLogits superset) {
float p = mConfig.typical, temperature = mConfig.temperature;
auto prob_varp = MNN::Express::_TempratureSoftmax(superset.logits, temperature);
auto prob_varp = _TempratureSoftmax(superset.logits, temperature);
auto probs = (float*)(prob_varp->readMap<float>());
auto size = prob_varp->getInfo()->size;
std::vector<IndexScore> index_scores;
Expand Down

0 comments on commit 17329be

Please sign in to comment.