From d1118e9f51c71e04a149c044c8d77df5a20e5d2d Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Mon, 14 Oct 2024 13:55:50 +0200 Subject: [PATCH 1/6] working but needs to fix LSTM --- decoder.h | 10 +- encodec.cpp | 265 ++++++++++++++++++++-------------------------------- encoder.h | 10 +- ggml | 2 +- lstm.h | 45 ++++----- ops.cpp | 27 +----- ops.h | 2 - quantizer.h | 23 ++--- 8 files changed, 143 insertions(+), 241 deletions(-) diff --git a/decoder.h b/decoder.h index e2a61e4..7f37544 100644 --- a/decoder.h +++ b/decoder.h @@ -41,7 +41,7 @@ struct encodec_decoder { }; struct ggml_tensor *encodec_forward_decoder( - const struct encodec_decoder *decoder, struct ggml_allocr *allocr, struct ggml_context *ctx0, + const struct encodec_decoder *decoder, struct ggml_context *ctx0, struct ggml_tensor *quantized_out, const int *ratios, const int kernel_size, const int res_kernel_size, const int stride) { @@ -60,14 +60,14 @@ struct ggml_tensor *encodec_forward_decoder( const encodec_lstm lstm = decoder->lstm; // first lstm layer + char l0_prefix[7] = "dec_l0"; struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( - ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w, - lstm.l0_ih_b, lstm.l0_hh_b); + ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, l0_prefix); // second lstm layer + char l1_prefix[7] = "dec_l1"; struct ggml_tensor *out = forward_pass_lstm_unilayer( - ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w, - lstm.l1_ih_b, lstm.l1_hh_b); + ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, l1_prefix); inpL = ggml_add(ctx0, inpL, out); } diff --git a/encodec.cpp b/encodec.cpp index c8aa00d..e770f12 100644 --- a/encodec.cpp +++ b/encodec.cpp @@ -102,8 +102,8 @@ struct encodec_context { // buffer for model evaluation ggml_backend_buffer_t buf_compute; - // custom allocrator - struct ggml_allocr *allocr = NULL; + // tensor graph allocator + ggml_gallocr_t allocr = NULL; // intermediate steps struct ggml_tensor *encoded = NULL; // Encoded audio @@ -171,78 +171,10 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int auto &ctx = model.ctx; - size_t buffer_size = 0; - size_t n_tensors = 0; - - // Evaluating context size - { - const auto &hparams = model.hparams; - - const int in_channels = hparams.in_channels; - const int hidden_dim = hparams.hidden_dim; - const int n_filters = hparams.n_filters; - const int kernel_size = hparams.kernel_size; - const int res_kernel_sz = hparams.residual_kernel_size; - const int n_bins = hparams.n_bins; - const int *ratios = hparams.ratios; - const int n_lstm_layers = hparams.n_lstm_layers; - - // encoder - { - int mult = 1; // scaling factor for hidden size - - // initial conv1d layer - buffer_size += in_channels * n_filters * kernel_size * ggml_type_size(wtype); // weight - buffer_size += n_filters * ggml_type_size(GGML_TYPE_F32); // bias - - // resnet blocks - for (int i = 0; i < 4; i++) { - // conv1 - buffer_size += res_kernel_sz * (mult * n_filters) * (mult * n_filters / 2) * ggml_type_size(wtype); // weight - buffer_size += (mult * n_filters / 2) * ggml_type_size(GGML_TYPE_F32); // bias - - // conv2 - buffer_size += (mult * n_filters / 2) * (mult * n_filters) * ggml_type_size(wtype); // weight - buffer_size += (mult * n_filters) * ggml_type_size(GGML_TYPE_F32); // bias - - // shortcut - buffer_size += (mult * n_filters) * (mult * n_filters) * ggml_type_size(wtype); // weight - buffer_size += (mult * n_filters) * ggml_type_size(GGML_TYPE_F32); // bias - - // downsampling layers - buffer_size += (2 * ratios[3 - i]) * (mult * n_filters) * (mult * n_filters * 2) * ggml_type_size(wtype); // weight - buffer_size += (2 * mult * n_filters) * ggml_type_size(GGML_TYPE_F32); // bias - - mult *= 2; - } - - // lstm - buffer_size += 2 * n_lstm_layers * (mult * n_filters) * (4 * mult * n_filters) * ggml_type_size(wtype); // weight_ih and weight_hh - buffer_size += 2 * n_lstm_layers * (4 * mult * n_filters) * ggml_type_size(GGML_TYPE_F32); // bias_ih and bias_hh - - // final conv - buffer_size += kernel_size * (mult * n_filters) * hidden_dim * ggml_type_size(wtype); // weight - buffer_size += hidden_dim * ggml_type_size(GGML_TYPE_F32); // bias - } - - // decoder mirrors the encoder (same number of parameters), just double context size - buffer_size *= 2; - - // quantizer - int n_q = 32; // 32 is an upper bound on the number of codebooks. - buffer_size += n_q * hidden_dim * n_bins * ggml_type_size(GGML_TYPE_F32); // embed - - buffer_size += 10ull * MB; // object overhead - - n_tensors = ((4 * 2) * 4 + 2 + 4 * n_lstm_layers + 2) * 2; // encoder and decoder - n_tensors += n_q * 1; // quantizer - - // printf("%s: ggml tensor size = %d bytes\n", __func__, (int)sizeof(ggml_tensor)); - // printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size / (1024.0 * 1024.0)); - } - // create the ggml context { + size_t n_tensors = ((4 * 2) * 4 + 2 + 4 * model.hparams.n_lstm_layers + 2) * 2; // encoder and decoder + n_tensors += model.hparams.n_q * 1; // quantizer struct ggml_init_params params = { /* .mem_size = */ ggml_tensor_overhead() * n_tensors, /* .mem_buffer = */ NULL, @@ -288,21 +220,18 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int return false; } - // allocate weights buffer - model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size); - - // prepare memory for the weights + // create the tensors for the model { - const auto &hparams = model.hparams; + const auto & hparams = model.hparams; - const int in_channels = hparams.in_channels; - const int hidden_dim = hparams.hidden_dim; - const int n_filters = hparams.n_filters; - const int kernel_size = hparams.kernel_size; + const int in_channels = hparams.in_channels; + const int hidden_dim = hparams.hidden_dim; + const int n_filters = hparams.n_filters; + const int kernel_size = hparams.kernel_size; const int res_kernel_sz = hparams.residual_kernel_size; - const int n_q = hparams.n_q; - const int *ratios = hparams.ratios; - const int n_bins = hparams.n_bins; + const int n_q = hparams.n_q; + const int *ratios = hparams.ratios; + const int n_bins = hparams.n_bins; // encoder { @@ -469,10 +398,11 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int } } + // allocate the model tensors in a backend buffer + model.buffer_w = ggml_backend_alloc_ctx_tensors(ctx, model.backend); + // load weights { - ggml_allocr *alloc = ggml_allocr_new_from_buffer(model.buffer_w); - size_t total_size = 0; model.n_loaded = 0; @@ -529,14 +459,8 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int return false; } - ggml_allocr_alloc(alloc, tensor); - - if (ggml_backend_is_cpu(model.backend) -#ifdef GGML_USE_METAL - || ggml_backend_is_metal(model.backend) -#endif - ) { - // for the CPU and Metal backends, we can read directly into the device memory + if (ggml_backend_buffer_is_host(model.buffer_w)) { + // for some backends such as CPU and Metal, the tensor data is in system memory and we can read directly into it infile.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); } else { // read into a temporary buffer first, then copy to device memory @@ -549,7 +473,6 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int model.n_loaded++; } - ggml_allocr_free(alloc); printf("%s: model size = %8.2f MB\n", __func__, total_size / 1024.0 / 1024.0); } @@ -581,13 +504,13 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, // since we are using ggml-alloc, this buffer only needs enough space to hold the // ggml_tensor and ggml_cgraph structs, but not the tensor data - static size_t buf_size = ggml_tensor_overhead() * GGML_MAX_NODES + ggml_graph_overhead(); + static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); struct ggml_init_params ggml_params = { - /*.mem_size =*/buf_size, - /*.mem_buffer =*/buf.data(), - /*.no_alloc =*/true, // skip allocating as we use ggml_alloc to allocate exact memory requirements + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() }; struct ggml_context *ctx0 = ggml_init(ggml_params); @@ -595,38 +518,34 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, struct ggml_cgraph *gf = ggml_new_graph(ctx0); struct ggml_tensor *inp = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_samples); - ggml_allocr_alloc(allocr, inp); - - // avoid writing to tensors if we are only measuring the memory usage - if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(inp, inp_audio, 0, n_samples * ggml_element_size(inp)); - } + ggml_set_name(inp, "inp"); + ggml_set_input(inp); - const struct encodec_encoder *encoder = &model.encoder; - const struct encodec_quantizer *quantizer = &model.quantizer; - const struct encodec_decoder *decoder = &model.decoder; + const struct encodec_encoder * encoder = &model.encoder; + const struct encodec_quantizer * quantizer = &model.quantizer; + const struct encodec_decoder * decoder = &model.decoder; struct ggml_tensor * encoded = encodec_forward_encoder( - encoder, allocr, ctx0, inp, ratios, kernel_size, res_kernel_sz, stride - ); + encoder, ctx0, inp, ratios, kernel_size, res_kernel_sz, stride); struct ggml_tensor * codes = encodec_forward_quantizer_encode( - quantizer, allocr, ctx0, encoded, n_bins, sr, bandwidth, hop_length - ); + quantizer, ctx0, encoded, n_bins, sr, bandwidth, hop_length); struct ggml_tensor * quantized = encodec_forward_quantizer_decode( - quantizer, allocr, ctx0, codes, hidden_dim, n_bins, sr, bandwidth, hop_length - ); + quantizer, ctx0, codes, hidden_dim, n_bins, sr, bandwidth, hop_length); struct ggml_tensor * decoded = encodec_forward_decoder( - decoder, allocr, ctx0, quantized, ratios, kernel_size, res_kernel_sz, stride - ); + decoder, ctx0, quantized, ratios, kernel_size, res_kernel_sz, stride); switch (mode) { case encodec_run_mode_t::FULL: { + ggml_set_name(decoded, "decoded"); + ggml_set_output(decoded); ggml_build_forward_expand(gf, decoded); } break; case encodec_run_mode_t::ENCODE: { + ggml_set_name(codes, "codes"); + ggml_set_output(codes); ggml_build_forward_expand(gf, codes); } break; case encodec_run_mode_t::DECODE: { @@ -675,15 +594,13 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int3 const int N = n_codes / n_q; - // since we are using ggml-alloc, this buffer only needs enough space to hold the - // ggml_tensor and ggml_cgraph structs, but not the tensor data - static size_t buf_size = ggml_tensor_overhead() * GGML_MAX_NODES + ggml_graph_overhead(); + static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); struct ggml_init_params ggml_params = { - /*.mem_size =*/buf_size, - /*.mem_buffer =*/buf.data(), - /*.no_alloc =*/true, + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, }; struct ggml_context *ctx0 = ggml_init(ggml_params); @@ -691,26 +608,24 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int3 struct ggml_cgraph *gf = ggml_new_graph(ctx0); struct ggml_tensor *inp_codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, N, n_q); - ggml_allocr_alloc(allocr, inp_codes); - - // avoid writing to tensors if we are only measuring the memory usage - if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(inp_codes, codes, 0, N * n_q * ggml_element_size(inp_codes)); - } + ggml_set_name(inp_codes, "inp_codes"); + ggml_set_input(inp_codes); - const struct encodec_quantizer *quantizer = &model.quantizer; - const struct encodec_decoder *decoder = &model.decoder; + const struct encodec_quantizer * quantizer = &model.quantizer; + const struct encodec_decoder * decoder = &model.decoder; struct ggml_tensor *quantized = encodec_forward_quantizer_decode( - quantizer, allocr, ctx0, inp_codes, hidden_dim, n_bins, sr, bandwidth, hop_length + quantizer, ctx0, inp_codes, hidden_dim, n_bins, sr, bandwidth, hop_length ); struct ggml_tensor *decoded = encodec_forward_decoder( - decoder, allocr, ctx0, quantized, ratios, kernel_size, res_kernel_sz, stride + decoder, ctx0, quantized, ratios, kernel_size, res_kernel_sz, stride ); switch (mode) { case encodec_run_mode_t::DECODE: { + ggml_set_name(decoded, "decoded"); + ggml_set_output(decoded); ggml_build_forward_expand(gf, decoded); } break; default: { @@ -727,19 +642,38 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int3 return gf; } +static void encodec_zero_tensor(struct ggml_cgraph *gf, const char *name) { + struct ggml_tensor *tensor = ggml_graph_get_tensor(gf, name); + ggml_set_zero(tensor); +} + bool encodec_eval_internal(struct encodec_context *ectx, const float * raw_audio, const int n_samples, const int n_threads, const encodec_run_mode_t mode) { auto & model = ectx->model; auto & allocr = ectx->allocr; - // reset the allocator to free all the memory allocated during the previous inference - ggml_allocr_reset(allocr); - struct ggml_cgraph *gf = encodec_build_graph(ectx, raw_audio, n_samples, mode); - // allocate tensors - ggml_allocr_alloc_graph(allocr, gf); + // allocate the graph tensors + ggml_gallocr_alloc_graph(allocr, gf); + + // set the graph inputs + struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "inp"); + ggml_backend_tensor_set(inp, raw_audio, 0, n_samples * ggml_element_size(inp)); + + // make sure accumulation tensor are zeroed + encodec_zero_tensor(gf, "enc_l0_ht"); + encodec_zero_tensor(gf, "enc_l1_ht"); + encodec_zero_tensor(gf, "enc_l0_ct"); + encodec_zero_tensor(gf, "enc_l1_ct"); + + encodec_zero_tensor(gf, "dec_l0_ht"); + encodec_zero_tensor(gf, "dec_l1_ht"); + encodec_zero_tensor(gf, "dec_l0_ct"); + encodec_zero_tensor(gf, "dec_l1_ct"); + + encodec_zero_tensor(gf, "quantized_out"); // run the computation if (ggml_backend_is_cpu(model.backend)) { @@ -761,13 +695,27 @@ bool encodec_eval_internal(struct encodec_context *ectx, const int32_t *codes, auto & model = ectx->model; auto & allocr = ectx->allocr; - // reset the allocator to free all the memory allocated during the previous inference - ggml_allocr_reset(allocr); - struct ggml_cgraph *gf = encodec_build_graph(ectx, codes, n_codes, mode); - // allocate tensors - ggml_allocr_alloc_graph(allocr, gf); + // allocate the graph tensors + ggml_gallocr_alloc_graph(allocr, gf); + + // set the graph inputs + struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "inp_codes"); + ggml_backend_tensor_set(inp, codes, 0, n_codes * ggml_element_size(inp)); + + // make sure accumulation tensor are zeroed + encodec_zero_tensor(gf, "enc_l0_ht"); + encodec_zero_tensor(gf, "enc_l1_ht"); + encodec_zero_tensor(gf, "enc_l0_ct"); + encodec_zero_tensor(gf, "enc_l1_ct"); + + encodec_zero_tensor(gf, "dec_l0_ht"); + encodec_zero_tensor(gf, "dec_l1_ht"); + encodec_zero_tensor(gf, "dec_l0_ct"); + encodec_zero_tensor(gf, "dec_l1_ct"); + + encodec_zero_tensor(gf, "quantized_out"); // run the computation if (ggml_backend_is_cpu(model.backend)) { @@ -790,21 +738,15 @@ bool encodec_eval(struct encodec_context *ectx, const float *raw_audio, // allocate the compute buffer { - // alignment required by the backend - size_t align = ggml_backend_get_alignment(ectx->model.backend); - ectx->allocr = ggml_allocr_new_measure(align); + // create a graph allocator with the backend's default buffer type + ectx->allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(ectx->model.backend)); // create the graph for memory usage estimation struct ggml_cgraph *gf = encodec_build_graph(ectx, raw_audio, n_samples, mode); - // compute the required memory - size_t mem_size = ggml_allocr_alloc_graph(ectx->allocr, gf); - - // recreate the allocator with the required memory - ggml_allocr_free(ectx->allocr); - ectx->buf_compute = ggml_backend_alloc_buffer(ectx->model.backend, mem_size); - ectx->allocr = ggml_allocr_new_from_buffer(ectx->buf_compute); - + // pre-allocate the compute buffer + ggml_gallocr_reserve(ectx->allocr, gf); + size_t mem_size = ggml_gallocr_get_buffer_size(ectx->allocr, 0); fprintf(stderr, "%s: compute buffer size: %.2f MB\n\n", __func__, mem_size / 1024.0 / 1024.0); } @@ -826,24 +768,19 @@ bool encodec_eval(struct encodec_context *ectx, const int32_t *codes, // allocate the compute buffer { - // alignment required by the backend - size_t align = ggml_backend_get_alignment(ectx->model.backend); - ectx->allocr = ggml_allocr_new_measure(align); + // create a graph allocator with the backend's default buffer type + ectx->allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(ectx->model.backend)); // create the graph for memory usage estimation struct ggml_cgraph *gf = encodec_build_graph(ectx, codes, n_codes, mode); - // compute the required memory - size_t mem_size = ggml_allocr_alloc_graph(ectx->allocr, gf); - - // recreate the allocator with the required memory - ggml_allocr_free(ectx->allocr); - ectx->buf_compute = ggml_backend_alloc_buffer(ectx->model.backend, mem_size); - ectx->allocr = ggml_allocr_new_from_buffer(ectx->buf_compute); - + // pre-allocate the compute buffer + ggml_gallocr_reserve(ectx->allocr, gf); + size_t mem_size = ggml_gallocr_get_buffer_size(ectx->allocr, 0); fprintf(stderr, "%s: compute buffer size: %.2f MB\n\n", __func__, mem_size / 1024.0 / 1024.0); } + // encodec eval if (!encodec_eval_internal(ectx, codes, n_codes, n_threads, mode)) { fprintf(stderr, "%s: failed to run encodec eval\n", __func__); diff --git a/encoder.h b/encoder.h index 12e4413..15b4e3f 100644 --- a/encoder.h +++ b/encoder.h @@ -37,7 +37,7 @@ struct encodec_encoder { }; struct ggml_tensor *encodec_forward_encoder( - const struct encodec_encoder *encoder, struct ggml_allocr *allocr, struct ggml_context *ctx0, + const struct encodec_encoder *encoder, struct ggml_context *ctx0, struct ggml_tensor *inp, const int * ratios, const int kernel_size, const int res_kernel_size, const int stride) { @@ -87,14 +87,14 @@ struct ggml_tensor *encodec_forward_encoder( const encodec_lstm lstm = encoder->lstm; // first lstm layer + char l0_prefix[7] = "enc_l0"; struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( - ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w, - lstm.l0_ih_b, lstm.l0_hh_b); + ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, l0_prefix); // second lstm layer + char l1_prefix[7] = "enc_l1"; struct ggml_tensor *out = forward_pass_lstm_unilayer( - ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w, - lstm.l1_ih_b, lstm.l1_hh_b); + ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, l1_prefix); inpL = ggml_add(ctx0, inpL, out); } diff --git a/ggml b/ggml index aa00e16..bfaa6c8 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit aa00e1676417d8007dbaa47d2ee1a6e06c60d546 +Subproject commit bfaa6c897cfebf9db144f5360ca564593cbccfe3 diff --git a/lstm.h b/lstm.h index 31c251d..1de84cb 100644 --- a/lstm.h +++ b/lstm.h @@ -20,33 +20,36 @@ struct encodec_lstm { }; struct ggml_tensor *forward_pass_lstm_unilayer(struct ggml_context *ctx0, - struct ggml_allocr *allocr, - struct ggml_tensor *inp, - struct ggml_tensor *weight_ih, - struct ggml_tensor *weight_hh, - struct ggml_tensor *bias_ih, - struct ggml_tensor *bias_hh) { - const int input_dim = inp->ne[1]; - const int hidden_dim = weight_ih->ne[1] / 4; + struct ggml_tensor *inp, + struct ggml_tensor *weight_ih, + struct ggml_tensor *weight_hh, + struct ggml_tensor *bias_ih, + struct ggml_tensor *bias_hh, + char *prefix) { const int seq_length = inp->ne[0]; + const int input_dim = inp->ne[1]; + const int hidden_dim = weight_ih->ne[1] / 4; + + char ct_name[10]; + char ht_name[10]; + + snprintf(ct_name, 10, "%s_ct", prefix); + snprintf(ht_name, 10, "%s_ht", prefix); struct ggml_tensor *hs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); - ggml_allocr_alloc(allocr, hs); + ggml_set_input(hs); struct ggml_tensor *c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); - ggml_allocr_alloc(allocr, c_t); + ggml_set_input(c_t); + ggml_set_name(c_t, ct_name); struct ggml_tensor *h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); - ggml_allocr_alloc(allocr, h_t); - - if (!ggml_allocr_is_measure(allocr)) { - h_t = ggml_set_zero(h_t); - c_t = ggml_set_zero(c_t); - } + ggml_set_input(h_t); + ggml_set_name(h_t, ht_name); struct ggml_tensor *current = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - for (int t = 0; t < seq_length; t++) { + for (int t = 0; t < 2; t++) { struct ggml_tensor *x_t = ggml_view_1d(ctx0, current, input_dim, t * current->nb[1]); struct ggml_tensor *inp_gates = ggml_mul_mat(ctx0, weight_ih, x_t); @@ -57,10 +60,10 @@ struct ggml_tensor *forward_pass_lstm_unilayer(struct ggml_context *ctx0, struct ggml_tensor *out_gates = ggml_add(ctx0, inp_gates, hid_gates); - struct ggml_tensor *i_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 0 * sizeof(float) * hidden_dim)); - struct ggml_tensor *f_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 1 * sizeof(float) * hidden_dim)); - struct ggml_tensor *g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 2 * sizeof(float) * hidden_dim)); - struct ggml_tensor *o_t = encodec_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3 * sizeof(float) * hidden_dim)); + struct ggml_tensor *i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 0 * sizeof(float) * hidden_dim)); + struct ggml_tensor *f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 1 * sizeof(float) * hidden_dim)); + struct ggml_tensor *g_t = ggml_tanh(ctx0 , ggml_view_1d(ctx0, out_gates, hidden_dim, 2 * sizeof(float) * hidden_dim)); + struct ggml_tensor *o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3 * sizeof(float) * hidden_dim)); c_t = ggml_add(ctx0, ggml_mul(ctx0, f_t, c_t), ggml_mul(ctx0, i_t, g_t)); diff --git a/ops.cpp b/ops.cpp index 18c0acc..b245f42 100644 --- a/ops.cpp +++ b/ops.cpp @@ -7,26 +7,6 @@ #include "ops.h" -static void encodec_sigmoid_impl(struct ggml_tensor *dst, const struct ggml_tensor *src, - int ith, int nth, void *userdata) { - GGML_ASSERT(userdata == NULL); - GGML_ASSERT(ggml_are_same_shape(dst, src)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src)); - - const float *src_data = ggml_get_data_f32(src); - float *dst_data = ggml_get_data_f32(dst); - - const int ne = (int)ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = std::min(ie0 + dr, ne); - - for (int i = ie0; i < ie1; ++i) { - dst_data[i] = 1.0f / (1.0f + expf(-src_data[i])); - } -} - static int get_extra_padding_for_conv_1d(struct ggml_tensor *inp, float kernel_size, float stride, float padding_total) { float length = inp->ne[0]; @@ -35,10 +15,6 @@ static int get_extra_padding_for_conv_1d(struct ggml_tensor *inp, float kernel_s return ideal_length - length; } -struct ggml_tensor *encodec_sigmoid(struct ggml_context *ctx, struct ggml_tensor *x) { - return ggml_map_custom1(ctx, x, encodec_sigmoid_impl, GGML_N_TASKS_MAX, NULL); -} - struct ggml_tensor *pad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, int padding_left, int padding_right) { int length = inp->ne[0]; @@ -52,11 +28,10 @@ struct ggml_tensor *pad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, // constant padding struct ggml_tensor *out = ggml_new_tensor_2d(ctx0, inp->type, length + extra_pad, dim); - ggml_set_zero(out); out = ggml_set_2d(ctx0, out, inp, out->nb[1], 0); } - struct ggml_tensor *padded = ggml_pad_reflec_1d(ctx0, inp, padding_left, padding_right); + struct ggml_tensor *padded = ggml_pad_reflect_1d(ctx0, inp, padding_left, padding_right); const int end = padded->ne[0] - extra_pad; struct ggml_tensor *dest = ggml_view_2d(ctx0, padded, end, dim, padded->nb[1], 0); diff --git a/ops.h b/ops.h index 891aa90..e935b91 100644 --- a/ops.h +++ b/ops.h @@ -2,8 +2,6 @@ #include "ggml.h" -struct ggml_tensor *encodec_sigmoid(struct ggml_context *ctx, struct ggml_tensor *x); - struct ggml_tensor *pad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, int padding_left, int padding_right); diff --git a/quantizer.h b/quantizer.h index 523c594..9986561 100644 --- a/quantizer.h +++ b/quantizer.h @@ -18,7 +18,7 @@ struct encodec_quantizer { }; struct ggml_tensor *encodec_forward_quantizer_encode( - const struct encodec_quantizer *quantizer, struct ggml_allocr *allocr, struct ggml_context *ctx0, + const struct encodec_quantizer *quantizer, struct ggml_context *ctx0, struct ggml_tensor *encoded_inp, const int n_bins, const int sr, const int bandwidth, const int hop_length) { @@ -33,15 +33,7 @@ struct ggml_tensor *encodec_forward_quantizer_encode( const int seq_length = encoded_inp->ne[0]; struct ggml_tensor *codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q); - ggml_allocr_alloc(allocr, codes); - - struct ggml_tensor *dist_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - ggml_allocr_alloc(allocr, dist_scale); - - if (!ggml_allocr_is_measure(allocr)) { - float s = -2.0f; - ggml_backend_tensor_set(dist_scale, &s, 0, sizeof(s)); - } + ggml_set_input(codes); struct ggml_tensor *inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp)); struct ggml_tensor *residual = inpL; @@ -53,7 +45,7 @@ struct ggml_tensor *encodec_forward_quantizer_encode( // compute distance // [seq_length, n_bins] struct ggml_tensor *dp = ggml_scale( - ctx0, ggml_mul_mat(ctx0, block.embed, residual), dist_scale); + ctx0, ggml_mul_mat(ctx0, block.embed, residual), -2.0f); // [n_bins] struct ggml_tensor *sqr_embed = ggml_sqr(ctx0, block.embed); @@ -84,7 +76,7 @@ struct ggml_tensor *encodec_forward_quantizer_encode( } struct ggml_tensor *encodec_forward_quantizer_decode( - const struct encodec_quantizer *quantizer, struct ggml_allocr *allocr, struct ggml_context *ctx0, + const struct encodec_quantizer *quantizer, struct ggml_context *ctx0, struct ggml_tensor *codes, const int hidden_dim, const int n_bins, const int sr, const int bandwidth, const int hop_length) { @@ -101,11 +93,8 @@ struct ggml_tensor *encodec_forward_quantizer_decode( assert(n_q == codes->ne[1]); struct ggml_tensor *quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); - ggml_allocr_alloc(allocr, quantized_out); - - if (!ggml_allocr_is_measure(allocr)) { - quantized_out = ggml_set_zero(quantized_out); - } + ggml_set_input(quantized_out); + ggml_set_name(quantized_out, "quantized_out"); for (int i = 0; i < n_q; i++) { encodec_quant_block block = quantizer->blocks[i]; From 13d6ed1b38eb671f66cb600d7bf4cffb4111c3c1 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 20 Oct 2024 00:43:27 +0200 Subject: [PATCH 2/6] working lstm --- .gitignore | 4 +- CMakeLists.txt | 3 + encodec.cpp | 178 ++++++++++++++++++++++++++++++++----------------- lstm.h | 2 +- 4 files changed, 122 insertions(+), 65 deletions(-) diff --git a/.gitignore b/.gitignore index 2738ef2..9bd2f1d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,6 @@ encodec *.th .vscode/ -build/ \ No newline at end of file +build/ + +*.wav diff --git a/CMakeLists.txt b/CMakeLists.txt index 693921c..df1ab72 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,9 @@ endif() set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_CXX_FLAGS_RELEASE "-O3") +set(CMAKE_CXX_FLAGS_DEBUG "-g -O0") + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(ENCODEC_STANDALONE ON) diff --git a/encodec.cpp b/encodec.cpp index e770f12..3b8504f 100644 --- a/encodec.cpp +++ b/encodec.cpp @@ -1,6 +1,7 @@ #include "ggml-alloc.h" #include "ggml-backend.h" #include "ggml.h" +#include "ggml/src/ggml-impl.h" #ifdef GGML_USE_CUBLAS #include "ggml-cuda.h" @@ -32,6 +33,7 @@ #include "quantizer.h" #define ENCODEC_FILE_MAGIC 'ggml' +#define ENCODEC_MAX_NODES 80000 typedef enum { // Run the end-to-end encoder-decoder pipeline @@ -96,9 +98,28 @@ struct encodec_model { std::map tensors; }; +struct encodec_ggml_cgraph_deleter { + void operator()(struct ggml_cgraph * cgraph) { + if (cgraph->nodes) + free(cgraph->nodes); + if (cgraph->leafs) + free(cgraph->leafs); + if (cgraph->visited_hash_set.keys) + free(cgraph->visited_hash_set.keys); + if (cgraph->grads) + free(cgraph->grads); + free(cgraph); + } +}; + struct encodec_context { encodec_model model; + // computational graph stored on the heap to avoid stack overflows + // the computational graph grows with the sequence length (because of the LSTM) + // which requires a lot of nodes + std::unique_ptr gf; + // buffer for model evaluation ggml_backend_buffer_t buf_compute; @@ -201,7 +222,6 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int #ifdef GGML_USE_METAL if (n_gpu_layers > 0) { fprintf(stderr, "%s: using Metal backend\n", __func__); - ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); model.backend = ggml_backend_metal_init(); if (!model.backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); @@ -473,7 +493,7 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int model.n_loaded++; } - printf("%s: model size = %8.2f MB\n", __func__, total_size / 1024.0 / 1024.0); + printf("%s: model size = %.2f MB\n", __func__, total_size / 1024.0 / 1024.0); } infile.close(); @@ -481,16 +501,57 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int return true; } -struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, - const float * inp_audio, - const int n_samples, - const encodec_run_mode_t mode) { +static struct ggml_cgraph * encodec_ggml_cgraph_create(size_t size) { + struct ggml_cgraph * cgraph = (struct ggml_cgraph *)calloc(1, sizeof(struct ggml_cgraph)); + cgraph->size = size; + cgraph->n_nodes = 0; + cgraph->n_leafs = 0; + cgraph->nodes = (struct ggml_tensor **)calloc(1, size * sizeof(struct ggml_tensor *)); + cgraph->leafs = (struct ggml_tensor **)calloc(1, size * sizeof(struct ggml_tensor *)); + + // next primes after powers of two + static const size_t primes[] = { + 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, + 2053, 4099, 8209, 16411, 32771, 65537, 131101, + 262147, 524309, 1048583, 2097169, 4194319, 8388617, + 16777259, 33554467, 67108879, 134217757, 268435459, + 536870923, 1073741827, 2147483659 + }; + static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); + + // find the smallest prime that is larger or equal to size + size_t l = 0; + size_t r = n_primes; + while (l < r) { + size_t m = (l + r)/2; + if (primes[m] < size * 2) { + l = m + 1; + } else { + r = m; + } + } + size_t hash_size = l < n_primes ? primes[l] : (size * 2 + 1); + + cgraph->visited_hash_set.size = hash_size; + cgraph->visited_hash_set.keys = (struct ggml_tensor **)calloc(1, hash_size * sizeof(struct ggml_tensor *)); + cgraph->visited_hash_set.used = (ggml_bitset_t *)calloc(1, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t)); + cgraph->order = GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT; + + return cgraph; +} + +void encodec_build_graph(struct encodec_context *ectx, + const float * inp_audio, + const int n_samples, + const encodec_run_mode_t mode) { assert(mode == encodec_run_mode_t::FULL || mode == encodec_run_mode_t::ENCODE); const auto & model = ectx->model; const auto & hparams = model.hparams; const auto & allocr = ectx->allocr; + auto & gf = ectx->gf; + const int *ratios = hparams.ratios; const int kernel_size = hparams.kernel_size; const int res_kernel_sz = hparams.residual_kernel_size; @@ -504,7 +565,7 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, // since we are using ggml-alloc, this buffer only needs enough space to hold the // ggml_tensor and ggml_cgraph structs, but not the tensor data - static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static size_t buf_size = ggml_tensor_overhead() * ENCODEC_MAX_NODES + ggml_graph_overhead(); static std::vector buf(buf_size); struct ggml_init_params ggml_params = { @@ -515,7 +576,7 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, struct ggml_context *ctx0 = ggml_init(ggml_params); - struct ggml_cgraph *gf = ggml_new_graph(ctx0); + gf = std::unique_ptr(encodec_ggml_cgraph_create(ENCODEC_MAX_NODES)); struct ggml_tensor *inp = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_samples); ggml_set_name(inp, "inp"); @@ -541,19 +602,18 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, case encodec_run_mode_t::FULL: { ggml_set_name(decoded, "decoded"); ggml_set_output(decoded); - ggml_build_forward_expand(gf, decoded); + ggml_build_forward_expand(gf.get(), decoded); } break; case encodec_run_mode_t::ENCODE: { ggml_set_name(codes, "codes"); ggml_set_output(codes); - ggml_build_forward_expand(gf, codes); + ggml_build_forward_expand(gf.get(), codes); } break; case encodec_run_mode_t::DECODE: { - return NULL; + assert(false); } break; default: { fprintf(stderr, "%s: unknown run mode\n", __func__); - return NULL; } break; } @@ -562,18 +622,18 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, ectx->encoded = encoded; ectx->codes = codes; ectx->decoded = decoded; - - return gf; } -struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int32_t *codes, - const int n_codes, const encodec_run_mode_t mode) { +void encodec_build_graph(struct encodec_context *ectx, const int32_t *codes, + const int n_codes, const encodec_run_mode_t mode) { assert(mode == encodec_run_mode_t::DECODE); const auto & model = ectx->model; const auto & hparams = model.hparams; const auto & allocr = ectx->allocr; + auto & gf = ectx->gf; + const int n_bins = hparams.n_bins; const int sr = hparams.sr; const int bandwidth = hparams.bandwidth; @@ -589,12 +649,12 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int3 if (n_codes % n_q != 0) { fprintf(stderr, "%s: invalid number of codes\n", __func__); - return NULL; + assert(false); } const int N = n_codes / n_q; - static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static size_t buf_size = ggml_tensor_overhead() * ENCODEC_MAX_NODES + ggml_graph_overhead(); static std::vector buf(buf_size); struct ggml_init_params ggml_params = { @@ -605,7 +665,7 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int3 struct ggml_context *ctx0 = ggml_init(ggml_params); - struct ggml_cgraph *gf = ggml_new_graph(ctx0); + gf = std::unique_ptr(encodec_ggml_cgraph_create(ENCODEC_MAX_NODES)); struct ggml_tensor *inp_codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, N, n_q); ggml_set_name(inp_codes, "inp_codes"); @@ -626,11 +686,11 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int3 case encodec_run_mode_t::DECODE: { ggml_set_name(decoded, "decoded"); ggml_set_output(decoded); - ggml_build_forward_expand(gf, decoded); + ggml_build_forward_expand(gf.get(), decoded); } break; default: { fprintf(stderr, "%s: unknown run mode\n", __func__); - return NULL; + assert(false); } break; } @@ -638,8 +698,6 @@ struct ggml_cgraph *encodec_build_graph(struct encodec_context *ectx, const int3 ectx->codes = inp_codes; ectx->decoded = decoded; - - return gf; } static void encodec_zero_tensor(struct ggml_cgraph *gf, const char *name) { @@ -652,39 +710,36 @@ bool encodec_eval_internal(struct encodec_context *ectx, const float * raw_audio const encodec_run_mode_t mode) { auto & model = ectx->model; auto & allocr = ectx->allocr; + auto & gf = ectx->gf; - struct ggml_cgraph *gf = encodec_build_graph(ectx, raw_audio, n_samples, mode); + encodec_build_graph(ectx, raw_audio, n_samples, mode); // allocate the graph tensors - ggml_gallocr_alloc_graph(allocr, gf); + ggml_gallocr_alloc_graph(allocr, gf.get()); // set the graph inputs - struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "inp"); + struct ggml_tensor * inp = ggml_graph_get_tensor(gf.get(), "inp"); ggml_backend_tensor_set(inp, raw_audio, 0, n_samples * ggml_element_size(inp)); // make sure accumulation tensor are zeroed - encodec_zero_tensor(gf, "enc_l0_ht"); - encodec_zero_tensor(gf, "enc_l1_ht"); - encodec_zero_tensor(gf, "enc_l0_ct"); - encodec_zero_tensor(gf, "enc_l1_ct"); + encodec_zero_tensor(gf.get(), "enc_l0_ht"); + encodec_zero_tensor(gf.get(), "enc_l1_ht"); + encodec_zero_tensor(gf.get(), "enc_l0_ct"); + encodec_zero_tensor(gf.get(), "enc_l1_ct"); - encodec_zero_tensor(gf, "dec_l0_ht"); - encodec_zero_tensor(gf, "dec_l1_ht"); - encodec_zero_tensor(gf, "dec_l0_ct"); - encodec_zero_tensor(gf, "dec_l1_ct"); + encodec_zero_tensor(gf.get(), "dec_l0_ht"); + encodec_zero_tensor(gf.get(), "dec_l1_ht"); + encodec_zero_tensor(gf.get(), "dec_l0_ct"); + encodec_zero_tensor(gf.get(), "dec_l1_ct"); - encodec_zero_tensor(gf, "quantized_out"); + encodec_zero_tensor(gf.get(), "quantized_out"); // run the computation if (ggml_backend_is_cpu(model.backend)) { ggml_backend_cpu_set_n_threads(model.backend, n_threads); } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif - ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_graph_compute(model.backend, gf.get()); return true; } @@ -694,39 +749,36 @@ bool encodec_eval_internal(struct encodec_context *ectx, const int32_t *codes, const encodec_run_mode_t mode) { auto & model = ectx->model; auto & allocr = ectx->allocr; + auto & gf = ectx->gf; - struct ggml_cgraph *gf = encodec_build_graph(ectx, codes, n_codes, mode); + encodec_build_graph(ectx, codes, n_codes, mode); // allocate the graph tensors - ggml_gallocr_alloc_graph(allocr, gf); + ggml_gallocr_alloc_graph(allocr, gf.get()); // set the graph inputs - struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "inp_codes"); + struct ggml_tensor * inp = ggml_graph_get_tensor(gf.get(), "inp_codes"); ggml_backend_tensor_set(inp, codes, 0, n_codes * ggml_element_size(inp)); // make sure accumulation tensor are zeroed - encodec_zero_tensor(gf, "enc_l0_ht"); - encodec_zero_tensor(gf, "enc_l1_ht"); - encodec_zero_tensor(gf, "enc_l0_ct"); - encodec_zero_tensor(gf, "enc_l1_ct"); + encodec_zero_tensor(gf.get(), "enc_l0_ht"); + encodec_zero_tensor(gf.get(), "enc_l1_ht"); + encodec_zero_tensor(gf.get(), "enc_l0_ct"); + encodec_zero_tensor(gf.get(), "enc_l1_ct"); - encodec_zero_tensor(gf, "dec_l0_ht"); - encodec_zero_tensor(gf, "dec_l1_ht"); - encodec_zero_tensor(gf, "dec_l0_ct"); - encodec_zero_tensor(gf, "dec_l1_ct"); + encodec_zero_tensor(gf.get(), "dec_l0_ht"); + encodec_zero_tensor(gf.get(), "dec_l1_ht"); + encodec_zero_tensor(gf.get(), "dec_l0_ct"); + encodec_zero_tensor(gf.get(), "dec_l1_ct"); - encodec_zero_tensor(gf, "quantized_out"); + encodec_zero_tensor(gf.get(), "quantized_out"); // run the computation if (ggml_backend_is_cpu(model.backend)) { ggml_backend_cpu_set_n_threads(model.backend, n_threads); } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(model.backend)) { - ggml_backend_metal_set_n_cb(model.backend, n_threads); - } -#endif - ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_graph_compute(model.backend, gf.get()); return true; } @@ -742,10 +794,10 @@ bool encodec_eval(struct encodec_context *ectx, const float *raw_audio, ectx->allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(ectx->model.backend)); // create the graph for memory usage estimation - struct ggml_cgraph *gf = encodec_build_graph(ectx, raw_audio, n_samples, mode); + encodec_build_graph(ectx, raw_audio, n_samples, mode); // pre-allocate the compute buffer - ggml_gallocr_reserve(ectx->allocr, gf); + ggml_gallocr_reserve(ectx->allocr, ectx->gf.get()); size_t mem_size = ggml_gallocr_get_buffer_size(ectx->allocr, 0); fprintf(stderr, "%s: compute buffer size: %.2f MB\n\n", __func__, mem_size / 1024.0 / 1024.0); } @@ -772,10 +824,10 @@ bool encodec_eval(struct encodec_context *ectx, const int32_t *codes, ectx->allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(ectx->model.backend)); // create the graph for memory usage estimation - struct ggml_cgraph *gf = encodec_build_graph(ectx, codes, n_codes, mode); + encodec_build_graph(ectx, codes, n_codes, mode); // pre-allocate the compute buffer - ggml_gallocr_reserve(ectx->allocr, gf); + ggml_gallocr_reserve(ectx->allocr, ectx->gf.get()); size_t mem_size = ggml_gallocr_get_buffer_size(ectx->allocr, 0); fprintf(stderr, "%s: compute buffer size: %.2f MB\n\n", __func__, mem_size / 1024.0 / 1024.0); } diff --git a/lstm.h b/lstm.h index 1de84cb..ac23ed1 100644 --- a/lstm.h +++ b/lstm.h @@ -49,7 +49,7 @@ struct ggml_tensor *forward_pass_lstm_unilayer(struct ggml_context *ctx0, struct ggml_tensor *current = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - for (int t = 0; t < 2; t++) { + for (int t = 0; t < seq_length; t++) { struct ggml_tensor *x_t = ggml_view_1d(ctx0, current, input_dim, t * current->nb[1]); struct ggml_tensor *inp_gates = ggml_mul_mat(ctx0, weight_ih, x_t); From 61d81c53aaf9c7411ce0271cbe632057f9c4bbbe Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 20 Oct 2024 00:54:18 +0200 Subject: [PATCH 3/6] CIs? --- .github/workflows/build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 70bf2e3..e9c4d8d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -24,6 +24,7 @@ jobs: id: checkout uses: actions/checkout@v4 with: + fetch-depth: 0 submodules: true - name: Dependencies @@ -48,6 +49,7 @@ jobs: id: checkout uses: actions/checkout@v4 with: + fetch-depth: 0 submodules: true - name: Dependencies @@ -75,6 +77,7 @@ jobs: - name: Clone uses: actions/checkout@v4 with: + fetch-depth: 0 submodules: recursive - name: Setup UCRT64 From c4d0ce52e1d09d6087533172b193cbb09d511d1b Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 20 Oct 2024 00:56:39 +0200 Subject: [PATCH 4/6] updated to lastest CI --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index bfaa6c8..3c29eda 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit bfaa6c897cfebf9db144f5360ca564593cbccfe3 +Subproject commit 3c29eda2f4e2b6ef581f1031d8b8891c4dd4c7df From 039c7e4a518c262df6c1a93ff2238ecf311edb3c Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 20 Oct 2024 09:58:03 +0200 Subject: [PATCH 5/6] comment --- encodec.cpp | 3 +++ utils.h | 7 ------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/encodec.cpp b/encodec.cpp index 3b8504f..2a1b779 100644 --- a/encodec.cpp +++ b/encodec.cpp @@ -501,6 +501,9 @@ bool encodec_load_model_weights(std::ifstream &infile, encodec_model &model, int return true; } +// Create a new ggml_cgraph with the given size (usually ENCODEC_MAX_NODES). We need a +// custom function since the graph is so large, it overpasses the max built-in ggml +// default size. static struct ggml_cgraph * encodec_ggml_cgraph_create(size_t size) { struct ggml_cgraph * cgraph = (struct ggml_cgraph *)calloc(1, sizeof(struct ggml_cgraph)); cgraph->size = size; diff --git a/utils.h b/utils.h index a5d72fa..fd628b8 100644 --- a/utils.h +++ b/utils.h @@ -28,10 +28,3 @@ int32_t get_num_quantizers_for_bandwidth(int bins, float frame_rate, float bandw int32_t n_q = MAX(1, floorf(bandwidth * 1000 / bw_per_q)); return n_q; } - -void ggml_log_callback_default(ggml_log_level level, const char *text, void *user_data) { - (void)level; - (void)user_data; - fputs(text, stderr); - fflush(stderr); -} From e2c24ef7ec767d506b59fbedcb6216d954302156 Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Bannier Date: Sun, 20 Oct 2024 10:16:51 +0200 Subject: [PATCH 6/6] update ggml hash commit --- ggml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml b/ggml index 3c29eda..c18f9ba 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 3c29eda2f4e2b6ef581f1031d8b8891c4dd4c7df +Subproject commit c18f9baeea2f3aea1ffc4afa4ad4496e51b7ff8a