diff --git a/chatglm.cpp b/chatglm.cpp index daaa26a..5006c83 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -663,8 +663,9 @@ std::vector BaseModelForCausalLM::generate(const std::vector &input_id int n_past = 0; const int n_ctx = input_ids.size(); + const int max_new_tokens = (gen_config.max_new_tokens > 0) ? gen_config.max_new_tokens : gen_config.max_length; - while ((int)output_ids.size() < gen_config.max_length) { + while ((int)output_ids.size() < std::min(gen_config.max_length, n_ctx + max_new_tokens)) { int next_token_id = generate_next_token(output_ids, gen_config, n_past, n_ctx); n_past = output_ids.size(); diff --git a/chatglm.h b/chatglm.h index 2394e15..ae1d8fe 100644 --- a/chatglm.h +++ b/chatglm.h @@ -814,6 +814,7 @@ class ModelLoader { struct GenerationConfig { int max_length; + int max_new_tokens; int max_context_length; bool do_sample; int top_k; @@ -822,10 +823,12 @@ struct GenerationConfig { float repetition_penalty; int num_threads; - GenerationConfig(int max_length = 2048, int max_context_length = 512, bool do_sample = true, int top_k = 0, - float top_p = 0.7, float temperature = 0.95, float repetition_penalty = 1.f, int num_threads = 0) - : max_length(max_length), max_context_length(max_context_length), do_sample(do_sample), top_k(top_k), - top_p(top_p), temperature(temperature), repetition_penalty(repetition_penalty), num_threads(num_threads) {} + GenerationConfig(int max_length = 2048, int max_new_tokens = -1, int max_context_length = 512, + bool do_sample = true, int top_k = 0, float top_p = 0.7, float temperature = 0.95, + float repetition_penalty = 1.f, int num_threads = 0) + : max_length(max_length), max_new_tokens(max_new_tokens), max_context_length(max_context_length), + do_sample(do_sample), top_k(top_k), top_p(top_p), temperature(temperature), + repetition_penalty(repetition_penalty), num_threads(num_threads) {} }; int get_num_physical_cores(); diff --git a/chatglm_cpp/_C.pyi b/chatglm_cpp/_C.pyi index 20a5df8..edca216 100644 --- a/chatglm_cpp/_C.pyi +++ b/chatglm_cpp/_C.pyi @@ -66,12 +66,13 @@ class GenerationConfig: do_sample: bool max_context_length: int max_length: int + max_new_tokens: int num_threads: int repetition_penalty: float temperature: float top_k: int top_p: float - def __init__(self, max_length: int = 2048, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, top_p: float = 0.7, temperature: float = 0.95, repetition_penalty: float = 1.0, num_threads: int = 0) -> None: + def __init__(self, max_length: int = 2048, max_new_tokens: int = -1, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, top_p: float = 0.7, temperature: float = 0.95, repetition_penalty: float = 1.0, num_threads: int = 0) -> None: ... class InternLM20BForCausalLM(BaseModelForCausalLM): pass diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index cdb79cd..46712cc 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -47,6 +47,7 @@ def chat( messages: List[ChatMessage], *, max_length: int = 2048, + max_new_tokens: int = -1, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, @@ -60,6 +61,7 @@ def chat( input_ids = self.tokenizer.encode_messages(messages, max_context_length) gen_config = _C.GenerationConfig( max_length=max_length, + max_new_tokens=max_new_tokens, max_context_length=max_context_length, do_sample=do_sample, top_k=top_k, @@ -77,6 +79,7 @@ def generate( prompt: str, *, max_length: int = 2048, + max_new_tokens: int = -1, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, @@ -89,6 +92,7 @@ def generate( input_ids = self.tokenizer.encode(prompt, max_context_length) gen_config = _C.GenerationConfig( max_length=max_length, + max_new_tokens=max_new_tokens, max_context_length=max_context_length, do_sample=do_sample, top_k=top_k, @@ -105,8 +109,9 @@ def _stream_generate_ids(self, input_ids: List[int], gen_config: _C.GenerationCo input_ids = input_ids.copy() n_past = 0 n_ctx = len(input_ids) + max_new_tokens = gen_config.max_new_tokens if gen_config.max_new_tokens > 0 else gen_config.max_length - while len(input_ids) < gen_config.max_length: + while len(input_ids) < min(gen_config.max_length, n_ctx + max_new_tokens): next_token_id = self.model.generate_next_token(input_ids, gen_config, n_past, n_ctx) yield next_token_id n_past = len(input_ids) diff --git a/chatglm_pybind.cpp b/chatglm_pybind.cpp index 24749b8..e3f2312 100644 --- a/chatglm_pybind.cpp +++ b/chatglm_pybind.cpp @@ -69,10 +69,11 @@ PYBIND11_MODULE(_C, m) { .def_property_readonly("model_type_name", &ModelConfig::model_type_name); py::class_(m, "GenerationConfig") - .def(py::init(), "max_length"_a = 2048, - "max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0, "top_p"_a = 0.7, "temperature"_a = 0.95, - "repetition_penalty"_a = 1.0, "num_threads"_a = 0) + .def(py::init(), "max_length"_a = 2048, + "max_new_tokens"_a = -1, "max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0, + "top_p"_a = 0.7, "temperature"_a = 0.95, "repetition_penalty"_a = 1.0, "num_threads"_a = 0) .def_readwrite("max_length", &GenerationConfig::max_length) + .def_readwrite("max_new_tokens", &GenerationConfig::max_new_tokens) .def_readwrite("max_context_length", &GenerationConfig::max_context_length) .def_readwrite("do_sample", &GenerationConfig::do_sample) .def_readwrite("top_k", &GenerationConfig::top_k) diff --git a/main.cpp b/main.cpp index e5bed73..d76ccc8 100644 --- a/main.cpp +++ b/main.cpp @@ -28,6 +28,7 @@ struct Args { std::string prompt = "你好"; std::string system = ""; int max_length = 2048; + int max_new_tokens = -1; int max_context_length = 512; bool interactive = false; int top_k = 0; @@ -44,7 +45,7 @@ static void usage(const std::string &prog) { << "options:\n" << " -h, --help show this help message and exit\n" << " -m, --model PATH model path (default: chatglm-ggml.bin)\n" - << " --mode inference mode chose from {chat, generate} (default: chat)\n" + << " --mode inference mode chosen from {chat, generate} (default: chat)\n" << " --sync synchronized generation without streaming\n" << " -p, --prompt PROMPT prompt to start generation with (default: 你好)\n" << " --pp, --prompt_path path to the plain text file that stores the prompt\n" @@ -52,6 +53,7 @@ static void usage(const std::string &prog) { << " --sp, --system_path path to the plain text file that stores the system message\n" << " -i, --interactive run in interactive mode\n" << " -l, --max_length N max total length including prompt and output (default: 2048)\n" + << " --max_new_tokens N max number of tokens to generate, ignoring the number of prompt tokens\n" << " -c, --max_context_length N\n" << " max context length (default: 512)\n" << " --top_k N top-k sampling (default: 0)\n" @@ -97,6 +99,8 @@ static Args parse_args(const std::vector &argv) { args.interactive = true; } else if (arg == "-l" || arg == "--max_length") { args.max_length = std::stoi(argv.at(++i)); + } else if (arg == "--max_new_tokens") { + args.max_new_tokens = std::stoi(argv.at(++i)); } else if (arg == "-c" || arg == "--max_context_length") { args.max_context_length = std::stoi(argv.at(++i)); } else if (arg == "--top_k") { @@ -179,8 +183,8 @@ static void chat(Args &args) { } auto streamer = std::make_unique(std::move(streamers)); - chatglm::GenerationConfig gen_config(args.max_length, args.max_context_length, args.temp > 0, args.top_k, - args.top_p, args.temp, args.repeat_penalty, args.num_threads); + chatglm::GenerationConfig gen_config(args.max_length, args.max_new_tokens, args.max_context_length, args.temp > 0, + args.top_k, args.top_p, args.temp, args.repeat_penalty, args.num_threads); if (args.verbose) { std::cout << "system info: | "