Skip to content

Commit

Permalink
operator <, >, <=, >=, == with n_type
Browse files Browse the repository at this point in the history
  • Loading branch information
EasternJournalist committed Mar 9, 2021
1 parent 0e25e38 commit 88bcafb
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
40 changes: 40 additions & 0 deletions include/DiffVar_arr.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,26 @@ namespace DiffNum {
return value == _Right.value;
}

bool operator <(const n_type _Right) const {
return value < _Right;
}

bool operator <=(const n_type _Right) const {
return value <= _Right;
}

bool operator >(const n_type _Right) const {
return value > _Right;
}

bool operator >=(const n_type _Right) const {
return value >= _Right;
}

bool operator ==(const n_type _Right) const {
return value == _Right;
}

const d_type& operator[](size_t var_idx) const {
return gradient[var_idx];
}
Expand Down Expand Up @@ -308,6 +328,26 @@ namespace DiffNum {
return getValue() == _Right.getValue();
}

bool operator <(const n_type _Right) const {
return getValue() < _Right;
}

bool operator <=(const n_type _Right) const {
return getValue() <= _Right;
}

bool operator >(const n_type _Right) const {
return getValue() > _Right;
}

bool operator >=(const n_type _Right) const {
return getValue() >= _Right;
}

bool operator ==(const n_type _Right) const {
return getValue() == _Right;
}

const d_type& operator[](size_t var_idx) const {
return gradient[var_idx];
}
Expand Down
40 changes: 40 additions & 0 deletions include/DiffVar_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,26 @@ namespace DiffNum {
return value == _Right.value;
}

__host__ __device__ bool operator <(const n_type _Right) const {
return value < _Right;
}

__host__ __device__ bool operator <=(const n_type _Right) const {
return value <= _Right;
}

__host__ __device__ bool operator >(const n_type _Right) const {
return value > _Right;
}

__host__ __device__ bool operator >=(const n_type _Right) const {
return value >= _Right;
}

__host__ __device__ bool operator ==(const n_type _Right) const {
return value == _Right;
}

__host__ __device__ const d_type& operator[](size_t var_idx) const {
return gradient[var_idx];
}
Expand Down Expand Up @@ -326,6 +346,26 @@ namespace DiffNum {
return gradient[var_idx];
}

__host__ __device__ bool operator <(const n_type _Right) const {
return getValue() < _Right;
}

__host__ __device__ bool operator <=(const n_type _Right) const {
return getValue() <= _Right;
}

__host__ __device__ bool operator >(const n_type _Right) const {
return getValue() > _Right;
}

__host__ __device__ bool operator >=(const n_type _Right) const {
return getValue() >= _Right;
}

__host__ __device__ bool operator ==(const n_type _Right) const {
return getValue() == _Right;
}

__host__ __device__ DiffVar_cuda() {}

__host__ __device__ DiffVar_cuda(const n_type value) : value(value) {
Expand Down
40 changes: 40 additions & 0 deletions include/DiffVar_vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ namespace DiffNum {
return value == _Right.value;
}

bool operator <(const n_type _Right) const {
return value < _Right;
}

bool operator <=(const n_type _Right) const {
return value <= _Right;
}

bool operator >(const n_type _Right) const {
return value > _Right;
}

bool operator >=(const n_type _Right) const {
return value >= _Right;
}

bool operator ==(const n_type _Right) const {
return value == _Right;
}

const d_type& operator[](size_t var_idx) const {
return gradient[var_idx];
}
Expand Down Expand Up @@ -339,6 +359,26 @@ namespace DiffNum {
return getValue() == _Right.getValue();
}

bool operator <(const n_type _Right) const {
return getValue() < _Right;
}

bool operator <=(const n_type _Right) const {
return getValue() <= _Right;
}

bool operator >(const n_type _Right) const {
return getValue() > _Right;
}

bool operator >=(const n_type _Right) const {
return getValue() >= _Right;
}

bool operator ==(const n_type _Right) const {
return getValue() == _Right;
}

const d_type& operator[](size_t var_idx) const {
return gradient[var_idx];
}
Expand Down

0 comments on commit 88bcafb

Please sign in to comment.