Skip to content

Commit

Permalink
Fix gated activation code (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson authored Mar 10, 2024
1 parent 74a07ce commit bc51a12
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ void nam::wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::M
// Mix-in condition
this->_z += this->_input_mixin.process(condition);

this->_activation->apply(this->_z);


if (this->_gated)
if (!this->_gated)
{
activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, 0, channels, this->_z.cols()));
this->_activation->apply(this->_z);
}
else
{
this->_activation->apply(this->_z.topRows(channels));
activations::Activation::get_activation("Sigmoid")->apply(this->_z.bottomRows(channels));
//activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, 0, channels, this->_z.cols()));

this->_z.topRows(channels).array() *= this->_z.bottomRows(channels).array();
// this->_z.topRows(channels) = this->_z.topRows(channels).cwiseProduct(
Expand Down

0 comments on commit bc51a12

Please sign in to comment.