Skip to content

Commit

Permalink
Added helper to differentiate matrix types of Row vector, column vect…
Browse files Browse the repository at this point in the history
…or, matrix or scalar type.
  • Loading branch information
alejandroarmas committed May 19, 2022
1 parent fbd7f45 commit 2ebab73
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 38 deletions.
9 changes: 8 additions & 1 deletion include/m_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ namespace Matrix {
static_assert(MatrixOperatable<SoftMax>);


// class Transpose : public UnaryAdapter<ReLU> {

// public:
// Matrix::Representation operate(
// const Matrix::Representation& m) const noexcept;
// };

}

namespace Metric {
Expand All @@ -92,7 +99,7 @@ namespace Matrix {

auto result = Impl().operate(l, r);

assert(result.num_rows() == 1 && result.num_cols() == 1 && "Metric Operation must return scalar.");
assert(result.get_type() == Matrix::Representation::Type::SCALAR && "Metric Operation must return scalar.");

return Matrix::Representation{result};
}
Expand Down
54 changes: 45 additions & 9 deletions include/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <memory>
#include <utility>

#include "assert.h"
#include "strong_types.h"


Expand All @@ -17,21 +18,30 @@ namespace Matrix {
using Rows = NamedType<u_int64_t, struct RowParameter>;
using Columns = NamedType<u_int64_t, struct ColumnParameter>;




class Representation {

class Representation {

public:
using matrix_iter = std::vector<float>::iterator;
using const_matrix_iter = std::vector<float>::const_iterator;

~Representation() noexcept {};


enum class Type : uint8_t {
MATRIX,
ROW_VECTOR,
COLUMN_VECTOR,
SCALAR,
};


Representation() noexcept : rows(0), columns(0) {}


explicit Representation(Rows _l, Columns _w) noexcept :
explicit Representation(Rows _l, Columns _w) noexcept :
rows(_l.get()),
columns(_w.get()),
data(std::vector<float>(_l.get() * _w.get(), 0)) {}
Expand Down Expand Up @@ -63,7 +73,8 @@ namespace Matrix {
data = std::move(_other.data);
return *this;
}


Type get_type(void) const noexcept;

bool operator==(const Matrix::Representation _other) noexcept;
bool operator!=(const Matrix::Representation _other) noexcept;
Expand All @@ -83,18 +94,43 @@ namespace Matrix {
constexpr const_matrix_iter constScanEnd() const { return data.cend(); }


friend void swap(Representation& left, Representation& right) noexcept {
std::swap(left.rows, right.rows);
std::swap(left.columns, right.columns);
std::swap(left.data, right.data);
}
// friend void swap(Representation& left, Representation& right) noexcept {
// std::swap(left.rows, right.rows);
// std::swap(left.columns, right.columns);
// std::swap(left.data, right.data);
// }

private:
u_int64_t rows;
u_int64_t columns;
std::vector<float> data;
};



// class Matrix : public Representation {

// };

// class ColumnVector;

// class RowVector : public Representation {
// RowVector() {}

// ColumnVector transpose(void) {

// }
// };

// class ColumnVector : public Representation {

// ColumnVector() {}

// RowVector transpose(void) {
// }

// };

}


Expand Down
44 changes: 16 additions & 28 deletions m_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,6 @@ namespace Matrix {
assert((l.num_rows() == r.num_rows()) && (l.num_cols() == r.num_cols()));



// if ((l.num_rows() != r.num_rows()) && (l.num_cols() != r.num_cols())) {
// throw std::length_error(Utility::debug_message_2(l, r));
// }

auto output = Matrix::Representation(Rows(l.num_rows()), Columns(r.num_cols()));

Expand Down Expand Up @@ -145,18 +141,15 @@ namespace Matrix {


#if DEBUG
if ((l.num_rows() != r.num_rows()) && (l.num_cols() != r.num_cols()))
if (l.get_type() != r.get_type() ||
l.get_type() =! Matrix::Representation::Type::COLUMN_VECTOR &&
l.get_type() =! Matrix::Representation::Type::ROW_VECTOR)
std::cout << Utility::debug_message_2(l, r) << endl;
#endif
assert((l.num_rows() == r.num_rows()) && (l.num_cols() == r.num_cols()));
assert((l.num_rows() == 1) && (l.num_cols() == 1) || (l.num_cols() == 1) && (l.num_rows() == 1) && "Operands are not Vectors.");

// if (l.num_rows() != r.num_rows() && l.num_cols() != r.num_cols()) {
// throw std::length_error(Utility::debug_message_2(l, r));
// }
// if (l.num_rows() != 1 && l.num_cols() != 1) {
// throw std::length_error("Operands are not Vectors.");
// }
assert(l.get_type() == r.get_type() &&
l.get_type() == Matrix::Representation::Type::COLUMN_VECTOR ||
l.get_type() == Matrix::Representation::Type::ROW_VECTOR &&
"Operands are not Vectors.");

u_int64_t dimension;

Expand Down Expand Up @@ -208,14 +201,15 @@ namespace Matrix {


#if DEBUG
if ((l.num_rows() != r.num_rows()) && (l.num_cols() != r.num_cols()))
if (l.get_type() != r.get_type() ||
l.get_type() =! Matrix::Representation::Type::COLUMN_VECTOR &&
l.get_type() =! Matrix::Representation::Type::ROW_VECTOR)
std::cout << Utility::debug_message_2(l, r) << endl;
#endif
assert((l.num_rows() == r.num_rows()) && (l.num_cols() == r.num_cols()));

// if ((l.num_rows() != r.num_rows()) && (l.num_cols() != r.num_cols())) {
// throw std::length_error("Matrix A not same size as Matrix B.");
// }
assert(l.get_type() == r.get_type() &&
l.get_type() == Matrix::Representation::Type::COLUMN_VECTOR ||
l.get_type() == Matrix::Representation::Type::ROW_VECTOR &&
"Operands are not Vectors.");

Matrix::Representation output = Matrix::Representation(Rows(l.num_rows()), Columns(r.num_cols()));

Expand Down Expand Up @@ -324,10 +318,7 @@ namespace Matrix {
const Matrix::Representation& l,
const Matrix::Representation& r) const noexcept {

// if (l.num_cols() != r.num_rows()) {

// throw std::length_error(Utility::debug_message(l, r));
// }

#if DEBUG
if (l.num_cols() != r.num_rows())
std::cout << Utility::debug_message(l, r) << endl;
Expand All @@ -347,10 +338,7 @@ namespace Matrix {
const Matrix::Representation& l,
const Matrix::Representation& r) const noexcept {

// if (l.num_cols() != r.num_rows()) {
// throw std::length_error(Utility::debug_message(l, r));

// }


#if DEBUG
if (l.num_cols() != r.num_rows())
Expand Down
16 changes: 16 additions & 0 deletions matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

bool Matrix::Representation::operator==(const Matrix::Representation _other) noexcept {


bool isEqual = this->data.size() == _other.data.size();

for (size_t i = 0; isEqual && i < this->data.size(); i++) {
Expand Down Expand Up @@ -51,6 +52,21 @@ void Matrix::Representation::put(u_int64_t r, u_int64_t c, float val) noexcept {
}


Matrix::Representation::Type Matrix::Representation::get_type(void) const noexcept {
bool is_row_vector = rows == 1;
bool is_column_vector = columns == 1;
bool is_scalar = is_row_vector && is_column_vector;


if (is_scalar)
return Type::SCALAR;
else if (is_column_vector)
return Type::COLUMN_VECTOR;
else if (is_row_vector)
return Type::ROW_VECTOR;
else
return Type::MATRIX;
}



Expand Down

0 comments on commit 2ebab73

Please sign in to comment.