Skip to content

Commit

Permalink
Added iterable object that is released by tensors for iterating throu…
Browse files Browse the repository at this point in the history
…gh gradient.
  • Loading branch information
alejandroarmas committed Jun 2, 2022
1 parent 085c620 commit 024b8ca
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,28 @@ namespace NeuralNetwork {
namespace Graph {


class MatrixParameter {

using iterator = LevelOrderIterator<ReadParameterPolicy>;

public:
explicit MatrixParameter(TensorID _tid) : id(_tid) {}

iterator begin() noexcept {
return iterator{id};
}

iterator end() noexcept {
return iterator{TensorID(0)};
}
private:
TensorID id;


};



class TensorStatistics {


Expand Down Expand Up @@ -64,7 +86,7 @@ namespace NeuralNetwork {
class Tensor {

using matrix_t = Matrix::Representation;
using iterator = LevelOrderIterator;
using iterator = LevelOrderIterator<ComputeGradientPolicy>;

public:
~Tensor() noexcept {}
Expand Down Expand Up @@ -93,7 +115,7 @@ namespace NeuralNetwork {
void become_parent() noexcept;

matrix_t& release_matrix() noexcept;
// matrix_t get_grad();
matrix_t& get_grad() noexcept;

Matrix::Rows num_rows(void) const noexcept;
Matrix::Columns num_cols(void) const noexcept;
Expand All @@ -112,9 +134,13 @@ namespace NeuralNetwork {
return iterator{TensorID(0)};
}

MatrixParameter parameters() noexcept {
return MatrixParameter{my_tensor_id};
}

private:
matrix_t matrix;
// std::optional<matrix_t> grad;
matrix_t grad;
TensorID my_tensor_id;
bool is_leaf;
bool requires_grad;
Expand Down

0 comments on commit 024b8ca

Please sign in to comment.