diff --git a/chatglm.cpp b/chatglm.cpp index 3cac2a4..daaa26a 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #ifdef __has_include #if __has_include() @@ -589,19 +588,19 @@ int BaseModelForCausalLM::generate_next_token(const std::vector &input_ids, void BaseModelForCausalLM::sampling_repetition_penalty(float *first, float *last, const std::vector &input_ids, float penalty) { CHATGLM_CHECK(penalty > 0) << "penalty must be a positive float, but got " << penalty; - std::unordered_set unique_input_ids(input_ids.begin(), input_ids.end()); - for (int id : unique_input_ids) { - CHATGLM_CHECK(first <= first + id && first + id < last) << "invalid input id " << id; - if (first[id] > 0) { - first[id] /= penalty; - } else { - first[id] *= penalty; + const float inv_penalty = 1.f / penalty; + const int vocab_size = last - first; + std::vector occurrence(vocab_size, false); + for (const int id : input_ids) { + if (!occurrence[id]) { + first[id] *= (first[id] > 0) ? inv_penalty : penalty; } + occurrence[id] = true; } } void BaseModelForCausalLM::sampling_temperature(float *first, float *last, float temp) { - float inv_temp = 1.f / temp; + const float inv_temp = 1.f / temp; for (float *it = first; it != last; it++) { *it *= inv_temp; } @@ -616,12 +615,12 @@ TokenIdScore *BaseModelForCausalLM::sampling_top_p(TokenIdScore *first, TokenIdS sampling_softmax_inplace(first, last); while (first + 1 < last) { - float pivot_score = (last - 1)->score; // use mid score? + const float pivot_score = (last - 1)->score; // use mid score? TokenIdScore *mid = std::partition(first, last - 1, [pivot_score](const TokenIdScore &x) { return x.score > pivot_score; }); std::swap(*mid, *(last - 1)); - float prefix_sum = + const float prefix_sum = std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore &x) { return sum + x.score; }); if (prefix_sum >= top_p) { last = mid; diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index d3938bd..cdb79cd 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -6,7 +6,7 @@ import chatglm_cpp._C as _C from chatglm_cpp._C import ChatMessage -__version__ = "0.3.0" +__version__ = "0.3.1.dev" @dataclass diff --git a/chatglm_test.cpp b/chatglm_test.cpp index 03b8cd5..6295a3e 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -36,6 +36,8 @@ static inline char *read_tensor_data(char *ptr, ggml_tensor *tensor) { static inline float random() { return rand() / (float)RAND_MAX; } +static inline float random(float lo, float hi) { return lo + random() * (hi - lo); } + static inline void random_fill(ggml_tensor *tensor) { std::vector values(ggml_nelements(tensor)); for (float &v : values) { @@ -115,6 +117,28 @@ TEST(Sampling, RepetitionPenalty) { } } +TEST(DISABLED_Sampling, BenchmarkRepetitionPenalty) { + const float penalty = 1.2; + constexpr size_t vocab_size = 128000; + constexpr int seq_len = 32000; + std::vector logits(vocab_size); + for (auto &x : logits) { + x = random(-1, 1); + } + std::vector input_ids(seq_len); + for (size_t i = 0; i < input_ids.size(); i++) { + input_ids[i] = i; + } + + auto fn = [&logits, &input_ids, penalty] { + BaseModelForCausalLM::sampling_repetition_penalty(logits.data(), logits.data() + logits.size(), input_ids, + penalty); + }; + auto elapsed_ms = timeit(fn, 2, 100); + std::cout << "[" << ::testing::UnitTest::GetInstance()->current_test_info()->name() << "] " << elapsed_ms + << " ms\n"; +} + TEST(Sampling, Temperature) { constexpr float temp = 0.7; std::vector logits(64);