diff --git a/.gitignore b/.gitignore index a6391ada9..f46989dfb 100644 --- a/.gitignore +++ b/.gitignore @@ -360,3 +360,10 @@ pymnn_build/ # mnncompress generated MNN_compression_pb2.py + +# releases +release +example +debug +hack +tmp \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b940476e..4676d62c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,7 +20,9 @@ endif() project(MNN VERSION ${MNN_VERSION} LANGUAGES C CXX ASM) # complier options set(CMAKE_C_STANDARD 99) -set(CMAKE_CXX_STANDARD 11) +IF (NOT (CMAKE_CXX_STANDARD EQUAL 17)) + set(CMAKE_CXX_STANDARD 11) +ENDIF() set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/cmake" @@ -274,7 +276,7 @@ if(CMAKE_SYSTEM_NAME MATCHES "^Android") endif() option(MNN_USE_CPP11 "Enable MNN use c++11" ON) if (NOT MSVC) - if(MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) + if((MNN_CUDA AND MNN_SUPPORT_TRANSFORMER_FUSE) OR (CMAKE_CXX_STANDARD EQUAL 17)) set(CMAKE_CXX_STANDARD 17) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=gnu99") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") diff --git a/build_cmd.md b/build_cmd.md new file mode 100644 index 000000000..dbac57e20 --- /dev/null +++ b/build_cmd.md @@ -0,0 +1,30 @@ +~~~ +mkdir build +cd build +cmake .. -DCMAKE_CXX_STANDARD=17 -DMNN_USE_SYSTEM_LIB=OFF -DMNN_BUILD_SHARED_LIBS=ON -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_TEST=ON -DMNN_BUILD_QUANTOOLS=ON -DMNN_EVALUATION=ON -DMNN_BUILD_CONVERTER=ON -DMNN_PORTABLE_BUILD=ON -DTFMODEL_OPTIMIZE=ON -DMNN_BUILD_LLM=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON -DMNN_LOW_MEMORY=ON -DMNN_AVX512=ON +make -j20 +~~~ + +## 其他模型转换到MNN +### TensorFlow to MNN +```bash +./MNNConvert -f TF --modelFile XXX.pb --MNNModel XXX.mnn --bizCode biz +``` +注意:`*.pb`必须是frozen model,不能使用saved_model + +### TensorFlow Lite to MNN +```bash +./MNNConvert -f TFLITE --modelFile XXX.tflite --MNNModel XXX.mnn --bizCode biz +``` + +e.g. +~~~ +cd build + +./MNNConvert -f TF --modelFile ../example/model-mobilenet_v1_075.pb --MNNModel ../example/pose_model.mnn --bizCode biz --optimizeLevel 1 --optimizePrefer 2 + +python3 ../tools/script/testMNNFromTf.py ../example/model-mobilenet_v1_075.pb + +# check supported Ops +./MNNConvert -f TF --OP +~~~ \ No newline at end of file diff --git a/docs/compile/engine.md b/docs/compile/engine.md index eb8eb6503..adb4b8fb6 100644 --- a/docs/compile/engine.md +++ b/docs/compile/engine.md @@ -1,5 +1,6 @@ # 主库编译 默认编译产物为:`libMNN.so`,`express/libMNN_Express.so` +or `libMNN.a` ## Linux/MacOS - 环境要求 - cmake >= 3.10 diff --git a/express/NeuralNetWorkOp.cpp b/express/NeuralNetWorkOp.cpp index 18d58c3ec..589e4df91 100644 --- a/express/NeuralNetWorkOp.cpp +++ b/express/NeuralNetWorkOp.cpp @@ -476,6 +476,9 @@ VARP _Softmax(VARP logits, int axis) { softmax->main.AsAxis()->axis = axis; return (Variable::create(Expr::create(softmax.get(), {logits}))); } +VARP _TempratureSoftmax(VARP logits, float temperature, int axis) { + return _Softmax(logits * _Scalar(1.0f / temperature), axis); +} /*Computes softplus: log(exp(features) + 1). Args: features: A variable. Must be Halide_Type_Float. diff --git a/include/MNN/expr/NeuralNetWorkOp.hpp b/include/MNN/expr/NeuralNetWorkOp.hpp index 2a6bdfd61..181f029ec 100644 --- a/include/MNN/expr/NeuralNetWorkOp.hpp +++ b/include/MNN/expr/NeuralNetWorkOp.hpp @@ -58,6 +58,7 @@ MNN_PUBLIC VARP _Relu(VARP x, float slope = 0.0f); MNN_PUBLIC VARP _Relu6(VARP x, float minValue = 0.0f, float maxValue = 6.0f); MNN_PUBLIC VARP _PRelu(VARP x, std::vector &&slopes); MNN_PUBLIC VARP _Softmax(VARP logits, int axis = -1); +MNN_PUBLIC VARP _TempratureSoftmax(VARP logits, float temperature, int axis = -1); MNN_PUBLIC VARP _Softplus(VARP features); MNN_PUBLIC VARP _Softsign(VARP features); MNN_PUBLIC std::vector _Split(VARP value, INTS size_splits, int axis = 0); diff --git a/test/core/BufferAllocatorTest.cpp b/test/core/BufferAllocatorTest.cpp index 423b5e011..832f0bacc 100644 --- a/test/core/BufferAllocatorTest.cpp +++ b/test/core/BufferAllocatorTest.cpp @@ -38,13 +38,20 @@ class BufferAllocatorTest : public MNNTestCase { for (int i = 0; i < seqs.size(); i++) { int code = seqs[i]; if (code > 0) { + // printf("alloc: %d\n", code); auto res = allocator.alloc(code); allocs.push_back(res); } else { + // free the indexed chunk. + // printf("free: %d, idx: %d\n", code, abs(code) - 1); allocator.free(allocs[abs(code) - 1]); } } - size_t totalSize = allocator.compute(); + allocator.compute(); + // for (auto chunk : allocs){ + // printf("chunk: %p\n", chunk.ptr()); + // } + size_t totalSize = allocator.totalSize(); printf("StaticAllocator total size : %lu B, %f M\n", totalSize, totalSize / 1024.f / 1024.f); } virtual bool run(int precision) { diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index 5337afc3c..26b53e806 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -22,9 +22,11 @@ #include #include #include +#include "sampler/sampler.hpp" namespace MNN { namespace Transformer { +class Sampler; class Tokenizer; class Pipeline; class LlmConfig; @@ -53,20 +55,21 @@ class MNN_PUBLIC Llm { Llm(std::shared_ptr config) : config_(config) {} virtual ~Llm(); static Llm* createLLM(const std::string& config_path); - void chat(); + void chat(std::ostream* time_log=nullptr); void reset(); void trace(bool start); virtual void load(); - MNN::Express::VARP forward(const std::vector& input_ids); - int sample(MNN::Express::VARP logits, const std::vector& pre_ids); + MNN::Express::VARP forward(const std::vector& input_ids, bool prefill=true); + std::string decode(int id); + bool is_stop(int token_id); std::string apply_prompt_template(const std::string& user_content) const; std::string apply_chat_template(const std::vector& chat_prompts) const; std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr); - std::string response(const std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr); + std::string response(std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr); void generate_init(); std::string generate(const std::vector& input_ids, std::ostream* os, const char* end_with); - std::vector generate(const std::vector& input_ids, int max_new_tokens = -1); void print_speed(); + void print_speed(std::ostream* os); // config function std::string dump_config(); bool set_config(const std::string& content); @@ -85,9 +88,11 @@ class MNN_PUBLIC Llm { // time int64_t prefill_us_ = 0; int64_t decode_us_ = 0; + TimePerformance time_perf_; bool is_single_ = true; bool attention_fused_ = true; protected: + std::shared_ptr sampler_; std::shared_ptr config_; std::shared_ptr tokenizer_; std::vector key_value_shape_ = {}; @@ -97,9 +102,8 @@ class MNN_PUBLIC Llm { std::vector> modules_; std::vector> prefill_modules_, decode_modules_, current_modules_; const MNN::Express::Module* base_module_ = nullptr; + void initSampler(); void init_runtime(); - std::string decode(int id); - bool is_stop(int token_id); virtual std::vector tokenizer(const std::string& query); virtual MNN::Express::VARP embedding(const std::vector& input_ids); virtual MNN::Express::VARP gen_attention_mask(int seq_len); diff --git a/transformers/llm/engine/include/sampler/sampler.hpp b/transformers/llm/engine/include/sampler/sampler.hpp new file mode 100644 index 000000000..b28deb5e4 --- /dev/null +++ b/transformers/llm/engine/include/sampler/sampler.hpp @@ -0,0 +1,133 @@ +#ifndef SAMPLER_hpp +#define SAMPLER_hpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace MNN { +namespace Transformer { + +#define MICRO_TO_MILLI 1e-3f +#define MILLI_TO_MICRO 1000 +#define MICRO_TO_SEC 1e-6f +#define SEC_TO_MICRO 1000000 + +#define MEGA_TO_GIGA (1/1024.f) +#define GIGA_TO_MEGA 1024.f +#define KILLO_TO_GIGA (1/1024.f/1024.f) +#define GIGA_TO_KILLO (1024.f*1024.f) +#define KILLO_TO_MEGA (1/1024.f) +#define MEGA_TO_KILLO 1024.f + +struct PrefillTimePerformance { + size_t prefill_prev_token_ = 0; + size_t prefill_token_ = 0; + size_t prefill_us_ = 0; +}; + +struct DecodeTimePerformance { + size_t decode_prev_token_ = 0; + size_t decode_us_ = 0; +}; + +struct TimePerformance { + std::vector prefill_record_; + std::vector decode_record_; +}; + +void mergePerformance(struct TimePerformance* dst, struct TimePerformance* src); +void clearPerformance(struct TimePerformance* perf); + +class Llm; + +class MNN_PUBLIC Sampler { +protected: + Llm* mLlm; + std::vector> mCandidates; + std::vector mCommonPrefix; + int mMaxNewTokens; + int getGenLength(int candidate, int output_len) const { + return mCandidates[candidate].size() - (mCommonPrefix.size() - output_len); + } +public: + virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) = 0; + // prepare for another round of sampling + // in the future, only reset its own. + virtual void reset() {} +}; + +class MNN_PUBLIC LocalSampler : public Sampler { +public: + struct LocalSamplerConfig { + std::string type = "temperature"; + float temperature = 0.8; + int topK = 40; + float topP = 0.9; + float minP = 0.05; + float tfsZ = 1.0; + float typical = 0.95; + float penalty = 1.1; + int ngram = 8; + float ngram_factor = 1.0; // panalize repeated ngram with a multiplied ngram_factor. + float max_penalty = 10.; + }; +private: + struct LocalSamplerConfig mConfig; + int randomSelect(float* probs, size_t size); + int argmax(MNN::Express::VARP logits); + int temperature(MNN::Express::VARP logits, float temperature = 1.0); + struct IndexProb { + int index; + float prob; + }; + struct IndexProbCmpLess{ + bool operator()(IndexProb a, IndexProb b) { + return a.prob < b.prob; + } + }; + struct IndexProbCmpGreater{ + bool operator()(IndexProb a, IndexProb b) { + return a.prob > b.prob; + } + }; + int reSoftmaxSelect(std::vector index, std::vector scores, float temperature); + void topK(MNN::Express::VARP logits, int K, std::vector& topKindex, std::vector& topKprob); + int topK(MNN::Express::VARP logits, int K = 40, float temperature = 1.0); + void topP(MNN::Express::VARP logits, float p, float temperature, std::vector& topPindex, std::vector& topPprob); + int topP(MNN::Express::VARP logits, float p = 0.9, float temperature = 1.0); + void minP(MNN::Express::VARP logits, float p, float temperature, std::vector& minPindex, std::vector& minPprob); + int minP(MNN::Express::VARP logits, float p = 0.1, float temperature = 1.0); + void tfs(MNN::Express::VARP logits, float z, float temperature, std::vector& index, std::vector& tfsprob); + int tfs(MNN::Express::VARP logits, float z = 1.0, float temperature = 1.0); + void typical(MNN::Express::VARP logits, float p, float temperature, std::vector& index, std::vector& minPprob); + int typical(MNN::Express::VARP logits, float p = 1.0, float temperature = 1.0); + void penalty(MNN::Express::VARP logits, float penalty = 1.0, bool penalizeNgram = false, int ngram = 8, float ngram_factor = 1.0); + int penalty(MNN::Express::VARP logits, float penalty = 1.0, int ngram = 8, float ngram_factor = 1.0, float temperature = 1.0); + // int mixed(MNN::Express::VARP logits); + std::string handleToken(int token, std::ostream* os = &std::cout, const char* end_with = nullptr); +public: + LocalSampler(Llm* llm, int max_new_tokens, struct LocalSamplerConfig config); + int algorithm(MNN::Express::VARP logits); + virtual std::string sample(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, struct TimePerformance* time_perf = nullptr) override; + virtual void reset() override; +}; + + +} // Transformer +} // MNN + + +#endif // SAMPLER_hpp \ No newline at end of file diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 991c6c7ef..0b3ad6d30 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -12,10 +12,12 @@ #include #include +#include #include #include #include "cpp/ExprDebug.hpp" #include "llm/llm.hpp" +#include "sampler/sampler.hpp" #include "tokenizer.hpp" #include "llmconfig.hpp" // 0: no debug, 1: test op time, 2: print tensor info @@ -84,6 +86,64 @@ bool Llm::set_config(const std::string& content) { return config_->config_.merge(content.c_str()); } +void Llm::initSampler() { + std::string sampler_type = config_->sampler_type(); + std::cout << "Selected Sampler: " << sampler_type << std::endl; + // LocalSampler + LocalSampler::LocalSamplerConfig local_sampler_config; + local_sampler_config.type = sampler_type; + if (sampler_type == "greedy") { + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + if (sampler_type == "temperature") { + local_sampler_config.temperature = config_->temperature(); + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + if (sampler_type == "topK") { + local_sampler_config.temperature = config_->temperature(); + local_sampler_config.topK = config_->topK(); + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + if (sampler_type == "topP") { + local_sampler_config.temperature = config_->temperature(); + local_sampler_config.topP = config_->topP(); + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + if (sampler_type == "minP") { + local_sampler_config.temperature = config_->temperature(); + local_sampler_config.minP = config_->minP(); + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + if (sampler_type == "tfs") { + local_sampler_config.temperature = config_->temperature(); + local_sampler_config.tfsZ = config_->tfsZ(); + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + if (sampler_type == "typical") { + local_sampler_config.temperature = config_->temperature(); + local_sampler_config.typical = config_->typical(); + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + if (sampler_type == "penalty" || sampler_type == "penalize_ngram") { + local_sampler_config.temperature = config_->temperature(); + local_sampler_config.penalty = config_->penalty(); + local_sampler_config.ngram = config_->ngram(); + local_sampler_config.ngram_factor = config_->ngram_factor(); + sampler_ = std::shared_ptr(new LocalSampler(this, config_->max_new_tokens(), local_sampler_config)); + return; + } + // AdvancedSampler + // Not Implemented + MNN_ERROR("Designated Sampler Not Supported!\n"); +} + void Llm::init_runtime() { ScheduleConfig config; BackendConfig cpuBackendConfig; @@ -99,10 +159,13 @@ void Llm::init_runtime() { } else if (config_->memory() == "low") { cpuBackendConfig.memory = BackendConfig::Memory_Low; } + if (config_->precision() == "high") { cpuBackendConfig.precision = BackendConfig::Precision_High; } else if (config_->precision() == "low") { cpuBackendConfig.precision = BackendConfig::Precision_Low; + } else { + cpuBackendConfig.precision = BackendConfig::Precision_Normal; } config.backendConfig = &cpuBackendConfig; ExecutorScope::Current()->setGlobalExecutorConfig(config.type, cpuBackendConfig, config.numThread); @@ -131,6 +194,8 @@ void Llm::init_runtime() { { runtime_manager_->setCache(".tempcache"); } + // init sampler + initSampler(); } void Llm::load() { @@ -248,7 +313,14 @@ void Llm::trace(bool start) { mTracing = start; } -VARP Llm::forward(const std::vector& input_ids) { +VARP Llm::forward(const std::vector& input_ids, bool prefill) { + // set modules, seperate prefill and decode phase + if (prefill){ + current_modules_ = prefill_modules_; + } else + { + current_modules_ = decode_modules_; + } int seq_len = input_ids.size(); auto attention_mask = gen_attention_mask(seq_len); auto position_ids = gen_position_ids(seq_len); @@ -268,6 +340,7 @@ VARP Llm::forward(const std::vector& input_ids) { if (!attention_fused_) { past_key_values_[0] = outputs[1]; } + ExecutorScope::Current()->gc(Executor::FULL); } else { MNN_ERROR("Split models is depercarate\n"); return nullptr; @@ -277,28 +350,6 @@ VARP Llm::forward(const std::vector& input_ids) { return logits; } -int Llm::sample(VARP logits, const std::vector& pre_ids) { - std::unordered_set ids_set(pre_ids.begin(), pre_ids.end()); - auto scores = (float*)(logits->readMap()); - auto size = logits->getInfo()->size; - // repetition penalty - const float repetition_penalty = 1.1; - for (auto id : ids_set) { - float score = scores[id]; - scores[id] = score < 0 ? score * repetition_penalty : score / repetition_penalty; - } - // argmax - float max_score = 0; - int token_id = 0; - for (int i = 0; i < size; i++) { - float score = scores[i]; - if (score > max_score) { - max_score = score; - token_id = i; - } - } - return token_id; -} static std::string apply_template(std::string prompt_template, const std::string& content, const std::string& role = "") { if (prompt_template.empty()) return content; @@ -335,35 +386,42 @@ std::string Llm::apply_chat_template(const std::vector& chat_prompts return prompt_result; } -void Llm::chat() { +void Llm::chat(std::ostream* time_log) { std::vector history; history.push_back(std::make_pair("system", "You are a helpful assistant.")); while (true) { std::cout << "\nQ: "; std::string user_str; - std::cin >> user_str; + std::getline(std::cin, user_str); if (user_str == "/exit") { + if (time_log!=nullptr) this->print_speed(time_log); + history.clear(); + reset(); break; } if (user_str == "/reset") { - history.resize(1); + history.clear(); + history.push_back(std::make_pair("system", "You are a helpful assistant.")); + reset(); std::cout << "\nA: reset done." << std::endl; continue; } std::cout << "\nA: " << std::flush; - if (config_->reuse_kv()) { - response(user_str); - } else { - history.emplace_back(std::make_pair("user", user_str)); - auto assistant_str = response(history); - history.emplace_back(std::make_pair("assistant", assistant_str)); - } + history.emplace_back(std::make_pair("user", user_str)); + auto assistant_str = response(history); + if (!config_->reuse_kv()) + history.back().second += assistant_str + "<|im_end|>\n"; + else + history.back().second = "<|im_end|>\n"; std::cout << std::endl; } } void Llm::reset() { + clearPerformance(&time_perf_); history_ids_.clear(); + sampler_->reset(); + gen_seq_len_ = 0; all_seq_len_ = 0; } @@ -386,83 +444,18 @@ void Llm::generate_init() { } } -std::vector Llm::generate(const std::vector& input_ids, int max_new_tokens) { - generate_init(); - std::vector output_ids, all_ids = input_ids; - prompt_len_ = static_cast(input_ids.size()); - if (max_new_tokens < 0) { max_new_tokens = config_->max_new_tokens(); } - // prefill - current_modules_ = prefill_modules_; - auto logits = forward(input_ids); - if (logits.get() == nullptr) { - return {}; - } - int token = sample(logits, all_ids); - output_ids.push_back(token); - all_ids.push_back(token); - // decode - current_modules_ = decode_modules_; - while (gen_seq_len_ < max_new_tokens) { - logits = nullptr; - logits = forward({token}); - if (logits.get() == nullptr) { - return {}; - } - token = sample(logits, all_ids); - if (is_stop(token)) { break; } - output_ids.push_back(token); - all_ids.push_back(token); - } - return output_ids; -} std::string Llm::generate(const std::vector& input_ids, std::ostream* os, const char* end_with) { if (mTracing) { // Skip real forward - current_modules_ = prefill_modules_; - forward(input_ids); - current_modules_ = decode_modules_; - forward({input_ids[0]}); - forward({input_ids[0]}); + forward(input_ids, true); + forward({input_ids[0]}, false); + forward({input_ids[0]}, false); return "Test"; } prompt_len_ = static_cast(input_ids.size()); history_ids_.insert(history_ids_.end(), input_ids.begin(), input_ids.end()); // push to history_ids_ - auto st = std::chrono::system_clock::now(); - current_modules_ = prefill_modules_; - auto logits = forward(input_ids); - if (nullptr == logits.get()) { - return ""; - } - int token = sample(logits, history_ids_); - auto et = std::chrono::system_clock::now(); - current_modules_ = decode_modules_; - std::string output_str = decode(token); - prefill_us_ = std::chrono::duration_cast(et - st).count(); - *os << output_str << std::flush; - while (gen_seq_len_ < config_->max_new_tokens()) { - st = std::chrono::system_clock::now(); - history_ids_.push_back(token); - logits = nullptr; - logits = forward({token}); - if (nullptr == logits.get()) { - return ""; - } - if (logits->getInfo()->size == 0) { - return ""; - } - token = sample(logits, history_ids_); - et = std::chrono::system_clock::now(); - decode_us_ += std::chrono::duration_cast(et - st).count(); - if (is_stop(token)) { - *os << end_with << std::flush; - break; - } - auto word = decode(token); - *os << word << std::flush; - output_str += word; - } - ExecutorScope::Current()->gc(Executor::FULL); + std::string output_str = sampler_->sample(input_ids, os, end_with, &time_perf_); #ifdef DUMP_PROFILE_INFO print_speed(); #endif @@ -491,14 +484,16 @@ std::string Llm::response(const std::string& user_content, std::ostream* os, con return generate(input_ids, os, end_with); } -std::string Llm::response(const std::vector& chat_prompts, std::ostream* os, const char* end_with) { +std::string Llm::response(std::vector& chat_prompts, std::ostream* os, const char* end_with) { if (chat_prompts.empty()) { return ""; } generate_init(); if (!end_with) { end_with = "\n"; } auto prompt = apply_chat_template(chat_prompts); - if (config_->reuse_kv() && all_seq_len_ > 0) { - prompt = "<|im_end|>\n" + prompt; - } + chat_prompts.clear(); + chat_prompts.emplace_back(std::make_pair("", prompt)); + // if (config_->reuse_kv() && all_seq_len_ > 0) { + // prompt = "<|im_end|>\n" + prompt; + // } // std::cout << "# prompt : " << prompt << std::endl; auto input_ids = tokenizer_->encode(prompt); // printf("input_ids (%lu): ", input_ids.size()); for (auto id : input_ids) printf("%d, ", id); printf("\n"); @@ -536,21 +531,34 @@ Llm::~Llm() { } void Llm::print_speed() { - auto prefill_s = prefill_us_ * 1e-6; - auto decode_s = decode_us_ * 1e-6; - auto total_s = prefill_s + decode_s; - printf("\n#################################\n"); - printf(" total tokens num = %d\n", prompt_len_ + gen_seq_len_); - printf("prompt tokens num = %d\n", prompt_len_); - printf("output tokens num = %d\n", gen_seq_len_); - printf(" total time = %.2f s\n", total_s); - printf("prefill time = %.2f s\n", prefill_s); - printf(" decode time = %.2f s\n", decode_s); - printf(" total speed = %.2f tok/s\n", (prompt_len_ + gen_seq_len_) / total_s); - printf("prefill speed = %.2f tok/s\n", prompt_len_ / prefill_s); - printf(" decode speed = %.2f tok/s\n", gen_seq_len_ / decode_s); - printf(" chat speed = %.2f tok/s\n", gen_seq_len_ / total_s); - printf("##################################\n"); + // auto prefill_s = prefill_us_ * 1e-6; + // auto decode_s = decode_us_ * 1e-6; + // auto total_s = prefill_s + decode_s; + // printf("\n#################################\n"); + // printf(" total tokens num = %d\n", prompt_len_ + gen_seq_len_); + // printf("prompt tokens num = %d\n", prompt_len_); + // printf("output tokens num = %d\n", gen_seq_len_); + // printf(" total time = %.2f s\n", total_s); + // printf("prefill time = %.2f s\n", prefill_s); + // printf(" decode time = %.2f s\n", decode_s); + // printf(" total speed = %.2f tok/s\n", (prompt_len_ + gen_seq_len_) / total_s); + // printf("prefill speed = %.2f tok/s\n", prompt_len_ / prefill_s); + // printf(" decode speed = %.2f tok/s\n", gen_seq_len_ / decode_s); + // printf(" chat speed = %.2f tok/s\n", gen_seq_len_ / total_s); + // printf("##################################\n"); +} + +void Llm::print_speed(std::ostream* os) { + (*os) << "prefill " << time_perf_.prefill_record_.size() << std::endl; + (*os) << "prev_token token speed(token/s)" << std::endl; + for (auto record : time_perf_.prefill_record_) { + (*os) << record.prefill_prev_token_ << " " << record.prefill_token_ << " " << record.prefill_token_/(((float)record.prefill_us_)*MICRO_TO_SEC) << std::endl; + } + (*os) << "decode " << time_perf_.decode_record_.size() << std::endl; + (*os) << "prev_token speed(token/s)" << std::endl; + for (auto record : time_perf_.decode_record_) { + (*os) << record.decode_prev_token_ << " " << 1./(((float)record.decode_us_)*MICRO_TO_SEC) << std::endl; + } } static inline bool needNewVar(VARP var, int axis, int seq_len) { diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index 78bd3bc61..680e57417 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -98,6 +98,13 @@ class rapid_json_wrapper { return buffer.GetString(); } // read value + float value(const char* key, const float& default_value) const { + if (document.HasMember(key)) { + const auto& value = document[key]; + if (value.IsFloat()) return value.GetFloat(); + } + return default_value; + } int value(const char* key, const int& default_value) const { if (document.HasMember(key)) { const auto& value = document[key]; @@ -319,6 +326,48 @@ class LlmConfig { return llm_config_.value("prompt_template", ""); } // llm model config end > + + // < sampler config start + std::string sampler_type() const { + return config_.value("sampler_type", "greedy"); + } + + float temperature() const { + return config_.value("temperature", 1.0f); + } + + int topK() const { + return config_.value("topK", 40); + } + + float topP() const { + return config_.value("topP", 0.9f); + } + + float minP() const { + return config_.value("minP", 0.1f); + } + + float tfsZ() const { + return config_.value("tfsZ", 1.0f); + } + + float typical() const { + return config_.value("typical", 1.0f); + } + + float penalty() const { + return config_.value("penalty", 0.0f); + } + + int ngram() const { + return config_.value("n_gram", 8); + } + + float ngram_factor() const { + return config_.value("ngram_factor", 1.0f); + } + // sampler config end > }; } // Transformer } // MNN diff --git a/transformers/llm/engine/src/sampler.cpp b/transformers/llm/engine/src/sampler.cpp new file mode 100644 index 000000000..5eaec0e2f --- /dev/null +++ b/transformers/llm/engine/src/sampler.cpp @@ -0,0 +1,411 @@ +#include +#include +#include +#include +#include + +#include +#include +#include "llm/llm.hpp" +#include "sampler/sampler.hpp" + +namespace MNN{ +namespace Transformer{ + +void mergePerformance(struct TimePerformance* dst, struct TimePerformance* src) { + dst->prefill_record_.insert(dst->prefill_record_.end(), src->prefill_record_.begin(), src->prefill_record_.end()); + dst->decode_record_.insert(dst->decode_record_.end(), src->decode_record_.begin(), src->decode_record_.end()); +} + +void clearPerformance(struct TimePerformance* perf) { + perf->prefill_record_.clear(); + perf->decode_record_.clear(); +} + +LocalSampler::LocalSampler(Llm* llm, int max_new_tokens, struct LocalSamplerConfig config) { + mLlm = llm; + std::vector history_ids_; + mCandidates.emplace_back(history_ids_); // for LocalSampler, reference have never been modified manually. + mCommonPrefix = history_ids_; + mMaxNewTokens = max_new_tokens; + mConfig = config; +} + +int LocalSampler::randomSelect(float* probs, size_t size) { + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_real_distribution distribution(0.0, 1.0); + float target = distribution(generator); + float cumulative = 0.0; + for (int i = 0; i < size; i++) { + cumulative += probs[i]; + if (target < cumulative) { + return i; + } + } + return size - 1; +} + +int LocalSampler::temperature(MNN::Express::VARP logits, float temperature) { + logits = MNN::Express::_TempratureSoftmax(logits, temperature); + return randomSelect((float*)(logits->readMap()), logits->getInfo()->size); +} + +int LocalSampler::reSoftmaxSelect(std::vector index, std::vector scores, float temperature) { + auto varp = MNN::Express::_Input({(int)index.size()}, MNN::Express::NHWC); + auto scoresMap = (float*)(varp->writeMap()); + for (int i = 0; i < index.size(); ++i) { + scoresMap[i] = scores[i]; + } + int token_index_id = randomSelect((float*)(MNN::Express::_TempratureSoftmax(varp, temperature)->readMap()), index.size()); + return index[token_index_id]; +} + +void LocalSampler::topK(MNN::Express::VARP logits, int K, std::vector& index, std::vector& topKprob) { + auto scores = (float*)(logits->readMap()); + auto size = logits->getInfo()->size; + // 1. time complexity: O(nlogk) + std::priority_queue, IndexProbCmpGreater> heap; + for (int i = 0; i < size; i++) { + IndexProb m; + m.index = i; + m.prob = scores[i]; + if (heap.size() < K) { + heap.push(m); + } + else { + if (heap.top().prob < m.prob) { + heap.pop(); + heap.push(m); + } + } + } + // 2. store top K results + index.clear(); + index.resize(K); + topKprob.clear(); + topKprob.resize(K); + for (int i = 0; i < K; i++) { + index[K-i-1] = heap.top().index; + topKprob[K-i-1] = heap.top().prob; + heap.pop(); + } +} + +void LocalSampler::topP(MNN::Express::VARP logits, float p, float temperature, std::vector& index, std::vector& topPprob) { + auto prob = MNN::Express::_TempratureSoftmax(logits, temperature); + // 1. make max heap + auto scores = (float*)(prob->readMap()); + auto size = prob->getInfo()->size; + std::vector score_vector; + score_vector.resize(size); + for (int i = 0; i < size; i++) { + IndexProb m; + m.index = i; + m.prob = scores[i]; + score_vector[i] = m; + } + std::make_heap(score_vector.begin(), score_vector.end(), IndexProbCmpLess()); + // 2. top p algorithm + scores = (float*)(logits->readMap()); + float cumulative = 0.0f; + while (cumulative < p && !score_vector.empty()) { + std::pop_heap(score_vector.begin(), score_vector.end(), IndexProbCmpLess()); + IndexProb m = score_vector.back(); + score_vector.pop_back(); + index.push_back(m.index); + topPprob.push_back(scores[m.index]); + cumulative += m.prob; + } +} + +void LocalSampler::minP(MNN::Express::VARP logits, float p, float temperature, std::vector& index, std::vector& minPprob) { + auto prob = MNN::Express::_TempratureSoftmax(logits, temperature); + // 1. make max heap + auto scores = (float*)(prob->readMap()); + auto size = prob->getInfo()->size; + std::vector score_vector; + score_vector.resize(size); + for (int i = 0; i < size; i++) { + IndexProb m; + m.index = i; + m.prob = scores[i]; + score_vector[i] = m; + } + std::make_heap(score_vector.begin(), score_vector.end(), IndexProbCmpLess()); + // 2. min p algorithm + scores = (float*)(logits->readMap()); + for (int i = 0; i < size; ++i) { + std::pop_heap(score_vector.begin(), score_vector.end(), IndexProbCmpLess()); + IndexProb m = score_vector.back(); + if (m.prob < p && !index.empty()) break; + score_vector.pop_back(); + index.push_back(m.index); + minPprob.push_back(scores[m.index]); + } +} + +void LocalSampler::tfs(MNN::Express::VARP logits, float z, float temperature, std::vector& index, std::vector& tfsprob) { + // tfs algorithm + auto prob = MNN::Express::_TempratureSoftmax(logits, temperature); + // 1. softmax + auto scores = (float*)(prob->readMap()); + auto size = prob->getInfo()->size; + std::vector score_vector; + score_vector.resize(size); + for (int i = 0; i < size; i++) { + IndexProb m; + m.index = i; + m.prob = scores[i]; + score_vector[i] = m; + } + // 2. sort + std::sort(score_vector.begin(), score_vector.end(), IndexProbCmpGreater()); + scores = (float*)(logits->readMap()); + // 3. calculate derivatives + std::vector derivatives(size - 2, 0.0f); + float first = score_vector[0].prob - score_vector[1].prob; + float second = score_vector[1].prob - score_vector[2].prob; + for (int i = 0; i < size - 2; ++i) { + second = score_vector[i+1].prob - score_vector[i+2].prob; + derivatives[i] = std::fabs(first - second); + first = second; + } + // 4. normalize derivatives + float derivatives_sum = 0.0; + for (int i = 0; i < size - 2; ++i) derivatives_sum += derivatives[i]; + float derivatives_sum_rec = 1.0f / derivatives_sum; + for (int i = 0; i < size - 2; ++i) derivatives[i] *= derivatives_sum_rec; + // 5. cumulate, discard last 2 for sure. + float cumulative = 0.0; + for (int i = 0; i < size - 2; ++i) { + IndexProb m = score_vector[i]; + cumulative += derivatives[i]; + if (cumulative >= z && !index.empty()) break; + index.push_back(m.index); + tfsprob.push_back(scores[m.index]); + } +} + +void LocalSampler::typical(MNN::Express::VARP logits, float p, float temperature, std::vector& index, std::vector& minPprob) { + auto prob = MNN::Express::_TempratureSoftmax(logits, temperature); + auto scores = (float*)(prob->readMap()); + auto size = prob->getInfo()->size; + std::vector score_vector; + score_vector.resize(size); + // 1. calcaluate dist + float entropy = 0.0f; + for (int i = 0; i < size; i++) entropy -= scores[i] * std::log(scores[i]); + for (int i = 0; i < size; i++) { + IndexProb m; + m.index = i; + m.prob = std::fabs(entropy + std::log(scores[i])); + score_vector[i] = m; + } + // 2. make min heap for dist + std::make_heap(score_vector.begin(), score_vector.end(), IndexProbCmpGreater()); + // 3. typical p algorithm + auto probs = (float*)(prob->readMap()); + scores = (float*)(logits->readMap()); + float cumulative = 0.0f; + for (int i = 0; i < size; ++i) { + std::pop_heap(score_vector.begin(), score_vector.end(), IndexProbCmpGreater()); + IndexProb m = score_vector.back(); + cumulative += probs[m.index]; + if (cumulative >= p && !index.empty()) break; + score_vector.pop_back(); + index.push_back(m.index); + minPprob.push_back(scores[m.index]); + } +} + +int LocalSampler::topK(MNN::Express::VARP logits, int K, float temperature) { + // top K operation + std::vector index; + std::vector topKscores; + topK(logits, K, index, topKscores); + // apply Softmax and select + return reSoftmaxSelect(index, topKscores, temperature); +} + +int LocalSampler::topP(MNN::Express::VARP logits, float p, float temperature) { + // top p operation + std::vector index; + std::vector topPscores; + topP(logits, p, temperature, index, topPscores); + // apply Softmax and select + return reSoftmaxSelect(index, topPscores, temperature); +} + +int LocalSampler::minP(MNN::Express::VARP logits, float p, float temperature) { + // top p operation + std::vector index; + std::vector minPscores; + minP(logits, p, temperature, index, minPscores); + // apply Softmax and select + return reSoftmaxSelect(index, minPscores, temperature); +} + +int LocalSampler::tfs(MNN::Express::VARP logits, float z, float temperature) { + // top p operation + std::vector index; + std::vector scores; + tfs(logits, z, temperature, index, scores); + // apply Softmax and select + return reSoftmaxSelect(index, scores, temperature); +} + +int LocalSampler::typical(MNN::Express::VARP logits, float p, float temperature) { + // top p operation + std::vector index; + std::vector scores; + typical(logits, p, temperature, index, scores); + // apply Softmax and select + return reSoftmaxSelect(index, scores, temperature); +} + +int LocalSampler::argmax(MNN::Express::VARP logits) { + auto scores = (float*)(logits->readMap()); + auto size = logits->getInfo()->size; + float max_score = 0; + int token_id = 0; + for (int i = 0; i < size; i++) { + float score = scores[i]; + if (score > max_score) { + max_score = score; + token_id = i; + } + } + return token_id; +} + +// no frequency penalty now! +void LocalSampler::penalty(MNN::Express::VARP logits, float penalty, bool penalizeNgram, int ngram, float ngram_factor) { + if (penalty <= 1.0f) return; // no penalty! + if (ngram_factor <= 1.0f) penalizeNgram = false; + penalty = std::min(penalty, mConfig.max_penalty); + // initialization + std::vector& prev = mCandidates[0]; + std::unordered_map penalty_map; + // 1. local ngram info, reversed order + std::vector ngram_info(ngram-1); + if (penalizeNgram) { + for (int n = 0; n < ngram_info.size(); ++n) { + ngram_info[n] = prev[prev.size()-1-n]; + } + } + // 2. generate penalty map + for (int i = 0; i < prev.size(); ++i) { + if (penalty_map.count(prev[i]) == 0) penalty_map[prev[i]] = penalty; + if (penalizeNgram) { + float ngram_penalty = penalty; + for (int j = i-1; i-j < ngram && j>=0; --j) { + int idx = i-j-1; + if (prev[j] != ngram_info[idx]) break; + ngram_penalty *= ngram_factor; + // no repeat larger than ngram! + if (idx == ngram_info.size()-1) ngram_penalty = mConfig.max_penalty; + } + if (ngram_penalty > penalty_map[prev[i]]) penalty_map[prev[i]] = ngram_penalty; + } + } + // 3. penalize logits according to penalty_map + auto scoresMap = (float*)(logits->writeMap()); + for (auto it = penalty_map.begin(); it != penalty_map.end(); ++it) { + scoresMap[it->first] = (scoresMap[it->first] >= 0.0f) ? (scoresMap[it->first]/it->second) : (scoresMap[it->first]*it->second); + } +} + + +int LocalSampler::penalty(MNN::Express::VARP logits, float penalty, int ngram, float ngram_factor, float temperature) { + bool penalizeNgram = (mConfig.type == "penalize_ngram"); + this->penalty(logits, penalty, penalizeNgram, ngram, ngram_factor); + return this->temperature(logits, temperature); +} + +int LocalSampler::algorithm(MNN::Express::VARP logits) { + int res = 0; + if (mConfig.type == "greedy") res = argmax(logits); + if (mConfig.type == "temperature") res = temperature(logits, mConfig.temperature); + if (mConfig.type == "topK") res = topK(logits, mConfig.topK); + if (mConfig.type == "topP") res = topP(logits, mConfig.topP); + if (mConfig.type == "minP") res = minP(logits, mConfig.minP); + if (mConfig.type == "tfs") res = tfs(logits, mConfig.tfsZ); + if (mConfig.type == "typical") res = typical(logits, mConfig.typical); + if (mConfig.type == "penalty" || mConfig.type == "penalize_ngram") res = penalty(logits, mConfig.penalty, mConfig.ngram, mConfig.ngram_factor, mConfig.temperature); + // if (mConfig.type == "mixed") res = mixed(logits); + Express::ExecutorScope::Current()->gc(Express::Executor::FULL); + return res; +} + +std::string LocalSampler::handleToken(int token, std::ostream* os, const char* end_with) { + // CommonPrefix and Candidates managements + mCandidates[0].push_back(token); + mCommonPrefix.push_back(token); + std::string output_str = mLlm->decode(mCommonPrefix.back()); + // print + *os << output_str << std::flush; + return output_str; +} + +std::string LocalSampler::sample(const std::vector& input_ids, std::ostream* os, const char* end_with, struct TimePerformance* time_perf) { + // initialization for time performance + PrefillTimePerformance prefill_time; + prefill_time.prefill_prev_token_ = mCommonPrefix.size(); + prefill_time.prefill_token_ = input_ids.size(); + // initialization + std::string output_str; + mCandidates[0].insert(mCandidates[0].end(), input_ids.begin(), input_ids.end()); + mCommonPrefix.insert(mCommonPrefix.end(), input_ids.begin(), input_ids.end()); + // prefill + auto st = std::chrono::system_clock::now(); + auto logits = mLlm->forward(input_ids, true); + if (nullptr == logits.get()) { + return ""; + } + int token = algorithm(logits); + // record time + auto et = std::chrono::system_clock::now(); + prefill_time.prefill_us_ = std::chrono::duration_cast(et - st).count(); + time_perf->prefill_record_.push_back(prefill_time); + // handle the new token + output_str += handleToken(token, os, end_with); + // decode + while (getGenLength(0, output_str.size()) < mMaxNewTokens) { + DecodeTimePerformance decode_time; + decode_time.decode_prev_token_ = mCandidates[0].size(); + st = std::chrono::system_clock::now(); + // next token + logits = mLlm->forward({mCandidates[0].back()}, false); + if (nullptr == logits.get()) { + return output_str; + } + if (logits->getInfo()->size == 0) { + return output_str; + } + token = algorithm(logits); + et = std::chrono::system_clock::now(); + decode_time.decode_us_ = std::chrono::duration_cast(et - st).count(); + time_perf->decode_record_.push_back(decode_time); + if (mLlm->is_stop(token)) { + *os << end_with << std::flush; + break; + } else { + output_str += handleToken(token); + } + } + // return output_str + return output_str; +} + +void LocalSampler::reset() { + // in the future, only reset its own. + mCandidates.clear(); + std::vector history_ids_; + mCandidates.emplace_back(history_ids_); // for LocalSampler, reference have never been modified manually. + mCommonPrefix = history_ids_; +} + + +} // Transformer +} // MNN \ No newline at end of file