Skip to content

Commit

Permalink
Tensor registration encapsulated by factory function: TensorConstruct…
Browse files Browse the repository at this point in the history
…or::create() .
  • Loading branch information
alejandroarmas committed May 15, 2022
1 parent 8ff06d7 commit 783c2cf
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 101 deletions.
90 changes: 45 additions & 45 deletions include/tensor_forward_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,52 @@ namespace NeuralNetwork {
};


class ComputeTag;
class RecordTag;
/*
DESCRIPTION:
class PerformTensorStrategy {
Curiosily recurring Template Pattern for
accepting Strategy implementation visitor
USAGE:
if (recordTensorOperation && isBinaryOp) {
RecordBinaryTag _;
return _.compute_tensor(std::move(op_type), l, r, implementation);
}
*/
class PerformTensorStrategy {

public:
PerformTensorStrategy() = default;
template <class StrategyType>
struct StrategyTag {

template <Matrix::Operations::MatrixOperatable Operator>
std::shared_ptr<Tensor> compute_tensor(
Operator _op,
const std::shared_ptr<Tensor> l,
const std::shared_ptr<Tensor> r,
PerformTensorStrategy& strat_implementation) {

return strat_implementation.compute(
_op, l, r, *static_cast<
StrategyType const*>(this));
} };

class ComputeTag : public StrategyTag<ComputeTag> {
public:
ComputeTag() = default;
};
class RecordTag : public StrategyTag<RecordTag> {
public:
RecordTag() = default;
};

PerformTensorStrategy() :
map(ComputationalGraphMap::get()) {}

template <Matrix::Operations::MatrixOperatable Operator>
std::shared_ptr<Tensor> compute(
Expand All @@ -78,51 +116,13 @@ namespace NeuralNetwork {
const std::shared_ptr<Tensor> l,
const std::shared_ptr<Tensor> r,
RecordTag _);
private:
ComputationalGraphMap& map;

};



/*
DESCRIPTION:
Curiosily recurring Template Pattern for
accepting Strategy implementation visitor
USAGE:
if (recordTensorOperation && isBinaryOp) {
RecordBinaryTag _;
return _.compute_tensor(std::move(op_type), l, r, implementation);
}
*/
template <class StrategyType>
struct StrategyTag {

template <Matrix::Operations::MatrixOperatable Operator>
std::shared_ptr<Tensor> compute_tensor(
Operator _op,
const std::shared_ptr<Tensor> l,
const std::shared_ptr<Tensor> r,
PerformTensorStrategy& strat_implementation) {

return strat_implementation.compute(
_op, l, r, *static_cast<
StrategyType const*>(this));
} };

class ComputeTag : public StrategyTag<ComputeTag> {
public:
ComputeTag() = default;
};
class RecordTag : public StrategyTag<RecordTag> {
public:
RecordTag() = default;
};



Expand All @@ -134,4 +134,4 @@ namespace NeuralNetwork {
}


#endif // TENSOR_FORWARD_WRAPPER
#endif // TENSOR_FORWARD_WRAPPER
117 changes: 61 additions & 56 deletions tensor_forward_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@

#include "tensor_forward_wrapper.h"
#include "tensor.h"
#include "tensor_factory.h"

#include "matrix.h"
#include "generator.h"
#include "m_algorithms.h"
#include "m_algorithms_register.h"
#include "m_algorithms_utilities.h"

#include "strong_types.h"

#include <chrono>
#include <concepts>
#include <memory>
Expand All @@ -34,25 +37,30 @@ namespace NeuralNetwork {

if (recordTensorOperation) {

RecordTag _;
PerformTensorStrategy::RecordTag _;

return _.compute_tensor(op_type, l, r, implementation);
}

ComputeTag _;
PerformTensorStrategy::ComputeTag _;

return _.compute_tensor(op_type, l, nullptr, implementation);
return _.compute_tensor(op_type, l, r, implementation);

}

/*
templates explicit instantiation
*/

template class TensorOp<Matrix::Operations::Unary::ReLU>;
template class TensorOp<Matrix::Operations::Unary::SoftMax>;
template class TensorOp<Matrix::Operations::Binary::HadamardProduct::Std>;
template class TensorOp<Matrix::Operations::Binary::Multiplication::ParallelDNC>;
template class TensorOp<Matrix::Operations::Binary::Multiplication::Naive>;
template class TensorOp<Matrix::Operations::Binary::Multiplication::Square>;
template class TensorOp<Matrix::Operations::Binary::Addition::Std>;
template class TensorOp<Matrix::Operations::Binary::OuterProduct::Naive>;
template class TensorOp<Matrix::Operations::Metric::CrossEntropy>;



Expand All @@ -65,46 +73,46 @@ namespace NeuralNetwork {


Matrix::Representation out_matrix;
std::shared_ptr<Tensor> out_tensor;
std::shared_ptr<RegisteredOperation> out_op;

Matrix::Operations::Utility::Codify codify;
auto op_code = codify(_op);
std::shared_ptr<Tensor> out_tensor;


if constexpr (Matrix::Operations::UnaryMatrixOperatable<Operator>) {
out_matrix = _op(l->release_matrix());
out_matrix = _op(
l->release_matrix());
}
else if constexpr (Matrix::Operations::BinaryMatrixOperatable<Operator>) {
out_matrix = _op(
l->release_matrix(),
r->release_matrix()
);
l->release_matrix(),
r->release_matrix()
);
}


out_tensor = std::make_shared<Tensor>
(std::move(out_matrix), IsTrackable(true),
IsLeaf(false), IsRecordable(false));


if constexpr (Matrix::Operations::UnaryMatrixOperatable<Operator>) {
out_op = OperationFactory::create(
op_code,
*out_tensor,
l->get_operation()
);

out_tensor = TensorConstructor::create(_op,
std::move(out_matrix),
l->get_tensor_id(),
TensorID(0),
IsTrackable(true),
IsLeaf(true),
IsRecordable(false));
l->become_parent();
}
else if constexpr (Matrix::Operations::BinaryMatrixOperatable<Operator>) {
out_op = OperationFactory::create(
op_code,
*out_tensor,
l->get_operation(),
r->get_operation()
);
}

out_tensor = TensorConstructor::create(_op,
std::move(out_matrix),
l->get_tensor_id(),
r->get_tensor_id(),
IsTrackable(true),
IsLeaf(true),
IsRecordable(false));

l->become_parent();
r->become_parent();

}

