From 23442b003679e34af5e58dc3725c5654dd6c2fc7 Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Sat, 20 Jan 2024 23:09:31 +0800 Subject: [PATCH] Add perplexity example (#210) --- CMakeLists.txt | 18 ++++- README.md | 17 +++++ chatglm.cpp | 33 +++++--- chatglm.h | 23 +++--- chatglm_cpp/__init__.py | 2 +- chatglm_pybind.cpp | 7 +- main.cpp | 45 +++++------ pyproject.toml | 5 +- setup.py | 1 + tests/perplexity.cpp | 162 ++++++++++++++++++++++++++++++++++++++++ tests/ppl.sh | 21 ++++++ 11 files changed, 284 insertions(+), 50 deletions(-) create mode 100644 tests/perplexity.cpp create mode 100644 tests/ppl.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 373b98c7..b7a440fe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,8 @@ if (GGML_PERF) add_compile_definitions(GGML_PERF) endif () +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + file(GLOB CPP_SOURCES ${PROJECT_SOURCE_DIR}/*.h ${PROJECT_SOURCE_DIR}/*.cpp) @@ -47,8 +49,19 @@ set_source_files_properties(${CPP_SOURCES} PROPERTIES COMPILE_FLAGS "-pedantic-e add_library(chatglm STATIC chatglm.cpp) target_link_libraries(chatglm PUBLIC ggml sentencepiece-static) -add_executable(main main.cpp) -target_link_libraries(main PRIVATE chatglm) +# c++ examples +option(CHATGLM_ENABLE_EXAMPLES "chatglm: enable c++ examples" ON) +if (CHATGLM_ENABLE_EXAMPLES) + add_executable(main main.cpp) + target_link_libraries(main PRIVATE chatglm) + + find_package(OpenMP) + if (OpenMP_CXX_FOUND) + set(CHATGLM_OPENMP_TARGET OpenMP::OpenMP_CXX) + endif () + add_executable(perplexity tests/perplexity.cpp) + target_link_libraries(perplexity PRIVATE chatglm ${CHATGLM_OPENMP_TARGET}) +endif () # GoogleTest option(CHATGLM_ENABLE_TESTING "chatglm: enable testing" OFF) @@ -75,7 +88,6 @@ endif () option(CHATGLM_ENABLE_PYBIND "chatglm: enable python binding" OFF) if (CHATGLM_ENABLE_PYBIND) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}) set_target_properties(chatglm ggml sentencepiece-static PROPERTIES POSITION_INDEPENDENT_CODE TRUE) add_subdirectory(third_party/pybind11) pybind11_add_module(_C chatglm_pybind.cpp) diff --git a/README.md b/README.md index 157e36c1..6dd2f8cb 100644 --- a/README.md +++ b/README.md @@ -272,6 +272,8 @@ cmake -B build -DGGML_CUBLAS=ON -DCUDA_ARCHITECTURES="80" # for A100 cmake -B build -DGGML_CUBLAS=ON -DCUDA_ARCHITECTURES="70;75" # compatible with both V100 and T4 ``` +To find out the CUDA architecture of your GPU device, see [Matching CUDA arch and CUDA gencode for various NVIDIA architectures](https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/). + **Metal** MPS (Metal Performance Shaders) allows computation to run on powerful Apple Silicon GPU. Add the CMake flag `-DGGML_METAL=ON` to enable it. @@ -628,6 +630,21 @@ InternLM-20B: | ms/token (CPU @ Platinum 8260) | 230.0 | 236.7 | 276.6 | 290.6 | 357.1 | N/A | | ms/token (CUDA @ V100 SXM2) | 21.6 | 23.2 | 25.0 | 25.9 | 33.4 | N/A | +## Model Quality + +We measure model quality by evaluating the perplexity over the WikiText-2 test dataset, following the strided sliding window strategy in https://huggingface.co/docs/transformers/perplexity. Lower perplexity usually indicates a better model. + +Download and unzip the dataset from [link](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip). Measure the perplexity with a stride of 512 and max input length of 2048: +```sh +./build/bin/perplexity -m -f wikitext-2-raw/wiki.test.raw -s 512 -l 2048 +``` + +| | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F16 | +|-------------------------|-------|-------|-------|-------|-------|-------| +| [ChatGLM3-6B-Base][1] | 6.215 | 6.184 | 5.997 | 6.015 | 5.965 | 5.971 | + +[1]: https://huggingface.co/THUDM/chatglm3-6b-base + ## Development **Unit Test & Benchmark** diff --git a/chatglm.cpp b/chatglm.cpp index 5006c831..bf592a62 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -81,15 +81,20 @@ std::string to_string(ggml_tensor *tensor, bool with_data) { for (int i0 = 0; i0 < tensor->ne[0]; i0++) { auto ptr = (char *)tensor->data + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]; - float val; - if (tensor->type == GGML_TYPE_F32) { - val = *(float *)ptr; - } else if (tensor->type == GGML_TYPE_F16) { - val = ggml_fp16_to_fp32(*(ggml_fp16_t *)ptr); + oss << (i0 > 0 ? ", " : ""); + if (tensor->type == GGML_TYPE_I32) { + oss << *(int *)ptr; } else { - CHATGLM_THROW << "unimplemented"; + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(float *)ptr; + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(ggml_fp16_t *)ptr); + } else { + CHATGLM_THROW << "unimplemented"; + } + oss << std::setw(7) << std::fixed << std::setprecision(4) << val; } - oss << (i0 > 0 ? ", " : "") << std::setw(7) << std::fixed << std::setprecision(4) << val; } oss << "]"; } @@ -496,12 +501,11 @@ BaseModelForCausalLM::BaseModelForCausalLM(ModelConfig config, size_t mem_size, #endif } -int BaseModelForCausalLM::generate_next_token(const std::vector &input_ids, const GenerationConfig &gen_config, - int n_past, int n_ctx) { +ggml_tensor *BaseModelForCausalLM::forward_graph_compute(const std::vector &input_ids, int n_past, int n_ctx, + int n_threads, bool is_decoding) { ctx_.ctx_b = make_unique_ggml_context(ctx_.compute_buffer.size(), ctx_.compute_buffer.data(), false); ctx_.gf = {}; - int n_threads = gen_config.num_threads; // user defined if (n_threads <= 0) { n_threads = get_default_num_threads(); // default thread num } @@ -513,7 +517,7 @@ int BaseModelForCausalLM::generate_next_token(const std::vector &input_ids, ggml_tensor *curr_input_ids = ggml_new_tensor_1d(ctx_.ctx_b.get(), GGML_TYPE_I32, curr_input_ids_size); memcpy(curr_input_ids->data, input_ids.data() + n_past, ggml_nbytes(curr_input_ids)); - ggml_tensor *lm_logits = forward(&ctx_, curr_input_ids, n_past, n_ctx); + ggml_tensor *lm_logits = forward(&ctx_, curr_input_ids, n_past, n_ctx, is_decoding); lm_logits->backend = GGML_BACKEND_CPU; ggml_build_forward_expand(&ctx_.gf, lm_logits); @@ -527,6 +531,13 @@ int BaseModelForCausalLM::generate_next_token(const std::vector &input_ids, ggml_graph_print(&ctx_.gf); #endif + return lm_logits; +} + +int BaseModelForCausalLM::generate_next_token(const std::vector &input_ids, const GenerationConfig &gen_config, + int n_past, int n_ctx) { + ggml_tensor *lm_logits = forward_graph_compute(input_ids, n_past, n_ctx, gen_config.num_threads, true); + int vocab_size = lm_logits->ne[0]; float *next_token_logits = (float *)lm_logits->data; diff --git a/chatglm.h b/chatglm.h index ae1d8fe4..49a33607 100644 --- a/chatglm.h +++ b/chatglm.h @@ -855,7 +855,11 @@ class BaseModelForCausalLM { virtual ~BaseModelForCausalLM() = default; virtual void load(ModelLoader &loader) = 0; - virtual ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const = 0; + virtual ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx, + bool is_decoding) const = 0; + + ggml_tensor *forward_graph_compute(const std::vector &input_ids, int n_past, int n_ctx, int n_threads, + bool is_decoding); std::vector generate(const std::vector &input_ids, const GenerationConfig &gen_config, BaseStreamer *streamer = nullptr); @@ -896,10 +900,11 @@ class BasicModelForCausalLM : public BaseModelForCausalLM { ~BasicModelForCausalLM() { to_cpu(); } public: - ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const override { + ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx, + bool is_decoding) const override { ggml_tensor *transformer_outputs = transformer.forward(ctx, input_ids, n_past, n_ctx); - // NOTE: only compute next_token_logits for the last token - if (input_ids->ne[0] > 1) { + // NOTE: only compute next token logits for decoding + if (is_decoding && input_ids->ne[0] > 1) { transformer_outputs = tensor_assign_buffers( ggml_view_1d(ctx->ctx_b.get(), transformer_outputs, config.hidden_size, (input_ids->ne[0] - 1) * config.hidden_size * ggml_element_size(transformer_outputs))); @@ -1011,7 +1016,7 @@ class ChatGLMForCausalLM : public BasicModelForCausalLM { StateDict state_dict() const; public: - static constexpr size_t MEM_SIZE = 512 * MB; // 2k context + static constexpr size_t MEM_SIZE = 1280 * MB; // 2k context static constexpr size_t SCRATCH_SIZE = 1024 * MB; // 2k context }; @@ -1061,7 +1066,7 @@ class ChatGLM2ForCausalLM : public BasicModelForCausalLM { StateDict state_dict() const; public: - static constexpr size_t MEM_SIZE = 512 * MB; // 2k context + static constexpr size_t MEM_SIZE = 1280 * MB; // 2k context static constexpr size_t SCRATCH_SIZE = 1280 * MB; // 2k context }; @@ -1161,7 +1166,7 @@ class Baichuan7BForCausalLM : public BasicModelForCausalLM { StateDict state_dict() const; public: - static constexpr size_t MEM_SIZE = 512 * MB; + static constexpr size_t MEM_SIZE = 1280 * MB; static constexpr size_t SCRATCH_SIZE = 1280 * MB; }; @@ -1187,7 +1192,7 @@ class Baichuan13BForCausalLM : public BasicModelForCausalLM { StateDict state_dict() const; public: - static constexpr size_t MEM_SIZE = 512 * MB; + static constexpr size_t MEM_SIZE = 1280 * MB; static constexpr size_t SCRATCH_SIZE = 1280 * MB; }; @@ -1248,7 +1253,7 @@ class InternLMForCausalLM : public BasicModelForCausalLM { StateDict state_dict() const; public: - static constexpr size_t MEM_SIZE = 512 * MB; + static constexpr size_t MEM_SIZE = 1280 * MB; static constexpr size_t SCRATCH_SIZE = 1280 * MB; }; diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index 46712ccd..72fbca22 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -6,7 +6,7 @@ import chatglm_cpp._C as _C from chatglm_cpp._C import ChatMessage -__version__ = "0.3.1.dev" +__version__ = "0.3.1" @dataclass diff --git a/chatglm_pybind.cpp b/chatglm_pybind.cpp index e3f23123..9105be82 100644 --- a/chatglm_pybind.cpp +++ b/chatglm_pybind.cpp @@ -27,8 +27,11 @@ class PyBaseModelForCausalLM : public BaseModelForCausalLM { using BaseModelForCausalLM::BaseModelForCausalLM; void load(ModelLoader &loader) override { PYBIND11_OVERLOAD_PURE(void, PyBaseModelForCausalLM, load, loader); } - ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx) const override { - PYBIND11_OVERLOAD_PURE(ggml_tensor *, PyBaseModelForCausalLM, forward, ctx, input_ids, n_past, n_ctx) + + ggml_tensor *forward(ModelContext *ctx, ggml_tensor *input_ids, int n_past, int n_ctx, + bool is_decoding) const override { + PYBIND11_OVERLOAD_PURE(ggml_tensor *, PyBaseModelForCausalLM, forward, ctx, input_ids, n_past, n_ctx, + is_decoding) } }; diff --git a/main.cpp b/main.cpp index d76ccc8a..f68e0e74 100644 --- a/main.cpp +++ b/main.cpp @@ -40,28 +40,29 @@ struct Args { }; static void usage(const std::string &prog) { - std::cout << "Usage: " << prog << " [options]\n" - << "\n" - << "options:\n" - << " -h, --help show this help message and exit\n" - << " -m, --model PATH model path (default: chatglm-ggml.bin)\n" - << " --mode inference mode chosen from {chat, generate} (default: chat)\n" - << " --sync synchronized generation without streaming\n" - << " -p, --prompt PROMPT prompt to start generation with (default: 你好)\n" - << " --pp, --prompt_path path to the plain text file that stores the prompt\n" - << " -s, --system SYSTEM system message to set the behavior of the assistant\n" - << " --sp, --system_path path to the plain text file that stores the system message\n" - << " -i, --interactive run in interactive mode\n" - << " -l, --max_length N max total length including prompt and output (default: 2048)\n" - << " --max_new_tokens N max number of tokens to generate, ignoring the number of prompt tokens\n" - << " -c, --max_context_length N\n" - << " max context length (default: 512)\n" - << " --top_k N top-k sampling (default: 0)\n" - << " --top_p N top-p sampling (default: 0.7)\n" - << " --temp N temperature (default: 0.95)\n" - << " --repeat_penalty N penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled)\n" - << " -t, --threads N number of threads for inference\n" - << " -v, --verbose display verbose output including config/system/performance info\n"; + std::cout << "Usage: " << prog << R"( [options] + +options: + -h, --help show this help message and exit + -m, --model PATH model path (default: chatglm-ggml.bin) + --mode inference mode chosen from {chat, generate} (default: chat) + --sync synchronized generation without streaming + -p, --prompt PROMPT prompt to start generation with (default: 你好) + --pp, --prompt_path path to the plain text file that stores the prompt + -s, --system SYSTEM system message to set the behavior of the assistant + --sp, --system_path path to the plain text file that stores the system message + -i, --interactive run in interactive mode + -l, --max_length N max total length including prompt and output (default: 2048) + --max_new_tokens N max number of tokens to generate, ignoring the number of prompt tokens + -c, --max_context_length N + max context length (default: 512) + --top_k N top-k sampling (default: 0) + --top_p N top-p sampling (default: 0.7) + --temp N temperature (default: 0.95) + --repeat_penalty N penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) + -t, --threads N number of threads for inference + -v, --verbose display verbose output including config/system/performance info +)"; } static std::string read_text(std::string path) { diff --git a/pyproject.toml b/pyproject.toml index fc442293..19736ea5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,10 +13,10 @@ authors = [ maintainers = [ {name = "Jiahao Li", email = "liplus17@163.com"}, ] -description = "C++ implementation of ChatGLM-6B & ChatGLM2-6B" +description = "C++ implementation of ChatGLM family models and more LLMs" readme = "README.md" requires-python = ">=3.7" -keywords = ["ChatGLM", "ChatGLM2", "Large Language Model"] +keywords = ["ChatGLM", "ChatGLM2", "ChatGLM3", "Large Language Model"] license = {text = "MIT License"} classifiers = [ "Development Status :: 3 - Alpha", @@ -27,6 +27,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] dynamic = ["version"] diff --git a/setup.py b/setup.py index 4a94e2ab..ec6815cc 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ def build_extension(self, ext: CMakeExtension) -> None: f"-DPYTHON_EXECUTABLE={sys.executable}", f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm f"-DCHATGLM_ENABLE_PYBIND=ON", + f"-DCHATGLM_ENABLE_EXAMPLES=OFF", f"-DBUILD_SHARED_LIBS=OFF", ] build_args = [] diff --git a/tests/perplexity.cpp b/tests/perplexity.cpp new file mode 100644 index 00000000..a98e624c --- /dev/null +++ b/tests/perplexity.cpp @@ -0,0 +1,162 @@ +#include "chatglm.h" +#include +#include +#include +#include + +struct Args { + std::string model_path = "chatglm-ggml.bin"; + std::string corpus_path = "data/wikitext-2-raw/wiki.test.raw"; + int max_length = 1024; + int stride = 512; + int num_threads = 0; +}; + +static void usage(const std::string &prog) { + std::cout << "Usage: " << prog << R"( [options] + +options: + -h, --help show this help message and exit + -m, --model PATH model path + -f, --file path to the corpus + -l, --max_length N max total length including prompt and output + -s, --stride N stride size of the sliding window + -t, --threads N number of threads for inference +)"; +} + +static Args parse_args(const std::vector &argv) { + Args args; + + for (size_t i = 1; i < argv.size(); i++) { + const std::string &arg = argv.at(i); + + if (arg == "-h" || arg == "--help") { + usage(argv.at(0)); + exit(EXIT_SUCCESS); + } else if (arg == "-m" || arg == "--model") { + args.model_path = argv.at(++i); + } else if (arg == "-f" || arg == "--file") { + args.corpus_path = argv.at(++i); + } else if (arg == "-l" || arg == "--max_length") { + args.max_length = std::stoi(argv.at(++i)); + } else if (arg == "-s" || arg == "--stride") { + args.stride = std::stoi(argv.at(++i)); + } else if (arg == "-t" || arg == "--threads") { + args.num_threads = std::stoi(argv.at(++i)); + } else { + std::cerr << "Unknown argument: " << arg << std::endl; + usage(argv.at(0)); + exit(EXIT_FAILURE); + } + } + + return args; +} + +static Args parse_args(int argc, char **argv) { + std::vector argv_vec(argv, argv + argc); + return parse_args(argv_vec); +} + +static std::string read_text(std::string path) { + std::ifstream fin(path); + CHATGLM_CHECK(fin) << "cannot open file " << path; + std::ostringstream oss; + oss << fin.rdbuf(); + return oss.str(); +} + +static float cross_entropy(const ggml_tensor *input, const ggml_tensor *target) { + CHATGLM_CHECK(ggml_is_contiguous(input) && input->n_dims == 2 && input->type == GGML_TYPE_F32); + CHATGLM_CHECK(ggml_is_contiguous(target) && target->n_dims == 1 && target->type == GGML_TYPE_I32); + CHATGLM_CHECK(input->ne[1] == target->ne[0]); + + const int num_classes = input->ne[0]; + const int batch_size = input->ne[1]; + + float loss = 0.f; +#pragma omp parallel for reduction(+ : loss) + for (int i = 0; i < batch_size; i++) { + const int target_i = ((const int *)target->data)[i]; + const float *row = (const float *)input->data + i * input->ne[0]; + const float max_val = *std::max_element(row, row + num_classes); + float sum = 0.f; + for (int j = 0; j < num_classes; j++) { + sum += std::exp(row[j] - max_val); + } + loss += -(row[target_i] - max_val - std::log(sum)); + } + + return loss / batch_size; +} + +// reference: https://huggingface.co/docs/transformers/perplexity +static void perplexity(Args &args) { + std::cout << "Loading model from " << args.model_path << " ...\n"; + chatglm::Pipeline pipeline(args.model_path); + + std::cout << "Loading corpus from " << args.corpus_path << " ...\n"; + std::string corpus = read_text(args.corpus_path); + + std::cout << "Tokenizing corpus of " << corpus.size() << " bytes ...\n"; + std::vector corpus_ids = pipeline.tokenizer->encode(corpus, std::numeric_limits::max()); + corpus_ids.erase(corpus_ids.begin(), corpus_ids.begin() + 2); + + std::cout << "Computing perplexity against " << corpus_ids.size() << " tokens ...\n"; + + float total_loss = 0.f; + size_t num_samples = 0; + + size_t prev_end = 0; + for (size_t begin = 0; begin < corpus_ids.size(); begin += args.stride) { + const auto clk_start = std::chrono::system_clock::now(); + size_t end = std::min(begin + args.max_length, corpus_ids.size()); + size_t target_len = std::min(end - prev_end, size_t(args.max_length - 1)); + std::vector input_ids(corpus_ids.begin() + begin, corpus_ids.begin() + end); + + ggml_tensor *lm_logits = pipeline.model->forward_graph_compute(input_ids, 0, 0, args.num_threads, false); + + const auto clk_fwd = std::chrono::system_clock::now(); + + auto ctx = chatglm::make_unique_ggml_context(512 * chatglm::MB, nullptr, false); + ggml_tensor *next_lm_logits = ggml_view_2d(ctx.get(), lm_logits, lm_logits->ne[0], target_len, lm_logits->nb[1], + (input_ids.size() - target_len - 1) * lm_logits->nb[1]); + ggml_tensor *next_input_ids = ggml_new_tensor_1d(ctx.get(), GGML_TYPE_I32, target_len); + memcpy(next_input_ids->data, input_ids.data() + input_ids.size() - target_len, target_len * sizeof(int)); + + const float loss = cross_entropy(next_lm_logits, next_input_ids); + + total_loss += loss * target_len; + num_samples += target_len; + + const auto clk_end = std::chrono::system_clock::now(); + + const auto elapsed_fwd = std::chrono::duration_cast(clk_fwd - clk_start).count(); + const auto elapsed_ce = std::chrono::duration_cast(clk_end - clk_fwd).count(); + + const int progress = end * 100 / corpus_ids.size(); + std::cout << "[" << progress << "%] chunk [" << end - target_len << ", " << end + << ") perplexity: " << std::fixed << std::setprecision(3) << std::exp(loss) + << ", forward time: " << elapsed_fwd << " ms, cross entropy time: " << elapsed_ce << " ms\n"; + + prev_end = end; + if (end == corpus_ids.size()) { + break; + } + } + + const float ppl = std::exp(total_loss / num_samples); + std::cout << "Final perplexity: " << std::fixed << std::setprecision(3) << ppl << "\n"; +} + +int main(int argc, char **argv) { + try { + Args args = parse_args(argc, argv); + perplexity(args); + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + exit(EXIT_FAILURE); + } + return 0; +} diff --git a/tests/ppl.sh b/tests/ppl.sh new file mode 100644 index 00000000..206025ea --- /dev/null +++ b/tests/ppl.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +export CUDA_VISIBLE_DEVICES=0 + +# ChatGLM3-6B-Base +hf_model=THUDM/chatglm3-6b-base +ggml_model=chatglm3-base-ggml.bin + +# Baichuan2-7B-Base +# hf_model=baichuan-inc/Baichuan2-7B-Base +# ggml_model=baichuan2-7b-base-ggml.bin + +# InternLM +# hf_model=internlm/internlm-7b +# ggml_model=internlm-7b-base-ggml.bin + +for dtype in f16; do + python3 chatglm_cpp/convert.py -i $hf_model -o $ggml_model -t $dtype + echo "[perplexity] dtype=$dtype" + ./build/bin/perplexity -m $ggml_model -f data/wikitext-2-raw/wiki.test.raw -s 512 -l 2048 +done