Skip to content

Commit

Permalink
Implemented Subtract Binary Operation, Fixed Softmax Unary algorithm …
Browse files Browse the repository at this point in the history
…bug, implemented Metric.
  • Loading branch information
alejandroarmas committed May 15, 2022
1 parent d2a9c1f commit 8ff06d7
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 7 deletions.
83 changes: 77 additions & 6 deletions include/m_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <iostream>
#include <variant>

#include <assert.h>

#include "matrix.h"
#include "m_algorithms_concepts.h"

Expand All @@ -17,9 +19,8 @@ namespace Matrix {
namespace Operations {


enum struct Code {
NOP, MULTIPLY, PLUS, ReLU, OUTER_PRODUCT, HADAMARD,

enum class Code {
NOP, MULTIPLY, PLUS, ReLU, SoftMax, OUTER_PRODUCT, HADAMARD, CROSS_ENTROPY,
};


Expand All @@ -42,7 +43,6 @@ namespace Matrix {
Implementation& Impl() const { return *static_cast<Implementation*>(const_cast<UnaryAdapter<Implementation>*>(this)); }
friend Implementation;


};


Expand All @@ -58,8 +58,65 @@ namespace Matrix {
static_assert(MatrixOperatable<ReLU>);


class SoftMax : public UnaryAdapter<ReLU> {

public:
Matrix::Representation operate(
const Matrix::Representation& m) const;
};

static_assert(MatrixOperatable<SoftMax>);


}

namespace Metric {



template <class Implementation>
class BaseOp {

public:
BaseOp() = default;
~BaseOp() = default;
Matrix::Representation operator()(
const Matrix::Representation& l,
const Matrix::Representation& r) const {

bool rows_compatable = l.num_rows() == r.num_rows();
bool cols_compatable = l.num_cols() == r.num_cols();
bool is_vector = l.num_rows() == 1 || l.num_cols() == 1;

assert(rows_compatable && cols_compatable && is_vector);

auto result = Impl().operate(l, r);

assert(result.num_rows() == 1 && result.num_cols() == 1 && "Metric Operation must return scalar.");

return result;
}
private:
Implementation& Impl() const { return *static_cast<Implementation*>(const_cast<BaseOp<Implementation>*>(this)); }
friend Implementation;

};


class CrossEntropy : public BaseOp<CrossEntropy> {
public:
Matrix::Representation operate(
const Matrix::Representation& p,
const Matrix::Representation& q) const;
};


static_assert(MatrixOperatable<CrossEntropy>);


} // Metric


namespace Binary {


Expand All @@ -68,13 +125,13 @@ namespace Matrix {

public:
BaseOp() = default;
virtual Matrix::Representation operator()(
~BaseOp() = default;
Matrix::Representation operator()(
const Matrix::Representation& l,
const Matrix::Representation& r) const {

return Impl().operate(l, r);
};
virtual ~BaseOp() = default;
private:
Implementation& Impl() const { return *static_cast<Implementation*>(const_cast<BaseOp<Implementation>*>(this)); }
friend Implementation;
Expand All @@ -100,6 +157,20 @@ namespace Matrix {

static_assert(MatrixOperatable<Addition::Std>);

namespace Subtraction {


class Std : public BaseOp<Std> {
public:
Matrix::Representation operate(
const Matrix::Representation& l,
const Matrix::Representation& r) const;
};

}

static_assert(MatrixOperatable<Subtraction::Std>);


namespace OuterProduct {

Expand Down
69 changes: 68 additions & 1 deletion m_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <cilk/cilk.h>
#include <iostream>
#include <math.h>


namespace Matrix {
Expand All @@ -28,9 +29,57 @@ namespace Matrix {
return output;
}

}

Matrix::Representation SoftMax::operate(
const Matrix::Representation& m) const{


Matrix::Representation output = Matrix::Representation(
Matrix::Rows(m.num_rows()),
Matrix::Columns(m.num_cols())
);

auto max = std::max(m.constScanStart(), m.constScanEnd());

std::transform(m.constScanStart(), m.constScanEnd(), output.scanStart(), [max](const auto val) { return val - *max; });
std::transform(output.constScanStart(), output.constScanEnd(), output.scanStart(), [](const auto val) { return exp(val);});



return output;
}



} // Unary

namespace Metric {


Matrix::Representation CrossEntropy::operate(
const Matrix::Representation& p,
const Matrix::Representation& q) const {

Matrix::Representation output = Matrix::Representation(
Matrix::Rows(1),
Matrix::Columns(1)
);


float entropy = 0;

for (auto p_i = p.constScanStart(), q_i = q.constScanStart(); q_i != q.constScanEnd(); p_i++, q_i++) {
entropy += *p_i * log(*q_i);
}


output.put(0, 0, entropy);

return output;
}

}


namespace Binary {

Expand All @@ -54,6 +103,24 @@ namespace Matrix {
}
}

namespace Subtraction {

Matrix::Representation Std::operate(
const Matrix::Representation& l,
const Matrix::Representation& r) const {

if ((l.num_rows() != r.num_rows()) && (l.num_cols() != r.num_cols())) {
throw std::length_error(Utility::debug_message_2(l, r));
}

auto output = Matrix::Representation(Rows(l.num_rows()), Columns(r.num_cols()));

std::transform(l.constScanStart(), l.constScanEnd(), r.constScanStart(), output.scanStart(), std::minus<float>());

return output;
}
}


namespace OuterProduct {

Expand Down

0 comments on commit 8ff06d7

Please sign in to comment.