Skip to content

Commit

Permalink
feat: getSlug function
Browse files Browse the repository at this point in the history
  • Loading branch information
Az-r-ow committed Jul 6, 2024
1 parent 1ba0e4e commit a2f8390
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/NeuralNet/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ std::shared_ptr<Layer> Network::getOutputLayer() const {
return this->layers[this->layers.size() - 1];
}

std::string Network::getSlug() const {
std::string slug;

for (int l = 0; l < this->layers.size(); l++) {
Layer &cLayer = *this->layers[l];
slug += cLayer.getSlug() + "-";
}

slug.pop_back(); // remove last "-"
return slug;
}

double Network::train(std::vector<std::vector<double>> X, std::vector<double> y,
int epochs,
std::vector<std::shared_ptr<Callback>> callbacks,
Expand Down
8 changes: 8 additions & 0 deletions src/NeuralNet/Network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ class Network : public Model {
*/
size_t getNumLayers() const;

/**
* @brief Get the slug of the network based on it's architecture
*
* @return A string representing the combined slug of the different components
* in the Network
*/
std::string getSlug() const;

/**
* @brief This method will Train the model with the given inputs and labels
*
Expand Down
8 changes: 8 additions & 0 deletions src/NeuralNet/layers/Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class Dense : public Layer {
return;
};

/**
* @brief Dense layer slug
*/
std::string getSlug() const override {
return slug + std::to_string(nNeurons);
}

/**
* @brief This method is used to feed the inputs to the layer
*
Expand Down Expand Up @@ -80,6 +87,7 @@ class Dense : public Layer {
friend class Network;

double bias;
std::string slug = "dns";
Eigen::MatrixXd biases;
WEIGHT_INIT weightInit;
Eigen::MatrixXd weights;
Expand Down
8 changes: 8 additions & 0 deletions src/NeuralNet/layers/Dropout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class Dropout : public Layer {
this->scaleRate = 1.0 / (1.0 - rate);
};

/**
* @brief Dropout layer slug
*/
std::string getSlug() const override {
return slug + removeTrailingZeros(std::to_string(rate));
}

/**
* @brief This method is used to feed the inputs to the layer
*
Expand All @@ -49,6 +56,7 @@ class Dropout : public Layer {

private:
std::vector<std::tuple<int, int>> coordinates;
std::string slug = "do";

// non-public serialization
friend class cereal::access;
Expand Down
6 changes: 6 additions & 0 deletions src/NeuralNet/layers/Flatten.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ class Flatten : public Layer {
type = LayerType::FLATTEN;
};

/**
* @brief Flatten layer's slug
*/
std::string getSlug() const override { return slug; }

/**
* @brief This method flattens a 3D vector into a 2D Eigen::MatrixXd
*
Expand Down Expand Up @@ -72,6 +77,7 @@ class Flatten : public Layer {
friend class cereal::access;

std::tuple<int, int> inputShape;
std::string slug = "fltn";

template <class Archive>
void serialize(Archive &ar) {
Expand Down
7 changes: 7 additions & 0 deletions src/NeuralNet/layers/Layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ class Layer {
*/
int getNumNeurons() const { return nNeurons; };

/**
* @brief The slug of the layers (name + main parameter value)
*
* @return the slug
*/
virtual std::string getSlug() const = 0;

/**
* @brief Method to print layer's outputs
*/
Expand Down
12 changes: 12 additions & 0 deletions src/NeuralNet/utils/Functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,16 @@ static void signalHandler(int signum) {
exit(signum);
};

/* STRING OPERATIONS */

static std::string removeTrailingZeros(std::string str) {
str.erase(str.find_last_not_of('0') + 1, std::string::npos);

if (str.back() == '.') {
str.pop_back();
}

return str;
}

} // namespace NeuralNet

0 comments on commit a2f8390

Please sign in to comment.