Skip to content

Commit

Permalink
Gradient Instantiated with unit value. Might be useless.
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandroarmas committed Jun 2, 2022
1 parent 024b8ca commit 73be0d8
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ namespace NeuralNetwork {
IsTrackable _t, IsLeaf _f, IsRecordable _r) noexcept:
stats({}),
matrix(Matrix::Representation(_l, _w)),
// grad({}),
grad(Matrix::Representation(_l, _w)),
my_tensor_id(ComputationalGraphMap::get()._obtain_tensor_id()),
is_leaf(_f.get()),
requires_grad(_t.get()), record_statistics(_r.get()) {

Matrix::Generation::Normal<0, 1> normal_distribution_init;
matrix = normal_distribution_init(matrix);

Matrix::Generation::Tester<1> unit_gen;
grad = unit_gen(grad);


}

Expand All @@ -37,17 +41,21 @@ namespace NeuralNetwork {
IsTrackable _t, IsLeaf _f, IsRecordable _r) noexcept:
stats({}),
matrix(_m),
// grad({}),
grad(Matrix::Representation(
Matrix::Rows(_m.num_rows()),
Matrix::Columns(_m.num_cols()))),
my_tensor_id(ComputationalGraphMap::get()._obtain_tensor_id()),
is_leaf(_f.get()),
requires_grad(_t.get()), record_statistics(_r.get()) {

Matrix::Generation::Tester<1> unit_gen;
grad = unit_gen(grad);
}

Tensor::Tensor(const Tensor& other) noexcept:
// stats(other.stats),
matrix(other.matrix),
// grad(other.grad),
grad(other.grad),
my_tensor_id(other.my_tensor_id),
is_leaf(other.is_leaf),
requires_grad(other.requires_grad),
Expand All @@ -60,17 +68,12 @@ namespace NeuralNetwork {
requires_grad = other.requires_grad;
// stats = other.stats;
matrix = other.matrix;
// grad = other.grad;
grad = other.grad;
return *this;
}


void Tensor::detatch_from_computational_graph() noexcept {
// Matrix::Operations::Utility::Stringify stringify;
// auto fn = Matrix::Operations::Utility::Function::from(get_operation().get_code());

// std::cout << "Freeing " << this << " Registry: O[" << my_tensor_id.get() << "] = " << std::visit(stringify, fn) << std::endl;

ComputationalGraphMap::get()._recover_tensor_id(my_tensor_id);
}

Expand Down Expand Up @@ -101,9 +104,9 @@ namespace NeuralNetwork {
return matrix;
}

// Tensor::matrix_t Tensor::get_grad(){
// return grad.value_or(Matrix::Representation());
// }
Tensor::matrix_t& Tensor::get_grad() noexcept {
return grad;
}

Matrix::Rows Tensor::num_rows(void) const noexcept {
return Matrix::Rows(matrix.num_rows());
Expand Down

0 comments on commit 73be0d8

Please sign in to comment.