Skip to content

Commit

Permalink
Support max_new_tokens generation option
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Nov 24, 2023
1 parent e09726e commit cb31d84
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 13 deletions.
3 changes: 2 additions & 1 deletion chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,8 +663,9 @@ std::vector<int> BaseModelForCausalLM::generate(const std::vector<int> &input_id

int n_past = 0;
const int n_ctx = input_ids.size();
const int max_new_tokens = (gen_config.max_new_tokens > 0) ? gen_config.max_new_tokens : gen_config.max_length;

while ((int)output_ids.size() < gen_config.max_length) {
while ((int)output_ids.size() < std::min(gen_config.max_length, n_ctx + max_new_tokens)) {
int next_token_id = generate_next_token(output_ids, gen_config, n_past, n_ctx);

n_past = output_ids.size();
Expand Down
11 changes: 7 additions & 4 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ class ModelLoader {

struct GenerationConfig {
int max_length;
int max_new_tokens;
int max_context_length;
bool do_sample;
int top_k;
Expand All @@ -822,10 +823,12 @@ struct GenerationConfig {
float repetition_penalty;
int num_threads;

GenerationConfig(int max_length = 2048, int max_context_length = 512, bool do_sample = true, int top_k = 0,
float top_p = 0.7, float temperature = 0.95, float repetition_penalty = 1.f, int num_threads = 0)
: max_length(max_length), max_context_length(max_context_length), do_sample(do_sample), top_k(top_k),
top_p(top_p), temperature(temperature), repetition_penalty(repetition_penalty), num_threads(num_threads) {}
GenerationConfig(int max_length = 2048, int max_new_tokens = -1, int max_context_length = 512,
bool do_sample = true, int top_k = 0, float top_p = 0.7, float temperature = 0.95,
float repetition_penalty = 1.f, int num_threads = 0)
: max_length(max_length), max_new_tokens(max_new_tokens), max_context_length(max_context_length),
do_sample(do_sample), top_k(top_k), top_p(top_p), temperature(temperature),
repetition_penalty(repetition_penalty), num_threads(num_threads) {}
};

int get_num_physical_cores();
Expand Down
3 changes: 2 additions & 1 deletion chatglm_cpp/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ class GenerationConfig:
do_sample: bool
max_context_length: int
max_length: int
max_new_tokens: int
num_threads: int
repetition_penalty: float
temperature: float
top_k: int
top_p: float
def __init__(self, max_length: int = 2048, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, top_p: float = 0.7, temperature: float = 0.95, repetition_penalty: float = 1.0, num_threads: int = 0) -> None:
def __init__(self, max_length: int = 2048, max_new_tokens: int = -1, max_context_length: int = 512, do_sample: bool = True, top_k: int = 0, top_p: float = 0.7, temperature: float = 0.95, repetition_penalty: float = 1.0, num_threads: int = 0) -> None:
...
class InternLM20BForCausalLM(BaseModelForCausalLM):
pass
Expand Down
7 changes: 6 additions & 1 deletion chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def chat(
messages: List[ChatMessage],
*,
max_length: int = 2048,
max_new_tokens: int = -1,
max_context_length: int = 512,
do_sample: bool = True,
top_k: int = 0,
Expand All @@ -60,6 +61,7 @@ def chat(
input_ids = self.tokenizer.encode_messages(messages, max_context_length)
gen_config = _C.GenerationConfig(
max_length=max_length,
max_new_tokens=max_new_tokens,
max_context_length=max_context_length,
do_sample=do_sample,
top_k=top_k,
Expand All @@ -77,6 +79,7 @@ def generate(
prompt: str,
*,
max_length: int = 2048,
max_new_tokens: int = -1,
max_context_length: int = 512,
do_sample: bool = True,
top_k: int = 0,
Expand All @@ -89,6 +92,7 @@ def generate(
input_ids = self.tokenizer.encode(prompt, max_context_length)
gen_config = _C.GenerationConfig(
max_length=max_length,
max_new_tokens=max_new_tokens,
max_context_length=max_context_length,
do_sample=do_sample,
top_k=top_k,
Expand All @@ -105,8 +109,9 @@ def _stream_generate_ids(self, input_ids: List[int], gen_config: _C.GenerationCo
input_ids = input_ids.copy()
n_past = 0
n_ctx = len(input_ids)
max_new_tokens = gen_config.max_new_tokens if gen_config.max_new_tokens > 0 else gen_config.max_length

while len(input_ids) < gen_config.max_length:
while len(input_ids) < min(gen_config.max_length, n_ctx + max_new_tokens):
next_token_id = self.model.generate_next_token(input_ids, gen_config, n_past, n_ctx)
yield next_token_id
n_past = len(input_ids)
Expand Down
7 changes: 4 additions & 3 deletions chatglm_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ PYBIND11_MODULE(_C, m) {
.def_property_readonly("model_type_name", &ModelConfig::model_type_name);

py::class_<GenerationConfig>(m, "GenerationConfig")
.def(py::init<int, int, bool, int, float, float, float, int>(), "max_length"_a = 2048,
"max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0, "top_p"_a = 0.7, "temperature"_a = 0.95,
"repetition_penalty"_a = 1.0, "num_threads"_a = 0)
.def(py::init<int, int, int, bool, int, float, float, float, int>(), "max_length"_a = 2048,
"max_new_tokens"_a = -1, "max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0,
"top_p"_a = 0.7, "temperature"_a = 0.95, "repetition_penalty"_a = 1.0, "num_threads"_a = 0)
.def_readwrite("max_length", &GenerationConfig::max_length)
.def_readwrite("max_new_tokens", &GenerationConfig::max_new_tokens)
.def_readwrite("max_context_length", &GenerationConfig::max_context_length)
.def_readwrite("do_sample", &GenerationConfig::do_sample)
.def_readwrite("top_k", &GenerationConfig::top_k)
Expand Down
10 changes: 7 additions & 3 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct Args {
std::string prompt = "你好";
std::string system = "";
int max_length = 2048;
int max_new_tokens = -1;
int max_context_length = 512;
bool interactive = false;
int top_k = 0;
Expand All @@ -44,14 +45,15 @@ static void usage(const std::string &prog) {
<< "options:\n"
<< " -h, --help show this help message and exit\n"
<< " -m, --model PATH model path (default: chatglm-ggml.bin)\n"
<< " --mode inference mode chose from {chat, generate} (default: chat)\n"
<< " --mode inference mode chosen from {chat, generate} (default: chat)\n"
<< " --sync synchronized generation without streaming\n"
<< " -p, --prompt PROMPT prompt to start generation with (default: 你好)\n"
<< " --pp, --prompt_path path to the plain text file that stores the prompt\n"
<< " -s, --system SYSTEM system message to set the behavior of the assistant\n"
<< " --sp, --system_path path to the plain text file that stores the system message\n"
<< " -i, --interactive run in interactive mode\n"
<< " -l, --max_length N max total length including prompt and output (default: 2048)\n"
<< " --max_new_tokens N max number of tokens to generate, ignoring the number of prompt tokens\n"
<< " -c, --max_context_length N\n"
<< " max context length (default: 512)\n"
<< " --top_k N top-k sampling (default: 0)\n"
Expand Down Expand Up @@ -97,6 +99,8 @@ static Args parse_args(const std::vector<std::string> &argv) {
args.interactive = true;
} else if (arg == "-l" || arg == "--max_length") {
args.max_length = std::stoi(argv.at(++i));
} else if (arg == "--max_new_tokens") {
args.max_new_tokens = std::stoi(argv.at(++i));
} else if (arg == "-c" || arg == "--max_context_length") {
args.max_context_length = std::stoi(argv.at(++i));
} else if (arg == "--top_k") {
Expand Down Expand Up @@ -179,8 +183,8 @@ static void chat(Args &args) {
}
auto streamer = std::make_unique<chatglm::StreamerGroup>(std::move(streamers));

chatglm::GenerationConfig gen_config(args.max_length, args.max_context_length, args.temp > 0, args.top_k,
args.top_p, args.temp, args.repeat_penalty, args.num_threads);
chatglm::GenerationConfig gen_config(args.max_length, args.max_new_tokens, args.max_context_length, args.temp > 0,
args.top_k, args.top_p, args.temp, args.repeat_penalty, args.num_threads);

if (args.verbose) {
std::cout << "system info: | "
Expand Down

0 comments on commit cb31d84

Please sign in to comment.