diff --git a/dnn/fargan.c b/dnn/fargan.c index e0fa304cd..7fda611be 100644 --- a/dnn/fargan.c +++ b/dnn/fargan.c @@ -44,17 +44,17 @@ static void compute_fargan_cond(FARGANState *st, float *cond, const float *featu FARGAN *model; float dense_in[NB_FEATURES+COND_NET_PEMBED_OUT_SIZE]; float conv1_in[COND_NET_FCONV1_IN_SIZE]; - float conv2_in[COND_NET_FCONV2_IN_SIZE]; + float fdense2_in[COND_NET_FCONV1_OUT_SIZE]; model = &st->model; celt_assert(FARGAN_FEATURES+COND_NET_PEMBED_OUT_SIZE == model->cond_net_fdense1.nb_inputs); celt_assert(COND_NET_FCONV1_IN_SIZE == model->cond_net_fdense1.nb_outputs); - celt_assert(COND_NET_FCONV2_IN_SIZE == model->cond_net_fconv1.nb_outputs); + celt_assert(COND_NET_FCONV1_OUT_SIZE == model->cond_net_fconv1.nb_outputs); OPUS_COPY(&dense_in[NB_FEATURES], &model->cond_net_pembed.float_weights[IMAX(0,IMIN(period-32, 224))*COND_NET_PEMBED_OUT_SIZE], COND_NET_PEMBED_OUT_SIZE); OPUS_COPY(dense_in, features, NB_FEATURES); compute_generic_dense(&model->cond_net_fdense1, conv1_in, dense_in, ACTIVATION_TANH); - compute_generic_conv1d(&model->cond_net_fconv1, conv2_in, st->cond_conv1_state, conv1_in, COND_NET_FCONV1_IN_SIZE, ACTIVATION_TANH); - compute_generic_conv1d(&model->cond_net_fconv2, cond, st->cond_conv2_state, conv2_in, COND_NET_FCONV2_IN_SIZE, ACTIVATION_TANH); + compute_generic_conv1d(&model->cond_net_fconv1, fdense2_in, st->cond_conv1_state, conv1_in, COND_NET_FCONV1_IN_SIZE, ACTIVATION_TANH); + compute_generic_dense(&model->cond_net_fdense2, cond, fdense2_in, ACTIVATION_TANH); } static void fargan_deemphasis(float *pcm, float *deemph_mem) { @@ -141,7 +141,7 @@ static void run_fargan_subframe(FARGANState *st, float *pcm, const float *cond, void fargan_cont(FARGANState *st, const float *pcm0, const float *features0) { int i; - float cond[COND_NET_FCONV2_OUT_SIZE]; + float cond[COND_NET_FDENSE2_OUT_SIZE]; float x0[FARGAN_CONT_SAMPLES]; float dummy[FARGAN_SUBFRAME_SIZE]; int period=0; @@ -196,7 +196,7 @@ int fargan_load_model(FARGANState *st, const unsigned char *data, int len) { static void fargan_synthesize_impl(FARGANState *st, float *pcm, const float *features) { int subframe; - float cond[COND_NET_FCONV2_OUT_SIZE]; + float cond[COND_NET_FDENSE2_OUT_SIZE]; int period; celt_assert(st->cont_initialized); diff --git a/dnn/fargan.h b/dnn/fargan.h index 1031c0054..fd5ee4f0d 100644 --- a/dnn/fargan.h +++ b/dnn/fargan.h @@ -35,7 +35,7 @@ #define FARGAN_NB_SUBFRAMES 4 #define FARGAN_SUBFRAME_SIZE 40 #define FARGAN_FRAME_SIZE (FARGAN_NB_SUBFRAMES*FARGAN_SUBFRAME_SIZE) -#define FARGAN_COND_SIZE (COND_NET_FCONV2_OUT_SIZE/FARGAN_NB_SUBFRAMES) +#define FARGAN_COND_SIZE (COND_NET_FDENSE2_OUT_SIZE/FARGAN_NB_SUBFRAMES) #define FARGAN_DEEMPHASIS 0.85f #define SIG_NET_INPUT_SIZE (FARGAN_COND_SIZE+2*FARGAN_SUBFRAME_SIZE+4) @@ -49,7 +49,6 @@ typedef struct { float deemph_mem; float pitch_buf[PITCH_MAX_PERIOD]; float cond_conv1_state[COND_NET_FCONV1_STATE_SIZE]; - float cond_conv2_state[COND_NET_FCONV2_STATE_SIZE]; float fwc0_mem[SIG_NET_FWC0_STATE_SIZE]; float gru1_state[SIG_NET_GRU1_STATE_SIZE]; float gru2_state[SIG_NET_GRU2_STATE_SIZE]; diff --git a/dnn/torch/fargan/fargan.py b/dnn/torch/fargan/fargan.py index 65f0a97b7..8dbb694d3 100644 --- a/dnn/torch/fargan/fargan.py +++ b/dnn/torch/fargan/fargan.py @@ -125,7 +125,7 @@ def forward(self, x): return out class FWConv(nn.Module): - def __init__(self, in_size, out_size, kernel_size=3): + def __init__(self, in_size, out_size, kernel_size=2): super(FWConv, self).__init__() torch.manual_seed(5) @@ -163,20 +163,22 @@ def __init__(self, feature_dim=20, cond_size=256, pembed_dims=12): self.pembed = nn.Embedding(224, pembed_dims) self.fdense1 = nn.Linear(self.feature_dim + pembed_dims, 64, bias=False) self.fconv1 = nn.Conv1d(64, 128, kernel_size=3, padding='valid', bias=False) - self.fconv2 = nn.Conv1d(128, 80*4, kernel_size=3, padding='valid', bias=False) + self.fdense2 = nn.Linear(128, 80*4, bias=False) self.apply(init_weights) nb_params = sum(p.numel() for p in self.parameters()) print(f"cond model: {nb_params} weights") def forward(self, features, period): + features = features[:,2:,:] + period = period[:,2:] p = self.pembed(period-32) features = torch.cat((features, p), -1) tmp = torch.tanh(self.fdense1(features)) tmp = tmp.permute(0, 2, 1) tmp = torch.tanh(self.fconv1(tmp)) - tmp = torch.tanh(self.fconv2(tmp)) tmp = tmp.permute(0, 2, 1) + tmp = torch.tanh(self.fdense2(tmp)) #tmp = torch.tanh(self.fdense2(tmp)) return tmp @@ -190,21 +192,20 @@ def __init__(self, subframe_size=40, nb_subframes=4, cond_size=256): self.cond_gain_dense = nn.Linear(80, 1) #self.sig_dense1 = nn.Linear(4*self.subframe_size+self.passthrough_size+self.cond_size, self.cond_size, bias=False) - self.fwc0 = FWConv(2*self.subframe_size+80+4, self.cond_size) - self.gru1 = nn.GRUCell(self.cond_size+2*self.subframe_size, self.cond_size, bias=False) - self.gru2 = nn.GRUCell(self.cond_size+2*self.subframe_size, 128, bias=False) + self.fwc0 = FWConv(2*self.subframe_size+80+4, 192) + self.gru1 = nn.GRUCell(192+2*self.subframe_size, 160, bias=False) + self.gru2 = nn.GRUCell(160+2*self.subframe_size, 128, bias=False) self.gru3 = nn.GRUCell(128+2*self.subframe_size, 128, bias=False) - self.dense1_glu = GLU(self.cond_size) - self.gru1_glu = GLU(self.cond_size) + self.gru1_glu = GLU(160) self.gru2_glu = GLU(128) self.gru3_glu = GLU(128) - self.skip_glu = GLU(self.cond_size) + self.skip_glu = GLU(128) #self.ptaps_dense = nn.Linear(4*self.cond_size, 5) - self.skip_dense = nn.Linear(2*128+2*self.cond_size+2*self.subframe_size, self.cond_size, bias=False) - self.sig_dense_out = nn.Linear(self.cond_size, self.subframe_size, bias=False) - self.gain_dense_out = nn.Linear(self.cond_size, 4) + self.skip_dense = nn.Linear(192+160+2*128+2*self.subframe_size, 128, bias=False) + self.sig_dense_out = nn.Linear(128, self.subframe_size, bias=False) + self.gain_dense_out = nn.Linear(192, 4) self.apply(init_weights) @@ -291,10 +292,10 @@ def forward(self, features, period, nb_frames, pre=None, states=None): nb_pre_frames = pre.size(1)//self.frame_size if pre is not None else 0 states = ( - torch.zeros(batch_size, self.cond_size, device=device), + torch.zeros(batch_size, 160, device=device), torch.zeros(batch_size, 128, device=device), torch.zeros(batch_size, 128, device=device), - torch.zeros(batch_size, (2*self.subframe_size+80+4)*2, device=device) + torch.zeros(batch_size, (2*self.subframe_size+80+4)*1, device=device) ) sig = torch.zeros((batch_size, 0), device=device)