Skip to content

Commit

Permalink
Smaller version of fargan
Browse files Browse the repository at this point in the history
800k parameters, 600 MFLOPS
  • Loading branch information
jmvalin committed Nov 15, 2023
1 parent b0620c0 commit 6fb5f0f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
12 changes: 6 additions & 6 deletions dnn/fargan.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ static void compute_fargan_cond(FARGANState *st, float *cond, const float *featu
FARGAN *model;
float dense_in[NB_FEATURES+COND_NET_PEMBED_OUT_SIZE];
float conv1_in[COND_NET_FCONV1_IN_SIZE];
float conv2_in[COND_NET_FCONV2_IN_SIZE];
float fdense2_in[COND_NET_FCONV1_OUT_SIZE];
model = &st->model;
celt_assert(FARGAN_FEATURES+COND_NET_PEMBED_OUT_SIZE == model->cond_net_fdense1.nb_inputs);
celt_assert(COND_NET_FCONV1_IN_SIZE == model->cond_net_fdense1.nb_outputs);
celt_assert(COND_NET_FCONV2_IN_SIZE == model->cond_net_fconv1.nb_outputs);
celt_assert(COND_NET_FCONV1_OUT_SIZE == model->cond_net_fconv1.nb_outputs);
OPUS_COPY(&dense_in[NB_FEATURES], &model->cond_net_pembed.float_weights[IMAX(0,IMIN(period-32, 224))*COND_NET_PEMBED_OUT_SIZE], COND_NET_PEMBED_OUT_SIZE);
OPUS_COPY(dense_in, features, NB_FEATURES);

compute_generic_dense(&model->cond_net_fdense1, conv1_in, dense_in, ACTIVATION_TANH);
compute_generic_conv1d(&model->cond_net_fconv1, conv2_in, st->cond_conv1_state, conv1_in, COND_NET_FCONV1_IN_SIZE, ACTIVATION_TANH);
compute_generic_conv1d(&model->cond_net_fconv2, cond, st->cond_conv2_state, conv2_in, COND_NET_FCONV2_IN_SIZE, ACTIVATION_TANH);
compute_generic_conv1d(&model->cond_net_fconv1, fdense2_in, st->cond_conv1_state, conv1_in, COND_NET_FCONV1_IN_SIZE, ACTIVATION_TANH);
compute_generic_dense(&model->cond_net_fdense2, cond, fdense2_in, ACTIVATION_TANH);
}

static void fargan_deemphasis(float *pcm, float *deemph_mem) {
Expand Down Expand Up @@ -141,7 +141,7 @@ static void run_fargan_subframe(FARGANState *st, float *pcm, const float *cond,
void fargan_cont(FARGANState *st, const float *pcm0, const float *features0)
{
int i;
float cond[COND_NET_FCONV2_OUT_SIZE];
float cond[COND_NET_FDENSE2_OUT_SIZE];
float x0[FARGAN_CONT_SAMPLES];
float dummy[FARGAN_SUBFRAME_SIZE];
int period=0;
Expand Down Expand Up @@ -196,7 +196,7 @@ int fargan_load_model(FARGANState *st, const unsigned char *data, int len) {
static void fargan_synthesize_impl(FARGANState *st, float *pcm, const float *features)
{
int subframe;
float cond[COND_NET_FCONV2_OUT_SIZE];
float cond[COND_NET_FDENSE2_OUT_SIZE];
int period;
celt_assert(st->cont_initialized);

Expand Down
3 changes: 1 addition & 2 deletions dnn/fargan.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#define FARGAN_NB_SUBFRAMES 4
#define FARGAN_SUBFRAME_SIZE 40
#define FARGAN_FRAME_SIZE (FARGAN_NB_SUBFRAMES*FARGAN_SUBFRAME_SIZE)
#define FARGAN_COND_SIZE (COND_NET_FCONV2_OUT_SIZE/FARGAN_NB_SUBFRAMES)
#define FARGAN_COND_SIZE (COND_NET_FDENSE2_OUT_SIZE/FARGAN_NB_SUBFRAMES)
#define FARGAN_DEEMPHASIS 0.85f

#define SIG_NET_INPUT_SIZE (FARGAN_COND_SIZE+2*FARGAN_SUBFRAME_SIZE+4)
Expand All @@ -49,7 +49,6 @@ typedef struct {
float deemph_mem;
float pitch_buf[PITCH_MAX_PERIOD];
float cond_conv1_state[COND_NET_FCONV1_STATE_SIZE];
float cond_conv2_state[COND_NET_FCONV2_STATE_SIZE];
float fwc0_mem[SIG_NET_FWC0_STATE_SIZE];
float gru1_state[SIG_NET_GRU1_STATE_SIZE];
float gru2_state[SIG_NET_GRU2_STATE_SIZE];
Expand Down
29 changes: 15 additions & 14 deletions dnn/torch/fargan/fargan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def forward(self, x):
return out

class FWConv(nn.Module):
def __init__(self, in_size, out_size, kernel_size=3):
def __init__(self, in_size, out_size, kernel_size=2):
super(FWConv, self).__init__()

torch.manual_seed(5)
Expand Down Expand Up @@ -163,20 +163,22 @@ def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12):
self.pembed = nn.Embedding(224, pembed_dims)
self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False)
self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False)
self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False)
self.fdense2 = nn.Linear(128, 80*4, bias=False)

self.apply(init_weights)
nb_params = sum(p.numel() for p in self.parameters())
print(f"cond model: {nb_params} weights")

def forward(self, features, period):
features = features[:,2:,:]
period = period[:,2:]
p = self.pembed(period-32)
features = torch.cat((features, p), -1)
tmp = torch.tanh(self.fdense1(features))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fconv1(tmp))
tmp = torch.tanh(self.fconv2(tmp))
tmp = tmp.permute(0, 2, 1)
tmp = torch.tanh(self.fdense2(tmp))
#tmp = torch.tanh(self.fdense2(tmp))
return tmp

Expand All @@ -190,21 +192,20 @@ def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256):
self.cond_gain_dense = nn.Linear(80, 1)

#self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False)
self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size)
self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False)
self.fwc0 = FWConv(2*self.subframe_size+80+4, 192)
self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False)
self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False)
self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False)

self.dense1_glu = GLU(self.cond_size)
self.gru1_glu = GLU(self.cond_size)
self.gru1_glu = GLU(160)
self.gru2_glu = GLU(128)
self.gru3_glu = GLU(128)
self.skip_glu = GLU(self.cond_size)
self.skip_glu = GLU(128)
#self.ptaps_dense = nn.Linear(4*self.cond_size, 5)

self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False)
self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(self.cond_size, 4)
self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False)
self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False)
self.gain_dense_out = nn.Linear(192, 4)


self.apply(init_weights)
Expand Down Expand Up @@ -291,10 +292,10 @@ def forward(self, features, period, nb_frames, pre=None, states=None):
nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0

states = (
torch.zeros(batch_size, self.cond_size, device=device),
torch.zeros(batch_size, 160, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, 128, device=device),
torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device)
torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device)
)

sig = torch.zeros((batch_size, 0), device=device)
Expand Down

0 comments on commit 6fb5f0f

Please sign in to comment.