Skip to content

Commit

Permalink
metal
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jun 20, 2024
1 parent ffa0d77 commit 4a43a86
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 60 deletions.
72 changes: 36 additions & 36 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
#include <ggml-cuda.h>
#endif

#ifdef GGML_USE_METAL
#include <ggml-metal.h>
#endif

namespace chatglm {

static std::string shape_to_string(ggml_tensor *tensor) {
Expand Down Expand Up @@ -179,27 +183,6 @@ std::vector<ChatMessage> BaseTokenizer::filter_user_assistant_messages(const std
return out;
}

// void ModelContext::init_device_context() {
// #ifdef GGML_USE_METAL
// ctx_metal = make_unique_ggml_metal_context(1);

// const size_t max_size = ggml_get_max_tensor_size(ctx_w.get());

// void *weight_data = weight_buffer.empty() ? ggml_get_mem_buffer(ctx_w.get()) : (void *)weight_buffer.data();
// size_t weight_size = weight_buffer.empty() ? ggml_get_mem_size(ctx_w.get()) : weight_buffer.size();
// CHATGLM_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "weights", weight_data, weight_size, max_size));

// CHATGLM_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "kv", ggml_get_mem_buffer(ctx_kv.get()),
// ggml_get_mem_size(ctx_kv.get()), 0));

// void *compute_data = ctx_b ? ggml_get_mem_buffer(ctx_b.get()) : compute_meta.data();
// size_t compute_size = ctx_b ? ggml_get_mem_size(ctx_b.get()) : compute_meta.size();
// CHATGLM_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "compute", compute_data, compute_size, 0));

// CHATGLM_CHECK(ggml_metal_add_buffer(ctx_metal.get(), "scratch", scratch.data, scratch.size, 0));
// #endif
// }

// ===== streamer =====

void StreamerGroup::put(const std::vector<int> &output_ids) {
Expand Down Expand Up @@ -602,17 +585,16 @@ static ggml_tensor *apply_rotary_emb_glm2(ModelContext *mctx, ggml_tensor *layer
ggml_tensor *half_layer_view =
ggml_view_3d(ctx, layer, rope_dim, layer->ne[1], layer->ne[2], layer->nb[1], layer->nb[2], 0);

// TODO: metal
ggml_tensor *half_layer = half_layer_view;
if (!ggml_backend_is_cpu(mctx->backend.get())) {
half_layer = ggml_cont(ctx, half_layer);
}
#ifdef GGML_USE_CUDA
half_layer = ggml_cont(ctx, half_layer);
#endif
ggml_tensor *roped_half_layer =
ggml_rope_ext_inplace(ctx, half_layer, position_ids, nullptr, rope_dim, (int)RopeType::GPTJ, 0, rope_theta,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // [s, #h, d]
if (!ggml_backend_is_cpu(mctx->backend.get())) {
roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view);
}
#ifdef GGML_USE_CUDA
roped_half_layer = ggml_cpy(ctx, roped_half_layer, half_layer_view);
#endif
ggml_build_forward_expand(mctx->gf, roped_half_layer);

return layer;
Expand Down Expand Up @@ -1106,11 +1088,19 @@ ggml_tensor *GLMBlock::forward(ModelContext *mctx, ggml_tensor *hidden_states, g
ChatGLMForCausalLM::ChatGLMForCausalLM(const ModelConfig &config) : BasicModelForCausalLM(config) {}

void ChatGLMForCausalLM::load_state_dict(const StateDict &sd) {
// TODO: handle metal
void *sd_buf_base = ggml_backend_buffer_get_base(sd.buf.get());
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd.buf.get());
if (ggml_backend_is_cpu(mctx_->backend.get())) {
mctx_->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(
ggml_backend_buffer_get_base(sd.buf.get()), ggml_backend_buffer_get_size(sd.buf.get())));
} else {
mctx_->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(sd_buf_base, sd_buf_size));
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx_->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx_->ctx_w.get());
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx_->ctx_w.get(), mctx_->backend.get()));
}
Expand All @@ -1120,7 +1110,7 @@ void ChatGLMForCausalLM::load_state_dict(const StateDict &sd) {
const std::string &name = item.first;
ggml_tensor *self_weight = item.second;
ggml_tensor *ckpt_weight = sd.kv.at(name);
if (ggml_backend_is_cpu(mctx_->backend.get())) {
if (ggml_backend_is_cpu(mctx_->backend.get()) || ggml_cpu_has_metal()) {
ggml_backend_tensor_alloc(mctx_->buf_w.get(), self_weight, ckpt_weight->data);
} else {
ggml_backend_tensor_set(self_weight, ckpt_weight->data, 0, ggml_nbytes(self_weight));
Expand Down Expand Up @@ -1256,10 +1246,20 @@ bool ChatGLM2Tokenizer::is_special_id(int id) const {
ChatGLM2ForCausalLM::ChatGLM2ForCausalLM(const ModelConfig &config) : BasicModelForCausalLM(config) {}

void ChatGLM2ForCausalLM::load_state_dict(const StateDict &sd) {
void *sd_buf_base = ggml_backend_buffer_get_base(sd.buf.get());
const size_t sd_buf_size = ggml_backend_buffer_get_size(sd.buf.get());
if (ggml_backend_is_cpu(mctx_->backend.get())) {
mctx_->buf_w = unique_ggml_backend_buffer_t(ggml_backend_cpu_buffer_from_ptr(
ggml_backend_buffer_get_base(sd.buf.get()), ggml_backend_buffer_get_size(sd.buf.get())));
} else {
}
#ifdef GGML_USE_METAL
else if (ggml_backend_is_metal(mctx_->backend.get())) {
const size_t max_size = ggml_get_max_tensor_size(mctx_->ctx_w.get());
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_metal_buffer_from_ptr(sd_buf_base, sd_buf_size, max_size));
}
#endif
else {
mctx_->buf_w =
unique_ggml_backend_buffer_t(ggml_backend_alloc_ctx_tensors(mctx_->ctx_w.get(), mctx_->backend.get()));
}
Expand Down Expand Up @@ -1289,7 +1289,7 @@ void ChatGLM2ForCausalLM::load_state_dict(const StateDict &sd) {

CHATGLM_CHECK(ggml_nbytes(ckpt_weight) == ggml_nbytes(gate_proj) + ggml_nbytes(up_proj));

if (ggml_backend_is_cpu(mctx_->backend.get())) {
if (ggml_backend_is_cpu(mctx_->backend.get()) || ggml_cpu_has_metal()) {
ggml_backend_tensor_alloc(mctx_->buf_w.get(), gate_proj, ckpt_weight->data);
ggml_backend_tensor_alloc(mctx_->buf_w.get(), up_proj,
(char *)ckpt_weight->data + ggml_nbytes(gate_proj));
Expand All @@ -1301,7 +1301,7 @@ void ChatGLM2ForCausalLM::load_state_dict(const StateDict &sd) {
} else {
// normal weight
ggml_tensor *self_weight = self_sd.kv.at(name);
if (ggml_backend_is_cpu(mctx_->backend.get())) {
if (ggml_backend_is_cpu(mctx_->backend.get()) || ggml_cpu_has_metal()) {
ggml_backend_tensor_alloc(mctx_->buf_w.get(), self_weight, ckpt_weight->data);
} else {
ggml_backend_tensor_set(self_weight, ckpt_weight->data, 0, ggml_nbytes(self_weight));
Expand Down
19 changes: 1 addition & 18 deletions chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
#include <sstream>
#include <unordered_map>

// #ifdef GGML_USE_METAL
// #include <ggml-metal.h>
// #endif

namespace chatglm {

// ===== common =====
Expand Down Expand Up @@ -303,24 +299,11 @@ struct ggml_backend_buffer_deleter_t {

using unique_ggml_backend_buffer_t = std::unique_ptr<ggml_backend_buffer, ggml_backend_buffer_deleter_t>;

#ifdef GGML_USE_METAL
struct ggml_metal_context_deleter_t {
void operator()(ggml_metal_context *ctx) const noexcept { ggml_metal_free(ctx); }
};

using unique_ggml_metal_context_t = std::unique_ptr<ggml_metal_context, ggml_metal_context_deleter_t>;

static inline unique_ggml_metal_context_t make_unique_ggml_metal_context(int n_cb) {
return unique_ggml_metal_context_t(ggml_metal_init(n_cb));
}
#endif

// reference: https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp
template <typename T>
struct no_init {
T value;
no_init() { /* do nothing */
}
no_init() { /* do nothing */ }
};

struct ModelContext {
Expand Down
8 changes: 2 additions & 6 deletions chatglm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,15 +958,11 @@ static void check_chat_format(const Pipeline &pipeline) {
GenerationConfig gen_config;
gen_config.max_new_tokens = 1;
EXPECT_THROW(
{
pipeline.chat({{ChatMessage::ROLE_USER, "user"}, {ChatMessage::ROLE_USER, "user"}}, gen_config);
},
{ pipeline.chat({{ChatMessage::ROLE_USER, "user"}, {ChatMessage::ROLE_USER, "user"}}, gen_config); },
std::runtime_error);
EXPECT_THROW({ pipeline.chat({{ChatMessage::ROLE_ASSISTANT, "assistant"}}, gen_config); }, std::runtime_error);
EXPECT_THROW(
{
pipeline.chat({{ChatMessage::ROLE_USER, "user"}, {ChatMessage::ROLE_ASSISTANT, "assistant"}}, gen_config);
},
{ pipeline.chat({{ChatMessage::ROLE_USER, "user"}, {ChatMessage::ROLE_ASSISTANT, "assistant"}}, gen_config); },
std::runtime_error);
// never throw with system prompt
pipeline.chat({{ChatMessage::ROLE_SYSTEM, "system"}, {ChatMessage::ROLE_USER, "user"}}, gen_config);
Expand Down

0 comments on commit 4a43a86

Please sign in to comment.