Skip to content

Commit

Permalink
fix: adam not considering Flatten layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Az-r-ow committed Jul 8, 2024
1 parent 5927c91 commit c05c9e9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/NeuralNet/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,12 @@ void Network::updateOptimizerSetup(size_t numLayers) {

if (std::dynamic_pointer_cast<Adam>(this->optimizer)) {
// Get the number of dense layers
nLayers =
std::count_if(this->layers.begin(), this->layers.end(),
[](const std::shared_ptr<Layer> &ptr) {
return std::dynamic_pointer_cast<Dense>(ptr) != nullptr;
});
nLayers = std::count_if(
this->layers.begin(), this->layers.end(),
[](const std::shared_ptr<Layer> &ptr) {
return std::dynamic_pointer_cast<Dense>(ptr) != nullptr ||
std::dynamic_pointer_cast<Flatten>(ptr) != nullptr;
});
}

this->optimizer->insiderInit(nLayers);
Expand Down
3 changes: 3 additions & 0 deletions src/NeuralNet/optimizers/Adam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class Adam : public Optimizer {
v = Eigen::MatrixBase<Derived1>::Zero(param.rows(), param.cols());
}

assert(gradients.rows() == m.rows() && gradients.cols() == m.cols());
assert(gradients.rows() == v.rows() && gradients.cols() == v.cols());

// update biased first moment estimate
m = (beta1 * m).array() + ((1 - beta2) * gradients.array()).array();

Expand Down

0 comments on commit c05c9e9

Please sign in to comment.