Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allocate kv-cache on demand #277

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ void BaseModelForCausalLM::sampling_softmax_inplace(TokenIdScore *first, TokenId
std::vector<int> BaseModelForCausalLM::generate(const std::vector<int> &input_ids, const GenerationConfig &gen_config,
BaseStreamer *streamer) {
CHATGLM_CHECK(gen_config.max_length <= config.max_length)
<< "requested max_length (" << gen_config.max_length << ") is larger than model's max_length ("
<< "Requested max_length (" << gen_config.max_length << ") exceeds pre-configured model max_length ("
<< config.max_length << ")";

std::vector<int> output_ids;
Expand Down Expand Up @@ -1700,7 +1700,16 @@ StateDict InternLMForCausalLM::state_dict() const {

// ===== pipeline =====

Pipeline::Pipeline(const std::string &path) {
Pipeline::Pipeline(const std::string &path, int max_length) {
auto _update_config_max_length = [](ModelConfig &config, int max_length) {
if (max_length > 0) {
CHATGLM_CHECK(max_length <= config.max_length)
<< "Requested max_length (" << max_length << ") exceeds the max possible model sequence length ("
<< config.max_length;
config.max_length = max_length;
}
};

mapped_file = std::make_unique<MappedFile>(path);
ModelLoader loader(mapped_file->data, mapped_file->size);

Expand All @@ -1718,6 +1727,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>(), 1e-5f, ActivationType::GELU, true, true,
true, false, RopeType::CHATGLM, -1, AttentionMaskType::CHATGLM);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1734,6 +1744,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV2>(), 1e-5f, ActivationType::SILU, true, false,
false, false, RopeType::GPTJ, 2, AttentionMaskType::CAUSAL);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1758,6 +1769,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>(), 1e-6f, ActivationType::SILU, false, false,
false, false, RopeType::NEOX, 1, AttentionMaskType::CAUSAL);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1774,6 +1786,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>(), 1e-6f, ActivationType::SILU, false, false,
false, true, RopeType::DISABLED, -1, AttentionMaskType::CAUSAL);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1797,6 +1810,7 @@ Pipeline::Pipeline(const std::string &path) {
config = ModelConfig(model_type, rec, 1e-6f, ActivationType::SILU, false, false, false, false,
RopeType::NEOX, 1, AttentionMaskType::CAUSAL);
}
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand Down
2 changes: 1 addition & 1 deletion chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ class InternLMForCausalLM : public BasicModelForCausalLM<InternLMModel> {

class Pipeline {
public:
Pipeline(const std::string &path);
Pipeline(const std::string &path, int max_length = -1);

std::vector<int> generate(const std::vector<int> &input_ids, const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr) const;
Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class ModelType:
def value(self) -> int:
...
class Pipeline:
def __init__(self, path: str) -> None:
def __init__(self, path: str, max_length: int = -1) -> None:
...
@property
def model(self) -> BaseModelForCausalLM:
Expand Down
10 changes: 7 additions & 3 deletions chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ def _ensure_chat_message(message: Union[ChatMessage, Dict[str, Any]]) -> ChatMes


class Pipeline(_C.Pipeline):
def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None:
def __init__(self, model_path: str, *, max_length: Optional[int] = None, dtype: Optional[str] = None) -> None:
kwargs = {}
if max_length is not None:
kwargs.update(max_length=max_length)

if Path(model_path).is_file():
# load ggml model
super().__init__(str(model_path))
super().__init__(str(model_path), **kwargs)
else:
# convert hf model to ggml format
from chatglm_cpp.convert import convert
Expand All @@ -40,7 +44,7 @@ def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None:

with tempfile.NamedTemporaryFile("wb") as f:
convert(f, model_path, dtype=dtype)
super().__init__(f.name)
super().__init__(f.name, **kwargs)

def chat(
self,
Expand Down
2 changes: 1 addition & 1 deletion chatglm_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ PYBIND11_MODULE(_C, m) {
// ===== Pipeline ====

py::class_<Pipeline>(m, "Pipeline")
.def(py::init<const std::string &>(), "path"_a)
.def(py::init<const std::string &, int>(), "path"_a, "max_length"_a = -1)
.def_property_readonly("model", [](const Pipeline &self) { return self.model.get(); })
.def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer.get(); });
}
Expand Down
2 changes: 1 addition & 1 deletion main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ static inline void print_message(const chatglm::ChatMessage &message) {
static void chat(Args &args) {
ggml_time_init();
int64_t start_load_us = ggml_time_us();
chatglm::Pipeline pipeline(args.model_path);
chatglm::Pipeline pipeline(args.model_path, args.max_length);
int64_t end_load_us = ggml_time_us();

std::string model_name = pipeline.model->config.model_type_name();
Expand Down
13 changes: 13 additions & 0 deletions tests/test_chatglm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def check_pipeline(model_path, prompt, target, gen_kwargs={}):
assert stream_output == target


@pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found")
def test_pipeline_options():
# check max_length option
pipeline = chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH)
assert pipeline.model.config.max_length == 2048
pipeline = chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH, max_length=234)
assert pipeline.model.config.max_length == 234

# check if resources are properly released
for _ in range(100):
chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH)


@pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found")
def test_chatglm_pipeline():
check_pipeline(
Expand Down
Loading