-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Iterator for traversing nodes in reverse order of original computation.
- Loading branch information
1 parent
6d7ed4a
commit fe46ef7
Showing
4 changed files
with
228 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
|
||
#include "function_object_iterator.h" | ||
|
||
|
||
|
||
|
||
namespace NeuralNetwork { | ||
|
||
namespace Computation { | ||
|
||
namespace Graph { | ||
|
||
LevelOrderIterator& LevelOrderIterator::operator++(int) noexcept { | ||
|
||
if (!nodeStack.empty()) { | ||
|
||
|
||
current = nodeStack.top(); | ||
nodeStack.pop(); | ||
|
||
this->_stack_children(); | ||
} | ||
else current = TensorID(0); | ||
|
||
|
||
return *this; | ||
} | ||
|
||
|
||
FunctionObject LevelOrderIterator::operator*() const noexcept { | ||
|
||
ComputationalGraphMap& map = ComputationalGraphMap::get(); | ||
FunctionObject fn_obj = map._get_operation(current); | ||
return fn_obj; | ||
} | ||
|
||
|
||
void LevelOrderIterator::_stack_children(void) noexcept { | ||
|
||
ComputationalGraphMap& map = ComputationalGraphMap::get(); | ||
FunctionObject fn_obj = map._get_operation(current); | ||
|
||
fn_obj.stringify_type(); | ||
|
||
for (std::size_t i = 0; const auto tid: fn_obj.serialize()) { | ||
|
||
if (tid) { | ||
std::cout << "tid" << i << ": " << tid->get() << std::endl; | ||
|
||
if (i++) { | ||
nodeStack.emplace(tid->get()); | ||
} | ||
} | ||
|
||
} | ||
|
||
|
||
return; | ||
} | ||
|
||
|
||
} // Graph | ||
|
||
} // Computation | ||
|
||
} // NN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#ifndef FUNCTION_OBJECT_ITERATOR_H | ||
#define FUNCTION_OBJECT_ITERATOR_H | ||
|
||
|
||
#include "computational_graph_map.h" | ||
#include "function_object.h" | ||
#include "strong_types.h" | ||
|
||
#include <stack> | ||
#include <optional> | ||
|
||
|
||
|
||
namespace NeuralNetwork { | ||
|
||
namespace Computation { | ||
|
||
namespace Graph { | ||
|
||
/* | ||
Breadth First Search through computational graph. | ||
*/ | ||
class LevelOrderIterator { | ||
|
||
public: | ||
LevelOrderIterator(TensorID _t) noexcept : current(_t) { | ||
if (_t.get()) this->_stack_children(); | ||
} | ||
LevelOrderIterator(LevelOrderIterator&) = default; | ||
LevelOrderIterator(LevelOrderIterator&&) = default; | ||
LevelOrderIterator& operator=(const LevelOrderIterator&) = default; | ||
LevelOrderIterator& operator=(LevelOrderIterator&&) = default; | ||
LevelOrderIterator& operator++(int) noexcept; | ||
|
||
FunctionObject operator*() const noexcept; | ||
|
||
bool operator!=(const LevelOrderIterator& other) const noexcept{ | ||
return current != other.current; | ||
} | ||
|
||
private: | ||
void _stack_children() noexcept; | ||
|
||
TensorID current; | ||
std::stack<TensorID> nodeStack; | ||
|
||
|
||
|
||
|
||
}; | ||
|
||
} | ||
|
||
} | ||
|
||
} | ||
|
||
|
||
#endif // FUNCTION_OBJECT_ITERATOR_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#ifndef TENSOR_BACKWARDS_PASS | ||
#define TENSOR_BACKWARDS_PASS | ||
|
||
#include "tensor.h" | ||
|
||
namespace NeuralNetwork { | ||
|
||
namespace Computation { | ||
|
||
namespace Graph { | ||
|
||
class PrintTag; | ||
class GradientTag; | ||
|
||
|
||
class ReversePass { | ||
|
||
public: | ||
ReversePass() : | ||
map(ComputationalGraphMap::get()) {} | ||
|
||
void backwards( | ||
Tensor& _t, | ||
PrintTag _); | ||
|
||
void backwards( | ||
Tensor& _t, | ||
GradientTag _); | ||
private: | ||
ComputationalGraphMap& map; | ||
|
||
}; | ||
|
||
|
||
|
||
template <class StrategyType> | ||
struct ReverseTag { | ||
|
||
void _backwards( | ||
Tensor& _t, | ||
ReversePass& strat_implementation) { | ||
|
||
return strat_implementation.backwards( | ||
_t, *static_cast< | ||
StrategyType const*>(this)); | ||
} }; | ||
|
||
class PrintTag : public ReverseTag<PrintTag> { | ||
public: | ||
PrintTag() = default; | ||
}; | ||
class GradientTag : public ReverseTag<GradientTag> { | ||
public: | ||
GradientTag() = default; | ||
}; | ||
|
||
|
||
|
||
} | ||
|
||
} | ||
|
||
} | ||
|
||
|
||
#endif // TENSOR_BACKWARDS_PASS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include "tensor_backwards_pass.h" | ||
#include "m_algorithms_utilities.h" | ||
|
||
#include <iostream> | ||
#include <iomanip> | ||
|
||
#include <variant> | ||
#include <utility> | ||
#include <stack> | ||
|
||
|
||
namespace NeuralNetwork { | ||
|
||
namespace Computation { | ||
|
||
namespace Graph { | ||
|
||
|
||
void ReversePass::backwards(Tensor& _t, | ||
PrintTag _ ) { | ||
|
||
for (auto it = _t.begin(); it != _t.end(); it++) { | ||
std::cout << "Step" << std::endl; | ||
} | ||
|
||
|
||
} | ||
|
||
|
||
|
||
|
||
|
||
} | ||
|
||
} | ||
|
||
} |