diff --git a/chatglm.cpp b/chatglm.cpp index f4322db..c846b74 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -905,7 +905,7 @@ void BaseModelForCausalLM::sampling_softmax_inplace(TokenIdScore *first, TokenId std::vector BaseModelForCausalLM::generate(const std::vector &input_ids, const GenerationConfig &gen_config, BaseStreamer *streamer) { CHATGLM_CHECK(gen_config.max_length <= config.max_length) - << "requested max_length (" << gen_config.max_length << ") is larger than model's max_length (" + << "Requested max_length (" << gen_config.max_length << ") exceeds pre-configured model max_length (" << config.max_length << ")"; std::vector output_ids; @@ -1700,7 +1700,16 @@ StateDict InternLMForCausalLM::state_dict() const { // ===== pipeline ===== -Pipeline::Pipeline(const std::string &path) { +Pipeline::Pipeline(const std::string &path, int max_length) { + auto _update_config_max_length = [](ModelConfig &config, int max_length) { + if (max_length > 0) { + CHATGLM_CHECK(max_length <= config.max_length) + << "Requested max_length (" << max_length << ") exceeds the max possible model sequence length (" + << config.max_length; + config.max_length = max_length; + } + }; + mapped_file = std::make_unique(path); ModelLoader loader(mapped_file->data, mapped_file->size); @@ -1718,6 +1727,7 @@ Pipeline::Pipeline(const std::string &path) { // load config ModelConfig config(model_type, loader.read_basic(), 1e-5f, ActivationType::GELU, true, true, true, false, RopeType::CHATGLM, -1, AttentionMaskType::CHATGLM); + _update_config_max_length(config, max_length); // load tokenizer int proto_size = loader.read_basic(); @@ -1734,6 +1744,7 @@ Pipeline::Pipeline(const std::string &path) { // load config ModelConfig config(model_type, loader.read_basic(), 1e-5f, ActivationType::SILU, true, false, false, false, RopeType::GPTJ, 2, AttentionMaskType::CAUSAL); + _update_config_max_length(config, max_length); // load tokenizer int proto_size = loader.read_basic(); @@ -1758,6 +1769,7 @@ Pipeline::Pipeline(const std::string &path) { // load config ModelConfig config(model_type, loader.read_basic(), 1e-6f, ActivationType::SILU, false, false, false, false, RopeType::NEOX, 1, AttentionMaskType::CAUSAL); + _update_config_max_length(config, max_length); // load tokenizer int proto_size = loader.read_basic(); @@ -1774,6 +1786,7 @@ Pipeline::Pipeline(const std::string &path) { // load config ModelConfig config(model_type, loader.read_basic(), 1e-6f, ActivationType::SILU, false, false, false, true, RopeType::DISABLED, -1, AttentionMaskType::CAUSAL); + _update_config_max_length(config, max_length); // load tokenizer int proto_size = loader.read_basic(); @@ -1797,6 +1810,7 @@ Pipeline::Pipeline(const std::string &path) { config = ModelConfig(model_type, rec, 1e-6f, ActivationType::SILU, false, false, false, false, RopeType::NEOX, 1, AttentionMaskType::CAUSAL); } + _update_config_max_length(config, max_length); // load tokenizer int proto_size = loader.read_basic(); diff --git a/chatglm.h b/chatglm.h index a76b938..91a6edc 100644 --- a/chatglm.h +++ b/chatglm.h @@ -1065,7 +1065,7 @@ class InternLMForCausalLM : public BasicModelForCausalLM { class Pipeline { public: - Pipeline(const std::string &path); + Pipeline(const std::string &path, int max_length = -1); std::vector generate(const std::vector &input_ids, const GenerationConfig &gen_config, BaseStreamer *streamer = nullptr) const; diff --git a/chatglm_cpp/_C.pyi b/chatglm_cpp/_C.pyi index 3ebad41..16a0a49 100644 --- a/chatglm_cpp/_C.pyi +++ b/chatglm_cpp/_C.pyi @@ -174,7 +174,7 @@ class ModelType: def value(self) -> int: ... class Pipeline: - def __init__(self, path: str) -> None: + def __init__(self, path: str, max_length: int = -1) -> None: ... @property def model(self) -> BaseModelForCausalLM: diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index 72fbca2..92d9b72 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -27,10 +27,14 @@ def _ensure_chat_message(message: Union[ChatMessage, Dict[str, Any]]) -> ChatMes class Pipeline(_C.Pipeline): - def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None: + def __init__(self, model_path: str, *, max_length: Optional[int] = None, dtype: Optional[str] = None) -> None: + kwargs = {} + if max_length is not None: + kwargs.update(max_length=max_length) + if Path(model_path).is_file(): # load ggml model - super().__init__(str(model_path)) + super().__init__(str(model_path), **kwargs) else: # convert hf model to ggml format from chatglm_cpp.convert import convert @@ -40,7 +44,7 @@ def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None: with tempfile.NamedTemporaryFile("wb") as f: convert(f, model_path, dtype=dtype) - super().__init__(f.name) + super().__init__(f.name, **kwargs) def chat( self, diff --git a/chatglm_pybind.cpp b/chatglm_pybind.cpp index e20c367..8a191f7 100644 --- a/chatglm_pybind.cpp +++ b/chatglm_pybind.cpp @@ -160,7 +160,7 @@ PYBIND11_MODULE(_C, m) { // ===== Pipeline ==== py::class_(m, "Pipeline") - .def(py::init(), "path"_a) + .def(py::init(), "path"_a, "max_length"_a = -1) .def_property_readonly("model", [](const Pipeline &self) { return self.model.get(); }) .def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer.get(); }); } diff --git a/main.cpp b/main.cpp index f68e0e7..a909746 100644 --- a/main.cpp +++ b/main.cpp @@ -171,7 +171,7 @@ static inline void print_message(const chatglm::ChatMessage &message) { static void chat(Args &args) { ggml_time_init(); int64_t start_load_us = ggml_time_us(); - chatglm::Pipeline pipeline(args.model_path); + chatglm::Pipeline pipeline(args.model_path, args.max_length); int64_t end_load_us = ggml_time_us(); std::string model_name = pipeline.model->config.model_type_name(); diff --git a/tests/test_chatglm_cpp.py b/tests/test_chatglm_cpp.py index ba36cf5..45386bb 100644 --- a/tests/test_chatglm_cpp.py +++ b/tests/test_chatglm_cpp.py @@ -35,6 +35,19 @@ def check_pipeline(model_path, prompt, target, gen_kwargs={}): assert stream_output == target +@pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found") +def test_pipeline_options(): + # check max_length option + pipeline = chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH) + assert pipeline.model.config.max_length == 2048 + pipeline = chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH, max_length=234) + assert pipeline.model.config.max_length == 234 + + # check if resources are properly released + for _ in range(100): + chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH) + + @pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found") def test_chatglm_pipeline(): check_pipeline(