-
Notifications
You must be signed in to change notification settings - Fork 18
/
quantizer.h
111 lines (79 loc) · 3.51 KB
/
quantizer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once
#include <cassert>
#include <vector>
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "utils.h"
struct encodec_quant_block {
struct ggml_tensor *embed;
};
struct encodec_quantizer {
std::vector<encodec_quant_block> blocks;
};
struct ggml_tensor *encodec_forward_quantizer_encode(
const struct encodec_quantizer *quantizer, struct ggml_context *ctx0,
struct ggml_tensor *encoded_inp, const int n_bins, const int sr, const int bandwidth,
const int hop_length) {
if (!encoded_inp) {
fprintf(stderr, "%s: null input tensor\n", __func__);
return NULL;
}
const int frame_rate = (int)ceilf(sr / hop_length);
const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth);
const int seq_length = encoded_inp->ne[0];
struct ggml_tensor *codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q);
ggml_set_input(codes);
struct ggml_tensor *inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp));
struct ggml_tensor *residual = inpL;
struct ggml_tensor *indices;
for (int i = 0; i < n_q; i++) {
encodec_quant_block block = quantizer->blocks[i];
// compute distance
// [seq_length, n_bins]
struct ggml_tensor *dp = ggml_scale(
ctx0, ggml_mul_mat(ctx0, block.embed, residual), -2.0f);
// [n_bins]
struct ggml_tensor *sqr_embed = ggml_sqr(ctx0, block.embed);
struct ggml_tensor *sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed);
// [seq_length]
struct ggml_tensor *sqr_inp = ggml_sqr(ctx0, residual);
struct ggml_tensor *sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp);
// [seq_length, n_bins]
struct ggml_tensor *dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp);
dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist);
dist = ggml_neg(ctx0, dist);
// take the argmax over the column dimension
// [seq_length]
indices = ggml_argmax(ctx0, dist);
// look up in embedding table
struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices);
residual = ggml_sub(ctx0, residual, quantized);
codes = ggml_set_1d(ctx0, codes, indices, i * codes->nb[1]);
}
return codes;
}
struct ggml_tensor *encodec_forward_quantizer_decode(
const struct encodec_quantizer *quantizer, struct ggml_context *ctx0,
struct ggml_tensor *codes, const int hidden_dim, const int n_bins, const int sr, const int bandwidth,
const int hop_length) {
if (!codes) {
fprintf(stderr, "%s: null input tensor\n", __func__);
return NULL;
}
const int seq_length = codes->ne[0];
const int frame_rate = (int)ceilf(sr / hop_length);
const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth);
assert(n_q == codes->ne[1]);
struct ggml_tensor *quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length);
ggml_set_input(quantized_out);
ggml_set_name(quantized_out, "quantized_out");
for (int i = 0; i < n_q; i++) {
encodec_quant_block block = quantizer->blocks[i];
struct ggml_tensor *indices = ggml_view_1d(ctx0, codes, seq_length, i * codes->nb[1]);
struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices);
quantized_out = ggml_add(ctx0, quantized_out, quantized);
}
quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out));
return quantized_out;
}