Skip to content

Commit

Permalink
Abstraction for computational graph node added.
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandroarmas committed Apr 16, 2022
1 parent 1aedba7 commit 2604d07
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
107 changes: 107 additions & 0 deletions include/m_algorithms_register.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#ifndef TENSOR_ALGORITHM_REGISTER_H
#define TENSOR_ALGORITHM_REGISTER_H


#include "tensor.h"

#include <utility> // std::pair

/*
Visitor interface for all Computational Steps
and then Tensor inherits from that class.
Visitor Polymorphism depending on task,
reading for creating graph, writing data back,
...
*/

namespace NeuralNetwork {

namespace Computation {

namespace Graph {

class Tensor;

class RegisteredOperation : std::enable_shared_from_this<RegisteredOperation> {

using cgNode = std::shared_ptr<RegisteredOperation>;
using T = std::shared_ptr<Tensor>;
using NodePair = std::pair<cgNode, cgNode>;

public:

constexpr Matrix::Operations::Code get_operation_code(void) { return m_type; }
T share_tensor () { return result; }

static std::shared_ptr<RegisteredOperation> create(
const Matrix::Operations::Code _typ, T _res,
cgNode _op = nullptr, cgNode _op2 = nullptr) {

return std::shared_ptr<RegisteredOperation>(
new RegisteredOperation(_typ, _res, _op, _op2)
);

}

std::shared_ptr<RegisteredOperation> get_operation(void) {
return shared_from_this();
}


NodePair get_operands(void) {

if (operand && bin_operand) {
return {
this->operand->get_operation(),
this->bin_operand->get_operation()
};
}
else if (operand) {
return {
this->operand->get_operation(),
nullptr
};
}
else if (bin_operand) {
return {
nullptr,
this->bin_operand->get_operation()
};
}

return {
nullptr,
nullptr
};


}


protected:
const Matrix::Operations::Code m_type;
T result;
cgNode operand;
cgNode bin_operand;
private:
RegisteredOperation(const Matrix::Operations::Code _typ, T _res,
cgNode _op, cgNode _op2) :
m_type(_typ), result(_res),
operand(std::move(_op)),
bin_operand(std::move(_op2)) {}

};


}

}

}




#endif // TENSOR_ALGORITHM_REGISTER_H
21 changes: 21 additions & 0 deletions m_algorithms_register.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

#include "m_algorithms_register.h"

#include "m_algorithms.h"
#include <iostream>

#include <memory>

namespace NeuralNetwork {

namespace Computation {

namespace Graph {



}

}

}

0 comments on commit 2604d07

Please sign in to comment.