From 3cae680c35f9b2ebb4a0ee58756907c26cf0d82b Mon Sep 17 00:00:00 2001 From: Victor Li Date: Thu, 14 Nov 2024 19:34:57 -0800 Subject: [PATCH] Changed ff_dim_t to use nonnegative_int, added relative_ff_dim_t that uses int --- lib/kernels/src/legion_dim.cc | 2 +- lib/kernels/test/src/test_concat_kernel.cc | 2 +- lib/kernels/test/src/test_transpose_kernel.cc | 3 +- .../src/legion_tensor_shape.cc | 5 +- lib/local-execution/src/ops/linear.cc | 12 +- lib/local-execution/src/ops/pool_2d.cc | 16 +- lib/local-execution/src/ops/reverse.cc | 14 +- lib/local-execution/src/ops/softmax.cc | 9 +- lib/local-execution/src/ops/split.cc | 13 +- .../op-attrs/dim_ordered/dim_ordered.h | 185 ++++++++++++++++-- .../include/op-attrs/dim_ordered/enumerate.h | 2 +- .../dim_ordered/ff_ordered_from_map.h | 5 +- .../include/op-attrs/dim_ordered/get_idxs.h | 4 +- .../include/op-attrs/dim_ordered/slice.h | 83 +++++++- lib/op-attrs/include/op-attrs/ff_dim_t.h | 9 +- .../include/op-attrs/ff_dim_t.struct.toml | 6 +- .../parallel_tensor_dim_idx_t.variant.toml | 4 +- .../include/op-attrs/parallel_tensor_dims.h | 5 +- .../include/op-attrs/parallel_tensor_shape.h | 8 +- .../include/op-attrs/relative_ff_dim_t.h | 9 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 4 +- lib/op-attrs/include/op-attrs/tensor_shape.h | 4 +- lib/op-attrs/src/op-attrs/ff_dim_t.cc | 13 +- lib/op-attrs/src/op-attrs/ops/attention.cc | 6 +- .../attention/multihead_attention_inputs.cc | 18 +- .../multihead_attention_parallel_inputs.cc | 21 +- lib/op-attrs/src/op-attrs/ops/batch_matmul.cc | 24 +-- lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 12 +- lib/op-attrs/src/op-attrs/ops/broadcast.cc | 2 +- lib/op-attrs/src/op-attrs/ops/combine.cc | 8 +- lib/op-attrs/src/op-attrs/ops/concat.cc | 16 +- .../ops/conv_2d/conv_2d_input_shape.cc | 8 +- .../conv_2d/conv_2d_parallel_input_shape.cc | 8 +- lib/op-attrs/src/op-attrs/ops/embedding.cc | 6 +- lib/op-attrs/src/op-attrs/ops/flat.cc | 9 +- lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 9 +- lib/op-attrs/src/op-attrs/ops/linear.cc | 28 +-- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 16 +- lib/op-attrs/src/op-attrs/ops/repartition.cc | 5 +- lib/op-attrs/src/op-attrs/ops/softmax.cc | 3 +- .../src/op-attrs/parallel_tensor_dims.cc | 6 +- .../src/op-attrs/parallel_tensor_shape.cc | 13 +- .../src/op-attrs/relative_ff_dim_t.cc | 17 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 4 +- lib/op-attrs/src/op-attrs/tensor_shape.cc | 4 +- .../src/op-attrs/dim_ordered/enumerate.cc | 6 +- .../dim_ordered/ff_ordered_from_map.cc | 30 +-- .../test/src/op-attrs/dim_ordered/slice.cc | 48 ++++- lib/op-attrs/test/src/op-attrs/ff_dim_t.cc | 20 ++ .../src/op-attrs/get_incoming_tensor_roles.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/combine.cc | 4 +- lib/op-attrs/test/src/op-attrs/ops/concat.cc | 6 +- .../test/src/op-attrs/ops/element_binary.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/flat.cc | 32 +-- .../test/src/op-attrs/ops/layer_norm.cc | 4 +- .../test/src/op-attrs/ops/repartition.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/softmax.cc | 14 +- .../test/src/op-attrs/pcg_operator_attrs.cc | 2 +- .../test/src/op-attrs/relative_ff_dim_t.cc | 36 ++++ .../include/pcg/computation_graph_builder.h | 2 +- lib/pcg/src/pcg/computation_graph_builder.cc | 40 ++-- .../parallel_computation_graph_builder.cc | 6 +- lib/runtime/src/ops/embedding.cc | 2 +- .../test/src/test_substitution.cc | 6 +- .../utils/nonnegative_int/nonnegative_int.h | 4 +- .../utils/nonnegative_int/nonnegative_int.cc | 6 +- .../utils/nonnegative_int/nonnegative_int.cc | 105 +++++----- 67 files changed, 719 insertions(+), 320 deletions(-) create mode 100644 lib/op-attrs/test/src/op-attrs/ff_dim_t.cc create mode 100644 lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc diff --git a/lib/kernels/src/legion_dim.cc b/lib/kernels/src/legion_dim.cc index 9ef47d40ae..142dcbcb2c 100644 --- a/lib/kernels/src/legion_dim.cc +++ b/lib/kernels/src/legion_dim.cc @@ -7,7 +7,7 @@ legion_dim_t add_to_legion_dim(legion_dim_t legion_dim, int value) { } legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, int num_dimensions) { - return legion_dim_t(num_dimensions - ff_dim.value - 1); + return legion_dim_t(num_dimensions - ff_dim.value.get_value() - 1); } } // namespace FlexFlow diff --git a/lib/kernels/test/src/test_concat_kernel.cc b/lib/kernels/test/src/test_concat_kernel.cc index bf2a521b4e..55ce7da331 100644 --- a/lib/kernels/test/src/test_concat_kernel.cc +++ b/lib/kernels/test/src/test_concat_kernel.cc @@ -7,7 +7,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test concat kernel forward and backward") { size_t num_inputs = 3; size_t size_per_input = 100; - ff_dim_t concat_axis = ff_dim_t(0); + ff_dim_t concat_axis = ff_dim_t{nonnegative_int{0}}; ManagedPerDeviceFFHandle managed_handle{}; ManagedFFStream managed_stream{}; diff --git a/lib/kernels/test/src/test_transpose_kernel.cc b/lib/kernels/test/src/test_transpose_kernel.cc index 2fc186a257..2904fa01ae 100644 --- a/lib/kernels/test/src/test_transpose_kernel.cc +++ b/lib/kernels/test/src/test_transpose_kernel.cc @@ -7,7 +7,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Test Transpose Kernel Operations") { std::size_t num_dims = 2; - std::vector perm = {ff_dim_t(0), ff_dim_t(1)}; + std::vector perm = {ff_dim_t{nonnegative_int{0}}, + ff_dim_t{nonnegative_int{1}}}; ManagedPerDeviceFFHandle managed_handle{}; ManagedFFStream managed_stream{}; diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc index bce29fafeb..b227accc2e 100644 --- a/lib/local-execution/src/legion_tensor_shape.cc +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -1,14 +1,15 @@ #include "local-execution/legion_tensor_shape.h" +#include "kernels/legion_dim.h" #include "op-attrs/tensor_shape.h" namespace FlexFlow { legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, size_t num_dims) { - return legion_dim_t(num_dims - ff_dim.value - 1); + return legion_dim_t(num_dims - ff_dim.value.get_value() - 1); } legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, TensorShape const &shape) { - return legion_dim_t(num_dims(shape) - ff_dim.value - 1); + return legion_dim_from_ff_dim(ff_dim, num_dims(shape)); } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index f3ef022175..3e0b4672ab 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -66,8 +66,8 @@ static DeviceSpecificDeviceStates auto input = acc.get_tensor(INPUT); auto weight = acc.get_tensor(WEIGHT); auto output = acc.get_tensor(OUTPUT); - int out_dim = output.shape.at(ff_dim_t{0}); - int batch_size = output.shape.at(ff_dim_t{1}); + int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}); + int batch_size = output.shape.at(ff_dim_t{nonnegative_int{1}}); float *one_ptr; @@ -96,8 +96,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_argument(PROFILING); auto attrs = acc.get_argument(ATTRS); - int in_dim = input.shape.at(ff_dim_t{0}) + 1; - int out_dim = output.shape.at(ff_dim_t{0}) + 1; + int in_dim = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; + int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; int batch_size = output.shape.get_volume() / out_dim; float const *bias_ptr = NULL; @@ -140,8 +140,8 @@ static std::optional bias_ptr = bias.get_float_ptr(); } - int in_dim = input.shape.at(ff_dim_t{0}) + 1; - int out_dim = output.shape.at(ff_dim_t{0}) + 1; + int in_dim = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; + int out_dim = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; int batch_size = output.shape.get_volume() / out_dim; return profile(backward_kernel, diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc index 33d62b713c..3ab33a2ad6 100644 --- a/lib/local-execution/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -30,14 +30,14 @@ static DeviceSpecificDeviceStates auto input = acc.get_tensor(INPUT); auto output = acc.get_tensor(OUTPUT); - int input_w = input.shape.at(ff_dim_t(0)) + 1; - int input_h = input.shape.at(ff_dim_t(1)) + 1; - int input_c = input.shape.at(ff_dim_t(2)) + 1; - int input_n = input.shape.at(ff_dim_t(3)) + 1; - int output_w = output.shape.at(ff_dim_t(0)) + 1; - int output_h = output.shape.at(ff_dim_t(1)) + 1; - int output_c = output.shape.at(ff_dim_t(2)) + 1; - int output_n = output.shape.at(ff_dim_t(3)) + 1; + int input_w = input.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; + int input_h = input.shape.at(ff_dim_t{nonnegative_int{1}}) + 1; + int input_c = input.shape.at(ff_dim_t{nonnegative_int{2}}) + 1; + int input_n = input.shape.at(ff_dim_t{nonnegative_int{3}}) + 1; + int output_w = output.shape.at(ff_dim_t{nonnegative_int{0}}) + 1; + int output_h = output.shape.at(ff_dim_t{nonnegative_int{1}}) + 1; + int output_c = output.shape.at(ff_dim_t{nonnegative_int{2}}) + 1; + int output_n = output.shape.at(ff_dim_t{nonnegative_int{3}}) + 1; printf("init pool (input): n(%d) c(%d) h(%d) " "w(%d)\n", diff --git a/lib/local-execution/src/ops/reverse.cc b/lib/local-execution/src/ops/reverse.cc index 366a579bea..8ac4c045c7 100644 --- a/lib/local-execution/src/ops/reverse.cc +++ b/lib/local-execution/src/ops/reverse.cc @@ -53,11 +53,11 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1; for (int i = 0; i < output.shape.get_dim(); i++) { if (i < axis.value) { - in_blk_size *= output.shape.at(ff_dim_t(i)); + in_blk_size *= output.shape.at(ff_dim_t{nonnegative_int{i}}); } else if (i == axis.value) { - reverse_dim_size = output.shape.at(ff_dim_t(i)); + reverse_dim_size = output.shape.at(ff_dim_t{nonnegative_int{i}}); } else { - num_out_blks *= output.shape.at(ff_dim_t(i)); + num_out_blks *= output.shape.at(ff_dim_t{nonnegative_int{i}}); } } @@ -79,15 +79,15 @@ static std::optional auto output_grad = acc.get_tensor_grad(OUTPUT); auto attrs = acc.get_argument(ATTRS); - int axis = input_grad.shape.get_dim() - attrs.axis.value - 1; + int axis = input_grad.shape.get_dim() - attrs.axis.value.get_value() - 1; coord_t in_blk_size = 1, reverse_dim_size = 1, num_out_blks = 1; for (int i = 0; i < input_grad.shape.get_dim(); i++) { if (i < axis) { - in_blk_size *= input_grad.shape.at(ff_dim_t(i)); + in_blk_size *= input_grad.shape.at(ff_dim_t{nonnegative_int{i}}); } else if (i == axis) { - reverse_dim_size = input_grad.shape.at(ff_dim_t(i)); + reverse_dim_size = input_grad.shape.at(ff_dim_t{nonnegative_int{i}}); } else { - num_out_blks *= input_grad.shape.at(ff_dim_t(i)); + num_out_blks *= input_grad.shape.at(ff_dim_t{nonnegative_int{i}}); } } diff --git a/lib/local-execution/src/ops/softmax.cc b/lib/local-execution/src/ops/softmax.cc index 4c7979ae9b..8d412c739b 100644 --- a/lib/local-execution/src/ops/softmax.cc +++ b/lib/local-execution/src/ops/softmax.cc @@ -64,8 +64,13 @@ static DeviceSpecificDeviceStates int output_c = output.shape.at(legion_dim_t(2)); int output_n = output.shape.at(legion_dim_t(3)); - SoftmaxPerDeviceState per_device_state = init_kernel( - handle, attrs.dim.value, output_n, output_c, output_h, output_w); + SoftmaxPerDeviceState per_device_state = + init_kernel(handle, + attrs.dim.value.get_value(), + output_n, + output_c, + output_h, + output_w); return DeviceSpecificDeviceStates{ DeviceSpecific::create(per_device_state)}; diff --git a/lib/local-execution/src/ops/split.cc b/lib/local-execution/src/ops/split.cc index 9f039d84f8..c289bca205 100644 --- a/lib/local-execution/src/ops/split.cc +++ b/lib/local-execution/src/ops/split.cc @@ -47,11 +47,11 @@ OpTaskInvocation backward(SplitAttrs const &attrs) { void calc_block_size(coord_t &num_blocks, coord_t &block_size, ArrayShape const &array_shape, - int axis) { + ff_dim_t axis) { num_blocks = 1; block_size = 1; for (int d = 0; d < array_shape.num_elements(); d++) { - if (d <= axis) { + if (d <= axis.value.get_value()) { block_size *= array_shape.at(legion_dim_t(d)); } else { num_blocks *= array_shape.at(legion_dim_t(d)); @@ -66,12 +66,12 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto attrs = acc.get_argument(ATTRS); coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS]; - calc_block_size(num_blocks, in_block_size, input.shape, attrs.axis.value); + calc_block_size(num_blocks, in_block_size, input.shape, attrs.axis); for (int i = 0; i < attrs.splits.size(); i++) { coord_t out_num_blocks; calc_block_size( - out_num_blocks, out_block_size[i], output.shape, attrs.axis.value); + out_num_blocks, out_block_size[i], output.shape, attrs.axis); } float *output_float_ptr = output.get_float_ptr(); return profile(forward_kernel, @@ -94,12 +94,11 @@ static std::optional auto attrs = acc.get_argument(ATTRS); coord_t num_blocks, in_block_size, out_block_size[MAX_NUM_OUTPUTS]; - calc_block_size( - num_blocks, in_block_size, input_grad.shape, attrs.axis.value); + calc_block_size(num_blocks, in_block_size, input_grad.shape, attrs.axis); for (int i = 0; i < attrs.splits.size(); i++) { coord_t out_num_blocks; calc_block_size( - out_num_blocks, out_block_size[i], output_grad.shape, attrs.axis.value); + out_num_blocks, out_block_size[i], output_grad.shape, attrs.axis); } float const *output_grad_ptr = output_grad.get_float_ptr(); return profile(backward_kernel, diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 9a8d8adc0b..70083d6444 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -2,6 +2,8 @@ #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #include "op-attrs/ff_dim_t.dtg.h" +#include "op-attrs/relative_ff_dim_t.dtg.h" +#include "utils/containers/range.h" #include "utils/fmt/vector.h" #include "utils/stack_vector.h" #include @@ -152,6 +154,169 @@ struct DimOrdered { stack_vector contents; }; +template +struct DimOrdered { + DimOrdered() {} + + DimOrdered(std::initializer_list const &l) + : contents(l.begin(), l.end()) {} + + DimOrdered(std::vector const &contents) + : contents(contents.begin(), contents.end()) {} + + template + DimOrdered(It begin, It end) : contents(begin, end) {} + + template + DimOrdered(stack_vector const &contents) + : contents(contents.begin(), contents.end()) {} + + T const &at(ff_dim_t idx) const { + int raw = idx.value.get_value(); + return this->contents.at(raw); + } + + T const &at(relative_ff_dim_t idx) const { + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); + } + + T &at(ff_dim_t idx) { + int raw = idx.value.get_value(); + return this->contents.at(raw); + } + + T &at(relative_ff_dim_t idx) { + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return this->contents.at(raw); + } + + T const &operator[](ff_dim_t idx) const { + return this->at(idx); + } + + T const &operator[](relative_ff_dim_t idx) const { + return this->at(idx); + } + + T &operator[](ff_dim_t idx) { + return this->at(idx); + } + + T &operator[](relative_ff_dim_t idx) { + return this->at(idx); + } + + bool idx_is_valid(ff_dim_t const &idx) const { + int raw = idx.value.get_value(); + return raw < this->contents.size(); + } + + bool idx_is_valid(relative_ff_dim_t const &idx) const { + int raw = idx.value; + if (raw < 0) { + raw = this->contents.size() + raw; + } + return (raw >= 0 && raw < this->contents.size()); + } + + bool operator==(DimOrdered const &other) const { + return this->contents == other.contents; + } + + bool operator!=(DimOrdered const &other) const { + return this->contents != other.contents; + } + + bool operator<(DimOrdered const &other) const { + return this->contents < other.contents; + } + + using iterator = typename stack_vector::iterator; + using const_iterator = + typename stack_vector::const_iterator; + using reverse_iterator = + typename stack_vector::reverse_iterator; + using const_reverse_iterator = + typename stack_vector::const_reverse_iterator; + using value_type = T; + using pointer = value_type *; + using const_pointer = value_type const *; + using reference = value_type &; + using const_reference = value_type const &; + + iterator begin() { + return this->contents.begin(); + } + + const_iterator begin() const { + return this->cbegin(); + } + + const_iterator cbegin() const { + return this->contents.cbegin(); + } + + iterator end() { + return this->contents.end(); + } + + const_iterator end() const { + return this->cend(); + } + + const_iterator cend() const { + return this->contents.cend(); + } + + reverse_iterator rbegin() { + return this->contents.rbegin(); + } + + const_reverse_iterator rbegin() const { + return this->crbegin(); + } + + const_reverse_iterator crbegin() const { + return this->contents.crbegin(); + } + + reverse_iterator rend() { + return this->contents.crend(); + } + + const_reverse_iterator rend() const { + return this->crend(); + } + + const_reverse_iterator crend() const { + return this->contents.crend(); + } + + size_t size() const { + return this->contents.size(); + } + + size_t empty() const { + return this->contents.empty(); + } + + size_t num_dims() const { + return this->size(); + } + + friend struct ::std::hash; + +private: + stack_vector contents; +}; + template using FFOrdered = DimOrdered; @@ -166,31 +331,15 @@ std::ostream &operator<<(std::ostream &s, FFOrdered const &v) { return (s << fmt::to_string(v)); } -template -auto inner_to_outer(FFOrdered const &ff_ordered) - -> decltype(reversed_container(ff_ordered)) { - return reversed_container(ff_ordered); -} - template std::vector inner_to_outer_idxs(FFOrdered const &ff_ordered) { std::vector idxs; - for (size_t i = 0; i < ff_ordered.size(); i++) { - idxs.push_back(ff_dim_t(ff_ordered.size() - i - 1)); + for (int i : range(0, ff_ordered.size())) { + idxs.push_back(ff_dim_t{nonnegative_int{ff_ordered.size() - i - 1}}); } return idxs; } -template -std::vector outer_to_inner_idxs(FFOrdered const &ff_ordered) { - return reversed(inner_to_outer_idxs(ff_ordered)); -} - -template -FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { - return ff_ordered; -} - } // namespace FlexFlow /* template */ diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h index 38e7da4bb2..9e4271a1ff 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h @@ -19,7 +19,7 @@ template std::map enumerate(FFOrdered const &ff_ordered) { std::map result; for (int raw_ff_dim : count(ff_ordered.size())) { - ff_dim_t ff_dim = ff_dim_t{raw_ff_dim}; + ff_dim_t ff_dim = ff_dim_t{nonnegative_int{raw_ff_dim}}; result.insert({ff_dim, ff_ordered.at(ff_dim)}); } return result; diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h index 79d4929797..f8f49233ec 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h @@ -3,6 +3,7 @@ #include "op-attrs/dim_ordered/dim_ordered.h" #include "op-attrs/dim_ordered/ff_ordered_of.h" +#include "op-attrs/ff_dim_t.h" namespace FlexFlow { @@ -10,7 +11,7 @@ template FFOrdered ff_ordered_from_map(std::map const &m) { std::vector raw; for (int i = 0; i < m.size(); i++) { - raw.push_back(m.at(ff_dim_t{i})); + raw.push_back(m.at(ff_dim_t{nonnegative_int{i}})); } return ff_ordered_of(raw); } @@ -19,7 +20,7 @@ template FFOrdered ff_ordered_from_map(std::unordered_map const &m) { std::vector raw; for (int i = 0; i < m.size(); i++) { - raw.push_back(m.at(ff_dim_t{i})); + raw.push_back(m.at(ff_dim_t{nonnegative_int{i}})); } return ff_ordered_of(raw); } diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h index 7343dc0e69..4e7f8530a4 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H #include "op-attrs/dim_ordered/dim_ordered.h" +#include "op-attrs/ff_dim_t.h" #include "utils/containers/count.h" #include "utils/containers/transform.h" @@ -9,7 +10,8 @@ namespace FlexFlow { template std::vector get_idxs(FFOrdered const &d) { - return transform(count(d.size()), [](int i) { return ff_dim_t{i}; }); + return transform(count(d.size()), + [](int i) { return ff_dim_t{nonnegative_int{i}}; }); } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index e4c0e8e275..d61b51680d 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -22,17 +22,29 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, } template -FFOrdered slice(FFOrdered const &d, - std::optional const &start, - std::optional const &end) { - return nonoverloaded_slice(d, start, end); +FFOrdered ff_dim_t_nonoverloaded_slice(FFOrdered const &d, + std::optional const &start, + std::optional const &end) { + auto to_raw_idx = + [](std::optional const &idx) -> std::optional { + return transform(idx, + [](ff_dim_t const &i) { return i.value.get_value(); }); + }; + + return FFOrdered{subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } -template -DimOrdered slice(DimOrdered const &d, - std::optional const &start, - std::optional const &end) { - return nonoverloaded_slice(d, start, end); +template +FFOrdered relative_ff_dim_t_nonoverloaded_slice( + FFOrdered const &d, + std::optional const &start, + std::optional const &end) { + auto to_raw_idx = + [](std::optional const &idx) -> std::optional { + return transform(idx, [](relative_ff_dim_t const &i) { return i.value; }); + }; + + return FFOrdered{subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } template @@ -58,6 +70,59 @@ DimOrdered d, std::optional{start}, std::optional{end}); } +template +FFOrdered slice(FFOrdered const &d, + std::nullopt_t const &start, + ff_dim_t const &end) { + return ff_dim_t_nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +template +FFOrdered slice(FFOrdered const &d, + ff_dim_t const &start, + std::nullopt_t const &end) { + return ff_dim_t_nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +template +FFOrdered + slice(FFOrdered const &d, ff_dim_t const &start, ff_dim_t const &end) { + return ff_dim_t_nonoverloaded_slice( + d, std::optional{start}, std::optional{end}); +} + +template +FFOrdered slice(FFOrdered const &d, + std::nullopt_t const &start, + relative_ff_dim_t const &end) { + return relative_ff_dim_t_nonoverloaded_slice( + d, + std::optional{start}, + std::optional{end}); +} + +template +FFOrdered slice(FFOrdered const &d, + relative_ff_dim_t const &start, + std::nullopt_t const &end) { + return relative_ff_dim_t_nonoverloaded_slice( + d, + std::optional{start}, + std::optional{end}); +} + +template +FFOrdered slice(FFOrdered const &d, + relative_ff_dim_t const &start, + relative_ff_dim_t const &end) { + return relative_ff_dim_t_nonoverloaded_slice( + d, + std::optional{start}, + std::optional{end}); +} + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.h b/lib/op-attrs/include/op-attrs/ff_dim_t.h index a9a7233fbf..6ca38276eb 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim_t.h +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.h @@ -2,12 +2,17 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_FF_DIM_T_H #include "op-attrs/ff_dim_t.dtg.h" +#include "op-attrs/relative_ff_dim_t.dtg.h" #include "rapidcheck.h" +namespace FlexFlow { +relative_ff_dim_t ff_dim_t_to_relative_ff_dim_t(ff_dim_t ff_dim); +} // namespace FlexFlow + namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::ff_dim_t> { + static Gen<::FlexFlow::ff_dim_t> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml b/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml index 441f9826ca..38f51da4a1 100644 --- a/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml +++ b/lib/op-attrs/include/op-attrs/ff_dim_t.struct.toml @@ -9,6 +9,10 @@ features = [ "fmt", ] +includes = [ + "utils/nonnegative_int/nonnegative_int.h" +] + [[fields]] name = "value" -type = "int" +type = "::FlexFlow::nonnegative_int" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml index 7e7356a5e7..fbef10b7ce 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml @@ -9,12 +9,12 @@ features = [ ] includes = [ - "op-attrs/ff_dim_t.dtg.h", + "op-attrs/relative_ff_dim_t.dtg.h", "op-attrs/replica_type.dtg.h", ] [[values]] -type = "::FlexFlow::ff_dim_t" +type = "::FlexFlow::relative_ff_dim_t" [[values]] type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 1b8361abf6..6b88a7bda1 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -31,8 +31,9 @@ int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); -ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &, ff_dim_t); -ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &, ff_dim_t); +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &, + relative_ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &, relative_ff_dim_t); bool is_valid(ParallelTensorDims const &); TensorDims get_piece_dims(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 0759dc746e..0339b9b8a6 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,6 +1,7 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H +#include "op-attrs/ff_dim_t.h" #include "op-attrs/parallel_dim.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" @@ -12,13 +13,14 @@ namespace FlexFlow { int num_shard_dims(ParallelTensorShape const &); -ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); -ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, ff_dim_t); +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &, + relative_ff_dim_t); +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &, relative_ff_dim_t); FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); std::optional - try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); + try_get_shard_dim_at_idx(ParallelTensorShape const &, relative_ff_dim_t); ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h index 2f95eac1d4..8cd0337cc9 100644 --- a/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h +++ b/lib/op-attrs/include/op-attrs/relative_ff_dim_t.h @@ -1,13 +1,18 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_RELATIVE_FF_DIM_T_H +#include "op-attrs/ff_dim_t.dtg.h" #include "op-attrs/relative_ff_dim_t.dtg.h" #include "rapidcheck.h" +namespace FlexFlow { +ff_dim_t relative_ff_dim_t_to_ff_dim_t(relative_ff_dim_t ff_dim, int input_dim); +} // namespace FlexFlow + namespace rc { template <> -struct Arbitrary { - static Gen arbitrary(); +struct Arbitrary<::FlexFlow::relative_ff_dim_t> { + static Gen<::FlexFlow::relative_ff_dim_t> arbitrary(); }; } // namespace rc diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index ee44a39170..5e1503360b 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -9,8 +9,8 @@ namespace FlexFlow { FFOrdered const &ff_ordered(TensorDims const &); size_t num_dims(TensorDims const &); -size_t dim_at_idx(TensorDims const &, ff_dim_t); -size_t &dim_at_idx(TensorDims &, ff_dim_t); +size_t dim_at_idx(TensorDims const &, relative_ff_dim_t); +size_t &dim_at_idx(TensorDims &, relative_ff_dim_t); bool tensor_dims_is_broadcastable_to(TensorDims const &curr, TensorDims const &goal); diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 14ee637f92..b8733cddbe 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -6,8 +6,8 @@ namespace FlexFlow { size_t num_dims(TensorShape const &); -size_t dim_at_idx(TensorShape const &, ff_dim_t); -size_t &dim_at_idx(TensorShape &, ff_dim_t); +size_t dim_at_idx(TensorShape const &, relative_ff_dim_t); +size_t &dim_at_idx(TensorShape &, relative_ff_dim_t); size_t get_num_elements(TensorShape const &); size_t get_size_in_bytes(TensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/ff_dim_t.cc b/lib/op-attrs/src/op-attrs/ff_dim_t.cc index 043d2b2309..d8e56fcade 100644 --- a/lib/op-attrs/src/op-attrs/ff_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/ff_dim_t.cc @@ -1,8 +1,15 @@ #include "op-attrs/ff_dim_t.h" +namespace FlexFlow { +relative_ff_dim_t ff_dim_t_to_relative_ff_dim_t(ff_dim_t ff_dim) { + return relative_ff_dim_t{ff_dim.value.get_value()}; +} +} // namespace FlexFlow + namespace rc { -Gen Arbitrary::arbitrary() { - return gen::construct( - gen::inRange(0, MAX_TENSOR_DIM)); +Gen<::FlexFlow::ff_dim_t> Arbitrary<::FlexFlow::ff_dim_t>::arbitrary() { + return gen::construct<::FlexFlow::ff_dim_t>( + gen::map(gen::inRange(0, MAX_TENSOR_DIM), + [](int value) { return FlexFlow::nonnegative_int(value); })); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 483d832fee..57c7105534 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -33,15 +33,15 @@ int get_oProjSize(MultiHeadAttentionAttrs const &attrs) { } int get_qSize(TensorShape const &query_shape) { - return dim_at_idx(query_shape, ff_dim_t(0)); + return dim_at_idx(query_shape, relative_ff_dim_t{0}); } int get_kSize(TensorShape const &key_shape) { - return dim_at_idx(key_shape, ff_dim_t(0)); + return dim_at_idx(key_shape, relative_ff_dim_t{0}); } int get_vSize(TensorShape const &value_shape) { - return dim_at_idx(value_shape, ff_dim_t(0)); + return dim_at_idx(value_shape, relative_ff_dim_t{0}); } int get_qSize(MultiHeadAttentionParallelInputs const &inputs) { diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc index 65feb642e1..97544d1750 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_inputs.cc @@ -31,9 +31,9 @@ tl::expected 3)); } - size_t seq_len_q = dim_at_idx(input_q, ff_dim_t{-2}); - size_t seq_len_k = dim_at_idx(input_k, ff_dim_t{-2}); - size_t seq_len_v = dim_at_idx(input_v, ff_dim_t{-2}); + size_t seq_len_q = dim_at_idx(input_q, relative_ff_dim_t{-2}); + size_t seq_len_k = dim_at_idx(input_k, relative_ff_dim_t{-2}); + size_t seq_len_v = dim_at_idx(input_v, relative_ff_dim_t{-2}); if (!all_same(seq_len_q, seq_len_k, seq_len_v)) { return tl::unexpected(fmt::format( @@ -43,9 +43,9 @@ tl::expected seq_len_v)); } - size_t batch_size_q = dim_at_idx(input_q, ff_dim_t{-3}); - size_t batch_size_k = dim_at_idx(input_k, ff_dim_t{-3}); - size_t batch_size_v = dim_at_idx(input_v, ff_dim_t{-3}); + size_t batch_size_q = dim_at_idx(input_q, relative_ff_dim_t{-3}); + size_t batch_size_k = dim_at_idx(input_k, relative_ff_dim_t{-3}); + size_t batch_size_v = dim_at_idx(input_v, relative_ff_dim_t{-3}); if (!all_same(batch_size_q, batch_size_k, batch_size_v)) { return tl::unexpected(fmt::format( @@ -63,9 +63,9 @@ tl::expected input_v.data_type)); } - size_t q_size = dim_at_idx(input_q, ff_dim_t{-1}); - size_t k_size = dim_at_idx(input_k, ff_dim_t{-1}); - size_t v_size = dim_at_idx(input_v, ff_dim_t{-1}); + size_t q_size = dim_at_idx(input_q, relative_ff_dim_t{-1}); + size_t k_size = dim_at_idx(input_k, relative_ff_dim_t{-1}); + size_t v_size = dim_at_idx(input_v, relative_ff_dim_t{-1}); return MultiHeadAttentionInputs{ batch_size_q, diff --git a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc index b5ddeaac30..3bd0825555 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention/multihead_attention_parallel_inputs.cc @@ -42,7 +42,7 @@ tl::expected 3)); } - ShardParallelDim seq_len_q = shard_dim_at_idx(input_q, ff_dim_t{-2}); + ShardParallelDim seq_len_q = shard_dim_at_idx(input_q, relative_ff_dim_t{-2}); if (seq_len_q.degree != 1) { return tl::unexpected( fmt::format("Query sequence length parallel degree expected to be 1, " @@ -50,7 +50,7 @@ tl::expected seq_len_q.degree)); } - ShardParallelDim seq_len_k = shard_dim_at_idx(input_k, ff_dim_t{-2}); + ShardParallelDim seq_len_k = shard_dim_at_idx(input_k, relative_ff_dim_t{-2}); if (seq_len_k.degree != 1) { return tl::unexpected( fmt::format("Key sequence length parallel degree expected to be 1, but " @@ -58,7 +58,7 @@ tl::expected seq_len_k.degree)); } - ShardParallelDim seq_len_v = shard_dim_at_idx(input_v, ff_dim_t{-2}); + ShardParallelDim seq_len_v = shard_dim_at_idx(input_v, relative_ff_dim_t{-2}); if (seq_len_v.degree != 1) { return tl::unexpected( fmt::format("Value sequence length parallel degree expected to be 1, " @@ -66,9 +66,12 @@ tl::expected seq_len_v.degree)); } - ShardParallelDim batch_size_q = shard_dim_at_idx(input_q, ff_dim_t{-3}); - ShardParallelDim batch_size_k = shard_dim_at_idx(input_k, ff_dim_t{-3}); - ShardParallelDim batch_size_v = shard_dim_at_idx(input_v, ff_dim_t{-3}); + ShardParallelDim batch_size_q = + shard_dim_at_idx(input_q, relative_ff_dim_t{-3}); + ShardParallelDim batch_size_k = + shard_dim_at_idx(input_k, relative_ff_dim_t{-3}); + ShardParallelDim batch_size_v = + shard_dim_at_idx(input_v, relative_ff_dim_t{-3}); if (!all_same( batch_size_q.degree, batch_size_k.degree, batch_size_v.degree)) { @@ -80,7 +83,7 @@ tl::expected batch_size_v.degree)); } - ShardParallelDim query_dim = shard_dim_at_idx(input_q, ff_dim_t{-1}); + ShardParallelDim query_dim = shard_dim_at_idx(input_q, relative_ff_dim_t{-1}); if (query_dim.degree > 1) { return tl::unexpected( fmt::format("Expected query tensor to have query dim parallel degree " @@ -88,7 +91,7 @@ tl::expected query_dim.degree)); } - ShardParallelDim key_dim = shard_dim_at_idx(input_k, ff_dim_t{-1}); + ShardParallelDim key_dim = shard_dim_at_idx(input_k, relative_ff_dim_t{-1}); if (key_dim.degree > 1) { return tl::unexpected( fmt::format("Expected key tensor to have key dim parallel degree 1, " @@ -96,7 +99,7 @@ tl::expected key_dim.degree)); } - ShardParallelDim value_dim = shard_dim_at_idx(input_v, ff_dim_t{-1}); + ShardParallelDim value_dim = shard_dim_at_idx(input_v, relative_ff_dim_t{-1}); if (value_dim.degree > 1) { return tl::unexpected( fmt::format("Expected value tensor to have value dim parallel degree " diff --git a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc index f9836bd3ed..71118db7a6 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_matmul.cc @@ -57,13 +57,13 @@ tl::expected input_rhs.data_type)); } - size_t lhs_b = dim_at_idx(input_lhs, ff_dim_t{0}); - size_t n = dim_at_idx(input_lhs, ff_dim_t{1}); - size_t lhs_m = dim_at_idx(input_lhs, ff_dim_t{2}); + size_t lhs_b = dim_at_idx(input_lhs, relative_ff_dim_t{0}); + size_t n = dim_at_idx(input_lhs, relative_ff_dim_t{1}); + size_t lhs_m = dim_at_idx(input_lhs, relative_ff_dim_t{2}); - size_t rhs_b = dim_at_idx(input_rhs, ff_dim_t{0}); - size_t rhs_m = dim_at_idx(input_rhs, ff_dim_t{1}); - size_t p = dim_at_idx(input_rhs, ff_dim_t{2}); + size_t rhs_b = dim_at_idx(input_rhs, relative_ff_dim_t{0}); + size_t rhs_m = dim_at_idx(input_rhs, relative_ff_dim_t{1}); + size_t p = dim_at_idx(input_rhs, relative_ff_dim_t{2}); if (lhs_b != rhs_b) { return tl::unexpected( @@ -111,13 +111,13 @@ tl::expected assert(get_total_parallel_degree(input_lhs) == get_total_parallel_degree(input_rhs)); - ShardParallelDim lhs_b = shard_dim_at_idx(input_lhs, ff_dim_t{0}); - ShardParallelDim n = shard_dim_at_idx(input_lhs, ff_dim_t{1}); - ShardParallelDim lhs_m = shard_dim_at_idx(input_lhs, ff_dim_t{2}); + ShardParallelDim lhs_b = shard_dim_at_idx(input_lhs, relative_ff_dim_t{0}); + ShardParallelDim n = shard_dim_at_idx(input_lhs, relative_ff_dim_t{1}); + ShardParallelDim lhs_m = shard_dim_at_idx(input_lhs, relative_ff_dim_t{2}); - ShardParallelDim rhs_b = shard_dim_at_idx(input_rhs, ff_dim_t{0}); - ShardParallelDim rhs_m = shard_dim_at_idx(input_rhs, ff_dim_t{1}); - ShardParallelDim p = shard_dim_at_idx(input_rhs, ff_dim_t{2}); + ShardParallelDim rhs_b = shard_dim_at_idx(input_rhs, relative_ff_dim_t{0}); + ShardParallelDim rhs_m = shard_dim_at_idx(input_rhs, relative_ff_dim_t{1}); + ShardParallelDim p = shard_dim_at_idx(input_rhs, relative_ff_dim_t{2}); if (lhs_b != rhs_b) { return tl::unexpected( diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index f394bb8473..472e5f1a25 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -67,7 +67,7 @@ tl::expected return tl::unexpected("No gamma weights exist for attrs.affine = false"); } - size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + size_t num_channels = dim_at_idx(input_shape, relative_ff_dim_t{1}); return TensorShape{ TensorDims{FFOrdered{ @@ -109,8 +109,12 @@ static std::optional } FFOrdered non_channel_degrees = - concat(slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), - slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); + concat(slice(input_degrees.shard_degrees, + ff_dim_t{nonnegative_int{0}}, + ff_dim_t{nonnegative_int{1}}), + slice(input_degrees.shard_degrees, + ff_dim_t{nonnegative_int{2}}, + std::nullopt)); if (any_of(non_channel_degrees, [](int degree) { return degree != 1; })) { return fmt::format("Expected parallel degree of all non-channel dimensions " @@ -152,7 +156,7 @@ tl::expected return tl::unexpected("No gamma weights exist for attrs.affine = false"); } - ff_dim_t channel_dim = ff_dim_t{1}; + relative_ff_dim_t channel_dim = relative_ff_dim_t{1}; return ParallelTensorDimDegrees{ SumDegree{1}, diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index aa3c95f551..31e241e27b 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -15,7 +15,7 @@ RecordFormatter as_dot(BroadcastAttrs const &attrs) { for (int i = 0; i < num_dims(attrs.target_dims); i++) { r << kv(fmt::format("target_dims[{}]", i), - dim_at_idx(attrs.target_dims, ff_dim_t{i})); + dim_at_idx(attrs.target_dims, relative_ff_dim_t{i})); } return r; diff --git a/lib/op-attrs/src/op-attrs/ops/combine.cc b/lib/op-attrs/src/op-attrs/ops/combine.cc index e41b78c5af..f27fe4d34c 100644 --- a/lib/op-attrs/src/op-attrs/ops/combine.cc +++ b/lib/op-attrs/src/op-attrs/ops/combine.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/combine.h" +#include "op-attrs/ff_dim_t.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { @@ -7,8 +8,8 @@ tl::expected get_output_shape(CombineAttrs const &attrs, ParallelTensorShape const &input) { ShardParallelDim input_dim = ({ - std::optional result = - try_get_shard_dim_at_idx(input, attrs.combine_dim); + std::optional result = try_get_shard_dim_at_idx( + input, ff_dim_t_to_relative_ff_dim_t(attrs.combine_dim)); if (!result.has_value()) { return tl::unexpected(fmt::format( "Failed to get shard dim at index {} in parallel tensor shape {}", @@ -29,7 +30,8 @@ tl::expected } ParallelTensorShape output = input; - shard_dim_at_idx(output, attrs.combine_dim).degree /= attrs.combine_degree; + shard_dim_at_idx(output, ff_dim_t_to_relative_ff_dim_t(attrs.combine_dim)) + .degree /= attrs.combine_degree; return output; } diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 74295f279e..e2738ab82b 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -46,11 +46,18 @@ tl::expected if (!returned.has_value()) { return tl::unexpected(returned.error()); } - returned.value(); + std::map returned_wrapped; + for (const auto &[key, value] : returned.value()) { + ff_dim_t new_key = ff_dim_t{nonnegative_int{key.value}}; + returned_wrapped[new_key] = value; + } + returned_wrapped; }); - std::vector axis_dim_sizes = transform( - inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); + std::vector axis_dim_sizes = + transform(inputs, [&](TensorShape const &s) { + return dim_at_idx(s, ff_dim_t_to_relative_ff_dim_t(attrs.axis)); + }); size_t output_axis_dim_size = sum(axis_dim_sizes); @@ -104,7 +111,8 @@ tl::expected }); if (!all_of(inputs, [&](ParallelTensorShape const &s) { - return shard_dim_at_idx(s, attrs.axis).degree == 1; + return shard_dim_at_idx(s, ff_dim_t_to_relative_ff_dim_t(attrs.axis)) + .degree == 1; })) { return tl::unexpected(fmt::format( "get_output_shape for Concat expected input tensors to have parallel " diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc index a8a3b10bdf..aad067feb2 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_input_shape.cc @@ -6,10 +6,10 @@ namespace FlexFlow { Conv2DInputShape parse_input_shape(TensorShape const &input) { assert(num_dims(input) == 4); - size_t num_samples = dim_at_idx(input, ff_dim_t{0}); - size_t in_channels = dim_at_idx(input, ff_dim_t{1}); - size_t in_height = dim_at_idx(input, ff_dim_t{2}); - size_t in_width = dim_at_idx(input, ff_dim_t{3}); + size_t num_samples = dim_at_idx(input, relative_ff_dim_t{0}); + size_t in_channels = dim_at_idx(input, relative_ff_dim_t{1}); + size_t in_height = dim_at_idx(input, relative_ff_dim_t{2}); + size_t in_width = dim_at_idx(input, relative_ff_dim_t{3}); return Conv2DInputShape{ num_samples, diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc index 98f69d14c9..8143353b2d 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d/conv_2d_parallel_input_shape.cc @@ -7,10 +7,10 @@ Conv2DParallelInputShape parse_parallel_input_shape(ParallelTensorShape const &input) { assert(num_shard_dims(input) == 4); - ShardParallelDim sample_dim = shard_dim_at_idx(input, ff_dim_t{0}); - ShardParallelDim channel_dim = shard_dim_at_idx(input, ff_dim_t{1}); - ShardParallelDim height_dim = shard_dim_at_idx(input, ff_dim_t{2}); - ShardParallelDim width_dim = shard_dim_at_idx(input, ff_dim_t{3}); + ShardParallelDim sample_dim = shard_dim_at_idx(input, relative_ff_dim_t{0}); + ShardParallelDim channel_dim = shard_dim_at_idx(input, relative_ff_dim_t{1}); + ShardParallelDim height_dim = shard_dim_at_idx(input, relative_ff_dim_t{2}); + ShardParallelDim width_dim = shard_dim_at_idx(input, relative_ff_dim_t{3}); Conv2DParallelInputShape parsed = Conv2DParallelInputShape{ sample_dim, diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index d10d52c6f5..fe557695da 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -34,7 +34,7 @@ tl::expected } TensorShape output = input; - dim_at_idx(output, ff_dim_t{-1}) = attrs.out_channels; + dim_at_idx(output, relative_ff_dim_t{-1}) = attrs.out_channels; output.data_type = attrs.data_type; return output; } @@ -73,12 +73,12 @@ tl::expected }); SumDegree sum_degree = - SumDegree{shard_dim_at_idx(input, ff_dim_t{-1}).degree}; + SumDegree{shard_dim_at_idx(input, relative_ff_dim_t{-1}).degree}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{1}; FFOrdered shard_degrees = transform(input.dims.shard_dims, [](ShardParallelDim const &d) { return d.degree; }); - shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); + shard_degrees.at(relative_ff_dim_t{-1}) = get_discard_copy_degree(input); return lift_to_parallel_with_degrees( unpar, sum_degree, discard_copy_degree, shard_degrees); diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index e9833d5e3f..bc86102566 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -11,8 +11,9 @@ namespace FlexFlow { TensorShape get_output_shape(FlatAttrs const &attrs, TensorShape const &input_shape) { - FFOrdered leading_dims = - slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); + FFOrdered leading_dims = slice(ff_ordered(input_shape.dims), + ff_dim_t{nonnegative_int{0}}, + attrs.start_dim); FFOrdered flattened_dims = slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); FFOrdered trailing_dims = @@ -57,7 +58,9 @@ tl::expected /*discard_copy_degree=*/input_degrees.discard_copy_degree, /*shard_degrees=*/ concat(std::vector{ - slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), + slice(input_degrees.shard_degrees, + ff_dim_t{nonnegative_int{0}}, + attrs.start_dim), {product(flattened_dim_degrees)}, slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), }), diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index 0dd9ac7a17..a530fe58fc 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -73,7 +73,7 @@ tl::expected [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { - return dim_at_idx(input_shape, dim_idx); + return dim_at_idx(input_shape, ff_dim_t_to_relative_ff_dim_t(dim_idx)); }); return TensorShape{ @@ -117,7 +117,9 @@ static std::optional } if (!all_of(attrs.axes, [&](ff_dim_t axis) { - return shard_dim_at_idx(input_shape, axis).degree == 1; + return shard_dim_at_idx(input_shape, + ff_dim_t_to_relative_ff_dim_t(axis)) + .degree == 1; })) { return fmt::format("Expected parallel degree of all dimensions in " "LayerNorm axes {} to be 1, but received input shape {}", @@ -163,7 +165,8 @@ tl::expected [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); std::vector raw_weight_shard_dims = transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { - return shard_dim_at_idx(input_shape, dim_idx); + return shard_dim_at_idx(input_shape, + ff_dim_t_to_relative_ff_dim_t(dim_idx)); }); return ParallelTensorShape{ diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index feac647216..e00a47d490 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -41,7 +41,7 @@ RecordFormatter as_dot(LinearAttrs const &attrs) { tl::expected get_projection_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { - size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); + size_t in_channels = dim_at_idx(input_shape, relative_ff_dim_t{-1}); return TensorShape{ TensorDims{ @@ -64,7 +64,7 @@ tl::expected tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { TensorShape output_shape = input_shape; - output_shape.dims.ff_ordered.at(ff_dim_t{-1}) = + output_shape.dims.ff_ordered.at(relative_ff_dim_t{-1}) = size_t_from_int(attrs.out_channels); return output_shape; @@ -84,11 +84,11 @@ tl::expected SumDegree sum_degree = SumDegree{1}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{ - get_sum_degree(input) * - product( - slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; + get_sum_degree(input) * product(slice(ff_ordered_shard_degrees(input), + std::nullopt, + relative_ff_dim_t{-1}))}; FFOrdered shard_degrees = FFOrdered{ - shard_dim_at_idx(input, ff_dim_t{-1}).degree, + shard_dim_at_idx(input, relative_ff_dim_t{-1}).degree, get_discard_copy_degree(input), }; @@ -107,10 +107,11 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = SumDegree{ - get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree}; - DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product( - slice(ff_ordered_shard_degrees(input), std::nullopt, ff_dim_t{-1}))}; + SumDegree sum_degree = + SumDegree{get_sum_degree(input) * + shard_dim_at_idx(input, relative_ff_dim_t{-1}).degree}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{product(slice( + ff_ordered_shard_degrees(input), std::nullopt, relative_ff_dim_t{-1}))}; FFOrdered shard_degrees = FFOrdered{get_discard_copy_degree(input)}; return lift_to_parallel_with_degrees( @@ -129,11 +130,12 @@ tl::expected result_unpar.value(); }); - SumDegree sum_degree = SumDegree{ - get_sum_degree(input) * shard_dim_at_idx(input, ff_dim_t{-1}).degree}; + SumDegree sum_degree = + SumDegree{get_sum_degree(input) * + shard_dim_at_idx(input, relative_ff_dim_t{-1}).degree}; DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{1}; FFOrdered shard_degrees = ff_ordered_shard_degrees(input); - shard_degrees.at(ff_dim_t{-1}) = get_discard_copy_degree(input); + shard_degrees.at(relative_ff_dim_t{-1}) = get_discard_copy_degree(input); return lift_to_parallel_with_degrees( unpar, sum_degree, discard_copy_degree, shard_degrees); diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index 95bcd8b336..86d287ebc8 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -22,10 +22,10 @@ tl::expected input_dims)); } - size_t num_samples = dim_at_idx(input_dims, ff_dim_t{0}); - size_t num_channels = dim_at_idx(input_dims, ff_dim_t{1}); - size_t input_h = dim_at_idx(input_dims, ff_dim_t{2}); - size_t input_w = dim_at_idx(input_dims, ff_dim_t{3}); + size_t num_samples = dim_at_idx(input_dims, relative_ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_dims, relative_ff_dim_t{1}); + size_t input_h = dim_at_idx(input_dims, relative_ff_dim_t{2}); + size_t input_w = dim_at_idx(input_dims, relative_ff_dim_t{3}); if (input_h % output_h != 0) { return tl::unexpected(fmt::format( @@ -113,10 +113,10 @@ tl::expected input_shape)); } - size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); - size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); - size_t input_height = dim_at_idx(input_shape, ff_dim_t{2}); - size_t input_width = dim_at_idx(input_shape, ff_dim_t{3}); + size_t num_samples = dim_at_idx(input_shape, relative_ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_shape, relative_ff_dim_t{1}); + size_t input_height = dim_at_idx(input_shape, relative_ff_dim_t{2}); + size_t input_width = dim_at_idx(input_shape, relative_ff_dim_t{3}); size_t output_height = (input_height + 2 * attrs.padding_h - attrs.kernel_h) / attrs.stride_h + diff --git a/lib/op-attrs/src/op-attrs/ops/repartition.cc b/lib/op-attrs/src/op-attrs/ops/repartition.cc index 37a0b8a168..eb68bdec4e 100644 --- a/lib/op-attrs/src/op-attrs/ops/repartition.cc +++ b/lib/op-attrs/src/op-attrs/ops/repartition.cc @@ -6,8 +6,9 @@ tl::expected get_output_shape(RepartitionAttrs const &attrs, ParallelTensorShape const &input_shape) { ParallelTensorShape output_shape = input_shape; - output_shape.dims.shard_dims.at(attrs.repartition_dim).degree *= - attrs.repartition_degree; + output_shape.dims.shard_dims + .at(ff_dim_t_to_relative_ff_dim_t(attrs.repartition_dim)) + .degree *= attrs.repartition_degree; return output_shape; } diff --git a/lib/op-attrs/src/op-attrs/ops/softmax.cc b/lib/op-attrs/src/op-attrs/ops/softmax.cc index 541c590cbd..38562d24c1 100644 --- a/lib/op-attrs/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/src/op-attrs/ops/softmax.cc @@ -39,7 +39,8 @@ tl::expected get_discard_copy_degree(input_shape))); } - if (shard_dim_at_idx(input_shape, attrs.dim).degree != 1) { + if (shard_dim_at_idx(input_shape, ff_dim_t_to_relative_ff_dim_t(attrs.dim)) + .degree != 1) { return tl::unexpected( fmt::format("Expected parallel degree of Softmax dimension {} to be 1, " "but received input shape {}", diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 2955545561..0bb940924a 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -96,11 +96,13 @@ bool is_valid(ParallelTensorDims const &dims) { [](ReplicaParallelDim const &d) { return is_valid(d); }); } -ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, ff_dim_t idx) { +ShardParallelDim shard_dim_at_idx(ParallelTensorDims const &d, + relative_ff_dim_t idx) { return d.shard_dims.at(idx); } -ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &d, ff_dim_t idx) { +ShardParallelDim &shard_dim_at_idx(ParallelTensorDims &d, + relative_ff_dim_t idx) { return d.shard_dims.at(idx); } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index dcc567e0ca..97b7360c84 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -41,11 +41,13 @@ bool is_valid(ParallelTensorShape const &shape) { return is_valid(shape.dims); } -ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { +ShardParallelDim shard_dim_at_idx(ParallelTensorShape const &s, + relative_ff_dim_t d) { return shard_dim_at_idx(s.dims, d); } -ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, ff_dim_t d) { +ShardParallelDim &shard_dim_at_idx(ParallelTensorShape &s, + relative_ff_dim_t d) { return shard_dim_at_idx(s.dims, d); } @@ -54,7 +56,8 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &s) { } std::optional - try_get_shard_dim_at_idx(ParallelTensorShape const &s, ff_dim_t d) { + try_get_shard_dim_at_idx(ParallelTensorShape const &s, + relative_ff_dim_t d) { if (s.dims.shard_dims.idx_is_valid(d)) { return s.dims.shard_dims.at(d); } else { @@ -122,7 +125,7 @@ TensorShape get_reduced_shape(ParallelTensorShape const &s) { ParallelDim get_parallel_dim_at_idx(ParallelTensorShape const &shape, parallel_tensor_dim_idx_t idx) { return idx.visit( - overload{[&](ff_dim_t shard_dim) { + overload{[&](relative_ff_dim_t shard_dim) { return ParallelDim{shape.dims.shard_dims.at(shard_dim)}; }, [&](ReplicaType replica_type) { @@ -138,7 +141,7 @@ std::unordered_set get_parallel_tensor_dim_indices(ParallelTensorShape const &shape) { std::unordered_set indices; extend(indices, transform(range(num_shard_dims(shape.dims)), [](int idx) { - return parallel_tensor_dim_idx_t(ff_dim_t(idx)); + return parallel_tensor_dim_idx_t(relative_ff_dim_t{idx}); })); indices.insert(parallel_tensor_dim_idx_t(ReplicaType::SUM)); indices.insert(parallel_tensor_dim_idx_t(ReplicaType::DISCARD_COPY)); diff --git a/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc index 8080da203d..09e99a4e59 100644 --- a/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc +++ b/lib/op-attrs/src/op-attrs/relative_ff_dim_t.cc @@ -1,10 +1,21 @@ #include "op-attrs/relative_ff_dim_t.h" #include "rapidcheck.h" +namespace FlexFlow { +ff_dim_t relative_ff_dim_t_to_ff_dim_t(relative_ff_dim_t ff_dim, + int input_dim) { + int raw = ff_dim.value; + if (raw < 0) { + raw = input_dim + raw; + } + return ff_dim_t{nonnegative_int{raw}}; +} +} // namespace FlexFlow + namespace rc { -Gen - Arbitrary::arbitrary() { - return gen::construct( +Gen<::FlexFlow::relative_ff_dim_t> + Arbitrary<::FlexFlow::relative_ff_dim_t>::arbitrary() { + return gen::construct<::FlexFlow::relative_ff_dim_t>( gen::inRange(-MAX_TENSOR_DIM, MAX_TENSOR_DIM)); } } // namespace rc diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 1bb050db52..f0ac88d8e4 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -19,11 +19,11 @@ size_t num_dims(TensorDims const &dims) { return dims.ff_ordered.size(); } -size_t dim_at_idx(TensorDims const &dims, ff_dim_t idx) { +size_t dim_at_idx(TensorDims const &dims, relative_ff_dim_t idx) { return dims.ff_ordered.at(idx); } -size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { +size_t &dim_at_idx(TensorDims &dims, relative_ff_dim_t idx) { return dims.ff_ordered.at(idx); } diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index 07508e3065..70ed58aac6 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -11,11 +11,11 @@ size_t num_dims(TensorShape const &s) { return s.dims.ff_ordered.size(); } -size_t dim_at_idx(TensorShape const &s, ff_dim_t idx) { +size_t dim_at_idx(TensorShape const &s, relative_ff_dim_t idx) { return dim_at_idx(s.dims, idx); } -size_t &dim_at_idx(TensorShape &s, ff_dim_t idx) { +size_t &dim_at_idx(TensorShape &s, relative_ff_dim_t idx) { return dim_at_idx(s.dims, idx); } diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc index 180bc2a01f..bf4c33d65a 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc @@ -10,9 +10,9 @@ TEST_SUITE(FF_TEST_SUITE) { std::map result = enumerate(input); std::map correct = { - {ff_dim_t{0}, "zero"}, - {ff_dim_t{1}, "one"}, - {ff_dim_t{2}, "two"}, + {ff_dim_t{nonnegative_int{0}}, "zero"}, + {ff_dim_t{nonnegative_int{1}}, "one"}, + {ff_dim_t{nonnegative_int{2}}, "two"}, }; CHECK(result == correct); diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc index 7bc1695e5c..bba989920e 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -20,9 +20,9 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is missing keys") { SUBCASE("missing key is in middle") { T m = { - {ff_dim_t{0}, 4}, - {ff_dim_t{1}, 2}, - {ff_dim_t{3}, 5}, + {ff_dim_t{nonnegative_int{0}}, 4}, + {ff_dim_t{nonnegative_int{1}}, 2}, + {ff_dim_t{nonnegative_int{3}}, 5}, }; CHECK_THROWS(ff_ordered_from_map(m)); @@ -30,31 +30,21 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("missing key is 0 idx") { T m = { - {ff_dim_t{1}, 2}, - {ff_dim_t{2}, 7}, - {ff_dim_t{3}, 5}, + {ff_dim_t{nonnegative_int{1}}, 2}, + {ff_dim_t{nonnegative_int{2}}, 7}, + {ff_dim_t{nonnegative_int{3}}, 5}, }; CHECK_THROWS(ff_ordered_from_map(m)); } } - SUBCASE("input has negative keys") { - T m = { - {ff_dim_t{0}, 4}, - {ff_dim_t{1}, 5}, - {ff_dim_t{-1}, 2}, - }; - - CHECK_THROWS(ff_ordered_from_map(m)); - } - SUBCASE("input is valid") { T m = { - {ff_dim_t{0}, 4}, - {ff_dim_t{1}, 5}, - {ff_dim_t{2}, 2}, - {ff_dim_t{3}, 7}, + {ff_dim_t{nonnegative_int{0}}, 4}, + {ff_dim_t{nonnegative_int{1}}, 5}, + {ff_dim_t{nonnegative_int{2}}, 2}, + {ff_dim_t{nonnegative_int{3}}, 7}, }; FFOrdered result = ff_ordered_from_map(m); diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc index 8d5f247756..ccaab98860 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc @@ -4,22 +4,52 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE( - "slice(DimOrdered, std::optional, std::optional)") { + TEST_CASE("slice(FFOrdered, ..., ...") { FFOrdered d = FFOrdered{ 1, 2, 3, 4, }; + SUBCASE("ff_dim_t, ff_dim_t") { + FFOrdered result = + slice(d, ff_dim_t{nonnegative_int{1}}, ff_dim_t{nonnegative_int{3}}); + FFOrdered correct = FFOrdered{2, 3}; - FFOrdered result = slice(d, std::nullopt, ff_dim_t{-1}); - FFOrdered correct = FFOrdered{ - 1, - 2, - 3, - }; + CHECK(result == correct); + } + SUBCASE("ff_dim_t, std::nullopt_t") { + FFOrdered result = + slice(d, ff_dim_t{nonnegative_int{1}}, std::nullopt); + FFOrdered correct = FFOrdered{2, 3, 4}; + + CHECK(result == correct); + } + SUBCASE("std::nullopt_t, ff_dim_t") { + FFOrdered result = + slice(d, std::nullopt, ff_dim_t{nonnegative_int{3}}); + FFOrdered correct = FFOrdered{1, 2, 3}; + + CHECK(result == correct); + } + SUBCASE("relative_ff_dim_t, relative_ff_dim_t") { + FFOrdered result = + slice(d, relative_ff_dim_t{1}, relative_ff_dim_t{-1}); + FFOrdered correct = FFOrdered{2, 3}; + + CHECK(result == correct); + } + SUBCASE("relative_ff_dim_t, std::nullopt_t") { + FFOrdered result = slice(d, relative_ff_dim_t{-3}, std::nullopt); + FFOrdered correct = FFOrdered{2, 3, 4}; + + CHECK(result == correct); + } + SUBCASE("std::nullopt_t, relative_ff_dim_t") { + FFOrdered result = slice(d, std::nullopt, relative_ff_dim_t{-1}); + FFOrdered correct = FFOrdered{1, 2, 3}; - CHECK(result == correct); + CHECK(result == correct); + } } } diff --git a/lib/op-attrs/test/src/op-attrs/ff_dim_t.cc b/lib/op-attrs/test/src/op-attrs/ff_dim_t.cc new file mode 100644 index 0000000000..a2f24f539e --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ff_dim_t.cc @@ -0,0 +1,20 @@ +#include "op-attrs/ff_dim_t.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("FF_DIM_T_TO_RELATIVE_FF_DIM_T") { + SUBCASE("ZERO") { + ff_dim_t ff_dim = ff_dim_t{nonnegative_int{0}}; + relative_ff_dim_t relative_ff_dim = ff_dim_t_to_relative_ff_dim_t(ff_dim); + CHECK(relative_ff_dim == relative_ff_dim_t{0}); + } + + SUBCASE("POSITIVE") { + ff_dim_t ff_dim = ff_dim_t{nonnegative_int{1}}; + relative_ff_dim_t relative_ff_dim = ff_dim_t_to_relative_ff_dim_t(ff_dim); + CHECK(relative_ff_dim == relative_ff_dim_t{1}); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc index 33cc00c6a1..4688ad4008 100644 --- a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -9,7 +9,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Concat") { int num_incoming = 4; ComputationGraphOpAttrs attrs = - ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}}}; + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{nonnegative_int{0}}}}; std::vector result = get_incoming_tensor_roles(attrs, num_incoming); diff --git a/lib/op-attrs/test/src/op-attrs/ops/combine.cc b/lib/op-attrs/test/src/op-attrs/ops/combine.cc index bf74a072e0..577961b7b1 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/combine.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/combine.cc @@ -24,7 +24,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; SUBCASE("valid") { - ff_dim_t dim = ff_dim_t{2}; + ff_dim_t dim = ff_dim_t{nonnegative_int{2}}; int degree = 3; CombineAttrs attrs = CombineAttrs{ /*repartition_dim=*/dim, @@ -44,7 +44,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("invalid") { - ff_dim_t dim = ff_dim_t{2}; + ff_dim_t dim = ff_dim_t{nonnegative_int{2}}; int degree = 4; CombineAttrs attrs = CombineAttrs{ /*repartition_dim=*/dim, diff --git a/lib/op-attrs/test/src/op-attrs/ops/concat.cc b/lib/op-attrs/test/src/op-attrs/ops/concat.cc index 9e842c3ebe..2d9842b1dd 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/concat.cc @@ -10,7 +10,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { ConcatAttrs attrs = ConcatAttrs{ - ff_dim_t{1}, + ff_dim_t{nonnegative_int{1}}, }; SUBCASE("empty input shapes list passed") { @@ -81,7 +81,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("concat axis is out of bounds") { attrs = ConcatAttrs{ - ff_dim_t{3}, + ff_dim_t{nonnegative_int{3}}, }; std::vector input_shapes = { @@ -115,7 +115,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { ConcatAttrs attrs = ConcatAttrs{ - ff_dim_t{1}, + ff_dim_t{nonnegative_int{1}}, }; size_t dim0_size = 12; diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc index b091833f10..d5aab55cb2 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc @@ -41,7 +41,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("mismatched dim size") { TensorShape incorrect_rhs = input_lhs; - dim_at_idx(incorrect_rhs, ff_dim_t{0}) += 1; + dim_at_idx(incorrect_rhs, relative_ff_dim_t{0}) += 1; tl::expected result = get_output_shape(attrs, input_lhs, incorrect_rhs); diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc index d81ab95c35..8998dfaffd 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -20,8 +20,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten all dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{0}, - /*end_dim=*/ff_dim_t{4}, + /*start_dim=*/ff_dim_t{nonnegative_int{0}}, + /*end_dim=*/ff_dim_t{nonnegative_int{4}}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -37,8 +37,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten trailing dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{2}, - /*end_dim=*/ff_dim_t{4}, + /*start_dim=*/ff_dim_t{nonnegative_int{2}}, + /*end_dim=*/ff_dim_t{nonnegative_int{4}}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -56,8 +56,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten leading dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{0}, - /*end_dim=*/ff_dim_t{2}, + /*start_dim=*/ff_dim_t{nonnegative_int{0}}, + /*end_dim=*/ff_dim_t{nonnegative_int{2}}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -75,8 +75,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten middle dims") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{1}, - /*end_dim=*/ff_dim_t{3}, + /*start_dim=*/ff_dim_t{nonnegative_int{1}}, + /*end_dim=*/ff_dim_t{nonnegative_int{3}}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -94,8 +94,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten no dims (start_dim == end_dim)") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{2}, - /*end_dim=*/ff_dim_t{2}, + /*start_dim=*/ff_dim_t{nonnegative_int{2}}, + /*end_dim=*/ff_dim_t{nonnegative_int{2}}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -106,8 +106,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("flatten no dims (start_dim < end_dim)") { FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{2}, - /*end_dim=*/ff_dim_t{1}, + /*start_dim=*/ff_dim_t{nonnegative_int{2}}, + /*end_dim=*/ff_dim_t{nonnegative_int{1}}, }; TensorShape result = get_output_shape(attrs, input_shape); @@ -119,8 +119,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE( "get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { - FlatAttrs attrs = FlatAttrs{/*start_dim=*/ff_dim_t{1}, - /*end_dim=*/ff_dim_t{3}}; + FlatAttrs attrs = FlatAttrs{/*start_dim=*/ff_dim_t{nonnegative_int{1}}, + /*end_dim=*/ff_dim_t{nonnegative_int{3}}}; SUBCASE("allows shard parallelism in non-flattened dims") { ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ @@ -217,8 +217,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{1}, - /*end_dim=*/ff_dim_t{3}, + /*start_dim=*/ff_dim_t{nonnegative_int{1}}, + /*end_dim=*/ff_dim_t{nonnegative_int{3}}, }; tl::expected result = diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index f45ea91dac..b9426a89a2 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -11,7 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_layer_norm_incoming_tensor_roles(LayerNormAttrs)") { auto make_attrs = [](bool elementwise_affine) { return LayerNormAttrs{ - /*axes=*/{ff_dim_t{0}, ff_dim_t{2}}, + /*axes=*/{ff_dim_t{nonnegative_int{0}}, ff_dim_t{nonnegative_int{2}}}, elementwise_affine, /*eps=*/1.0, }; @@ -46,7 +46,7 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("shape inference (LayerNorm)") { LayerNormAttrs attrs_affine_true = LayerNormAttrs{ - /*axes=*/{ff_dim_t{1}, ff_dim_t{3}}, + /*axes=*/{ff_dim_t{nonnegative_int{1}}, ff_dim_t{nonnegative_int{3}}}, /*elementwise_affine=*/true, /*eps=*/0.1, }; diff --git a/lib/op-attrs/test/src/op-attrs/ops/repartition.cc b/lib/op-attrs/test/src/op-attrs/ops/repartition.cc index 8bc8205183..ba213f54f4 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/repartition.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/repartition.cc @@ -6,7 +6,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Repartition shape inference") { - ff_dim_t dim = ff_dim_t{2}; + ff_dim_t dim = ff_dim_t{nonnegative_int{2}}; int degree = 4; RepartitionAttrs attrs = RepartitionAttrs{ /*repartition_dim=*/dim, diff --git a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc index 65a74932cb..5808e5ef42 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc @@ -19,7 +19,7 @@ TEST_SUITE(FF_TEST_SUITE) { }; SUBCASE("attrs.dim in bounds") { - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{nonnegative_int{1}}}; tl::expected result = get_output_shape(attrs, input); @@ -29,7 +29,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("attrs.dims out of bounds") { - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{4}}; + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{nonnegative_int{4}}}; std::optional result = optional_from_expected(get_output_shape(attrs, input)); @@ -70,7 +70,7 @@ TEST_SUITE(FF_TEST_SUITE) { make_input(SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); SUBCASE("attrs.dim in bounds") { - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{nonnegative_int{1}}}; tl::expected result = get_output_shape(attrs, par_input); @@ -81,7 +81,7 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("attrs.dims out of bounds") { - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{4}}; + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{nonnegative_int{4}}}; std::optional result = optional_from_expected(get_output_shape(attrs, par_input)); @@ -94,7 +94,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("partition parallism in softmax dim (invalid)") { int degree1 = 2; - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{nonnegative_int{1}}}; ParallelTensorShape par_input = make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree1, 1); @@ -109,7 +109,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("sum parallelism (invalid)") { SumDegree sum_degree = SumDegree{2}; - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{nonnegative_int{1}}}; ParallelTensorShape par_input = make_input(sum_degree, DiscardCopyDegree{1}, 1, 1, 1); @@ -124,7 +124,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("discard copy parallelism (invalid)") { DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{nonnegative_int{1}}}; ParallelTensorShape par_input = make_input(SumDegree{1}, discard_copy_degree, 1, 1, 1); diff --git a/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc index ebeaec4d19..73f5f0674d 100644 --- a/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc @@ -6,7 +6,7 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("PCGOperatorAttrs to/from json") { PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{1}, + /*repartition_dim=*/ff_dim_t{nonnegative_int{1}}, /*repartition_degree=*/4, }}; nlohmann::json j = correct; diff --git a/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc b/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc new file mode 100644 index 0000000000..aa70e19962 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/relative_ff_dim_t.cc @@ -0,0 +1,36 @@ +#include "op-attrs/relative_ff_dim_t.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RELATIVE_FF_DIM_T_TO_FF_DIM_T") { + int input_dim = 5; + + SUBCASE("ZERO") { + relative_ff_dim_t relative_ff_dim = relative_ff_dim_t{0}; + ff_dim_t ff_dim = + relative_ff_dim_t_to_ff_dim_t(relative_ff_dim, input_dim); + CHECK(ff_dim == ff_dim_t{nonnegative_int{0}}); + } + + SUBCASE("POSITIVE") { + relative_ff_dim_t relative_ff_dim = relative_ff_dim_t{1}; + ff_dim_t ff_dim = + relative_ff_dim_t_to_ff_dim_t(relative_ff_dim, input_dim); + CHECK(ff_dim == ff_dim_t{nonnegative_int{1}}); + } + + SUBCASE("NEGATIVE") { + relative_ff_dim_t relative_ff_dim = relative_ff_dim_t{-1}; + ff_dim_t ff_dim = + relative_ff_dim_t_to_ff_dim_t(relative_ff_dim, input_dim); + CHECK(ff_dim == ff_dim_t{nonnegative_int{4}}); + } + + SUBCASE("OUT OF RANGE") { + relative_ff_dim_t relative_ff_dim = relative_ff_dim_t{-10}; + CHECK_THROWS(relative_ff_dim_t_to_ff_dim_t(relative_ff_dim, input_dim)); + } + } +} diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 45cde0de57..df93f69f2e 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -116,7 +116,7 @@ struct ComputationGraphBuilder { // Add a gather layer tensor_guid_t gather(tensor_guid_t const &input, tensor_guid_t const &index, - ff_dim_t dim, + relative_ff_dim_t dim, std::optional const &name = std::nullopt); // Add a cache layer tensor_guid_t diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index dff647f5a1..39072a06ff 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -19,6 +19,7 @@ #include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/relative_ff_dim_t.h" #include "op-attrs/tensor_dims.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" @@ -480,13 +481,8 @@ tensor_guid_t ComputationGraphBuilder::embedding( tensor_guid_t ComputationGraphBuilder::gather( tensor_guid_t const &input, tensor_guid_t const &index, - ff_dim_t dim, + relative_ff_dim_t dim, std::optional const &maybe_name) { - GatherAttrs attrs = GatherAttrs{dim}; - std::string name = - maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); - - LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { throw mk_runtime_error( @@ -496,6 +492,13 @@ tensor_guid_t ComputationGraphBuilder::gather( DataType::INT32, DataType::INT64)); } + + GatherAttrs attrs = GatherAttrs{ + relative_ff_dim_t_to_ff_dim_t(dim, num_dims(this->get_shape(input)))}; + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); @@ -792,7 +795,9 @@ tensor_guid_t ComputationGraphBuilder::concat( int axis, std::optional const &maybe_name) { - ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}}; + relative_ff_dim_t wrapped_axis = relative_ff_dim_t{axis}; + ConcatAttrs attrs = ConcatAttrs{relative_ff_dim_t_to_ff_dim_t( + wrapped_axis, num_dims(this->get_shape(inputs[0])))}; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); @@ -816,8 +821,11 @@ tensor_guid_t ComputationGraphBuilder::flat( int input_num_dims = num_dims(this->get_shape(input)); FlatAttrs attrs = FlatAttrs{ - /*start_dim=*/ff_dim_t{start_dim}, - /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, + /*start_dim=*/relative_ff_dim_t_to_ff_dim_t(relative_ff_dim_t{start_dim}, + input_num_dims), + /*end_dim=*/ + relative_ff_dim_t_to_ff_dim_t( + relative_ff_dim_t{end_dim.value_or(input_num_dims)}, input_num_dims), }; std::string name = @@ -849,8 +857,17 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( num_dims(input_shape))); } + stack_vector axes_stack; + std::transform(axes.begin(), + axes.end(), + std::back_inserter(axes_stack), + [&input_shape](int axis) { + return relative_ff_dim_t_to_ff_dim_t(relative_ff_dim_t{axis}, + num_dims(input_shape)); + }); + LayerNormAttrs attrs = LayerNormAttrs{ - stack_vector{axes.begin(), axes.end()}, + axes_stack, elementwise_affine, eps, }; @@ -906,7 +923,8 @@ tensor_guid_t ComputationGraphBuilder::softmax( input_shape)); } - SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{dim}}; + SoftmaxAttrs attrs = SoftmaxAttrs{relative_ff_dim_t_to_ff_dim_t( + relative_ff_dim_t{dim}, num_dims(input_shape))}; std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 20bd0ac92d..3f66b33b6e 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -492,7 +492,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); - parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2); + parallel_tensor_guid_t output = + b.parallel_partition(input, ff_dim_t{nonnegative_int{0}}, 2); parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { @@ -531,7 +532,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); - parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2); + parallel_tensor_guid_t output = + b.parallel_combine(input, ff_dim_t{nonnegative_int{0}}, 2); parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("incoming") { diff --git a/lib/runtime/src/ops/embedding.cc b/lib/runtime/src/ops/embedding.cc index 2370739d58..253fd3cb4f 100644 --- a/lib/runtime/src/ops/embedding.cc +++ b/lib/runtime/src/ops/embedding.cc @@ -85,7 +85,7 @@ static std::optional attrs.aggr, input.shape.get_dim(), output.shape.get_dim(), - input.shape.at(ff_dim_t(0))); + input.shape.at(ff_dim_t{nonnegative_int{0}})); } TaskImplFunction get_embedding_fwd_task_impl() { diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index 344954c553..dcb06a78fa 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -52,7 +52,8 @@ using namespace FlexFlow; // OperatorAttrAssignment op_ass_n1{ // {{OperatorAttributeKey::OP_TYPE, // AttrConstant{OperatorType::REPARTITION}}, -// {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, +// {OperatorAttributeKey::PARALLEL_DIM, +// AttrConstant{ff_dim_t{nonnegative_int{0}}}}, // {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; // // OperatorAttrAssignment op_ass_n2{ @@ -71,7 +72,8 @@ using namespace FlexFlow; // OperatorAttrAssignment op_ass_n3{ // {{OperatorAttributeKey::OP_TYPE, // AttrConstant{OperatorType::REDUCTION}}, -// {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, +// {OperatorAttributeKey::PARALLEL_DIM, +// AttrConstant{ff_dim_t{nonnegative_int{0}}}}, // {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; // // auto og = NodeLabelledOpenMultiDiGraph::create< diff --git a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h index f3fbc4556d..01bee29f63 100644 --- a/lib/utils/include/utils/nonnegative_int/nonnegative_int.h +++ b/lib/utils/include/utils/nonnegative_int/nonnegative_int.h @@ -40,6 +40,8 @@ class nonnegative_int { friend std::ostream &operator<<(std::ostream &os, nonnegative_int const &n); + friend int format_as(nonnegative_int const &); + int get_value() const; private: @@ -57,7 +59,7 @@ struct adl_serializer<::FlexFlow::nonnegative_int> { namespace std { template <> -struct hash { +struct hash<::FlexFlow::nonnegative_int> { std::size_t operator()(FlexFlow::nonnegative_int const &n) const noexcept; }; } // namespace std diff --git a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc index 3218d94ebf..2c317dce86 100644 --- a/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc +++ b/lib/utils/src/utils/nonnegative_int/nonnegative_int.cc @@ -79,6 +79,10 @@ std::ostream &operator<<(std::ostream &os, nonnegative_int const &n) { int nonnegative_int::get_value() const { return this->value_; } + +int format_as(nonnegative_int const &x) { + return x.get_value(); +} } // namespace FlexFlow namespace nlohmann { @@ -94,7 +98,7 @@ void adl_serializer<::FlexFlow::nonnegative_int>::to_json( } // namespace nlohmann namespace std { -std::size_t hash::operator()( +std::size_t hash<::FlexFlow::nonnegative_int>::operator()( FlexFlow::nonnegative_int const &n) const noexcept { return std::hash{}(n.get_value()); } diff --git a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc index 97ca027354..9d1218d547 100644 --- a/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc +++ b/lib/utils/test/src/utils/nonnegative_int/nonnegative_int.cc @@ -23,22 +23,22 @@ TEST_SUITE(FF_TEST_SUITE) { nonnegative_int nn_int_1b = nonnegative_int{1}; nonnegative_int nn_int_2 = nonnegative_int{2}; SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equal") { - CHECK((nn_int_1a == nn_int_1b) == true); + CHECK(nn_int_1a == nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, not equal") { - CHECK((nn_int_1a == nn_int_2) == false); + CHECK_FALSE(nn_int_1a == nn_int_2); } SUBCASE("LHS: nonnegative_int, RHS: int, equal") { - CHECK((nn_int_1a == 1) == true); + CHECK(nn_int_1a == 1); } SUBCASE("LHS: nonnegative_int, RHS: int, not equal") { - CHECK((nn_int_1a == 2) == false); + CHECK_FALSE(nn_int_1a == 2); } SUBCASE("LHS: int, RHS: nonnegative_int, equal") { - CHECK((1 == nn_int_1b) == true); + CHECK(1 == nn_int_1b); } SUBCASE("LHS: int, RHS: nonnegative_int, not equal") { - CHECK((2 == nn_int_1b) == false); + CHECK_FALSE(2 == nn_int_1b); } } @@ -47,22 +47,22 @@ TEST_SUITE(FF_TEST_SUITE) { nonnegative_int nn_int_1b = nonnegative_int{1}; nonnegative_int nn_int_2 = nonnegative_int{2}; SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equal") { - CHECK((nn_int_1a != nn_int_1b) == false); + CHECK_FALSE(nn_int_1a != nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, not equal") { - CHECK((nn_int_1a != nn_int_2) == true); + CHECK(nn_int_1a != nn_int_2); } SUBCASE("LHS: nonnegative_int, RHS: int, equal") { - CHECK((nn_int_1a != 1) == false); + CHECK_FALSE(nn_int_1a != 1); } SUBCASE("LHS: nonnegative_int, RHS: int, not equal") { - CHECK((nn_int_1a != 2) == true); + CHECK(nn_int_1a != 2); } SUBCASE("LHS: int, RHS: nonnegative_int, equal") { - CHECK((1 != nn_int_1b) == false); + CHECK_FALSE(1 != nn_int_1b); } SUBCASE("LHS: int, RHS: nonnegative_int, not equal") { - CHECK((2 != nn_int_1b) == true); + CHECK(2 != nn_int_1b); } } @@ -71,31 +71,31 @@ TEST_SUITE(FF_TEST_SUITE) { nonnegative_int nn_int_1b = nonnegative_int{1}; nonnegative_int nn_int_2 = nonnegative_int{2}; SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { - CHECK((nn_int_1a < nn_int_2) == true); + CHECK(nn_int_1a < nn_int_2); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { - CHECK((nn_int_1a < nn_int_1b) == false); + CHECK_FALSE(nn_int_1a < nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { - CHECK((nn_int_2 < nn_int_1b) == false); + CHECK_FALSE(nn_int_2 < nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: int, less than") { - CHECK((nn_int_1a < 2) == true); + CHECK(nn_int_1a < 2); } SUBCASE("LHS: nonnegative_int, RHS: int, equals") { - CHECK((nn_int_1a < 1) == false); + CHECK_FALSE(nn_int_1a < 1); } SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { - CHECK((nn_int_2 < 1) == false); + CHECK_FALSE(nn_int_2 < 1); } SUBCASE("LHS: int, RHS: nonnegative_int, less than") { - CHECK((1 < nn_int_2) == true); + CHECK(1 < nn_int_2); } SUBCASE("LHS: int, RHS: nonnegative_int, equals") { - CHECK((1 < nn_int_1b) == false); + CHECK_FALSE(1 < nn_int_1b); } SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { - CHECK((2 < nn_int_1b) == false); + CHECK_FALSE(2 < nn_int_1b); } } @@ -104,31 +104,31 @@ TEST_SUITE(FF_TEST_SUITE) { nonnegative_int nn_int_1b = nonnegative_int{1}; nonnegative_int nn_int_2 = nonnegative_int{2}; SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { - CHECK((nn_int_1a <= nn_int_2) == true); + CHECK(nn_int_1a <= nn_int_2); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { - CHECK((nn_int_1a <= nn_int_1b) == true); + CHECK(nn_int_1a <= nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { - CHECK((nn_int_2 <= nn_int_1b) == false); + CHECK_FALSE(nn_int_2 <= nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: int, less than") { - CHECK((nn_int_1a <= 2) == true); + CHECK(nn_int_1a <= 2); } SUBCASE("LHS: nonnegative_int, RHS: int, equals") { - CHECK((nn_int_1a <= 1) == true); + CHECK(nn_int_1a <= 1); } SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { - CHECK((nn_int_2 <= 1) == false); + CHECK_FALSE(nn_int_2 <= 1); } SUBCASE("LHS: int, RHS: nonnegative_int, less than") { - CHECK((1 <= nn_int_2) == true); + CHECK(1 <= nn_int_2); } SUBCASE("LHS: int, RHS: nonnegative_int, equals") { - CHECK((1 <= nn_int_1b) == true); + CHECK(1 <= nn_int_1b); } SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { - CHECK((2 <= nn_int_1b) == false); + CHECK_FALSE(2 <= nn_int_1b); } } @@ -137,31 +137,31 @@ TEST_SUITE(FF_TEST_SUITE) { nonnegative_int nn_int_1b = nonnegative_int{1}; nonnegative_int nn_int_2 = nonnegative_int{2}; SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { - CHECK((nn_int_1a > nn_int_2) == false); + CHECK_FALSE(nn_int_1a > nn_int_2); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { - CHECK((nn_int_1a > nn_int_1b) == false); + CHECK_FALSE(nn_int_1a > nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { - CHECK((nn_int_2 > nn_int_1b) == true); + CHECK(nn_int_2 > nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: int, less than") { - CHECK((nn_int_1a > 2) == false); + CHECK_FALSE(nn_int_1a > 2); } SUBCASE("LHS: nonnegative_int, RHS: int, equals") { - CHECK((nn_int_1a > 1) == false); + CHECK_FALSE(nn_int_1a > 1); } SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { - CHECK((nn_int_2 > 1) == true); + CHECK(nn_int_2 > 1); } SUBCASE("LHS: int, RHS: nonnegative_int, less than") { - CHECK((1 > nn_int_2) == false); + CHECK_FALSE(1 > nn_int_2); } SUBCASE("LHS: int, RHS: nonnegative_int, equals") { - CHECK((1 > nn_int_1b) == false); + CHECK_FALSE(1 > nn_int_1b); } SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { - CHECK((2 > nn_int_1b) == true); + CHECK(2 > nn_int_1b); } } @@ -170,35 +170,35 @@ TEST_SUITE(FF_TEST_SUITE) { nonnegative_int nn_int_1b = nonnegative_int{1}; nonnegative_int nn_int_2 = nonnegative_int{2}; SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, less than") { - CHECK((nn_int_1a >= nn_int_2) == false); + CHECK_FALSE(nn_int_1a >= nn_int_2); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, equals") { - CHECK((nn_int_1a >= nn_int_1b) == true); + CHECK(nn_int_1a >= nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: nonnegative_int, greater than") { - CHECK((nn_int_2 >= nn_int_1b) == true); + CHECK(nn_int_2 >= nn_int_1b); } SUBCASE("LHS: nonnegative_int, RHS: int, less than") { - CHECK((nn_int_1a >= 2) == false); + CHECK_FALSE(nn_int_1a >= 2); } SUBCASE("LHS: nonnegative_int, RHS: int, equals") { - CHECK((nn_int_1a >= 1) == true); + CHECK(nn_int_1a >= 1); } SUBCASE("LHS: nonnegative_int, RHS: int, greater than") { - CHECK((nn_int_2 >= 1) == true); + CHECK(nn_int_2 >= 1); } SUBCASE("LHS: int, RHS: nonnegative_int, less than") { - CHECK((1 >= nn_int_2) == false); + CHECK_FALSE(1 >= nn_int_2); } SUBCASE("LHS: int, RHS: nonnegative_int, equals") { - CHECK((1 >= nn_int_1b) == true); + CHECK(1 >= nn_int_1b); } SUBCASE("LHS: int, RHS: nonnegative_int, greater than") { - CHECK((2 >= nn_int_1b) == true); + CHECK(2 >= nn_int_1b); } } - TEST_CASE("nonnegative_int adl_serializer") { + TEST_CASE("adl_serializer") { SUBCASE("to_json") { nonnegative_int input = nonnegative_int{5}; @@ -230,7 +230,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(hash_fn(nn_int_1a) != hash_fn(nn_int_2)); } SUBCASE("Unordered set works with nonnegative_int") { - std::unordered_set nonnegative_int_set; + std::unordered_set<::FlexFlow::nonnegative_int> nonnegative_int_set; nonnegative_int_set.insert(nn_int_1a); nonnegative_int_set.insert(nn_int_1b); nonnegative_int_set.insert(nn_int_2); @@ -246,4 +246,9 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(oss.str() == "1"); } + + TEST_CASE("nonnegative int fmt::to_string") { + nonnegative_int nn_int_1 = nonnegative_int{1}; + CHECK(fmt::to_string(nn_int_1) == "1"); + } }