From 670fb62db63232ec8b913a6d97439646a64cc3f2 Mon Sep 17 00:00:00 2001 From: Victor Li Date: Wed, 30 Oct 2024 18:18:06 -0700 Subject: [PATCH] Changed ff_dim to ff_dim_t, added in nonnegative_int type --- .../local-execution/legion_tensor_shape.h | 2 +- lib/local-execution/src/ops/linear.cc | 2 +- .../op-attrs/dim_ordered/dim_ordered.h | 2 +- lib/op-attrs/include/op-attrs/ff_dim.h | 18 -- lib/op-attrs/include/op-attrs/ff_dim_t.h | 14 + ...f_dim.struct.toml => ff_dim_t.struct.toml} | 0 .../op-attrs/ops/combine_attrs.struct.toml | 4 +- .../op-attrs/ops/concat_attrs.struct.toml | 4 +- .../op-attrs/ops/flat_attrs.struct.toml | 4 +- .../op-attrs/ops/gather_attrs.struct.toml | 4 +- .../op-attrs/ops/layer_norm_attrs.struct.toml | 4 +- .../op-attrs/ops/reduce_attrs.struct.toml | 4 +- .../ops/repartition_attrs.struct.toml | 4 +- .../op-attrs/ops/reverse_attrs.struct.toml | 4 +- .../op-attrs/ops/softmax_attrs.struct.toml | 4 +- .../op-attrs/ops/split_attrs.struct.toml | 4 +- .../op-attrs/ops/transpose_attrs.struct.toml | 4 +- .../parallel_tensor_dim_idx_t.variant.toml | 2 +- .../include/op-attrs/relative_ff_dim_t.h | 14 + .../op-attrs/relative_ff_dim_t.struct.toml | 14 + lib/op-attrs/src/op-attrs/ff_dim_t.cc | 8 + .../src/op-attrs/relative_ff_dim_t.cc | 10 + .../test/src/op-attrs/dim_ordered/zip.cc | 2 +- lib/runtime/src/parallel_op_info.h | 2 +- .../operator_attribute_value.variant.toml | 2 +- .../utils/nonnegative_int/nonnegative_int.h | 65 +++++ .../utils/nonnegative_int/nonnegative_int.cc | 101 +++++++ .../utils/nonnegative_int/nonnegative_int.cc | 249 ++++++++++++++++++ 28 files changed, 504 insertions(+), 47 deletions(-) delete mode 100644 lib/op-attrs/include/op-attrs/ff_dim.h create mode 100644 lib/op-attrs/include/op-attrs/ff_dim_t.h rename lib/op-attrs/include/op-attrs/{ff_dim.struct.toml => ff_dim_t.struct.toml} (100%) create mode 100644 lib/op-attrs/include/op-attrs/relative_ff_dim_t.h create mode 100644 lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/ff_dim_t.cc create mode 100644 lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc create mode 100644 lib/utils/include/utils/nonnegative_int/nonnegative_int.h create mode 100644 lib/utils/src/utils/nonnegative_int/nonnegative_int.cc create mode 100644 lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc diff --git a/lib/local-execution/include/local-execution/legion_tensor_shape.h b/lib/local-execution/include/local-execution/legion_tensor_shape.h index 2f2ed50d41..e9b8a15876 100644 --- a/lib/local-execution/include/local-execution/legion_tensor_shape.h +++ b/lib/local-execution/include/local-execution/legion_tensor_shape.h @@ -3,7 +3,7 @@ #include "kernels/legion_dim.h" #include "op-attrs/datatype.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" #include "utils/visitable.h" diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index 9934e2a45c..f3ef022175 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -1,7 +1,7 @@ #include "linear.h" #include "kernels/linear_kernels.h" #include "local-execution/task_argument_accessor.h" -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/get_output_shapes.h" #include "utils/exception.h" #include "utils/hash-utils.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 6aa23d40fc..9a8d8adc0b 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H -#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim_t.dtg.h" #include "utils/fmt/vector.h" #include "utils/stack_vector.h" #include diff --git a/lib/op-attrs/include/op-attrs/ff_dim.h b/lib/op-attrs/include/op-attrs/ff_dim.h deleted file mode 100644 index e78ce4b51e..0000000000 --- a/lib/op-attrs/include/op-attrs/ff_dim.h +++ /dev/null @@ -1,18 +0,0 @@ - -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H - -#include "op-attrs/ff_dim.dtg.h" -#include "rapidcheck.h" - -namespace rc { -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::construct( - gen::inRange(0, MAX_TENSOR_DIM)); - } -}; -} // namespace rc - -#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.h b/lib/op-attrs/include/op-attrs/ff_dim_t.h new file mode 100644 index 0000000000..a9a7233fbf --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H + +#include "op-attrs/ff_dim_t.dtg.h" +#include "rapidcheck.h" + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H diff --git a/lib/op-attrs/include/op-attrs/ff_dim.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml similarity index 100% rename from lib/op-attrs/include/op-attrs/ff_dim.struct.toml rename to lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml diff --git a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml index 585295fe1c..e7eeedec06 100644 --- a/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/combine_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index fab8132993..f3c66d0416 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h" + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h" ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index 7349e2a8c4..301df8bca4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -11,14 +11,14 @@ features = [ includes = [ "", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.dtg.h", ] src_includes = [ "utils/fmt/optional.h", "utils/json/optional.h", "utils/rapidcheck/optional.h", - "op-attrs/ff_dim.h", + "op-attrs/ff_dim_t.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml index c8bb88dcc7..66d475aa46 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/gather_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h" + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h" ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml index ec60d39f7f..401eaeeec4 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml index 717e7954e8..88e57ef7c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reduce_attrs.struct.toml @@ -11,8 +11,8 @@ features = [ includes = [ "op-attrs/operator_type.dtg.h", - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", "utils/stack_vector.h", ] diff --git a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml index 25a33c0c15..69c4b7580f 100644 --- a/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/repartition_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml index 198346e5dd..2577ac1398 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/reverse_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml index 8b839c122a..49172f44b0 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/softmax_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml index 8cdf7728af..f10aa7c3fd 100644 --- a/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/split_attrs.struct.toml @@ -11,8 +11,8 @@ features = [ includes = [ "utils/stack_vector.h", - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index 0dc30d9a79..b1c5f60382 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/ff_dim.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.h", + "op-attrs/ff_dim_t.dtg.h", "op-attrs/dim_ordered/dim_ordered.h", ] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml index 9396cbcbe8..7e7356a5e7 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml @@ -9,7 +9,7 @@ features = [ ] includes = [ - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.dtg.h", "op-attrs/replica_type.dtg.h", ] diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h new file mode 100644 index 0000000000..2f95eac1d4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H + +#include "op-attrs/relative_ff_dim_t.dtg.h" +#include "rapidcheck.h" + +namespace rc { +template <> +struct Arbitrary { + static Gen arbitrary(); +}; +} // namespace rc + +#endif // _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml new file mode 100644 index 0000000000..a93b649052 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "relative_ff_dim_t" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "value" +type = "int" diff --git a/lib/op-attrs/src/op-attrs/ff_dim_t.cc b/lib/op-attrs/src/op-attrs/ff_dim_t.cc new file mode 100644 index 0000000000..043d2b2309 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ff_dim_t.cc @@ -0,0 +1,8 @@ +#include "op-attrs/ff_dim_t.h" + +namespace rc { +Gen Arbitrary::arbitrary() { + return gen::construct( + gen::inRange(0, MAX_TENSOR_DIM)); +} +} // namespace rc diff --git a/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc new file mode 100644 index 0000000000..8080da203d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc @@ -0,0 +1,10 @@ +#include "op-attrs/relative_ff_dim_t.h" +#include "rapidcheck.h" + +namespace rc { +Gen + Arbitrary::arbitrary() { + return gen::construct( + gen::inRange(-MAX_TENSOR_DIM, MAX_TENSOR_DIM)); +} +} // namespace rc diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc index 8e3d0f1b80..b77bb8f71e 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc @@ -1,5 +1,5 @@ #include "op-attrs/dim_ordered/zip.h" -#include "op-attrs/ff_dim.dtg.h" +#include "op-attrs/ff_dim_t.dtg.h" #include "test/utils/doctest/fmt/pair.h" #include diff --git a/lib/runtime/src/parallel_op_info.h b/lib/runtime/src/parallel_op_info.h index ebd44f012b..49ad22be74 100644 --- a/lib/runtime/src/parallel_op_info.h +++ b/lib/runtime/src/parallel_op_info.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_PARALLEL_OPS_PARALLEL_OP_INFO_H #define _FLEXFLOW_PARALLEL_OPS_PARALLEL_OP_INFO_H -#include "op-attrs/ff_dim.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/operator_type.h" #include "utils/visitable.h" #include diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 02a856f59a..8fe4a9494d 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -13,7 +13,7 @@ includes = [ "", "", "op-attrs/operator_type.dtg.h", - "op-attrs/ff_dim.dtg.h", + "op-attrs/ff_dim_t.dtg.h", "op-attrs/activation.dtg.h", "op-attrs/aggregate_op.dtg.h", "op-attrs/regularizer_attrs.dtg.h", diff --git a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h new file mode 100644 index 0000000000..f3fbc4556d --- /dev/null +++ b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h @@ -0,0 +1,65 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONNEGATIVE_INT_NONNEGATIVE_INT_H + +#include "rapidcheck.h" + +#include +#include +#include +#include +#include + +namespace FlexFlow { +class nonnegative_int { +public: + nonnegative_int() = delete; + explicit nonnegative_int(int value); + + explicit operator int() const noexcept; + + bool operator<(nonnegative_int const &other) const; + bool operator==(nonnegative_int const &other) const; + bool operator>(nonnegative_int const &other) const; + bool operator<=(nonnegative_int const &other) const; + bool operator!=(nonnegative_int const &other) const; + bool operator>=(nonnegative_int const &other) const; + + bool operator<(int const &other) const; + bool operator==(int const &other) const; + bool operator>(int const &other) const; + bool operator<=(int const &other) const; + bool operator!=(int const &other) const; + bool operator>=(int const &other) const; + + friend bool operator<(int const &lhs, nonnegative_int const &rhs); + friend bool operator==(int const &lhs, nonnegative_int const &rhs); + friend bool operator>(int const &lhs, nonnegative_int const &rhs); + friend bool operator<=(int const &lhs, nonnegative_int const &rhs); + friend bool operator!=(int const &lhs, nonnegative_int const &rhs); + friend bool operator>=(int const &lhs, nonnegative_int const &rhs); + + friend std::ostream &operator<<(std::ostream &os, nonnegative_int const &n); + + int get_value() const; + +private: + int value_; +}; +} // namespace FlexFlow + +namespace nlohmann { +template <> +struct adl_serializer<::FlexFlow::nonnegative_int> { + static ::FlexFlow::nonnegative_int from_json(json const &j); + static void to_json(json &j, ::FlexFlow::nonnegative_int t); +}; +} // namespace nlohmann + +namespace std { +template <> +struct hash { + std::size_t operator()(FlexFlow::nonnegative_int const &n) const noexcept; +}; +} // namespace std + +#endif diff --git a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc new file mode 100644 index 0000000000..3218d94ebf --- /dev/null +++ b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc @@ -0,0 +1,101 @@ +#include "utils/nonnegative_int/nonnegative_int.h" + +namespace FlexFlow { + +nonnegative_int::nonnegative_int(int value) { + if (value < 0) { + throw std::invalid_argument( + "Value of nonnegative_int type must be nonnegative."); + } + this->value_ = value; +} + +nonnegative_int::operator int() const noexcept { + return this->value_; +} + +bool nonnegative_int::operator<(nonnegative_int const &other) const { + return this->value_ < other.value_; +} +bool nonnegative_int::operator==(nonnegative_int const &other) const { + return this->value_ == other.value_; +} +bool nonnegative_int::operator>(nonnegative_int const &other) const { + return this->value_ > other.value_; +} +bool nonnegative_int::operator<=(nonnegative_int const &other) const { + return this->value_ <= other.value_; +} +bool nonnegative_int::operator!=(nonnegative_int const &other) const { + return this->value_ != other.value_; +} +bool nonnegative_int::operator>=(nonnegative_int const &other) const { + return this->value_ >= other.value_; +} + +bool nonnegative_int::operator<(int const &other) const { + return this->value_ < other; +} +bool nonnegative_int::operator==(int const &other) const { + return this->value_ == other; +} +bool nonnegative_int::operator>(int const &other) const { + return this->value_ > other; +} +bool nonnegative_int::operator<=(int const &other) const { + return this->value_ <= other; +} +bool nonnegative_int::operator!=(int const &other) const { + return this->value_ != other; +} +bool nonnegative_int::operator>=(int const &other) const { + return this->value_ >= other; +} + +bool operator<(int const &lhs, nonnegative_int const &rhs) { + return lhs < rhs.value_; +} +bool operator==(int const &lhs, nonnegative_int const &rhs) { + return lhs == rhs.value_; +} +bool operator>(int const &lhs, nonnegative_int const &rhs) { + return lhs > rhs.value_; +} +bool operator<=(int const &lhs, nonnegative_int const &rhs) { + return lhs <= rhs.value_; +} +bool operator!=(int const &lhs, nonnegative_int const &rhs) { + return lhs != rhs.value_; +} +bool operator>=(int const &lhs, nonnegative_int const &rhs) { + return lhs >= rhs.value_; +} + +std::ostream &operator<<(std::ostream &os, nonnegative_int const &n) { + os << n.value_; + return os; +} + +int nonnegative_int::get_value() const { + return this->value_; +} +} // namespace FlexFlow + +namespace nlohmann { +::FlexFlow::nonnegative_int + adl_serializer<::FlexFlow::nonnegative_int>::from_json(json const &j) { + return ::FlexFlow::nonnegative_int{j.template get()}; +} + +void adl_serializer<::FlexFlow::nonnegative_int>::to_json( + json &j, ::FlexFlow::nonnegative_int t) { + j = t.get_value(); +} +} // namespace nlohmann + +namespace std { +std::size_t hash::operator()( + FlexFlow::nonnegative_int const &n) const noexcept { + return std::hash{}(n.get_value()); +} +} // namespace std diff --git a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc new file mode 100644 index 0000000000..97ca027354 --- /dev/null +++ b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc @@ -0,0 +1,249 @@ +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nonnegative_int initialization") { + SUBCASE("positive int initialization") { + CHECK_NOTHROW(nonnegative_int{1}); + } + + SUBCASE("zero initialization") { + CHECK_NOTHROW(nonnegative_int{0}); + } + + SUBCASE("negative int initialization") { + CHECK_THROWS(nonnegative_int{-1}); + } + } + + TEST_CASE("nonnegative_int == comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equal") { + CHECK((nn_int_1a == nn_int_1b) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, not equal") { + CHECK((nn_int_1a == nn_int_2) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equal") { + CHECK((nn_int_1a == 1) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, not equal") { + CHECK((nn_int_1a == 2) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equal") { + CHECK((1 == nn_int_1b) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, not equal") { + CHECK((2 == nn_int_1b) == false); + } + } + + TEST_CASE("nonnegative_int != comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equal") { + CHECK((nn_int_1a != nn_int_1b) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, not equal") { + CHECK((nn_int_1a != nn_int_2) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equal") { + CHECK((nn_int_1a != 1) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, not equal") { + CHECK((nn_int_1a != 2) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equal") { + CHECK((1 != nn_int_1b) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, not equal") { + CHECK((2 != nn_int_1b) == true); + } + } + + TEST_CASE("nonnegative_int < comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK((nn_int_1a < nn_int_2) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK((nn_int_1a < nn_int_1b) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK((nn_int_2 < nn_int_1b) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK((nn_int_1a < 2) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK((nn_int_1a < 1) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK((nn_int_2 < 1) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK((1 < nn_int_2) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK((1 < nn_int_1b) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK((2 < nn_int_1b) == false); + } + } + + TEST_CASE("nonnegative_int <= comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK((nn_int_1a <= nn_int_2) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK((nn_int_1a <= nn_int_1b) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK((nn_int_2 <= nn_int_1b) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK((nn_int_1a <= 2) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK((nn_int_1a <= 1) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK((nn_int_2 <= 1) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK((1 <= nn_int_2) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK((1 <= nn_int_1b) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK((2 <= nn_int_1b) == false); + } + } + + TEST_CASE("nonnegative_int > comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK((nn_int_1a > nn_int_2) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK((nn_int_1a > nn_int_1b) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK((nn_int_2 > nn_int_1b) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK((nn_int_1a > 2) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK((nn_int_1a > 1) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK((nn_int_2 > 1) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK((1 > nn_int_2) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK((1 > nn_int_1b) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK((2 > nn_int_1b) == true); + } + } + + TEST_CASE("nonnegative_int >= comparisons") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { + CHECK((nn_int_1a >= nn_int_2) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { + CHECK((nn_int_1a >= nn_int_1b) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { + CHECK((nn_int_2 >= nn_int_1b) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, less than") { + CHECK((nn_int_1a >= 2) == false); + } + SUBCASE("LHS: nonnegative_int, RHS: int, equals") { + CHECK((nn_int_1a >= 1) == true); + } + SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { + CHECK((nn_int_2 >= 1) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, less than") { + CHECK((1 >= nn_int_2) == false); + } + SUBCASE("LHS: int, RHS: nonnegative_int, equals") { + CHECK((1 >= nn_int_1b) == true); + } + SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { + CHECK((2 >= nn_int_1b) == true); + } + } + + TEST_CASE("nonnegative_int adl_serializer") { + SUBCASE("to_json") { + nonnegative_int input = nonnegative_int{5}; + + nlohmann::json result = input; + nlohmann::json correct = 5; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + nlohmann::json input = 5; + + nonnegative_int result = input.template get(); + nonnegative_int correct = nonnegative_int{5}; + + CHECK(result == correct); + } + } + + TEST_CASE("std::hash") { + nonnegative_int nn_int_1a = nonnegative_int{1}; + nonnegative_int nn_int_1b = nonnegative_int{1}; + nonnegative_int nn_int_2 = nonnegative_int{2}; + std::hash hash_fn; + SUBCASE("Identical values have the same hash") { + CHECK(hash_fn(nn_int_1a) == hash_fn(nn_int_1b)); + } + SUBCASE("Different values have different hashes") { + CHECK(hash_fn(nn_int_1a) != hash_fn(nn_int_2)); + } + SUBCASE("Unordered set works with nonnegative_int") { + std::unordered_set nonnegative_int_set; + nonnegative_int_set.insert(nn_int_1a); + nonnegative_int_set.insert(nn_int_1b); + nonnegative_int_set.insert(nn_int_2); + + CHECK(nonnegative_int_set.size() == 2); + } + } + + TEST_CASE("nonnegative int >> operator") { + nonnegative_int nn_int_1 = nonnegative_int{1}; + std::ostringstream oss; + oss << nn_int_1; + + CHECK(oss.str() == "1"); + } +}