out_tensor->register_operation(out_op);

return out_tensor;

Expand All @@ -125,13 +133,9 @@ namespace NeuralNetwork {
Matrix::Operations::Utility::Stringify stringify;
_s.set_operation_string(stringify(_op));

Matrix::Operations::Utility::Codify codify;
Matrix::Operations::Code op_code = codify(_op);

Matrix::Representation out_matrix;
std::shared_ptr<Tensor> out_tensor;
std::shared_ptr<RegisteredOperation> out_op;



_s.set_matrix_start(std::chrono::steady_clock::now());

Expand All @@ -147,32 +151,33 @@ namespace NeuralNetwork {
}

_s.set_matrix_end(std::chrono::steady_clock::now());


out_tensor = std::make_shared<Tensor>(
std::move(out_matrix), IsTrackable(true),
IsLeaf(false), IsRecordable(true));



if constexpr (Matrix::Operations::UnaryMatrixOperatable<Operator>) {
out_op = OperationFactory::create(
op_code,
*out_tensor,
l->get_operation()
);

out_tensor = TensorConstructor::create(_op,
std::move(out_matrix),
l->get_tensor_id(),
TensorID(0),
IsTrackable(true),
IsLeaf(true),
IsRecordable(true));
l->become_parent();

}
else if constexpr (Matrix::Operations::BinaryMatrixOperatable<Operator>) {
out_op = OperationFactory::create(
op_code,
*out_tensor,
l->get_operation(),
r->get_operation()
);

out_tensor = TensorConstructor::create(_op,
std::move(out_matrix),
l->get_tensor_id(),
r->get_tensor_id(),
IsTrackable(true),
IsLeaf(true),
IsRecordable(true));
l->become_parent();
r->become_parent();
}


out_tensor->register_operation(out_op);

_s.set_graph_end(std::chrono::steady_clock::now());
out_tensor->stats = _s;

Expand Down

0 comments on commit 783c2cf

Please sign in to comment.