Skip to content

Commit

Permalink
fixed quantizer decode part
Browse files Browse the repository at this point in the history
  • Loading branch information
PABannier committed Oct 18, 2023
1 parent 672c1b7 commit 1885a72
Showing 1 changed file with 72 additions and 68 deletions.
140 changes: 72 additions & 68 deletions encodec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,96 +764,98 @@ struct ggml_cgraph * encodec_graph(
inpL = ggml_add(ctx0, inpL, out);
}

bp = inpL;

// // final conv
// {
// inpL = ggml_elu(ctx0, inpL);
// final conv
{
inpL = ggml_elu(ctx0, inpL);

// encoded_inp = strided_conv_1d(
// ctx0, inpL, model.encoder.final_conv_w, model.encoder.final_conv_b, stride);
// }
encoded_inp = strided_conv_1d(
ctx0, inpL, model.encoder.final_conv_w, model.encoder.final_conv_b, stride);
}
}

// quantizer (encode)
// struct ggml_tensor * codes;
// {
// const auto & hparams = model.hparams;
// // originally, n_q = n_q or len(self.layers)
// // for this model, n_q is at most 32, but the implementation we are comparing
// // our model against has only 16, hence we hardcode 16 as n_q for now.
// // const int n_q = hparams.n_q;
// const int n_q = 16;
struct ggml_tensor * codes;
{
const auto & hparams = model.hparams;
// originally, n_q = n_q or len(self.layers)
// for this model, n_q is at most 32, but the implementation we are comparing
// our model against has only 16, hence we hardcode 16 as n_q for now.
// const int n_q = hparams.n_q;
const int n_q = 16;

// const int seq_length = encoded_inp->ne[0];
// codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q);
const int seq_length = encoded_inp->ne[0];
codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q);

// struct ggml_tensor * inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp));
// struct ggml_tensor * residual = inpL;
// struct ggml_tensor * indices;
struct ggml_tensor * inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp));
struct ggml_tensor * residual = inpL;
struct ggml_tensor * indices;

// for (int i = 0; i < n_q; i++) {
// encodec_quant_block block = model.quantizer.blocks[i];
for (int i = 0; i < n_q; i++) {
encodec_quant_block block = model.quantizer.blocks[i];

// // compute distance
// // [seq_length, n_bins]
// struct ggml_tensor * dp = ggml_scale(
// ctx0, ggml_mul_mat(ctx0, block.embed, residual), ggml_new_f32(ctx0, -2.0f));
// compute distance
// [seq_length, n_bins]
struct ggml_tensor * dp = ggml_scale(
ctx0, ggml_mul_mat(ctx0, block.embed, residual), ggml_new_f32(ctx0, -2.0f));

// // [n_bins]
// struct ggml_tensor * sqr_embed = ggml_sqr(ctx0, block.embed);
// struct ggml_tensor * sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed);
// [n_bins]
struct ggml_tensor * sqr_embed = ggml_sqr(ctx0, block.embed);
struct ggml_tensor * sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed);

// // [seq_length]
// struct ggml_tensor * sqr_inp = ggml_sqr(ctx0, residual);
// struct ggml_tensor * sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp);
// [seq_length]
struct ggml_tensor * sqr_inp = ggml_sqr(ctx0, residual);
struct ggml_tensor * sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp);

// // [seq_length, n_bins]
// struct ggml_tensor * dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp);
// dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist);
// dist = ggml_scale(ctx0, dist, ggml_new_f32(ctx0, -1.0f));
// [seq_length, n_bins]
struct ggml_tensor * dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp);
dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist);
dist = ggml_scale(ctx0, dist, ggml_new_f32(ctx0, -1.0f));

// // take the argmax over the column dimension
// // [seq_length]
// indices = ggml_argmax(ctx0, dist);
// take the argmax over the column dimension
// [seq_length]
indices = ggml_argmax(ctx0, dist);

// // look up in embedding table
// struct ggml_tensor * quantized = ggml_get_rows(ctx0, block.embed, indices);
// look up in embedding table
struct ggml_tensor * quantized = ggml_get_rows(ctx0, block.embed, indices);

// residual = ggml_sub(ctx0, residual, quantized);
residual = ggml_sub(ctx0, residual, quantized);

// codes = ggml_set_1d(ctx0, codes, indices, i*codes->nb[1]);
// }
codes = ggml_set_1d(ctx0, codes, indices, i*codes->nb[1]);
}

// }
}

// // quantizer (decode)
// struct ggml_tensor * quantized_out;
// {
// const auto & hparams = model.hparams;
// const int hidden_dim = hparams.hidden_dim;
// quantizer (decode)
struct ggml_tensor * quantized_out;
{
const auto & hparams = model.hparams;
const int hidden_dim = hparams.hidden_dim;

// const int seq_length = codes->ne[0];
// const int n_q = codes->ne[1];
const int seq_length = codes->ne[0];
const int n_q = codes->ne[1];

// quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length);
// // if (!ggml_allocr_is_measure(ectx.allocr)) {
// // quantized_out = ggml_set_zero(quantized_out);
// // }
quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length);
ggml_allocr_alloc(allocr, quantized_out);

// for (int i = 0; i < n_q; i++) {
// encodec_quant_block block = model.quantizer.blocks[i];
if (!ggml_allocr_is_measure(allocr)) {
quantized_out = ggml_set_zero(quantized_out);
}

// struct ggml_tensor * indices = ggml_view_1d(ctx0, codes, seq_length, i*codes->nb[1]);
// struct ggml_tensor * quantized = ggml_get_rows(ctx0, block.embed, indices);
for (int i = 0; i < n_q; i++) {
encodec_quant_block block = model.quantizer.blocks[i];

// quantized_out = ggml_add(ctx0, quantized_out, quantized);
// }
struct ggml_tensor * indices = ggml_view_1d(ctx0, codes, seq_length, i*codes->nb[1]);
struct ggml_tensor * quantized = ggml_get_rows(ctx0, block.embed, indices);

// quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out));
// }
quantized_out = ggml_add(ctx0, quantized_out, quantized);
}

quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out));
}

bp = quantized_out;

// // decoder
// decoder
// struct ggml_tensor * decoded_inp;
// struct ggml_tensor * out;
// {
Expand All @@ -876,11 +878,11 @@ struct ggml_cgraph * encodec_graph(

// // first lstm layer
// struct ggml_tensor * hs1 = forward_pass_lstm_unilayer(
// ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b);
// ctx0, allocr, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b);

// // second lstm layer
// struct ggml_tensor * out = forward_pass_lstm_unilayer(
// ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b);
// ctx0, allocr, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b);

// inpL = ggml_add(ctx0, inpL, out);
// }
Expand Down Expand Up @@ -927,6 +929,8 @@ struct ggml_cgraph * encodec_graph(
// out = decoded_inp;
// }

// bp = out;

ggml_build_forward_expand(gf, bp);

ggml_free(ctx0);
Expand Down

0 comments on commit 1885a72

Please sign in to comment.