Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster repetition penalty sampling operation #199

Merged
merged 4 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <string>
#include <sys/stat.h>
#include <thread>
#include <unordered_set>

#ifdef __has_include
#if __has_include(<unistd.h>)
Expand Down Expand Up @@ -589,19 +588,19 @@ int BaseModelForCausalLM::generate_next_token(const std::vector<int> &input_ids,
void BaseModelForCausalLM::sampling_repetition_penalty(float *first, float *last, const std::vector<int> &input_ids,
float penalty) {
CHATGLM_CHECK(penalty > 0) << "penalty must be a positive float, but got " << penalty;
std::unordered_set<int> unique_input_ids(input_ids.begin(), input_ids.end());
for (int id : unique_input_ids) {
CHATGLM_CHECK(first <= first + id && first + id < last) << "invalid input id " << id;
if (first[id] > 0) {
first[id] /= penalty;
} else {
first[id] *= penalty;
const float inv_penalty = 1.f / penalty;
const int vocab_size = last - first;
std::vector<bool> occurrence(vocab_size, false);
for (const int id : input_ids) {
if (!occurrence[id]) {
first[id] *= (first[id] > 0) ? inv_penalty : penalty;
}
occurrence[id] = true;
}
}

void BaseModelForCausalLM::sampling_temperature(float *first, float *last, float temp) {
float inv_temp = 1.f / temp;
const float inv_temp = 1.f / temp;
for (float *it = first; it != last; it++) {
*it *= inv_temp;
}
Expand All @@ -616,12 +615,12 @@ TokenIdScore *BaseModelForCausalLM::sampling_top_p(TokenIdScore *first, TokenIdS
sampling_softmax_inplace(first, last);

while (first + 1 < last) {
float pivot_score = (last - 1)->score; // use mid score?
const float pivot_score = (last - 1)->score; // use mid score?
TokenIdScore *mid =
std::partition(first, last - 1, [pivot_score](const TokenIdScore &x) { return x.score > pivot_score; });
std::swap(*mid, *(last - 1));

float prefix_sum =
const float prefix_sum =
std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore &x) { return sum + x.score; });
if (prefix_sum >= top_p) {
last = mid;
Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chatglm_cpp._C as _C
from chatglm_cpp._C import ChatMessage

__version__ = "0.3.0"
__version__ = "0.3.1.dev"


@dataclass
Expand Down
24 changes: 24 additions & 0 deletions chatglm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ static inline char *read_tensor_data(char *ptr, ggml_tensor *tensor) {

static inline float random() { return rand() / (float)RAND_MAX; }

static inline float random(float lo, float hi) { return lo + random() * (hi - lo); }

static inline void random_fill(ggml_tensor *tensor) {
std::vector<float> values(ggml_nelements(tensor));
for (float &v : values) {
Expand Down Expand Up @@ -115,6 +117,28 @@ TEST(Sampling, RepetitionPenalty) {
}
}

TEST(DISABLED_Sampling, BenchmarkRepetitionPenalty) {
const float penalty = 1.2;
constexpr size_t vocab_size = 128000;
constexpr int seq_len = 32000;
std::vector<float> logits(vocab_size);
for (auto &x : logits) {
x = random(-1, 1);
}
std::vector<int> input_ids(seq_len);
for (size_t i = 0; i < input_ids.size(); i++) {
input_ids[i] = i;
}

auto fn = [&logits, &input_ids, penalty] {
BaseModelForCausalLM::sampling_repetition_penalty(logits.data(), logits.data() + logits.size(), input_ids,
penalty);
};
auto elapsed_ms = timeit(fn, 2, 100);
std::cout << "[" << ::testing::UnitTest::GetInstance()->current_test_info()->name() << "] " << elapsed_ms
<< " ms\n";
}

TEST(Sampling, Temperature) {
constexpr float temp = 0.7;
std::vector<float> logits(64);
Expand Down