Skip to content

Commit

Permalink
Iterator for traversing nodes in reverse order of original computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandroarmas committed May 15, 2022
1 parent 6d7ed4a commit fe46ef7
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
66 changes: 66 additions & 0 deletions function_object_iterator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

#include "function_object_iterator.h"




namespace NeuralNetwork {

namespace Computation {

namespace Graph {

LevelOrderIterator& LevelOrderIterator::operator++(int) noexcept {

if (!nodeStack.empty()) {


current = nodeStack.top();
nodeStack.pop();

this->_stack_children();
}
else current = TensorID(0);


return *this;
}


FunctionObject LevelOrderIterator::operator*() const noexcept {

ComputationalGraphMap& map = ComputationalGraphMap::get();
FunctionObject fn_obj = map._get_operation(current);
return fn_obj;
}


void LevelOrderIterator::_stack_children(void) noexcept {

ComputationalGraphMap& map = ComputationalGraphMap::get();
FunctionObject fn_obj = map._get_operation(current);

fn_obj.stringify_type();

for (std::size_t i = 0; const auto tid: fn_obj.serialize()) {

if (tid) {
std::cout << "tid" << i << ": " << tid->get() << std::endl;

if (i++) {
nodeStack.emplace(tid->get());
}
}

}


return;
}


} // Graph

} // Computation

} // NN
59 changes: 59 additions & 0 deletions include/function_object_iterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#ifndef FUNCTION_OBJECT_ITERATOR_H
#define FUNCTION_OBJECT_ITERATOR_H


#include "computational_graph_map.h"
#include "function_object.h"
#include "strong_types.h"

#include <stack>
#include <optional>



namespace NeuralNetwork {

namespace Computation {

namespace Graph {

/*
Breadth First Search through computational graph.
*/
class LevelOrderIterator {

public:
LevelOrderIterator(TensorID _t) noexcept : current(_t) {
if (_t.get()) this->_stack_children();
}
LevelOrderIterator(LevelOrderIterator&) = default;
LevelOrderIterator(LevelOrderIterator&&) = default;
LevelOrderIterator& operator=(const LevelOrderIterator&) = default;
LevelOrderIterator& operator=(LevelOrderIterator&&) = default;
LevelOrderIterator& operator++(int) noexcept;

FunctionObject operator*() const noexcept;

bool operator!=(const LevelOrderIterator& other) const noexcept{
return current != other.current;
}

private:
void _stack_children() noexcept;

TensorID current;
std::stack<TensorID> nodeStack;




};

}

}

}


#endif // FUNCTION_OBJECT_ITERATOR_H
66 changes: 66 additions & 0 deletions include/tensor_backwards_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#ifndef TENSOR_BACKWARDS_PASS
#define TENSOR_BACKWARDS_PASS

#include "tensor.h"

namespace NeuralNetwork {

namespace Computation {

namespace Graph {

class PrintTag;
class GradientTag;


class ReversePass {

public:
ReversePass() :
map(ComputationalGraphMap::get()) {}

void backwards(
Tensor& _t,
PrintTag _);

void backwards(
Tensor& _t,
GradientTag _);
private:
ComputationalGraphMap& map;

};



template <class StrategyType>
struct ReverseTag {

void _backwards(
Tensor& _t,
ReversePass& strat_implementation) {

return strat_implementation.backwards(
_t, *static_cast<
StrategyType const*>(this));
} };

class PrintTag : public ReverseTag<PrintTag> {
public:
PrintTag() = default;
};
class GradientTag : public ReverseTag<GradientTag> {
public:
GradientTag() = default;
};



}

}

}


#endif // TENSOR_BACKWARDS_PASS
37 changes: 37 additions & 0 deletions tensor_backwards_pass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "tensor_backwards_pass.h"
#include "m_algorithms_utilities.h"

#include <iostream>
#include <iomanip>

#include <variant>
#include <utility>
#include <stack>


namespace NeuralNetwork {

namespace Computation {

namespace Graph {


void ReversePass::backwards(Tensor& _t,
PrintTag _ ) {

for (auto it = _t.begin(); it != _t.end(); it++) {
std::cout << "Step" << std::endl;
}


}





}

}

}

0 comments on commit fe46ef7

Please sign in to comment.