Skip to content

Commit

Permalink
Changed main API
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandroarmas committed Apr 29, 2022
1 parent 2fd8473 commit 028e46b
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,23 @@ pred_probab = nn.Softmax(dim=1)(logits)

int main(void) {

using matrix_t = Matrix::Representation;

std::unique_ptr<matrix_t> ma = std::make_unique<matrix_t>(Matrix::Rows(1), Matrix::Columns(2000));
Matrix::Generation::Normal<0, 1> normal_distribution_init;
ma = normal_distribution_init(std::move(ma));
auto ma = std::make_shared<NeuralNetwork::Computation::Graph::Tensor>(Matrix::Rows(1), Matrix::Columns(2000));

NeuralNetwork::Sequential model;

model.add(std::make_unique<NeuralNetwork::Layer>(
std::make_unique<NeuralNetwork::MatrixMultiplyStep>(Matrix::Rows(2000), Matrix::Columns(1000), normal_distribution_init),
std::make_unique<NeuralNetwork::AddStep>(Matrix::Columns(1000), normal_distribution_init)
std::make_unique<NeuralNetwork::MatrixMultiplyStep>(Matrix::Rows(2000), Matrix::Columns(1000)),
std::make_unique<NeuralNetwork::AddStep>(Matrix::Columns(1000))
));
model.add(std::make_unique<NeuralNetwork::ActivationFunctions::ReLU>());
model.add(std::make_unique<NeuralNetwork::Layer>(
std::make_unique<NeuralNetwork::MatrixMultiplyStep>(Matrix::Rows(1000), Matrix::Columns(10), normal_distribution_init),
std::make_unique<NeuralNetwork::AddStep>(Matrix::Columns(10), normal_distribution_init)
std::make_unique<NeuralNetwork::MatrixMultiplyStep>(Matrix::Rows(1000), Matrix::Columns(10)),
std::make_unique<NeuralNetwork::AddStep>(Matrix::Columns(10))
));
model.add(std::make_unique<NeuralNetwork::ActivationFunctions::ReLU>());


auto out = model.forward(std::move(ma));
auto out = model.forward(ma);

// NeuralNetwork::Computation::Tree::ComputeOperation handler;
// handler.setNextHandler(std::make_unique<NeuralNetwork::Computation::Tree::TimerHandler>());
Expand Down

0 comments on commit 028e46b

Please sign in to comment.