From 5e1f349215294ae312ccac9f2891aecacaf95e39 Mon Sep 17 00:00:00 2001 From: Reyna Abhyankar Date: Thu, 22 Aug 2024 14:59:04 -0700 Subject: [PATCH 01/12] Local execution tests (#1418) * pr for debugging kernel driver issues * Commit flake files * current kernel tests * softmax, flat, transpose kernel tests * clang formatting kernel tests * reverse, split, full dropout kernels * rest of kernel-tests * minor cleannup * Restore .proj.toml * Delete misadded directory * merge fix * more merge fixes * resolved merge conflicts with repo-refactor * code review changes * allocator updates * allocation util updates * test clean up and review fixes * fixed forward backward pass consistencies, added filler tests for all tests, other review changes * unnested test subcases and more review changes * Add == in OpTaskBinding * Add single operator test example * Finish multi operator test * added managed_stream and handle classes, other minor clean up * fix accessor and corresponding shape clarity, other clean up * merge error fixes * More aggressive subcasing * Remove comment * managed handle and stream fixes, removed datatype dispatch from cuda_helper, other clean up * managed handle and stream updates * Refactoring and split tests * Fix build * Fix build * Add cuda test suite * Remove mock * Pass task registry * Pass slots backing and task arg acc * Pass cost estimator test * Fix * PR fixes * Fixes * Add test to ci * Fix test libs * Fix build, add more fmt placeholders * Fixes * Fixes * Delete file * Fixes * Fixes * Fixes * Fix includes * Fix includes --------- Co-authored-by: Dylan Lim Co-authored-by: Dylan Lim <72822184+oOTigger@users.noreply.github.com> Co-authored-by: Colin Unger --- .github/workflows/helpers/test_libs.sh | 2 +- .github/workflows/per-lib-check.yml | 4 + .proj.toml | 4 +- cmake/flexflow-utils.cmake | 2 +- lib/kernels/CMakeLists.txt | 2 +- lib/kernels/include/kernels/accessor.h | 17 + lib/kernels/include/kernels/array_shape.h | 6 +- .../include/kernels/attention_kernels.h | 22 + lib/kernels/include/kernels/device.h | 4 +- lib/kernels/include/kernels/ff_handle.h | 3 + lib/kernels/include/kernels/legion_dim.h | 17 + lib/kernels/include/kernels/profiling.h | 14 +- .../kernels/profiling_settings.struct.toml | 18 + lib/kernels/src/accessor.cc | 29 ++ lib/kernels/src/array_shape.cc | 37 +- lib/kernels/src/cuda/cuda_helper.cu | 27 +- lib/kernels/src/cuda/ops/attention_kernels.cu | 35 ++ lib/kernels/src/device.cc | 4 + lib/kernels/src/device.h | 12 +- lib/kernels/src/ff_handle.cc | 18 + lib/kernels/src/local_cuda_allocator.cc | 2 +- lib/kernels/test/src/test_cast_kernel.cc | 1 - lib/local-execution/CMakeLists.txt | 2 + .../include/local-execution/arg_ref.h | 13 +- .../include/local-execution/concrete_arg.h | 9 + .../local-execution/cost_details.struct.toml | 1 - .../include/local-execution/device_specific.h | 22 +- ...device_specific_device_states.variant.toml | 90 ++++ .../include/local-execution/device_states.h | 40 -- .../fwd_bwd_task_impl_function.h | 32 ++ .../local-execution/init_task_impl_function.h | 33 ++ .../local-execution/local_cpu_allocator.h | 22 + .../local-execution/local_slots_backing.h | 23 +- .../local_task_argument_accessor.h | 10 + .../local-execution/local_training_backing.h | 14 +- .../include/local-execution/op_arg_ref.h | 20 +- .../op_arg_ref_type.variant.toml | 22 + .../local-execution/op_arg_spec.variant.toml | 2 +- .../local-execution/op_task_invocation.h | 17 +- .../local-execution/op_task_signature.h | 34 +- .../local-execution}/ops/attention.h | 0 ...parallel_tensor_shape_ref_type.struct.toml | 14 + .../local-execution/per_device_op_state.h | 15 + .../per_device_op_state.variant.toml | 87 ++++ .../per_device_op_state_ref_type.struct.toml | 12 + .../include/local-execution/runtime_arg_ref.h | 1 + .../local-execution/task_argument_accessor.h | 16 +- .../local-execution/task_id_t.enum.toml | 431 ++++++++++++++++ .../task_impl_function.variant.toml | 21 + .../include/local-execution/task_registry.h | 21 +- .../local-execution/task_registry.struct.toml | 34 ++ .../local-execution/task_signature_impl.h | 23 +- .../task_signature_impl.struct.toml | 20 + .../include/local-execution/tasks.h | 160 +----- .../local-execution/tracked_allocator.h | 1 + lib/local-execution/src/concrete_arg.cc | 29 ++ .../src/fwd_bwd_task_impl_function.cc | 54 ++ .../src/init_task_impl_function.cc | 47 ++ .../src/legion_tensor_shape.cc | 13 + .../src/local_cost_estimator.cc | 3 +- .../src/local_cpu_allocator.cc | 24 + .../src/local_slots_backing.cc | 70 ++- .../src/local_task_argument_accessor.cc | 34 ++ .../src/local_training_backing.cc | 111 ++-- lib/local-execution/src/op_arg_ref.cc | 2 +- .../src/{local-execution => }/op_arg_spec.cc | 0 lib/local-execution/src/op_task_invocation.cc | 31 +- lib/local-execution/src/op_task_signature.cc | 69 ++- lib/local-execution/src/ops/attention.cc | 39 +- lib/local-execution/src/ops/batch_matmul.cc | 14 +- lib/local-execution/src/ops/batch_norm.cc | 23 +- lib/local-execution/src/ops/cast.cc | 10 +- lib/local-execution/src/ops/combine.cc | 11 +- lib/local-execution/src/ops/concat.cc | 13 +- lib/local-execution/src/ops/conv_2d.cc | 21 +- lib/local-execution/src/ops/dropout.cc | 21 +- lib/local-execution/src/ops/dropout.h | 2 +- lib/local-execution/src/ops/element_binary.cc | 23 +- lib/local-execution/src/ops/element_unary.cc | 23 +- lib/local-execution/src/ops/flat.cc | 10 +- lib/local-execution/src/ops/gather.cc | 21 +- lib/local-execution/src/ops/input.cc | 9 + lib/local-execution/src/ops/input.h | 13 + lib/local-execution/src/ops/layer_norm.cc | 21 +- lib/local-execution/src/ops/linear.cc | 41 +- lib/local-execution/src/ops/noop.cc | 14 +- lib/local-execution/src/ops/noop.h | 6 +- lib/local-execution/src/ops/pool_2d.cc | 57 ++- lib/local-execution/src/ops/reduce.cc | 21 +- lib/local-execution/src/ops/reduction.cc | 10 +- lib/local-execution/src/ops/repartition.cc | 23 +- lib/local-execution/src/ops/replicate.cc | 10 +- lib/local-execution/src/ops/reshape.cc | 21 +- lib/local-execution/src/ops/reverse.cc | 10 +- lib/local-execution/src/ops/softmax.cc | 21 +- lib/local-execution/src/ops/split.cc | 10 +- lib/local-execution/src/ops/topk.cc | 21 +- lib/local-execution/src/ops/transpose.cc | 21 +- lib/local-execution/src/ops/weight.cc | 9 + lib/local-execution/src/ops/weight.h | 13 + lib/local-execution/src/per_device_state.cc | 12 + lib/local-execution/src/task_registry.cc | 59 ++- .../src/task_signature_impl.cc | 478 ++++++++++++------ lib/local-execution/src/tracked_allocator.cc | 3 +- lib/local-execution/test/CMakeLists.txt | 14 + .../test/src/test_local_cost_estimator.cc | 77 +++ .../test/src/test_local_slots_backing.cc | 273 ++++++++++ .../test/src/test_local_task_arg_accessor.cc | 143 ++++++ .../test/src/test_task_registry.cc | 131 +++++ lib/local-execution/test/src/test_utils.cc | 9 + lib/local-execution/test/src/test_utils.h | 12 + lib/op-attrs/src/op-attrs/ops/attention.cc | 48 +- lib/pcg/include/pcg/computation_graph.h | 3 + .../include/pcg/computation_graph_builder.h | 11 +- lib/pcg/src/pcg/computation_graph.cc | 10 + lib/pcg/src/pcg/computation_graph_builder.cc | 41 ++ .../src/ops/embedding.cc | 0 lib/utils/include/utils/join_strings.h | 2 +- 118 files changed, 3055 insertions(+), 882 deletions(-) create mode 100644 lib/kernels/include/kernels/profiling_settings.struct.toml create mode 100644 lib/kernels/src/ff_handle.cc create mode 100644 lib/local-execution/include/local-execution/device_specific_device_states.variant.toml delete mode 100644 lib/local-execution/include/local-execution/device_states.h create mode 100644 lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h create mode 100644 lib/local-execution/include/local-execution/init_task_impl_function.h create mode 100644 lib/local-execution/include/local-execution/local_cpu_allocator.h create mode 100644 lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml rename lib/local-execution/{src => include/local-execution}/ops/attention.h (100%) create mode 100644 lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml create mode 100644 lib/local-execution/include/local-execution/per_device_op_state.h create mode 100644 lib/local-execution/include/local-execution/per_device_op_state.variant.toml create mode 100644 lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml create mode 100644 lib/local-execution/include/local-execution/task_id_t.enum.toml create mode 100644 lib/local-execution/include/local-execution/task_impl_function.variant.toml create mode 100644 lib/local-execution/include/local-execution/task_registry.struct.toml create mode 100644 lib/local-execution/include/local-execution/task_signature_impl.struct.toml create mode 100644 lib/local-execution/src/concrete_arg.cc create mode 100644 lib/local-execution/src/fwd_bwd_task_impl_function.cc create mode 100644 lib/local-execution/src/init_task_impl_function.cc create mode 100644 lib/local-execution/src/legion_tensor_shape.cc create mode 100644 lib/local-execution/src/local_cpu_allocator.cc rename lib/local-execution/src/{local-execution => }/op_arg_spec.cc (100%) create mode 100644 lib/local-execution/src/ops/input.cc create mode 100644 lib/local-execution/src/ops/input.h create mode 100644 lib/local-execution/src/ops/weight.cc create mode 100644 lib/local-execution/src/ops/weight.h create mode 100644 lib/local-execution/src/per_device_state.cc create mode 100644 lib/local-execution/test/CMakeLists.txt create mode 100644 lib/local-execution/test/src/test_local_cost_estimator.cc create mode 100644 lib/local-execution/test/src/test_local_slots_backing.cc create mode 100644 lib/local-execution/test/src/test_local_task_arg_accessor.cc create mode 100644 lib/local-execution/test/src/test_task_registry.cc create mode 100644 lib/local-execution/test/src/test_utils.cc create mode 100644 lib/local-execution/test/src/test_utils.h rename lib/{local-execution => runtime}/src/ops/embedding.cc (100%) diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_libs.sh index 7662a7e601..69baa66364 100755 --- a/.github/workflows/helpers/test_libs.sh +++ b/.github/workflows/helpers/test_libs.sh @@ -7,7 +7,7 @@ DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" REPO="$(realpath -- "$DIR/../../../")" TEST_LIBS=("${@/%/-tests}") -REGEX="^$(IFS='|'; echo "${TEST_LIBS[*]}")\$" +REGEX="^($(IFS='|'; echo "${TEST_LIBS[*]}"))\$" cd "$REPO/build-ci" make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "${TEST_LIBS[@]}" diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 35860fbcec..38556a3c0e 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -116,6 +116,10 @@ jobs: run: | test_libs.sh substitution-generator + - name: Test local-execution + run: | + test_libs.sh local-execution + - name: Generate code coverage run: | echo "gitwork: $GITHUB_WORKSPACE" diff --git a/.proj.toml b/.proj.toml index 14bdcdb3b7..ee91d07833 100644 --- a/.proj.toml +++ b/.proj.toml @@ -11,16 +11,18 @@ build_targets = [ "substitutions", "compiler", "substitution-generator", - "local-execution", + "local-execution", ] test_targets = [ + # "kernels-tests", "utils-tests", "op-attrs-tests", "pcg-tests", "substitutions-tests", "compiler-tests", "substitution-generator-tests", + "local-execution-tests" ] [cmake_flags_extra] diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 32798e6833..1dbd16bdb1 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -118,7 +118,7 @@ function(ff_add_test_executable) ${FF_TEST_EXEC_NAME} ${FF_TEST_EXEC_DEPS}) - target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}") + target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}" FF_CUDA_TEST_SUITE="cuda-${FF_TEST_EXEC_NAME}") define_ff_vars(${FF_TEST_EXEC_NAME}) ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) diff --git a/lib/kernels/CMakeLists.txt b/lib/kernels/CMakeLists.txt index f166dd027c..8ccd7c1011 100644 --- a/lib/kernels/CMakeLists.txt +++ b/lib/kernels/CMakeLists.txt @@ -40,4 +40,4 @@ set_target_properties( CUDA_STANDARD 17 ) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 1ef121fb2a..2ee081ecbc 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -145,6 +145,22 @@ std::vector const *> GenericTensorAccessorR read_only_accessor_from_write_accessor( GenericTensorAccessorW const &write_accessor); +bool is_shape_and_dtype_equal(GenericTensorAccessorW const &acc1, + GenericTensorAccessorW const &acc2); + +bool shape_and_dtype_matches(GenericTensorAccessorW const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype); + +bool shape_and_dtype_matches(GenericTensorAccessorR const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype); + +std::pair + get_shape_and_datatype(GenericTensorAccessorR const &accessor); +std::pair + get_shape_and_datatype(GenericTensorAccessorW const &accessor); + } // namespace FlexFlow namespace FlexFlow { @@ -152,6 +168,7 @@ static_assert(is_well_behaved_value_type_no_hash::value, ""); static_assert(is_well_behaved_value_type_no_hash::value, ""); + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 5427d25bc3..5de9fae7ad 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -39,7 +39,8 @@ struct ArrayShape { legion_dim_t last_idx() const; legion_dim_t neg_idx(int) const; - std::optional at_maybe(std::size_t) const; + std::optional at_maybe(legion_dim_t) const; + std::optional at_maybe(ff_dim_t) const; ArrayShape sub_shape(std::optional> start, @@ -54,6 +55,9 @@ size_t get_volume(ArrayShape const &); TensorShape get_tensor_shape(ArrayShape const &, DataType); +std::string format_as(ArrayShape const &); +std::ostream &operator<<(std::ostream &, ArrayShape const &); + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index de37b4169f..575de57f09 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -25,6 +25,25 @@ struct MHAPerDeviceState { int *hiWinIdx; void *reserveSpace; Allocator allocator; + + bool operator==(MHAPerDeviceState const &other) const; + bool operator!=(MHAPerDeviceState const &other) const; + +private: + std::tuple + tie() const; }; FF_VISITABLE_STRUCT_NO_EQ(MHAPerDeviceState, @@ -43,6 +62,9 @@ FF_VISITABLE_STRUCT_NO_EQ(MHAPerDeviceState, reserveSpace, allocator); +std::string format_as(MHAPerDeviceState const &x); +std::ostream &operator<<(std::ostream &s, MHAPerDeviceState const &x); + namespace Kernels { namespace MultiHeadAttention { diff --git a/lib/kernels/include/kernels/device.h b/lib/kernels/include/kernels/device.h index c4e78821dc..cf4329774d 100644 --- a/lib/kernels/include/kernels/device.h +++ b/lib/kernels/include/kernels/device.h @@ -95,11 +95,13 @@ using coord_t = long long; exit(1); \ } while (0) +char const *getCudaErrorString(cudaError_t status); + #define checkCUDA(status) \ do { \ std::stringstream _error; \ if (status != 0) { \ - _error << "CUDA failure: " << cudaGetErrorString(status) << " (" \ + _error << "CUDA failure: " << getCudaErrorString(status) << " (" \ << status << ")"; \ FatalError(_error.str()); \ } \ diff --git a/lib/kernels/include/kernels/ff_handle.h b/lib/kernels/include/kernels/ff_handle.h index 89df04e3c1..179ce41cbf 100644 --- a/lib/kernels/include/kernels/ff_handle.h +++ b/lib/kernels/include/kernels/ff_handle.h @@ -40,6 +40,9 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(PerDeviceFFHandle, allowTensorOpMathConversion); #endif +std::string format_as(PerDeviceFFHandle const &x); +std::ostream &operator<<(std::ostream &s, PerDeviceFFHandle const &x); + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index d8ffd91489..aafbd2cdcb 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -15,6 +15,23 @@ using LegionOrdered = DimOrdered; using LegionTensorDims = LegionOrdered; +template +FFOrdered + ff_ordered_from_legion_ordered(LegionOrdered const &legion_ordered) { + return FFOrdered(legion_ordered.rbegin(), legion_ordered.rend()); +} + +template +std::string format_as(LegionOrdered const &v) { + std::vector as_vec(v.cbegin(), v.cend()); + return fmt::format("", as_vec); +} + +template +std::ostream &operator<<(std::ostream &s, LegionOrdered const &v) { + return (s << fmt::to_string(v)); +} + } // namespace FlexFlow #endif diff --git a/lib/kernels/include/kernels/profiling.h b/lib/kernels/include/kernels/profiling.h index 602689d491..655d540685 100644 --- a/lib/kernels/include/kernels/profiling.h +++ b/lib/kernels/include/kernels/profiling.h @@ -2,20 +2,11 @@ #define _FLEXFLOW_KERNELS_PROFILING_H #include "device.h" +#include "kernels/profiling_settings.dtg.h" #include "utils/visitable.h" namespace FlexFlow { -struct ProfilingSettings : public use_visitable_cmp { -public: - ProfilingSettings() = delete; - ProfilingSettings(int warmup_iters, int measure_iters); - -public: - int warmup_iters; - int measure_iters; -}; - template std::optional profiling_wrapper(F const &f, bool enable_profiling, Ts &&...ts) { @@ -59,7 +50,4 @@ std::optional profiling_wrapper(F const &f, } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::ProfilingSettings, warmup_iters, measure_iters); -MAKE_VISIT_HASHABLE(::FlexFlow::ProfilingSettings); - #endif diff --git a/lib/kernels/include/kernels/profiling_settings.struct.toml b/lib/kernels/include/kernels/profiling_settings.struct.toml new file mode 100644 index 0000000000..694dfac76a --- /dev/null +++ b/lib/kernels/include/kernels/profiling_settings.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "ProfilingSettings" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "warmup_iters" +type = "int" + +[[fields]] +name = "measure_iters" +type = "int" diff --git a/lib/kernels/src/accessor.cc b/lib/kernels/src/accessor.cc index 56002718b1..a852f0d7b3 100644 --- a/lib/kernels/src/accessor.cc +++ b/lib/kernels/src/accessor.cc @@ -138,4 +138,33 @@ GenericTensorAccessorR read_only_accessor_from_write_accessor( writable.data_type, writable.shape, req(writable.ptr)}; } +bool is_shape_and_dtype_equal(GenericTensorAccessorW const &acc1, + GenericTensorAccessorW const &acc2) { + return acc1.shape == acc2.shape && acc1.data_type == acc2.data_type; +} + +bool shape_and_dtype_matches(GenericTensorAccessorW const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype) { + return accessor.shape == expected_shape && + accessor.data_type == expected_dtype; +} + +bool shape_and_dtype_matches(GenericTensorAccessorR const &accessor, + ArrayShape const &expected_shape, + DataType const &expected_dtype) { + return accessor.shape == expected_shape && + accessor.data_type == expected_dtype; +} + +std::pair + get_shape_and_datatype(GenericTensorAccessorR const &accessor) { + return std::make_pair(accessor.shape, accessor.data_type); +} + +std::pair + get_shape_and_datatype(GenericTensorAccessorW const &accessor) { + return std::make_pair(accessor.shape, accessor.data_type); +} + } // namespace FlexFlow diff --git a/lib/kernels/src/array_shape.cc b/lib/kernels/src/array_shape.cc index 7daf97ecd1..d5e2f1167d 100644 --- a/lib/kernels/src/array_shape.cc +++ b/lib/kernels/src/array_shape.cc @@ -39,7 +39,15 @@ std::size_t ArrayShape::num_elements() const { } std::size_t ArrayShape::operator[](legion_dim_t idx) const { - return dims[idx]; + return dims.at(idx); +} + +std::size_t ArrayShape::at(legion_dim_t idx) const { + return dims.at(idx); +} + +std::size_t ArrayShape::at(ff_dim_t idx) const { + return dims.at(legion_dim_from_ff_dim(idx, this->num_dims())); } ArrayShape ArrayShape::sub_shape( @@ -48,16 +56,37 @@ ArrayShape ArrayShape::sub_shape( NOT_IMPLEMENTED(); } -std::optional ArrayShape::at_maybe(std::size_t index) const { - if (index < dims.size()) { - return dims.at(legion_dim_t(index)); +std::optional ArrayShape::at_maybe(legion_dim_t index) const { + if (index.value < dims.size()) { + return dims.at(index); } else { return std::nullopt; } } +std::optional ArrayShape::at_maybe(ff_dim_t index) const { + return this->at_maybe(legion_dim_from_ff_dim(index, this->num_dims())); +} + size_t get_volume(ArrayShape const &shape) { return shape.get_volume(); } +TensorShape get_tensor_shape(ArrayShape const &shape, DataType dtype) { + return TensorShape{TensorDims{ff_ordered_from_legion_ordered(shape.dims)}, + dtype}; +} + +std::string format_as(ArrayShape const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, ArrayShape const &x) { + return (s << fmt::to_string(x)); +} + } // namespace FlexFlow diff --git a/lib/kernels/src/cuda/cuda_helper.cu b/lib/kernels/src/cuda/cuda_helper.cu index 3488ce29af..2ff02038f4 100644 --- a/lib/kernels/src/cuda/cuda_helper.cu +++ b/lib/kernels/src/cuda/cuda_helper.cu @@ -220,25 +220,14 @@ __host__ void ffStatus_t cudnnSetTensorDescriptorFromArrayShape(cudnnTensorDescriptor_t tensor, ArrayShape const &shape) { - std::vector reversed_dims(shape.dims.begin(), shape.dims.end()); - reversed(reversed_dims); - ArrayShape flipped(reversed_dims); - - if (flipped.get_dim() == 5) { - assert(flipped[legion_dim_t(0)] == 1); - flipped = flipped.sub_shape(legion_dim_t(1), std::nullopt); - } - - assert(flipped.get_dim() > 0); - assert(flipped.get_dim() < 4); - - return cudnnSetTensor4dDescriptor(tensor, - CUDNN_TENSOR_NCHW, - CUDNN_DATA_FLOAT, - flipped.at_maybe(0).value_or(1), - flipped.at_maybe(1).value_or(2), - flipped.at_maybe(2).value_or(3), - flipped.at_maybe(3).value_or(3)); + return cudnnSetTensor4dDescriptor( + tensor, + CUDNN_TENSOR_NCHW, + CUDNN_DATA_FLOAT, + shape.at_maybe(legion_dim_t{0}).value_or(1), + shape.at_maybe(legion_dim_t{1}).value_or(1), + shape.at_maybe(legion_dim_t{2}).value_or(1), + shape.at_maybe(legion_dim_t{3}).value_or(1)); } cudnnDataType_t ff_to_cudnn_datatype(DataType type) { diff --git a/lib/kernels/src/cuda/ops/attention_kernels.cu b/lib/kernels/src/cuda/ops/attention_kernels.cu index e50f3983cc..38c32ad9e4 100644 --- a/lib/kernels/src/cuda/ops/attention_kernels.cu +++ b/lib/kernels/src/cuda/ops/attention_kernels.cu @@ -18,6 +18,41 @@ #include "kernels/device.h" namespace FlexFlow { + +bool MHAPerDeviceState::operator==(MHAPerDeviceState const &other) const { + return this->tie() == other.tie(); +} + +bool MHAPerDeviceState::operator!=(MHAPerDeviceState const &other) const { + return this->tie() != other.tie(); +} + +std:: + tuple + MHAPerDeviceState::tie() const { + return std::tie(this->handle, + this->weightSize, + this->reserveSpaceSize, + this->attnDesc, + this->qDesc, + this->kDesc, + this->vDesc, + this->oDesc, + this->devQoSeqArray, + this->devKvSeqArray, + this->loWinIdx, + this->hiWinIdx, + this->reserveSpace); +} + +std::string format_as(MHAPerDeviceState const &x) { + return fmt::format("MHAPerDeviceState"); +} + +std::ostream &operator<<(std::ostream &s, MHAPerDeviceState const &x) { + return (s << fmt::to_string(x)); +} + namespace Kernels { namespace MultiHeadAttention { diff --git a/lib/kernels/src/device.cc b/lib/kernels/src/device.cc index 0df5e84ee9..f46099c79a 100644 --- a/lib/kernels/src/device.cc +++ b/lib/kernels/src/device.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +char const *getCudaErrorString(cudaError_t status) { + return cudaGetErrorString(status); +} + ffError_t ffEventCreate(ffEvent_t *e) { #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) return cudaEventCreate(e); diff --git a/lib/kernels/src/device.h b/lib/kernels/src/device.h index 96670f712f..ceff2f92ff 100644 --- a/lib/kernels/src/device.h +++ b/lib/kernels/src/device.h @@ -7,6 +7,8 @@ #include "op-attrs/operator_type.h" #include +namespace FlexFlow { + #if defined(FF_USE_CUDA) #include #elif defined(FF_USE_HIP_CUDA) @@ -26,9 +28,6 @@ #error "Unknown device" #endif -using ::FlexFlow::DataType; -using ::FlexFlow::OperatorType; - #define checkCUDNN(status) \ do { \ std::stringstream _error; \ @@ -132,9 +131,8 @@ __host__ void updateGAS(float *para_ptr, template void print_tensor(T const *ptr, size_t num_elements, char const *prefix); -ffStatus_t - cudnnSetTensorDescriptorFromArrayShape(ffTensorDescriptor_t tensor, - FlexFlow::ArrayShape const &shape); +ffStatus_t cudnnSetTensorDescriptorFromArrayShape(ffTensorDescriptor_t tensor, + ArrayShape const &shape); ffDataType_t ff_to_cuda_datatype(DataType type); @@ -142,4 +140,6 @@ ffCudnnDataType_t ff_to_cudnn_datatype(DataType type); void handle_unimplemented_kernel(OperatorType op_type); +} // namespace FlexFlow + #endif diff --git a/lib/kernels/src/ff_handle.cc b/lib/kernels/src/ff_handle.cc new file mode 100644 index 0000000000..63ca6975fd --- /dev/null +++ b/lib/kernels/src/ff_handle.cc @@ -0,0 +1,18 @@ +#include "kernels/ff_handle.h" + +namespace FlexFlow { + +std::string format_as(PerDeviceFFHandle const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, PerDeviceFFHandle const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow diff --git a/lib/kernels/src/local_cuda_allocator.cc b/lib/kernels/src/local_cuda_allocator.cc index 9e9cb19070..cdcfb017a0 100644 --- a/lib/kernels/src/local_cuda_allocator.cc +++ b/lib/kernels/src/local_cuda_allocator.cc @@ -21,7 +21,7 @@ void LocalCudaAllocator::deallocate(void *ptr) { } LocalCudaAllocator::~LocalCudaAllocator() { - for (auto ptr : ptrs) { + for (void *ptr : this->ptrs) { checkCUDA(cudaFree(ptr)); } } diff --git a/lib/kernels/test/src/test_cast_kernel.cc b/lib/kernels/test/src/test_cast_kernel.cc index 004bc9c32f..b110208bce 100644 --- a/lib/kernels/test/src/test_cast_kernel.cc +++ b/lib/kernels/test/src/test_cast_kernel.cc @@ -1,6 +1,5 @@ #include "doctest/doctest.h" #include "kernels/cast_kernels.h" -#include "kernels/cast_kernels_cpu.h" #include "test_utils.h" #include diff --git a/lib/local-execution/CMakeLists.txt b/lib/local-execution/CMakeLists.txt index 52b0f8edf7..f649f86ce3 100644 --- a/lib/local-execution/CMakeLists.txt +++ b/lib/local-execution/CMakeLists.txt @@ -14,3 +14,5 @@ ff_add_library( pcg spdlog ) + +add_subdirectory(test) diff --git a/lib/local-execution/include/local-execution/arg_ref.h b/lib/local-execution/include/local-execution/arg_ref.h index 992a7971a5..30326b0e84 100644 --- a/lib/local-execution/include/local-execution/arg_ref.h +++ b/lib/local-execution/include/local-execution/arg_ref.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LOCAL_EXECUTION_ARG_REF_H #include "kernels/ff_handle.h" -#include "local-execution/profiling.h" // #include "local-execution/serialization.h #include "utils/type_index.h" #include "utils/visitable.h" @@ -32,6 +31,14 @@ struct ArgRefSpec { return this->type_idx; } + bool operator==(ArgRefSpec const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(ArgRefSpec const &other) const { + return this->tie() != other.tie(); + } + template static ArgRefSpec create(ArgRef const &r) { // static_assert(is_serializable::value, "Type must be serializeable"); @@ -46,6 +53,10 @@ struct ArgRefSpec { std::type_index type_idx; LABEL_TYPE ref_type; + std::tuple + tie() const { + return std::tie(this->type_idx, this->ref_type); + } friend struct std::hash>; }; diff --git a/lib/local-execution/include/local-execution/concrete_arg.h b/lib/local-execution/include/local-execution/concrete_arg.h index acb5c206da..3bc2714a71 100644 --- a/lib/local-execution/include/local-execution/concrete_arg.h +++ b/lib/local-execution/include/local-execution/concrete_arg.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_CONCRETE_ARG_H #define _FLEXFLOW_LOCAL_EXECUTION_CONCRETE_ARG_H +#include "fmt/format.h" #include "local-execution/serialization.h" #include "utils/type_index.h" #include @@ -22,6 +23,9 @@ struct ConcreteArgSpec { return this->type_idx; } + bool operator==(ConcreteArgSpec const &other) const; + bool operator!=(ConcreteArgSpec const &other) const; + template static ConcreteArgSpec create(T const &t) { // static_assert(is_serializable::value, "Type must be serializable"); @@ -40,8 +44,13 @@ struct ConcreteArgSpec { std::type_index type_idx; std::shared_ptr ptr; + + std::tuple tie() const; }; +std::string format_as(ConcreteArgSpec const &); +std::ostream &operator<<(std::ostream &, ConcreteArgSpec const &); + } // namespace FlexFlow #endif diff --git a/lib/local-execution/include/local-execution/cost_details.struct.toml b/lib/local-execution/include/local-execution/cost_details.struct.toml index e0d89acfdb..d17438b9ff 100644 --- a/lib/local-execution/include/local-execution/cost_details.struct.toml +++ b/lib/local-execution/include/local-execution/cost_details.struct.toml @@ -16,4 +16,3 @@ type = "float" [[fields]] name = "total_mem_usage" type = "size_t" - diff --git a/lib/local-execution/include/local-execution/device_specific.h b/lib/local-execution/include/local-execution/device_specific.h index 25d940e089..3a36e02327 100644 --- a/lib/local-execution/include/local-execution/device_specific.h +++ b/lib/local-execution/include/local-execution/device_specific.h @@ -13,7 +13,17 @@ struct DeviceSpecific { template static DeviceSpecific create(Args &&...args) { - NOT_IMPLEMENTED(); + size_t device_idx = 0; + return DeviceSpecific(std::make_shared(std::forward(args)...), + device_idx); + } + + bool operator==(DeviceSpecific const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(DeviceSpecific const &other) const { + return this->tie() != other.tie(); } T const *get(size_t curr_device_idx) const { @@ -23,17 +33,21 @@ struct DeviceSpecific { curr_device_idx, this->device_idx); } - return this->ptr; + return (T const *)this->ptr.get(); } // TODO: can modify ptr private: - DeviceSpecific(T *ptr, size_t device_idx) + DeviceSpecific(std::shared_ptr ptr, size_t device_idx) : ptr(ptr), device_idx(device_idx) {} - T *ptr; + std::shared_ptr ptr; size_t device_idx; + + std::tuple tie() const { + return std::tie(this->ptr, this->device_idx); + } }; // manually force serialization to make DeviceSpecific trivially diff --git a/lib/local-execution/include/local-execution/device_specific_device_states.variant.toml b/lib/local-execution/include/local-execution/device_specific_device_states.variant.toml new file mode 100644 index 0000000000..5f73bbbb8e --- /dev/null +++ b/lib/local-execution/include/local-execution/device_specific_device_states.variant.toml @@ -0,0 +1,90 @@ +namespace = "FlexFlow" +name = "DeviceSpecificDeviceStates" +features = [ + "eq", +] + +includes = [ + "kernels/attention_kernels.h", + "kernels/batch_norm_kernels.h", + "kernels/conv_2d_kernels.h", + "kernels/dropout_kernels.h", + "kernels/element_binary_kernels.h", + "kernels/element_unary_kernels.h", + "kernels/gather_kernels.h", + "kernels/layer_norm_kernels.h", + "kernels/linear_kernels.h", + "kernels/partition_kernels.h", + "kernels/pool_2d_kernels.h", + "kernels/reduce_kernels.h", + "kernels/reduction_kernels.h", + "kernels/reshape_kernels.h", + "kernels/softmax_kernels.h", + "kernels/topk_kernels.h", + "kernels/transpose_kernels.h", + "local-execution/device_specific.h", +] + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::MHAPerDeviceState>" +key = "device_specific_mha_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::BatchNormPerDeviceState>" +key = "device_specific_batch_norm_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::Conv2DPerDeviceState>" +key = "device_specific_conv2d_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::DropoutPerDeviceState>" +key = "device_specific_dropout_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ElementBinaryPerDeviceState>" +key = "device_specific_element_binary_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ElementUnaryPerDeviceState>" +key = "device_specific_element_unary_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::GatherPerDeviceState>" +key = "device_specific_gather_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::LayerNormPerDeviceState>" +key = "device_specific_layer_norm_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::LinearPerDeviceState>" +key = "device_specific_linear_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::Pool2DPerDeviceState>" +key = "device_specific_pool_2d_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ReducePerDeviceState>" +key = "device_specific_reduce_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::RepartitionPerDeviceState>" +key = "device_specific_repartition_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::ReshapePerDeviceState>" +key = "device_specific_reshape_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::SoftmaxPerDeviceState>" +key = "device_specific_softmax_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::TopKPerDeviceState>" +key = "device_specific_topk_per_device_state" + +[[values]] +type = "::FlexFlow::DeviceSpecific<::FlexFlow::TransposePerDeviceState>" +key = "device_specific_transpose_per_device_state" diff --git a/lib/local-execution/include/local-execution/device_states.h b/lib/local-execution/include/local-execution/device_states.h deleted file mode 100644 index 550c64a4bd..0000000000 --- a/lib/local-execution/include/local-execution/device_states.h +++ /dev/null @@ -1,40 +0,0 @@ - -#include "kernels/attention_kernels.h" -#include "kernels/batch_norm_kernels.h" -#include "kernels/conv_2d_kernels.h" -#include "kernels/dropout_kernels.h" -#include "kernels/element_binary_kernels.h" -#include "kernels/element_unary_kernels.h" -#include "kernels/gather_kernels.h" -#include "kernels/layer_norm_kernels.h" -#include "kernels/linear_kernels.h" -#include "kernels/partition_kernels.h" -#include "kernels/pool_2d_kernels.h" -#include "kernels/reduce_kernels.h" -#include "kernels/reduction_kernels.h" -#include "kernels/reshape_kernels.h" -#include "kernels/softmax_kernels.h" -#include "kernels/topk_kernels.h" -#include "kernels/transpose_kernels.h" -#include - -namespace FlexFlow { - -using DeviceStates = std::variant; - -} diff --git a/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h b/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h new file mode 100644 index 0000000000..7f80af77f3 --- /dev/null +++ b/lib/local-execution/include/local-execution/fwd_bwd_task_impl_function.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_FWD_BWD_TASK_IMPL_FUNCTION_H + +#include "local-execution/task_argument_accessor.h" + +namespace FlexFlow { + +struct FwdBwdTaskImplFunction { + + std::optional (*function_ptr)(TaskArgumentAccessor const &); + + bool operator==(FwdBwdTaskImplFunction const &) const; + bool operator!=(FwdBwdTaskImplFunction const &) const; + bool operator<(FwdBwdTaskImplFunction const &) const; + bool operator>(FwdBwdTaskImplFunction const &) const; + bool operator<=(FwdBwdTaskImplFunction const &) const; + bool operator>=(FwdBwdTaskImplFunction const &) const; +}; + +std::string format_as(FwdBwdTaskImplFunction const &x); +std::ostream &operator<<(std::ostream &s, FwdBwdTaskImplFunction const &x); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::FwdBwdTaskImplFunction> { + size_t operator()(::FlexFlow::FwdBwdTaskImplFunction const &) const; +}; +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/init_task_impl_function.h b/lib/local-execution/include/local-execution/init_task_impl_function.h new file mode 100644 index 0000000000..b85944e13a --- /dev/null +++ b/lib/local-execution/include/local-execution/init_task_impl_function.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H +#define _FLEXFLOW_LOCAL_EXECUTION_INIT_TASK_IMPL_FUNCTION_H + +#include "local-execution/device_specific_device_states.dtg.h" +#include "local-execution/task_argument_accessor.h" + +namespace FlexFlow { + +struct InitTaskImplFunction { + + DeviceSpecificDeviceStates (*function_ptr)(TaskArgumentAccessor const &); + + bool operator==(InitTaskImplFunction const &) const; + bool operator!=(InitTaskImplFunction const &) const; + bool operator<(InitTaskImplFunction const &) const; + bool operator>(InitTaskImplFunction const &) const; + bool operator<=(InitTaskImplFunction const &) const; + bool operator>=(InitTaskImplFunction const &) const; +}; + +std::string format_as(InitTaskImplFunction const &x); +std::ostream &operator<<(std::ostream &s, InitTaskImplFunction const &x); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash<::FlexFlow::InitTaskImplFunction> { + size_t operator()(::FlexFlow::InitTaskImplFunction const &) const; +}; +} // namespace std + +#endif diff --git a/lib/local-execution/include/local-execution/local_cpu_allocator.h b/lib/local-execution/include/local-execution/local_cpu_allocator.h new file mode 100644 index 0000000000..d1e81facf2 --- /dev/null +++ b/lib/local-execution/include/local-execution/local_cpu_allocator.h @@ -0,0 +1,22 @@ +#include "kernels/allocation.h" +#include + +namespace FlexFlow { + +struct LocalCPUAllocator : public IAllocator { + LocalCPUAllocator() = default; + LocalCPUAllocator(LocalCPUAllocator const &) = delete; + LocalCPUAllocator(LocalCPUAllocator &&) = delete; + ~LocalCPUAllocator() = default; + + void *allocate(size_t) override; + void deallocate(void *) override; + +private: + std::unordered_map> ptrs; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalCPUAllocator); + +Allocator create_local_cpu_memory_allocator(); + +} // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/local_slots_backing.h b/lib/local-execution/include/local-execution/local_slots_backing.h index 46dcc3d914..6a0c28e988 100644 --- a/lib/local-execution/include/local-execution/local_slots_backing.h +++ b/lib/local-execution/include/local-execution/local_slots_backing.h @@ -3,9 +3,9 @@ #define _FLEXFLOW_LOCAL_EXECUTION_SLOT_REGISTRY_H #include "kernels/accessor.h" -#include "local-execution/device_states.h" #include "local-execution/local_task_argument_accessor.h" #include "local-execution/op_task_invocation.h" +#include "local-execution/per_device_op_state.h" #include "local-execution/runtime_arg_config.h" namespace FlexFlow { @@ -18,17 +18,24 @@ struct LocalSlotsBacking { public: void add_per_device_op_state(layer_guid_t const &, - DeviceSpecific const &); - bool is_tensor_allocated(tensor_guid_t const &) const; - GenericTensorAccessorW const &get_tensor_backing(tensor_guid_t const &, - IsGrad) const; + DeviceSpecificDeviceStates const &); + void allocate_outgoing_tensors(layer_guid_t const &, + ComputationGraph const &, + Allocator &); TensorSlotsBacking construct_tensor_slots_backing(OpTaskBinding const &, layer_guid_t const &) const; ArgSlotsBacking construct_arg_slots_backing(OpTaskBinding const &, layer_guid_t const &) const; + + ConcreteArgSpec resolve_runtime_arg_ref_spec(RuntimeArgRefSpec const &) const; ConcreteArgSpec resolve_op_arg_ref_spec(OpArgRefSpec const &, layer_guid_t const &) const; - ConcreteArgSpec resolve_runtime_arg_ref_spec(RuntimeArgRefSpec const &) const; + +private: + bool is_tensor_allocated(tensor_guid_t const &) const; + bool is_gradient_tensor_allocated(tensor_guid_t const &) const; + GenericTensorAccessorW const &get_tensor_backing(tensor_guid_t const &, + IsGrad) const; public: // tensors @@ -36,13 +43,11 @@ struct LocalSlotsBacking { TensorBackingMap gradient_tensor_mapping; std::unordered_map> input_tensor_slots; - std::unordered_map> - weight_tensor_slots; std::unordered_map> output_tensor_slots; // arguments - std::unordered_map>> + std::unordered_map per_device_op_states; RuntimeArgConfig runtime_arg_config; }; diff --git a/lib/local-execution/include/local-execution/local_task_argument_accessor.h b/lib/local-execution/include/local-execution/local_task_argument_accessor.h index 27c8af0836..1e1516a0de 100644 --- a/lib/local-execution/include/local-execution/local_task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/local_task_argument_accessor.h @@ -38,6 +38,16 @@ struct LocalTaskArgumentAccessor : public ITaskArgumentAccessor { TensorSlotsBacking tensor_slots_backing; ArgSlotsBacking arg_slots_backing; }; + +using TensorSlotsBackingWithoutAddresses = std::unordered_map< + SlotGradId, + std::variant, + std::vector>>>; + +TensorSlotsBackingWithoutAddresses + get_slots_backing_without_tensor_allocation_addresses( + TensorSlotsBacking const &); + CHECK_RC_COPY_VIRTUAL_COMPLIANT(LocalTaskArgumentAccessor); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h index 2fe3a57407..b398bb8cc3 100644 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ b/lib/local-execution/include/local-execution/local_training_backing.h @@ -14,29 +14,27 @@ struct LocalTrainingBacking { ComputationGraph const &, TensorBackingMap const &, RuntimeArgConfig const &); - ~LocalTrainingBacking() = default; void execute_init(); PerLayerElapsedTime execute_forward(); PerLayerElapsedTime execute_backward(); void execute_update(); -private: - DeviceSpecific - call_init_task_impl(task_id_t, TaskArgumentAccessor const &); - std::optional call_task_impl(task_id_t, TaskArgumentAccessor); - TaskArgumentAccessor get_task_arg_accessor(OpTaskInvocation const &, layer_guid_t const &) const; +private: + DeviceSpecificDeviceStates call_init_task_impl(task_id_t, + TaskArgumentAccessor const &); + std::optional call_task_impl(task_id_t, TaskArgumentAccessor); + +private: Allocator allocator; ComputationGraph computation_graph; TaskRegistry task_registry; LocalSlotsBacking local_slots_backing; }; -std::vector get_task_ids(ComputationGraphOpAttrs const &); - } // namespace FlexFlow #endif diff --git a/lib/local-execution/include/local-execution/op_arg_ref.h b/lib/local-execution/include/local-execution/op_arg_ref.h index 939fab21c0..20d6ccb1c5 100644 --- a/lib/local-execution/include/local-execution/op_arg_ref.h +++ b/lib/local-execution/include/local-execution/op_arg_ref.h @@ -3,30 +3,22 @@ #include "local-execution/arg_ref.h" #include "local-execution/device_specific.h" +#include "local-execution/op_arg_ref_type.dtg.h" +#include "local-execution/per_device_op_state.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -enum class OpArgRefLabel { PER_DEVICE_OP_STATE, PARALLEL_TENSOR_SHAPE }; - -struct PerDeviceOpStateRefType {}; - -struct ParallelTensorShapeRefType { - int idx; -}; - -using OpArgRefType = - std::variant; - template using OpArgRef = ArgRef; using OpArgRefSpec = ArgRefSpec; template -OpArgRef> per_device_op_state() { - OpArgRefType op_arg_ref_type = PerDeviceOpStateRefType{}; - ArgRef> arg_ref = {op_arg_ref_type}; +OpArgRef per_device_op_state() { + OpArgRefType op_arg_ref_type = OpArgRefType{PerDeviceOpStateRefType{}}; + static_assert(PerDeviceOpState::IsPartOfPerDeviceOpState_v); + ArgRef arg_ref = {op_arg_ref_type}; return arg_ref; } diff --git a/lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml b/lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml new file mode 100644 index 0000000000..cd226da161 --- /dev/null +++ b/lib/local-execution/include/local-execution/op_arg_ref_type.variant.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "OpArgRefType" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "local-execution/per_device_op_state_ref_type.dtg.h", + "local-execution/parallel_tensor_shape_ref_type.dtg.h", +] + +[[values]] +type = "::FlexFlow::PerDeviceOpStateRefType" +key = "per_device_op_state_ref_type" + +[[values]] +type = "::FlexFlow::ParallelTensorShapeRefType" +key = "parallel_tensor_shape_ref_type" diff --git a/lib/local-execution/include/local-execution/op_arg_spec.variant.toml b/lib/local-execution/include/local-execution/op_arg_spec.variant.toml index a13018e6a1..28169902ae 100644 --- a/lib/local-execution/include/local-execution/op_arg_spec.variant.toml +++ b/lib/local-execution/include/local-execution/op_arg_spec.variant.toml @@ -1,7 +1,7 @@ namespace = "FlexFlow" name = "OpArgSpec" features = [ - # "eq", + "eq", # "ord", # "hash", # "json", diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h index eafd6b80b0..73a0460554 100644 --- a/lib/local-execution/include/local-execution/op_task_invocation.h +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -11,7 +11,7 @@ #include "local-execution/profiling.h" #include "local-execution/runtime_arg_ref.h" #include "local-execution/slot_grad_id.dtg.h" -#include "local-execution/tasks.h" +#include "local-execution/task_id_t.dtg.h" #include "local-execution/variadic_tensor_ref.h" #include "op-attrs/computation_graph_op_attrs.h" #include "pcg/computation_graph.h" @@ -85,6 +85,8 @@ struct OpTaskBinding { void bind_arg(slot_id_t name, OpArgRef const &ref) { this->insert_arg_spec(name, OpArgSpec{OpArgRefSpec::create(ref)}); } + bool operator==(OpTaskBinding const &other) const; + bool operator!=(OpTaskBinding const &other) const; std::unordered_map const & get_tensor_bindings() const; @@ -93,15 +95,19 @@ struct OpTaskBinding { void bind_from_forward(OpTaskBinding const &fwd); private: - void insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec); std::unordered_map tensor_bindings; std::unordered_map arg_bindings; + +private: + void insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec); + std::tuple + tie() const; }; struct OpTaskInvocation { public: OpTaskInvocation() = delete; - OpTaskInvocation(task_id_t const &task_id, OpTaskBinding const &binding) + OpTaskInvocation(task_id_t task_id, OpTaskBinding const &binding) : task_id(task_id), binding(binding) {} public: @@ -109,11 +115,6 @@ struct OpTaskInvocation { OpTaskBinding binding; }; -OpTaskInvocation init(ComputationGraphOpAttrs const &); -OpTaskInvocation forward(ComputationGraphOpAttrs const &); -OpTaskInvocation backward(ComputationGraphOpAttrs const &); - -OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); OpTaskBinding infer_bwd_binding(OpTaskBinding const &fwd); bool is_invocation_valid(OpTaskSignature const &sig, diff --git a/lib/local-execution/include/local-execution/op_task_signature.h b/lib/local-execution/include/local-execution/op_task_signature.h index ad5177b289..0447644354 100644 --- a/lib/local-execution/include/local-execution/op_task_signature.h +++ b/lib/local-execution/include/local-execution/op_task_signature.h @@ -7,7 +7,9 @@ #include "local-execution/serialization.h" #include "local-execution/slot_id_t.dtg.h" #include "local-execution/slot_type.dtg.h" -#include "local-execution/tasks.h" +#include "local-execution/task_id_t.dtg.h" +#include "utils/hash/unordered_map.h" +#include "utils/hash/unordered_set.h" #include "utils/type_index.h" #include "utils/visitable.h" @@ -40,10 +42,9 @@ struct OpTaskSignature { void add_output_slot(int, SlotType slot_type = SlotType::TENSOR); void add_output_slot(slot_id_t, SlotType slot_type = SlotType::TENSOR); - void add_bwd_necessary_output_slot(int, - SlotType slot_type = SlotType::TENSOR); - void add_bwd_necessary_output_slot(slot_id_t, - SlotType slot_type = SlotType::TENSOR); + void add_bwd_optional_output_slot(int, SlotType slot_type = SlotType::TENSOR); + void add_bwd_optional_output_slot(slot_id_t, + SlotType slot_type = SlotType::TENSOR); void add_weight_slot(int, SlotType slot_type = SlotType::TENSOR); void add_weight_slot(slot_id_t, SlotType slot_type = SlotType::TENSOR); @@ -96,27 +97,10 @@ struct OpTaskSignature { FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION( OpTaskSignature, type, return_value, task_arg_types, op_tensor_slots); -template -void register_task(task_id_t, - std::string const &name, - OpTaskSignature const &, - F const &func); +std::string format_as(OpTaskSignature const &x); +std::ostream &operator<<(std::ostream &s, OpTaskSignature const &x); -template -void register_task(task_id_t, - std::string const &name, - OpTaskSignature const &, - F const &func, - F const &cpu_func); - -template -OpTaskSignature init_signature(); - -template -OpTaskSignature fwd_signature(); - -template -OpTaskSignature bwd_signature(); +OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd); } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/attention.h b/lib/local-execution/include/local-execution/ops/attention.h similarity index 100% rename from lib/local-execution/src/ops/attention.h rename to lib/local-execution/include/local-execution/ops/attention.h diff --git a/lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml b/lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml new file mode 100644 index 0000000000..fe340f4451 --- /dev/null +++ b/lib/local-execution/include/local-execution/parallel_tensor_shape_ref_type.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParallelTensorShapeRefType" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +[[fields]] +name = "idx" +type = "int" diff --git a/lib/local-execution/include/local-execution/per_device_op_state.h b/lib/local-execution/include/local-execution/per_device_op_state.h new file mode 100644 index 0000000000..1edd5b6360 --- /dev/null +++ b/lib/local-execution/include/local-execution/per_device_op_state.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_PER_DEVICE_STATE_H +#define _FLEXFLOW_LOCAL_EXECUTION_PER_DEVICE_STATE_H + +#include "local-execution/device_specific_device_states.dtg.h" +#include "local-execution/per_device_op_state.dtg.h" + +namespace FlexFlow { + +PerDeviceOpState + get_device_state_from_device_specific(DeviceSpecificDeviceStates const &, + size_t device_idx); + +} + +#endif diff --git a/lib/local-execution/include/local-execution/per_device_op_state.variant.toml b/lib/local-execution/include/local-execution/per_device_op_state.variant.toml new file mode 100644 index 0000000000..f99ff10bb9 --- /dev/null +++ b/lib/local-execution/include/local-execution/per_device_op_state.variant.toml @@ -0,0 +1,87 @@ +namespace = "FlexFlow" +name = "PerDeviceOpState" +features = [] + +includes = [ + "kernels/attention_kernels.h", + "kernels/batch_norm_kernels.h", + "kernels/conv_2d_kernels.h", + "kernels/dropout_kernels.h", + "kernels/element_binary_kernels.h", + "kernels/element_unary_kernels.h", + "kernels/gather_kernels.h", + "kernels/layer_norm_kernels.h", + "kernels/linear_kernels.h", + "kernels/partition_kernels.h", + "kernels/pool_2d_kernels.h", + "kernels/reduce_kernels.h", + "kernels/reduction_kernels.h", + "kernels/reshape_kernels.h", + "kernels/softmax_kernels.h", + "kernels/topk_kernels.h", + "kernels/transpose_kernels.h", +] + +[[values]] +type = "::FlexFlow::MHAPerDeviceState" +key = "mha_per_device_state" + +[[values]] +type = "::FlexFlow::BatchNormPerDeviceState" +key = "batch_norm_per_device_state" + +[[values]] +type = "::FlexFlow::Conv2DPerDeviceState" +key = "conv2d_per_device_state" + +[[values]] +type = "::FlexFlow::DropoutPerDeviceState" +key = "dropout_per_device_state" + +[[values]] +type = "::FlexFlow::ElementBinaryPerDeviceState" +key = "element_binary_per_device_state" + +[[values]] +type = "::FlexFlow::ElementUnaryPerDeviceState" +key = "element_unary_per_device_state" + +[[values]] +type = "::FlexFlow::GatherPerDeviceState" +key = "gather_per_device_state" + +[[values]] +type = "::FlexFlow::LayerNormPerDeviceState" +key = "layer_norm_per_device_state" + +[[values]] +type = "::FlexFlow::LinearPerDeviceState" +key = "linear_per_device_state" + +[[values]] +type = "::FlexFlow::Pool2DPerDeviceState" +key = "pool_2d_per_device_state" + +[[values]] +type = "::FlexFlow::ReducePerDeviceState" +key = "reduce_per_device_state" + +[[values]] +type = "::FlexFlow::RepartitionPerDeviceState" +key = "repartition_per_device_state" + +[[values]] +type = "::FlexFlow::ReshapePerDeviceState" +key = "reshape_per_device_state" + +[[values]] +type = "::FlexFlow::SoftmaxPerDeviceState" +key = "softmax_per_device_state" + +[[values]] +type = "::FlexFlow::TopKPerDeviceState" +key = "topk_per_device_state" + +[[values]] +type = "::FlexFlow::TransposePerDeviceState" +key = "transpose_per_device_state" diff --git a/lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml b/lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml new file mode 100644 index 0000000000..e3d48a02ee --- /dev/null +++ b/lib/local-execution/include/local-execution/per_device_op_state_ref_type.struct.toml @@ -0,0 +1,12 @@ +namespace = "FlexFlow" +name = "PerDeviceOpStateRefType" + +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +fields = [] diff --git a/lib/local-execution/include/local-execution/runtime_arg_ref.h b/lib/local-execution/include/local-execution/runtime_arg_ref.h index 470f02929b..279d854a27 100644 --- a/lib/local-execution/include/local-execution/runtime_arg_ref.h +++ b/lib/local-execution/include/local-execution/runtime_arg_ref.h @@ -4,6 +4,7 @@ #include "local-execution/arg_ref.h" #include "local-execution/config.h" #include "local-execution/device_specific.h" +#include "local-execution/profiling.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/task_argument_accessor.h b/lib/local-execution/include/local-execution/task_argument_accessor.h index 7a84bfb5c3..54c8dfc5f1 100644 --- a/lib/local-execution/include/local-execution/task_argument_accessor.h +++ b/lib/local-execution/include/local-execution/task_argument_accessor.h @@ -1,19 +1,27 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_ARGUMENT_ACCESSOR_H #define _FLEXFLOW_LOCAL_EXECUTION_TASK_ARGUMENT_ACCESSOR_H +#include "local-execution/device_specific.h" #include "local-execution/itask_argument_accessor.h" +#include "local-execution/per_device_op_state.dtg.h" namespace FlexFlow { struct TaskArgumentAccessor { template - T const &get_argument(int slot) const { - return this->get_argument(slot_id_t{slot}); + T const &get_argument(slot_id_t slot) const { + if constexpr (PerDeviceOpState::IsPartOfPerDeviceOpState_v) { + PerDeviceOpState device_states = + this->ptr->get_concrete_arg(slot).get(); + return device_states.get(); + } else { + return this->ptr->get_concrete_arg(slot).get(); + } } template - T const &get_argument(slot_id_t slot) const { - return this->ptr->get_concrete_arg(slot).get(); + T const &get_argument(int slot) const { + return this->get_argument(slot_id_t{slot}); } template diff --git a/lib/local-execution/include/local-execution/task_id_t.enum.toml b/lib/local-execution/include/local-execution/task_id_t.enum.toml new file mode 100644 index 0000000000..9cbe64c268 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_id_t.enum.toml @@ -0,0 +1,431 @@ +namespace = "FlexFlow" +name = "task_id_t" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "TOP_LEVEL_TASK_ID" + +[[values]] +name = "FF_INIT_TASK_ID" + +[[values]] +name = "IMAGE_INIT_TASK_ID" + +[[values]] +name = "LABEL_INIT_TASK_ID" + +[[values]] +name = "LOAD_IMAGES_TASK_ID" + +[[values]] +name = "NORMALIZE_IMAGES_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_INIT_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_FWD_TASK_ID" + +[[values]] +name = "ELEMENTBINARY_BWD_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_INIT_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_FWD_TASK_ID" + +[[values]] +name = "ELEMENTUNARY_BWD_TASK_ID" + +[[values]] +name = "CONV2D_INIT_TASK_ID" + +[[values]] +name = "CONV2D_FWD_TASK_ID" + +[[values]] +name = "CONV2D_BWD_TASK_ID" + +[[values]] +name = "DROPOUT_INIT_TASK_ID" + +[[values]] +name = "DROPOUT_FWD_TASK_ID" + +[[values]] +name = "DROPOUT_BWD_TASK_ID" + +[[values]] +name = "EMBED_INIT_TASK_ID" + +[[values]] +name = "EMBED_FWD_TASK_ID" + +[[values]] +name = "EMBED_BWD_TASK_ID" + +[[values]] +name = "GATHER_INIT_TASK_ID" + +[[values]] +name = "GATHER_FWD_TASK_ID" + +[[values]] +name = "GATHER_BWD_TASK_ID" + +[[values]] +name = "CAST_INIT_TASK_ID" + +[[values]] +name = "CAST_FWD_TASK_ID" + +[[values]] +name = "CAST_BWD_TASK_ID" + +[[values]] +name = "POOL2D_INIT_TASK_ID" + +[[values]] +name = "POOL2D_FWD_TASK_ID" + +[[values]] +name = "POOL2D_BWD_TASK_ID" + +[[values]] +name = "BATCHNORM_INIT_TASK_ID" + +[[values]] +name = "BATCHNORM_FWD_TASK_ID" + +[[values]] +name = "BATCHNORM_BWD_TASK_ID" + +[[values]] +name = "BATCHMATMUL_INIT_TASK_ID" + +[[values]] +name = "BATCHMATMUL_FWD_TASK_ID" + +[[values]] +name = "BATCHMATMUL_BWD_TASK_ID" + +[[values]] +name = "LAYERNORM_INIT_TASK_ID" + +[[values]] +name = "LAYERNORM_FWD_TASK_ID" + +[[values]] +name = "LAYERNORM_BWD_TASK_ID" + +[[values]] +name = "LINEAR_INIT_TASK_ID" + +[[values]] +name = "LINEAR_FWD_TASK_ID" + +[[values]] +name = "LINEAR_BWD_TASK_ID" + +[[values]] +name = "FLAT_INIT_TASK_ID" + +[[values]] +name = "FLAT_FWD_TASK_ID" + +[[values]] +name = "FLAT_BWD_TASK_ID" + +[[values]] +name = "SOFTMAX_INIT_TASK_ID" + +[[values]] +name = "SOFTMAX_FWD_TASK_ID" + +[[values]] +name = "SOFTMAX_BWD_TASK_ID" + +[[values]] +name = "CONCAT_INIT_TASK_ID" + +[[values]] +name = "CONCAT_FWD_TASK_ID" + +[[values]] +name = "CONCAT_BWD_TASK_ID" + +[[values]] +name = "SPLIT_INIT_TASK_ID" + +[[values]] +name = "SPLIT_FWD_TASK_ID" + +[[values]] +name = "SPLIT_BWD_TASK_ID" + +[[values]] +name = "REDUCE_INIT_TASK_ID" + +[[values]] +name = "REDUCE_FWD_TASK_ID" + +[[values]] +name = "REDUCE_BWD_TASK_ID" + +[[values]] +name = "RESHAPE_INIT_TASK_ID" + +[[values]] +name = "RESHAPE_FWD_TASK_ID" + +[[values]] +name = "RESHAPE_BWD_TASK_ID" + +[[values]] +name = "REVERSE_INIT_TASK_ID" + +[[values]] +name = "REVERSE_FWD_TASK_ID" + +[[values]] +name = "REVERSE_BWD_TASK_ID" + +[[values]] +name = "TOPK_INIT_TASK_ID" + +[[values]] +name = "TOPK_FWD_TASK_ID" + +[[values]] +name = "TOPK_BWD_TASK_ID" + +[[values]] +name = "TRANSPOSE_INIT_TASK_ID" + +[[values]] +name = "TRANSPOSE_FWD_TASK_ID" + +[[values]] +name = "TRANSPOSE_BWD_TASK_ID" + +[[values]] +name = "ATTENTION_INIT_TASK_ID" + +[[values]] +name = "ATTENTION_FWD_TASK_ID" + +[[values]] +name = "ATTENTION_BWD_TASK_ID" + +[[values]] +name = "MSELOSS_BWD_TASK_ID" + +[[values]] +name = "FUSEDOP_INIT_TASK_ID" + +[[values]] +name = "FUSEDOP_FWD_TASK_ID" + +[[values]] +name = "FUSEDOP_BWD_TASK_ID" + +[[values]] +name = "NOOP_INIT_TASK_ID" + +[[values]] +name = "METRICS_COMP_TASK_ID" + +[[values]] +name = "UPDATE_METRICS_TASK_ID" + +[[values]] +name = "PS_PREFETCH_TASK_ID" + +[[values]] +name = "LOSS_BWD_TASK_ID" + +[[values]] +name = "SGD_UPD_PS_TASK_ID" + +[[values]] +name = "ADAM_UPD_PS_TASK_ID" + +[[values]] +name = "SGD_UPD_NCCL_TASK_ID" + +[[values]] +name = "ADAM_UPD_NCCL_TASK_ID" + +[[values]] +name = "GLOROT_INIT_TASK_ID" + +[[values]] +name = "ZERO_INIT_TASK_ID" + +[[values]] +name = "CONSTANT_INIT_TASK_ID" + +[[values]] +name = "UNIFORM_INIT_TASK_ID" + +[[values]] +name = "NORMAL_INIT_TASK_ID" + +[[values]] +name = "NCCL_GETUNIQUEID_TASK_ID" + +[[values]] +name = "NCCL_INIT_COMMS_TASK_ID" + +[[values]] +name = "STRATEGY_SEARCH_TASK_ID" + +[[values]] +name = "GRAPH_OPTIMIZE_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_INDEX_LOAD_ENTIRE_CPU_TASK_ID" + +[[values]] +name = "PY_DL_FLOAT_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "PY_DL_INT32_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "PY_DL_INT64_LOAD_BATCH_GPU_TASK_ID" + +[[values]] +name = "REPARTITION_INIT_TASK_ID" + +[[values]] +name = "REPARTITION_FWD_TASK_ID" + +[[values]] +name = "REPARTITION_BWD_TASK_ID" + +[[values]] +name = "COMBINE_INIT_TASK_ID" + +[[values]] +name = "COMBINE_FWD_TASK_ID" + +[[values]] +name = "COMBINE_BWD_TASK_ID" + +[[values]] +name = "REPLICATE_INIT_TASK_ID" + +[[values]] +name = "REPLICATE_FWD_TASK_ID" + +[[values]] +name = "REPLICATE_BWD_TASK_ID" + +[[values]] +name = "REDUCTION_INIT_TASK_ID" + +[[values]] +name = "REDUCTION_FWD_TASK_ID" + +[[values]] +name = "REDUCTION_BWD_TASK_ID" + +[[values]] +name = "PIPELINE_INIT_TASK_ID" + +[[values]] +name = "PIPELINE_FWD_TASK_ID" + +[[values]] +name = "PIPELINE_BWD_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_INIT_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_FWD_TASK_ID" + +[[values]] +name = "FUSED_PARALLELOP_BWD_TASK_ID" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_FIRST" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_1" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_2" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_3" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_4" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_5" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_6" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_7" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_8" + +[[values]] +name = "CUSTOM_GPU_TASK_ID_LAST" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_FIRST" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_1" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_2" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_3" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_4" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_5" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_6" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_7" + +[[values]] +name = "CUSTOM_CPU_TASK_ID_LAST" + +[[values]] +name = "PYTHON_TOP_LEVEL_TASK_ID" diff --git a/lib/local-execution/include/local-execution/task_impl_function.variant.toml b/lib/local-execution/include/local-execution/task_impl_function.variant.toml new file mode 100644 index 0000000000..a12be37da2 --- /dev/null +++ b/lib/local-execution/include/local-execution/task_impl_function.variant.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TaskImplFunction" +features = [ + "eq", + "fmt", + "hash", + "ord" +] + +includes = [ + "local-execution/init_task_impl_function.h", + "local-execution/fwd_bwd_task_impl_function.h", +] + +[[values]] +type = "::FlexFlow::InitTaskImplFunction" +key = "init_task_impl_function" + +[[values]] +type = "::FlexFlow::FwdBwdTaskImplFunction" +key = "fwd_bwd_task_impl_function" diff --git a/lib/local-execution/include/local-execution/task_registry.h b/lib/local-execution/include/local-execution/task_registry.h index 01b7d29b36..e00cc183da 100644 --- a/lib/local-execution/include/local-execution/task_registry.h +++ b/lib/local-execution/include/local-execution/task_registry.h @@ -2,25 +2,16 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_REGISTRY_H #define _FLEXFLOW_LOCAL_EXECUTION_TASK_REGISTRY_H -#include "local-execution/op_task_invocation.h" -#include "local-execution/task_signature_impl.h" -#include "op-attrs/operator_attrs.h" -#include "pcg/computation_graph.h" +#include "local-execution/task_registry.dtg.h" +#include "op-attrs/computation_graph_op_attrs.h" namespace FlexFlow { -struct TaskRegistry { - TaskRegistry() = default; +TaskRegistry empty_task_registry(); - void register_task(task_id_t const &, - layer_guid_t const &, - ComputationGraphOpAttrs const &attrs); - - std::unordered_map init_task_ids; - std::unordered_map forward_task_ids; - std::unordered_map backward_task_ids; - std::unordered_map task_mapping; -}; +void register_tasks_for_layer(TaskRegistry &, + layer_guid_t const &, + ComputationGraphOpAttrs const &attrs); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/task_registry.struct.toml b/lib/local-execution/include/local-execution/task_registry.struct.toml new file mode 100644 index 0000000000..308527efac --- /dev/null +++ b/lib/local-execution/include/local-execution/task_registry.struct.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "TaskRegistry" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "local-execution/task_signature_impl.dtg.h", + "local-execution/task_id_t.dtg.h", + "pcg/layer_guid_t.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "init_task_ids" +type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::task_id_t>>" + +[[fields]] +name = "forward_task_ids" +type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::task_id_t>>" + +[[fields]] +name = "backward_task_ids" +type = "std::unordered_map<::FlexFlow::layer_guid_t, std::optional<::FlexFlow::task_id_t>>" + +[[fields]] +name = "task_mapping" +type = "std::unordered_map<::FlexFlow::task_id_t, ::FlexFlow::TaskSignatureAndImpl>" diff --git a/lib/local-execution/include/local-execution/task_signature_impl.h b/lib/local-execution/include/local-execution/task_signature_impl.h index 659ccb23d6..98c5c0cb3b 100644 --- a/lib/local-execution/include/local-execution/task_signature_impl.h +++ b/lib/local-execution/include/local-execution/task_signature_impl.h @@ -1,24 +1,19 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_IMPL_H #define _FLEXFLOW_LOCAL_EXECUTION_TASK_SIGNATURE_IMPL_H -#include "local-execution/device_specific.h" -#include "local-execution/device_states.h" -#include "local-execution/tasks.h" -#include "task_argument_accessor.h" -#include "utils/variant.h" +#include "local-execution/op_task_invocation.h" +#include "local-execution/task_id_t.dtg.h" +#include "local-execution/task_signature_impl.dtg.h" +#include "op-attrs/computation_graph_op_attrs.h" namespace FlexFlow { -using TaskImplFunction = std::variant< - std::function(TaskArgumentAccessor const &)>, - std::function(TaskArgumentAccessor const &)>>; - -struct TaskSignatureAndImpl { - TaskImplFunction impl_function; - OpTaskSignature task_signature; -}; - TaskSignatureAndImpl get_task_sig_impl(task_id_t const &); +std::vector get_task_ids(ComputationGraphOpAttrs const &); + +OpTaskInvocation init(ComputationGraphOpAttrs const &); +OpTaskInvocation forward(ComputationGraphOpAttrs const &); +OpTaskInvocation backward(ComputationGraphOpAttrs const &); } // namespace FlexFlow diff --git a/lib/local-execution/include/local-execution/task_signature_impl.struct.toml b/lib/local-execution/include/local-execution/task_signature_impl.struct.toml new file mode 100644 index 0000000000..981794503b --- /dev/null +++ b/lib/local-execution/include/local-execution/task_signature_impl.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "TaskSignatureAndImpl" +features = [ + "eq", + "fmt", + "hash" +] + +includes = [ + "local-execution/task_impl_function.dtg.h", + "local-execution/op_task_signature.h", +] + +[[fields]] +name = "impl_function" +type = "::FlexFlow::TaskImplFunction" + +[[fields]] +name = "task_signature" +type = "::FlexFlow::OpTaskSignature" diff --git a/lib/local-execution/include/local-execution/tasks.h b/lib/local-execution/include/local-execution/tasks.h index a747c95383..4f5b26c43b 100644 --- a/lib/local-execution/include/local-execution/tasks.h +++ b/lib/local-execution/include/local-execution/tasks.h @@ -1,169 +1,13 @@ #ifndef _FLEXFLOW_LOCAL_EXECUTION_TASKS_H #define _FLEXFLOW_LOCAL_EXECUTION_TASKS_H +#include "local-execution/task_id_t.dtg.h" #include #include #include namespace FlexFlow { - -enum task_id_t { - TOP_LEVEL_TASK_ID, - FF_INIT_TASK_ID, - IMAGE_INIT_TASK_ID, - LABEL_INIT_TASK_ID, - LOAD_IMAGES_TASK_ID, - NORMALIZE_IMAGES_TASK_ID, - ELEMENTBINARY_INIT_TASK_ID, - ELEMENTBINARY_FWD_TASK_ID, - ELEMENTBINARY_BWD_TASK_ID, - ELEMENTUNARY_INIT_TASK_ID, - ELEMENTUNARY_FWD_TASK_ID, - ELEMENTUNARY_BWD_TASK_ID, - CONV2D_INIT_TASK_ID, - CONV2D_FWD_TASK_ID, - CONV2D_BWD_TASK_ID, - DROPOUT_INIT_TASK_ID, - DROPOUT_FWD_TASK_ID, - DROPOUT_BWD_TASK_ID, - EMBED_INIT_TASK_ID, - EMBED_FWD_TASK_ID, - EMBED_BWD_TASK_ID, - GATHER_INIT_TASK_ID, - GATHER_FWD_TASK_ID, - GATHER_BWD_TASK_ID, - CAST_INIT_TASK_ID, - CAST_FWD_TASK_ID, - CAST_BWD_TASK_ID, - POOL2D_INIT_TASK_ID, - POOL2D_FWD_TASK_ID, - POOL2D_BWD_TASK_ID, - BATCHNORM_INIT_TASK_ID, - BATCHNORM_FWD_TASK_ID, - BATCHNORM_BWD_TASK_ID, - BATCHMATMUL_INIT_TASK_ID, - BATCHMATMUL_FWD_TASK_ID, - BATCHMATMUL_BWD_TASK_ID, - LAYERNORM_INIT_TASK_ID, - LAYERNORM_FWD_TASK_ID, - LAYERNORM_BWD_TASK_ID, - LINEAR_INIT_TASK_ID, - LINEAR_FWD_TASK_ID, - LINEAR_BWD_TASK_ID, - FLAT_INIT_TASK_ID, - FLAT_FWD_TASK_ID, - FLAT_BWD_TASK_ID, - SOFTMAX_INIT_TASK_ID, - SOFTMAX_FWD_TASK_ID, - SOFTMAX_BWD_TASK_ID, - CONCAT_INIT_TASK_ID, - CONCAT_FWD_TASK_ID, - CONCAT_BWD_TASK_ID, - SPLIT_INIT_TASK_ID, - SPLIT_FWD_TASK_ID, - SPLIT_BWD_TASK_ID, - REDUCE_INIT_TASK_ID, - REDUCE_FWD_TASK_ID, - REDUCE_BWD_TASK_ID, - RESHAPE_INIT_TASK_ID, - RESHAPE_FWD_TASK_ID, - RESHAPE_BWD_TASK_ID, - REVERSE_INIT_TASK_ID, - REVERSE_FWD_TASK_ID, - REVERSE_BWD_TASK_ID, - TOPK_INIT_TASK_ID, - TOPK_FWD_TASK_ID, - TOPK_BWD_TASK_ID, - TRANSPOSE_INIT_TASK_ID, - TRANSPOSE_FWD_TASK_ID, - TRANSPOSE_BWD_TASK_ID, - ATTENTION_INIT_TASK_ID, - ATTENTION_FWD_TASK_ID, - ATTENTION_BWD_TASK_ID, - MSELOSS_BWD_TASK_ID, - FUSEDOP_INIT_TASK_ID, - FUSEDOP_FWD_TASK_ID, - FUSEDOP_BWD_TASK_ID, - NOOP_INIT_TASK_ID, - // Metrics tasks - METRICS_COMP_TASK_ID, - UPDATE_METRICS_TASK_ID, - // Parameter server prefetch task - PS_PREFETCH_TASK_ID, - // Loss - LOSS_BWD_TASK_ID, - // Optimizer with PS - SGD_UPD_PS_TASK_ID, - ADAM_UPD_PS_TASK_ID, - // Optimizer with NCCL - SGD_UPD_NCCL_TASK_ID, - ADAM_UPD_NCCL_TASK_ID, - // Initializer - GLOROT_INIT_TASK_ID, - ZERO_INIT_TASK_ID, - CONSTANT_INIT_TASK_ID, - UNIFORM_INIT_TASK_ID, - NORMAL_INIT_TASK_ID, - // NCCL tasks - NCCL_GETUNIQUEID_TASK_ID, - NCCL_INIT_COMMS_TASK_ID, - // Search - STRATEGY_SEARCH_TASK_ID, - // Graph - GRAPH_OPTIMIZE_TASK_ID, - // Python data loader - PY_DL_FLOAT_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT32_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT64_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_FLOAT_INDEX_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT32_INDEX_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_INT64_INDEX_LOAD_ENTIRE_CPU_TASK_ID, - PY_DL_FLOAT_LOAD_BATCH_GPU_TASK_ID, - PY_DL_INT32_LOAD_BATCH_GPU_TASK_ID, - PY_DL_INT64_LOAD_BATCH_GPU_TASK_ID, - // Parallel Ops - REPARTITION_INIT_TASK_ID, - REPARTITION_FWD_TASK_ID, - REPARTITION_BWD_TASK_ID, - COMBINE_INIT_TASK_ID, - COMBINE_FWD_TASK_ID, - COMBINE_BWD_TASK_ID, - REPLICATE_INIT_TASK_ID, - REPLICATE_FWD_TASK_ID, - REPLICATE_BWD_TASK_ID, - REDUCTION_INIT_TASK_ID, - REDUCTION_FWD_TASK_ID, - REDUCTION_BWD_TASK_ID, - PIPELINE_INIT_TASK_ID, - PIPELINE_FWD_TASK_ID, - PIPELINE_BWD_TASK_ID, - FUSED_PARALLELOP_INIT_TASK_ID, - FUSED_PARALLELOP_FWD_TASK_ID, - FUSED_PARALLELOP_BWD_TASK_ID, - // Custom tasks - CUSTOM_GPU_TASK_ID_FIRST, - CUSTOM_GPU_TASK_ID_1, - CUSTOM_GPU_TASK_ID_2, - CUSTOM_GPU_TASK_ID_3, - CUSTOM_GPU_TASK_ID_4, - CUSTOM_GPU_TASK_ID_5, - CUSTOM_GPU_TASK_ID_6, - CUSTOM_GPU_TASK_ID_7, - CUSTOM_GPU_TASK_ID_8, - CUSTOM_GPU_TASK_ID_LAST, - CUSTOM_CPU_TASK_ID_FIRST, - CUSTOM_CPU_TASK_ID_1, - CUSTOM_CPU_TASK_ID_2, - CUSTOM_CPU_TASK_ID_3, - CUSTOM_CPU_TASK_ID_4, - CUSTOM_CPU_TASK_ID_5, - CUSTOM_CPU_TASK_ID_6, - CUSTOM_CPU_TASK_ID_7, - CUSTOM_CPU_TASK_ID_LAST, - // Make sure PYTHON_TOP_LEVEL_TASK_ID is - // consistent with python/main.cc - PYTHON_TOP_LEVEL_TASK_ID = 11111, -}; +// PYTHON_TOP_LEVEL_TASK_ID = 11111, void register_flexflow_internal_tasks(); diff --git a/lib/local-execution/include/local-execution/tracked_allocator.h b/lib/local-execution/include/local-execution/tracked_allocator.h index ae7bd076ce..731e04fdc8 100644 --- a/lib/local-execution/include/local-execution/tracked_allocator.h +++ b/lib/local-execution/include/local-execution/tracked_allocator.h @@ -17,6 +17,7 @@ struct TrackedAllocator : public IAllocator { private: size_t current_mem_usage = 0; + std::unordered_map ptr_mem_usage; Allocator allocator; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(TrackedAllocator); diff --git a/lib/local-execution/src/concrete_arg.cc b/lib/local-execution/src/concrete_arg.cc new file mode 100644 index 0000000000..450d663e17 --- /dev/null +++ b/lib/local-execution/src/concrete_arg.cc @@ -0,0 +1,29 @@ +#include "local-execution/concrete_arg.h" + +namespace FlexFlow { + +bool ConcreteArgSpec::operator==(ConcreteArgSpec const &other) const { + return this->tie() == other.tie(); +} + +bool ConcreteArgSpec::operator!=(ConcreteArgSpec const &other) const { + return this->tie() != other.tie(); +} + +std::tuple ConcreteArgSpec::tie() const { + return std::tie(this->type_idx); +} + +std::string format_as(ConcreteArgSpec const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, ConcreteArgSpec const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/fwd_bwd_task_impl_function.cc b/lib/local-execution/src/fwd_bwd_task_impl_function.cc new file mode 100644 index 0000000000..f85d7cec61 --- /dev/null +++ b/lib/local-execution/src/fwd_bwd_task_impl_function.cc @@ -0,0 +1,54 @@ +#include "local-execution/fwd_bwd_task_impl_function.h" + +namespace FlexFlow { + +bool FwdBwdTaskImplFunction::operator==( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr == other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator!=( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr != other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator<( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr < other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator>( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr > other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator<=( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr <= other.function_ptr; +} + +bool FwdBwdTaskImplFunction::operator>=( + FwdBwdTaskImplFunction const &other) const { + return this->function_ptr >= other.function_ptr; +} + +std::string format_as(FwdBwdTaskImplFunction const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::ostream &operator<<(std::ostream &s, FwdBwdTaskImplFunction const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::FwdBwdTaskImplFunction const &x) const { + return std::hash{}(x.function_ptr); +} +} // namespace std diff --git a/lib/local-execution/src/init_task_impl_function.cc b/lib/local-execution/src/init_task_impl_function.cc new file mode 100644 index 0000000000..9501f72dd6 --- /dev/null +++ b/lib/local-execution/src/init_task_impl_function.cc @@ -0,0 +1,47 @@ +#include "local-execution/init_task_impl_function.h" + +namespace FlexFlow { + +bool InitTaskImplFunction::operator==(InitTaskImplFunction const &other) const { + return this->function_ptr == other.function_ptr; +} + +bool InitTaskImplFunction::operator!=(InitTaskImplFunction const &other) const { + return this->function_ptr != other.function_ptr; +} + +bool InitTaskImplFunction::operator<(InitTaskImplFunction const &other) const { + return this->function_ptr < other.function_ptr; +} + +bool InitTaskImplFunction::operator>(InitTaskImplFunction const &other) const { + return this->function_ptr > other.function_ptr; +} + +bool InitTaskImplFunction::operator<=(InitTaskImplFunction const &other) const { + return this->function_ptr <= other.function_ptr; +} + +bool InitTaskImplFunction::operator>=(InitTaskImplFunction const &other) const { + return this->function_ptr >= other.function_ptr; +} + +std::string format_as(InitTaskImplFunction const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, InitTaskImplFunction const &x) { + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace std { +size_t hash::operator()( + ::FlexFlow::InitTaskImplFunction const &x) const { + return std::hash{}(x.function_ptr); +} +} // namespace std diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc new file mode 100644 index 0000000000..b3a045bab4 --- /dev/null +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -0,0 +1,13 @@ +#include "local-execution/legion_tensor_shape.h" + +namespace FlexFlow { + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, size_t num_dims) { + return legion_dim_t(num_dims - ff_dim.value - 1); +} + +legion_dim_t legion_dim_from_ff_dim(ff_dim_t ff_dim, TensorShape const &shape) { + return legion_dim_t(num_dims(shape) - ff_dim.value - 1); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index be3dfb01aa..d4e0467cbf 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -32,7 +32,8 @@ CostDetails LocalCostEstimator::estimate_cost( std::vector const &outputs, MachineView const &mv) const { - if (is_parallel_op(op) || op.has() || op.has()) { + if (is_parallel_op(op) || op.has() || op.has() || + op.has()) { return CostDetails{0, 0}; } diff --git a/lib/local-execution/src/local_cpu_allocator.cc b/lib/local-execution/src/local_cpu_allocator.cc new file mode 100644 index 0000000000..4ca5f987a8 --- /dev/null +++ b/lib/local-execution/src/local_cpu_allocator.cc @@ -0,0 +1,24 @@ +#include "local-execution/local_cpu_allocator.h" +#include "utils/containers/contains_key.h" + +namespace FlexFlow { +void *LocalCPUAllocator::allocate(size_t requested_memory_size) { + void *ptr = malloc(requested_memory_size); + this->ptrs.insert({ptr, std::unique_ptr(ptr, free)}); + return ptr; +} + +void LocalCPUAllocator::deallocate(void *ptr) { + if (contains_key(this->ptrs, ptr)) { + this->ptrs.erase(ptr); + } else { + throw std::runtime_error( + "Deallocating a pointer that was not allocated by this Allocator"); + } +} + +Allocator create_local_cpu_memory_allocator() { + return Allocator::create(); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/local_slots_backing.cc b/lib/local-execution/src/local_slots_backing.cc index a07799def6..0ec9068c6a 100644 --- a/lib/local-execution/src/local_slots_backing.cc +++ b/lib/local-execution/src/local_slots_backing.cc @@ -11,22 +11,61 @@ LocalSlotsBacking::LocalSlotsBacking(TensorBackingMap const &allocated_tensors, void LocalSlotsBacking::add_per_device_op_state( layer_guid_t const &op_guid, - DeviceSpecific const &device_state) { + DeviceSpecificDeviceStates const &device_state) { this->per_device_op_states.insert({op_guid, device_state}); } +void LocalSlotsBacking::allocate_outgoing_tensors( + layer_guid_t const &layer_guid, + ComputationGraph const &computation_graph, + Allocator &allocator) { + std::vector incoming_tensors = + get_incoming_tensors(computation_graph, layer_guid); + std::vector outgoing_tensors = + get_outgoing_tensors(computation_graph, layer_guid); + for (tensor_guid_t const &output_tensor : outgoing_tensors) { + TensorAttrs tensor_attrs = + get_tensor_attrs(computation_graph, output_tensor); + // tensor allocation + if (!is_tensor_allocated(output_tensor)) { + GenericTensorAccessorW tensor_backing = + allocator.allocate_tensor(tensor_attrs.shape); + this->tensor_mapping.insert({output_tensor, tensor_backing}); + } + + // gradient tensor allocation + if (tensor_attrs.create_gradients == CreateGrad::YES && + !is_gradient_tensor_allocated(output_tensor)) { + GenericTensorAccessorW gradient_tensor_backing = + allocator.allocate_tensor(tensor_attrs.shape); + this->gradient_tensor_mapping.insert( + {output_tensor, gradient_tensor_backing}); + } + } + + this->input_tensor_slots.insert({layer_guid, incoming_tensors}); + this->output_tensor_slots.insert({layer_guid, outgoing_tensors}); +} + bool LocalSlotsBacking::is_tensor_allocated( tensor_guid_t const &tensor_id) const { return contains_key(this->tensor_mapping, tensor_id); } +bool LocalSlotsBacking::is_gradient_tensor_allocated( + tensor_guid_t const &tensor_id) const { + return contains_key(this->gradient_tensor_mapping, tensor_id); +} + GenericTensorAccessorW const & LocalSlotsBacking::get_tensor_backing(tensor_guid_t const &tensor_id, IsGrad is_grad) const { switch (is_grad) { case IsGrad::NO: + assert(contains_key(this->tensor_mapping, tensor_id)); return this->tensor_mapping.at(tensor_id); case IsGrad::YES: + assert(contains_key(this->gradient_tensor_mapping, tensor_id)); return this->gradient_tensor_mapping.at(tensor_id); default: throw mk_runtime_error(fmt::format( @@ -43,12 +82,12 @@ TensorSlotsBacking LocalSlotsBacking::construct_tensor_slots_backing( std::vector tensor_guids; switch (tensor_spec.role) { case TensorRole::INPUT: - tensor_guids = this->input_tensor_slots.at(op_guid); - break; case TensorRole::WEIGHT: - tensor_guids = this->weight_tensor_slots.at(op_guid); + assert(contains_key(this->input_tensor_slots, op_guid)); + tensor_guids = this->input_tensor_slots.at(op_guid); break; case TensorRole::OUTPUT: + assert(contains_key(this->output_tensor_slots, op_guid)); tensor_guids = this->output_tensor_slots.at(op_guid); break; default: @@ -56,9 +95,12 @@ TensorSlotsBacking LocalSlotsBacking::construct_tensor_slots_backing( fmt::format("Invalid TensorRole")); // inserting role yields // "type_is_unformattable" error } + + assert(tensor_guids.size() > tensor_spec.idx); IsGrad is_grad = slot_grad_id.is_grad; GenericTensorAccessorW tensor_backing = this->get_tensor_backing(tensor_guids.at(tensor_spec.idx), is_grad); + mapping.insert({slot_grad_id, tensor_backing}); } return mapping; @@ -87,13 +129,22 @@ ArgSlotsBacking LocalSlotsBacking::construct_arg_slots_backing( ConcreteArgSpec LocalSlotsBacking::resolve_op_arg_ref_spec( OpArgRefSpec const &op_arg_ref_spec, layer_guid_t const &op_guid) const { - if (op_arg_ref_spec.holds>()) { - return ConcreteArgSpec::create(per_device_op_states.at(op_guid)); + if (op_arg_ref_spec.holds()) { + assert(contains_key(per_device_op_states, op_guid)); + DeviceSpecificDeviceStates device_specific = + per_device_op_states.at(op_guid); + PerDeviceOpState device_state = + get_device_state_from_device_specific(device_specific, 0); + return ConcreteArgSpec::create(device_state); } else if (op_arg_ref_spec.holds()) { ParallelTensorShapeRefType index_op_arg_ref = - std::get(op_arg_ref_spec.get_ref_type()); + op_arg_ref_spec.get_ref_type().get(); + + assert(contains_key(this->input_tensor_slots, op_guid)); std::vector input_tensor_guids = this->input_tensor_slots.at(op_guid); + + assert(input_tensor_guids.size() > index_op_arg_ref.idx); GenericTensorAccessorW tensor_backing = this->get_tensor_backing( input_tensor_guids.at(index_op_arg_ref.idx), IsGrad::NO); ParallelTensorShape shape = lift_to_parallel( @@ -107,9 +158,8 @@ ConcreteArgSpec LocalSlotsBacking::resolve_op_arg_ref_spec( ConcreteArgSpec LocalSlotsBacking::resolve_runtime_arg_ref_spec( RuntimeArgRefSpec const &runtime_arg_ref_spec) const { if (runtime_arg_ref_spec.holds>()) { - return ConcreteArgSpec::create(this->runtime_arg_config.ff_handle); - } else if (runtime_arg_ref_spec.holds()) { - return ConcreteArgSpec::create(this->runtime_arg_config.enable_profiling); + return ConcreteArgSpec::create( + *(this->runtime_arg_config.ff_handle.get(0))); } else if (runtime_arg_ref_spec.holds()) { return ConcreteArgSpec::create(this->runtime_arg_config.profiling_settings); } else { diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc index 62fe9b2d16..5d0156201e 100644 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local_task_argument_accessor.cc @@ -1,5 +1,8 @@ #include "local-execution/local_task_argument_accessor.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/transform.h" #include "utils/hash/pair.h" +#include "utils/overload.h" namespace FlexFlow { @@ -54,6 +57,37 @@ Allocator LocalTaskArgumentAccessor::get_allocator() const { return this->allocator; } +TensorSlotsBackingWithoutAddresses + get_slots_backing_without_tensor_allocation_addresses( + TensorSlotsBacking const &slots_backing) { + + TensorSlotsBackingWithoutAddresses addressless_slots_backing; + + using TensorAccessorVariant = + std::variant>; + for (auto const &slot_tensor : slots_backing) { + TensorAccessorVariant accessor_variant = slot_tensor.second; + std::visit( + overload{ + [&](GenericTensorAccessorW const &accessor) { + addressless_slots_backing.insert( + {slot_tensor.first, get_shape_and_datatype(accessor)}); + }, + [&](std::vector const &variadic_accessor) { + std::vector> + variadic_addressless_accessor = + transform(variadic_accessor, + [](GenericTensorAccessorW const &accessor) { + return get_shape_and_datatype(accessor); + }); + addressless_slots_backing.insert( + {slot_tensor.first, variadic_addressless_accessor}); + }}, + accessor_variant); + } + return addressless_slots_backing; +} + size_t LocalTaskArgumentAccessor::get_device_idx() const { return 0; } diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index 6d5a5011fd..a2ee06a95a 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -1,4 +1,5 @@ #include "local-execution/local_training_backing.h" +#include "local-execution/task_signature_impl.h" #include "utils/containers/reversed.h" #include "utils/exception.h" @@ -10,51 +11,29 @@ LocalTrainingBacking::LocalTrainingBacking( TensorBackingMap const &tensor_backing_mapping, RuntimeArgConfig const &runtime_arg_config) : allocator(allocator), computation_graph(computation_graph), - local_slots_backing(tensor_backing_mapping, runtime_arg_config) { - std::vector layers = topological_ordering(computation_graph); - for (layer_guid_t const &node : layers) { + local_slots_backing(tensor_backing_mapping, runtime_arg_config), + task_registry(empty_task_registry()) { + + for (layer_guid_t const &node : topological_ordering(computation_graph)) { ComputationGraphOpAttrs attrs = get_layer_attrs(computation_graph, node).attrs; - // register tasks - std::vector task_ids = get_task_ids(attrs); - for (task_id_t task_id : task_ids) { - this->task_registry.register_task(task_id, node, attrs); - } + // allocate outgoing tensors + this->local_slots_backing.allocate_outgoing_tensors( + node, computation_graph, this->allocator); - // insert pre-allocated tensors - this->local_slots_backing.input_tensor_slots.insert( - {node, get_incoming_tensors(computation_graph, node)}); - this->local_slots_backing.output_tensor_slots.insert( - {node, get_outgoing_tensors(computation_graph, node)}); - - // allocate new tensors - for (tensor_guid_t const &edge : - get_outgoing_tensors(computation_graph, node)) { - if (!this->local_slots_backing.is_tensor_allocated(edge)) { - TensorAttrs tensor_attrs = get_tensor_attrs(computation_graph, edge); - GenericTensorAccessorW tensor_backing = - this->allocator.allocate_tensor(tensor_attrs.shape); - this->local_slots_backing.tensor_mapping.insert({edge, tensor_backing}); - - if (tensor_attrs.create_gradients == CreateGrad::YES) { - GenericTensorAccessorW gradient_tensor_backing = - this->allocator.allocate_tensor(tensor_attrs.shape); - this->local_slots_backing.gradient_tensor_mapping.insert( - {edge, gradient_tensor_backing}); - } - } - } + // register tasks + register_tasks_for_layer(this->task_registry, node, attrs); } } -DeviceSpecific +DeviceSpecificDeviceStates LocalTrainingBacking::call_init_task_impl(task_id_t task_id, TaskArgumentAccessor const &acc) { TaskSignatureAndImpl task_sig_impl = this->task_registry.task_mapping.at(task_id); - auto fn = std::get( - TaskArgumentAccessor const &)>>(task_sig_impl.impl_function); + auto fn = + task_sig_impl.impl_function.get().function_ptr; return fn(acc); } @@ -63,24 +42,26 @@ std::optional TaskArgumentAccessor acc) { TaskSignatureAndImpl task_sig_impl = this->task_registry.task_mapping.at(task_id); - auto fn = std::get< - std::function(TaskArgumentAccessor const &)>>( - task_sig_impl.impl_function); + auto fn = + task_sig_impl.impl_function.get().function_ptr; return fn(acc); } void LocalTrainingBacking::execute_init() { for (layer_guid_t const &operator_node : topological_ordering(this->computation_graph)) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - OpTaskInvocation invocation = init(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - DeviceSpecific device_state = - this->call_init_task_impl(invocation.task_id, accessor); - this->local_slots_backing.add_per_device_op_state(operator_node, - device_state); + if (this->task_registry.init_task_ids.at(operator_node).has_value()) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = init(attrs); + TaskArgumentAccessor accessor = + this->get_task_arg_accessor(invocation, operator_node); + DeviceSpecificDeviceStates device_state = + this->call_init_task_impl(invocation.task_id, accessor); + this->local_slots_backing.add_per_device_op_state(operator_node, + device_state); + } } } @@ -88,14 +69,17 @@ PerLayerElapsedTime LocalTrainingBacking::execute_forward() { PerLayerElapsedTime per_op_elapsed_time; for (layer_guid_t const &operator_node : topological_ordering(this->computation_graph)) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - OpTaskInvocation invocation = forward(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - std::optional elapsed_time = - this->call_task_impl(invocation.task_id, accessor); - per_op_elapsed_time.insert({operator_node, elapsed_time}); + if (this->task_registry.forward_task_ids.at(operator_node).has_value()) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = forward(attrs); + TaskArgumentAccessor accessor = + this->get_task_arg_accessor(invocation, operator_node); + std::optional elapsed_time = + this->call_task_impl(invocation.task_id, accessor); + per_op_elapsed_time.insert({operator_node, elapsed_time}); + } } return per_op_elapsed_time; } @@ -104,14 +88,17 @@ PerLayerElapsedTime LocalTrainingBacking::execute_backward() { PerLayerElapsedTime per_op_elapsed_time; for (layer_guid_t const &operator_node : reversed(topological_ordering(this->computation_graph))) { - ComputationGraphOpAttrs attrs = - get_layer_attrs(this->computation_graph, operator_node).attrs; - OpTaskInvocation invocation = backward(attrs); - TaskArgumentAccessor accessor = - this->get_task_arg_accessor(invocation, operator_node); - std::optional elapsed_time = - this->call_task_impl(invocation.task_id, accessor); - per_op_elapsed_time.insert({operator_node, elapsed_time}); + if (this->task_registry.backward_task_ids.at(operator_node).has_value()) { + ComputationGraphOpAttrs attrs = + get_layer_attrs(this->computation_graph, operator_node).attrs; + + OpTaskInvocation invocation = backward(attrs); + TaskArgumentAccessor accessor = + this->get_task_arg_accessor(invocation, operator_node); + std::optional elapsed_time = + this->call_task_impl(invocation.task_id, accessor); + per_op_elapsed_time.insert({operator_node, elapsed_time}); + } } return per_op_elapsed_time; } diff --git a/lib/local-execution/src/op_arg_ref.cc b/lib/local-execution/src/op_arg_ref.cc index cad251f33e..b3d6e2f1a5 100644 --- a/lib/local-execution/src/op_arg_ref.cc +++ b/lib/local-execution/src/op_arg_ref.cc @@ -3,7 +3,7 @@ namespace FlexFlow { OpArgRef input_parallel_tensor_shape(int idx) { - OpArgRefType arg_ref_type = ParallelTensorShapeRefType{idx}; + OpArgRefType arg_ref_type = OpArgRefType{ParallelTensorShapeRefType{idx}}; ArgRef arg_ref = {arg_ref_type}; return arg_ref; } diff --git a/lib/local-execution/src/local-execution/op_arg_spec.cc b/lib/local-execution/src/op_arg_spec.cc similarity index 100% rename from lib/local-execution/src/local-execution/op_arg_spec.cc rename to lib/local-execution/src/op_arg_spec.cc diff --git a/lib/local-execution/src/op_task_invocation.cc b/lib/local-execution/src/op_task_invocation.cc index 3569bfb122..19c8894b05 100644 --- a/lib/local-execution/src/op_task_invocation.cc +++ b/lib/local-execution/src/op_task_invocation.cc @@ -36,6 +36,20 @@ void OpTaskBinding::insert_arg_spec(slot_id_t name, OpArgSpec const &arg_spec) { this->arg_bindings.insert({name, arg_spec}); } +bool OpTaskBinding::operator==(OpTaskBinding const &other) const { + return this->tie() == other.tie(); +} + +bool OpTaskBinding::operator!=(OpTaskBinding const &other) const { + return this->tie() != other.tie(); +} + +std::tuple const &, + std::unordered_map const &> + OpTaskBinding::tie() const { + return std::tie(this->tensor_bindings, this->arg_bindings); +} + std::unordered_map const & OpTaskBinding::get_tensor_bindings() const { return this->tensor_bindings; @@ -82,6 +96,9 @@ bool is_tensor_invocation_valid(OpTaskSignature const &sig, return false; } } + + // FIXME -- make sure invocation doesn't contain MORE than signature + // https://github.com/flexflow/FlexFlow/issues/1442 return true; } @@ -93,13 +110,13 @@ bool is_arg_type_invalid(std::type_index expected_arg_type, bool is_arg_invocation_valid(OpTaskSignature const &sig, OpTaskInvocation const &inv) { - auto sig_arg_types = sig.get_arg_types(); - for (auto arg_binding : inv.binding.get_arg_bindings()) { - std::type_index arg_type = sig_arg_types.at(arg_binding.first); - if (is_arg_type_invalid(arg_type, arg_binding.second)) { - return false; - } - } + // FIXME -- arg signature/invocation checking + // https://github.com/flexflow/FlexFlow/issues/1442 + // auto sig_arg_types = sig.get_arg_types(); + // for (auto arg_binding : inv.binding.get_arg_bindings()) { + // std::type_index arg_type = sig_arg_types.at(arg_binding.first); + // assert (!is_arg_type_invalid(arg_type, arg_binding.second)); + // } return true; } diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc index 3267ff592f..36a1dd708d 100644 --- a/lib/local-execution/src/op_task_signature.cc +++ b/lib/local-execution/src/op_task_signature.cc @@ -1,4 +1,6 @@ #include "local-execution/op_task_signature.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { @@ -61,18 +63,6 @@ void OpTaskSignature::add_output_slot(int name, SlotType slot_type) { } void OpTaskSignature::add_output_slot(slot_id_t name, SlotType slot_type) { - OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ - name, slot_type, TensorRole::OUTPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; - this->op_tensor_slots.insert(op_tensor_slot_spec); -} - -void OpTaskSignature::add_bwd_necessary_output_slot(int name, - SlotType slot_type) { - this->add_bwd_necessary_output_slot(slot_id_t{name}, slot_type); -} - -void OpTaskSignature::add_bwd_necessary_output_slot(slot_id_t name, - SlotType slot_type) { OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{name, slot_type, @@ -82,6 +72,18 @@ void OpTaskSignature::add_bwd_necessary_output_slot(slot_id_t name, this->op_tensor_slots.insert(op_tensor_slot_spec); } +void OpTaskSignature::add_bwd_optional_output_slot(int name, + SlotType slot_type) { + this->add_bwd_optional_output_slot(slot_id_t{name}, slot_type); +} + +void OpTaskSignature::add_bwd_optional_output_slot(slot_id_t name, + SlotType slot_type) { + OpTensorSlotSpec op_tensor_slot_spec = OpTensorSlotSpec{ + name, slot_type, TensorRole::OUTPUT, IsGrad::NO, OpSlotOptions::OPTIONAL}; + this->op_tensor_slots.insert(op_tensor_slot_spec); +} + void OpTaskSignature::add_weight_slot(int name, SlotType slot_type) { this->add_weight_slot(slot_id_t{name}, slot_type); } @@ -116,4 +118,47 @@ void OpTaskSignature::add_from_slot_spec(OpTensorSlotSpec const &spec) { this->op_tensor_slots.insert(spec); } +OpTaskSignature infer_bwd_signature(OpTaskSignature const &fwd) { + OpTaskSignature bwd = fwd; + bwd.type = OpTaskType::BWD; + for (auto const &op_tensor_slot_spec : fwd.get_tensor_slots()) { + OpSlotOptions slot_option = op_tensor_slot_spec.slot_option; + if (slot_option != OpSlotOptions::UNTRAINABLE || + slot_option != OpSlotOptions::OPTIONAL_UNTRAINABLE) { + OpTensorSlotSpec grad_spec = + OpTensorSlotSpec{op_tensor_slot_spec.name, + op_tensor_slot_spec.slot_type, + op_tensor_slot_spec.tensor_role, + IsGrad::YES, + op_tensor_slot_spec.slot_option}; + bwd.op_tensor_slots.insert(grad_spec); + } + } + + return bwd; +} + +std::unordered_set OpTaskSignature::get_tensor_slots() const { + return this->op_tensor_slots; +} + +std::unordered_map + OpTaskSignature::get_arg_types() const { + return this->task_arg_types; +} + +std::string format_as(OpTaskSignature const &x) { + std::ostringstream oss; + oss << ""; + return oss.str(); +} +std::ostream &operator<<(std::ostream &s, OpTaskSignature const &x) { + return s << fmt::to_string(x); +} + } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/attention.cc b/lib/local-execution/src/ops/attention.cc index fc3627404d..eebef9039d 100644 --- a/lib/local-execution/src/ops/attention.cc +++ b/lib/local-execution/src/ops/attention.cc @@ -13,9 +13,10 @@ * limitations under the License. */ -#include "attention.h" +#include "local-execution/ops/attention.h" #include "kernels/attention_kernels.h" #include "local-execution/op_task_signature.h" +#include "op-attrs/ops/attention.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" namespace FlexFlow { @@ -56,7 +57,7 @@ OpTaskInvocation init(MultiHeadAttentionAttrs const &attrs) { b.bind_arg(VPROJSIZE, get_vProjSize(attrs)); b.bind_arg(OPROJSIZE, get_oProjSize(attrs)); - return {ATTENTION_INIT_TASK_ID, b}; + return {task_id_t::ATTENTION_INIT_TASK_ID, b}; } OpTaskInvocation forward(MultiHeadAttentionAttrs const &attrs) { @@ -71,23 +72,24 @@ OpTaskInvocation forward(MultiHeadAttentionAttrs const &attrs) { b.bind_arg(PROFILING, profiling_settings()); b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {ATTENTION_FWD_TASK_ID, b}; + return {task_id_t::ATTENTION_FWD_TASK_ID, b}; } OpTaskInvocation backward(MultiHeadAttentionAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {ATTENTION_BWD_TASK_ID, b}; + return {task_id_t::ATTENTION_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); Allocator allocator = acc.get_allocator(); - size_t qProjSize = acc.get_argument(QPROJSIZE); - size_t kProjSize = acc.get_argument(KPROJSIZE); - size_t vProjSize = acc.get_argument(VPROJSIZE); - size_t oProjSize = acc.get_argument(OPROJSIZE); + size_t qProjSize = acc.get_argument(QPROJSIZE); + size_t kProjSize = acc.get_argument(KPROJSIZE); + size_t vProjSize = acc.get_argument(VPROJSIZE); + size_t oProjSize = acc.get_argument(OPROJSIZE); + PerDeviceFFHandle handle = acc.get_argument(HANDLE); ParallelTensorShape query_parallel_tensor_shape = acc.get_argument(QUERY_PARALLEL_TENSOR_SHAPE); @@ -129,7 +131,8 @@ static DeviceSpecific qoSeqLength, kvSeqLength, attrs.add_bias_kv); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -140,7 +143,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); ProfilingSettings profiling = acc.get_argument(PROFILING); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + MHAPerDeviceState per_device_state = + acc.get_argument(PER_DEVICE_STATE); return profile(forward_kernel, profiling, @@ -166,7 +170,8 @@ static std::optional auto key_grad = acc.get_tensor_grad(KEY); auto value_grad = acc.get_tensor_grad(VALUE); - auto per_device_state = acc.get_argument(PER_DEVICE_STATE); + MHAPerDeviceState per_device_state = + acc.get_argument(PER_DEVICE_STATE); ProfilingSettings profiling = acc.get_argument(PROFILING); float *key_grad_ptr = @@ -197,13 +202,13 @@ static std::optional } TaskImplFunction get_attention_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_attention_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_attention_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_attention_init_signature() { @@ -245,7 +250,9 @@ OpTaskSignature get_attention_bwd_signature() { } std::vector get_task_ids(MultiHeadAttentionAttrs const &) { - return {ATTENTION_INIT_TASK_ID, ATTENTION_FWD_TASK_ID, ATTENTION_BWD_TASK_ID}; + return {task_id_t::ATTENTION_INIT_TASK_ID, + task_id_t::ATTENTION_FWD_TASK_ID, + task_id_t::ATTENTION_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/batch_matmul.cc b/lib/local-execution/src/ops/batch_matmul.cc index d18e58baf4..1eae409ae2 100644 --- a/lib/local-execution/src/ops/batch_matmul.cc +++ b/lib/local-execution/src/ops/batch_matmul.cc @@ -45,13 +45,13 @@ OpTaskInvocation forward(BatchMatmulAttrs const &attrs) { fwd.bind_arg(PROFILING, profiling_settings()); fwd.bind_arg(ITERATION_CONFIG, iteration_config()); - return {BATCHMATMUL_FWD_TASK_ID, fwd}; + return {task_id_t::BATCHMATMUL_FWD_TASK_ID, fwd}; } OpTaskInvocation backward(BatchMatmulAttrs const &attrs) { OpTaskBinding bwd = infer_bwd_binding(forward(attrs).binding); - return {BATCHMATMUL_BWD_TASK_ID, bwd}; + return {task_id_t::BATCHMATMUL_BWD_TASK_ID, bwd}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -153,10 +153,10 @@ static std::optional } TaskImplFunction get_batch_matmul_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_batch_matmul_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_batch_matmul_fwd_signature() { @@ -173,14 +173,14 @@ OpTaskSignature get_batch_matmul_fwd_signature() { } OpTaskSignature get_batch_matmul_bwd_signature() { - OpTaskSignature bwd = - infer_bwd_signature(fwd_signature()); + OpTaskSignature bwd = infer_bwd_signature(get_batch_matmul_fwd_signature()); return bwd; } std::vector get_task_ids(BatchMatmulAttrs const &) { - return {BATCHMATMUL_FWD_TASK_ID, BATCHMATMUL_BWD_TASK_ID}; + return {task_id_t::BATCHMATMUL_FWD_TASK_ID, + task_id_t::BATCHMATMUL_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/batch_norm.cc b/lib/local-execution/src/ops/batch_norm.cc index 5eaa264541..851566fc02 100644 --- a/lib/local-execution/src/ops/batch_norm.cc +++ b/lib/local-execution/src/ops/batch_norm.cc @@ -43,7 +43,7 @@ OpTaskInvocation init(BatchNormAttrs const &attrs) { binding.bind_arg(PROFILING, profiling_settings()); binding.bind_arg(HANDLE, ff_handle()); - return {BATCHNORM_INIT_TASK_ID, binding}; + return {task_id_t::BATCHNORM_INIT_TASK_ID, binding}; } OpTaskInvocation forward(BatchNormAttrs const &attrs) { @@ -57,16 +57,16 @@ OpTaskInvocation forward(BatchNormAttrs const &attrs) { binding.bind(BIAS, input_tensor(2)); binding.bind(OUTPUT, output_tensor(0)); - return {BATCHNORM_FWD_TASK_ID, binding}; + return {task_id_t::BATCHNORM_FWD_TASK_ID, binding}; } OpTaskInvocation backward(BatchNormAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {BATCHNORM_BWD_TASK_ID, binding}; + return {task_id_t::BATCHNORM_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { Allocator allocator = acc.get_allocator(); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -91,7 +91,8 @@ static DeviceSpecific output_w, attrs.relu); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -143,13 +144,13 @@ static std::optional } TaskImplFunction get_batch_norm_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_batch_norm_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_batch_norm_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_batch_norm_init_signature() { @@ -185,9 +186,9 @@ OpTaskSignature get_batch_norm_bwd_signature() { std::vector get_task_ids(BatchNormAttrs const &) { return { - BATCHNORM_INIT_TASK_ID, - BATCHNORM_FWD_TASK_ID, - BATCHNORM_BWD_TASK_ID, + task_id_t::BATCHNORM_INIT_TASK_ID, + task_id_t::BATCHNORM_FWD_TASK_ID, + task_id_t::BATCHNORM_BWD_TASK_ID, }; } diff --git a/lib/local-execution/src/ops/cast.cc b/lib/local-execution/src/ops/cast.cc index 7e109a2140..3e7baf49a9 100644 --- a/lib/local-execution/src/ops/cast.cc +++ b/lib/local-execution/src/ops/cast.cc @@ -34,13 +34,13 @@ OpTaskInvocation forward(CastAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {CAST_FWD_TASK_ID, binding}; + return {task_id_t::CAST_FWD_TASK_ID, binding}; } OpTaskInvocation backward(CastAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {CAST_BWD_TASK_ID, binding}; + return {task_id_t::CAST_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -79,10 +79,10 @@ static std::optional } TaskImplFunction get_cast_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_cast_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_cast_fwd_signature() { @@ -104,7 +104,7 @@ OpTaskSignature get_cast_bwd_signature() { } std::vector get_task_ids(CastAttrs const &) { - return {CAST_FWD_TASK_ID, CAST_BWD_TASK_ID}; + return {task_id_t::CAST_FWD_TASK_ID, task_id_t::CAST_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/combine.cc b/lib/local-execution/src/ops/combine.cc index a6aeaebf14..ccc82cce17 100644 --- a/lib/local-execution/src/ops/combine.cc +++ b/lib/local-execution/src/ops/combine.cc @@ -32,13 +32,13 @@ OpTaskInvocation forward(CombineAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {COMBINE_FWD_TASK_ID, binding}; + return {task_id_t::COMBINE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(CombineAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {COMBINE_BWD_TASK_ID, b}; + return {task_id_t::COMBINE_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -84,4 +84,11 @@ OpTaskSignature get_combine_bwd_signature() { return bwd; } +TaskImplFunction get_combine_fwd_task_impl() { + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; +} +TaskImplFunction get_combine_bwd_task_impl() { + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; +} + }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 5dfd100d84..35f663b1cd 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -34,13 +34,13 @@ OpTaskInvocation forward(ConcatAttrs const &attrs) { binding.bind_arg(PROFILING, profiling_settings()); binding.bind_arg(ATTRS, attrs); - return {CONCAT_FWD_TASK_ID, binding}; + return {task_id_t::CONCAT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ConcatAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {CONCAT_BWD_TASK_ID, b}; + return {task_id_t::CONCAT_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -79,10 +79,10 @@ static std::optional } TaskImplFunction get_concat_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_concat_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_concat_fwd_signature() { @@ -97,14 +97,13 @@ OpTaskSignature get_concat_fwd_signature() { } OpTaskSignature get_concat_bwd_signature() { - OpTaskSignature bwd = - infer_bwd_signature(fwd_signature()); + OpTaskSignature bwd = infer_bwd_signature(get_concat_fwd_signature()); return bwd; } std::vector get_task_ids(ConcatAttrs const &) { - return {CONCAT_FWD_TASK_ID, CONCAT_BWD_TASK_ID}; + return {task_id_t::CONCAT_FWD_TASK_ID, task_id_t::CONCAT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/conv_2d.cc b/lib/local-execution/src/ops/conv_2d.cc index 7aede41355..d5c6e7f851 100644 --- a/lib/local-execution/src/ops/conv_2d.cc +++ b/lib/local-execution/src/ops/conv_2d.cc @@ -26,7 +26,7 @@ OpTaskInvocation init(Conv2DAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {CONV2D_INIT_TASK_ID, binding}; + return {task_id_t::CONV2D_INIT_TASK_ID, binding}; } OpTaskInvocation forward(Conv2DAttrs const &attrs) { @@ -42,16 +42,16 @@ OpTaskInvocation forward(Conv2DAttrs const &attrs) { binding.bind(FILTER, weight_tensor(0)); binding.bind(BIAS, weight_tensor(1)); - return {CONV2D_FWD_TASK_ID, binding}; + return {task_id_t::CONV2D_FWD_TASK_ID, binding}; } OpTaskInvocation backward(Conv2DAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {CONV2D_BWD_TASK_ID, binding}; + return {task_id_t::CONV2D_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -75,7 +75,8 @@ static DeviceSpecific output, filter.get_float_ptr(), filter_grad.get_float_ptr()); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -131,13 +132,13 @@ static std::optional } TaskImplFunction get_conv_2d_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_conv_2d_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_conv_2d_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_conv_2d_init_signature() { @@ -176,7 +177,9 @@ OpTaskSignature get_conv_2d_bwd_signature() { } std::vector get_task_ids(Conv2DAttrs const &) { - return {CONV2D_INIT_TASK_ID, CONV2D_FWD_TASK_ID, CONV2D_BWD_TASK_ID}; + return {task_id_t::CONV2D_INIT_TASK_ID, + task_id_t::CONV2D_FWD_TASK_ID, + task_id_t::CONV2D_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/dropout.cc b/lib/local-execution/src/ops/dropout.cc index 9195e859ce..cac08866cc 100644 --- a/lib/local-execution/src/ops/dropout.cc +++ b/lib/local-execution/src/ops/dropout.cc @@ -18,7 +18,7 @@ OpTaskInvocation init(DropoutAttrs const &attrs) { binding.bind_arg(FF_HANDLE, ff_handle()); binding.bind(OUTPUT, output_tensor(0)); - return {DROPOUT_INIT_TASK_ID, binding}; + return {task_id_t::DROPOUT_INIT_TASK_ID, binding}; } OpTaskInvocation forward(DropoutAttrs const &attrs) { @@ -31,16 +31,16 @@ OpTaskInvocation forward(DropoutAttrs const &attrs) { binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {DROPOUT_FWD_TASK_ID, binding}; + return {task_id_t::DROPOUT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(DropoutAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {DROPOUT_BWD_TASK_ID, b}; + return {task_id_t::DROPOUT_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); Allocator allocator = acc.get_allocator(); @@ -49,7 +49,8 @@ static DeviceSpecific DropoutPerDeviceState per_device_state = init_kernel(handle, attrs.rate, attrs.seed, output.shape, allocator); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -86,13 +87,13 @@ static std::optional } TaskImplFunction get_dropout_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_dropout_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_dropout_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_dropout_init_signature() { @@ -126,7 +127,9 @@ OpTaskSignature get_dropout_bwd_signature() { } std::vector get_task_ids(DropoutAttrs const &) { - return {DROPOUT_INIT_TASK_ID, DROPOUT_FWD_TASK_ID, DROPOUT_BWD_TASK_ID}; + return {task_id_t::DROPOUT_INIT_TASK_ID, + task_id_t::DROPOUT_FWD_TASK_ID, + task_id_t::DROPOUT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/dropout.h b/lib/local-execution/src/ops/dropout.h index 58910ab6dc..84b67a29c2 100644 --- a/lib/local-execution/src/ops/dropout.h +++ b/lib/local-execution/src/ops/dropout.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "local-execution/tasks.h" +#include "local-execution/task_id_t.dtg.h" #include "op-attrs/ops/dropout.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_binary.cc b/lib/local-execution/src/ops/element_binary.cc index bd5b415df3..48c6c699a2 100644 --- a/lib/local-execution/src/ops/element_binary.cc +++ b/lib/local-execution/src/ops/element_binary.cc @@ -27,7 +27,7 @@ OpTaskInvocation init(ElementBinaryAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {ELEMENTBINARY_INIT_TASK_ID, binding}; + return {task_id_t::ELEMENTBINARY_INIT_TASK_ID, binding}; } OpTaskInvocation forward(ElementBinaryAttrs const &attrs) { @@ -42,16 +42,16 @@ OpTaskInvocation forward(ElementBinaryAttrs const &attrs) { per_device_op_state()); binding.bind_arg(HANDLE, ff_handle()); - return {ELEMENTBINARY_FWD_TASK_ID, binding}; + return {task_id_t::ELEMENTBINARY_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ElementBinaryAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {ELEMENTBINARY_BWD_TASK_ID, b}; + return {task_id_t::ELEMENTBINARY_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto input_lhs = acc.get_tensor(LHS_INPUT); auto input_rhs = acc.get_tensor(RHS_INPUT); @@ -68,7 +68,8 @@ static DeviceSpecific input_lhs.shape, input_rhs.shape, output.shape); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -125,15 +126,15 @@ static std::optional } TaskImplFunction get_element_binary_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_element_binary_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_element_binary_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_element_binary_init_signature() { @@ -172,9 +173,9 @@ OpTaskSignature get_element_binary_bwd_signature() { } std::vector get_task_ids(ElementBinaryAttrs const &) { - return {ELEMENTBINARY_INIT_TASK_ID, - ELEMENTBINARY_FWD_TASK_ID, - ELEMENTBINARY_BWD_TASK_ID}; + return {task_id_t::ELEMENTBINARY_INIT_TASK_ID, + task_id_t::ELEMENTBINARY_FWD_TASK_ID, + task_id_t::ELEMENTBINARY_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc index 3185bbfef9..a52ebb8089 100644 --- a/lib/local-execution/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -26,7 +26,7 @@ OpTaskInvocation init(ElementUnaryAttrs const &attrs) { b.bind_arg(ATTRS, attrs); b.bind_arg(INPUT_SHAPE, input_parallel_tensor_shape(0)); - return {ELEMENTUNARY_INIT_TASK_ID, b}; + return {task_id_t::ELEMENTUNARY_INIT_TASK_ID, b}; } OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { @@ -39,16 +39,16 @@ OpTaskInvocation forward(ElementUnaryAttrs const &attrs) { b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {ELEMENTUNARY_FWD_TASK_ID, b}; + return {task_id_t::ELEMENTUNARY_FWD_TASK_ID, b}; } OpTaskInvocation backward(ElementUnaryAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {ELEMENTUNARY_BWD_TASK_ID, b}; + return {task_id_t::ELEMENTUNARY_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); @@ -61,7 +61,8 @@ static DeviceSpecific ElementUnaryPerDeviceState per_device_state = init_kernel( get_piece_shape(input_shape), get_piece_shape(output_shape), attrs); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -112,13 +113,13 @@ static std::optional } TaskImplFunction get_element_unary_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_element_unary_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_element_unary_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_element_unary_init_signature() { @@ -152,9 +153,9 @@ OpTaskSignature get_element_unary_bwd_signature() { } std::vector get_task_ids(ElementUnaryAttrs const &) { - return {ELEMENTUNARY_INIT_TASK_ID, - ELEMENTUNARY_FWD_TASK_ID, - ELEMENTUNARY_BWD_TASK_ID}; + return {task_id_t::ELEMENTUNARY_INIT_TASK_ID, + task_id_t::ELEMENTUNARY_FWD_TASK_ID, + task_id_t::ELEMENTUNARY_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/flat.cc b/lib/local-execution/src/ops/flat.cc index 5d791e4459..3fe5029fa1 100644 --- a/lib/local-execution/src/ops/flat.cc +++ b/lib/local-execution/src/ops/flat.cc @@ -15,13 +15,13 @@ OpTaskInvocation forward(FlatAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind_arg(PROFILING, profiling_settings()); - return {FLAT_FWD_TASK_ID, binding}; + return {task_id_t::FLAT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(FlatAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {FLAT_BWD_TASK_ID, b}; + return {task_id_t::FLAT_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -53,10 +53,10 @@ static std::optional } TaskImplFunction get_flat_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_flat_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_flat_fwd_signature() { @@ -76,7 +76,7 @@ OpTaskSignature get_flat_bwd_signature() { } std::vector get_task_ids(FlatAttrs const &) { - return {FLAT_FWD_TASK_ID, FLAT_BWD_TASK_ID}; + return {task_id_t::FLAT_FWD_TASK_ID, task_id_t::FLAT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/gather.cc b/lib/local-execution/src/ops/gather.cc index 44455bc42c..a015c64f4d 100644 --- a/lib/local-execution/src/ops/gather.cc +++ b/lib/local-execution/src/ops/gather.cc @@ -34,7 +34,7 @@ OpTaskInvocation init(GatherAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {GATHER_INIT_TASK_ID, binding}; + return {task_id_t::GATHER_INIT_TASK_ID, binding}; } OpTaskInvocation forward(GatherAttrs const &attrs) { @@ -49,16 +49,16 @@ OpTaskInvocation forward(GatherAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind(INDEX, weight_tensor(0)); - return {GATHER_FWD_TASK_ID, binding}; + return {task_id_t::GATHER_FWD_TASK_ID, binding}; } OpTaskInvocation backward(GatherAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {GATHER_BWD_TASK_ID, binding}; + return {task_id_t::GATHER_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); auto index = acc.get_tensor(INDEX); @@ -80,7 +80,8 @@ static DeviceSpecific } GatherPerDeviceState per_device_state = {handle, legion_dim}; - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -121,13 +122,13 @@ static std::optional } TaskImplFunction get_gather_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_gather_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_gather_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_gather_init_signature() { @@ -165,7 +166,9 @@ OpTaskSignature get_gather_bwd_signature() { } std::vector get_task_ids(GatherAttrs const &) { - return {GATHER_INIT_TASK_ID, GATHER_FWD_TASK_ID, GATHER_BWD_TASK_ID}; + return {task_id_t::GATHER_INIT_TASK_ID, + task_id_t::GATHER_FWD_TASK_ID, + task_id_t::GATHER_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/input.cc b/lib/local-execution/src/ops/input.cc new file mode 100644 index 0000000000..56d19fa1ba --- /dev/null +++ b/lib/local-execution/src/ops/input.cc @@ -0,0 +1,9 @@ +#include "input.h" + +namespace FlexFlow { + +std::vector get_task_ids(InputAttrs const &attrs) { + return {}; +} + +}; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/input.h b/lib/local-execution/src/ops/input.h new file mode 100644 index 0000000000..97985585e1 --- /dev/null +++ b/lib/local-execution/src/ops/input.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_INPUT_H +#define _FLEXFLOW_INPUT_H + +#include "local-execution/op_task_invocation.h" +#include "op-attrs/ops/input.h" + +namespace FlexFlow { + +std::vector get_task_ids(InputAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/src/ops/layer_norm.cc b/lib/local-execution/src/ops/layer_norm.cc index 9530628d46..e99d27319c 100644 --- a/lib/local-execution/src/ops/layer_norm.cc +++ b/lib/local-execution/src/ops/layer_norm.cc @@ -46,7 +46,7 @@ OpTaskInvocation init(LayerNormAttrs const &attrs) { b.bind_arg(HANDLE, ff_handle()); b.bind_arg(ATTRS, attrs); - return {LAYERNORM_INIT_TASK_ID, b}; + return {task_id_t::LAYERNORM_INIT_TASK_ID, b}; } OpTaskInvocation forward(LayerNormAttrs const &attrs) { @@ -59,13 +59,13 @@ OpTaskInvocation forward(LayerNormAttrs const &attrs) { b.bind_arg(PROFILING, profiling_settings()); b.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {LAYERNORM_FWD_TASK_ID, b}; + return {task_id_t::LAYERNORM_FWD_TASK_ID, b}; } OpTaskInvocation backward(LayerNormAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {LAYERNORM_BWD_TASK_ID, b}; + return {task_id_t::LAYERNORM_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -112,7 +112,7 @@ static std::optional beta_grad); } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); Allocator allocator = acc.get_allocator(); @@ -141,17 +141,18 @@ static DeviceSpecific effective_batch_size, effective_num_elements, attrs.eps); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } TaskImplFunction get_layer_norm_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_layer_norm_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_layer_norm_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_layer_norm_fwd_signature() { @@ -184,7 +185,9 @@ OpTaskSignature get_layer_norm_init_signature() { } std::vector get_task_ids(LayerNormAttrs const &) { - return {LAYERNORM_INIT_TASK_ID, LAYERNORM_FWD_TASK_ID, LAYERNORM_BWD_TASK_ID}; + return {task_id_t::LAYERNORM_INIT_TASK_ID, + task_id_t::LAYERNORM_FWD_TASK_ID, + task_id_t::LAYERNORM_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/linear.cc b/lib/local-execution/src/ops/linear.cc index 599f671e92..9934e2a45c 100644 --- a/lib/local-execution/src/ops/linear.cc +++ b/lib/local-execution/src/ops/linear.cc @@ -31,7 +31,7 @@ OpTaskInvocation init(LinearAttrs const &attrs) { binding.bind(WEIGHT, weight_tensor(0)); // weight binding.bind(OUTPUT, output_tensor(0)); // output - return {LINEAR_INIT_TASK_ID, binding}; + return {task_id_t::LINEAR_INIT_TASK_ID, binding}; } OpTaskInvocation forward(LinearAttrs const &attrs) { @@ -49,16 +49,16 @@ OpTaskInvocation forward(LinearAttrs const &attrs) { per_device_op_state()); binding.bind_arg(ATTRS, attrs); - return {LINEAR_FWD_TASK_ID, binding}; + return {task_id_t::LINEAR_FWD_TASK_ID, binding}; } OpTaskInvocation backward(LinearAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {LINEAR_BWD_TASK_ID, b}; + return {task_id_t::LINEAR_BWD_TASK_ID, b}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -71,17 +71,18 @@ static DeviceSpecific float *one_ptr; - LinearPerDeviceState state = init_kernel(handle, - one_ptr, - attrs.activation, - attrs.regularizer, - attrs.use_bias, - input.data_type, - weight.data_type, - output.data_type, - batch_size, - attrs.out_channels); - return DeviceSpecific::create(state); + LinearPerDeviceState per_device_state = init_kernel(handle, + one_ptr, + attrs.activation, + attrs.regularizer, + attrs.use_bias, + input.data_type, + weight.data_type, + output.data_type, + batch_size, + attrs.out_channels); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -160,13 +161,13 @@ static std::optional } TaskImplFunction get_linear_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_linear_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_linear_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_linear_init_signature() { @@ -203,7 +204,9 @@ OpTaskSignature get_linear_bwd_signature() { } std::vector get_task_ids(LinearAttrs const &) { - return {LINEAR_INIT_TASK_ID, LINEAR_FWD_TASK_ID, LINEAR_BWD_TASK_ID}; + return {task_id_t::LINEAR_INIT_TASK_ID, + task_id_t::LINEAR_FWD_TASK_ID, + task_id_t::LINEAR_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/noop.cc b/lib/local-execution/src/ops/noop.cc index 168d547c17..e35fdec275 100644 --- a/lib/local-execution/src/ops/noop.cc +++ b/lib/local-execution/src/ops/noop.cc @@ -14,21 +14,11 @@ */ #include "noop.h" -#include "local-execution/op_task_invocation.h" -#include "utils/hash-utils.h" namespace FlexFlow { -std::optional init(NoopAttrs const &attrs) { - return std::nullopt; -} - -std::optional forward(NoopAttrs const &attrs) { - return std::nullopt; -} - -std::optional backward(NoopAttrs const &attrs) { - return std::nullopt; +std::vector get_task_ids(NoopAttrs const &attrs) { + return {}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/noop.h b/lib/local-execution/src/ops/noop.h index fab2cf1f86..959f7dc054 100644 --- a/lib/local-execution/src/ops/noop.h +++ b/lib/local-execution/src/ops/noop.h @@ -4,12 +4,12 @@ #include "local-execution/op_task_invocation.h" #include "op-attrs/ops/input.h" #include "op-attrs/ops/noop.h" +#include "op-attrs/ops/weight_attrs.dtg.h" namespace FlexFlow { -std::optional init(NoopAttrs const &); -std::optional forward(NoopAttrs const &); -std::optional backward(NoopAttrs const &); +std::vector get_task_ids(NoopAttrs const &); + } // namespace FlexFlow #endif diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc index d6f100390a..789ed2cd63 100644 --- a/lib/local-execution/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -20,10 +20,10 @@ OpTaskInvocation init(Pool2DAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); binding.bind_arg(HANDLE, ff_handle()); - return {POOL2D_INIT_TASK_ID, binding}; + return {task_id_t::POOL2D_INIT_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -64,25 +64,26 @@ static DeviceSpecific printf("Warning: changing pool_padding_w to satisfy output_w size\n"); } - Pool2DPerDeviceState state = init_kernel(handle, - attrs.activation, - input_w, - input_h, - input_c, - input_n, - output_w, - output_h, - output_c, - output_n, - pad_h, - pad_w, - attrs.kernel_h, - attrs.kernel_w, - attrs.stride_h, - attrs.stride_w, - attrs.pool_type); - - return DeviceSpecific::create(state); + Pool2DPerDeviceState per_device_state = init_kernel(handle, + attrs.activation, + input_w, + input_h, + input_c, + input_n, + output_w, + output_h, + output_c, + output_n, + pad_h, + pad_w, + attrs.kernel_h, + attrs.kernel_w, + attrs.stride_h, + attrs.stride_w, + attrs.pool_type); + + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } OpTaskInvocation forward(Pool2DAttrs const &attrs) { @@ -94,13 +95,13 @@ OpTaskInvocation forward(Pool2DAttrs const &attrs) { binding.bind_arg(PER_DEVICE_STATE, per_device_op_state()); - return {POOL2D_FWD_TASK_ID, binding}; + return {task_id_t::POOL2D_FWD_TASK_ID, binding}; } OpTaskInvocation backward(Pool2DAttrs const &attrs) { OpTaskBinding b = infer_bwd_binding(forward(attrs).binding); - return {POOL2D_BWD_TASK_ID, b}; + return {task_id_t::POOL2D_BWD_TASK_ID, b}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -141,13 +142,13 @@ static std::optional } TaskImplFunction get_pool_2d_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_pool_2d_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_pool_2d_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_pool_2d_init_signature() { @@ -178,7 +179,9 @@ OpTaskSignature get_pool_2d_bwd_signature() { } std::vector get_task_ids(Pool2DAttrs const &) { - return {POOL2D_INIT_TASK_ID, POOL2D_FWD_TASK_ID, POOL2D_BWD_TASK_ID}; + return {task_id_t::POOL2D_INIT_TASK_ID, + task_id_t::POOL2D_FWD_TASK_ID, + task_id_t::POOL2D_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reduce.cc b/lib/local-execution/src/ops/reduce.cc index 23c05eb17e..a043d9f847 100644 --- a/lib/local-execution/src/ops/reduce.cc +++ b/lib/local-execution/src/ops/reduce.cc @@ -29,10 +29,10 @@ OpTaskInvocation init(ReduceAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REDUCE_INIT_TASK_ID, binding}; + return {task_id_t::REDUCE_INIT_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); auto attrs = acc.get_argument(ATTRS); @@ -44,7 +44,8 @@ static DeviceSpecific size_t reduction_size = input.shape.get_volume() / output.shape.get_volume(); ReducePerDeviceState per_device_state = init_kernel(handle, op_type, reduction_size, input.shape, output.shape); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } // Note: forward_kernel only needs ReducePerDeviceState, input, output @@ -58,7 +59,7 @@ OpTaskInvocation forward(ReduceAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REDUCE_FWD_TASK_ID, binding}; + return {task_id_t::REDUCE_FWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -80,7 +81,7 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { OpTaskInvocation backward(ReduceAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REDUCE_BWD_TASK_ID, binding}; + return {task_id_t::REDUCE_BWD_TASK_ID, binding}; } static std::optional @@ -101,13 +102,13 @@ static std::optional } TaskImplFunction get_reduce_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_reduce_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reduce_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reduce_init_signature() { @@ -135,7 +136,9 @@ OpTaskSignature get_reduce_bwd_signature() { } std::vector get_task_ids(ReduceAttrs const &) { - return {REDUCE_INIT_TASK_ID, REDUCE_FWD_TASK_ID, REDUCE_BWD_TASK_ID}; + return {task_id_t::REDUCE_INIT_TASK_ID, + task_id_t::REDUCE_FWD_TASK_ID, + task_id_t::REDUCE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reduction.cc b/lib/local-execution/src/ops/reduction.cc index 85d9d48148..a58d79a4f8 100644 --- a/lib/local-execution/src/ops/reduction.cc +++ b/lib/local-execution/src/ops/reduction.cc @@ -33,13 +33,13 @@ OpTaskInvocation forward(ReductionAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REDUCTION_FWD_TASK_ID, binding}; + return {task_id_t::REDUCTION_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReductionAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REDUCTION_BWD_TASK_ID, binding}; + return {task_id_t::REDUCTION_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -74,10 +74,10 @@ static std::optional } TaskImplFunction get_reduction_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reduction_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reduction_fwd_signature() { @@ -96,7 +96,7 @@ OpTaskSignature get_reduction_bwd_signature() { } std::vector get_task_ids(ReductionAttrs const &) { - return {REDUCTION_FWD_TASK_ID, REDUCTION_BWD_TASK_ID}; + return {task_id_t::REDUCTION_FWD_TASK_ID, task_id_t::REDUCTION_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/repartition.cc b/lib/local-execution/src/ops/repartition.cc index 61050ed3d0..73692f4a13 100644 --- a/lib/local-execution/src/ops/repartition.cc +++ b/lib/local-execution/src/ops/repartition.cc @@ -31,7 +31,7 @@ OpTaskInvocation init(RepartitionAttrs const &attrs) { binding.bind_arg(HANDLE, ff_handle()); binding.bind(INPUT, input_tensor(0)); - return {REPARTITION_INIT_TASK_ID, binding}; + return {task_id_t::REPARTITION_INIT_TASK_ID, binding}; } OpTaskInvocation forward(RepartitionAttrs const &attrs) { @@ -44,16 +44,16 @@ OpTaskInvocation forward(RepartitionAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REPARTITION_FWD_TASK_ID, binding}; + return {task_id_t::REPARTITION_FWD_TASK_ID, binding}; } OpTaskInvocation backward(RepartitionAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REPARTITION_BWD_TASK_ID, binding}; + return {task_id_t::REPARTITION_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto input = acc.get_tensor(INPUT); PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -62,7 +62,8 @@ static DeviceSpecific RepartitionPerDeviceState per_device_state = init_kernel(handle, input.data_type); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -97,13 +98,13 @@ static std::optional } TaskImplFunction get_repartition_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_repartition_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_repartition_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_repartition_init_signature() { @@ -129,9 +130,9 @@ OpTaskSignature get_repartition_bwd_signature() { } std::vector get_task_ids(RepartitionAttrs const &) { - return {REPARTITION_INIT_TASK_ID, - REPARTITION_FWD_TASK_ID, - REPARTITION_BWD_TASK_ID}; + return {task_id_t::REPARTITION_INIT_TASK_ID, + task_id_t::REPARTITION_FWD_TASK_ID, + task_id_t::REPARTITION_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/replicate.cc b/lib/local-execution/src/ops/replicate.cc index b3d3a152d6..135475a711 100644 --- a/lib/local-execution/src/ops/replicate.cc +++ b/lib/local-execution/src/ops/replicate.cc @@ -35,12 +35,12 @@ OpTaskInvocation forward(ReplicateAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind_arg(ATTRS, attrs); - return {REPLICATE_FWD_TASK_ID, binding}; + return {task_id_t::REPLICATE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReplicateAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REPLICATE_BWD_TASK_ID, binding}; + return {task_id_t::REPLICATE_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -73,10 +73,10 @@ static std::optional } TaskImplFunction get_replicate_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_replicate_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_replicate_fwd_signature() { @@ -94,7 +94,7 @@ OpTaskSignature get_replicate_bwd_signature() { } std::vector get_task_ids(ReplicateAttrs const &) { - return {REPLICATE_FWD_TASK_ID, REPLICATE_BWD_TASK_ID}; + return {task_id_t::REPLICATE_FWD_TASK_ID, task_id_t::REPLICATE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reshape.cc b/lib/local-execution/src/ops/reshape.cc index 8d6c0d83c2..7584d405eb 100644 --- a/lib/local-execution/src/ops/reshape.cc +++ b/lib/local-execution/src/ops/reshape.cc @@ -28,7 +28,7 @@ OpTaskInvocation init(ReshapeAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); - return {RESHAPE_INIT_TASK_ID, binding}; + return {task_id_t::RESHAPE_INIT_TASK_ID, binding}; } OpTaskInvocation forward(ReshapeAttrs const &attrs) { @@ -40,21 +40,22 @@ OpTaskInvocation forward(ReshapeAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {RESHAPE_FWD_TASK_ID, binding}; + return {task_id_t::RESHAPE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReshapeAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {RESHAPE_BWD_TASK_ID, binding}; + return {task_id_t::RESHAPE_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto attrs = acc.get_argument(ATTRS); ReshapePerDeviceState per_device_state = init_kernel(attrs.shape.data_type); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -91,13 +92,13 @@ static std::optional } TaskImplFunction get_reshape_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_reshape_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reshape_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reshape_init_signature() { @@ -124,7 +125,9 @@ OpTaskSignature get_reshape_bwd_signature() { } std::vector get_task_ids(ReshapeAttrs const &) { - return {RESHAPE_INIT_TASK_ID, RESHAPE_FWD_TASK_ID, RESHAPE_BWD_TASK_ID}; + return {task_id_t::RESHAPE_INIT_TASK_ID, + task_id_t::RESHAPE_FWD_TASK_ID, + task_id_t::RESHAPE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/reverse.cc b/lib/local-execution/src/ops/reverse.cc index b2a5107360..366a579bea 100644 --- a/lib/local-execution/src/ops/reverse.cc +++ b/lib/local-execution/src/ops/reverse.cc @@ -34,12 +34,12 @@ OpTaskInvocation forward(ReverseAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {REVERSE_FWD_TASK_ID, binding}; + return {task_id_t::REVERSE_FWD_TASK_ID, binding}; } OpTaskInvocation backward(ReverseAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {REVERSE_BWD_TASK_ID, binding}; + return {task_id_t::REVERSE_BWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -103,10 +103,10 @@ static std::optional } TaskImplFunction get_reverse_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_reverse_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_reverse_fwd_signature() { @@ -124,7 +124,7 @@ OpTaskSignature get_reverse_bwd_signature() { } std::vector get_task_ids(ReverseAttrs const &) { - return {REVERSE_FWD_TASK_ID, REVERSE_BWD_TASK_ID}; + return {task_id_t::REVERSE_FWD_TASK_ID, task_id_t::REVERSE_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/softmax.cc b/lib/local-execution/src/ops/softmax.cc index a0b3a047a7..4c7979ae9b 100644 --- a/lib/local-execution/src/ops/softmax.cc +++ b/lib/local-execution/src/ops/softmax.cc @@ -30,7 +30,7 @@ OpTaskInvocation init(SoftmaxAttrs const &attrs) { binding.bind_arg(HANDLE, ff_handle()); binding.bind_arg(ATTRS, attrs); - return {SOFTMAX_INIT_TASK_ID, binding}; + return {task_id_t::SOFTMAX_INIT_TASK_ID, binding}; } OpTaskInvocation forward(SoftmaxAttrs const &attrs) { @@ -43,16 +43,16 @@ OpTaskInvocation forward(SoftmaxAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {SOFTMAX_FWD_TASK_ID, binding}; + return {task_id_t::SOFTMAX_FWD_TASK_ID, binding}; } OpTaskInvocation backward(SoftmaxAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {SOFTMAX_BWD_TASK_ID, binding}; + return {task_id_t::SOFTMAX_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { PerDeviceFFHandle handle = acc.get_argument(HANDLE); @@ -67,7 +67,8 @@ static DeviceSpecific SoftmaxPerDeviceState per_device_state = init_kernel( handle, attrs.dim.value, output_n, output_c, output_h, output_w); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -107,13 +108,13 @@ static std::optional } TaskImplFunction get_softmax_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_softmax_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_softmax_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_softmax_init_signature() { @@ -140,7 +141,9 @@ OpTaskSignature get_softmax_bwd_signature() { } std::vector get_task_ids(SoftmaxAttrs const &) { - return {SOFTMAX_INIT_TASK_ID, SOFTMAX_FWD_TASK_ID, SOFTMAX_BWD_TASK_ID}; + return {task_id_t::SOFTMAX_INIT_TASK_ID, + task_id_t::SOFTMAX_FWD_TASK_ID, + task_id_t::SOFTMAX_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/split.cc b/lib/local-execution/src/ops/split.cc index 59d68dc8f5..9f039d84f8 100644 --- a/lib/local-execution/src/ops/split.cc +++ b/lib/local-execution/src/ops/split.cc @@ -35,13 +35,13 @@ OpTaskInvocation forward(SplitAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {SPLIT_FWD_TASK_ID, binding}; + return {task_id_t::SPLIT_FWD_TASK_ID, binding}; } OpTaskInvocation backward(SplitAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {SPLIT_BWD_TASK_ID, binding}; + return {task_id_t::SPLIT_BWD_TASK_ID, binding}; } void calc_block_size(coord_t &num_blocks, @@ -114,10 +114,10 @@ static std::optional } TaskImplFunction get_split_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_split_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_split_fwd_signature() { @@ -135,7 +135,7 @@ OpTaskSignature get_split_bwd_signature() { } std::vector get_task_ids(SplitAttrs const &) { - return {SPLIT_FWD_TASK_ID, SPLIT_BWD_TASK_ID}; + return {task_id_t::SPLIT_FWD_TASK_ID, task_id_t::SPLIT_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/topk.cc b/lib/local-execution/src/ops/topk.cc index 1669f58c3b..7f3519529a 100644 --- a/lib/local-execution/src/ops/topk.cc +++ b/lib/local-execution/src/ops/topk.cc @@ -33,7 +33,7 @@ OpTaskInvocation init(TopKAttrs const &attrs) { binding.bind_arg(ATTRS, attrs); - return {TOPK_INIT_TASK_ID, binding}; + return {task_id_t::TOPK_INIT_TASK_ID, binding}; } OpTaskInvocation forward(TopKAttrs const &attrs) { @@ -47,22 +47,23 @@ OpTaskInvocation forward(TopKAttrs const &attrs) { binding.bind(OUTPUT, output_tensor(0)); binding.bind(INDICES, output_tensor(1)); - return {TOPK_FWD_TASK_ID, binding}; + return {task_id_t::TOPK_FWD_TASK_ID, binding}; } OpTaskInvocation backward(TopKAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {TOPK_BWD_TASK_ID, binding}; + return {task_id_t::TOPK_BWD_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto attrs = acc.get_argument(ATTRS); TopKPerDeviceState per_device_state = init_kernel(attrs.sorted); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -119,13 +120,13 @@ static std::optional } TaskImplFunction get_topk_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_topk_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_topk_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_topk_init_signature() { @@ -154,7 +155,9 @@ OpTaskSignature get_topk_bwd_signature() { } std::vector get_task_ids(TopKAttrs const &) { - return {TOPK_INIT_TASK_ID, TOPK_FWD_TASK_ID, TOPK_BWD_TASK_ID}; + return {task_id_t::TOPK_INIT_TASK_ID, + task_id_t::TOPK_FWD_TASK_ID, + task_id_t::TOPK_BWD_TASK_ID}; } }; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/transpose.cc b/lib/local-execution/src/ops/transpose.cc index 5fa57772e2..5c3c1dd1ca 100644 --- a/lib/local-execution/src/ops/transpose.cc +++ b/lib/local-execution/src/ops/transpose.cc @@ -34,16 +34,17 @@ enum Slots { OpTaskInvocation init(TransposeAttrs const &attrs) { OpTaskBinding binding; binding.bind_arg(ATTRS, attrs); - return {TRANSPOSE_INIT_TASK_ID, binding}; + return {task_id_t::TRANSPOSE_INIT_TASK_ID, binding}; } -static DeviceSpecific +static DeviceSpecificDeviceStates init_task_impl(TaskArgumentAccessor const &acc) { auto const &attrs = acc.get_argument(ATTRS); std::vector perm = inner_to_outer_idxs(attrs.perm); TransposePerDeviceState per_device_state = init_kernel(perm.size(), perm); - return DeviceSpecific::create(per_device_state); + return DeviceSpecificDeviceStates{ + DeviceSpecific::create(per_device_state)}; } OpTaskInvocation forward(TransposeAttrs const &attrs) { @@ -56,7 +57,7 @@ OpTaskInvocation forward(TransposeAttrs const &attrs) { binding.bind(INPUT, input_tensor(0)); binding.bind(OUTPUT, output_tensor(0)); - return {TRANSPOSE_FWD_TASK_ID, binding}; + return {task_id_t::TRANSPOSE_FWD_TASK_ID, binding}; } static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { @@ -95,17 +96,17 @@ static std::optional OpTaskInvocation backward(TransposeAttrs const &attrs) { OpTaskBinding binding = infer_bwd_binding(forward(attrs).binding); - return {TRANSPOSE_BWD_TASK_ID, binding}; + return {task_id_t::TRANSPOSE_BWD_TASK_ID, binding}; } TaskImplFunction get_transpose_init_task_impl() { - return init_task_impl; + return TaskImplFunction{InitTaskImplFunction{init_task_impl}}; } TaskImplFunction get_transpose_fwd_task_impl() { - return forward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{forward_task_impl}}; } TaskImplFunction get_transpose_bwd_task_impl() { - return backward_task_impl; + return TaskImplFunction{FwdBwdTaskImplFunction{backward_task_impl}}; } OpTaskSignature get_transpose_init_signature() { @@ -131,7 +132,9 @@ OpTaskSignature get_transpose_bwd_signature() { } std::vector get_task_ids(TransposeAttrs const &) { - return {TRANSPOSE_INIT_TASK_ID, TRANSPOSE_FWD_TASK_ID, TRANSPOSE_BWD_TASK_ID}; + return {task_id_t::TRANSPOSE_INIT_TASK_ID, + task_id_t::TRANSPOSE_FWD_TASK_ID, + task_id_t::TRANSPOSE_BWD_TASK_ID}; } } // namespace FlexFlow diff --git a/lib/local-execution/src/ops/weight.cc b/lib/local-execution/src/ops/weight.cc new file mode 100644 index 0000000000..5537163e85 --- /dev/null +++ b/lib/local-execution/src/ops/weight.cc @@ -0,0 +1,9 @@ +#include "weight.h" + +namespace FlexFlow { + +std::vector get_task_ids(WeightAttrs const &attrs) { + return {}; +} + +}; // namespace FlexFlow diff --git a/lib/local-execution/src/ops/weight.h b/lib/local-execution/src/ops/weight.h new file mode 100644 index 0000000000..e59a88f07d --- /dev/null +++ b/lib/local-execution/src/ops/weight.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_WEIGHT_H +#define _FLEXFLOW_WEIGHT_H + +#include "local-execution/op_task_invocation.h" +#include "op-attrs/ops/weight_attrs.dtg.h" + +namespace FlexFlow { + +std::vector get_task_ids(WeightAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/local-execution/src/per_device_state.cc b/lib/local-execution/src/per_device_state.cc new file mode 100644 index 0000000000..fa470b196d --- /dev/null +++ b/lib/local-execution/src/per_device_state.cc @@ -0,0 +1,12 @@ +#include "local-execution/per_device_op_state.h" +#include "utils/overload.h" + +namespace FlexFlow { + +PerDeviceOpState get_device_state_from_device_specific( + DeviceSpecificDeviceStates const &device_specific, size_t device_idx) { + return device_specific.visit( + [&](auto const &x) { return PerDeviceOpState{*(x.get(device_idx))}; }); +} + +} // namespace FlexFlow diff --git a/lib/local-execution/src/task_registry.cc b/lib/local-execution/src/task_registry.cc index 23a0c96d28..dad5c1fc69 100644 --- a/lib/local-execution/src/task_registry.cc +++ b/lib/local-execution/src/task_registry.cc @@ -1,31 +1,44 @@ #include "local-execution/task_registry.h" +#include "local-execution/task_signature_impl.h" namespace FlexFlow { -void TaskRegistry::register_task(task_id_t const &task_id, - layer_guid_t const &op_id, - ComputationGraphOpAttrs const &attrs) { - TaskSignatureAndImpl task_signature_impl = get_task_sig_impl(task_id); - switch (task_signature_impl.task_signature.type) { - case OpTaskType::INIT: - assert( - is_invocation_valid(task_signature_impl.task_signature, init(attrs))); - this->init_task_ids.insert({op_id, task_id}); - break; - case OpTaskType::FWD: - assert(is_invocation_valid(task_signature_impl.task_signature, - forward(attrs))); - this->forward_task_ids.insert({op_id, task_id}); - break; - case OpTaskType::BWD: - assert(is_invocation_valid(task_signature_impl.task_signature, - backward(attrs))); - this->backward_task_ids.insert({op_id, task_id}); - break; - default: - throw mk_runtime_error("Invalid OpTaskType"); +TaskRegistry empty_task_registry() { + return TaskRegistry{{}, {}, {}, {}}; +} + +void register_tasks_for_layer(TaskRegistry &task_registry, + layer_guid_t const &op_id, + ComputationGraphOpAttrs const &attrs) { + task_registry.init_task_ids.insert({op_id, std::nullopt}); + task_registry.forward_task_ids.insert({op_id, std::nullopt}); + task_registry.backward_task_ids.insert({op_id, std::nullopt}); + + // register tasks + std::vector task_ids = get_task_ids(attrs); + for (task_id_t task_id : task_ids) { + TaskSignatureAndImpl task_signature_impl = get_task_sig_impl(task_id); + switch (task_signature_impl.task_signature.type) { + case OpTaskType::INIT: + assert(is_invocation_valid(task_signature_impl.task_signature, + init(attrs))); + task_registry.init_task_ids[op_id] = task_id; + break; + case OpTaskType::FWD: + assert(is_invocation_valid(task_signature_impl.task_signature, + forward(attrs))); + task_registry.forward_task_ids[op_id] = task_id; + break; + case OpTaskType::BWD: + assert(is_invocation_valid(task_signature_impl.task_signature, + backward(attrs))); + task_registry.backward_task_ids[op_id] = task_id; + break; + default: + throw mk_runtime_error("Invalid OpTaskType"); + } + task_registry.task_mapping.insert({task_id, task_signature_impl}); } - this->task_mapping.insert({task_id, task_signature_impl}); } } // namespace FlexFlow diff --git a/lib/local-execution/src/task_signature_impl.cc b/lib/local-execution/src/task_signature_impl.cc index 62e77e3199..ca428aad25 100644 --- a/lib/local-execution/src/task_signature_impl.cc +++ b/lib/local-execution/src/task_signature_impl.cc @@ -1,5 +1,5 @@ #include "local-execution/task_signature_impl.h" -#include "ops/attention.h" +#include "local-execution/ops/attention.h" #include "ops/batch_matmul.h" #include "ops/batch_norm.h" #include "ops/cast.h" @@ -12,8 +12,10 @@ #include "ops/embedding.h" #include "ops/flat.h" #include "ops/gather.h" +#include "ops/input.h" #include "ops/layer_norm.h" #include "ops/linear.h" +#include "ops/noop.h" #include "ops/pool_2d.h" #include "ops/reduce.h" #include "ops/reduction.h" @@ -25,160 +27,344 @@ #include "ops/split.h" #include "ops/topk.h" #include "ops/transpose.h" +#include "ops/weight.h" +#include "utils/overload.h" namespace FlexFlow { TaskSignatureAndImpl get_task_sig_impl(task_id_t const &task_id) { switch (task_id) { - case ELEMENTBINARY_INIT_TASK_ID: - return {get_element_binary_init_task_impl(), - get_element_binary_init_signature()}; - case ELEMENTBINARY_FWD_TASK_ID: - return {get_element_binary_fwd_task_impl(), - get_element_binary_fwd_signature()}; - case ELEMENTBINARY_BWD_TASK_ID: - return {get_element_binary_bwd_task_impl(), - get_element_binary_bwd_signature()}; - case ELEMENTUNARY_INIT_TASK_ID: - return {get_element_unary_init_task_impl(), - get_element_unary_init_signature()}; - case ELEMENTUNARY_FWD_TASK_ID: - return {get_element_unary_fwd_task_impl(), - get_element_unary_fwd_signature()}; - case ELEMENTUNARY_BWD_TASK_ID: - return {get_element_binary_bwd_task_impl(), - get_element_binary_bwd_signature()}; - case CONV2D_INIT_TASK_ID: - return {get_conv_2d_init_task_impl(), get_conv_2d_init_signature()}; - case CONV2D_FWD_TASK_ID: - return {get_conv_2d_fwd_task_impl(), get_conv_2d_fwd_signature()}; - case CONV2D_BWD_TASK_ID: - return {get_conv_2d_bwd_task_impl(), get_conv_2d_bwd_signature()}; - case DROPOUT_INIT_TASK_ID: - return {get_dropout_init_task_impl(), get_dropout_init_signature()}; - case DROPOUT_FWD_TASK_ID: - return {get_dropout_fwd_task_impl(), get_dropout_fwd_signature()}; - case DROPOUT_BWD_TASK_ID: - return {get_dropout_bwd_task_impl(), get_dropout_bwd_signature()}; - case EMBED_FWD_TASK_ID: - return {get_embedding_fwd_task_impl(), get_embedding_fwd_signature()}; - case EMBED_BWD_TASK_ID: - return {get_embedding_bwd_task_impl(), get_embedding_bwd_signature()}; - case GATHER_INIT_TASK_ID: - return {get_gather_init_task_impl(), get_gather_init_signature()}; - case GATHER_FWD_TASK_ID: - return {get_gather_fwd_task_impl(), get_embedding_fwd_signature()}; - case GATHER_BWD_TASK_ID: - return {get_gather_bwd_task_impl(), get_embedding_bwd_signature()}; - case CAST_FWD_TASK_ID: - return {get_cast_fwd_task_impl(), get_cast_fwd_signature()}; - case CAST_BWD_TASK_ID: - return {get_cast_bwd_task_impl(), get_cast_bwd_signature()}; - case POOL2D_INIT_TASK_ID: - return {get_pool_2d_init_task_impl(), get_pool_2d_init_signature()}; - case POOL2D_FWD_TASK_ID: - return {get_pool_2d_fwd_task_impl(), get_pool_2d_fwd_signature()}; - case POOL2D_BWD_TASK_ID: - return {get_pool_2d_bwd_task_impl(), get_pool_2d_bwd_signature()}; - case BATCHNORM_INIT_TASK_ID: - return {get_batch_norm_init_task_impl(), get_batch_norm_init_signature()}; - case BATCHNORM_FWD_TASK_ID: - return {get_batch_norm_fwd_task_impl(), get_batch_norm_fwd_signature()}; - case BATCHNORM_BWD_TASK_ID: - return {get_batch_norm_bwd_task_impl(), get_batch_norm_bwd_signature()}; - case BATCHMATMUL_FWD_TASK_ID: - return {get_batch_matmul_fwd_task_impl(), - get_batch_matmul_fwd_signature()}; - case BATCHMATMUL_BWD_TASK_ID: - return {get_batch_matmul_bwd_task_impl(), - get_batch_matmul_bwd_signature()}; - case LAYERNORM_INIT_TASK_ID: - return {get_layer_norm_init_task_impl(), get_layer_norm_init_signature()}; - case LAYERNORM_FWD_TASK_ID: - return {get_layer_norm_fwd_task_impl(), get_layer_norm_init_signature()}; - case LAYERNORM_BWD_TASK_ID: - return {get_layer_norm_bwd_task_impl(), get_layer_norm_bwd_signature()}; - case LINEAR_INIT_TASK_ID: - return {get_linear_init_task_impl(), get_linear_init_signature()}; - case LINEAR_FWD_TASK_ID: - return {get_linear_fwd_task_impl(), get_linear_fwd_signature()}; - case LINEAR_BWD_TASK_ID: - return {get_linear_bwd_task_impl(), get_linear_bwd_signature()}; - case FLAT_FWD_TASK_ID: - return {get_flat_fwd_task_impl(), get_flat_fwd_signature()}; - case FLAT_BWD_TASK_ID: - return {get_flat_bwd_task_impl(), get_flat_bwd_signature()}; - case SOFTMAX_INIT_TASK_ID: - return {get_softmax_init_task_impl(), get_softmax_init_signature()}; - case SOFTMAX_FWD_TASK_ID: - return {get_softmax_fwd_task_impl(), get_softmax_fwd_signature()}; - case SOFTMAX_BWD_TASK_ID: - return {get_softmax_bwd_task_impl(), get_softmax_bwd_signature()}; - case CONCAT_FWD_TASK_ID: - return {get_concat_fwd_task_impl(), get_concat_fwd_signature()}; - case CONCAT_BWD_TASK_ID: - return {get_concat_bwd_task_impl(), get_concat_bwd_signature()}; - case SPLIT_FWD_TASK_ID: - return {get_split_fwd_task_impl(), get_split_fwd_signature()}; - case SPLIT_BWD_TASK_ID: - return {get_split_bwd_task_impl(), get_split_bwd_signature()}; - case REDUCE_INIT_TASK_ID: - return {get_reduce_init_task_impl(), get_reduce_init_signature()}; - case REDUCE_FWD_TASK_ID: - return {get_reduce_fwd_task_impl(), get_reduce_fwd_signature()}; - case REDUCE_BWD_TASK_ID: - return {get_reduce_bwd_task_impl(), get_reduce_bwd_signature()}; - case RESHAPE_INIT_TASK_ID: - return {get_reshape_init_task_impl(), get_reshape_init_signature()}; - case RESHAPE_FWD_TASK_ID: - return {get_reshape_fwd_task_impl(), get_reshape_fwd_signature()}; - case RESHAPE_BWD_TASK_ID: - return {get_reshape_bwd_task_impl(), get_reshape_bwd_signature()}; - case REVERSE_FWD_TASK_ID: - return {get_reverse_fwd_task_impl(), get_reverse_fwd_signature()}; - case REVERSE_BWD_TASK_ID: - return {get_reverse_bwd_task_impl(), get_reverse_bwd_signature()}; - case TOPK_INIT_TASK_ID: - return {get_topk_init_task_impl(), get_topk_init_signature()}; - case TOPK_FWD_TASK_ID: - return {get_topk_fwd_task_impl(), get_topk_fwd_signature()}; - case TOPK_BWD_TASK_ID: - return {get_topk_bwd_task_impl(), get_topk_bwd_signature()}; - case TRANSPOSE_INIT_TASK_ID: - return {get_transpose_init_task_impl(), get_transpose_init_signature()}; - case TRANSPOSE_FWD_TASK_ID: - return {get_transpose_fwd_task_impl(), get_transpose_fwd_signature()}; - case TRANSPOSE_BWD_TASK_ID: - return {get_transpose_bwd_task_impl(), get_transpose_bwd_signature()}; - case ATTENTION_INIT_TASK_ID: - return {get_attention_init_task_impl(), get_attention_init_signature()}; - case ATTENTION_FWD_TASK_ID: - return {get_attention_fwd_task_impl(), get_attention_fwd_signature()}; - case ATTENTION_BWD_TASK_ID: - return {get_attention_bwd_task_impl(), get_attention_bwd_signature()}; - case COMBINE_FWD_TASK_ID: - return {get_combine_fwd_task_impl(), get_combine_fwd_signature()}; - case COMBINE_BWD_TASK_ID: - return {get_combine_bwd_task_impl(), get_combine_bwd_signature()}; - case REDUCTION_FWD_TASK_ID: - return {get_reduction_fwd_task_impl(), get_reduction_fwd_signature()}; - case REDUCTION_BWD_TASK_ID: - return {get_reduction_bwd_task_impl(), get_reduction_bwd_signature()}; - case REPARTITION_INIT_TASK_ID: - return {get_repartition_init_task_impl(), - get_repartition_init_signature()}; - case REPARTITION_FWD_TASK_ID: - return {get_repartition_fwd_task_impl(), get_repartition_fwd_signature()}; - case REPARTITION_BWD_TASK_ID: - return {get_repartition_bwd_task_impl(), get_repartition_bwd_signature()}; - case REPLICATE_FWD_TASK_ID: - return {get_replicate_fwd_task_impl(), get_replicate_fwd_signature()}; - case REPLICATE_BWD_TASK_ID: - return {get_replicate_bwd_task_impl(), get_replicate_bwd_signature()}; + case task_id_t::ELEMENTBINARY_INIT_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_init_task_impl(), + get_element_binary_init_signature()}; + case task_id_t::ELEMENTBINARY_FWD_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_fwd_task_impl(), + get_element_binary_fwd_signature()}; + case task_id_t::ELEMENTBINARY_BWD_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_bwd_task_impl(), + get_element_binary_bwd_signature()}; + case task_id_t::ELEMENTUNARY_INIT_TASK_ID: + return TaskSignatureAndImpl{get_element_unary_init_task_impl(), + get_element_unary_init_signature()}; + case task_id_t::ELEMENTUNARY_FWD_TASK_ID: + return TaskSignatureAndImpl{get_element_unary_fwd_task_impl(), + get_element_unary_fwd_signature()}; + case task_id_t::ELEMENTUNARY_BWD_TASK_ID: + return TaskSignatureAndImpl{get_element_binary_bwd_task_impl(), + get_element_binary_bwd_signature()}; + case task_id_t::CONV2D_INIT_TASK_ID: + return TaskSignatureAndImpl{get_conv_2d_init_task_impl(), + get_conv_2d_init_signature()}; + case task_id_t::CONV2D_FWD_TASK_ID: + return TaskSignatureAndImpl{get_conv_2d_fwd_task_impl(), + get_conv_2d_fwd_signature()}; + case task_id_t::CONV2D_BWD_TASK_ID: + return TaskSignatureAndImpl{get_conv_2d_bwd_task_impl(), + get_conv_2d_bwd_signature()}; + case task_id_t::DROPOUT_INIT_TASK_ID: + return TaskSignatureAndImpl{get_dropout_init_task_impl(), + get_dropout_init_signature()}; + case task_id_t::DROPOUT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_dropout_fwd_task_impl(), + get_dropout_fwd_signature()}; + case task_id_t::DROPOUT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_dropout_bwd_task_impl(), + get_dropout_bwd_signature()}; + // case task_id_t::EMBED_FWD_TASK_ID: + // return TaskSignatureAndImpl{get_embedding_fwd_task_impl(), + // get_embedding_fwd_signature()}; + // case task_id_t::EMBED_BWD_TASK_ID: + // return TaskSignatureAndImpl{get_embedding_bwd_task_impl(), + // get_embedding_bwd_signature()}; + case task_id_t::GATHER_INIT_TASK_ID: + return TaskSignatureAndImpl{get_gather_init_task_impl(), + get_gather_init_signature()}; + case task_id_t::GATHER_FWD_TASK_ID: + return TaskSignatureAndImpl{get_gather_fwd_task_impl(), + get_gather_fwd_signature()}; + case task_id_t::GATHER_BWD_TASK_ID: + return TaskSignatureAndImpl{get_gather_bwd_task_impl(), + get_gather_bwd_signature()}; + case task_id_t::CAST_FWD_TASK_ID: + return TaskSignatureAndImpl{get_cast_fwd_task_impl(), + get_cast_fwd_signature()}; + case task_id_t::CAST_BWD_TASK_ID: + return TaskSignatureAndImpl{get_cast_bwd_task_impl(), + get_cast_bwd_signature()}; + case task_id_t::POOL2D_INIT_TASK_ID: + return TaskSignatureAndImpl{get_pool_2d_init_task_impl(), + get_pool_2d_init_signature()}; + case task_id_t::POOL2D_FWD_TASK_ID: + return TaskSignatureAndImpl{get_pool_2d_fwd_task_impl(), + get_pool_2d_fwd_signature()}; + case task_id_t::POOL2D_BWD_TASK_ID: + return TaskSignatureAndImpl{get_pool_2d_bwd_task_impl(), + get_pool_2d_bwd_signature()}; + case task_id_t::BATCHNORM_INIT_TASK_ID: + return TaskSignatureAndImpl{get_batch_norm_init_task_impl(), + get_batch_norm_init_signature()}; + case task_id_t::BATCHNORM_FWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_norm_fwd_task_impl(), + get_batch_norm_fwd_signature()}; + case task_id_t::BATCHNORM_BWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_norm_bwd_task_impl(), + get_batch_norm_bwd_signature()}; + case task_id_t::BATCHMATMUL_FWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_matmul_fwd_task_impl(), + get_batch_matmul_fwd_signature()}; + case task_id_t::BATCHMATMUL_BWD_TASK_ID: + return TaskSignatureAndImpl{get_batch_matmul_bwd_task_impl(), + get_batch_matmul_bwd_signature()}; + case task_id_t::LAYERNORM_INIT_TASK_ID: + return TaskSignatureAndImpl{get_layer_norm_init_task_impl(), + get_layer_norm_init_signature()}; + case task_id_t::LAYERNORM_FWD_TASK_ID: + return TaskSignatureAndImpl{get_layer_norm_fwd_task_impl(), + get_layer_norm_init_signature()}; + case task_id_t::LAYERNORM_BWD_TASK_ID: + return TaskSignatureAndImpl{get_layer_norm_bwd_task_impl(), + get_layer_norm_bwd_signature()}; + case task_id_t::LINEAR_INIT_TASK_ID: + return TaskSignatureAndImpl{get_linear_init_task_impl(), + get_linear_init_signature()}; + case task_id_t::LINEAR_FWD_TASK_ID: + return TaskSignatureAndImpl{get_linear_fwd_task_impl(), + get_linear_fwd_signature()}; + case task_id_t::LINEAR_BWD_TASK_ID: + return TaskSignatureAndImpl{get_linear_bwd_task_impl(), + get_linear_bwd_signature()}; + case task_id_t::FLAT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_flat_fwd_task_impl(), + get_flat_fwd_signature()}; + case task_id_t::FLAT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_flat_bwd_task_impl(), + get_flat_bwd_signature()}; + case task_id_t::SOFTMAX_INIT_TASK_ID: + return TaskSignatureAndImpl{get_softmax_init_task_impl(), + get_softmax_init_signature()}; + case task_id_t::SOFTMAX_FWD_TASK_ID: + return TaskSignatureAndImpl{get_softmax_fwd_task_impl(), + get_softmax_fwd_signature()}; + case task_id_t::SOFTMAX_BWD_TASK_ID: + return TaskSignatureAndImpl{get_softmax_bwd_task_impl(), + get_softmax_bwd_signature()}; + case task_id_t::CONCAT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_concat_fwd_task_impl(), + get_concat_fwd_signature()}; + case task_id_t::CONCAT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_concat_bwd_task_impl(), + get_concat_bwd_signature()}; + case task_id_t::SPLIT_FWD_TASK_ID: + return TaskSignatureAndImpl{get_split_fwd_task_impl(), + get_split_fwd_signature()}; + case task_id_t::SPLIT_BWD_TASK_ID: + return TaskSignatureAndImpl{get_split_bwd_task_impl(), + get_split_bwd_signature()}; + case task_id_t::REDUCE_INIT_TASK_ID: + return TaskSignatureAndImpl{get_reduce_init_task_impl(), + get_reduce_init_signature()}; + case task_id_t::REDUCE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reduce_fwd_task_impl(), + get_reduce_fwd_signature()}; + case task_id_t::REDUCE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reduce_bwd_task_impl(), + get_reduce_bwd_signature()}; + case task_id_t::RESHAPE_INIT_TASK_ID: + return TaskSignatureAndImpl{get_reshape_init_task_impl(), + get_reshape_init_signature()}; + case task_id_t::RESHAPE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reshape_fwd_task_impl(), + get_reshape_fwd_signature()}; + case task_id_t::RESHAPE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reshape_bwd_task_impl(), + get_reshape_bwd_signature()}; + case task_id_t::REVERSE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reverse_fwd_task_impl(), + get_reverse_fwd_signature()}; + case task_id_t::REVERSE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reverse_bwd_task_impl(), + get_reverse_bwd_signature()}; + case task_id_t::TOPK_INIT_TASK_ID: + return TaskSignatureAndImpl{get_topk_init_task_impl(), + get_topk_init_signature()}; + case task_id_t::TOPK_FWD_TASK_ID: + return TaskSignatureAndImpl{get_topk_fwd_task_impl(), + get_topk_fwd_signature()}; + case task_id_t::TOPK_BWD_TASK_ID: + return TaskSignatureAndImpl{get_topk_bwd_task_impl(), + get_topk_bwd_signature()}; + case task_id_t::TRANSPOSE_INIT_TASK_ID: + return TaskSignatureAndImpl{get_transpose_init_task_impl(), + get_transpose_init_signature()}; + case task_id_t::TRANSPOSE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_transpose_fwd_task_impl(), + get_transpose_fwd_signature()}; + case task_id_t::TRANSPOSE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_transpose_bwd_task_impl(), + get_transpose_bwd_signature()}; + case task_id_t::ATTENTION_INIT_TASK_ID: + return TaskSignatureAndImpl{get_attention_init_task_impl(), + get_attention_init_signature()}; + case task_id_t::ATTENTION_FWD_TASK_ID: + return TaskSignatureAndImpl{get_attention_fwd_task_impl(), + get_attention_fwd_signature()}; + case task_id_t::ATTENTION_BWD_TASK_ID: + return TaskSignatureAndImpl{get_attention_bwd_task_impl(), + get_attention_bwd_signature()}; + case task_id_t::COMBINE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_combine_fwd_task_impl(), + get_combine_fwd_signature()}; + case task_id_t::COMBINE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_combine_bwd_task_impl(), + get_combine_bwd_signature()}; + case task_id_t::REDUCTION_FWD_TASK_ID: + return TaskSignatureAndImpl{get_reduction_fwd_task_impl(), + get_reduction_fwd_signature()}; + case task_id_t::REDUCTION_BWD_TASK_ID: + return TaskSignatureAndImpl{get_reduction_bwd_task_impl(), + get_reduction_bwd_signature()}; + case task_id_t::REPARTITION_INIT_TASK_ID: + return TaskSignatureAndImpl{get_repartition_init_task_impl(), + get_repartition_init_signature()}; + case task_id_t::REPARTITION_FWD_TASK_ID: + return TaskSignatureAndImpl{get_repartition_fwd_task_impl(), + get_repartition_fwd_signature()}; + case task_id_t::REPARTITION_BWD_TASK_ID: + return TaskSignatureAndImpl{get_repartition_bwd_task_impl(), + get_repartition_bwd_signature()}; + case task_id_t::REPLICATE_FWD_TASK_ID: + return TaskSignatureAndImpl{get_replicate_fwd_task_impl(), + get_replicate_fwd_signature()}; + case task_id_t::REPLICATE_BWD_TASK_ID: + return TaskSignatureAndImpl{get_replicate_bwd_task_impl(), + get_replicate_bwd_signature()}; default: throw mk_runtime_error( fmt::format("Invalid task ID")); // inserting task_id yields // "type_is_unformattable" error } } + +std::vector get_task_ids(ComputationGraphOpAttrs const &op) { + return op.visit>(overload{ + [](BatchMatmulAttrs const &attrs) { return get_task_ids(attrs); }, + [](BatchNormAttrs const &attrs) { return get_task_ids(attrs); }, + [](CastAttrs const &attrs) { return get_task_ids(attrs); }, + [](ConcatAttrs const &attrs) { return get_task_ids(attrs); }, + [](Conv2DAttrs const &attrs) { return get_task_ids(attrs); }, + [](DropoutAttrs const &attrs) { return get_task_ids(attrs); }, + [](ElementBinaryAttrs const &attrs) { return get_task_ids(attrs); }, + [](ElementUnaryAttrs const &attrs) { return get_task_ids(attrs); }, + // [](EmbeddingAttrs const & attrs) { + // return get_task_ids(attrs); + // }, + [](FlatAttrs const &attrs) { return get_task_ids(attrs); }, + [](GatherAttrs const &attrs) { return get_task_ids(attrs); }, + [](InputAttrs const &attrs) { return get_task_ids(attrs); }, + [](LayerNormAttrs const &attrs) { return get_task_ids(attrs); }, + [](LinearAttrs const &attrs) { return get_task_ids(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return get_task_ids(attrs); }, + [](NoopAttrs const &attrs) { return get_task_ids(attrs); }, + [](Pool2DAttrs const &attrs) { return get_task_ids(attrs); }, + [](ReduceAttrs const &attrs) { return get_task_ids(attrs); }, + [](ReverseAttrs const &attrs) { return get_task_ids(attrs); }, + [](ReshapeAttrs const &attrs) { return get_task_ids(attrs); }, + [](SplitAttrs const &attrs) { return get_task_ids(attrs); }, + [](SoftmaxAttrs const &attrs) { return get_task_ids(attrs); }, + [](TopKAttrs const &attrs) { return get_task_ids(attrs); }, + [](TransposeAttrs const &attrs) { return get_task_ids(attrs); }, + [](WeightAttrs const &attrs) { return get_task_ids(attrs); }, + [](auto const &attrs) -> std::vector { + throw mk_runtime_error(fmt::format("Unhandled attr type: {}", attrs)); + }, + }); +} + +OpTaskInvocation init(ComputationGraphOpAttrs const &op) { + return op.visit(overload{ + [](BatchNormAttrs const &attrs) { return init(attrs); }, + [](Conv2DAttrs const &attrs) { return init(attrs); }, + [](DropoutAttrs const &attrs) { return init(attrs); }, + [](ElementBinaryAttrs const &attrs) { return init(attrs); }, + [](ElementUnaryAttrs const &attrs) { return init(attrs); }, + [](GatherAttrs const &attrs) { return init(attrs); }, + [](LayerNormAttrs const &attrs) { return init(attrs); }, + [](LinearAttrs const &attrs) { return init(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return init(attrs); }, + [](Pool2DAttrs const &attrs) { return init(attrs); }, + [](ReduceAttrs const &attrs) { return init(attrs); }, + [](ReshapeAttrs const &attrs) { return init(attrs); }, + [](SoftmaxAttrs const &attrs) { return init(attrs); }, + [](TopKAttrs const &attrs) { return init(attrs); }, + [](TransposeAttrs const &attrs) { return init(attrs); }, + [](auto const &attrs) -> OpTaskInvocation { + throw mk_runtime_error(fmt::format("Unhandled attr type {}", attrs)); + }, + }); +} + +OpTaskInvocation forward(ComputationGraphOpAttrs const &op) { + return op.visit(overload{ + [](BatchMatmulAttrs const &attrs) { return forward(attrs); }, + [](BatchNormAttrs const &attrs) { return forward(attrs); }, + [](CastAttrs const &attrs) { return forward(attrs); }, + [](ConcatAttrs const &attrs) { return forward(attrs); }, + [](Conv2DAttrs const &attrs) { return forward(attrs); }, + [](DropoutAttrs const &attrs) { return forward(attrs); }, + [](ElementBinaryAttrs const &attrs) { return forward(attrs); }, + [](ElementUnaryAttrs const &attrs) { return forward(attrs); }, + // [](EmbeddingAttrs const & attrs) { + // return forward(attrs); + // }, + [](FlatAttrs const &attrs) { return forward(attrs); }, + [](GatherAttrs const &attrs) { return forward(attrs); }, + [](LayerNormAttrs const &attrs) { return forward(attrs); }, + [](LinearAttrs const &attrs) { return forward(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return forward(attrs); }, + [](Pool2DAttrs const &attrs) { return forward(attrs); }, + [](ReduceAttrs const &attrs) { return forward(attrs); }, + [](ReverseAttrs const &attrs) { return forward(attrs); }, + [](ReshapeAttrs const &attrs) { return forward(attrs); }, + [](SplitAttrs const &attrs) { return forward(attrs); }, + [](SoftmaxAttrs const &attrs) { return forward(attrs); }, + [](TopKAttrs const &attrs) { return forward(attrs); }, + [](TransposeAttrs const &attrs) { return forward(attrs); }, + [](auto const &attrs) -> OpTaskInvocation { + throw mk_runtime_error(fmt::format("Unhandled attr type {}", attrs)); + }, + }); +} + +OpTaskInvocation backward(ComputationGraphOpAttrs const &op) { + return op.visit(overload{ + [](BatchMatmulAttrs const &attrs) { return backward(attrs); }, + [](BatchNormAttrs const &attrs) { return backward(attrs); }, + [](CastAttrs const &attrs) { return backward(attrs); }, + [](ConcatAttrs const &attrs) { return backward(attrs); }, + [](Conv2DAttrs const &attrs) { return backward(attrs); }, + [](DropoutAttrs const &attrs) { return backward(attrs); }, + [](ElementBinaryAttrs const &attrs) { return backward(attrs); }, + [](ElementUnaryAttrs const &attrs) { return backward(attrs); }, + // [](EmbeddingAttrs const & attrs) { + // return backward(attrs); + // }, + [](FlatAttrs const &attrs) { return backward(attrs); }, + [](GatherAttrs const &attrs) { return backward(attrs); }, + [](LayerNormAttrs const &attrs) { return backward(attrs); }, + [](LinearAttrs const &attrs) { return backward(attrs); }, + [](MultiHeadAttentionAttrs const &attrs) { return backward(attrs); }, + [](Pool2DAttrs const &attrs) { return backward(attrs); }, + [](ReduceAttrs const &attrs) { return backward(attrs); }, + [](ReverseAttrs const &attrs) { return backward(attrs); }, + [](ReshapeAttrs const &attrs) { return backward(attrs); }, + [](SplitAttrs const &attrs) { return backward(attrs); }, + [](SoftmaxAttrs const &attrs) { return backward(attrs); }, + [](TopKAttrs const &attrs) { return backward(attrs); }, + [](TransposeAttrs const &attrs) { return backward(attrs); }, + [](auto const &attrs) -> OpTaskInvocation { + throw mk_runtime_error(fmt::format("Unhandled attr type {}", attrs)); + }, + }); +} + } // namespace FlexFlow diff --git a/lib/local-execution/src/tracked_allocator.cc b/lib/local-execution/src/tracked_allocator.cc index 68636906c3..e6c3a11711 100644 --- a/lib/local-execution/src/tracked_allocator.cc +++ b/lib/local-execution/src/tracked_allocator.cc @@ -7,13 +7,14 @@ TrackedAllocator::TrackedAllocator(Allocator a) : allocator(a) {} void *TrackedAllocator::allocate(size_t requested_memory_size) { void *ptr = this->allocator.allocate(requested_memory_size); + this->ptr_mem_usage.insert({ptr, requested_memory_size}); this->current_mem_usage += requested_memory_size; return ptr; } void TrackedAllocator::deallocate(void *ptr) { size_t psize; - checkCUDA(cudaGetSymbolSize(&psize, ptr)); + this->ptr_mem_usage.erase(ptr); this->allocator.deallocate(ptr); this->current_mem_usage -= psize; } diff --git a/lib/local-execution/test/CMakeLists.txt b/lib/local-execution/test/CMakeLists.txt new file mode 100644 index 0000000000..930ab5c4e2 --- /dev/null +++ b/lib/local-execution/test/CMakeLists.txt @@ -0,0 +1,14 @@ +ff_add_test_executable( + NAME + local-execution-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + doctest + utils-test-common + local-execution + kernels + op-attrs +) diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc new file mode 100644 index 0000000000..2bd0acc222 --- /dev/null +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -0,0 +1,77 @@ +#include "doctest/doctest.h" +#include "kernels/local_cuda_allocator.h" +#include "kernels/managed_per_device_ff_handle.h" +#include "local-execution/local_cost_estimator.h" +#include "pcg/computation_graph_builder.h" +#include "test_utils.h" + +namespace FlexFlow { + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("Local Cost Estimator") { + // local backing initialization + ManagedPerDeviceFFHandle managed_handle{}; + + RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ + DeviceSpecific::create(managed_handle.raw_handle()), + EnableProfiling::YES, + ProfilingSettings{/*warmup_iters=*/0, + /*measure_iters=*/1}}; + + LocalCostEstimator cost_estimator = LocalCostEstimator{runtime_arg_config}; + + SUBCASE("Estimate cost -- Attention Op") { + int embed_dim = 32; + int num_heads = 10; + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + + size_t batch_size = 40; + size_t seq_len = 48; + size_t feature_size = 36; + + DataType dtype = DataType::FLOAT; + ParallelTensorShape inputs_shape = lift_to_parallel(TensorShape{ + TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, + DataType::FLOAT, + }); + + ParallelTensorShape weights_shape = throw_if_unexpected( + get_weights_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); + ParallelTensorAttrs weight_attrs = + ParallelTensorAttrs{weights_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + ParallelTensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); + ParallelTensorAttrs output_attrs = + ParallelTensorAttrs{output_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + CostDetails result = cost_estimator.estimate_cost( + PCGOperatorAttrs{attrs}, + std::vector{ + inputs_shape, inputs_shape, inputs_shape}, + std::vector{weight_attrs}, + std::vector{output_attrs}, + make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1})); + + CHECK(result.total_elapsed_time > 0); + CHECK(result.total_mem_usage > 0); + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_slots_backing.cc b/lib/local-execution/test/src/test_local_slots_backing.cc new file mode 100644 index 0000000000..542aa66087 --- /dev/null +++ b/lib/local-execution/test/src/test_local_slots_backing.cc @@ -0,0 +1,273 @@ +#include "doctest/doctest.h" +#include "kernels/attention_kernels.h" +#include "local-execution/local_cost_estimator.h" +#include "local-execution/local_cpu_allocator.h" +#include "local-execution/local_slots_backing.h" +#include "pcg/computation_graph_builder.h" +#include "test_utils.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/variant.h" +#include "utils/fmt/vector.h" + +namespace FlexFlow { + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("LocalSlotsBacking -- Attention Op") { + // allocate input memory + Allocator allocator = create_local_cpu_memory_allocator(); + int embed_dim = 32; + int num_heads = 10; + + size_t batch_size = 40; + size_t seq_len = 48; + size_t feature_size = 36; + + DataType dtype = DataType::FLOAT; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, + DataType::FLOAT, + }; + TensorShape query_shape = input_tensor_shape; + TensorShape key_shape = input_tensor_shape; + TensorShape value_shape = input_tensor_shape; + GenericTensorAccessorW query = allocator.allocate_tensor(query_shape); + GenericTensorAccessorW key = allocator.allocate_tensor(key_shape); + GenericTensorAccessorW value = allocator.allocate_tensor(value_shape); + + // build graph + ComputationGraphBuilder cg_builder; + tensor_guid_t query_guid = + cg_builder.create_tensor(query_shape, CreateGrad::YES); + tensor_guid_t key_guid = + cg_builder.create_tensor(key_shape, CreateGrad::YES); + tensor_guid_t value_guid = + cg_builder.create_tensor(value_shape, CreateGrad::YES); + + std::string layer_name = "attn1"; + tensor_guid_t output_guid = + cg_builder.multihead_attention(query_guid, + key_guid, + value_guid, + embed_dim, + num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0f, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + /*initializer=*/std::nullopt, + /*maybe_name=*/layer_name); + + layer_guid_t layer_guid = + get_layer_by_name(cg_builder.computation_graph, layer_name); + + TensorBackingMap tensor_backing_map = { + {query_guid, query}, {key_guid, key}, {value_guid, value}}; + + // runtime arg config + ProfilingSettings settings = ProfilingSettings{/*warmup_iters=*/0, + /*measure_iters=*/0}; + PerDeviceFFHandle handle = get_mock_per_device_ff_handle(); + RuntimeArgConfig runtime_arg_config = + RuntimeArgConfig{DeviceSpecific::create(handle), + EnableProfiling::NO, + settings}; + + LocalSlotsBacking local_slots_backing = {tensor_backing_map, + runtime_arg_config}; + + SUBCASE("LocalSlotsBacking::allocate_outgoing_tensors") { + auto get_result_shape_and_dtype_for_tensor_guid_and_map = + [&](tensor_guid_t t, + TensorBackingMap m) -> std::pair { + GenericTensorAccessorW accessor = m.at(t); + return get_shape_and_datatype(accessor); + }; + + SUBCASE("Input (QKV) and gradient tensors allocation") { + + // allocate all tensors from input nodes + for (layer_guid_t const &node : + topological_ordering(cg_builder.computation_graph)) { + if (node == layer_guid) { + break; + } + local_slots_backing.allocate_outgoing_tensors( + node, cg_builder.computation_graph, allocator); + } + + SUBCASE("Query grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + query_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = {ArrayShape{query_shape}, + dtype}; + CHECK(result == correct); + } + SUBCASE("Key grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + key_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = {ArrayShape{key_shape}, + dtype}; + CHECK(result == correct); + } + SUBCASE("Value grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + value_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = {ArrayShape{value_shape}, + dtype}; + CHECK(result == correct); + } + } + SUBCASE("Output and gradient tensors allocation") { + local_slots_backing.allocate_outgoing_tensors( + layer_guid, cg_builder.computation_graph, allocator); + SUBCASE("Output") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + output_guid, local_slots_backing.tensor_mapping); + std::pair correct = { + ArrayShape{ + get_tensor_attrs(cg_builder.computation_graph, output_guid) + .shape}, + dtype}; + CHECK(result == correct); + } + SUBCASE("Output grad") { + std::pair result = + get_result_shape_and_dtype_for_tensor_guid_and_map( + output_guid, local_slots_backing.gradient_tensor_mapping); + std::pair correct = { + ArrayShape{ + get_tensor_attrs(cg_builder.computation_graph, output_guid) + .shape}, + dtype}; + CHECK(result == correct); + } + } + + SUBCASE("Tensor slots") { + local_slots_backing.allocate_outgoing_tensors( + layer_guid, cg_builder.computation_graph, allocator); + SUBCASE("Input tensor slots") { + std::vector correct_incoming_tensors = + get_incoming_tensors(cg_builder.computation_graph, layer_guid); + CHECK(correct_incoming_tensors == + local_slots_backing.input_tensor_slots.at(layer_guid)); + } + SUBCASE("Output tensor slots") { + std::vector correct_outgoing_tensors = + get_outgoing_tensors(cg_builder.computation_graph, layer_guid); + CHECK(correct_outgoing_tensors == + local_slots_backing.output_tensor_slots.at(layer_guid)); + } + } + } + + SUBCASE("Construct Slots Backings") { + enum Slots { + QUERY, + KEY, + VALUE, + WEIGHTS, + OUTPUT, + QUERY_PARALLEL_TENSOR_SHAPE, + QPROJSIZE, + ATTRS, + PROFILING, + HANDLE, + }; + MultiHeadAttentionAttrs attrs = + get_layer_attrs(cg_builder.computation_graph, layer_guid) + .attrs.get(); + OpTaskBinding binding = [&] { + OpTaskBinding b; + b.bind(QUERY, input_tensor(0)); + b.bind(KEY, input_tensor(1)); + b.bind(VALUE, input_tensor(2)); + b.bind(WEIGHTS, weight_tensor(3)); + b.bind(OUTPUT, output_tensor(0)); + + b.bind_grad(QUERY, input_tensor(0)); + + b.bind_arg(QPROJSIZE, get_qProjSize(attrs)); + b.bind_arg(ATTRS, attrs); + b.bind_arg(QUERY_PARALLEL_TENSOR_SHAPE, input_parallel_tensor_shape(0)); + b.bind_arg(PROFILING, profiling_settings()); + b.bind_arg(HANDLE, ff_handle()); + return b; + }(); + + // allocate all incoming and outgoing tensors for graph + for (layer_guid_t const &node : + topological_ordering(cg_builder.computation_graph)) { + local_slots_backing.allocate_outgoing_tensors( + node, cg_builder.computation_graph, allocator); + } + + SUBCASE("LocalSlotsBacking::construct_tensor_slots_backing") { + TensorSlotsBackingWithoutAddresses result = + get_slots_backing_without_tensor_allocation_addresses( + local_slots_backing.construct_tensor_slots_backing(binding, + layer_guid)); + TensorSlotsBackingWithoutAddresses correct = [&] { + TensorShape weights_shape = throw_if_unexpected( + get_weights_shape(attrs, query_shape, key_shape, value_shape)); + GenericTensorAccessorW weights = + allocator.allocate_tensor(weights_shape); + + TensorAttrs output_attrs = + get_tensor_attrs(cg_builder.computation_graph, output_guid); + GenericTensorAccessorW output = + allocator.allocate_tensor(output_attrs.shape); + return get_slots_backing_without_tensor_allocation_addresses( + TensorSlotsBacking{ + {SlotGradId{slot_id_t{QUERY}, IsGrad::NO}, query}, + {SlotGradId{slot_id_t{KEY}, IsGrad::NO}, key}, + {SlotGradId{slot_id_t{VALUE}, IsGrad::NO}, value}, + {SlotGradId{slot_id_t{WEIGHTS}, IsGrad::NO}, weights}, + {SlotGradId{slot_id_t{OUTPUT}, IsGrad::NO}, output}, + {SlotGradId{slot_id_t{QUERY}, IsGrad::YES}, query}}); + }(); + + CHECK(result == correct); + } + SUBCASE("LocalSlotsBacking::construct_arg_slots_backing") { + ArgSlotsBacking result = + local_slots_backing.construct_arg_slots_backing(binding, + layer_guid); + + ArgSlotsBacking correct = [&] { + ParallelTensorShape query_parallel_tensor_shape = + lift_to_parallel(query_shape); + + return ArgSlotsBacking{ + {slot_id_t{QPROJSIZE}, + ConcreteArgSpec::create(get_qProjSize(attrs))}, + {slot_id_t{ATTRS}, ConcreteArgSpec::create(attrs)}, + {slot_id_t{QUERY_PARALLEL_TENSOR_SHAPE}, + ConcreteArgSpec::create(query_parallel_tensor_shape)}, + {slot_id_t{PROFILING}, + ConcreteArgSpec::create(runtime_arg_config.profiling_settings)}, + {slot_id_t{HANDLE}, ConcreteArgSpec::create(handle)}}; + }(); + + CHECK(result == correct); + } + + SUBCASE("LocalSlotsBacking::resolve_runtime_arg_ref_spec") { + RuntimeArgRefSpec ref_spec = RuntimeArgRefSpec::create(ff_handle()); + ConcreteArgSpec arg_spec = + local_slots_backing.resolve_runtime_arg_ref_spec(ref_spec); + + PerDeviceFFHandle result_handle = arg_spec.get(); + CHECK(result_handle == handle); + } + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_task_arg_accessor.cc b/lib/local-execution/test/src/test_local_task_arg_accessor.cc new file mode 100644 index 0000000000..819d773e00 --- /dev/null +++ b/lib/local-execution/test/src/test_local_task_arg_accessor.cc @@ -0,0 +1,143 @@ +#include "doctest/doctest.h" +#include "local-execution/local_cpu_allocator.h" +#include "local-execution/local_task_argument_accessor.h" +#include "local-execution/task_signature_impl.h" + +namespace FlexFlow { + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("LocalTaskArgumentAccessor") { + Allocator allocator = create_local_cpu_memory_allocator(); + int embed_dim = 32; + int num_heads = 10; + + size_t batch_size = 40; + size_t seq_len = 48; + size_t feature_size = 36; + + DataType dtype = DataType::FLOAT; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, + DataType::FLOAT, + }; + + GenericTensorAccessorW input = + allocator.allocate_tensor(input_tensor_shape); + GenericTensorAccessorW input_grad = + allocator.allocate_tensor(input_tensor_shape); + + std::vector variadic_tensors = {input, input}; + std::vector variadic_tensors_grad = {input_grad, + input_grad}; + + enum Slots { + INPUT, + VARIADIC_TENSORS, + }; + + TensorSlotsBacking tensor_slots_backing = { + {SlotGradId{slot_id_t{INPUT}, IsGrad::NO}, input}, + {SlotGradId{slot_id_t{INPUT}, IsGrad::YES}, input_grad}, + {SlotGradId{slot_id_t{VARIADIC_TENSORS}, IsGrad::NO}, variadic_tensors}, + {SlotGradId{slot_id_t{VARIADIC_TENSORS}, IsGrad::YES}, + variadic_tensors_grad}, + }; + + LocalTaskArgumentAccessor acc = {allocator, tensor_slots_backing, {}}; + + SUBCASE("get_tensor") { + SUBCASE("get_tensor(slot_id_t, Permissions::RO, IsGrad::NO)") { + GenericTensorAccessor correct = GenericTensorAccessor{ + read_only_accessor_from_write_accessor(input)}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RO, IsGrad::NO); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::RO, IsGrad::YES)") { + GenericTensorAccessor correct = GenericTensorAccessor{ + read_only_accessor_from_write_accessor(input_grad)}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RO, IsGrad::YES); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::WO, IsGrad::NO)") { + GenericTensorAccessor correct = GenericTensorAccessor{input}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::WO, IsGrad::NO); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::WO, IsGrad::YES)") { + GenericTensorAccessor correct = GenericTensorAccessor{input_grad}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::WO, IsGrad::YES); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::RW, IsGrad::NO)") { + GenericTensorAccessor correct = GenericTensorAccessor{input}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RW, IsGrad::NO); + CHECK(correct == result); + } + SUBCASE("get_tensor(slot_id_t, Permissions::RW, IsGrad::YES)") { + GenericTensorAccessor correct = GenericTensorAccessor{input_grad}; + GenericTensorAccessor result = + acc.get_tensor(slot_id_t{INPUT}, Permissions::RW, IsGrad::YES); + CHECK(correct == result); + } + } + + SUBCASE("get_variadic_tensor") { + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::RO, IsGrad::NO)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{std::vector{ + read_only_accessor_from_write_accessor(variadic_tensors.at(0)), + read_only_accessor_from_write_accessor( + variadic_tensors.at(1))}}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RO, IsGrad::NO); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::RO, IsGrad::YES)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{std::vector{ + read_only_accessor_from_write_accessor( + variadic_tensors_grad.at(0)), + read_only_accessor_from_write_accessor( + variadic_tensors_grad.at(1))}}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RO, IsGrad::YES); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::NO)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::WO, IsGrad::NO); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::YES)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors_grad}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::WO, IsGrad::YES); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::NO)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RW, IsGrad::NO); + CHECK(result == correct); + } + SUBCASE("get_variadic_tensor(slot_id_t, Permissions::WO, IsGrad::YES)") { + VariadicGenericTensorAccessor correct = + VariadicGenericTensorAccessor{variadic_tensors_grad}; + VariadicGenericTensorAccessor result = acc.get_variadic_tensor( + slot_id_t{VARIADIC_TENSORS}, Permissions::RW, IsGrad::YES); + CHECK(result == correct); + } + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_task_registry.cc b/lib/local-execution/test/src/test_task_registry.cc new file mode 100644 index 0000000000..fa3b068425 --- /dev/null +++ b/lib/local-execution/test/src/test_task_registry.cc @@ -0,0 +1,131 @@ +#include "doctest/doctest.h" +#include "kernels/local_cuda_allocator.h" +#include "local-execution/local_cost_estimator.h" +#include "local-execution/ops/attention.h" +#include "local-execution/task_signature_impl.h" +#include "pcg/computation_graph_builder.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/unordered_map.h" + +namespace FlexFlow { + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Task Registry") { + TaskRegistry task_registry = empty_task_registry(); + + layer_guid_t layer_guid = layer_guid_t{Node{0}}; + int embed_dim = 32; + int num_heads = 10; + ComputationGraphOpAttrs attrs = + ComputationGraphOpAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }}; + + SUBCASE("register single layer") { + register_tasks_for_layer(task_registry, layer_guid, attrs); + + TaskRegistry correct_task_registry = [&] { + std::unordered_map> + init_task_ids = {{layer_guid, task_id_t::ATTENTION_INIT_TASK_ID}}; + std::unordered_map> + fwd_task_ids = {{layer_guid, task_id_t::ATTENTION_FWD_TASK_ID}}; + std::unordered_map> + bwd_task_ids = {{layer_guid, task_id_t::ATTENTION_BWD_TASK_ID}}; + std::unordered_map task_mapping = { + {task_id_t::ATTENTION_INIT_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_INIT_TASK_ID)}, + {task_id_t::ATTENTION_FWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_FWD_TASK_ID)}, + {task_id_t::ATTENTION_BWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_BWD_TASK_ID)}}; + return TaskRegistry{ + init_task_ids, fwd_task_ids, bwd_task_ids, task_mapping}; + }(); + + CHECK(task_registry == correct_task_registry); + } + + SUBCASE("multiple layers same task") { + layer_guid_t other_layer_guid = layer_guid_t{Node{1}}; + register_tasks_for_layer(task_registry, layer_guid, attrs); + register_tasks_for_layer(task_registry, other_layer_guid, attrs); + + SUBCASE("layer to task ids") { + std::unordered_map> correct = { + {layer_guid, task_id_t::ATTENTION_INIT_TASK_ID}, + {other_layer_guid, task_id_t::ATTENTION_INIT_TASK_ID}, + }; + CHECK(correct == task_registry.init_task_ids); + } + + std::unordered_map correct_task_mapping = + {{task_id_t::ATTENTION_INIT_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_INIT_TASK_ID)}, + {task_id_t::ATTENTION_FWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_FWD_TASK_ID)}, + {task_id_t::ATTENTION_BWD_TASK_ID, + get_task_sig_impl(task_id_t::ATTENTION_BWD_TASK_ID)}}; + SUBCASE("task to signature+impl mapping") { + CHECK(correct_task_mapping == task_registry.task_mapping); + } + SUBCASE("different attrs, still same task fn mapping") { + int embed_dim = 100; + layer_guid_t layer_3 = layer_guid_t{Node{3}}; + ComputationGraphOpAttrs other_attrs = + ComputationGraphOpAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }}; + register_tasks_for_layer(task_registry, layer_3, other_attrs); + + CHECK(correct_task_mapping == task_registry.task_mapping); + } + } + + SUBCASE("equality") { + TaskRegistry other_task_registry = empty_task_registry(); + SUBCASE("different attrs is still equal") { + int embed_dim = 100; + ComputationGraphOpAttrs other_attrs = + ComputationGraphOpAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/embed_dim, + /*num_heads=*/num_heads, + /*kdim=*/embed_dim, + /*vdim=*/embed_dim, + /*dropout=*/0.0, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }}; + + register_tasks_for_layer(task_registry, layer_guid, attrs); + register_tasks_for_layer(other_task_registry, layer_guid, other_attrs); + + CHECK(task_registry == other_task_registry); + } + + SUBCASE("different layer_guid is not equal") { + register_tasks_for_layer(task_registry, layer_guid, attrs); + layer_guid_t other_layer_guid = layer_guid_t{Node{1}}; + register_tasks_for_layer(other_task_registry, other_layer_guid, attrs); + + CHECK(task_registry != other_task_registry); + } + } + } +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_utils.cc b/lib/local-execution/test/src/test_utils.cc new file mode 100644 index 0000000000..095e1272a2 --- /dev/null +++ b/lib/local-execution/test/src/test_utils.cc @@ -0,0 +1,9 @@ +#include "test_utils.h" + +namespace FlexFlow { + +PerDeviceFFHandle get_mock_per_device_ff_handle() { + return {nullptr, nullptr, nullptr, 0, false}; +} + +} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_utils.h b/lib/local-execution/test/src/test_utils.h new file mode 100644 index 0000000000..9a7b3f5991 --- /dev/null +++ b/lib/local-execution/test/src/test_utils.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LOCAL_EXECUTION_TEST_UTILS +#define _FLEXFLOW_LOCAL_EXECUTION_TEST_UTILS + +#include "kernels/ff_handle.h" + +namespace FlexFlow { + +PerDeviceFFHandle get_mock_per_device_ff_handle(); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 3e4095eca8..036daa6e67 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -43,28 +43,52 @@ int get_vSize(TensorShape const &value_shape) { return dim_at_idx(value_shape, ff_dim_t(0)); } -int get_qSize(MultiHeadAttentionParallelInputs const &) { - NOT_IMPLEMENTED(); +int get_qSize(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.query_dim.size; } -int get_qSize(MultiHeadAttentionInputs const &) { - NOT_IMPLEMENTED(); +int get_qSize(MultiHeadAttentionInputs const &inputs) { + return inputs.query_size; } -int get_kSize(MultiHeadAttentionParallelInputs const &) { - NOT_IMPLEMENTED(); +int get_kSize(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.key_dim.size; } -int get_kSize(MultiHeadAttentionInputs const &) { - NOT_IMPLEMENTED(); +int get_kSize(MultiHeadAttentionInputs const &inputs) { + return inputs.key_size; } -int get_vSize(MultiHeadAttentionParallelInputs const &) { - NOT_IMPLEMENTED(); +int get_vSize(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.value_dim.size; } -int get_vSize(MultiHeadAttentionInputs const &) { - NOT_IMPLEMENTED(); +int get_vSize(MultiHeadAttentionInputs const &inputs) { + return inputs.value_size; +} + +int get_kvSeqLength(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.sequence_dim.size; +} + +int get_kvSeqLength(MultiHeadAttentionInputs const &inputs) { + return inputs.sequence_length; +} + +int get_qoSeqLength(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.sequence_dim.size; // FIXME -- assumes only prefill +} + +int get_qoSeqLength(MultiHeadAttentionInputs const &inputs) { + return inputs.sequence_length; // FIXME -- assumes only prefil +} + +int get_num_samples(MultiHeadAttentionParallelInputs const &inputs) { + return inputs.batch_dim.size; +} + +int get_num_samples(MultiHeadAttentionInputs const &inputs) { + return inputs.batch_size; } tl::expected diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index cdb20c2303..46d5b22afb 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -29,6 +29,9 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); +layer_guid_t get_layer_by_name(ComputationGraph const &cg, + std::string const &name); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index c52ec2d5bb..0e453c50c3 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -168,10 +168,11 @@ struct ComputationGraphBuilder { DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - tensor_guid_t concat(int n, - std::vector const &tensors, - int axis, - std::optional const &name = std::nullopt); + tensor_guid_t + concat(int n, + std::vector const &tensors, + int axis, + std::optional const &maybe_name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, @@ -224,7 +225,7 @@ struct ComputationGraphBuilder { bool add_bias_kv = false, bool add_zero_attn = false, std::optional initializer = std::nullopt, - std::optional const &name = std::nullopt); + std::optional const &maybe_name = std::nullopt); tensor_guid_t create_tensor(TensorShape const &, CreateGrad); tensor_guid_t create_weight( TensorShape const &, diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 43eb3ac42b..afa1774858 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,4 +1,5 @@ #include "pcg/computation_graph.h" +#include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -54,4 +55,13 @@ LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n) { return cg.raw_graph.at(n.raw_node); } +layer_guid_t get_layer_by_name(ComputationGraph const &cg, + std::string const &name) { + std::unordered_set found = + filter(get_layers(cg), [&](layer_guid_t const &l) { + return get_layer_attrs(cg, l).name == name; + }); + return get_only(found); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 1dbe191970..5028ed5709 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -502,6 +502,47 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( return this->add_layer(layer, {input}, {}, output_shape); } +tensor_guid_t ComputationGraphBuilder::multihead_attention( + tensor_guid_t const &query, + tensor_guid_t const &key, + tensor_guid_t const &value, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool bias, + bool add_bias_kv, + bool add_zero_attn, + std::optional initializer, + std::optional const &maybe_name) { + + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{embed_dim, + num_heads, + kdim, + vdim, + dropout, + bias, + add_bias_kv, + add_zero_attn}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = throw_if_unexpected(get_output_shape( + attrs, get_shape(query), get_shape(key), get_shape(value))); + + TensorShape weights_shape = throw_if_unexpected(get_weights_shape( + attrs, get_shape(query), get_shape(key), get_shape(value))); + TensorAttrs weight_attrs = make_weight_attrs(weights_shape, initializer); + + return this->add_layer(layer, + std::vector{query, key, value}, + {weight_attrs}, + output_shape); +} + TensorShape ComputationGraphBuilder::get_broadcast_target_shape( std::vector const &) { NOT_IMPLEMENTED(); diff --git a/lib/local-execution/src/ops/embedding.cc b/lib/runtime/src/ops/embedding.cc similarity index 100% rename from lib/local-execution/src/ops/embedding.cc rename to lib/runtime/src/ops/embedding.cc diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h index db82004317..9eb717b066 100644 --- a/lib/utils/include/utils/join_strings.h +++ b/lib/utils/include/utils/join_strings.h @@ -18,7 +18,7 @@ std::string join_strings(InputIt first, if (!first_iter) { oss << delimiter; } - oss << *first; + oss << f(*first); /* break; */ first_iter = false; /* i++; */ From 1cfb07e1d6f87f1ba9a21babe0f1f5ca35070aee Mon Sep 17 00:00:00 2001 From: hsdfzhsdfz <157999659+hsdfzhsdfz@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:55:08 -0400 Subject: [PATCH 02/12] Add Transformer Model PCG (#1453) * Add transformer pcg prototype * Update based on review comments * Add models makefile and test * Update * Pass test * Enhance Transformer implementation * Reflect review comments * [WIP] Save initial refactor * Implement actual encorder decoder architecture * Remove duplicated definition * Update based on review * Update argument order * Implement most of the shape inference and ComputationGraphBuilder support * Fix bug in LayerNorm shape inference tests, disable {?} doctest default * Fix transformer test * Format * Few small fixes in respose to PR comments * Add asserts for transformer layer shapes * Make config default * Small style fixes and some additional docs --------- Co-authored-by: Colin Unger Co-authored-by: Colin Unger Co-authored-by: Reyna Abhyankar --- .github/workflows/per-lib-check.yml | 8 + .proj.toml | 4 +- cmake/doctestlib.cmake | 5 + lib/CMakeLists.txt | 1 + lib/kernels/include/kernels/accessor.h | 6 + lib/kernels/src/accessor.cc | 22 ++ .../test/src/test_local_task_arg_accessor.cc | 1 + lib/models/CMakeLists.txt | 16 + lib/models/include/models/transformer.h | 47 +++ .../models/transformer_config.struct.toml | 51 ++++ lib/models/src/models/transformer.cc | 165 +++++++++++ lib/models/test/CMakeLists.txt | 12 + lib/models/test/src/models/transformer.cc | 19 ++ .../computation_graph_op_attrs.variant.toml | 2 +- lib/op-attrs/include/op-attrs/dim_ordered.h | 1 + .../op-attrs/dim_ordered/ff_ordered_of.h | 15 + .../include/op-attrs/dim_ordered/get_idxs.h | 17 ++ .../include/op-attrs/dim_ordered/zip.h | 19 ++ lib/op-attrs/include/op-attrs/get_op_type.h | 2 +- .../include/op-attrs/get_output_shapes.h | 10 +- lib/op-attrs/include/op-attrs/ops/broadcast.h | 8 +- ...truct.toml => broadcast_attrs.struct.toml} | 4 +- lib/op-attrs/include/op-attrs/ops/dropout.h | 8 +- .../include/op-attrs/ops/layer_norm.h | 16 +- lib/op-attrs/include/op-attrs/ops/softmax.h | 10 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 5 + lib/op-attrs/include/op-attrs/tensor_shape.h | 5 + lib/op-attrs/src/layer_norm.cc | 3 - .../src/op-attrs/dim_ordered/ff_ordered_of.cc | 1 + .../src/op-attrs/dim_ordered/get_idxs.cc | 1 + lib/op-attrs/src/op-attrs/dim_ordered/zip.cc | 1 + lib/op-attrs/src/op-attrs/ops/broadcast.cc | 33 +++ lib/op-attrs/src/op-attrs/ops/dropout.cc | 25 +- lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 173 ++++++++++- lib/op-attrs/src/op-attrs/ops/softmax.cc | 49 ++- lib/op-attrs/src/op-attrs/tensor_dims.cc | 35 +++ lib/op-attrs/src/op-attrs/tensor_shape.cc | 33 +++ .../test/src/op-attrs/dim_ordered/zip.cc | 41 +++ lib/op-attrs/test/src/op-attrs/ops/dropout.cc | 101 +++++++ .../test/src/op-attrs/ops/layer_norm.cc | 278 ++++++++++++++++++ lib/op-attrs/test/src/op-attrs/ops/softmax.cc | 138 +++++++++ lib/op-attrs/test/src/op-attrs/tensor_dims.cc | 121 ++++++++ .../test/src/op-attrs/tensor_shape.cc | 64 ++++ lib/pcg/include/pcg/computation_graph.h | 3 + .../include/pcg/computation_graph_builder.h | 6 +- lib/pcg/src/pcg/computation_graph.cc | 6 + lib/pcg/src/pcg/computation_graph_builder.cc | 203 +++++++++++-- lib/utils/CMakeLists.txt | 1 + lib/utils/include/utils/containers/any_of.h | 18 ++ .../include/utils/containers/transform.h | 13 + .../utils/containers/unordered_multiset_of.h | 15 + lib/utils/include/utils/expected.h | 18 +- lib/utils/include/utils/fmt/expected.h | 30 +- lib/utils/include/utils/fmt/map.h | 12 + lib/utils/include/utils/fmt/multiset.h | 56 ++++ lib/utils/include/utils/fmt/optional.h | 12 + lib/utils/include/utils/fmt/pair.h | 12 + lib/utils/include/utils/fmt/set.h | 12 + lib/utils/include/utils/fmt/unordered_map.h | 12 + .../include/utils/fmt/unordered_multiset.h | 56 ++++ lib/utils/include/utils/fmt/unordered_set.h | 12 + lib/utils/include/utils/fmt/variant.h | 12 + lib/utils/include/utils/fmt/vector.h | 12 + lib/utils/include/utils/graph/query_set.h | 1 + lib/utils/include/utils/optional.h | 12 - lib/utils/include/utils/stack_vector.h | 21 +- lib/utils/src/utils/containers/any_of.cc | 1 + .../utils/containers/unordered_multiset_of.cc | 1 + lib/utils/src/utils/expected.cc | 1 + lib/utils/src/utils/fmt/multiset.cc | 1 + lib/utils/src/utils/fmt/unordered_multiset.cc | 1 + .../test/common/include/test/utils/doctest.h | 13 +- lib/utils/test/src/utils/bidict/bidict.cc | 4 +- .../bidict/try_merge_nondisjoint_bidicts.cc | 1 + lib/utils/test/src/utils/containers/any_of.cc | 36 +++ lib/utils/test/src/utils/containers/extend.cc | 2 + lib/utils/test/src/utils/containers/filter.cc | 5 + .../src/utils/containers/filtermap_keys.cc | 2 + .../src/utils/containers/filtermap_values.cc | 2 + .../utils/containers/get_element_counts.cc | 1 + .../src/utils/containers/inplace_filter.cc | 5 + .../test/src/utils/containers/intersection.cc | 2 + lib/utils/test/src/utils/containers/repeat.cc | 1 + .../test/src/utils/containers/transform.cc | 25 ++ .../try_merge_nondisjoint_unordered_maps.cc | 2 + .../utils/containers/unordered_multiset_of.cc | 15 + .../src/utils/containers/unordered_set_of.cc | 1 + .../src/utils/containers/without_order.cc | 1 + lib/utils/test/src/utils/expected.cc | 29 ++ lib/utils/test/src/utils/fmt/expected.cc | 20 ++ lib/utils/test/src/utils/fmt/unordered_set.cc | 13 +- ...plete_bipartite_composite_decomposition.cc | 1 + .../intermediate_sp_decomposition_tree.cc | 1 + .../serial_parallel_decomposition.cc | 1 + 94 files changed, 2218 insertions(+), 101 deletions(-) create mode 100644 lib/models/CMakeLists.txt create mode 100644 lib/models/include/models/transformer.h create mode 100644 lib/models/include/models/transformer_config.struct.toml create mode 100644 lib/models/src/models/transformer.cc create mode 100644 lib/models/test/CMakeLists.txt create mode 100644 lib/models/test/src/models/transformer.cc create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/zip.h rename lib/op-attrs/include/op-attrs/ops/{broadcast.struct.toml => broadcast_attrs.struct.toml} (68%) delete mode 100644 lib/op-attrs/src/layer_norm.cc create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_of.cc create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/zip.cc create mode 100644 lib/op-attrs/src/op-attrs/ops/broadcast.cc create mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/dropout.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/softmax.cc create mode 100644 lib/op-attrs/test/src/op-attrs/tensor_dims.cc create mode 100644 lib/op-attrs/test/src/op-attrs/tensor_shape.cc create mode 100644 lib/utils/include/utils/containers/any_of.h create mode 100644 lib/utils/include/utils/containers/unordered_multiset_of.h create mode 100644 lib/utils/include/utils/fmt/multiset.h create mode 100644 lib/utils/include/utils/fmt/unordered_multiset.h create mode 100644 lib/utils/src/utils/containers/any_of.cc create mode 100644 lib/utils/src/utils/containers/unordered_multiset_of.cc create mode 100644 lib/utils/src/utils/expected.cc create mode 100644 lib/utils/src/utils/fmt/multiset.cc create mode 100644 lib/utils/src/utils/fmt/unordered_multiset.cc create mode 100644 lib/utils/test/src/utils/containers/any_of.cc create mode 100644 lib/utils/test/src/utils/containers/unordered_multiset_of.cc create mode 100644 lib/utils/test/src/utils/expected.cc diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 38556a3c0e..639f4d82b5 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -92,6 +92,10 @@ jobs: run: | build_libs.sh local-execution + - name: Build models + run: | + build_libs.sh models + - name: Test utils run: | test_libs.sh utils @@ -120,6 +124,10 @@ jobs: run: | test_libs.sh local-execution + - name: Test models + run: | + test_libs.sh models + - name: Generate code coverage run: | echo "gitwork: $GITHUB_WORKSPACE" diff --git a/.proj.toml b/.proj.toml index ee91d07833..721d212e31 100644 --- a/.proj.toml +++ b/.proj.toml @@ -12,6 +12,7 @@ build_targets = [ "compiler", "substitution-generator", "local-execution", + "models", ] test_targets = [ @@ -22,7 +23,8 @@ test_targets = [ "substitutions-tests", "compiler-tests", "substitution-generator-tests", - "local-execution-tests" + "local-execution-tests", + "models-tests", ] [cmake_flags_extra] diff --git a/cmake/doctestlib.cmake b/cmake/doctestlib.cmake index 5f29d94fd0..ed78bf968b 100644 --- a/cmake/doctestlib.cmake +++ b/cmake/doctestlib.cmake @@ -8,4 +8,9 @@ else() include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake) endif() +target_compile_definitions( + doctest::doctest + INTERFACE + DOCTEST_CONFIG_REQUIRE_STRINGIFICATION_FOR_ALL_USED_TYPES +) alias_library(doctest doctest::doctest) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index b3735ed19f..972c656126 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(utils) add_subdirectory(ffi) add_subdirectory(substitutions) add_subdirectory(substitution-generator) +add_subdirectory(models) diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 2ee081ecbc..d3221474c0 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -39,6 +39,9 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(GenericTensorAccessorW, shape, ptr); +std::string format_as(GenericTensorAccessorW const &); +std::ostream &operator<<(std::ostream &, GenericTensorAccessorW const &); + class GenericTensorAccessorR { public: template @@ -67,6 +70,9 @@ FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(GenericTensorAccessorR, shape, ptr); +std::string format_as(GenericTensorAccessorR const &); +std::ostream &operator<<(std::ostream &, GenericTensorAccessorR const &); + int32_t *get_int32_ptr(GenericTensorAccessorW const &); int64_t *get_int64_ptr(GenericTensorAccessorW const &); float *get_float_ptr(GenericTensorAccessorW const &); diff --git a/lib/kernels/src/accessor.cc b/lib/kernels/src/accessor.cc index a852f0d7b3..27b7eb390d 100644 --- a/lib/kernels/src/accessor.cc +++ b/lib/kernels/src/accessor.cc @@ -22,6 +22,17 @@ half *GenericTensorAccessorW::get_half_ptr() const { return this->get(); } +std::string format_as(GenericTensorAccessorW const &a) { + return fmt::format("", + a.data_type, + a.shape, + a.ptr); +} + +std::ostream &operator<<(std::ostream &s, GenericTensorAccessorW const &a) { + return (s << fmt::to_string(a)); +} + int32_t const *GenericTensorAccessorR::get_int32_ptr() const { return this->get(); } @@ -42,6 +53,17 @@ half const *GenericTensorAccessorR::get_half_ptr() const { return get(); } +std::string format_as(GenericTensorAccessorR const &a) { + return fmt::format("", + a.data_type, + a.shape, + a.ptr); +} + +std::ostream &operator<<(std::ostream &s, GenericTensorAccessorR const &a) { + return (s << fmt::to_string(a)); +} + int32_t *get_int32_ptr(GenericTensorAccessorW const &a) { return get(a); } diff --git a/lib/local-execution/test/src/test_local_task_arg_accessor.cc b/lib/local-execution/test/src/test_local_task_arg_accessor.cc index 819d773e00..0637faaf1c 100644 --- a/lib/local-execution/test/src/test_local_task_arg_accessor.cc +++ b/lib/local-execution/test/src/test_local_task_arg_accessor.cc @@ -2,6 +2,7 @@ #include "local-execution/local_cpu_allocator.h" #include "local-execution/local_task_argument_accessor.h" #include "local-execution/task_signature_impl.h" +#include "utils/fmt/variant.h" namespace FlexFlow { diff --git a/lib/models/CMakeLists.txt b/lib/models/CMakeLists.txt new file mode 100644 index 0000000000..7dd7f48700 --- /dev/null +++ b/lib/models/CMakeLists.txt @@ -0,0 +1,16 @@ +ff_add_library( + NAME + models + SRC_PATTERNS + src/*.cc + PUBLIC_INCLUDE + include/ + PRIVATE_INCLUDE + src/ + DEPS + op-attrs + utils + pcg +) + +add_subdirectory(test) \ No newline at end of file diff --git a/lib/models/include/models/transformer.h b/lib/models/include/models/transformer.h new file mode 100644 index 0000000000..e50fa37709 --- /dev/null +++ b/lib/models/include/models/transformer.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H + +#include "models/transformer_config.dtg.h" +#include "pcg/computation_graph_builder.h" + +namespace FlexFlow { + +// Helper functions to construct the Transformer model +tensor_guid_t create_transformer_feedforward_network(ComputationGraphBuilder &, + TransformerConfig const &, + tensor_guid_t const &); +tensor_guid_t create_transformer_encoder_layer(ComputationGraphBuilder &, + TransformerConfig const &, + tensor_guid_t const &); +tensor_guid_t create_transformer_decoder_layer(ComputationGraphBuilder &, + TransformerConfig const &, + tensor_guid_t const &, + tensor_guid_t const &); + +tensor_guid_t create_transformer_encoder(ComputationGraphBuilder &, + TransformerConfig const &, + tensor_guid_t const &); +tensor_guid_t create_transformer_decoder(ComputationGraphBuilder &, + TransformerConfig const &, + tensor_guid_t const &, + tensor_guid_t const &); + +/** + * @brief Get the base config from the Attention Is All You Need paper. + * + * @details See the first row of the Table 3 at the top of p. 9 in + * https://arxiv.org/abs/1706.03762 + */ +TransformerConfig get_default_transformer_config(); + +/** + * @brief Get the Transformer computation graph. + * + * @param TransformerConfig The config of Transformer model. + * @return ComputationGraph The PCG of a Transformer model. + */ +ComputationGraph get_transformer_computation_graph(TransformerConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/transformer_config.struct.toml b/lib/models/include/models/transformer_config.struct.toml new file mode 100644 index 0000000000..23b0478dde --- /dev/null +++ b/lib/models/include/models/transformer_config.struct.toml @@ -0,0 +1,51 @@ +namespace = "FlexFlow" +name = "TransformerConfig" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_features" +type = "size_t" + +[[fields]] +name = "sequence_length" +type = "size_t" + +[[fields]] +name = "batch_size" +type = "size_t" + +[[fields]] +name = "dim_feedforward" +type = "size_t" + +[[fields]] +name = "num_heads" +type = "size_t" + +[[fields]] +name = "num_encoder_layers" +type = "size_t" + +[[fields]] +name = "num_decoder_layers" +type = "size_t" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "layer_norm_eps" +type = "float" + +[[fields]] +name = "vocab_size" +type = "size_t" diff --git a/lib/models/src/models/transformer.cc b/lib/models/src/models/transformer.cc new file mode 100644 index 0000000000..874cd85787 --- /dev/null +++ b/lib/models/src/models/transformer.cc @@ -0,0 +1,165 @@ +#include "models/transformer.h" +#include "pcg/computation_graph.h" + +namespace FlexFlow { + +TransformerConfig get_default_transformer_config() { + return TransformerConfig{/*num_features=*/512, + /*sequence_length=*/512, + /*batch_size=*/64, + /*dim_feedforward=*/2048, + /*num_heads=*/8, + /*num_encoder_layers=*/6, + /*num_decoder_layers=*/6, + /*dropout=*/0.1, + /*layer_norm_eps=*/1e-05, + /*vocab_size=*/64}; +} + +tensor_guid_t create_feedforward_network(ComputationGraphBuilder &cgb, + TransformerConfig const &config, + tensor_guid_t const &input) { + tensor_guid_t layer1_out = cgb.dense( + input, config.dim_feedforward, Activation::RELU, /*use_bias=*/true); + tensor_guid_t dropout_out = cgb.dropout(layer1_out, config.dropout); + tensor_guid_t layer2_out = cgb.dense(dropout_out, + config.num_features, + /*activation=*/std::nullopt, + /*use_bias=*/true); + return cgb.dropout(layer2_out, config.dropout); +}; + +tensor_guid_t create_transformer_encoder_layer(ComputationGraphBuilder &cgb, + TransformerConfig const &config, + tensor_guid_t const &input) { + std::vector layer_norm_axis{2}; // Normalize the last dim + int kdim = config.dim_feedforward / config.num_heads; + int vdim = config.dim_feedforward / config.num_heads; + tensor_guid_t self_attention = cgb.multihead_attention(input, + input, + input, + config.num_features, + config.num_heads, + kdim, + vdim, + config.dropout); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, self_attention)); + + tensor_guid_t normalized = cgb.layer_norm(cgb.add(self_attention, input), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, normalized)); + + tensor_guid_t feedforward_output = + create_feedforward_network(cgb, config, normalized); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, feedforward_output)); + return cgb.layer_norm(cgb.add(normalized, feedforward_output), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); +} + +tensor_guid_t create_transformer_encoder(ComputationGraphBuilder &cgb, + TransformerConfig const &config, + tensor_guid_t const &input) { + tensor_guid_t t = input; + for (int i = 0; i < config.num_encoder_layers; i++) { + t = create_transformer_encoder_layer(cgb, config, t); + } + return t; +}; + +tensor_guid_t + create_transformer_decoder_layer(ComputationGraphBuilder &cgb, + TransformerConfig const &config, + tensor_guid_t const &input, + tensor_guid_t const &encoder_output) { + std::vector layer_norm_axis{2}; // Normalize the last dim + int kdim = config.dim_feedforward / config.num_heads; + int vdim = config.dim_feedforward / config.num_heads; + tensor_guid_t self_attention = cgb.multihead_attention(input, + input, + input, + config.num_features, + config.num_heads, + kdim, + vdim, + config.dropout); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, self_attention)); + + tensor_guid_t self_attention_normalized = + cgb.layer_norm(cgb.add(input, self_attention), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, self_attention_normalized)); + + tensor_guid_t mha = cgb.multihead_attention(input, + encoder_output, + encoder_output, + config.num_features, + config.num_heads, + kdim, + vdim, + config.dropout); + assert(are_tensor_guid_shapes_equivalent(cgb.computation_graph, input, mha)); + + tensor_guid_t mha_normalized = + cgb.layer_norm(cgb.add(self_attention_normalized, mha), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, mha_normalized)); + + tensor_guid_t feedforward_output = + create_feedforward_network(cgb, config, mha_normalized); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, feedforward_output)); + + return cgb.layer_norm(cgb.add(mha_normalized, feedforward_output), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); +} + +tensor_guid_t create_transformer_decoder(ComputationGraphBuilder &cgb, + TransformerConfig const &config, + tensor_guid_t const &input, + tensor_guid_t const &encoder_output) { + tensor_guid_t t = input; + for (int i = 0; i < config.num_decoder_layers; i++) { + t = create_transformer_decoder_layer(cgb, config, t, encoder_output); + } + return t; +} + +ComputationGraph + get_transformer_computation_graph(TransformerConfig const &config) { + ComputationGraphBuilder cgb; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + config.batch_size, config.sequence_length, config.num_features}}, + DataType::FLOAT, + }; + tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); + + tensor_guid_t encoder_output = create_transformer_encoder(cgb, config, input); + tensor_guid_t decoder_output = + create_transformer_decoder(cgb, config, input, encoder_output); + + tensor_guid_t out_prob = cgb.softmax(cgb.dense(decoder_output, + /*outDim=*/config.vocab_size, + Activation::RELU, + /*use_bias=*/true)); + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/CMakeLists.txt b/lib/models/test/CMakeLists.txt new file mode 100644 index 0000000000..9c2ae64d4e --- /dev/null +++ b/lib/models/test/CMakeLists.txt @@ -0,0 +1,12 @@ +ff_add_test_executable( + NAME + models-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + models + doctest + utils-test-common +) diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer.cc new file mode 100644 index 0000000000..2133e9965b --- /dev/null +++ b/lib/models/test/src/models/transformer.cc @@ -0,0 +1,19 @@ +#include "models/transformer.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transformer_computation_graph") { + TransformerConfig config = get_default_transformer_config(); + + ComputationGraph result = get_transformer_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 317; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml index f8811e2b0b..014526a601 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.variant.toml @@ -13,7 +13,7 @@ includes = [ "op-attrs/ops/attention_attrs.dtg.h", "op-attrs/ops/batch_matmul.dtg.h", "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", "op-attrs/ops/cast_attrs.dtg.h", "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index dbc237a03d..6868ba083f 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H #include "op-attrs/ff_dim.dtg.h" +#include "utils/fmt/vector.h" #include "utils/json.h" #include "utils/stack_vector.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h new file mode 100644 index 0000000000..c843ed3842 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H + +#include "op-attrs/dim_ordered.h" + +namespace FlexFlow { + +template +FFOrdered ff_ordered_of(C const &c) { + return FFOrdered{c.cbegin(), c.cend()}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h new file mode 100644 index 0000000000..560862677e --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers/count.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +std::vector get_idxs(FFOrdered const &d) { + return transform(count(d.size()), [](int i) { return ff_dim_t{i}; }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h new file mode 100644 index 0000000000..54554afb81 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H + +#include "op-attrs/dim_ordered.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/zip.h" + +namespace FlexFlow { + +template +DimOrdered> zip(DimOrdered const &lhs, + DimOrdered const &rhs) { + return DimOrdered>{ + zip(as_vector(lhs), as_vector(rhs))}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/get_op_type.h b/lib/op-attrs/include/op-attrs/get_op_type.h index 8ea7208415..b60880a98b 100644 --- a/lib/op-attrs/include/op-attrs/get_op_type.h +++ b/lib/op-attrs/include/op-attrs/get_op_type.h @@ -4,7 +4,7 @@ #include "op-attrs/ops/attention_attrs.dtg.h" #include "op-attrs/ops/batch_matmul.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" -#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/broadcast_attrs.dtg.h" #include "op-attrs/ops/cast_attrs.dtg.h" #include "op-attrs/ops/combine_attrs.dtg.h" #include "op-attrs/ops/concat_attrs.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 724e499810..612c226a13 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -81,7 +81,7 @@ std::vector get_tensor_shapes_unsafe(std::vector const &); template -TensorShape get_output_shape(Attrs const &attrs, TensorShape const &) { +TensorShape get_output_shape(Attrs const &attrs, TensorShape const &shape) { NOT_IMPLEMENTED(); } @@ -110,19 +110,13 @@ template std::vector get_output_shapes(Attrs const &attrs, std::vector const &); -ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &, - std::vector const &); ParallelTensorShape get_output_shape(ConcatAttrs const &, std::vector const &); -ParallelTensorShape get_output_shape(DropoutAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(FlatAttrs const &, ParallelTensorShape const &); std::vector get_output_shapes(GatherAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(LayerNormAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(ReduceAttrs const &, @@ -131,8 +125,6 @@ ParallelTensorShape get_output_shape(ReverseAttrs const &, ParallelTensorShape const &); std::vector get_output_shapes(SplitAttrs const &, ParallelTensorShape const &); -ParallelTensorShape get_output_shape(SoftmaxAttrs const &, - ParallelTensorShape const &); ParallelTensorShape get_output_shape(TopKAttrs const &, ParallelTensorShape const &); ParallelTensorShape get_output_shape(TransposeAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index ad44060400..0a5f057578 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -1,11 +1,17 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BROADCAST_H -#include "op-attrs/ops/broadcast.dtg.h" +#include "op-attrs/ops/broadcast_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { +CHECK_VALID_OP_ATTR(BroadcastAttrs); + +tl::expected get_output_shape(BroadcastAttrs const &, + TensorShape const &); ParallelTensorShape get_output_shape(BroadcastAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml b/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.struct.toml similarity index 68% rename from lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml rename to lib/op-attrs/include/op-attrs/ops/broadcast_attrs.struct.toml index c87afa59b5..52e2ee66ca 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/broadcast_attrs.struct.toml @@ -10,9 +10,9 @@ features = [ ] includes = [ - "utils/stack_vector.h", + "op-attrs/tensor_dims.dtg.h", ] [[fields]] name = "target_dims" -type = "::FlexFlow::stack_vector" +type = "::FlexFlow::TensorDims" diff --git a/lib/op-attrs/include/op-attrs/ops/dropout.h b/lib/op-attrs/include/op-attrs/ops/dropout.h index a0493301c4..86e5db4d77 100644 --- a/lib/op-attrs/include/op-attrs/ops/dropout.h +++ b/lib/op-attrs/include/op-attrs/ops/dropout.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_DROPOUT_ATTRS_H #define _FLEXFLOW_DROPOUT_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/dropout_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(DropoutAttrs const &, - ParallelTensorShape const &); +TensorShape get_output_shape(DropoutAttrs const &, TensorShape const &); +tl::expected + get_output_shape(DropoutAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(DropoutAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 01130139f1..94f9b9e147 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -7,8 +7,20 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(LayerNormAttrs const &, - ParallelTensorShape const &); +tl::expected get_output_shape(LayerNormAttrs const &, + TensorShape const &); +tl::expected + get_gamma_weights_shape(LayerNormAttrs const &, TensorShape const &); +tl::expected + get_beta_weights_shape(LayerNormAttrs const &, TensorShape const &); + +tl::expected + get_output_shape(LayerNormAttrs const &, ParallelTensorShape const &); +tl::expected + get_gamma_weights_shape(LayerNormAttrs const &, + ParallelTensorShape const &); +tl::expected + get_beta_weights_shape(LayerNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(LayerNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/softmax.h b/lib/op-attrs/include/op-attrs/ops/softmax.h index d855716cfb..6eacc66b78 100644 --- a/lib/op-attrs/include/op-attrs/ops/softmax.h +++ b/lib/op-attrs/include/op-attrs/ops/softmax.h @@ -1,16 +1,20 @@ #ifndef _FLEXFLOW_SOFTMAX_ATTRS_H #define _FLEXFLOW_SOFTMAX_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/softmax_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(SoftmaxAttrs); -ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, - ParallelTensorShape const &input_shape); +tl::expected + get_output_shape(SoftmaxAttrs const &attrs, TensorShape const &input_shape); +tl::expected + get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index 2391197471..c8af3b02e7 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -12,6 +12,11 @@ size_t num_dims(TensorDims const &); size_t dim_at_idx(TensorDims const &, ff_dim_t); size_t &dim_at_idx(TensorDims &, ff_dim_t); +bool tensor_dims_is_broadcastable_to(TensorDims const &curr, + TensorDims const &goal); +std::optional + get_broadcast_target_dims(std::unordered_set const &); + ParallelTensorDims lift_to_parallel(TensorDims const &); ParallelTensorDims lift_to_parallel_with_degrees(TensorDims const &, diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 14ee637f92..108df58dce 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -11,6 +11,11 @@ size_t &dim_at_idx(TensorShape &, ff_dim_t); size_t get_num_elements(TensorShape const &); size_t get_size_in_bytes(TensorShape const &); +bool tensor_shape_is_broadcastable_to(TensorShape const &curr, + TensorShape const &goal); +std::optional + get_broadcast_target_shape(std::unordered_set const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/layer_norm.cc b/lib/op-attrs/src/layer_norm.cc deleted file mode 100644 index ab88de3622..0000000000 --- a/lib/op-attrs/src/layer_norm.cc +++ /dev/null @@ -1,3 +0,0 @@ -#include "op-attrs/ops/layer_norm.h" - -namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_of.cc b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_of.cc new file mode 100644 index 0000000000..8e5c2fd38a --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_of.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/ff_ordered_of.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc b/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc new file mode 100644 index 0000000000..175ae8d4bd --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/get_idxs.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/get_idxs.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc new file mode 100644 index 0000000000..208fc4a719 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/zip.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/zip.h" diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc new file mode 100644 index 0000000000..bd69864aff --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -0,0 +1,33 @@ +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/tensor_dims.h" + +namespace FlexFlow { + +tl::expected + get_output_shape(BroadcastAttrs const &attrs, + TensorShape const &input_shape) { + if (num_dims(attrs.target_dims) < num_dims(input_shape.dims)) { + return tl::unexpected(fmt::format( + "get_output_shape for Broadcast expected num_dims(input_dims) <= " + "num_dims(target_dims), but recieved input_shape {} with num dims " + "greater than target_dims {}", + input_shape, + attrs.target_dims)); + } + + if (tensor_dims_is_broadcastable_to(input_shape.dims, attrs.target_dims)) { + return TensorShape{attrs.target_dims, input_shape.data_type}; + } else { + return tl::unexpected(fmt::format( + "Input tensor shape {} is not broadcastable to target dims {}", + input_shape, + attrs.target_dims)); + } +} + +ParallelTensorShape get_output_shape(BroadcastAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/dropout.cc b/lib/op-attrs/src/op-attrs/ops/dropout.cc index adbd144f38..8852acaad1 100644 --- a/lib/op-attrs/src/op-attrs/ops/dropout.cc +++ b/lib/op-attrs/src/op-attrs/ops/dropout.cc @@ -1,10 +1,29 @@ #include "op-attrs/ops/dropout.h" +#include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(DropoutAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(DropoutAttrs const &, + TensorShape const &input_shape) { + return input_shape; +} + +tl::expected + get_output_shape(DropoutAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discard copy degree {}", + get_discard_copy_degree(input_shape))); + } + + return input_shape; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index 437ba3638a..76a5e25dfc 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -1,10 +1,177 @@ #include "op-attrs/ops/layer_norm.h" +#include "op-attrs/dim_ordered/ff_ordered_of.h" +#include "op-attrs/dim_ordered/get_idxs.h" +#include "utils/containers/all_of.h" +#include "utils/containers/any_of.h" +#include "utils/containers/contains.h" +#include "utils/containers/filter.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(LayerNormAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +static std::optional + check_input_shape(LayerNormAttrs const &attrs, + TensorShape const &input_shape) { + if (any_of(attrs.axes, [&](ff_dim_t axis) { + return axis.value >= num_dims(input_shape); + })) { + return fmt::format( + "LayerNorm axes {} out-of-bounds for input tensor shape {}", + attrs.axes, + input_shape); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(LayerNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + return input_shape; +} + +tl::expected + get_gamma_weights_shape(LayerNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.elementwise_affine) { + return tl::unexpected( + "No gamma weights exist for attrs.elementwise_affine = false"); + } + + std::vector non_layer_norm_dim_idxs = filter( + get_idxs(input_shape.dims.ff_ordered), + [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); + std::vector raw_weight_dims = + transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { + return dim_at_idx(input_shape, dim_idx); + }); + + return TensorShape{ + TensorDims{ff_ordered_of(raw_weight_dims)}, + DataType::FLOAT, + }; +} + +tl::expected + get_beta_weights_shape(LayerNormAttrs const &attrs, + TensorShape const &input_shape) { + if (!attrs.elementwise_affine) { + return tl::unexpected( + "No beta weights exist for attrs.elementwise_affine = false"); + } + + return get_gamma_weights_shape(attrs, input_shape); +} + +static std::optional + check_input_shape(LayerNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + { + TensorShape reduced_shape = get_reduced_shape(input_shape); + std::optional maybe_err_msg = + check_input_shape(attrs, reduced_shape); + if (maybe_err_msg.has_value()) { + return maybe_err_msg; + } + } + + if (get_sum_degree(input_shape) != 1) { + return fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape)); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return fmt::format( + "Expected discard copy degree 1, but received discartd copy degree {}", + get_discard_copy_degree(input_shape)); + } + + if (!all_of(attrs.axes, [&](ff_dim_t axis) { + return shard_dim_at_idx(input_shape, axis).degree == 1; + })) { + return fmt::format("Expected parallel degree of all dimensions in " + "LayerNorm axes {} to be 1, but received input shape {}", + attrs.axes, + input_shape); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(LayerNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + return input_shape; +} + +tl::expected + get_gamma_weights_shape(LayerNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.elementwise_affine) { + return tl::unexpected( + "No gamma weights exist for attrs.elementwise_affine = false"); + } + + std::vector non_layer_norm_dim_idxs = filter( + get_idxs(input_shape.dims.shard_dims), + [&](ff_dim_t const &dim_idx) { return !contains(attrs.axes, dim_idx); }); + std::vector raw_weight_shard_dims = + transform(non_layer_norm_dim_idxs, [&](ff_dim_t const &dim_idx) { + return shard_dim_at_idx(input_shape, dim_idx); + }); + + return ParallelTensorShape{ + ParallelTensorDims{ + ff_ordered_of(raw_weight_shard_dims), + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; +} + +tl::expected + get_beta_weights_shape(LayerNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + if (!attrs.elementwise_affine) { + return tl::unexpected( + "No beta weights exist for attrs.elementwise_affine = false"); + } + + return get_gamma_weights_shape(attrs, input_shape); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/softmax.cc b/lib/op-attrs/src/op-attrs/ops/softmax.cc index 2d870af50e..541c590cbd 100644 --- a/lib/op-attrs/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/src/op-attrs/ops/softmax.cc @@ -1,10 +1,53 @@ #include "op-attrs/ops/softmax.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { -ParallelTensorShape get_output_shape(SoftmaxAttrs const &attrs, - ParallelTensorShape const &input_shape) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(SoftmaxAttrs const &attrs, + TensorShape const &input_shape) { + if (attrs.dim.value >= num_dims(input_shape)) { + return tl::unexpected( + fmt::format("get_output_shape for Softmax received out-of-bounds " + "attrs.dim {} for input tensor shape {}", + attrs.dim, + input_shape)); + } + + return input_shape; +} + +tl::expected + get_output_shape(SoftmaxAttrs const &attrs, + ParallelTensorShape const &input_shape) { + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + + if (get_sum_degree(input_shape) != 1) { + return tl::unexpected( + fmt::format("Expected sum degree 1, but receieved sum degree {}", + get_sum_degree(input_shape))); + } + + if (get_discard_copy_degree(input_shape) != 1) { + return tl::unexpected(fmt::format( + "Expected discard copy degree 1, but received discard copy degree {}", + get_discard_copy_degree(input_shape))); + } + + if (shard_dim_at_idx(input_shape, attrs.dim).degree != 1) { + return tl::unexpected( + fmt::format("Expected parallel degree of Softmax dimension {} to be 1, " + "but received input shape {}", + attrs.dim, + input_shape)); + } + + return input_shape; } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index 47bd8a4821..e716793a8f 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -1,7 +1,10 @@ #include "op-attrs/tensor_dims.h" +#include "op-attrs/dim_ordered/zip.h" #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.dtg.h" +#include "utils/containers/all_of.h" #include "utils/containers/as_vector.h" +#include "utils/containers/reversed.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" #include "utils/integer_conversions.h" @@ -24,6 +27,38 @@ size_t &dim_at_idx(TensorDims &dims, ff_dim_t idx) { return dims.ff_ordered.at(idx); } +bool tensor_dims_is_broadcastable_to(TensorDims const &curr, + TensorDims const &goal) { + if (num_dims(curr) > num_dims(goal)) { + return false; + } + + std::vector curr_dims = as_vector(curr.ff_ordered); + std::vector goal_dims = as_vector(goal.ff_ordered); + + for (auto const &[curr_dim, goal_dim] : + zip(reversed(curr_dims), reversed(goal_dims))) { + if (curr_dim != 1 && curr_dim != goal_dim) { + return false; + } + } + + return true; +} + +std::optional + get_broadcast_target_dims(std::unordered_set const &dims) { + for (TensorDims target_candidate : dims) { + if (all_of(dims, [&](TensorDims const &d) { + return tensor_dims_is_broadcastable_to(d, target_candidate); + })) { + return target_candidate; + } + } + + return std::nullopt; +} + ParallelTensorDims lift_to_parallel(TensorDims const &dims) { std::vector shard_degrees(num_dims(dims), 1); // 1 repeated num_dims(dims) times diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index 9d564a6d27..b604d442cb 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -1,7 +1,9 @@ #include "op-attrs/tensor_shape.h" #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.h" +#include "utils/containers/get_only.h" #include "utils/containers/product.h" +#include "utils/containers/transform.h" namespace FlexFlow { @@ -25,4 +27,35 @@ size_t get_size_in_bytes(TensorShape const &s) { return get_num_elements(s) * size_of_datatype(s.data_type); } +bool tensor_shape_is_broadcastable_to(TensorShape const &curr, + TensorShape const &goal) { + return tensor_dims_is_broadcastable_to(curr.dims, goal.dims) && + curr.data_type == goal.data_type; +} + +std::optional + get_broadcast_target_shape(std::unordered_set const &shapes) { + std::unordered_set datatypes = + transform(shapes, [](TensorShape const &s) { return s.data_type; }); + + if (datatypes.size() != 1) { + return std::nullopt; + } + + std::unordered_set shapes_dims = + transform(shapes, [](TensorShape const &s) { return s.dims; }); + + std::optional maybe_result_dims = + get_broadcast_target_dims(shapes_dims); + std::optional result = + transform(maybe_result_dims, [&](TensorDims const &result_dims) { + return TensorShape{ + result_dims, + get_only(datatypes), + }; + }); + + return result; +} + } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc new file mode 100644 index 0000000000..11e09dc43f --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc @@ -0,0 +1,41 @@ +#include "op-attrs/dim_ordered/zip.h" +#include "op-attrs/ff_dim.dtg.h" +#include "utils/fmt/pair.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("zip(DimOrdered, DimOrdered)") { + DimOrdered lhs_input = {9, 9, 8, 9}; + DimOrdered rhs_input = {"m", "m", "k", "l", "m"}; + + SUBCASE("lhs is longer") { + DimOrdered> result = + zip(lhs_input, rhs_input); + + DimOrdered> correct = { + {9, "m"}, + {9, "m"}, + {8, "k"}, + {9, "l"}, + }; + + CHECK(result == correct); + } + + SUBCASE("rhs is longer") { + DimOrdered> result = + zip(rhs_input, lhs_input); + + DimOrdered> correct = { + {"m", 9}, + {"m", 9}, + {"k", 8}, + {"l", 9}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc new file mode 100644 index 0000000000..17a68ccbc8 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc @@ -0,0 +1,101 @@ +#include "op-attrs/ops/dropout.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(DropoutAttrs, TensorShape)") { + DropoutAttrs attrs = DropoutAttrs{ + /*rate=*/0.5, + /*seed=*/1, + }; + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + }}, + DataType::FLOAT, + }; + + TensorShape result = get_output_shape(attrs, input); + TensorShape correct = input; + + CHECK(result == correct); + } + + TEST_CASE("get_output_shape(DropoutAttrs, ParallelTensorShape)") { + DropoutAttrs attrs = DropoutAttrs{ + /*rate=*/0.5, + /*seed=*/1, + }; + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + }}, + DataType::FLOAT, + }; + + TensorShape output = input; + + auto make_input = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto make_output = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + SUBCASE("partition parallelism (allowed)") { + int degree0 = 2; + int degree2 = 4; + + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = + make_output(SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); + + CHECK(result == correct); + } + + SUBCASE("sum parallelism (not allowed)") { + SumDegree sum_degree = SumDegree{2}; + + ParallelTensorShape par_input = + make_input(sum_degree, DiscardCopyDegree{1}, 1, 1, 1); + + std::optional result = + optional_from_expected(get_output_shape(attrs, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("discard copy parallelism (not allowed)") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + ParallelTensorShape par_input = + make_input(SumDegree{1}, discard_copy_degree, 1, 1, 1); + + std::optional result = + optional_from_expected(get_output_shape(attrs, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc new file mode 100644 index 0000000000..8f5f4054d6 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -0,0 +1,278 @@ +#include "op-attrs/ops/layer_norm.h" +#include "test/utils/doctest.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("shape inference (LayerNorm)") { + LayerNormAttrs attrs_affine_true = LayerNormAttrs{ + /*axes=*/{ff_dim_t{1}, ff_dim_t{3}}, + /*elementwise_affine=*/true, + /*eps=*/0.1, + }; + + LayerNormAttrs attrs_affine_false = [&] { + LayerNormAttrs attrs = attrs_affine_true; + attrs.elementwise_affine = false; + return attrs; + }(); + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + 18, + }}, + DataType::FLOAT, + }; + + TensorShape output = input; + + TensorShape gamma = TensorShape{ + TensorDims{FFOrdered{ + 12, + 16, + }}, + DataType::FLOAT, + }; + + TensorShape beta = gamma; + + SUBCASE("get_output_shape(LayerNormAttrs, TensorShape)") { + tl::expected result = + get_output_shape(attrs_affine_true, input); + tl::expected correct = output; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(LayerNormAttrs, TensorShape)") { + SUBCASE("elementwise_affine = true") { + tl::expected result = + get_gamma_weights_shape(attrs_affine_true, input); + tl::expected correct = gamma; + + CHECK(result == correct); + } + + SUBCASE("elementwise_affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_shape(LayerNormAttrs, TensorShape)") { + SUBCASE("elementwise_affine = true") { + tl::expected result = + get_beta_weights_shape(attrs_affine_true, input); + tl::expected correct = beta; + + CHECK(result == correct); + } + + SUBCASE("elementwise_affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + auto make_input = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o0, + int o1, + int o2, + int o3) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o0, o1, o2, o3}); + }; + + auto make_output = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o0, + int o1, + int o2, + int o3) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o0, o1, o2, o3}); + }; + + auto make_gamma_weights = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o2) { + return lift_to_parallel_with_degrees( + gamma, o_sum, o_eq, FFOrdered{o0, o2}); + }; + + auto make_beta_weights = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o2) { + return lift_to_parallel_with_degrees( + beta, o_sum, o_eq, FFOrdered{o0, o2}); + }; + + SUBCASE("parallel shape inference (LayerNorm)") { + SUBCASE("partition parallelism (not in axes)") { + int degree0 = 2; + int degree2 = 3; + + ParallelTensorShape par_input = make_input( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2, 1); + + SUBCASE("get_output_shape(LayerNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_output_shape(attrs_affine_true, par_input); + tl::expected correct = make_output( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2, 1); + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + SUBCASE("elementwise_affine = true") { + tl::expected result = + get_gamma_weights_shape(attrs_affine_true, par_input); + tl::expected correct = + make_gamma_weights( + SumDegree{1}, DiscardCopyDegree{1}, degree0, degree2); + + CHECK(result == correct); + } + + SUBCASE("elementwise_affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_false, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + SUBCASE("elementwise_affine = true") { + tl::expected result = + get_beta_weights_shape(attrs_affine_true, par_input); + tl::expected correct = + make_beta_weights( + SumDegree{1}, DiscardCopyDegree{1}, degree0, degree2); + + CHECK(result == correct); + } + + SUBCASE("elementwise_affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_false, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + SUBCASE("partition parallelism (in axes)") { + int degree1 = 2; + int degree2 = 4; + + ParallelTensorShape par_input = make_input( + SumDegree{1}, DiscardCopyDegree{1}, 1, degree1, degree2, 1); + + SUBCASE("get_output_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_output_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("sum parallelism") { + SumDegree sum_degree = SumDegree{2}; + + ParallelTensorShape par_input = + make_input(sum_degree, DiscardCopyDegree{1}, 1, 1, 1, 1); + + SUBCASE("get_output_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_output_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy parallelism") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + ParallelTensorShape par_input = + make_input(SumDegree{1}, discard_copy_degree, 1, 1, 1, 1); + + SUBCASE("get_output_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_output_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE( + "get_gamma_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(LayerNormAttrs, ParallelTensorShape)") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_true, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc new file mode 100644 index 0000000000..f6a8da016f --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc @@ -0,0 +1,138 @@ +#include "op-attrs/ops/softmax.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(SoftmaxAttrs, TensorShape)") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + }}, + DataType::FLOAT, + }; + + SUBCASE("attrs.dim in bounds") { + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("attrs.dims out of bounds") { + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{4}}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(SoftmaxAttrs, ParallelTensorShape)") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + }}, + DataType::FLOAT, + }; + TensorShape output = input; + + auto make_input = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto make_output = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + output, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + SUBCASE("partition parallelism in non-softmax-dim (valid)") { + int degree0 = 2; + int degree2 = 4; + + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); + + SUBCASE("attrs.dim in bounds") { + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = make_output( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); + + CHECK(result == correct); + } + + SUBCASE("attrs.dims out of bounds") { + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{4}}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("partition parallism in softmax dim (invalid)") { + int degree1 = 2; + + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + + ParallelTensorShape par_input = + make_input(SumDegree{1}, DiscardCopyDegree{1}, 1, degree1, 1); + + std::optional result = + optional_from_expected(get_output_shape(attrs, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("sum parallelism (invalid)") { + SumDegree sum_degree = SumDegree{2}; + + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + + ParallelTensorShape par_input = + make_input(sum_degree, DiscardCopyDegree{1}, 1, 1, 1); + + std::optional result = + optional_from_expected(get_output_shape(attrs, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("discard copy parallelism (invalid)") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{1}}; + + ParallelTensorShape par_input = + make_input(SumDegree{1}, discard_copy_degree, 1, 1, 1); + + std::optional result = + optional_from_expected(get_output_shape(attrs, par_input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc new file mode 100644 index 0000000000..25c7eb036f --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc @@ -0,0 +1,121 @@ +#include "op-attrs/tensor_dims.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("tensor_dims_is_broadcastable_to(TensorDims, TensorDims)") { + + TensorDims goal = TensorDims{FFOrdered{1, 1, 4, 3}}; + + SUBCASE("dims match") { + bool result = tensor_dims_is_broadcastable_to(goal, goal); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("curr only needs num_dims promotion") { + TensorDims curr = TensorDims{FFOrdered{4, 3}}; + + bool result = tensor_dims_is_broadcastable_to(curr, goal); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("curr only needs dim expansion") { + TensorDims curr = TensorDims{FFOrdered{1, 1, 1, 3}}; + + bool result = tensor_dims_is_broadcastable_to(curr, goal); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("curr needs both num_dims promotion and dim expansion") { + TensorDims curr = TensorDims{FFOrdered{1, 3}}; + + bool result = tensor_dims_is_broadcastable_to(curr, goal); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("curr needs invalid dim promotion") { + TensorDims curr = TensorDims{FFOrdered{1, 1, 2, 3}}; + + bool result = tensor_dims_is_broadcastable_to(curr, goal); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("num_dims(goal) < num_dims(curr)") { + TensorDims curr = TensorDims{FFOrdered{1, 1, 10, 4, 3}}; + + bool result = tensor_dims_is_broadcastable_to(curr, goal); + bool correct = false; + + CHECK(result == correct); + } + } + + TEST_CASE("get_broadcast_target_dims(std::unordered_set)") { + TensorDims d1 = TensorDims{FFOrdered{1, 10, 4, 3}}; + + TensorDims d2 = TensorDims{FFOrdered{10, 4, 1}}; + + SUBCASE("has target in inputs") { + TensorDims d3 = TensorDims{FFOrdered{1, 1, 4, 3}}; + + std::optional result = + get_broadcast_target_dims({d1, d2, d3}); + std::optional correct = d1; + + CHECK(result == correct); + } + + SUBCASE("has no possible target") { + TensorDims d3 = TensorDims{FFOrdered{1, 1, 1, 4}}; + + std::optional result = + get_broadcast_target_dims({d1, d2, d3}); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("has possible target, but not in inputs") { + TensorDims d3 = TensorDims{FFOrdered{1, 1, 1, 4, 3}}; + + TensorDims possible_target = + TensorDims{FFOrdered{1, 1, 10, 4, 3}}; + + REQUIRE(tensor_dims_is_broadcastable_to(d1, possible_target)); + REQUIRE(tensor_dims_is_broadcastable_to(d2, possible_target)); + REQUIRE(tensor_dims_is_broadcastable_to(d3, possible_target)); + + std::optional result = + get_broadcast_target_dims({d1, d2, d3}); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("inputs is empty") { + std::optional result = get_broadcast_target_dims({}); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("all inputs are same") { + std::optional result = + get_broadcast_target_dims({d1, d1, d1, d1, d1}); + std::optional correct = d1; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc b/lib/op-attrs/test/src/op-attrs/tensor_shape.cc new file mode 100644 index 0000000000..bc715c183a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/tensor_shape.cc @@ -0,0 +1,64 @@ +#include "op-attrs/tensor_shape.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_broadcast_target_shape(std::unordered_set)") { + SUBCASE("target exists in inputs") { + DataType datatype = DataType::FLOAT; + + TensorShape s1 = TensorShape{ + TensorDims{FFOrdered{ + 1, + }}, + datatype, + }; + + TensorShape s2 = TensorShape{ + TensorDims{FFOrdered{10, 4, 3}}, + datatype, + }; + + TensorShape s3 = TensorShape{ + TensorDims{FFOrdered{ + 4, + 1, + }}, + datatype, + }; + + std::optional result = + get_broadcast_target_shape({s1, s2, s3}); + std::optional correct = s2; + + CHECK(result == correct); + } + + SUBCASE("datatypes don't match") { + TensorDims dims = TensorDims{FFOrdered{10, 4, 3}}; + + TensorShape s1 = TensorShape{ + dims, + DataType::FLOAT, + }; + + TensorShape s2 = TensorShape{ + dims, + DataType::DOUBLE, + }; + + std::optional result = get_broadcast_target_shape({s1, s2}); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("inputs is empty") { + std::optional result = get_broadcast_target_shape({}); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 46d5b22afb..088139a0f3 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -18,6 +18,9 @@ LayerAddedResult add_layer(ComputationGraph &computation_graph, std::vector const &inputs, std::vector const &outputs); TensorAttrs get_tensor_attrs(ComputationGraph const &, tensor_guid_t const &); +bool are_tensor_guid_shapes_equivalent(ComputationGraph const &cg, + tensor_guid_t const &t1, + tensor_guid_t const &t2); std::vector topological_ordering(ComputationGraph const &cg); diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 0e453c50c3..0ca330408e 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -189,7 +189,7 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a softmax layer tensor_guid_t softmax(tensor_guid_t const &input, - int dim = -1, + std::optional dim = std::nullopt, std::optional const &name = std::nullopt); // Create input tensors and constants tensor_guid_t @@ -244,7 +244,9 @@ struct ComputationGraphBuilder { private: TensorShape get_shape(tensor_guid_t const &) const; - tensor_guid_t broadcast(tensor_guid_t const &, TensorShape const &); + tensor_guid_t broadcast(tensor_guid_t const &, + TensorShape const &, + std::string const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index afa1774858..e7f5f2b737 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -25,6 +25,12 @@ TensorAttrs get_tensor_attrs(ComputationGraph const &cg, return cg.raw_graph.at(t.raw_graph_output); } +bool are_tensor_guid_shapes_equivalent(ComputationGraph const &cg, + tensor_guid_t const &t1, + tensor_guid_t const &t2) { + return get_tensor_attrs(cg, t1).shape == get_tensor_attrs(cg, t2).shape; +} + std::vector topological_ordering(ComputationGraph const &cg) { std::vector layers = get_topological_ordering(cg.raw_graph); return transform( diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 5028ed5709..b6d0e7c890 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -6,11 +6,12 @@ #include "op-attrs/ops/embedding.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "pcg/computation_graph.h" +#include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/enumerate_vector.h" #include "utils/containers/transform.h" #include "utils/expected.h" -#include "utils/fmt.h" +#include namespace FlexFlow { @@ -117,9 +118,25 @@ tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, } } -tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &, - TensorShape const &) { - NOT_IMPLEMENTED(); +tensor_guid_t + ComputationGraphBuilder::broadcast(tensor_guid_t const &input, + TensorShape const &target_shape, + std::string const &name) { + TensorShape input_shape = this->get_shape(input); + if (!tensor_shape_is_broadcastable_to(input_shape, target_shape)) { + throw mk_runtime_error(fmt::format( + "Cannot broadcast input tensor of shape {} to target shape {}", + input_shape, + target_shape)); + } + + BroadcastAttrs attrs = BroadcastAttrs{target_shape.dims}; + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + return this->add_layer(layer, {input}, {}, output_shape); } tensor_guid_t @@ -170,12 +187,16 @@ tensor_guid_t ComputationGraphBuilder::element_binary( DataType compute_type = std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); - tensor_guid_t lhs_input = this->as_type(this->broadcast(lhs, compute_shape), - compute_type, - name + "_inputl_pre_cast"); - tensor_guid_t rhs_input = this->as_type(this->broadcast(rhs, compute_shape), - compute_type, - name + "_inputr_pre_cast"); + tensor_guid_t lhs_input = this->as_type( + this->broadcast( + lhs, compute_shape, fmt::format("{}_inputl_broadcast", name)), + compute_type, + name + "_inputl_cast"); + tensor_guid_t rhs_input = this->as_type( + this->broadcast( + rhs, compute_shape, fmt::format("{}_inputr_broadcast", name)), + compute_type, + name + "_inputr_cast"); ElementBinaryAttrs attrs = ElementBinaryAttrs{op_type, compute_type, false, false}; @@ -475,11 +496,11 @@ std::vector ComputationGraphBuilder::gather( // LayerAttrs layer = {attrs, name}; // TensorShape output_shape = get_output_shape(attrs, -// get_shape(gate_preds), -// get_shape(gate_assign), -// get_shape(true_gate_assign), -// get_shape(full_gate_gradients), -// get_shape(exp_preds)); +// this->get_shape(gate_preds), +// this->get_shape(gate_assign), +// this->get_shape(true_gate_assign), +// this->get_shape(full_gate_gradients), +// this->get_shape(exp_preds)); // std::vector inputs = { // gate_preds, gate_assign, true_gate_assign, full_gate_gradients}; @@ -530,11 +551,17 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = throw_if_unexpected(get_output_shape( - attrs, get_shape(query), get_shape(key), get_shape(value))); - - TensorShape weights_shape = throw_if_unexpected(get_weights_shape( - attrs, get_shape(query), get_shape(key), get_shape(value))); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, + this->get_shape(query), + this->get_shape(key), + this->get_shape(value))); + + TensorShape weights_shape = + throw_if_unexpected(get_weights_shape(attrs, + this->get_shape(query), + this->get_shape(key), + this->get_shape(value))); TensorAttrs weight_attrs = make_weight_attrs(weights_shape, initializer); return this->add_layer(layer, @@ -544,13 +571,141 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( } TensorShape ComputationGraphBuilder::get_broadcast_target_shape( - std::vector const &) { - NOT_IMPLEMENTED(); + std::vector const &inputs) { + std::vector input_shapes = transform( + inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); + + return this->get_broadcast_target_shape(input_shapes); } TensorShape ComputationGraphBuilder::get_broadcast_target_shape( - std::vector const &) { - NOT_IMPLEMENTED(); + std::vector const &input_shapes) { + std::optional maybe_result = + ::FlexFlow::get_broadcast_target_shape(unordered_set_of(input_shapes)); + + if (maybe_result.has_value()) { + return maybe_result.value(); + } else { + throw mk_runtime_error(fmt::format( + "ComputationGraphBuilder::get_broadcast_target_shape failed to find " + "target tensor shape for input tensor shapes {}", + input_shapes)); + } +} + +tensor_guid_t ComputationGraphBuilder::dense( + tensor_guid_t const &input, + int outDim, + std::optional activation, + bool use_bias, + DataType data_type, + std::optional const &kernel_initializer, + std::optional const &bias_initializer, + std::optional const &maybe_name) { + LinearAttrs attrs = + LinearAttrs{outDim, use_bias, data_type, activation, std::nullopt}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + std::vector weights; + TensorShape kernel_shape = + throw_if_unexpected(get_kernel_shape(attrs, this->get_shape(input))); + weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + + if (use_bias) { + TensorShape bias_shape = + throw_if_unexpected(get_bias_shape(attrs, this->get_shape(input))); + weights.push_back(make_weight_attrs(bias_shape, bias_initializer)); + } + + return this->add_layer(layer, {input}, weights, output_shape); +} + +tensor_guid_t ComputationGraphBuilder::layer_norm( + tensor_guid_t const &input, + std::vector const &axes, + bool elementwise_affine, + float eps, + std::optional const &maybe_name) { + + TensorShape input_shape = this->get_shape(input); + + if (any_of(axes, + [&](size_t axis) { return axis >= num_dims(input_shape); })) { + throw mk_runtime_error(fmt::format( + "ComputationGraphBuilder::layer_norm received axes {} with " + "out-of-bound element (input tensor has num dimensions = {})", + axes, + num_dims(input_shape))); + } + + LayerNormAttrs attrs = LayerNormAttrs{ + stack_vector{axes.begin(), axes.end()}, + elementwise_affine, + eps, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + if (elementwise_affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm + + TensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{float{1}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + TensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{float{0}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } + + return this->add_layer(layer, {input}, weights, output_shape); +} + +tensor_guid_t ComputationGraphBuilder::softmax( + tensor_guid_t const &input, + std::optional maybe_dim, + std::optional const &maybe_name) { + + TensorShape input_shape = this->get_shape(input); + + int dim = maybe_dim.value_or(num_dims(input_shape) - 1); + + if (dim >= num_dims(input_shape)) { + throw mk_runtime_error( + fmt::format("ComputationGraphBuilder::softmax received out-of-bounds " + "dim {} for input tensor shape {}", + dim, + input_shape)); + } + + SoftmaxAttrs attrs = SoftmaxAttrs{ff_dim_t{dim}}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + return this->add_layer(layer, {input}, {}, output_shape); } } // namespace FlexFlow diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index a0d77b9f76..ae5e120fad 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -13,6 +13,7 @@ ff_add_library( fmt json cuda + doctest ) add_subdirectory(ffi) diff --git a/lib/utils/include/utils/containers/any_of.h b/lib/utils/include/utils/containers/any_of.h new file mode 100644 index 0000000000..1504e2aa84 --- /dev/null +++ b/lib/utils/include/utils/containers/any_of.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ANY_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ANY_OF_H + +namespace FlexFlow { + +template +bool any_of(C const &c, F const &f) { + for (auto const &v : c) { + if (f(v)) { + return true; + } + } + return false; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index c40e05b591..c4e561f059 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -4,6 +4,7 @@ #include "utils/containers/vector_transform.h" #include "utils/required_core.h" #include +#include #include #include @@ -52,6 +53,18 @@ std::unordered_map transform(std::unordered_map const &m, return result; } +template +std::optional> transform(std::optional const &o, + F &&f) { + using Return = std::invoke_result_t; + if (o.has_value()) { + Return r = f(o.value()); + return std::optional{r}; + } else { + return std::nullopt; + } +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/unordered_multiset_of.h b/lib/utils/include/utils/containers/unordered_multiset_of.h new file mode 100644 index 0000000000..3cdef3f8b9 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_multiset_of.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MULTISET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MULTISET_OF_H + +#include + +namespace FlexFlow { + +template +std::unordered_multiset unordered_multiset_of(C const &c) { + return std::unordered_multiset{std::cbegin(c), std::cend(c)}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/expected.h b/lib/utils/include/utils/expected.h index e4d633ca18..84e829ea65 100644 --- a/lib/utils/include/utils/expected.h +++ b/lib/utils/include/utils/expected.h @@ -1,17 +1,25 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_EXPECTED_H #define _FLEXFLOW_UTILS_INCLUDE_EXPECTED_H -#include "tl/expected.hpp" #include "utils/fmt.h" +#include #include +#include namespace FlexFlow { -using namespace tl; - template -unexpected error_msg(Args &&...args) { - return make_unexpected(fmt::format(std::forward(args)...)); +tl::unexpected error_msg(Args &&...args) { + return tl::make_unexpected(fmt::format(std::forward(args)...)); +} + +template +std::optional optional_from_expected(tl::expected const &x) { + if (x.has_value()) { + return x.value(); + } else { + return std::nullopt; + } } } // namespace FlexFlow diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index 5edd054ebe..21a6d28ca2 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -3,14 +3,17 @@ #include "fmt/format.h" #include "utils/check_fmtable.h" +#include #include #include namespace fmt { template -struct formatter<::tl::expected, Char> - /* std::enable_if_t>::value>> */ +struct formatter< + ::tl::expected, + Char, + std::enable_if_t>::value>> : formatter<::std::string> { template auto format(::tl::expected const &m, FormatContext &ctx) @@ -29,4 +32,27 @@ struct formatter<::tl::expected, Char> } // namespace fmt +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, tl::expected const &t) { + CHECK_FMTABLE(T); + CHECK_FMTABLE(E); + + return s << fmt::to_string(t); +} + +} // namespace FlexFlow + +namespace doctest { + +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 46bf9ca8fa..8e186928fd 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -5,6 +5,7 @@ #include "utils/containers/sorted.h" #include "utils/fmt/pair.h" #include "utils/join_strings.h" +#include #include #include @@ -47,4 +48,15 @@ std::ostream &operator<<(std::ostream &s, std::map const &m) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h new file mode 100644 index 0000000000..cff150dc29 --- /dev/null +++ b/lib/utils/include/utils/fmt/multiset.h @@ -0,0 +1,56 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MULTISET_H + +#include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::multiset, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::multiset const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = + ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); + // } + return formatter::format("{" + result + "}", ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::multiset const &x) { + CHECK_FMTABLE(T); + + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 2364e49568..45eebc2c58 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_OPTIONAL_H #include "utils/check_fmtable.h" +#include #include #include @@ -42,4 +43,15 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::optional const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index ab5ddd4e28..6f7e6f6b52 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #include "utils/check_fmtable.h" +#include #include #include @@ -39,4 +40,15 @@ std::ostream &operator<<(std::ostream &s, std::pair const &m) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::pair const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index a183d37542..1f8012f240 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -4,6 +4,7 @@ #include "utils/check_fmtable.h" #include "utils/containers/sorted.h" #include "utils/join_strings.h" +#include #include #include #include @@ -42,4 +43,15 @@ std::ostream &operator<<(std::ostream &s, std::set const &x) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 876a032fe6..75bbb4cb8a 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -5,6 +5,7 @@ #include "utils/fmt/pair.h" #include "utils/join_strings.h" #include +#include #include #include #include @@ -47,4 +48,15 @@ std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h new file mode 100644 index 0000000000..41abbc925e --- /dev/null +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -0,0 +1,56 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_UNORDERED_MULTISET_H + +#include "utils/check_fmtable.h" +#include "utils/join_strings.h" +#include +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::unordered_multiset, + Char, + std::enable_if_t>::value>> + : formatter<::std::string> { + template + auto format(::std::unordered_multiset const &m, FormatContext &ctx) + -> decltype(ctx.out()) { + CHECK_FMTABLE(T); + + std::string result = + ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); + // } + return formatter::format("{" + result + "}", ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +template +std::ostream &operator<<(std::ostream &s, std::unordered_multiset const &x) { + CHECK_FMTABLE(T); + + return s << fmt::to_string(x); +} + +} // namespace FlexFlow + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 257545af1b..646ef0c7c5 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -4,6 +4,7 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" #include "utils/type_traits_core.h" +#include #include #include @@ -42,4 +43,15 @@ std::ostream &operator<<(std::ostream &s, std::unordered_set const &x) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 06a56417c3..867577f72a 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H +#include #include #include @@ -32,4 +33,15 @@ std::ostream &operator<<(std::ostream &s, std::variant const &v) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::variant const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 5d9ca0aeae..96526175a8 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -3,6 +3,7 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" +#include #include #include @@ -40,4 +41,15 @@ std::ostream &operator<<(std::ostream &s, std::vector const &v) { } // namespace FlexFlow +namespace doctest { + +template +struct StringMaker> { + static String convert(std::vector const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + #endif diff --git a/lib/utils/include/utils/graph/query_set.h b/lib/utils/include/utils/graph/query_set.h index 38ef031bf5..4dc1f037e4 100644 --- a/lib/utils/include/utils/graph/query_set.h +++ b/lib/utils/include/utils/graph/query_set.h @@ -7,6 +7,7 @@ #include "utils/containers/filter_keys.h" #include "utils/containers/intersection.h" #include "utils/containers/set_union.h" +#include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" #include "utils/fmt/unordered_set.h" diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3192eb22da..3448ec4e0e 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -23,18 +23,6 @@ T const &assert_unwrap(std::optional const &o) { return o.value(); } -template -std::optional> transform(std::optional const &o, - F &&f) { - using Return = std::invoke_result_t; - if (o.has_value()) { - Return r = f(o.value()); - return std::optional{r}; - } else { - return std::nullopt; - } -} - } // namespace FlexFlow namespace rc { diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 4030611714..1d654e3415 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -1,16 +1,16 @@ #ifndef _FLEXFLOW_UTILS_STACK_VECTOR_H #define _FLEXFLOW_UTILS_STACK_VECTOR_H -#include "hash-utils.h" -#include "rapidcheck.h" -#include "utils/fmt.h" -#include "utils/fmt/vector.h" +#include "utils/hash-utils.h" +#include "utils/join_strings.h" #include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include #include +#include #include +#include #include namespace FlexFlow { @@ -48,6 +48,9 @@ struct stack_vector { } } + stack_vector(std::initializer_list const &l) + : stack_vector(l.begin(), l.end()) {} + operator std::vector() const { return {this->cbegin(), this->cend()}; } @@ -293,10 +296,14 @@ struct stack_vector { return this->contents.data(); } - friend std::vector format_as(stack_vector const &v) { - CHECK_FMTABLE(std::vector); + friend std::string format_as(stack_vector const &v) { + CHECK_FMTABLE(T); - return static_cast>(v); + std::string result = + ::FlexFlow::join_strings(v.cbegin(), v.cend(), ", ", [](T const &t) { + return fmt::to_string(t); + }); + return "[" + result + "]"; } private: diff --git a/lib/utils/src/utils/containers/any_of.cc b/lib/utils/src/utils/containers/any_of.cc new file mode 100644 index 0000000000..c6784c346b --- /dev/null +++ b/lib/utils/src/utils/containers/any_of.cc @@ -0,0 +1 @@ +#include "utils/containers/any_of.h" diff --git a/lib/utils/src/utils/containers/unordered_multiset_of.cc b/lib/utils/src/utils/containers/unordered_multiset_of.cc new file mode 100644 index 0000000000..5add043c76 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_multiset_of.cc @@ -0,0 +1 @@ +#include "utils/containers/unordered_multiset_of.h" diff --git a/lib/utils/src/utils/expected.cc b/lib/utils/src/utils/expected.cc new file mode 100644 index 0000000000..ac61ea6709 --- /dev/null +++ b/lib/utils/src/utils/expected.cc @@ -0,0 +1 @@ +#include "utils/expected.h" diff --git a/lib/utils/src/utils/fmt/multiset.cc b/lib/utils/src/utils/fmt/multiset.cc new file mode 100644 index 0000000000..9f20c0d9d1 --- /dev/null +++ b/lib/utils/src/utils/fmt/multiset.cc @@ -0,0 +1 @@ +#include "utils/fmt/multiset.h" diff --git a/lib/utils/src/utils/fmt/unordered_multiset.cc b/lib/utils/src/utils/fmt/unordered_multiset.cc new file mode 100644 index 0000000000..cf463296cc --- /dev/null +++ b/lib/utils/src/utils/fmt/unordered_multiset.cc @@ -0,0 +1 @@ +#include "utils/fmt/unordered_multiset.h" diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest.h index cb386a5507..6dd21f5ce5 100644 --- a/lib/utils/test/common/include/test/utils/doctest.h +++ b/lib/utils/test/common/include/test/utils/doctest.h @@ -11,6 +11,12 @@ using namespace FlexFlow; namespace doctest { +#define CHECK_WITHOUT_STRINGIFY(...) \ + do { \ + bool result = __VA_ARGS__; \ + CHECK(result); \ + } while (0); + // template // std::string // doctest_print_container(InputIt first, @@ -65,11 +71,4 @@ namespace doctest { // } // }; -template -struct StringMaker> { - static String convert(tl::expected const &m) { - return toString(fmt::to_string(m)); - } -}; - } // namespace doctest diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index b82534a0de..5c2ffd5bba 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -1,5 +1,6 @@ #include "utils/bidict/bidict.h" #include "test/utils/doctest.h" +#include "utils/fmt/vector.h" using namespace FlexFlow; @@ -59,7 +60,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("bidict::end") { auto it = dict.end(); - CHECK(it == dict.end()); + + CHECK_WITHOUT_STRINGIFY(it == dict.end()); } SUBCASE("map_keys(bidict, F)") { diff --git a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc index 99d256251b..2eb8f869f9 100644 --- a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc @@ -1,5 +1,6 @@ #include "utils/bidict/try_merge_nondisjoint_bidicts.h" #include "test/utils/doctest.h" +#include "utils/fmt/optional.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/any_of.cc b/lib/utils/test/src/utils/containers/any_of.cc new file mode 100644 index 0000000000..24227d15d4 --- /dev/null +++ b/lib/utils/test/src/utils/containers/any_of.cc @@ -0,0 +1,36 @@ +#include "utils/containers/any_of.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("any_of(C, F)") { + SUBCASE("has element matching condition") { + std::vector input = {1, 2, 3}; + + bool result = any_of(input, [](int x) { return x > 1; }); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("does not have element matching condition") { + std::vector input = {1, 2, 3}; + + bool result = any_of(input, [](int x) { return x > 5; }); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("input is empty") { + std::vector input = {}; + + bool result = any_of(input, [](int x) { return true; }); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/extend.cc b/lib/utils/test/src/utils/containers/extend.cc index f1c4ff4415..e0d156a3fc 100644 --- a/lib/utils/test/src/utils/containers/extend.cc +++ b/lib/utils/test/src/utils/containers/extend.cc @@ -1,4 +1,6 @@ #include "utils/containers/extend.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filter.cc b/lib/utils/test/src/utils/containers/filter.cc index 5971022c63..da459094ef 100644 --- a/lib/utils/test/src/utils/containers/filter.cc +++ b/lib/utils/test/src/utils/containers/filter.cc @@ -1,5 +1,10 @@ #include "utils/containers/filter.h" #include "test/utils/all.h" +#include "utils/fmt/map.h" +#include "utils/fmt/set.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_keys.cc b/lib/utils/test/src/utils/containers/filtermap_keys.cc index 59d81ebfc6..758264627b 100644 --- a/lib/utils/test/src/utils/containers/filtermap_keys.cc +++ b/lib/utils/test/src/utils/containers/filtermap_keys.cc @@ -1,5 +1,7 @@ #include "utils/containers/filtermap_keys.h" #include "test/utils/doctest.h" +#include "utils/fmt/map.h" +#include "utils/fmt/unordered_map.h" using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_values.cc b/lib/utils/test/src/utils/containers/filtermap_values.cc index 9a23630c91..d2b6ddd220 100644 --- a/lib/utils/test/src/utils/containers/filtermap_values.cc +++ b/lib/utils/test/src/utils/containers/filtermap_values.cc @@ -1,5 +1,7 @@ #include "utils/containers/filtermap_values.h" #include "test/utils/doctest.h" +#include "utils/fmt/map.h" +#include "utils/fmt/unordered_map.h" using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/get_element_counts.cc b/lib/utils/test/src/utils/containers/get_element_counts.cc index 41c9c726d4..11e2ef7e05 100644 --- a/lib/utils/test/src/utils/containers/get_element_counts.cc +++ b/lib/utils/test/src/utils/containers/get_element_counts.cc @@ -1,4 +1,5 @@ #include "utils/containers/get_element_counts.h" +#include "utils/fmt/unordered_map.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/inplace_filter.cc b/lib/utils/test/src/utils/containers/inplace_filter.cc index 349ba21aa0..7ef9d73339 100644 --- a/lib/utils/test/src/utils/containers/inplace_filter.cc +++ b/lib/utils/test/src/utils/containers/inplace_filter.cc @@ -1,5 +1,10 @@ #include "utils/containers/inplace_filter.h" #include "test/utils/all.h" +#include "utils/fmt/map.h" +#include "utils/fmt/set.h" +#include "utils/fmt/unordered_map.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/intersection.cc b/lib/utils/test/src/utils/containers/intersection.cc index cf25c8023a..ac9acf5e2b 100644 --- a/lib/utils/test/src/utils/containers/intersection.cc +++ b/lib/utils/test/src/utils/containers/intersection.cc @@ -1,4 +1,6 @@ #include "utils/containers/intersection.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/repeat.cc b/lib/utils/test/src/utils/containers/repeat.cc index b76faeba91..50e4b3e7c5 100644 --- a/lib/utils/test/src/utils/containers/repeat.cc +++ b/lib/utils/test/src/utils/containers/repeat.cc @@ -1,4 +1,5 @@ #include "utils/containers/repeat.h" +#include "utils/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/transform.cc b/lib/utils/test/src/utils/containers/transform.cc index 903cc23df5..916bc20928 100644 --- a/lib/utils/test/src/utils/containers/transform.cc +++ b/lib/utils/test/src/utils/containers/transform.cc @@ -1,4 +1,7 @@ #include "utils/containers/transform.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" #include using namespace ::FlexFlow; @@ -26,4 +29,26 @@ TEST_SUITE(FF_TEST_SUITE) { std::string correct = "bcd"; CHECK(result == correct); } + + TEST_CASE("transform(std::optional, F)") { + SUBCASE("has value") { + std::optional input = 3; + + std::optional result = + transform(input, [](int x) { return std::to_string(x); }); + std::optional correct = "3"; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + std::optional input = std::nullopt; + + std::optional result = + transform(input, [](int x) { return std::to_string(x); }); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } } diff --git a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc index 7b38ca7365..6aeab4ae6e 100644 --- a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc +++ b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc @@ -1,5 +1,7 @@ #include "utils/containers/try_merge_nondisjoint_unordered_maps.h" #include "test/utils/doctest.h" +#include "utils/fmt/optional.h" +#include "utils/fmt/unordered_map.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc new file mode 100644 index 0000000000..0ab0ef1446 --- /dev/null +++ b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc @@ -0,0 +1,15 @@ +#include "utils/containers/unordered_multiset_of.h" +#include "utils/fmt/unordered_multiset.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("unordered_multiset_of") { + std::vector input = {1, 2, 3, 3, 2, 3}; + std::unordered_multiset result = unordered_multiset_of(input); + std::unordered_multiset correct = {1, 2, 3, 3, 2, 3}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/unordered_set_of.cc b/lib/utils/test/src/utils/containers/unordered_set_of.cc index 1f23663fdc..d42b41dd50 100644 --- a/lib/utils/test/src/utils/containers/unordered_set_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_set_of.cc @@ -1,4 +1,5 @@ #include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_set.h" #include #include diff --git a/lib/utils/test/src/utils/containers/without_order.cc b/lib/utils/test/src/utils/containers/without_order.cc index a2cf453837..939c6ff108 100644 --- a/lib/utils/test/src/utils/containers/without_order.cc +++ b/lib/utils/test/src/utils/containers/without_order.cc @@ -1,4 +1,5 @@ #include "utils/containers/without_order.h" +#include "utils/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/expected.cc b/lib/utils/test/src/utils/expected.cc new file mode 100644 index 0000000000..14679e0d13 --- /dev/null +++ b/lib/utils/test/src/utils/expected.cc @@ -0,0 +1,29 @@ +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("optional_from_expected(tl::expected)") { + SUBCASE("has value") { + tl::expected input = 1; + + std::optional result = optional_from_expected(input); + std::optional correct = 1; + + CHECK(result == correct); + } + + SUBCASE("has unexpected") { + tl::expected input = + tl::make_unexpected("error message"); + + std::optional result = optional_from_expected(input); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/fmt/expected.cc b/lib/utils/test/src/utils/fmt/expected.cc index f21f5d22c8..fb39732761 100644 --- a/lib/utils/test/src/utils/fmt/expected.cc +++ b/lib/utils/test/src/utils/fmt/expected.cc @@ -19,4 +19,24 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("doctest::toString(tl::expected)") { + SUBCASE("has expected") { + tl::expected input = 3; + + doctest::String result = doctest::toString(input); + doctest::String correct = "expected(3)"; + + CHECK(result == correct); + } + + SUBCASE("has unexpected") { + tl::expected input = tl::make_unexpected("error"); + + doctest::String result = doctest::toString(input); + doctest::String correct = "unexpected(error)"; + + CHECK(result == correct); + } + } } diff --git a/lib/utils/test/src/utils/fmt/unordered_set.cc b/lib/utils/test/src/utils/fmt/unordered_set.cc index df889e429c..9dc8d236f1 100644 --- a/lib/utils/test/src/utils/fmt/unordered_set.cc +++ b/lib/utils/test/src/utils/fmt/unordered_set.cc @@ -1,6 +1,7 @@ #include "utils/fmt/unordered_set.h" #include "test/utils/doctest.h" -#include "utils/containers/get_element_counts.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/fmt/unordered_multiset.h" using namespace ::FlexFlow; @@ -9,10 +10,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set input = {0, 1, 3, 2}; std::string result = fmt::to_string(input); std::string correct = "{0, 1, 2, 3}"; - std::unordered_map result_char_counts = - get_element_counts(result); - std::unordered_map correct_char_counts = - get_element_counts(correct); - CHECK(result_char_counts == correct_char_counts); + std::unordered_multiset unordered_result = + unordered_multiset_of(result); + std::unordered_multiset unordered_correct = + unordered_multiset_of(correct); + CHECK(unordered_result == unordered_correct); } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.cc index c1fe605349..edaffb0a93 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/complete_bipartite_composite_decomposition.h" +#include "utils/fmt/optional.h" #include "utils/fmt/unordered_set.h" #include "utils/hash/unordered_set.h" #include diff --git a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc index fbc69dc7f0..4560f95ff7 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc @@ -1,4 +1,5 @@ #include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/fmt/variant.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc index 7b8548eac7..7cf17c3fee 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc @@ -1,4 +1,5 @@ #include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/fmt/unordered_set.h" #include using namespace ::FlexFlow; From 2b4106f5bb2cd737a786bc4b4ded7b237a236d18 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Thu, 5 Sep 2024 21:39:34 -0700 Subject: [PATCH 03/12] Re-enable `substitutions` (#1471) * Start on pcg builder * Add tests and some implementation for pcg builder * Add pcg tests, make dtgen constructors explicit to fix bug * Add remainder of PCG tests * Fix build issues in local-execution * Format * Address Reyna comments, add topological_order function for PCG * Pre multidigraph refactor * Removing visitable from sp code * Add open dataflow graph, start to replace pcg dataflow graph * Start refactoring substitutions * Add utility functions to support pattern matching * Pre-refactor inputs * Fix proj url * Get back to substitutions, now with unordered graph inputs * Get substitutions building * substitutions-tests now builds * Fix bug in filter, pass some initial substitution tests * Add tests for fmt::to_string, fix some substitutions bugs * Pass initial unit tests for find_pattern_matches * Start on unit tests for pcg pattern * Pass initial test for find_pattern_matches * Fix small build issue in tests * Format * Sync tests in CI with tests in proj * Fix minor build errors in kernels and local-execution * Format * Remove outdated code * More outdated code removal * More cleanup, add test for sp decomposition * Pull apart containers.h * More sp testing and fixes * Break up graph algorithms.h * Pre- full SP algo commit * Add initial implementation and tests for cbc decomposition and inverse line graph * Pass test for get_inverse_line_graph * Add new multidigraph * Fix get_inverse_line_graph to return a MultiDiGraph instead of a DiGraph * Add tests for parallel and series reduction finding * Add really rough implementation of valdez sp decomposition * Fix local-execution build * Add implementations and tests for applying series/parallel reductions * Format * Clean up sp decomposition interface and tests * Format * Add comments for top-level substitutions functions, add proj doxygen support * Start sketching out substitutions code * Fix build errors * Add ability to permute node ids * Cleanup and start to test new substitutions code * Add test case for evaluate_substitution_output * Add naive isomorphism detection code * Add graph inputs to open dataflow graph isomorphism * Add input permutation to evaluate_substitution_output * Fix permute_node_ids * Add test for permute_input_ids * Migrate over to mutable implementation of apply_substitution * Add fast isomorphism checking and an initial implementation of full substitution logic * Pass initial full substitutions test * Cleanup old isomorphism checking code * Fix post-merge bugs * Fix broken pcg builder test * Format * Reorganize code and remove some outdated code pre-code-review * Format * Address review comments * Address missed comment * Remove latex dependency to avoid CI out-of-disk-space * Format * Fix build issues * Fix incorrect test case --- flake.lock | 6 +- lib/compiler/src/machine_mapping.cc | 2 +- .../src/{ => local-execution}/op_arg_spec.cc | 0 lib/local-execution/src/ops/pool_2d.cc | 1 - lib/local-execution/src/ops/transpose.cc | 1 - lib/op-attrs/include/op-attrs/as_dot.h | 15 - .../op-attrs/computation_graph_op_attrs.h | 2 + .../include/op-attrs/dim_ordered/enumerate.h | 30 ++ .../include/op-attrs/get_output_shapes.h | 223 +--------- lib/op-attrs/include/op-attrs/is_valid.h | 59 +++ lib/op-attrs/include/op-attrs/ops/attention.h | 2 +- .../include/op-attrs/ops/batch_norm.h | 3 +- lib/op-attrs/include/op-attrs/ops/cast.h | 4 +- lib/op-attrs/include/op-attrs/ops/concat.h | 13 +- lib/op-attrs/include/op-attrs/ops/flat.h | 9 +- lib/op-attrs/include/op-attrs/ops/gather.h | 9 +- lib/op-attrs/include/op-attrs/ops/input.h | 6 +- .../include/op-attrs/ops/layer_norm.h | 5 +- lib/op-attrs/include/op-attrs/ops/linear.h | 3 + .../include/op-attrs/ops/loss_functions.h | 75 ---- .../loss_functions/loss_attrs.variant.toml | 23 ++ .../loss_functions/loss_function.enum.toml | 23 ++ .../ops/loss_functions/loss_functions.h | 20 + .../other_loss_attrs.struct.toml | 18 + ...rical_cross_entropy_loss_attrs.struct.toml | 15 + lib/op-attrs/include/op-attrs/ops/noop.h | 4 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 4 +- lib/op-attrs/include/op-attrs/ops/replicate.h | 2 +- lib/op-attrs/include/op-attrs/ops/reshape.h | 4 +- lib/op-attrs/include/op-attrs/ops/reverse.h | 4 +- lib/op-attrs/include/op-attrs/ops/split.h | 5 +- lib/op-attrs/include/op-attrs/ops/topk.h | 2 +- lib/op-attrs/include/op-attrs/ops/transpose.h | 8 +- lib/op-attrs/include/op-attrs/ops/weight.h | 18 + .../op-attrs/ops/weight_attrs.struct.toml | 9 +- .../include/op-attrs/parallel_op_attrs.h | 17 + .../op-attrs/parallel_op_attrs.variant.toml | 34 ++ .../include/op-attrs/parallel_tensor_shape.h | 2 + .../parallel_tensor_shape.struct.toml | 4 +- .../include/op-attrs/pcg_operator_attrs.h | 1 + lib/op-attrs/src/get_output_shapes.cc | 29 -- lib/op-attrs/src/op-attrs/as_dot.cc | 13 - lib/op-attrs/src/op-attrs/datatype.cc | 2 +- .../src/op-attrs/dim_ordered/enumerate.cc | 1 + .../src/op-attrs/get_output_shapes.cc | 85 ++++ lib/op-attrs/src/op-attrs/is_valid.cc | 3 + lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 5 + lib/op-attrs/src/op-attrs/ops/cast.cc | 1 + lib/op-attrs/src/op-attrs/ops/concat.cc | 10 + lib/op-attrs/src/op-attrs/ops/embedding.cc | 1 + lib/op-attrs/src/op-attrs/ops/flat.cc | 19 +- lib/op-attrs/src/op-attrs/ops/gather.cc | 12 + lib/op-attrs/src/op-attrs/ops/input.cc | 6 +- lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 2 + lib/op-attrs/src/op-attrs/ops/linear.cc | 16 + .../src/{ => op-attrs/ops}/loss_functions.cc | 17 +- lib/op-attrs/src/op-attrs/ops/noop.cc | 5 + lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 4 + lib/op-attrs/src/op-attrs/ops/reshape.cc | 5 + lib/op-attrs/src/op-attrs/ops/reverse.cc | 4 + lib/op-attrs/src/op-attrs/ops/split.cc | 5 + lib/op-attrs/src/op-attrs/ops/transpose.cc | 4 + lib/op-attrs/src/op-attrs/ops/weight.cc | 14 + .../src/op-attrs/parallel_op_attrs.cc | 37 ++ .../src/op-attrs/parallel_tensor_shape.cc | 14 + .../src/op-attrs/pcg_operator_attrs.cc | 12 + lib/op-attrs/src/op_attrs.cc | 10 - .../src/op-attrs/dim_ordered/enumerate.cc | 20 + .../test/src/op-attrs/ops/layer_norm.cc | 1 + lib/op-attrs/test/src/ops/cast.cc | 1 + .../include/pcg/computation_graph_builder.h | 9 +- .../file_format/v1/graphs/v1_dataflow_graph.h | 2 +- .../v1/graphs/v1_dataflow_graph.struct.toml | 2 +- .../v1/graphs/v1_graph_edge.struct.toml | 8 +- .../v1/graphs/v1_labelled_dataflow_graph.h | 8 +- .../v1_labelled_dataflow_graph.struct.toml | 4 +- .../generate_weight_transform.h | 16 + .../parallel_computation_graph.h | 3 - .../parallel_computation_graph_builder.h | 34 ++ .../parallel_computation_graph_edge.h | 18 + ...arallel_computation_graph_edge.struct.toml | 16 + .../parallel_tensor_guid_t.h | 13 + .../parallel_tensor_use_t.struct.toml | 16 + lib/pcg/src/pcg/computation_graph.cc | 2 +- lib/pcg/src/pcg/computation_graph_builder.cc | 21 +- .../v1/graphs/v1_dataflow_graph.cc | 15 +- .../generate_weight_transform.cc | 35 ++ .../parallel_computation_graph.cc | 2 +- .../parallel_computation_graph_builder.cc | 158 ++++++- .../parallel_computation_graph_edge.cc | 22 + .../parallel_tensor_guid_t.cc | 9 + .../parallel_computation_graph_builder.cc | 28 +- lib/substitution-generator/CMakeLists.txt | 1 + .../include/substitution-generator/json.h | 59 --- .../substitution-generator/legacy_rules.h | 59 +++ .../{json.cc => legacy_rules.cc} | 20 +- .../{json.cc => legacy_rules.cc} | 6 +- .../input_parallel_tensor_guid_t.struct.toml | 16 + .../open_parallel_tensor_guid_t.h | 29 ++ .../open_parallel_tensor_guid_t.struct.toml | 16 + .../operator_attribute_constraint.h | 18 + .../operator_attribute_expr.h | 4 +- .../operator_attribute_value.variant.toml | 5 +- .../materialize_operator_from_attrs_map.h | 15 + .../output_graph/output_graph_expr.h | 15 + .../output_graph_expr.struct.toml | 3 +- .../output_graph_expr_input.struct.toml | 16 + .../output_graph_expr_node.struct.toml | 16 + .../output_graph_expr_node_output.struct.toml | 16 + .../output_operator_attr_access.struct.toml | 4 +- .../output_operator_attribute_expr.h | 15 + .../output_operator_attrs_assignment.h | 25 ++ ...tput_operator_attrs_assignment.struct.toml | 5 + .../output_pattern_value.struct.toml | 16 + .../include/substitutions/pcg_pattern.h | 8 +- .../substitutions/pcg_pattern_builder.h | 34 ++ .../include/substitutions/pcg_pattern_match.h | 24 ++ .../pcg_pattern_match.struct.toml | 29 ++ .../sub_parallel_computation_graph.h | 47 ++- ...arallel_computation_graph_data.struct.toml | 41 ++ .../sub_parallel_computation_graph_edge.h | 24 ++ ...arallel_computation_graph_edge.struct.toml | 16 + .../include/substitutions/substitution.h | 10 +- .../substitutions/substitution.struct.toml | 12 +- .../evaluate_substitution_output.h | 31 ++ .../output_expr_to_result_sub_pcg_mapping.h | 20 + ...expr_to_result_sub_pcg_mapping.struct.toml | 23 ++ .../perform_shape_inference.h | 33 ++ .../tensor_attribute_key.enum.toml | 6 + .../tensor_pattern/tensor_attribute_pattern.h | 12 + .../unlabelled_dataflow_graph_pattern_match.h | 7 + .../unlabelled/unlabelled_graph_pattern.h | 3 +- lib/substitutions/src/substitution.cc | 387 ------------------ .../open_parallel_tensor_guid_t.cc | 16 + .../operator_pattern/get_attribute.cc | 8 +- .../operator_attribute_constraint.cc | 33 ++ .../operator_attribute_expr.cc | 4 +- .../operator_pattern/satisfies_constraint.cc | 2 +- .../materialize_operator_from_attrs_map.cc | 155 +++++++ .../output_graph/output_graph_expr.cc | 17 + .../output_operator_attribute_expr.cc | 19 + .../output_operator_attrs_assignment.cc | 42 ++ .../src/substitutions/pcg_pattern.cc | 58 ++- .../src/substitutions/pcg_pattern_builder.cc | 52 +++ .../src/substitutions/pcg_pattern_match.cc | 49 +++ .../sub_parallel_computation_graph.cc | 181 +++++++- .../sub_parallel_computation_graph_edge.cc | 38 ++ .../src/substitutions/substitution.cc | 299 +++++++------- .../evaluate_substitution_output.cc | 94 +++++ .../output_expr_to_result_sub_pcg_mapping.cc | 32 ++ .../perform_shape_inference.cc | 45 ++ .../tensor_pattern/get_attribute.cc | 1 + .../tensor_attribute_pattern.cc | 9 + .../unlabelled/find_pattern_matches.cc | 5 +- .../unlabelled/pattern_matching.cc | 12 +- .../unlabelled/unlabelled_graph_pattern.cc | 10 +- .../operator_pattern/get_attribute.cc | 34 ++ .../test/src/substitutions/pcg_pattern.cc | 102 ++--- .../test/src/substitutions/substitution.cc | 229 +++++++++++ .../evaluate_substitution_output.cc | 274 +++++++++++++ .../perform_shape_inference.cc | 173 ++++++++ .../substitutions/unlabelled/pattern_split.cc | 6 +- .../test/src/test_pattern_matches.cc | 3 +- .../algorithms/bidict_from_enumerating.h | 35 ++ .../algorithms/bidict_from_keys_and_values.h | 28 ++ .../bidict/algorithms/bidict_from_pairs.h | 17 + .../utils/bidict/algorithms/left_entries.h | 20 + .../utils/bidict/algorithms/merge_bidicts.h | 29 ++ .../utils/bidict/algorithms/right_entries.h | 20 + lib/utils/include/utils/bidict/bidict.h | 9 + lib/utils/include/utils/containers.decl.h | 7 - lib/utils/include/utils/containers.h | 31 -- .../include/utils/containers/enumerate.h | 35 +- .../utils/containers/enumerate_vector.h | 11 +- lib/utils/include/utils/containers/filtrans.h | 73 ++++ .../utils/containers/get_all_permutations.h | 106 +++++ lib/utils/include/utils/containers/keys.h | 17 +- .../include/utils/containers/merge_maps.h | 28 ++ .../include/utils/containers/set_union.h | 2 +- .../include/utils/containers/transform.h | 12 + lib/utils/include/utils/dot_file.h | 26 +- lib/utils/include/utils/exception.decl.h | 35 -- lib/utils/include/utils/exception.h | 19 +- lib/utils/include/utils/fmt.decl.h | 27 -- lib/utils/include/utils/fmt.h | 10 +- .../utils/graph/dataflow_graph/algorithms.h | 5 +- .../graph/dataflow_graph/algorithms/as_dot.h | 16 + .../dataflow_graph_isomorphism.struct.toml | 16 + .../dataflow_graphs_are_isomorphic.h | 13 + .../algorithms/find_isomorphism.h | 20 + .../algorithms/find_isomorphisms.h | 14 + .../algorithms/get_incoming_edges.h | 16 + .../algorithms/get_subgraph_outgoing_edges.h | 14 + .../dataflow_graph/dataflow_edge_query.h | 3 + .../graph/dataflow_graph/dataflow_graph.h | 3 +- .../dataflow_graph/dataflow_output_query.h | 4 + ...ordered_set_labelled_open_dataflow_graph.h | 28 +- .../algorithms/find_isomorphism.h | 30 ++ .../algorithms/is_isomorphic_under.h | 25 ++ .../labelled_dataflow_graphs_are_isomorphic.h | 18 + .../view_as_labelled_open_dataflow_graph.h | 4 +- .../i_labelled_dataflow_graph_view.h | 4 +- .../labelled_dataflow_graph_view.h | 4 +- .../algorithms/as_dot.h | 29 ++ .../algorithms/find_isomorphism.h | 39 ++ .../from_labelled_open_dataflow_graph_data.h | 36 ++ .../algorithms/get_graph_data.h | 36 ++ .../algorithms/is_isomorphic_under.h | 36 ++ ...elled_open_dataflow_graph_data.struct.toml | 41 ++ ...lled_open_dataflow_graphs_are_isomorphic.h | 18 + .../algorithms/permute_input_ids.h | 46 +++ .../algorithms/permute_node_ids.h | 54 +++ .../algorithms/rewrite_labels.h | 2 +- .../algorithms/rewrite_node_labels.h | 26 ++ .../algorithms/rewrite_value_labels.h | 27 ++ .../algorithms/with_labelling.h | 4 +- .../i_labelled_open_dataflow_graph.h | 4 + .../i_labelled_open_dataflow_graph_view.h | 6 +- .../labelled_open_dataflow_graph.h | 10 + .../labelled_open_dataflow_graph_view.h | 4 +- .../generate_new_node_id_permutation.h | 13 + .../node/algorithms/new_node.struct.toml | 16 + .../include/utils/graph/node/node_query.h | 2 + .../open_dataflow_graph/algorithms/as_dot.h | 17 + .../algorithms/find_isomorphism.h | 20 + .../algorithms/find_isomorphisms.h | 15 + .../from_open_dataflow_graph_data.h | 31 ++ .../generate_new_input_id_permutation.h | 14 + .../algorithms/get_edges.h | 12 + .../algorithms/get_graph_data.h | 13 + .../algorithms/get_incoming_edge.h | 13 + .../get_incoming_edges.h} | 14 +- .../algorithms/get_inputs.h | 14 + .../get_open_dataflow_graph_inputs.h | 13 + .../algorithms/get_open_dataflow_value_uses.h | 15 + .../algorithms/get_open_dataflow_values.h | 14 + .../algorithms/get_source_nodes.h | 12 + .../algorithms/get_subgraph_incoming_edges.h | 14 + .../get_unused_open_dataflow_graph_inputs.h | 13 + .../algorithms/is_isomorphic_under.h | 15 + .../new_dataflow_graph_input.struct.toml | 16 + .../open_dataflow_graph_data.struct.toml | 36 ++ ...pen_dataflow_graph_isomorphism.struct.toml | 21 + .../open_dataflow_graphs_are_isomorphic.h | 13 + .../algorithms/permute_input_ids.h | 15 + .../algorithms/permute_node_ids.h | 15 + .../dataflow_input_edge_query.h | 6 + .../open_dataflow_graph/open_dataflow_edge.h | 3 +- .../open_dataflow_edge_query.h | 8 + .../open_dataflow_graph/open_dataflow_value.h | 16 + lib/utils/include/utils/required_core.h | 6 +- lib/utils/include/utils/tuple.h | 2 +- .../algorithms/bidict_from_enumerating.cc | 1 + .../algorithms/bidict_from_keys_and_values.cc | 1 + .../bidict/algorithms/bidict_from_pairs.cc | 1 + .../utils/bidict/algorithms/left_entries.cc | 1 + .../utils/bidict/algorithms/merge_bidicts.cc | 1 + .../utils/bidict/algorithms/right_entries.cc | 1 + lib/utils/src/utils/containers/filtrans.cc | 1 + .../utils/containers/get_all_permutations.cc | 1 + lib/utils/src/utils/containers/merge_maps.cc | 1 + lib/utils/src/utils/exception.cc | 1 + .../utils/graph/dataflow_graph/algorithms.cc | 20 +- .../graph/dataflow_graph/algorithms/as_dot.cc | 64 +++ .../dataflow_graphs_are_isomorphic.cc | 11 + .../algorithms/find_isomorphism.cc | 20 + .../algorithms/find_isomorphisms.cc | 22 + .../algorithms/get_incoming_edges.cc | 31 ++ .../algorithms/get_subgraph_outgoing_edges.cc | 24 ++ .../algorithms/view_as_open_dataflow_graph.cc | 41 ++ .../algorithms/view_as_open_dataflow_graph.h | 32 ++ .../dataflow_graph/dataflow_edge_query.cc | 29 ++ .../dataflow_graph/dataflow_output_query.cc | 15 + .../algorithms/is_isomorphic_under.cc | 1 + ...labelled_dataflow_graphs_are_isomorphic.cc | 1 + .../algorithms/as_dot.cc | 1 + .../algorithms/find_isomorphism.cc | 1 + .../from_labelled_open_dataflow_graph_data.cc | 1 + .../algorithms/get_graph_data.cc | 1 + .../algorithms/is_isomorphic_under.cc | 1 + ...led_open_dataflow_graphs_are_isomorphic.cc | 1 + .../algorithms/permute_input_ids.cc | 1 + .../algorithms/rewrite_node_labels.cc | 1 + .../algorithms/rewrite_value_labels.cc | 1 + .../generate_new_node_id_permutation.cc | 16 + lib/utils/src/utils/graph/node/node_query.cc | 5 + .../open_dataflow_graph/algorithms/as_dot.cc | 63 +++ .../algorithms/find_isomorphism.cc | 20 + .../algorithms/find_isomorphisms.cc | 248 +++++++++++ .../from_open_dataflow_graph_data.cc | 41 ++ .../generate_new_input_id_permutation.cc | 19 + .../algorithms/get_edges.cc | 10 + .../algorithms/get_graph_data.cc | 17 + .../algorithms/get_incoming_edge.cc | 15 + .../get_incoming_edges.cc} | 39 +- .../algorithms/get_inputs.cc | 15 + .../get_open_dataflow_graph_inputs.cc | 10 + .../get_open_dataflow_value_uses.cc | 17 + .../algorithms/get_open_dataflow_values.cc | 17 + .../algorithms/get_source_nodes.cc | 16 + .../algorithms/get_subgraph_incoming_edges.cc | 29 ++ .../algorithms/get_subgraph_inputs.cc | 4 +- .../get_unused_open_dataflow_graph_inputs.cc | 15 + .../algorithms/is_isomorphic_under.cc | 30 ++ .../open_dataflow_graphs_are_isomorphic.cc | 11 + .../algorithms/permute_input_ids.cc | 44 ++ .../algorithms/permute_node_ids.cc | 72 ++++ .../dataflow_input_edge_query.cc | 27 ++ .../open_dataflow_graph/open_dataflow_edge.cc | 18 +- .../open_dataflow_edge_query.cc | 34 ++ .../open_dataflow_value.cc | 22 + .../graph/serial_parallel/graph_generation.cc | 3 +- .../algorithms/bidict_from_enumerating.cc | 38 ++ .../test/src/utils/containers/enumerate.cc | 50 +++ .../test/src/utils/containers/filtrans.cc | 57 +++ .../utils/containers/get_all_permutations.cc | 54 +++ .../utils/graph/dataflow_graph/algorithms.cc | 4 +- .../dataflow_graphs_are_isomorphic.cc | 93 +++++ .../algorithms/find_isomorphism.cc | 101 +++++ .../algorithms/get_subgraph_outgoing_edges.cc | 41 ++ .../algorithms/find_isomorphism.cc | 187 +++++++++ ...labelled_dataflow_graphs_are_isomorphic.cc | 169 ++++++++ .../algorithms/find_isomorphism.cc | 185 +++++++++ .../algorithms/is_isomorphic_under.cc | 60 +++ ...led_open_dataflow_graphs_are_isomorphic.cc | 121 ++++++ .../algorithms/find_isomorphism.cc | 134 ++++++ .../get_open_dataflow_graph_inputs.cc | 24 ++ .../get_open_dataflow_value_uses.cc | 74 ++++ .../get_unused_open_dataflow_graph_inputs.cc | 41 ++ .../open_dataflow_graphs_are_isomorphic.cc | 110 +++++ .../algorithms/permute_input_ids.cc | 79 ++++ .../algorithms/permute_node_ids.cc | 175 ++++++++ 332 files changed, 7939 insertions(+), 1396 deletions(-) rename lib/local-execution/src/{ => local-execution}/op_arg_spec.cc (100%) delete mode 100644 lib/op-attrs/include/op-attrs/as_dot.h create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h create mode 100644 lib/op-attrs/include/op-attrs/is_valid.h delete mode 100644 lib/op-attrs/include/op-attrs/ops/loss_functions.h create mode 100644 lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h create mode 100644 lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml create mode 100644 lib/op-attrs/include/op-attrs/ops/weight.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_op_attrs.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml delete mode 100644 lib/op-attrs/src/get_output_shapes.cc delete mode 100644 lib/op-attrs/src/op-attrs/as_dot.cc create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc create mode 100644 lib/op-attrs/src/op-attrs/get_output_shapes.cc create mode 100644 lib/op-attrs/src/op-attrs/is_valid.cc rename lib/op-attrs/src/{ => op-attrs/ops}/loss_functions.cc (80%) create mode 100644 lib/op-attrs/src/op-attrs/ops/weight.cc create mode 100644 lib/op-attrs/src/op-attrs/parallel_op_attrs.cc delete mode 100644 lib/op-attrs/src/op_attrs.cc create mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h create mode 100644 lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc create mode 100644 lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc delete mode 100644 lib/substitution-generator/include/substitution-generator/json.h create mode 100644 lib/substitution-generator/include/substitution-generator/legacy_rules.h rename lib/substitution-generator/src/substitution-generator/{json.cc => legacy_rules.cc} (67%) rename lib/substitution-generator/test/substitution-generator/{json.cc => legacy_rules.cc} (88%) create mode 100644 lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml create mode 100644 lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h create mode 100644 lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml create mode 100644 lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h create mode 100644 lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h create mode 100644 lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml create mode 100644 lib/substitutions/include/substitutions/pcg_pattern_builder.h create mode 100644 lib/substitutions/include/substitutions/pcg_pattern_match.h create mode 100644 lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml create mode 100644 lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml create mode 100644 lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h create mode 100644 lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml create mode 100644 lib/substitutions/include/substitutions/substitution_internal/evaluate_substitution_output.h create mode 100644 lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h create mode 100644 lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.struct.toml create mode 100644 lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h create mode 100644 lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h delete mode 100644 lib/substitutions/src/substitution.cc create mode 100644 lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc create mode 100644 lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.cc create mode 100644 lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc create mode 100644 lib/substitutions/src/substitutions/pcg_pattern_builder.cc create mode 100644 lib/substitutions/src/substitutions/pcg_pattern_match.cc create mode 100644 lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc create mode 100644 lib/substitutions/src/substitutions/substitution_internal/evaluate_substitution_output.cc create mode 100644 lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc create mode 100644 lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc create mode 100644 lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc create mode 100644 lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc create mode 100644 lib/substitutions/test/src/substitutions/substitution.cc create mode 100644 lib/substitutions/test/src/substitutions/substitution_internal/evaluate_substitution_output.cc create mode 100644 lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc create mode 100644 lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h create mode 100644 lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h create mode 100644 lib/utils/include/utils/bidict/algorithms/bidict_from_pairs.h create mode 100644 lib/utils/include/utils/bidict/algorithms/left_entries.h create mode 100644 lib/utils/include/utils/bidict/algorithms/merge_bidicts.h create mode 100644 lib/utils/include/utils/bidict/algorithms/right_entries.h create mode 100644 lib/utils/include/utils/containers/filtrans.h create mode 100644 lib/utils/include/utils/containers/get_all_permutations.h create mode 100644 lib/utils/include/utils/containers/merge_maps.h delete mode 100644 lib/utils/include/utils/exception.decl.h delete mode 100644 lib/utils/include/utils/fmt.decl.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphism.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphisms.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/get_incoming_edges.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h create mode 100644 lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h create mode 100644 lib/utils/include/utils/graph/node/algorithms/generate_new_node_id_permutation.h create mode 100644 lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_edges.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_graph_data.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h rename lib/utils/include/utils/graph/open_dataflow_graph/{algorithms.h => algorithms/get_incoming_edges.h} (53%) create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_inputs.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h create mode 100644 lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.h create mode 100644 lib/utils/src/utils/bidict/algorithms/bidict_from_enumerating.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/bidict_from_keys_and_values.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/bidict_from_pairs.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/left_entries.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/right_entries.cc create mode 100644 lib/utils/src/utils/containers/filtrans.cc create mode 100644 lib/utils/src/utils/containers/get_all_permutations.cc create mode 100644 lib/utils/src/utils/containers/merge_maps.cc create mode 100644 lib/utils/src/utils/exception.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphisms.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.cc create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.cc create mode 100644 lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.cc create mode 100644 lib/utils/src/utils/graph/node/algorithms/generate_new_node_id_permutation.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_edges.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_graph_data.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.cc rename lib/utils/src/utils/graph/open_dataflow_graph/{algorithms.cc => algorithms/get_incoming_edges.cc} (55%) create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_inputs.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc create mode 100644 lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc create mode 100644 lib/utils/test/src/utils/containers/enumerate.cc create mode 100644 lib/utils/test/src/utils/containers/filtrans.cc create mode 100644 lib/utils/test/src/utils/containers/get_all_permutations.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc create mode 100644 lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc create mode 100644 lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc create mode 100644 lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc create mode 100644 lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc create mode 100644 lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc create mode 100644 lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc create mode 100644 lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc diff --git a/flake.lock b/flake.lock index b36a96ee80..1aad68ae29 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1722405648, - "narHash": "sha256-+9cRIT+bwo7qxI966HjwR2Sw37CcXD1JlG9nw+vq2lY=", + "lastModified": 1722923482, + "narHash": "sha256-myUec+oBcnKNCqLQqSiPCyXFsIsvlrsGoj/mQFlHVrY=", "owner": "lockshaw", "repo": "proj", - "rev": "3674de6208c52f3a022e8f00660ee01d580aa466", + "rev": "c650b0e52337652ea7190131988c0370e0ee7f25", "type": "github" }, "original": { diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 12eacb2a30..af7756c635 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -12,10 +12,10 @@ #include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" #include "utils/exception.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" #include "utils/graph/serial_parallel/serial_parallel_decomposition.h" diff --git a/lib/local-execution/src/op_arg_spec.cc b/lib/local-execution/src/local-execution/op_arg_spec.cc similarity index 100% rename from lib/local-execution/src/op_arg_spec.cc rename to lib/local-execution/src/local-execution/op_arg_spec.cc diff --git a/lib/local-execution/src/ops/pool_2d.cc b/lib/local-execution/src/ops/pool_2d.cc index 789ed2cd63..33d62b713c 100644 --- a/lib/local-execution/src/ops/pool_2d.cc +++ b/lib/local-execution/src/ops/pool_2d.cc @@ -3,7 +3,6 @@ #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/pool_2d.h" -#include "utils/exception.decl.h" #include "utils/exception.h" #include "utils/hash-utils.h" diff --git a/lib/local-execution/src/ops/transpose.cc b/lib/local-execution/src/ops/transpose.cc index 5c3c1dd1ca..3e4ac15db3 100644 --- a/lib/local-execution/src/ops/transpose.cc +++ b/lib/local-execution/src/ops/transpose.cc @@ -17,7 +17,6 @@ #include "kernels/transpose_kernels.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/transpose.h" -#include "utils/exception.decl.h" using namespace FlexFlow::Kernels::Transpose; diff --git a/lib/op-attrs/include/op-attrs/as_dot.h b/lib/op-attrs/include/op-attrs/as_dot.h deleted file mode 100644 index d92557c2f4..0000000000 --- a/lib/op-attrs/include/op-attrs/as_dot.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H -#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_AS_DOT_H - -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/record_formatter.h" - -namespace FlexFlow { - -RecordFormatter as_dot(ComputationGraphOpAttrs const &); -RecordFormatter as_dot(PCGOperatorAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h index 4be17798f7..03f38bb8f9 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -2,10 +2,12 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H #include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "utils/record_formatter.h" namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &); +RecordFormatter as_dot(ComputationGraphOpAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h new file mode 100644 index 0000000000..f9f6d00532 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H + +#include "op-attrs/dim_ordered.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/count.h" + +namespace FlexFlow { + +/** + * @brief Generate a map from indices to elements of \p c. + * + * @note We return a std::map to prevent mixups of \ref ff_dim_t and + * \ref legion_dim_t. Note that std::map provides ordered iteration in + * increasing order, so iterating through the result of this function should + * function as expected. + */ +template +std::map enumerate(FFOrdered const &ff_ordered) { + std::map result; + for (int raw_ff_dim : count(ff_ordered.size())) { + ff_dim_t ff_dim = ff_dim_t{raw_ff_dim}; + result.insert({ff_dim, ff_ordered.at(ff_dim)}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/get_output_shapes.h b/lib/op-attrs/include/op-attrs/get_output_shapes.h index 612c226a13..c27bbb190f 100644 --- a/lib/op-attrs/include/op-attrs/get_output_shapes.h +++ b/lib/op-attrs/include/op-attrs/get_output_shapes.h @@ -1,228 +1,15 @@ #ifndef _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H #define _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "ops/reverse.h" -#include "tensor_shape.h" -#include "utils/containers/get_only.h" -#include "utils/optional.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include namespace FlexFlow { -template -struct has_unary_output_t : std::false_type {}; -template -struct has_unary_input_t : std::false_type {}; -template -struct has_binary_input_t : std::false_type {}; - -template -struct has_multi_output_t : std::true_type {}; -template -struct has_multi_input_t : std::true_type {}; - -template -struct has_multi_output_t< - T, - typename std::enable_if::value>::type> - : std::false_type {}; - -template -struct has_multi_input_t< - T, - typename std::enable_if<(has_unary_input_t::value || - has_binary_input_t::value)>::type> - : std::false_type {}; - -/* template struct output_type_t { using - * type = std::vector; }; */ - -template -typename std::enable_if::value, bool>::type - is_valid(T const &t, std::vector const &shapes) { - if (shapes.size() != 1) { - return false; - } - - return is_valid(t, get_only(shapes)); -} - -template -typename std::enable_if::value, bool>::type - is_valid(T const &t, std::vector const &shapes) { - if (shapes.size() != 2) { - return false; - } - - return is_valid(t, shapes.at(0), shapes.at(1)); -} - -template -typename std::enable_if<(has_unary_input_t::value && - has_unary_output_t::value), - ParallelTensorShape>::type - output_shapes(T const &t, std::vector const &shapes) { - return output_shape(t, get_only(shapes)); -} - -template -typename std::enable_if<(has_binary_input_t::value && - has_unary_output_t::value), - std::vector>::type - output_shapes(T const &t, std::vector const &shapes) { - assert(shapes.size() == 2); - - return {output_shape(t, shapes.at(0), shapes.at(1))}; -} - -TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); -std::vector - get_tensor_shapes_unsafe(std::vector const &); - -template -TensorShape get_output_shape(Attrs const &attrs, TensorShape const &shape) { - NOT_IMPLEMENTED(); -} - -template -TensorShape get_output_shape(Attrs const &attrs, - TensorShape const &, - TensorShape const &) { - NOT_IMPLEMENTED(); -} - -template -TensorShape get_output_shape(Attrs const &attrs, - std::vector const &) { - NOT_IMPLEMENTED(); -} -template -std::vector get_output_shapes(Attrs const &attrs, - TensorShape const &); -template -std::vector get_output_shapes(Attrs const &attrs, - TensorShape const &, - TensorShape const &) { - NOT_IMPLEMENTED(); -} -template -std::vector get_output_shapes(Attrs const &attrs, - std::vector const &); - -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); -std::vector get_output_shapes(GatherAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReduceAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(ReverseAttrs const &, - ParallelTensorShape const &); -std::vector get_output_shapes(SplitAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(TopKAttrs const &, - ParallelTensorShape const &); -ParallelTensorShape get_output_shape(TransposeAttrs const &, - std::vector const &); - -struct GetOutputShapesFunctor { - GetOutputShapesFunctor(std::vector const &s) : s(s) {} - - std::vector const &s; - - template - std::vector operator()(T const &t) { - return get_output_shapes(t, s); - } -}; - -template std::vector - get_output_shapes(std::variant const &t, - std::vector const &s) { - return get_output_shape(GetOutputShapesFunctor{s}, t); -} - -template -typename std::enable_if::value, std::optional>::type - get_num_outputs(T const &) { - return std::nullopt; -} - -template -typename std::enable_if::value, std::optional>::type - get_num_outputs(T const &) { - return 1; -} - -int get_num_outputs(SplitAttrs const &attrs); - -template -bool is_valid(T const &t, std::vector const &shapes) { - auto num_outputs = get_num_outputs(t); - if (num_outputs.has_value() && shapes.size() != num_outputs.value()) { - return false; - } - - for (ParallelTensorShape const &shape : shapes) { - if (!is_valid(shape)) { - return false; - } - } - - return is_valid_internal(t, shapes); -} - -template -typename std::enable_if::value, bool>::type - is_valid_internal(T const &t, - std::vector const &shapes) { - return is_valid_internal(t, get_only(shapes)); -} - -template -typename std::enable_if::value, bool>::type - is_valid_internal(T const &t, - std::vector const &shapes) { - return is_valid_internal(t, shapes.at(0), shapes.at(1)); -} - -bool is_valid_internal(MultiHeadAttentionAttrs const &, - std::vector const &); -bool is_valid_internal(BatchMatmulAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ConcatAttrs const &, - std::vector const &); -bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ElementBinaryAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(GatherAttrs const &, - ParallelTensorShape const &, - ParallelTensorShape const &); -bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &); -bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &); + get_output_shapes(PCGOperatorAttrs const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/is_valid.h b/lib/op-attrs/include/op-attrs/is_valid.h new file mode 100644 index 0000000000..2d91307e19 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/is_valid.h @@ -0,0 +1,59 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_IS_VALID_H + +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +template +bool is_valid(T const &t, std::vector const &shapes) { + auto num_outputs = get_num_outputs(t); + if (num_outputs.has_value() && shapes.size() != num_outputs.value()) { + return false; + } + + for (ParallelTensorShape const &shape : shapes) { + if (!is_valid(shape)) { + return false; + } + } + + return is_valid_internal(t, shapes); +} + +bool is_valid_internal(MultiHeadAttentionAttrs const &, + std::vector const &); +bool is_valid_internal(BatchMatmulAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); +bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ConcatAttrs const &, + std::vector const &); +bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ElementBinaryAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); +bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(GatherAttrs const &, + ParallelTensorShape const &, + ParallelTensorShape const &); +bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &); +bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index e126c425dc..40f57d08af 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_ATTENTION_ATTRS_H #define _FLEXFLOW_ATTENTION_ATTRS_H -#include "core.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" #include "op-attrs/ops/attention_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index b9a1d87a75..8afcbb06b1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -1,12 +1,13 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H -#include "core.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" namespace FlexFlow { +TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/cast.h b/lib/op-attrs/include/op-attrs/ops/cast.h index ead779c553..f85481b45b 100644 --- a/lib/op-attrs/include/op-attrs/ops/cast.h +++ b/lib/op-attrs/include/op-attrs/ops/cast.h @@ -3,8 +3,8 @@ #include "op-attrs/ops/cast_attrs.dtg.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index 8a72708971..f3ac8494c0 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -1,13 +1,20 @@ -#ifndef _FLEXFLOW_CONCAT_ATTRS_H -#define _FLEXFLOW_CONCAT_ATTRS_H +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_CONCAT_H -#include "core.h" #include "op-attrs/ops/concat_attrs.dtg.h" +#include "op-attrs/ops/core.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); +TensorShape get_output_shape(ConcatAttrs const &, + std::vector const &); +ParallelTensorShape get_output_shape(ConcatAttrs const &, + std::vector const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index d5d9069f51..676d21c59b 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -1,14 +1,19 @@ #ifndef _FLEXFLOW_FLAT_ATTRS_H #define _FLEXFLOW_FLAT_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); +TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(FlatAttrs const &, + ParallelTensorShape const &); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/gather.h b/lib/op-attrs/include/op-attrs/ops/gather.h index 79516a8862..42efd13b60 100644 --- a/lib/op-attrs/include/op-attrs/ops/gather.h +++ b/lib/op-attrs/include/op-attrs/ops/gather.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_GATHER_ATTRS_H #define _FLEXFLOW_GATHER_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/gather_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.h" @@ -9,6 +9,13 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(GatherAttrs); +TensorShape get_output_shape(GatherAttrs const &, + TensorShape const &input, + TensorShape const &index); +ParallelTensorShape get_output_shape(GatherAttrs const &, + ParallelTensorShape const &input, + ParallelTensorShape const &index); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/include/op-attrs/ops/input.h b/lib/op-attrs/include/op-attrs/ops/input.h index 9fe0ee2c2d..fe92c77a52 100644 --- a/lib/op-attrs/include/op-attrs/ops/input.h +++ b/lib/op-attrs/include/op-attrs/ops/input.h @@ -1,15 +1,17 @@ #ifndef _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H #define _FLEXFLOW_OP_ATTRS_OPS_OP_ATTRS_INPUT_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/input_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(InputAttrs); -ParallelTensorShape get_output_shape(InputAttrs const &); +TensorShape get_output_shape(InputAttrs const &); +ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 94f9b9e147..29b0b2f514 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -1,9 +1,10 @@ #ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index dd6948165e..795ba19ae8 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -5,12 +5,15 @@ #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "utils/record_formatter.h" #include namespace FlexFlow { CHECK_VALID_OP_ATTR(LinearAttrs); +RecordFormatter as_dot(LinearAttrs const &); + tl::expected get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions.h deleted file mode 100644 index 58d372d9e5..0000000000 --- a/lib/op-attrs/include/op-attrs/ops/loss_functions.h +++ /dev/null @@ -1,75 +0,0 @@ -#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H -#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H - -#include "core.h" -#include "utils/exception.h" -#include "utils/visitable.h" -#include - -namespace FlexFlow { - -enum class LossFunction { - CATEGORICAL_CROSSENTROPY, - SPARSE_CATEGORICAL_CROSSENTROPY, - MEAN_SQUARED_ERROR_AVG_REDUCE, - MEAN_SQUARED_ERROR_SUM_REDUCE, - IDENTITY -}; - -LossFunction parse_loss_function_name(std::string const &); - -struct SparseCategoricalCrossEntropyLossAttrs { - req replace_labels; // for aggregate_spec: More predictions than labels -}; -FF_VISITABLE_STRUCT(SparseCategoricalCrossEntropyLossAttrs, replace_labels); -CHECK_VALID_OP_ATTR(SparseCategoricalCrossEntropyLossAttrs); - -struct OtherLossAttrs { - req loss_type; -}; -FF_VISITABLE_STRUCT(OtherLossAttrs, loss_type); -CHECK_VALID_OP_ATTR(OtherLossAttrs); - -using LossAttrs = - std::variant; - -LossFunction get_loss_function(OtherLossAttrs const &); -LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &); -LossFunction get_loss_function(LossAttrs const &); - -} // namespace FlexFlow - -namespace fmt { - -template <> -struct formatter<::FlexFlow::LossFunction> : formatter { - template - auto format(::FlexFlow::LossFunction d, FormatContext &ctx) const - -> decltype(ctx.out()) { - using namespace FlexFlow; - - string_view name = "unknown"; - switch (d) { - case LossFunction::CATEGORICAL_CROSSENTROPY: - name = "CategoricalCrossEntropy"; - break; - case LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY: - name = "SparseCategoricalCrossEntropy"; - break; - case LossFunction::MEAN_SQUARED_ERROR_AVG_REDUCE: - name = "MeanSquaredErrorAvgReduce"; - break; - case LossFunction::MEAN_SQUARED_ERROR_SUM_REDUCE: - name = "MeanSquaredErrorSumReduce"; - break; - case LossFunction::IDENTITY: - name = "Identity"; - break; - } - return formatter::format(name, ctx); - } -}; - -} // namespace fmt - -#endif diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml new file mode 100644 index 0000000000..17293095e4 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_attrs.variant.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LossAttrs" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.dtg.h", + "op-attrs/ops/loss_functions/other_loss_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::SparseCategoricalCrossEntropyLossAttrs" +key = "sparse_categorical_cross_entropy_loss" + +[[values]] +type = "::FlexFlow::OtherLossAttrs" +key = "other_loss" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml new file mode 100644 index 0000000000..9658202a45 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_function.enum.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "LossFunction" +features = [ + "fmt", + "hash", + "rapidcheck", + "json", +] + +[[values]] +name = "CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "SPARSE_CATEGORICAL_CROSSENTROPY" + +[[values]] +name = "MEAN_SQUARED_ERROR_AVG_REDUCE" + +[[values]] +name = "MEAN_SQUARED_ERROR_SUM_REDUCE" + +[[values]] +name = "IDENTITY" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h new file mode 100644 index 0000000000..ca8f3e6602 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/loss_functions.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H +#define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_LOSS_FUNCTIONS_H + +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "op-attrs/ops/loss_functions/loss_function.dtg.h" + +namespace FlexFlow { + +CHECK_VALID_OP_ATTR(LossAttrs); + +LossFunction parse_loss_function_name(std::string const &); + +LossFunction get_loss_function(OtherLossAttrs const &); +LossFunction get_loss_function(SparseCategoricalCrossEntropyLossAttrs const &); +LossFunction get_loss_function(LossAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml new file mode 100644 index 0000000000..284a4b1d7d --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/other_loss_attrs.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OtherLossAttrs" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [ + "op-attrs/ops/loss_functions/loss_function.dtg.h", +] + +[[fields]] +name = "loss_type" +type = "::FlexFlow::LossFunction" diff --git a/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml new file mode 100644 index 0000000000..c50b432ba2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/loss_functions/sparse_categorical_cross_entropy_loss_attrs.struct.toml @@ -0,0 +1,15 @@ +namespace = "FlexFlow" +name = "SparseCategoricalCrossEntropyLossAttrs" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[fields]] +# for aggregate_spec: More predictions than labels +name = "replace_labels" +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/ops/noop.h b/lib/op-attrs/include/op-attrs/ops/noop.h index eb01009259..2c61dff886 100644 --- a/lib/op-attrs/include/op-attrs/ops/noop.h +++ b/lib/op-attrs/include/op-attrs/ops/noop.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_OP_ATTRS_OPS_NOOP_H #define _FLEXFLOW_OP_ATTRS_OPS_NOOP_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/noop_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(NoopAttrs); +TensorShape get_output_shape(NoopAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 162f9aef05..505fdd9f8c 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_POOL_2D_ATTRS_H #define _FLEXFLOW_POOL_2D_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); +TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/replicate.h b/lib/op-attrs/include/op-attrs/ops/replicate.h index 4c46bf88a9..9104a36155 100644 --- a/lib/op-attrs/include/op-attrs/ops/replicate.h +++ b/lib/op-attrs/include/op-attrs/ops/replicate.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_REPLICATE_ATTRS_H #define _FLEXFLOW_REPLICATE_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/replicate_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/ops/reshape.h b/lib/op-attrs/include/op-attrs/ops/reshape.h index cd2ca80c3a..e87ca5c750 100644 --- a/lib/op-attrs/include/op-attrs/ops/reshape.h +++ b/lib/op-attrs/include/op-attrs/ops/reshape.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_RESHAPE_ATTRS_H #define _FLEXFLOW_RESHAPE_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/reshape_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -9,6 +9,8 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ReshapeAttrs); +TensorShape get_output_shape(ReshapeAttrs const &attrs, + TensorShape const &input_shape); ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/reverse.h b/lib/op-attrs/include/op-attrs/ops/reverse.h index adc62dc9ae..023e714c20 100644 --- a/lib/op-attrs/include/op-attrs/ops/reverse.h +++ b/lib/op-attrs/include/op-attrs/ops/reverse.h @@ -1,14 +1,16 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_REVERSE_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/reverse_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(ReverseAttrs); +TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/split.h b/lib/op-attrs/include/op-attrs/ops/split.h index 8fc2257760..e6a08d6e77 100644 --- a/lib/op-attrs/include/op-attrs/ops/split.h +++ b/lib/op-attrs/include/op-attrs/ops/split.h @@ -1,15 +1,18 @@ #ifndef _FLEXFLOW_SPLIT_ATTRS_H #define _FLEXFLOW_SPLIT_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/split_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { CHECK_VALID_OP_ATTR(SplitAttrs); +std::vector get_output_shapes(SplitAttrs const &, + TensorShape const &); std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index c6af40dd48..bd11f0ae91 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_TOPK_ATTRS_H #define _FLEXFLOW_TOPK_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" diff --git a/lib/op-attrs/include/op-attrs/ops/transpose.h b/lib/op-attrs/include/op-attrs/ops/transpose.h index 6e23d91d78..6de83ee414 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose.h +++ b/lib/op-attrs/include/op-attrs/ops/transpose.h @@ -1,16 +1,18 @@ #ifndef _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H #define _FLEXFLOW_OP_META_OPS_TRANSPOSE_ATTRS_H -#include "core.h" +#include "op-attrs/ops/core.h" #include "op-attrs/ops/transpose_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(TransposeAttrs); -ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, - ParallelTensorShape const &input_shape); +TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &); +ParallelTensorShape get_output_shape(TransposeAttrs const &, + ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/weight.h b/lib/op-attrs/include/op-attrs/ops/weight.h new file mode 100644 index 0000000000..ab97b31012 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/ops/weight.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_WEIGHT_H + +#include "op-attrs/ops/core.h" +#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" + +namespace FlexFlow { + +CHECK_VALID_OP_ATTR(WeightAttrs); + +TensorShape get_output_shape(WeightAttrs const &); +ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml index 28810a437e..c4d22a006c 100644 --- a/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/weight_attrs.struct.toml @@ -8,4 +8,11 @@ features = [ "rapidcheck", "fmt", ] -fields = [] + +includes = [ + "op-attrs/tensor_shape.dtg.h", +] + +[[fields]] +name = "tensor_shape" +type = "::FlexFlow::TensorShape" diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.h b/lib/op-attrs/include/op-attrs/parallel_op_attrs.h new file mode 100644 index 0000000000..8669669f09 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_op_attrs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_OP_ATTRS_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_OP_ATTRS_H + +#include "op-attrs/parallel_op_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ParallelOpAttrs const &, + ParallelTensorShape const &); +PCGOperatorAttrs pcg_op_attrs_from_parallel_op_attrs(ParallelOpAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml b/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml new file mode 100644 index 0000000000..f1631a41f2 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_op_attrs.variant.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "ParallelOpAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/reduction_attrs.dtg.h", + "op-attrs/ops/repartition_attrs.dtg.h", + "op-attrs/ops/replicate_attrs.dtg.h", +] + +[[values]] +type = "::FlexFlow::CombineAttrs" +key = "combine_distributed" + +[[values]] +type = "::FlexFlow::ReductionAttrs" +key = "reduce_distributed" + +[[values]] +type = "::FlexFlow::RepartitionAttrs" +key = "partition_distributed" + +[[values]] +type = "::FlexFlow::ReplicateAttrs" +key = "replicate_distributed" + diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 99be635ffc..76356b39d4 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -2,6 +2,7 @@ #define _OP_META_PARALLEL_TENSOR_SHAPE_H #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" #include @@ -36,6 +37,7 @@ int get_total_parallel_degree(ParallelTensorShape const &); bool is_valid(ParallelTensorShape const &); +TensorShape require_not_parallel(ParallelTensorShape const &); TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &); std::vector get_tensor_shapes_unsafe(std::vector const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml index e6197bcd51..806af55cba 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.struct.toml @@ -10,8 +10,8 @@ features = [ ] includes = [ - "op-attrs/parallel_tensor_dims.h", - "op-attrs/datatype.h", + "op-attrs/parallel_tensor_dims.dtg.h", + "op-attrs/datatype.dtg.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h index 25be926cbe..08167fe3d9 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -10,6 +10,7 @@ bool is_parallel_op(PCGOperatorAttrs const &); OperatorType get_op_type(PCGOperatorAttrs const &); ComputationGraphOpAttrs compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); +RecordFormatter as_dot(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc deleted file mode 100644 index c20d4be34c..0000000000 --- a/lib/op-attrs/src/get_output_shapes.cc +++ /dev/null @@ -1,29 +0,0 @@ -#include "op-attrs/get_output_shapes.h" - -namespace FlexFlow { - -ParallelTensorShape as_parallel(TensorShape const &); -std::vector as_parallel(std::vector const &); - -std::vector get_output_shapes( - PCGOperatorAttrs const &op_params, - std::vector const &input_tensor_shapes) { - NOT_IMPLEMENTED(); -} - -// TensorShape get_output_shape(AggregateAttrs const &attrs, -// TensorShape const &gate_preds, -// TensorShape const &gate_assign, -// TensorShape const &true_gate_assign, -// TensorShape const &full_gate_gradients, -// std::vector const &exp_preds) { -// return get_tensor_shape_unsafe( -// get_output_shape(attrs, -// as_parallel(gate_preds), -// as_parallel(gate_assign), -// as_parallel(true_gate_assign), -// as_parallel(full_gate_gradients), -// as_parallel(exp_preds))); -// } - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/as_dot.cc b/lib/op-attrs/src/op-attrs/as_dot.cc deleted file mode 100644 index f8d05de941..0000000000 --- a/lib/op-attrs/src/op-attrs/as_dot.cc +++ /dev/null @@ -1,13 +0,0 @@ -#include "op-attrs/as_dot.h" - -namespace FlexFlow { - -RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { - NOT_IMPLEMENTED(); -} - -RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { - NOT_IMPLEMENTED(); -} - -} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/datatype.cc b/lib/op-attrs/src/op-attrs/datatype.cc index bd29c8033a..3bee05c253 100644 --- a/lib/op-attrs/src/op-attrs/datatype.cc +++ b/lib/op-attrs/src/op-attrs/datatype.cc @@ -19,7 +19,7 @@ size_t size_of_datatype(DataType data_type) { case DataType::DOUBLE: return sizeof(double); default: - throw mk_runtime_error("Unknown DataType {}", data_type); + throw mk_runtime_error(fmt::format("Unknown DataType {}", data_type)); } } diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc new file mode 100644 index 0000000000..6edd5485af --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/enumerate.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/enumerate.h" diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc new file mode 100644 index 0000000000..d91d1a1eca --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -0,0 +1,85 @@ +#include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/concat.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/dropout.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" +#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/input.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/weight.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::vector + get_output_shapes(PCGOperatorAttrs const &pcg_op_attrs, + std::vector const &inputs) { + return pcg_op_attrs.visit>(overload{ + [&](BatchMatmulAttrs const &attrs) -> std::vector { + return {throw_if_unexpected( + get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; + }, + [&](BatchNormAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](CastAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](CombineAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](ConcatAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs)}; + }, + [&](Conv2DAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](DropoutAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](ElementBinaryAttrs const &attrs) -> std::vector { + return {throw_if_unexpected( + get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; + }, + [&](ElementUnaryAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](EmbeddingAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](FlatAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](GatherAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0), inputs.at(1))}; + }, + [&](InputAttrs const &attrs) -> std::vector { + return {get_output_parallel_tensor_shape(attrs)}; + }, + [&](LayerNormAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](LinearAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, + [&](ReplicateAttrs const &attrs) -> std::vector { + return {get_output_shape(attrs, inputs.at(0))}; + }, + [&](WeightAttrs const &attrs) -> std::vector { + return {get_output_parallel_tensor_shape(attrs)}; + }, + [&](auto const &attrs) -> std::vector { + NOT_IMPLEMENTED(); + }}); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/is_valid.cc b/lib/op-attrs/src/op-attrs/is_valid.cc new file mode 100644 index 0000000000..14eae33b4b --- /dev/null +++ b/lib/op-attrs/src/op-attrs/is_valid.cc @@ -0,0 +1,3 @@ +#include "op-attrs/is_valid.h" + +namespace FlexFlow {} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index 7be51efa22..b75c3521c6 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +TensorShape get_output_shape(BatchNormAttrs const &, + TensorShape const &input_shape) { + return input_shape; +} + ParallelTensorShape get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/cast.cc b/lib/op-attrs/src/op-attrs/ops/cast.cc index 444409ffcb..cfbfd61ced 100644 --- a/lib/op-attrs/src/op-attrs/ops/cast.cc +++ b/lib/op-attrs/src/op-attrs/ops/cast.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/cast.h" +#include "op-attrs/datatype.h" namespace FlexFlow { diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 065c58f365..02fee70bea 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -11,4 +11,14 @@ namespace FlexFlow { /* return valid; */ /* } */ +TensorShape get_output_shape(ConcatAttrs const &, + std::vector const &) { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape get_output_shape(ConcatAttrs const &, + std::vector const &) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/embedding.cc b/lib/op-attrs/src/op-attrs/ops/embedding.cc index 4a7d4395b6..d10d52c6f5 100644 --- a/lib/op-attrs/src/op-attrs/ops/embedding.cc +++ b/lib/op-attrs/src/op-attrs/ops/embedding.cc @@ -1,6 +1,7 @@ #include "op-attrs/ops/embedding.h" #include "op-attrs/dim_ordered/slice.h" #include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/parallel_tensor_dims.h" #include "utils/containers/product.h" #include "utils/integer_conversions.h" diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index b0683c5f08..5d318207ee 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -3,15 +3,24 @@ namespace FlexFlow { -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; +TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); } -namespace Output { -constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; +ParallelTensorShape get_output_shape(FlatAttrs const &, + ParallelTensorShape const &) { + NOT_IMPLEMENTED(); } +// namespace Input { +// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, +// REPLICA = 4; +// } +// +// namespace Output { +// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; +// } +// /* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ /* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ diff --git a/lib/op-attrs/src/op-attrs/ops/gather.cc b/lib/op-attrs/src/op-attrs/ops/gather.cc index 4f2c13c794..4b1053aee1 100644 --- a/lib/op-attrs/src/op-attrs/ops/gather.cc +++ b/lib/op-attrs/src/op-attrs/ops/gather.cc @@ -2,6 +2,18 @@ namespace FlexFlow { +TensorShape get_output_shape(GatherAttrs const &, + TensorShape const &input, + TensorShape const &index) { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape get_output_shape(GatherAttrs const &, + ParallelTensorShape const &input, + ParallelTensorShape const &index) { + NOT_IMPLEMENTED(); +} + /* bool GatherAttrs::is_valid(ParallelTensorShape const &lhs, * ParallelTensorShape const &rhs) const { */ /* if (lhs.num_dims() != rhs.num_dims()) { */ diff --git a/lib/op-attrs/src/op-attrs/ops/input.cc b/lib/op-attrs/src/op-attrs/ops/input.cc index 93606b603a..acc0b02e69 100644 --- a/lib/op-attrs/src/op-attrs/ops/input.cc +++ b/lib/op-attrs/src/op-attrs/ops/input.cc @@ -2,7 +2,11 @@ namespace FlexFlow { -ParallelTensorShape get_output_shape(InputAttrs const &) { +TensorShape get_output_shape(InputAttrs const &) { + NOT_IMPLEMENTED(); +} + +ParallelTensorShape get_output_parallel_tensor_shape(InputAttrs const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index 76a5e25dfc..b9603d7850 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -1,6 +1,8 @@ #include "op-attrs/ops/layer_norm.h" #include "op-attrs/dim_ordered/ff_ordered_of.h" #include "op-attrs/dim_ordered/get_idxs.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" #include "utils/containers/all_of.h" #include "utils/containers/any_of.h" #include "utils/containers/contains.h" diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index beb944d1a0..24a8250690 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -8,6 +8,22 @@ namespace FlexFlow { +RecordFormatter as_dot(LinearAttrs const &attrs) { + RecordFormatter r; + + auto kv = [](std::string const &label, auto const &val) { + RecordFormatter rr; + rr << label << fmt::to_string(val); + return rr; + }; + + r << kv("out_channels", attrs.out_channels) << kv("use_bias", attrs.use_bias) + << kv("data_type", attrs.data_type) << kv("activation", attrs.activation) + << kv("regularizer", attrs.regularizer); + + return r; +} + tl::expected get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); diff --git a/lib/op-attrs/src/loss_functions.cc b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc similarity index 80% rename from lib/op-attrs/src/loss_functions.cc rename to lib/op-attrs/src/op-attrs/ops/loss_functions.cc index 094e117d77..e756d08547 100644 --- a/lib/op-attrs/src/loss_functions.cc +++ b/lib/op-attrs/src/op-attrs/ops/loss_functions.cc @@ -1,4 +1,4 @@ -#include "op-attrs/ops/loss_functions.h" +#include "op-attrs/ops/loss_functions/loss_functions.h" #include "utils/containers/transform.h" #include #include @@ -8,20 +8,15 @@ namespace FlexFlow { LossFunction get_loss_type(OtherLossAttrs const &attrs) { return attrs.loss_type; } + LossFunction get_loss_type(SparseCategoricalCrossEntropyLossAttrs const &attrs) { return LossFunction::SPARSE_CATEGORICAL_CROSSENTROPY; } -struct GetLossFunction { - template - LossFunction operator()(T const &t) { - return get_loss_type(t); - } -}; - LossFunction get_loss_type(LossAttrs const &attrs) { - return visit(GetLossFunction{}, attrs); + return attrs.visit( + [](auto const &t) { return get_loss_type(t); }); } LossFunction parse_loss_name(std::string const &raw_name) { @@ -37,8 +32,8 @@ LossFunction parse_loss_name(std::string const &raw_name) { } else if (name == "identity") { return LossFunction::IDENTITY; } else { - throw mk_runtime_error( - "Unknown loss type {}. Please report this as an issue.", name); + throw mk_runtime_error(fmt::format( + "Unknown loss type {}. Please report this as an issue.", name)); } } diff --git a/lib/op-attrs/src/op-attrs/ops/noop.cc b/lib/op-attrs/src/op-attrs/ops/noop.cc index b2b15d820c..6ba33146e4 100644 --- a/lib/op-attrs/src/op-attrs/ops/noop.cc +++ b/lib/op-attrs/src/op-attrs/ops/noop.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +TensorShape get_output_shape(NoopAttrs const &, + TensorShape const &input_shape) { + return input_shape; +} + ParallelTensorShape get_output_shape(NoopAttrs const &, ParallelTensorShape const &input_shape) { return input_shape; diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index cf6ed177d3..e1917efd89 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/reshape.cc b/lib/op-attrs/src/op-attrs/ops/reshape.cc index 7d0600550a..6216ad8c6c 100644 --- a/lib/op-attrs/src/op-attrs/ops/reshape.cc +++ b/lib/op-attrs/src/op-attrs/ops/reshape.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +TensorShape get_output_shape(ReshapeAttrs const &attrs, + TensorShape const &input_shape) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(ReshapeAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/reverse.cc b/lib/op-attrs/src/op-attrs/ops/reverse.cc index 79b5bd50fb..c38d7e4782 100644 --- a/lib/op-attrs/src/op-attrs/ops/reverse.cc +++ b/lib/op-attrs/src/op-attrs/ops/reverse.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(ReverseAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(ReverseAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/split.cc b/lib/op-attrs/src/op-attrs/ops/split.cc index cfb4071833..a9fe691584 100644 --- a/lib/op-attrs/src/op-attrs/ops/split.cc +++ b/lib/op-attrs/src/op-attrs/ops/split.cc @@ -2,6 +2,11 @@ namespace FlexFlow { +std::vector get_output_shapes(SplitAttrs const &, + TensorShape const &) { + NOT_IMPLEMENTED(); +} + std::vector get_output_shapes(SplitAttrs const &attrs, ParallelTensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/transpose.cc b/lib/op-attrs/src/op-attrs/ops/transpose.cc index 75f7eb3c18..50e6fb35f5 100644 --- a/lib/op-attrs/src/op-attrs/ops/transpose.cc +++ b/lib/op-attrs/src/op-attrs/ops/transpose.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(TransposeAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(TransposeAttrs const &op_attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/ops/weight.cc b/lib/op-attrs/src/op-attrs/ops/weight.cc new file mode 100644 index 0000000000..f8b6b7ec49 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/ops/weight.cc @@ -0,0 +1,14 @@ +#include "op-attrs/ops/weight.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +TensorShape get_output_shape(WeightAttrs const &attrs) { + return attrs.tensor_shape; +} + +ParallelTensorShape get_output_parallel_tensor_shape(WeightAttrs const &attrs) { + return lift_to_parallel(attrs.tensor_shape); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc new file mode 100644 index 0000000000..c458d4149d --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_op_attrs.cc @@ -0,0 +1,37 @@ +#include "op-attrs/parallel_op_attrs.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" +#include "utils/overload.h" + +namespace FlexFlow { + +ParallelTensorShape get_output_shape(ParallelOpAttrs const &attrs, + ParallelTensorShape const &input_shape) { + return attrs.visit(overload{ + [&](CombineAttrs const &combine_attrs) { + return throw_if_unexpected( + get_output_shape(combine_attrs, input_shape)); + }, + [&](ReductionAttrs const &reduction_attrs) { + return throw_if_unexpected( + get_output_shape(reduction_attrs, input_shape)); + }, + [&](RepartitionAttrs const &repartition_attrs) { + return throw_if_unexpected( + get_output_shape(repartition_attrs, input_shape)); + }, + [&](ReplicateAttrs const &replicate_attrs) { + return get_output_shape(replicate_attrs, input_shape); + }, + }); +} + +PCGOperatorAttrs + pcg_op_attrs_from_parallel_op_attrs(ParallelOpAttrs const &attrs) { + return attrs.visit( + [](auto const &attrs) { return PCGOperatorAttrs{attrs}; }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 150fb6a76d..10bf5027a4 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -1,4 +1,5 @@ #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/tensor_dims.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" @@ -74,6 +75,19 @@ ParallelTensorShape }; } +TensorShape require_not_parallel(ParallelTensorShape const &s) { + int total_degree = get_total_parallel_degree(s); + if (total_degree != 1) { + throw mk_runtime_error( + fmt::format("Error: require_not_parallel received a parallel tensor " + "shape with parallel degree {}: {}", + total_degree, + s)); + } + + return get_reduced_shape(s); +} + TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &) { NOT_IMPLEMENTED(); } diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index 74882fe9f2..0bb134da6b 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -1,5 +1,6 @@ #include "op-attrs/pcg_operator_attrs.h" #include "op-attrs/get_op_type.h" +#include "op-attrs/ops/linear.h" #include "utils/overload.h" namespace FlexFlow { @@ -64,4 +65,15 @@ ComputationGraphOpAttrs }); } +RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { + return attrs.visit(overload{ + [](LinearAttrs const &l) { return as_dot(l); }, + [&](auto const &) { + RecordFormatter r; + r << fmt::to_string(get_op_type(attrs)); + return r; + }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op_attrs.cc b/lib/op-attrs/src/op_attrs.cc deleted file mode 100644 index 6125c03a59..0000000000 --- a/lib/op-attrs/src/op_attrs.cc +++ /dev/null @@ -1,10 +0,0 @@ -/* #include "op-attrs/ops/op_attrs.h" */ - -/* namespace FlexFlow { */ - -/* int OpAttrsInterface::num_outputs(std::vector const - * &inputs) const { */ -/* return this->output_shapes(inputs).size(); */ -/* } */ - -/* } */ diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc new file mode 100644 index 0000000000..d2c758a05f --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc @@ -0,0 +1,20 @@ +#include "op-attrs/dim_ordered/enumerate.h" +#include "utils/fmt/map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("enumerate(FFOrdered)") { + FFOrdered input = {"zero", "one", "two"}; + + std::map result = enumerate(input); + std::map correct = { + {ff_dim_t{0}, "zero"}, + {ff_dim_t{1}, "one"}, + {ff_dim_t{2}, "two"}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index 8f5f4054d6..b9dd66df5d 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/layer_norm.h" +#include "op-attrs/parallel_tensor_shape.h" #include "test/utils/doctest.h" #include "utils/expected.h" #include "utils/fmt/expected.h" diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc index 086d25d042..31030ca0f9 100644 --- a/lib/op-attrs/test/src/ops/cast.cc +++ b/lib/op-attrs/test/src/ops/cast.cc @@ -1,4 +1,5 @@ #include "op-attrs/ops/cast.h" +#include "op-attrs/parallel_tensor_shape.h" #include "test/utils/doctest.h" TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 0ca330408e..c641aed6a4 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -114,11 +114,10 @@ struct ComputationGraphBuilder { std::optional const &kernel_initializer = std::nullopt, std::optional const &name = std::nullopt); // Add a gather layer - std::vector - gather(tensor_guid_t const &input, - tensor_guid_t const &index, - ff_dim_t dim, - std::optional const &name = std::nullopt); + tensor_guid_t gather(tensor_guid_t const &input, + tensor_guid_t const &index, + ff_dim_t dim, + std::optional const &name = std::nullopt); // Add a cache layer tensor_guid_t cache(tensor_guid_t const &input, diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h index 0e547e7688..05c486f0f7 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.h @@ -8,7 +8,7 @@ namespace FlexFlow { V1DataflowGraph to_v1(DataflowGraphView const &); V1DataflowGraph to_v1(DataflowGraphView const &, - std::unordered_map const &); + std::unordered_map const &); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml index dc9dc96f29..d9aade739c 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml @@ -19,7 +19,7 @@ includes = [ [[fields]] name = "nodes" -type = "std::vector" +type = "std::vector" [[fields]] name = "edges" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml index b0d2546977..752706fe1d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_graph_edge.struct.toml @@ -11,16 +11,16 @@ features = [ [[fields]] name = "srcNode" -type = "size_t" +type = "int" [[fields]] name = "srcIdx" -type = "size_t" +type = "int" [[fields]] name = "dstNode" -type = "size_t" +type = "int" [[fields]] name = "dstIdx" -type = "size_t" +type = "int" diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h index b1f96c513b..48203d73ae 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -3,7 +3,7 @@ #include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" #include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" -#include "utils/containers/enumerate.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -16,14 +16,14 @@ template V1LabelledDataflowGraph to_v1(LabelledDataflowGraphView const &g) { - bidict nodes = enumerate(get_nodes(g)); + bidict nodes = bidict_from_enumerating(get_nodes(g)); V1DataflowGraph unlabelled = to_v1(g, nodes.reversed()); - std::unordered_map node_labels = map_values( + std::unordered_map node_labels = map_values( nodes.as_unordered_map(), [&](Node const &n) { return g.at(n); }); - std::unordered_map> output_labels = + std::unordered_map> output_labels = map_values(nodes.as_unordered_map(), [&](Node const &n) { return transform(get_outputs(g, n), [&](DataflowOutput const &o) { return g.at(o); }); diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index 0a6a148159..fd8d4c39c4 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -22,11 +22,11 @@ includes = [ [[fields]] name = "node_labels" -type = "std::unordered_map" +type = "std::unordered_map" [[fields]] name = "output_labels" -type = "std::unordered_map>" +type = "std::unordered_map>" [[fields]] name = "graph" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h b/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h new file mode 100644 index 0000000000..eb4928deaa --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/generate_weight_transform.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_GENERATE_WEIGHT_TRANSFORM_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_GENERATE_WEIGHT_TRANSFORM_H + +#include "op-attrs/parallel_op_attrs.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" + +namespace FlexFlow { + +std::unordered_set + generate_weight_transform(TensorShape const ¤t, + ParallelTensorShape const &goal); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 4caaad06b2..9150681070 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -28,9 +28,6 @@ std::vector get_layer_outputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); -parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &, - parallel_tensor_guid_t const &); - ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 5b34ee641a..20e947ad58 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -84,10 +84,35 @@ struct ParallelComputationGraphBuilder { std::optional output_bias_initializer = std::nullopt, std::optional const &name = std::nullopt); + parallel_tensor_guid_t + batch_norm(parallel_tensor_guid_t const &input, + bool relu = true, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t relu(parallel_tensor_guid_t const &x, std::optional const &name = std::nullopt); + parallel_tensor_guid_t + identity(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + gelu(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + sigmoid(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + tanh(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + elu(parallel_tensor_guid_t const &x, + std::optional const &name = std::nullopt); + parallel_tensor_guid_t parallel_partition(parallel_tensor_guid_t const &x, ff_dim_t dim, @@ -137,6 +162,15 @@ struct ParallelComputationGraphBuilder { std::vector const &weights, ParallelTensorShape const &output); + parallel_tensor_guid_t + add_weight(ParallelTensorAttrs const &weight_tensor_attrs, + std::optional const &name = std::nullopt); + + parallel_tensor_guid_t + element_unary(ElementUnaryAttrs const &element_unary_attrs, + parallel_tensor_guid_t const &input, + std::optional const &name); + public: ParallelComputationGraph pcg; }; diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h new file mode 100644 index 0000000000..7aac8558e4 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +parallel_tensor_guid_t + get_parallel_tensor(ParallelComputationGraphEdge const &); +parallel_layer_guid_t get_src_layer(ParallelComputationGraphEdge const &); +parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &); +int get_dst_layer_input_idx(ParallelComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml new file mode 100644 index 0000000000..25ef3f5d27 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ParallelComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h new file mode 100644 index 0000000000..905a365b4b --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_GUID_T_H + +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +parallel_layer_guid_t get_source_layer(parallel_tensor_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml new file mode 100644 index 0000000000..6d5e007650 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "parallel_tensor_use_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_input" +type = "::FlexFlow::DataflowInput" diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index e7f5f2b737..deaa440ef8 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -53,7 +53,7 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n) { - return transform(get_inputs(cg.raw_graph, n.raw_node), + return transform(get_input_values(cg.raw_graph, n.raw_node), [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index b6d0e7c890..3f2feaf619 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -2,13 +2,24 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/enumerate_vector.h" +#include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/expected.h" #include @@ -49,7 +60,7 @@ std::vector ComputationGraphBuilder::add_layer( return fmt::format("{}.weights[{}]", layer_name, weight_idx); }); LayerAttrs weight_layer_attrs = LayerAttrs{ - ComputationGraphOpAttrs{WeightAttrs{}}, + ComputationGraphOpAttrs{WeightAttrs{weight_tensor_attrs.shape}}, weight_name, }; std::vector weight_layer_inputs = {}; @@ -451,7 +462,7 @@ tensor_guid_t ComputationGraphBuilder::embedding( return this->add_layer(layer, {input}, {weight_attrs}, output_shape); } -std::vector ComputationGraphBuilder::gather( +tensor_guid_t ComputationGraphBuilder::gather( tensor_guid_t const &input, tensor_guid_t const &index, ff_dim_t dim, @@ -469,10 +480,10 @@ std::vector ComputationGraphBuilder::gather( DataType::INT32, DataType::INT64); } - std::vector output_shapes = - get_output_shapes(attrs, this->get_shape(input), this->get_shape(index)); + TensorShape output_shape = + get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, {}, output_shapes); + return this->add_layer(layer, {input}, {}, output_shape); } /* std::vector diff --git a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc index 787ce5bf7d..cf150a339f 100644 --- a/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc +++ b/lib/pcg/src/pcg/file_format/v1/graphs/v1_dataflow_graph.cc @@ -1,4 +1,5 @@ #include "pcg/file_format/v1/graphs/v1_dataflow_graph.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" #include "utils/containers/enumerate.h" #include "utils/containers/sorted.h" #include "utils/containers/values.h" @@ -9,17 +10,19 @@ namespace FlexFlow { V1DataflowGraph to_v1(DataflowGraphView const &g) { - return to_v1(g, enumerate(get_nodes(g)).reversed()); + bidict node_enumeration_bidict = + bidict_from_enumerating(get_nodes(g)); + std::unordered_map node_enumeration = + node_enumeration_bidict.reversed().as_unordered_map(); + return to_v1(g, node_enumeration); } V1DataflowGraph to_v1(DataflowGraphView const &g, - std::unordered_map const &nodes) { + std::unordered_map const &nodes) { std::unordered_set edges; for (DataflowEdge const &e : get_edges(g)) { - edges.insert(V1GraphEdge{nodes.at(e.src.node), - size_t_from_int(e.src.idx), - nodes.at(e.dst.node), - size_t_from_int(e.dst.idx)}); + edges.insert(V1GraphEdge{ + nodes.at(e.src.node), e.src.idx, nodes.at(e.dst.node), e.dst.idx}); } return V1DataflowGraph{ diff --git a/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc new file mode 100644 index 0000000000..dadad6277f --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/generate_weight_transform.cc @@ -0,0 +1,35 @@ +#include "pcg/parallel_computation_graph/generate_weight_transform.h" +#include "op-attrs/dim_ordered/enumerate.h" +#include "op-attrs/parallel_tensor_shape.h" + +namespace FlexFlow { + +std::unordered_set + generate_weight_transform(TensorShape const ¤t, + ParallelTensorShape const &goal) { + std::unordered_set result; + + int sum_degree = get_sum_degree(goal); + if (sum_degree != 1) { + throw mk_runtime_error( + fmt::format("generate_weight_transform currently only supports " + "sum_degree = 1, but received {}", + sum_degree)); + } + + int discard_copy_degree = get_discard_copy_degree(goal); + if (discard_copy_degree != 1) { + result.insert(ParallelOpAttrs{ReplicateAttrs{discard_copy_degree}}); + } + + for (auto const &[shard_dim, shard_degree] : + enumerate(ff_ordered_shard_degrees(goal))) { + if (shard_degree != 1) { + result.insert(ParallelOpAttrs{RepartitionAttrs{shard_dim, shard_degree}}); + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 831287567d..5b178160cd 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -44,7 +44,7 @@ std::vector get_layer_inputs(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { return transform( - get_inputs(pcg.raw_graph, l.raw_graph_node), + get_input_values(pcg.raw_graph, l.raw_graph_node), [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index b632c984bc..8290a2ff94 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,6 +1,8 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/parallel_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" +#include "pcg/parallel_computation_graph/generate_weight_transform.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/enumerate_vector.h" @@ -326,11 +328,28 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( return this->add_layer(layer, {query, key, value}, weights, output_shape); } -parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( +parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, + bool relu, std::optional const &maybe_name) { - ElementUnaryAttrs attrs = ElementUnaryAttrs{OperatorType::RELU, std::nullopt}; + BatchNormAttrs attrs = BatchNormAttrs{relu}; + + std::string name = + maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); + + ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + + ParallelTensorShape output_shape = + get_output_shape(attrs, this->get_shape(input)); + + return this->add_layer(layer, {input}, {}, {output_shape}); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::element_unary( + ElementUnaryAttrs const &attrs, + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); @@ -343,6 +362,78 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( return this->add_layer(layer, {input}, {}, {output_shape}); } +parallel_tensor_guid_t ParallelComputationGraphBuilder::relu( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::RELU, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::identity( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::IDENTITY, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::gelu( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::GELU, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::sigmoid( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::SIGMOID, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::tanh( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::TANH, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + +parallel_tensor_guid_t ParallelComputationGraphBuilder::elu( + parallel_tensor_guid_t const &input, + std::optional const &maybe_name) { + + ElementUnaryAttrs attrs = ElementUnaryAttrs{ + OperatorType::ELU, + std::nullopt, + }; + + return this->element_unary(attrs, input, maybe_name); +} + parallel_tensor_guid_t ParallelComputationGraphBuilder::parallel_partition( parallel_tensor_guid_t const &input, ff_dim_t dim, @@ -441,6 +532,54 @@ ParallelTensorShape ParallelComputationGraphBuilder::get_shape( return get_parallel_tensor_attrs(this->pcg, t).shape; } +parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight( + ParallelTensorAttrs const &weight_tensor_attrs, + std::optional const &weight_name) { + ParallelTensorShape par_weight_shape = weight_tensor_attrs.shape; + TensorShape unpar_weight_shape = get_reduced_shape(weight_tensor_attrs.shape); + + ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{unpar_weight_shape}}, + weight_name, + }; + + std::vector weight_layer_inputs = {}; + std::vector weight_output_attrs = {weight_tensor_attrs}; + + DataflowOutput current_raw_weight_tensor = get_only( + this->pcg.raw_graph + .add_node( + weight_layer_attrs, weight_layer_inputs, weight_output_attrs) + .outputs); + ParallelTensorShape current_shape = lift_to_parallel(unpar_weight_shape); + + for (ParallelOpAttrs const ¶llel_op_attr : + generate_weight_transform(unpar_weight_shape, par_weight_shape)) { + ParallelTensorShape output_shape = + get_output_shape(parallel_op_attr, current_shape); + ParallelTensorAttrs output_attrs = ParallelTensorAttrs{ + output_shape, + std::nullopt, + std::nullopt, + CreateGrad::YES, + }; + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + pcg_op_attrs_from_parallel_op_attrs(parallel_op_attr), + std::nullopt, + }; + current_raw_weight_tensor = get_only( + this->pcg.raw_graph + .add_node(layer_attrs, {current_raw_weight_tensor}, {output_attrs}) + .outputs); + current_shape = output_shape; + } + + assert(current_shape == par_weight_shape); + + return parallel_tensor_guid_t{current_raw_weight_tensor}; +} + std::vector ParallelComputationGraphBuilder::add_layer( ParallelLayerAttrs const &layer, std::vector const &inputs, @@ -455,18 +594,9 @@ std::vector ParallelComputationGraphBuilder::add_layer( transform(layer.name, [&](std::string const &layer_name) { return fmt::format("{}.weights[{}]", layer_name, weight_idx); }); - ParallelLayerAttrs weight_layer_attrs = ParallelLayerAttrs{ - PCGOperatorAttrs{WeightAttrs{}}, - weight_name, - }; - std::vector weight_layer_inputs = {}; - std::vector weight_output_attrs = { - weight_tensor_attrs}; - raw_weight_tensors.push_back(get_only(this->pcg.raw_graph - .add_node(weight_layer_attrs, - weight_layer_inputs, - weight_output_attrs) - .outputs)); + + raw_weight_tensors.push_back( + this->add_weight(weight_tensor_attrs, weight_name).raw_graph_output); } std::vector raw_inputs = diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc new file mode 100644 index 0000000000..dca8154eb4 --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_edge.cc @@ -0,0 +1,22 @@ +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" + +namespace FlexFlow { + +parallel_tensor_guid_t + get_parallel_tensor(ParallelComputationGraphEdge const &e) { + return parallel_tensor_guid_t{e.raw_edge.src}; +} + +parallel_layer_guid_t get_src_layer(ParallelComputationGraphEdge const &e) { + return parallel_layer_guid_t{e.raw_edge.src.node}; +} + +parallel_layer_guid_t get_dst_layer(ParallelComputationGraphEdge const &e) { + return parallel_layer_guid_t{e.raw_edge.dst.node}; +} + +int get_dst_layer_input_idx(ParallelComputationGraphEdge const &e) { + return e.raw_edge.dst.idx; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc new file mode 100644 index 0000000000..ad4eae041f --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_guid_t.cc @@ -0,0 +1,9 @@ +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" + +namespace FlexFlow { + +parallel_layer_guid_t get_source_layer(parallel_tensor_guid_t const &t) { + return parallel_layer_guid_t{t.raw_graph_output.node}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index db01728cf0..440f735e80 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -2,6 +2,7 @@ #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" #include "test/utils/doctest.h" #include "utils/containers/count.h" #include "utils/containers/generate_map.h" @@ -39,7 +40,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t rhs = b.create_input_tensor(rhs_shape); parallel_tensor_guid_t out = b.add(lhs, rhs); - parallel_layer_guid_t layer = get_source_layer(b.pcg, out); + parallel_layer_guid_t layer = get_source_layer(out); SUBCASE("inputs") { std::vector result = @@ -102,7 +103,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t b_tensor = b.create_input_tensor(b_shape); parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor); - parallel_layer_guid_t layer = get_source_layer(b.pcg, out); + parallel_layer_guid_t layer = get_source_layer(out); SUBCASE("inputs") { std::vector result = @@ -145,7 +146,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType output_datatype = DataType::DOUBLE; parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.cast(input, output_datatype); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -205,7 +206,7 @@ TEST_SUITE(FF_TEST_SUITE) { [&](parallel_layer_guid_t const &l) { return get_parallel_layer_attrs(b.pcg, l); }); - CHECK_MESSAGE(layers.size() == 4, "Incorrect layers ", layers); + CHECK_MESSAGE(layers.size() == 6, "Incorrect layers ", layers); auto num_attrs_of_type = [&](OperatorType op_type) -> int { return count(values(layers), [&](ParallelLayerAttrs const &l) { @@ -222,6 +223,9 @@ TEST_SUITE(FF_TEST_SUITE) { int num_conv_attrs = num_attrs_of_type(OperatorType::CONV2D); CHECK(num_conv_attrs == 1); + int num_replicate_attrs = num_attrs_of_type(OperatorType::REPLICATE); + CHECK(num_replicate_attrs == 2); + parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( as_vector(items(layers)), [](std::pair const &kv) @@ -307,7 +311,7 @@ TEST_SUITE(FF_TEST_SUITE) { Activation::RELU, /*use_bias=*/true, DataType::FLOAT); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -350,7 +354,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*outDim=*/8, AggregateOp::SUM, DataType::FLOAT); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -400,7 +404,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t value = b.create_input_tensor(value_shape); parallel_tensor_guid_t output = b.multihead_attention(query, key, value, embed_dim, num_heads); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -441,7 +445,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.relu(input); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -480,7 +484,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -519,7 +523,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -558,7 +562,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_replicate(input, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = @@ -597,7 +601,7 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t input = b.create_input_tensor(input_shape); parallel_tensor_guid_t output = b.parallel_reduce(input, 2); - parallel_layer_guid_t layer = get_source_layer(b.pcg, output); + parallel_layer_guid_t layer = get_source_layer(output); SUBCASE("inputs") { std::vector result = diff --git a/lib/substitution-generator/CMakeLists.txt b/lib/substitution-generator/CMakeLists.txt index 41005e6a4e..1db0d888ba 100644 --- a/lib/substitution-generator/CMakeLists.txt +++ b/lib/substitution-generator/CMakeLists.txt @@ -11,6 +11,7 @@ ff_add_library( utils op-attrs pcg + substitutions ) # add_subdirectory(ffi) diff --git a/lib/substitution-generator/include/substitution-generator/json.h b/lib/substitution-generator/include/substitution-generator/json.h deleted file mode 100644 index 5563d8a835..0000000000 --- a/lib/substitution-generator/include/substitution-generator/json.h +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef _FLEXFLOW_SUBSTITUTION_LOADER_H -#define _FLEXFLOW_SUBSTITUTION_LOADER_H - -#include "substitution-generator/legacy_operator_type.dtg.h" -#include "substitution-generator/legacy_pm_parameter.dtg.h" -#include -#include -#include - -namespace FlexFlow { - -struct Parameter { - LegacyPMParameter key; - int value; -}; -void from_json(nlohmann::json const &j, Parameter &p); - -struct Tensor { - int opId; - int tsId; -}; -void from_json(nlohmann::json const &j, Tensor &t); - -struct Operator { - LegacyOperatorType op_type; - std::vector input; - std::vector para; - - std::optional at(LegacyPMParameter key) const; -}; -void from_json(nlohmann::json const &j, Operator &t); - -struct MapOutput { - int dstOpId; - int dstTsId; - int srcOpId; - int srcTsId; -}; -void from_json(nlohmann::json const &j, MapOutput &t); - -struct Rule { - std::string name; - std::vector srcOp; - std::vector dstOp; - std::vector mappedOutput; -}; -void from_json(nlohmann::json const &j, Rule &t); - -struct RuleCollection { - std::vector rules; -}; -void from_json(nlohmann::json const &j, RuleCollection &c); - -RuleCollection load_rule_collection(std::istream &s); -RuleCollection load_rule_collection_from_path(std::string const &path); - -} // namespace FlexFlow - -#endif // _FLEXFLOW_SUBSTITUTION_LOADER_H diff --git a/lib/substitution-generator/include/substitution-generator/legacy_rules.h b/lib/substitution-generator/include/substitution-generator/legacy_rules.h new file mode 100644 index 0000000000..a0e0a9790a --- /dev/null +++ b/lib/substitution-generator/include/substitution-generator/legacy_rules.h @@ -0,0 +1,59 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_RULES_H +#define _FLEXFLOW_LIB_SUBSTITUTION_GENERATOR_INCLUDE_SUBSTITUTION_GENERATOR_LEGACY_RULES_H + +#include "substitution-generator/legacy_operator_type.dtg.h" +#include "substitution-generator/legacy_pm_parameter.dtg.h" +#include +#include +#include + +namespace FlexFlow { + +struct LegacyParameter { + LegacyPMParameter key; + int value; +}; +void from_json(nlohmann::json const &j, LegacyParameter &p); + +struct LegacyTensor { + int opId; + int tsId; +}; +void from_json(nlohmann::json const &j, LegacyTensor &t); + +struct LegacyOperator { + LegacyOperatorType op_type; + std::vector input; + std::vector para; + + std::optional at(LegacyPMParameter key) const; +}; +void from_json(nlohmann::json const &j, LegacyOperator &t); + +struct LegacyMapOutput { + int dstOpId; + int dstTsId; + int srcOpId; + int srcTsId; +}; +void from_json(nlohmann::json const &j, LegacyMapOutput &t); + +struct LegacyRule { + std::string name; + std::vector srcOp; + std::vector dstOp; + std::vector mappedOutput; +}; +void from_json(nlohmann::json const &j, LegacyRule &t); + +struct LegacyRuleCollection { + std::vector rules; +}; +void from_json(nlohmann::json const &j, LegacyRuleCollection &c); + +LegacyRuleCollection load_rule_collection(std::istream &s); +LegacyRuleCollection load_rule_collection_from_path(std::string const &path); + +} // namespace FlexFlow + +#endif // _FLEXFLOW_SUBSTITUTION_LOADER_H diff --git a/lib/substitution-generator/src/substitution-generator/json.cc b/lib/substitution-generator/src/substitution-generator/legacy_rules.cc similarity index 67% rename from lib/substitution-generator/src/substitution-generator/json.cc rename to lib/substitution-generator/src/substitution-generator/legacy_rules.cc index 940ecb3e36..157f062cbf 100644 --- a/lib/substitution-generator/src/substitution-generator/json.cc +++ b/lib/substitution-generator/src/substitution-generator/legacy_rules.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include #include #include @@ -7,12 +7,12 @@ using json = nlohmann::json; namespace FlexFlow { -void from_json(json const &j, Parameter &p) { +void from_json(json const &j, LegacyParameter &p) { j.at("key").get_to(p.key); j.at("value").get_to(p.value); } -void from_json(json const &j, Tensor &t) { +void from_json(json const &j, LegacyTensor &t) { j.at("opId").get_to(t.opId); j.at("tsId").get_to(t.tsId); } @@ -29,38 +29,38 @@ void from_json(json const &j, Tensor &t) { /* return value; */ /* } */ -void from_json(json const &j, Operator &o) { +void from_json(json const &j, LegacyOperator &o) { j.at("type").get_to(o.op_type); j.at("input").get_to(o.input); j.at("para").get_to(o.para); } -void from_json(json const &j, MapOutput &m) { +void from_json(json const &j, LegacyMapOutput &m) { j.at("dstOpId").get_to(m.dstOpId); j.at("dstTsId").get_to(m.dstTsId); j.at("srcOpId").get_to(m.srcOpId); j.at("srcTsId").get_to(m.srcTsId); } -void from_json(json const &j, Rule &r) { +void from_json(json const &j, LegacyRule &r) { j.at("name").get_to(r.name); j.at("srcOp").get_to(r.srcOp); j.at("dstOp").get_to(r.dstOp); j.at("mappedOutput").get_to(r.mappedOutput); } -void from_json(json const &j, RuleCollection &c) { +void from_json(json const &j, LegacyRuleCollection &c) { j.at("rule").get_to(c.rules); } -RuleCollection load_rule_collection(std::istream &s) { +LegacyRuleCollection load_rule_collection(std::istream &s) { json j; s >> j; - RuleCollection rule_collection = j; + LegacyRuleCollection rule_collection = j; return rule_collection; } -RuleCollection load_rule_collection_from_path(std::string const &path) { +LegacyRuleCollection load_rule_collection_from_path(std::string const &path) { std::ifstream input(path); return load_rule_collection(input); } diff --git a/lib/substitution-generator/test/substitution-generator/json.cc b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc similarity index 88% rename from lib/substitution-generator/test/substitution-generator/json.cc rename to lib/substitution-generator/test/substitution-generator/legacy_rules.cc index befdaf1308..4dd9bb8cc4 100644 --- a/lib/substitution-generator/test/substitution-generator/json.cc +++ b/lib/substitution-generator/test/substitution-generator/legacy_rules.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include "doctest/doctest.h" using namespace FlexFlow; @@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) { {"type", "OP_EW_ADD"}, }; - Operator o; + LegacyOperator o; from_json(j, o); CHECK(o.op_type == LegacyOperatorType::EW_ADD); @@ -28,7 +28,7 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("deserialize full file") { - RuleCollection collection = + LegacyRuleCollection collection = load_rule_collection_from_path("graph_subst_3_v2.json"); CHECK(collection.rules.size() == 640); } diff --git a/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml new file mode 100644 index 0000000000..dd2e850aed --- /dev/null +++ b/lib/substitutions/include/substitutions/input_parallel_tensor_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "input_parallel_tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h new file mode 100644 index 0000000000..ad60d50db1 --- /dev/null +++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPEN_PARALLEL_TENSOR_GUID_T_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPEN_PARALLEL_TENSOR_GUID_T_H + +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "substitutions/input_parallel_tensor_guid_t.dtg.h" +#include "substitutions/open_parallel_tensor_guid_t.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_closed(parallel_tensor_guid_t); +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_input(input_parallel_tensor_guid_t); + +template > +Ret visit_open_parallel_tensor_guid(open_parallel_tensor_guid_t t, F f) { + return t.raw_open_dataflow_value.visit(overload{ + [&](DataflowOutput const &o) { return f(parallel_tensor_guid_t{o}); }, + [&](DataflowGraphInput const &i) { + return f(input_parallel_tensor_guid_t{i}); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml new file mode 100644 index 0000000000..f07dc12d62 --- /dev/null +++ b/lib/substitutions/include/substitutions/open_parallel_tensor_guid_t.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "open_parallel_tensor_guid_t" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +] + +[[fields]] +name = "raw_open_dataflow_value" +type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h new file mode 100644 index 0000000000..4affdd697f --- /dev/null +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_constraint.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OPERATOR_PATTERN_OPERATOR_ATTRIBUTE_CONSTRAINT_H + +#include "substitutions/operator_pattern/operator_attribute_constraint.dtg.h" + +namespace FlexFlow { + +OperatorAttributeConstraint op_type_equals_constraint(OperatorType); + +OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey, + OperatorAttributeValue const &); +OperatorAttributeConstraint + make_equals_constraint(OperatorAttributeExpr const &, + OperatorAttributeValue const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h index e63c03207b..a6324863a6 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_expr.h @@ -9,8 +9,8 @@ namespace FlexFlow { std::optional - evaluate_attribute_expr(PCGOperatorAttrs const &attrs, - OperatorAttributeExpr const &expr); + evaluate_attribute_expr(OperatorAttributeExpr const &expr, + PCGOperatorAttrs const &attrs); } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index da2feb1903..7df65ef361 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -35,6 +35,9 @@ type = "int" [[values]] type = "bool" +[[values]] +type = "float" + [[values]] type = "std::vector" @@ -45,7 +48,7 @@ type = "std::vector<::FlexFlow::ff_dim_t>" type = "::FlexFlow::OperatorType" [[values]] -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" [[values]] type = "::FlexFlow::ff_dim_t" diff --git a/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h b/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h new file mode 100644 index 0000000000..cc2fac4805 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/materialize_operator_from_attrs_map.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_MATERIALIZE_OPERATOR_FROM_ATTRS_MAP_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_MATERIALIZE_OPERATOR_FROM_ATTRS_MAP_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_key.dtg.h" +#include "substitutions/operator_pattern/operator_attribute_value.dtg.h" + +namespace FlexFlow { + +PCGOperatorAttrs materialize_operator_from_attrs_map( + std::unordered_map const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h new file mode 100644 index 0000000000..e550767292 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_GRAPH_EXPR_H + +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/output_graph/output_graph_expr_node.dtg.h" +#include "substitutions/output_graph/output_graph_expr_node_output.dtg.h" + +namespace FlexFlow { + +std::vector + get_node_outputs(OutputGraphExpr const &, OutputGraphExprNode const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml index 5caeff92f5..9ad65369a9 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr.struct.toml @@ -5,8 +5,9 @@ features = [] includes = [ "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h", "substitutions/output_graph/output_operator_attrs_assignment.dtg.h", + "", ] [[fields]] name = "raw_graph" -type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::nullopt_t>" +type = "::FlexFlow::LabelledOpenDataflowGraph<::FlexFlow::OutputOperatorAttrsAssignment, std::monostate>" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml new file mode 100644 index 0000000000..fe7a861f0a --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_input.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputGraphExprInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_dataflow_graph_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml new file mode 100644 index 0000000000..37c2a1f563 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputGraphExprNode" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h" +] + +[[fields]] +name = "raw_graph_node" +type = "::FlexFlow::Node" diff --git a/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml new file mode 100644 index 0000000000..7a2072e385 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_graph_expr_node_output.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputGraphExprNodeOutput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_output.dtg.h", +] + +[[fields]] +name = "raw_dataflow_output" +type = "::FlexFlow::DataflowOutput" diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml index 5527635a2e..e856249e50 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attr_access.struct.toml @@ -8,13 +8,13 @@ features = [ ] includes = [ - "utils/graph/node/node.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", "substitutions/operator_pattern/operator_attribute_expr.dtg.h", ] [[fields]] name = "node" -type = "::FlexFlow::Node" +type = "::FlexFlow::PatternNode" # NOTE(@wmdi) I am not sure whether these should be part of attribute expr. [[fields]] diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h new file mode 100644 index 0000000000..cba095b444 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attribute_expr.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRIBUTE_EXPR_H + +#include "output_operator_attribute_expr.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" + +namespace FlexFlow { + +OperatorAttributeValue evaluate_output_operator_attribute_expr( + OutputOperatorAttributeExpr const &, + std::unordered_map const &node_match); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h new file mode 100644 index 0000000000..60540c0711 --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_OUTPUT_GRAPH_OUTPUT_OPERATOR_ATTRS_ASSIGNMENT_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.dtg.h" +#include "substitutions/unlabelled/pattern_node.dtg.h" + +namespace FlexFlow { + +OutputOperatorAttrsAssignment output_operator_clone_node(PatternNode const &); + +PCGOperatorAttrs materialize_output_operator_from_attrs_assignment( + OutputOperatorAttrsAssignment const &attrs_assignment, + std::unordered_map const &node_match); + +std::pair + copy_attr_from_pattern_node(OperatorAttributeKey key, + PatternNode const &pattern_node); +std::pair + set_attr_to_constant(OperatorAttributeKey key, + OperatorAttributeValue const &value); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml index ac91e9f146..d712ea96f7 100644 --- a/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml +++ b/lib/substitutions/include/substitutions/output_graph/output_operator_attrs_assignment.struct.toml @@ -11,6 +11,7 @@ features = [ includes = [ "substitutions/operator_pattern/operator_attribute_key.dtg.h", "substitutions/output_graph/output_operator_attribute_expr.dtg.h", + "substitutions/unlabelled/pattern_node.dtg.h", "", ] @@ -19,6 +20,10 @@ src_includes = [ "utils/fmt/unordered_map.h", ] +# [[fields]] +# name = "clone_operator" +# type = "std::optional" + # NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can # define the assignment for each operator type. [[fields]] diff --git a/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml b/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml new file mode 100644 index 0000000000..e29eef4cdd --- /dev/null +++ b/lib/substitutions/include/substitutions/output_graph/output_pattern_value.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "OutputPatternValue" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", +] + +[[fields]] +name = "raw_dataflow_value" +type = "::FlexFlow::OpenDataflowValue" diff --git a/lib/substitutions/include/substitutions/pcg_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h index 593f0ddc9e..7342e8169f 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -2,18 +2,18 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_H #include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/unlabelled/pattern_matching.h" #include "substitutions/unlabelled/pattern_node.dtg.h" #include "substitutions/unlabelled/pattern_value.dtg.h" -#include "substitutions/unlabelled/unlabelled_graph_pattern.dtg.h" namespace FlexFlow { /** * @brief Find all locations in \p pcg that match \p pattern */ -std::vector +std::vector find_pattern_matches(PCGPattern const &pattern, SubParallelComputationGraph const &pcg); @@ -24,10 +24,12 @@ TensorAttributePattern get_tensor_pattern(PCGPattern const &, OperatorAttributePattern get_operator_pattern(PCGPattern const &, PatternNode const &); std::unordered_set get_inputs(PCGPattern const &); +std::vector get_pattern_node_outputs(PCGPattern const &, + PatternNode const &); bool assignment_satisfies(SubParallelComputationGraph const &, PCGPattern const &, - UnlabelledDataflowGraphPatternMatch const &); + PCGPatternMatch const &); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/pcg_pattern_builder.h b/lib/substitutions/include/substitutions/pcg_pattern_builder.h new file mode 100644 index 0000000000..4c91dd07af --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern_builder.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_BUILDER_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_BUILDER_H + +#include "substitutions/operator_pattern/operator_attribute_pattern.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" +#include "substitutions/unlabelled/pattern_value.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" + +namespace FlexFlow { + +struct PCGPatternBuilder { + PCGPatternBuilder(); + + PatternValue add_input(); + PatternValue add_input(TensorAttributePattern const &); + + std::vector + add_operator(OperatorAttributePattern const &, + std::vector const &inputs, + std::vector const &outputs); + PatternValue add_operator(OperatorAttributePattern const &, + std::vector const &inputs, + TensorAttributePattern const &output); + + PCGPattern get_pattern() const; + +private: + LabelledOpenDataflowGraph g; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern_match.h b/lib/substitutions/include/substitutions/pcg_pattern_match.h new file mode 100644 index 0000000000..388377d70c --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern_match.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_MATCH_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_PCG_PATTERN_MATCH_H + +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/unlabelled/pattern_node_output.dtg.h" +#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" + +namespace FlexFlow { + +bidict + get_output_mapping_for_pcg_pattern_match( + PCGPatternMatch const &match, + PCGPattern const &pattern, + SubParallelComputationGraph const &spcg); + +UnlabelledDataflowGraphPatternMatch + get_unlabelled_pattern_match(PCGPatternMatch const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml b/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml new file mode 100644 index 0000000000..f45bedd2be --- /dev/null +++ b/lib/substitutions/include/substitutions/pcg_pattern_match.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "PCGPatternMatch" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "substitutions/unlabelled/pattern_node.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/open_parallel_tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "node_assignment" +type = "::FlexFlow::bidict<::FlexFlow::PatternNode, ::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "input_assignment" +type = "std::unordered_map<::FlexFlow::PatternInput, ::FlexFlow::open_parallel_tensor_guid_t>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 42d85dc549..00032045c0 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -2,20 +2,26 @@ #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_H #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include "substitutions/open_parallel_tensor_guid_t.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/sub_parallel_computation_graph_data.dtg.h" +#include "substitutions/sub_parallel_computation_graph_edge.dtg.h" namespace FlexFlow { std::unordered_set get_parallel_layers(SubParallelComputationGraph const &sub_pcg); ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &, - Node const &); + parallel_layer_guid_t const &); PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &, - Node const &); + parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &, - OpenDataflowValue const &); + open_parallel_tensor_guid_t const &); SubParallelComputationGraph sub_pcg_from_full_pcg(ParallelComputationGraph const &); ParallelComputationGraph @@ -25,6 +31,41 @@ parallel_layer_guid_t get_parallel_layer_by_name(SubParallelComputationGraph const &pcg, std::string const &name); +std::vector + get_layer_inputs(SubParallelComputationGraph const &, + parallel_layer_guid_t const &); +std::vector + get_layer_outputs(SubParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set get_subgraph_incoming_edges( + SubParallelComputationGraph const &, + std::unordered_set const &); +std::unordered_set get_subgraph_outgoing_edges( + SubParallelComputationGraph const &, + std::unordered_set const &); + +std::unordered_set get_subgraph_incoming_edges( + SubParallelComputationGraph const &, + std::unordered_set const &); + +std::unordered_set + get_parallel_tensor_uses(SubParallelComputationGraph const &, + open_parallel_tensor_guid_t const &); + +SubParallelComputationGraphData + get_sub_pcg_data(SubParallelComputationGraph const &); +SubParallelComputationGraph + sub_pcg_from_graph_data(SubParallelComputationGraphData const &); +bool sub_pcgs_are_isomorphic(SubParallelComputationGraph const &, + SubParallelComputationGraph const &); + +SubParallelComputationGraph + without_layer_names(SubParallelComputationGraph const &); + +std::string as_dot(SubParallelComputationGraph const &); +void debug_print_dot(SubParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml new file mode 100644 index 0000000000..537af231bf --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_data.struct.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraphData" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/open_parallel_tensor_guid_t.dtg.h", + "substitutions/input_parallel_tensor_guid_t.dtg.h", + "substitutions/sub_parallel_computation_graph_edge.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::ParallelLayerAttrs>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::SubParallelComputationGraphEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::input_parallel_tensor_guid_t>" + +[[fields]] +name = "value_data" +type = "std::unordered_map<::FlexFlow::open_parallel_tensor_guid_t, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h new file mode 100644 index 0000000000..15cbb6127c --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUB_PARALLEL_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" +#include "substitutions/open_parallel_tensor_guid_t.dtg.h" +#include "substitutions/sub_parallel_computation_graph_edge.dtg.h" + +namespace FlexFlow { + +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_dst(parallel_tensor_guid_t const &tensor, + parallel_layer_guid_t const &layer, + int input_idx); +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_use(open_parallel_tensor_guid_t const &tensor, + parallel_tensor_use_t const &use); +open_parallel_tensor_guid_t + get_parallel_tensor(SubParallelComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml new file mode 100644 index 0000000000..6d8f72bae8 --- /dev/null +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "SubParallelComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::OpenDataflowEdge" diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index 4d3473997b..7b4e5e6912 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H #define _FLEXFLOW_SUBSTITUTIONS_SUBSTITUTION_H -#include "sub_parallel_computation_graph.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/substitution.dtg.h" -#include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" namespace FlexFlow { @@ -23,8 +23,8 @@ namespace FlexFlow { bool is_valid_substitution(Substitution const &); /** - * @brief Applies substitution to sub_pcg at the location specified by match, - * returning the resulting SubParallelComputationGraph + * @brief Applies \p substitution to \p sub_pcg at the location specified by \p + * match, returning the resulting SubParallelComputationGraph * * @param sub_pcg * @param substitution @@ -39,7 +39,7 @@ bool is_valid_substitution(Substitution const &); SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &sub_pcg, Substitution const &substitution, - UnlabelledDataflowGraphPatternMatch const &match); + PCGPatternMatch const &match); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/substitution.struct.toml b/lib/substitutions/include/substitutions/substitution.struct.toml index f370ef80fd..49bef62747 100644 --- a/lib/substitutions/include/substitutions/substitution.struct.toml +++ b/lib/substitutions/include/substitutions/substitution.struct.toml @@ -5,6 +5,10 @@ features = [] includes = [ "substitutions/pcg_pattern.dtg.h", "substitutions/output_graph/output_graph_expr.dtg.h", + "substitutions/output_graph/output_graph_expr_input.dtg.h", + "substitutions/output_graph/output_graph_expr_node_output.dtg.h", + "substitutions/unlabelled/pattern_input.dtg.h", + "substitutions/unlabelled/pattern_node_output.dtg.h", ] [[fields]] @@ -16,9 +20,9 @@ name = "output_graph_expr" type = "::FlexFlow::OutputGraphExpr" [[fields]] -name = "input_edge_match_to_output" -type = "::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, ::FlexFlow::OpenDataflowValue>" +name = "inputs_mapping" +type = "::FlexFlow::bidict<::FlexFlow::PatternInput, ::FlexFlow::OutputGraphExprInput>" [[fields]] -name = "output_edge_match_to_output" -type = "::FlexFlow::bidict<::FlexFlow::DataflowOutput, ::FlexFlow::DataflowOutput>" +name = "outputs_mapping" +type = "::FlexFlow::bidict<::FlexFlow::PatternNodeOutput, ::FlexFlow::OutputGraphExprNodeOutput>" diff --git a/lib/substitutions/include/substitutions/substitution_internal/evaluate_substitution_output.h b/lib/substitutions/include/substitutions/substitution_internal/evaluate_substitution_output.h new file mode 100644 index 0000000000..a0461b075b --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/evaluate_substitution_output.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_EVALUATE_SUBSTITUTION_OUTPUT_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_EVALUATE_SUBSTITUTION_OUTPUT_H + +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.dtg.h" +#include + +namespace FlexFlow { + +/** + * @brief Takes a SubParallelComputationGraph and a PCGPatternMatch where a + * Substitution applies and evaluates the Substitution's OutputGraphExpr + * (producing another SubParallelComputationGraph) using the information from + * the matched nodes. + * + * @details Exists only to enable apply_substitution(SubParallelComputationGraph + * const &, Substitution const &, PCGPatternMatch const &) + * + * @note The resulting SubParallelComputationGraph has new node ids, i.e., does + * not have the same node ids as the OutputGraphExpr + */ +std::pair + evaluate_substitution_output(SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h new file mode 100644 index 0000000000..603cb670bf --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_OUTPUT_EXPR_TO_RESULT_SUB_PCG_MAPPING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_OUTPUT_EXPR_TO_RESULT_SUB_PCG_MAPPING_H + +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "substitutions/output_graph/output_graph_expr.dtg.h" +#include "substitutions/output_graph/output_graph_expr_node_output.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.dtg.h" + +namespace FlexFlow { + +bidict + get_output_graph_expr_output_mapping( + OutputExprToResultSubPCGMapping const &, + OutputGraphExpr const &, + SubParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.struct.toml b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.struct.toml new file mode 100644 index 0000000000..1fac79a91d --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OutputExprToResultSubPCGMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "substitutions/input_parallel_tensor_guid_t.dtg.h", + "substitutions/output_graph/output_graph_expr_node.dtg.h", + "substitutions/output_graph/output_graph_expr_input.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::OutputGraphExprNode>" + +[[fields]] +name = "input_mapping" +type = "::FlexFlow::bidict<::FlexFlow::input_parallel_tensor_guid_t, ::FlexFlow::OutputGraphExprInput>" diff --git a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h new file mode 100644 index 0000000000..de9d1cd78a --- /dev/null +++ b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H + +#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" + +namespace FlexFlow { + +/** + * @brief Takes a SubParallelComputationGraph but without ParallelTensorShape + * annotations on its OpenDataflowValue%s and uses shape inference to fill them + * in. + * + * @details The OutputGraphExpr of a Substitution only computes + * PCGOperatorAttr%s, not ParallelTensorShape%s, under the theory that shapes + * can be inferred by parallel shape inference. The responsibility of this + * function is to traverse the result of evaluating the OutputGraphExpr + * (resulting from evaluate_substitution_output) + * and annotate each of the OpenDataflowValue%s with the inferred shape. + * + * Exists only to enable apply_substitution(SubParallelComputationGraph const &, + * Substitution const &, PCGPatternMatch const &) + */ +LabelledOpenDataflowGraphView + perform_shape_inference( + LabelledOpenDataflowGraphView const + &g, + std::unordered_map const + &input_shapes); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml index 3df36d13ac..541888038b 100644 --- a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_key.enum.toml @@ -12,3 +12,9 @@ name = "DIM_SIZES" [[values]] name = "DIM_DEGREES" + +[[values]] +name = "DISCARD_COPY_DEGREE_DIM" + +[[values]] +name = "SUM_DEGREE_DIM" diff --git a/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h new file mode 100644 index 0000000000..5b7ebf4ef8 --- /dev/null +++ b/lib/substitutions/include/substitutions/tensor_pattern/tensor_attribute_pattern.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H + +#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h" + +namespace FlexFlow { + +TensorAttributePattern tensor_attribute_pattern_match_all(); + +} // namespace FlexFlow + +#endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h index 262ae64bf8..09d6a12716 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_UNLABELLED_UNLABELLED_DATAFLOW_GRAPH_PATTERN_MATCH_H +#include "substitutions/pcg_pattern.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/unlabelled/pattern_value.dtg.h" #include "substitutions/unlabelled/unlabelled_dataflow_graph_pattern_match.dtg.h" #include @@ -20,6 +22,11 @@ std::optional bidict const &merged_graph_values_to_inputs_of_2); +std::unordered_map + get_output_assignment(SubParallelComputationGraph const &, + PCGPattern const &, + UnlabelledDataflowGraphPatternMatch const &); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h index 95277edfc3..949fbf455b 100644 --- a/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h +++ b/lib/substitutions/include/substitutions/unlabelled/unlabelled_graph_pattern.h @@ -17,7 +17,8 @@ std::unordered_set get_values(UnlabelledGraphPattern const &); std::vector get_topological_ordering(UnlabelledGraphPattern const &); -std::unordered_set get_inputs(UnlabelledGraphPattern const &); +std::unordered_set + get_graph_inputs(UnlabelledGraphPattern const &); std::unordered_set get_edges(UnlabelledGraphPattern const &); diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc deleted file mode 100644 index 94993f3c90..0000000000 --- a/lib/substitutions/src/substitution.cc +++ /dev/null @@ -1,387 +0,0 @@ -#include "substitutions/substitution.h" - -namespace FlexFlow { - -/* struct DeriveValidOperatorAttributeExpr { */ -/* template */ -/* std::unordered_set> */ -/* operator()(T const &t) { */ -/* return derive_valid_operator_attribute_expr(t); */ -/* } */ - -/* std::unordered_set> */ -/* derive_valid_operator_attribute_expr(OperatorAttributeKey const &key) { - */ -/* return {key}; */ -/* } */ - -/* std::unordered_set> */ -/* derive_valid_operator_attribute_expr( */ -/* ListIndexAccess const &access) { */ -/* return {access, access.attribute_key}; */ -/* } */ - -/* std::unordered_set> */ -/* derive_valid_operator_attribute_expr( */ -/* ListSize const &ls) { */ -/* return {ls, ls.attribute_key}; */ -/* } */ -/* }; */ - -/* std::unordered_set> */ -/* get_valid_operator_attribute_exprs(OperatorPattern const &pattern) { */ -/* return set_union(transform( */ -/* pattern.attribute_constraints, [](OperatorAttributeConstraint const &t) - * { */ -/* return visit(DeriveValidOperatorAttributeExpr{}, t.attribute_expr); - */ -/* })); */ -/* } */ - -/* bool is_valid_operator_attribute_expr( */ -/* OperatorPattern const &pattern, */ -/* AttributeExpr const &expr) { */ -/* return contains(get_valid_operator_attribute_exprs(pattern), expr); */ -/* } */ - -/* struct IsValidOperatorAttributeExprFunctor { */ -/* GraphPattern const &graph_pattern; */ - -/* template */ -/* bool operator()(T const &t) const { */ -/* return is_valid(t); */ -/* } */ - -/* bool is_valid(OperatorAttrAccess const &t) const { */ -/* return is_valid_operator_attribute_expr(graph_pattern.value().at(t.node), - */ -/* t.attr_expr); */ -/* } */ - -/* bool is_valid(AttrConstant const &t) const { */ -/* return true; */ -/* } */ -/* }; */ - -/* bool is_valid_operator_attribute_expr(GraphPattern const &pattern, */ -/* OperatorAttributeExpr const &expr) { */ -/* return visit(IsValidOperatorAttributeExprFunctor{pattern}, expr); */ -/* } */ - -/* bool is_valid_substitution(Substitution const &s) { */ -/* for (Node const &node : get_nodes(s.output_graph_expr.value())) { */ -/* for (OperatorAttributeExpr expr : */ -/* values(s.output_graph_expr.value().at(node).assignments)) { */ -/* if (!is_valid_operator_attribute_expr(s.input_graph, expr)) { */ -/* return false; */ -/* } */ -/* } */ -/* } */ -/* return true; */ -/* } */ - -/* struct EvaluateOperatorAttributeExpr { */ -/* SubParallelComputationGraph const &graph; */ -/* MultiDiGraphPatternMatch const &match; */ - -/* template */ -/* OperatorAttributeValue operator()(T const &t) { */ -/* return evaluate(t); */ -/* } */ - -/* OperatorAttributeValue evaluate(OperatorAttrAccess const &t) { */ -/* Node node_in_pattern = t.node; */ -/* Node node_in_pcg = match.node_assignment.at_l(node_in_pattern); */ -/* return evaluate_attribute_expr(graph.at(node_in_pcg), - * t.attr_expr).value(); */ -/* } */ - -/* OperatorAttributeValue evaluate(AttrConstant const &t) { */ -/* return t.value; */ -/* } */ -/* }; */ - -/* OperatorAttributeValue */ -/* evaluate_graph_attribute_expr(SubParallelComputationGraph const &g, */ -/* MultiDiGraphPatternMatch const &match, */ -/* OperatorAttributeExpr const &expr) { */ -/* return visit(EvaluateOperatorAttributeExpr{g, match}, expr); */ -/* } */ - -/* Operator get_operator_attrs(SubParallelComputationGraph const &graph, */ -/* MultiDiGraphPatternMatch const &match, */ -/* OperatorAttrAssignment const &assignment) { */ -/* std::unordered_map - * assignments; */ -/* for (auto const &[key, expr] : assignment.assignments) { */ -/* OperatorAttributeValue value = */ -/* evaluate_graph_attribute_expr(graph, match, expr); */ -/* assignments.emplace(key, value); */ -/* } */ -/* assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); */ -/* assert(std::holds_alternative( */ -/* assignments.at(OperatorAttributeKey::OP_TYPE))); */ -/* OperatorType op_type = */ -/* std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); - */ -/* switch (op_type) { */ -/* case OperatorType::BATCHMATMUL: */ -/* return Operator{ */ -/* BatchMatmulAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::A_SEQ_LENGTH_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, */ -/* std::nullopt}; */ -/* case OperatorType::BATCHNORM: */ -/* return Operator{BatchNormAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::RELU))}, */ -/* std::nullopt}; */ -/* case OperatorType::CAST: */ -/* return Operator{CastAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::DATA_TYPE))}, - */ -/* std::nullopt}; */ -/* case OperatorType::CONCAT: */ -/* return Operator{ */ -/* ConcatAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::AXIS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::CONV2D: */ -/* return Operator{ */ -/* Conv2DAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::GROUPS)), */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ -/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::DROPOUT: */ -/* return Operator{DropoutAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::RATE)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::SEED))}, */ -/* std::nullopt}; */ -/* case OperatorType::EW_ADD: */ -/* case OperatorType::EW_DIV: */ -/* case OperatorType::EW_EQUAL: */ -/* case OperatorType::EW_GREATER: */ -/* case OperatorType::EW_LESS: */ -/* case OperatorType::EW_MAX: */ -/* case OperatorType::EW_MIN: */ -/* case OperatorType::EW_MUL: */ -/* case OperatorType::EW_SUB: */ -/* return Operator{ */ -/* ElementBinaryAttrs{op_type, */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::DATA_TYPE)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::SCALAR_ADD: */ -/* case OperatorType::SCALAR_FLOOR_DIV: */ -/* case OperatorType::SCALAR_MULTIPLY: */ -/* case OperatorType::SCALAR_SUB: */ -/* case OperatorType::SCALAR_TRUE_DIV: */ -/* return Operator{ */ -/* ElementScalarUnaryAttrs{ */ -/* op_type, */ -/* std::get(assignments.at(OperatorAttributeKey::SCALAR))}, - */ -/* std::nullopt}; */ -/* case OperatorType::EXP: */ -/* case OperatorType::IDENTITY: */ -/* case OperatorType::GELU: */ -/* case OperatorType::RSQRT: */ -/* case OperatorType::POW: */ -/* case OperatorType::SIN: */ -/* case OperatorType::COS: */ -/* return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; */ -/* case OperatorType::EMBEDDING: */ -/* return Operator{ */ -/* EmbeddingAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::AGGR)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::OP_TYPE))}, */ -/* std::nullopt}; */ -/* case OperatorType::FLAT: */ -/* return Operator{FlatAttrs{}, std::nullopt}; */ -/* case OperatorType::GATHER: */ -/* return Operator{GatherAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::DIM))}, */ -/* std::nullopt}; */ -/* case OperatorType::INPUT: */ -/* return Operator{InputAttrs{}, std::nullopt}; */ -/* case OperatorType::LAYERNORM: */ -/* return Operator{ */ -/* LayerNormAttrs{ */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::AXES)), */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - */ -/* std::get(assignments.at(OperatorAttributeKey::EPSILON))}, - */ -/* std::nullopt}; */ -/* case OperatorType::LINEAR: */ -/* return Operator{ */ -/* LinearAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::DATA_TYPE)), */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ACTIVATION)), */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::REGULARIZER))}, */ -/* std::nullopt}; */ -/* case OperatorType::MULTIHEAD_ATTENTION: */ -/* return Operator{ */ -/* MultiHeadAttentionAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - */ -/* std::get(assignments.at(OperatorAttributeKey::VDIM)), */ -/* std::get(assignments.at(OperatorAttributeKey::DROPOUT)), - */ -/* std::get(assignments.at(OperatorAttributeKey::BIAS)), */ -/* std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, */ -/* std::nullopt}; */ -/* case OperatorType::NOOP: */ -/* return Operator{NoopAttrs{}, std::nullopt}; */ -/* case OperatorType::POOL2D: */ -/* return Operator{ */ -/* Pool2DAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_H)), - */ -/* std::get(assignments.at(OperatorAttributeKey::PADDING_W)), - */ -/* std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - */ -/* std::get( */ -/* assignments.at(OperatorAttributeKey::ACTIVATION))}, */ -/* std::nullopt}; */ -/* case OperatorType::REDUCE_ARGMAX: */ -/* case OperatorType::REDUCE_ARGMIN: */ -/* case OperatorType::REDUCE_MAX: */ -/* case OperatorType::REDUCE_MEAN: */ -/* case OperatorType::REDUCE_MIN: */ -/* case OperatorType::REDUCE_PROD: */ -/* case OperatorType::REDUCE_SUM: */ -/* return Operator{ */ -/* ReduceAttrs{ */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::AXES)), */ -/* op_type, */ -/* std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::REVERSE: */ -/* return Operator{ReverseAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::AXIS))}, */ -/* std::nullopt}; */ -/* case OperatorType::RESHAPE: */ -/* return Operator{ReshapeAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::SHAPE))}, */ -/* std::nullopt}; */ -/* case OperatorType::SPLIT: */ -/* return Operator{ */ -/* SplitAttrs{ */ -/* std::get>( */ -/* assignments.at(OperatorAttributeKey::SPLITS)), */ -/* std::get(assignments.at(OperatorAttributeKey::AXIS))}, - */ -/* std::nullopt}; */ -/* case OperatorType::SOFTMAX: */ -/* return Operator{SoftmaxAttrs{std::get( */ -/* assignments.at(OperatorAttributeKey::DIM))}, */ -/* std::nullopt}; */ -/* case OperatorType::TOPK: */ -/* return Operator{ */ -/* TopKAttrs{ */ -/* std::get(assignments.at(OperatorAttributeKey::K)), */ -/* std::get(assignments.at(OperatorAttributeKey::SORTED))}, - */ -/* std::nullopt}; */ -/* case OperatorType::TRANSPOSE: */ -/* return Operator{ */ -/* TransposeAttrs{std::get>( */ -/* assignments.at(OperatorAttributeKey::PERMUTATION))}, */ -/* std::nullopt}; */ -/* case OperatorType::COMBINE: */ -/* return Operator{CombineAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), - */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, - */ -/* std::nullopt}; */ -/* case OperatorType::REDUCTION: */ -/* return Operator{ */ -/* ReductionAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ -/* std::nullopt}; */ -/* case OperatorType::REPARTITION: */ -/* return Operator{ */ -/* RepartitionAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ -/* std::nullopt}; */ -/* case OperatorType::REPLICATE: */ -/* return Operator{ */ -/* ReplicateAttrs{std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DIM)), */ -/* std::get(assignments.at( */ -/* OperatorAttributeKey::PARALLEL_DEGREE))}, */ -/* std::nullopt}; */ -/* default: */ -/* throw mk_runtime_error("Unknown Operator"); */ -/* } */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc b/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc new file mode 100644 index 0000000000..76329229a4 --- /dev/null +++ b/lib/substitutions/src/substitutions/open_parallel_tensor_guid_t.cc @@ -0,0 +1,16 @@ +#include "substitutions/open_parallel_tensor_guid_t.h" + +namespace FlexFlow { + +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_closed(parallel_tensor_guid_t t) { + return open_parallel_tensor_guid_t{OpenDataflowValue{t.raw_graph_output}}; +} + +open_parallel_tensor_guid_t + open_parallel_tensor_guid_from_input(input_parallel_tensor_guid_t i) { + return open_parallel_tensor_guid_t{ + OpenDataflowValue{i.raw_dataflow_graph_input}}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index b12564faf0..26f8ff5062 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -190,9 +190,7 @@ std::optional get_attribute(LinearAttrs const &p, case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; case OperatorAttributeKey::USE_BIAS: - return bool(p.use_bias); // NOTE(@wmd): Without casting to bool, it will - // return an OperatorAttributeValue with - // underlying type int. Might be a req issue. + return p.use_bias; case OperatorAttributeKey::DATA_TYPE: return p.data_type; case OperatorAttributeKey::ACTIVATION: @@ -213,6 +211,8 @@ std::optional return p.num_heads; case OperatorAttributeKey::USE_BIAS: return p.bias; + case OperatorAttributeKey::DROPOUT: + return p.dropout; default: return std::nullopt; } @@ -248,7 +248,7 @@ std::optional get_attribute(Pool2DAttrs const &p, case OperatorAttributeKey::POOL_TYPE: return p.pool_type; case OperatorAttributeKey::ACTIVATION: - return p.activation; + return std::optional{p.activation}; default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc new file mode 100644 index 0000000000..5ab528ed3d --- /dev/null +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_constraint.cc @@ -0,0 +1,33 @@ +#include "substitutions/operator_pattern/operator_attribute_constraint.h" + +namespace FlexFlow { + +OperatorAttributeConstraint op_type_equals_constraint(OperatorType op_type) { + return OperatorAttributeConstraint{ + ConstraintType::EQUAL, + OperatorAttributeExpr{OperatorAttributeKey::OP_TYPE}, + OperatorAttributeValue{op_type}, + }; +} + +OperatorAttributeConstraint + op_attr_key_equals(OperatorAttributeKey key, + OperatorAttributeValue const &val) { + return OperatorAttributeConstraint{ + ConstraintType::EQUAL, + OperatorAttributeExpr{key}, + OperatorAttributeValue{val}, + }; +} + +OperatorAttributeConstraint + make_equals_constraint(OperatorAttributeExpr const &expr, + OperatorAttributeValue const &val) { + return OperatorAttributeConstraint{ + ConstraintType::EQUAL, + expr, + val, + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc index 4a55fa3de3..20f32b129f 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/operator_attribute_expr.cc @@ -7,8 +7,8 @@ namespace FlexFlow { std::optional - evaluate_attribute_expr(PCGOperatorAttrs const &attrs, - OperatorAttributeExpr const &expr) { + evaluate_attribute_expr(OperatorAttributeExpr const &expr, + PCGOperatorAttrs const &attrs) { return expr.visit>(overload{ [&](OperatorAttributeKey const &k) { return get_attribute(attrs, k); }, [&](OperatorAttributeListSize const &k) { diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index ae42515cc8..194ae49255 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -7,7 +7,7 @@ bool operator_satisfies_constraint( PCGOperatorAttrs const &attrs, OperatorAttributeConstraint const &constraint) { std::optional expr_val = - evaluate_attribute_expr(attrs, constraint.attribute_expr); + evaluate_attribute_expr(constraint.attribute_expr, attrs); if (!expr_val.has_value()) { return false; diff --git a/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc new file mode 100644 index 0000000000..7d65f687c8 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/materialize_operator_from_attrs_map.cc @@ -0,0 +1,155 @@ +#include "substitutions/output_graph/materialize_operator_from_attrs_map.h" +#include "utils/containers/contains_key.h" +#include "utils/fmt/unordered_map.h" + +namespace FlexFlow { + +struct Accessor { + Accessor( + std::unordered_map const &m) + : m(m) {} + + std::unordered_map const &m; + + template + T const &get(OperatorAttributeKey k) const { + if (contains_key(this->m, k)) { + return this->m.at(k).get(); + } else { + throw mk_runtime_error( + fmt::format("Could not find key {} in attrs map: {}", k, this->m)); + } + } +}; + +PCGOperatorAttrs materialize_operator_from_attrs_map( + std::unordered_map const + &attrs) { + OperatorType op_type = + attrs.at(OperatorAttributeKey::OP_TYPE).get(); + + Accessor acc = Accessor{attrs}; + + switch (op_type) { + case OperatorType::MULTIHEAD_ATTENTION: + return PCGOperatorAttrs{MultiHeadAttentionAttrs{ + /*embed_dim=*/acc.get(OperatorAttributeKey::EMBED_DIM), + /*num_heads=*/acc.get(OperatorAttributeKey::NUM_HEADS), + /*kdim=*/acc.get(OperatorAttributeKey::KDIM), + /*vdim=*/acc.get(OperatorAttributeKey::VDIM), + /*dropout=*/acc.get(OperatorAttributeKey::DROPOUT), + /*bias=*/acc.get(OperatorAttributeKey::BIAS), + /*add_bias_kv=*/acc.get(OperatorAttributeKey::ADD_BIAS_KV), + /*add_zero_attn=*/acc.get(OperatorAttributeKey::ADD_ZERO_ATTN), + }}; + case OperatorType::POOL2D: + return PCGOperatorAttrs{Pool2DAttrs{ + /*kernel_h=*/acc.get(OperatorAttributeKey::KERNEL_H), + /*kernel_w=*/acc.get(OperatorAttributeKey::KERNEL_W), + /*stride_h=*/acc.get(OperatorAttributeKey::STRIDE_H), + /*stride_w=*/acc.get(OperatorAttributeKey::STRIDE_W), + /*padding_h=*/acc.get(OperatorAttributeKey::PADDING_H), + /*padding_w=*/acc.get(OperatorAttributeKey::PADDING_W), + /*pool_type=*/acc.get(OperatorAttributeKey::POOL_TYPE), + /*activation=*/ + acc.get>(OperatorAttributeKey::ACTIVATION) + .value(), + }}; + case OperatorType::NOOP: + case OperatorType::INPUT: + case OperatorType::WEIGHT: + case OperatorType::CONV2D: + case OperatorType::DROPOUT: + case OperatorType::LINEAR: + return PCGOperatorAttrs{LinearAttrs{ + /*out_channels=*/acc.get(OperatorAttributeKey::OUT_CHANNELS), + /*use_bias=*/acc.get(OperatorAttributeKey::USE_BIAS), + /*data_type=*/acc.get(OperatorAttributeKey::DATA_TYPE), + /*activation=*/ + acc.get>(OperatorAttributeKey::ACTIVATION), + /*regularizer=*/ + acc.get>( + OperatorAttributeKey::REGULARIZER), + }}; + case OperatorType::BATCHMATMUL: + case OperatorType::SCALAR_MULTIPLY: + case OperatorType::SCALAR_ADD: + case OperatorType::SCALAR_FLOOR_DIV: + case OperatorType::SCALAR_TRUE_DIV: + case OperatorType::SCALAR_SUB: + case OperatorType::RELU: + case OperatorType::IDENTITY: + case OperatorType::SIGMOID: + case OperatorType::TANH: + case OperatorType::ELU: + case OperatorType::FLAT: + case OperatorType::SOFTMAX: + case OperatorType::BATCHNORM: + case OperatorType::CONCAT: + case OperatorType::SPLIT: + case OperatorType::EMBEDDING: + case OperatorType::CACHE: + case OperatorType::RESHAPE: + case OperatorType::REVERSE: + case OperatorType::TRANSPOSE: + case OperatorType::EW_ADD: + case OperatorType::EW_MUL: + case OperatorType::MATMUL: + case OperatorType::MUL: + case OperatorType::ENLARGE: + case OperatorType::SQUEEZE: + case OperatorType::UNSQUEEZE: + case OperatorType::EW_SUB: + case OperatorType::EW_DIV: + case OperatorType::EW_EQUAL: + case OperatorType::EW_GREATER: + case OperatorType::EW_LESS: + case OperatorType::EW_MAX: + case OperatorType::EW_MIN: + case OperatorType::REDUCE_ARGMAX: + case OperatorType::REDUCE_ARGMIN: + case OperatorType::REDUCE_MAX: + case OperatorType::REDUCE_MEAN: + case OperatorType::REDUCE_MIN: + case OperatorType::REDUCE_PROD: + case OperatorType::REDUCE_SUM: + case OperatorType::PAD: + case OperatorType::SHAPE: + case OperatorType::SIZE: + case OperatorType::TOPK: + case OperatorType::WHERE: + case OperatorType::CEIL: + case OperatorType::CAST: + case OperatorType::EXP: + case OperatorType::ROUND: + case OperatorType::LOG: + case OperatorType::LOGICAL_NOT: + case OperatorType::SQRT: + case OperatorType::SIN: + case OperatorType::COS: + case OperatorType::LEAKYRELU: + case OperatorType::SLICE: + case OperatorType::RESIZE: + case OperatorType::PRELU: + case OperatorType::GELU: + case OperatorType::FUSED: + case OperatorType::RSQRT: + case OperatorType::POW: + case OperatorType::MEAN: + case OperatorType::LAYERNORM: + case OperatorType::GATHER: + case OperatorType::BROADCAST: + case OperatorType::REPARTITION: + case OperatorType::COMBINE: + case OperatorType::REPLICATE: + case OperatorType::REDUCTION: + case OperatorType::BATCH: + case OperatorType::PIPELINE: + case OperatorType::FUSED_PARALLEL: + default: + throw mk_runtime_error( + fmt::format("Unsupported operator type {}", op_type)); + } +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc new file mode 100644 index 0000000000..3d6aadc795 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_graph_expr.cc @@ -0,0 +1,17 @@ +#include "substitutions/output_graph/output_graph_expr.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms.h" + +namespace FlexFlow { + +std::vector + get_node_outputs(OutputGraphExpr const &g, OutputGraphExprNode const &n) { + std::vector raw_outputs = + get_outputs(g.raw_graph, n.raw_graph_node); + + return transform(raw_outputs, [](DataflowOutput const &o) { + return OutputGraphExprNodeOutput{o}; + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.cc new file mode 100644 index 0000000000..e7cfcf232c --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attribute_expr.cc @@ -0,0 +1,19 @@ +#include "substitutions/output_graph/output_operator_attribute_expr.h" +#include "substitutions/operator_pattern/operator_attribute_expr.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OperatorAttributeValue evaluate_output_operator_attribute_expr( + OutputOperatorAttributeExpr const &expr, + std::unordered_map const &node_match) { + return expr.visit(overload{ + [&](OutputOperatorAttrAccess const &a) { + return evaluate_attribute_expr(a.attr_expr, node_match.at(a.node)) + .value(); + }, + [](AttrConstant const &c) { return c.value; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc new file mode 100644 index 0000000000..fa247cd151 --- /dev/null +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc @@ -0,0 +1,42 @@ +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/output_graph/materialize_operator_from_attrs_map.h" +#include "substitutions/output_graph/output_operator_attribute_expr.h" +#include "utils/containers/map_values.h" + +namespace FlexFlow { + +OutputOperatorAttrsAssignment output_operator_clone_node(PatternNode const &) { + NOT_IMPLEMENTED(); +} + +PCGOperatorAttrs materialize_output_operator_from_attrs_assignment( + OutputOperatorAttrsAssignment const &attrs_assignment, + std::unordered_map const &node_match) { + std::unordered_map attr_map = + map_values(attrs_assignment.assignments, + [&](OutputOperatorAttributeExpr const &expr) { + return evaluate_output_operator_attribute_expr(expr, + node_match); + }); + + return materialize_operator_from_attrs_map(attr_map); +} + +std::pair + copy_attr_from_pattern_node(OperatorAttributeKey key, + PatternNode const &pattern_node) { + return {key, + OutputOperatorAttributeExpr{OutputOperatorAttrAccess{ + pattern_node, OperatorAttributeExpr{key}}}}; +} + +std::pair + set_attr_to_constant(OperatorAttributeKey key, + OperatorAttributeValue const &value) { + return { + key, + OutputOperatorAttributeExpr{AttrConstant{value}}, + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index 4591e644bb..e53877006d 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -1,8 +1,12 @@ #include "substitutions/pcg_pattern.h" #include "substitutions/operator_pattern/satisfies_pattern.h" +#include "substitutions/pcg_pattern_match.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/satisfies_pattern.h" #include "substitutions/unlabelled/pattern_value.h" +#include "utils/containers/map_values.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms.h" namespace FlexFlow { @@ -12,22 +16,37 @@ static MatchAdditionalCriterion return MatchAdditionalCriterion{ [&](PatternNode const &patternNode, Node const &pcgNode) { return operator_satisfies_pattern( - get_operator_attrs(pcg, pcgNode), + get_operator_attrs(pcg, parallel_layer_guid_t{pcgNode}), get_operator_pattern(pattern, patternNode)); }, [&](PatternValue const &patternValue, OpenDataflowValue const &pcgValue) { return parallel_tensor_satisfies_pattern( - get_parallel_tensor_attrs(pcg, pcgValue), + get_parallel_tensor_attrs(pcg, + open_parallel_tensor_guid_t{pcgValue}), get_tensor_pattern(pattern, patternValue)); }}; } -std::vector +std::vector find_pattern_matches(PCGPattern const &pattern, SubParallelComputationGraph const &pcg) { - return find_pattern_matches(get_unlabelled_pattern(pattern), - pcg.raw_graph, - pcg_pattern_criteria(pattern, pcg)); + std::vector unlabelled_matches = + find_pattern_matches(get_unlabelled_pattern(pattern), + pcg.raw_graph, + pcg_pattern_criteria(pattern, pcg)); + auto pcg_match_from_unlabelled_match = + [](UnlabelledDataflowGraphPatternMatch const &m) { + return PCGPatternMatch{ + map_values(m.node_assignment, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + map_values(m.input_assignment, + [](OpenDataflowValue const &i) { + return open_parallel_tensor_guid_t{i}; + }), + }; + }; + + return transform(unlabelled_matches, pcg_match_from_unlabelled_match); } UnlabelledGraphPattern get_unlabelled_pattern(PCGPattern const &p) { @@ -44,14 +63,25 @@ OperatorAttributePattern get_operator_pattern(PCGPattern const &p, return p.raw_graph.at(n.raw_node); } -bool assignment_satisfies( - SubParallelComputationGraph const &pcg, - PCGPattern const &pattern, - UnlabelledDataflowGraphPatternMatch const &patternMatch) { - return unlabelled_pattern_does_match(get_unlabelled_pattern(pattern), - pcg.raw_graph, - patternMatch, - pcg_pattern_criteria(pattern, pcg)); +std::vector + get_pattern_node_outputs(PCGPattern const &pattern, + PatternNode const &node) { + std::vector raw_outputs = + get_outputs(pattern.raw_graph, node.raw_node); + + return transform(raw_outputs, [](DataflowOutput const &o) { + return PatternNodeOutput{o}; + }); +} + +bool assignment_satisfies(SubParallelComputationGraph const &pcg, + PCGPattern const &pattern, + PCGPatternMatch const &pattern_match) { + return unlabelled_pattern_does_match( + get_unlabelled_pattern(pattern), + pcg.raw_graph, + get_unlabelled_pattern_match(pattern_match), + pcg_pattern_criteria(pattern, pcg)); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern_builder.cc b/lib/substitutions/src/substitutions/pcg_pattern_builder.cc new file mode 100644 index 0000000000..e81671f08a --- /dev/null +++ b/lib/substitutions/src/substitutions/pcg_pattern_builder.cc @@ -0,0 +1,52 @@ +#include "substitutions/pcg_pattern_builder.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" +#include "substitutions/unlabelled/pattern_value.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" + +namespace FlexFlow { + +PCGPatternBuilder::PCGPatternBuilder() + : g(LabelledOpenDataflowGraph:: + create>()) {} + +// PatternValue add_input() { +// return tensor_attribute_pattern_match_all(); +// } +// +// PatternValue PCGPatternBuilder::add_input(TensorAttributePattern const &p) { +// return PatternValue{PatternInput{this->g.add_input(p)}}; +// } +// +// std::vector +// PCGPatternBuilder::add_operator(OperatorAttributePattern const &p, +// std::vector const +// &inputs, +// std::vector +// const &outputs) { +// NodeAddedResult node_added_result = this->g.add_node(p, +// transform(inputs, +// raw_open_dataflow_value_from_pattern_value), +// outputs); +// return transform(node_added_result.outputs, +// pattern_value_from_raw_open_dataflow_value); +// } +// +// PatternValue PCGPatternBuilder::add_operator(OperatorAttributePattern const +// &p, +// std::vector const +// &inputs, TensorAttributePattern +// const &output) { +// return get_only(this->add_operator(p, inputs, {output})); +// } +// +// +// PCGPattern PCGPatternBuilder::get_pattern() const { +// return PCGPattern{this->g}; +// } + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc new file mode 100644 index 0000000000..f1f4e31d57 --- /dev/null +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -0,0 +1,49 @@ +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/containers/map_values.h" +#include "utils/containers/zip.h" + +namespace FlexFlow { + +bidict + get_output_mapping_for_pcg_pattern_match( + PCGPatternMatch const &match, + PCGPattern const &pattern, + SubParallelComputationGraph const &spcg) { + bidict result; + + for (auto const &[pattern_node, matched_layer] : match.node_assignment) { + std::vector matched_layer_output_tensors = + get_layer_outputs(spcg, matched_layer); + std::vector pattern_node_outputs = + get_pattern_node_outputs(pattern, pattern_node); + + assert(matched_layer_output_tensors.size() == pattern_node_outputs.size()); + + bidict mapping = + bidict_from_keys_and_values(pattern_node_outputs, + matched_layer_output_tensors); + + result = merge_bidicts(result, mapping); + } + + return result; +} + +UnlabelledDataflowGraphPatternMatch + get_unlabelled_pattern_match(PCGPatternMatch const &match) { + return UnlabelledDataflowGraphPatternMatch{ + map_values( + match.node_assignment, + [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), + map_values(match.input_assignment, + [](open_parallel_tensor_guid_t const &i) { + return i.raw_open_dataflow_value; + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 2f050ce45e..0bbe0e97a7 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -1,32 +1,43 @@ #include "substitutions/sub_parallel_computation_graph.h" +#include "op-attrs/pcg_operator_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/values.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" namespace FlexFlow { std::unordered_set get_parallel_layers(SubParallelComputationGraph const &sub_pcg) { - return get_parallel_layers(pcg_from_sub_pcg_by_dropping_inputs(sub_pcg)); + return transform(get_nodes(sub_pcg.raw_graph), + [](Node const &n) { return parallel_layer_guid_t{n}; }); } ParallelLayerAttrs get_parallel_layer_attrs(SubParallelComputationGraph const &spcg, - Node const &n) { - return spcg.raw_graph.at(n); + parallel_layer_guid_t const &layer) { + return spcg.raw_graph.at(layer.raw_graph_node); } PCGOperatorAttrs get_operator_attrs(SubParallelComputationGraph const &spcg, - Node const &n) { + parallel_layer_guid_t const &n) { return get_parallel_layer_attrs(spcg, n).op_attrs; } ParallelTensorAttrs get_parallel_tensor_attrs(SubParallelComputationGraph const &spcg, - OpenDataflowValue const &v) { - return spcg.raw_graph.at(v); + open_parallel_tensor_guid_t const &v) { + return spcg.raw_graph.at(v.raw_open_dataflow_value); } SubParallelComputationGraph @@ -58,4 +69,162 @@ parallel_layer_guid_t name); } +std::vector + get_layer_inputs(SubParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer) { + return transform(get_inputs(pcg.raw_graph, layer.raw_graph_node), + [](OpenDataflowValue const &v) { + return open_parallel_tensor_guid_t{v}; + }); +} + +std::vector + get_layer_outputs(SubParallelComputationGraph const &pcg, + parallel_layer_guid_t const &layer) { + return transform( + get_outputs(pcg.raw_graph, layer.raw_graph_node), + [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); +} + +std::unordered_set get_subgraph_outgoing_edges( + SubParallelComputationGraph const &spcg, + std::unordered_set const &layers) { + std::unordered_set raw_edges = get_subgraph_outgoing_edges( + spcg.raw_graph, transform(layers, [](parallel_layer_guid_t const &l) { + return l.raw_graph_node; + })); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_incoming_edges( + SubParallelComputationGraph const &spcg, + std::unordered_set const &subgraph) { + std::unordered_set raw_subgraph = + transform(subgraph, [](parallel_layer_guid_t const &l) { + return l.raw_graph_node; + }); + std::unordered_set raw_incoming_edges = + get_subgraph_incoming_edges(spcg.raw_graph, raw_subgraph); + + return transform(raw_incoming_edges, [](OpenDataflowEdge const &e) { + return SubParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set + get_parallel_tensor_uses(SubParallelComputationGraph const &spcg, + open_parallel_tensor_guid_t const &t) { + std::unordered_set raw_uses = + get_open_dataflow_value_uses(spcg.raw_graph, t.raw_open_dataflow_value); + return transform(raw_uses, [](DataflowInput const &i) { + return parallel_tensor_use_t{i}; + }); +} + +SubParallelComputationGraphData + get_sub_pcg_data(SubParallelComputationGraph const &pcg) { + LabelledOpenDataflowGraphData + raw_data = get_graph_data(pcg.raw_graph); + + return SubParallelComputationGraphData{ + map_keys(raw_data.node_data, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + transform(raw_data.edges, + [](OpenDataflowEdge const &e) { + return SubParallelComputationGraphEdge{e}; + }), + transform(raw_data.inputs, + [](DataflowGraphInput const &i) { + return input_parallel_tensor_guid_t{i}; + }), + map_keys(raw_data.value_data, + [](OpenDataflowValue const &v) { + return open_parallel_tensor_guid_t{v}; + }), + }; +} + +SubParallelComputationGraph + sub_pcg_from_graph_data(SubParallelComputationGraphData const &data) { + LabelledOpenDataflowGraphData + raw_data = LabelledOpenDataflowGraphData{ + map_keys( + data.node_data, + [](parallel_layer_guid_t const &l) { return l.raw_graph_node; }), + transform(data.edges, + [](SubParallelComputationGraphEdge const &e) { + return e.raw_edge; + }), + transform(data.inputs, + [](input_parallel_tensor_guid_t const &i) { + return i.raw_dataflow_graph_input; + }), + map_keys(data.value_data, + [](open_parallel_tensor_guid_t const &t) { + return t.raw_open_dataflow_value; + }), + }; + + return SubParallelComputationGraph{ + from_labelled_open_dataflow_graph_data(raw_data), + }; +} + +SubParallelComputationGraph + without_layer_names(SubParallelComputationGraph const &spcg) { + return SubParallelComputationGraph{ + rewrite_node_labels( + spcg.raw_graph, + [](Node const &n, ParallelLayerAttrs const &old_attrs) { + ParallelLayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + }), + }; +} + +bool sub_pcgs_are_isomorphic(SubParallelComputationGraph const &lhs, + SubParallelComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + +std::string as_dot(SubParallelComputationGraph const &spcg) { + std::function get_node_label = + [](ParallelLayerAttrs const &a) -> std::string { + RecordFormatter r = as_dot(a.op_attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + std::function get_input_label = + [](ParallelTensorAttrs const &a) -> std::string { + RecordFormatter r; + + r << fmt::to_string(a.shape); + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + return as_dot(spcg.raw_graph, get_node_label, get_input_label); +} + +void debug_print_dot(SubParallelComputationGraph const &spcg) { + std::cout << as_dot(spcg) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc new file mode 100644 index 0000000000..bb8cb449bc --- /dev/null +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph_edge.cc @@ -0,0 +1,38 @@ +#include "substitutions/sub_parallel_computation_graph_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" + +namespace FlexFlow { + +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_dst(parallel_tensor_guid_t const &tensor, + parallel_layer_guid_t const &layer, + int input_idx) { + return SubParallelComputationGraphEdge{ + OpenDataflowEdge{ + DataflowEdge{ + tensor.raw_graph_output, + DataflowInput{ + layer.raw_graph_node, + input_idx, + }, + }, + }, + }; +} + +SubParallelComputationGraphEdge + subpcg_edge_from_tensor_and_use(open_parallel_tensor_guid_t const &tensor, + parallel_tensor_use_t const &use) { + return SubParallelComputationGraphEdge{ + open_dataflow_edge_from_src_and_dst(tensor.raw_open_dataflow_value, + use.raw_dataflow_input), + }; +} + +open_parallel_tensor_guid_t + get_parallel_tensor(SubParallelComputationGraphEdge const &e) { + OpenDataflowValue raw_value = get_open_dataflow_edge_src(e.raw_edge); + return open_parallel_tensor_guid_t{raw_value}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution.cc b/lib/substitutions/src/substitutions/substitution.cc index b4e6709a73..22e15cb01a 100644 --- a/lib/substitutions/src/substitutions/substitution.cc +++ b/lib/substitutions/src/substitutions/substitution.cc @@ -1,154 +1,169 @@ #include "substitutions/substitution.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph_edge.h" +#include "substitutions/substitution_internal/evaluate_substitution_output.h" +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/values.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/graph/node/algorithms.h" +#include "utils/overload.h" namespace FlexFlow { -/* struct AddMappedEdgeFunctor { */ -/* bidict const &node_mapping; */ -/* SubParallelComputationGraph &new_pcg; */ - -/* template */ -/* void operator()(T const &t) { */ -/* return add_mapped_edge(t); */ -/* } */ - -/* void add_mapped_edge(InputMultiDiEdge const &e) { */ -/* new_pcg.add_edge(InputMultiDiEdge{ */ -/* node_mapping.at_l(e.dst), new_pcg.add_node_port(), e.uid}); */ -/* } */ - -/* void add_mapped_edge(OutputMultiDiEdge const &e) { */ -/* new_pcg.add_edge(OutputMultiDiEdge{ */ -/* node_mapping.at_l(e.src), new_pcg.add_node_port(), e.uid}); */ -/* } */ - -/* void add_mapped_edge(MultiDiEdge const &e) { */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(e.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ -/* }; */ - -/* struct AddNewEdgeFunctor { */ -/* SubParallelComputationGraph const &old_pcg; */ -/* SubParallelComputationGraph &new_pcg; */ -/* MultiDiGraphPatternMatch const &match; */ -/* bidict node_mapping; */ - -/* template */ -/* void operator()(TO const &old_edge, TN const &new_edge) { */ -/* return add_new_edge(old_edge, new_edge); */ -/* } */ - -/* void add_new_edge(InputMultiDiEdge const &old_edge, */ -/* InputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(InputMultiDiEdge{node_mapping.at_l(new_edge.dst), */ -/* new_pcg.add_node_port(), */ -/* old_edge.uid}); */ -/* } */ - -/* void add_new_edge(MultiDiEdge const &old_edge, */ -/* InputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(new_edge.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(old_edge.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ - -/* void add_new_edge(OutputMultiDiEdge const &old_edge, */ -/* OutputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(OutputMultiDiEdge{node_mapping.at_l(new_edge.src), */ -/* new_pcg.add_node_port(), */ -/* old_edge.uid}); */ -/* } */ - -/* void add_new_edge(MultiDiEdge const &old_edge, */ -/* OutputMultiDiEdge const &new_edge) { */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(old_edge.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(new_edge.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ - -/* void add_new_edge(InputMultiDiEdge const &, OutputMultiDiEdge const &) { */ -/* assert(false); */ -/* } */ - -/* void add_new_edge(OpenMultiDiEdge const &, MultiDiEdge const &) { */ -/* assert(false); */ -/* } */ - -/* void add_new_edge(OutputMultiDiEdge const &, InputMultiDiEdge const &) { */ -/* assert(false); */ -/* } */ -/* }; */ - -/* SubParallelComputationGraph */ -/* apply_substitution(SubParallelComputationGraph const &pcg, */ -/* Substitution const &substitution, */ -/* MultiDiGraphPatternMatch const &match) { */ -/* SubParallelComputationGraph new_pcg = */ -/* OutputLabelledOpenMultiDiGraph::template - * create< */ -/* UnorderedOutputLabelledOpenMultiDiGraph>(); */ -/* bidict node_mapping; // Refactor it with global nodes */ -/* for (Node const &node : get_nodes(pcg)) { */ -/* if (!contains_r(match.node_assignment, node)) { */ -/* node_mapping.equate(node, new_pcg.add_node(pcg.at(node))); */ -/* } */ -/* } */ -/* for (OpenMultiDiEdge const &edge : get_edges(pcg)) { */ -/* if (!contains_r(match.edge_assignment, edge)) { */ -/* visit(AddMappedEdgeFunctor{node_mapping, new_pcg}, edge); */ -/* } */ -/* } */ -/* for (Node const &output_node : */ -/* get_nodes(substitution.output_graph_expr.value())) { */ -/* Operator new_op = get_operator_attrs( */ -/* pcg, match, substitution.output_graph_expr.value().at(output_node)); - */ -/* Node new_node = new_pcg.add_node(new_op); */ -/* node_mapping.equate(output_node, new_node); */ -/* } */ -/* for (OpenMultiDiEdge const &output_edge : */ -/* get_edges(substitution.output_graph_expr.value())) { */ -/* if (std::holds_alternative(output_edge)) { */ -/* InputMultiDiEdge e = std::get(output_edge); */ -/* OpenMultiDiEdge original_edge = */ -/* match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); */ -/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ -/* original_edge, */ -/* output_edge); */ -/* } else if (std::holds_alternative(output_edge)) { */ -/* OutputMultiDiEdge e = std::get(output_edge); */ -/* OpenMultiDiEdge original_edge = */ -/* match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); */ -/* visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, */ -/* original_edge, */ -/* output_edge); */ -/* } else { */ -/* assert(std::holds_alternative(output_edge)); */ -/* MultiDiEdge e = std::get(output_edge); */ -/* new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), */ -/* new_pcg.add_node_port(), */ -/* node_mapping.at_l(e.src), */ -/* new_pcg.add_node_port()}); */ -/* } */ -/* } */ - -/* return new_pcg; */ -/* } */ - bool is_valid_substitution(Substitution const &) { NOT_IMPLEMENTED(); } SubParallelComputationGraph - apply_substitution(SubParallelComputationGraph const &, - Substitution const &, - UnlabelledDataflowGraphPatternMatch const &) { - NOT_IMPLEMENTED(); + apply_substitution(SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match) { + auto substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + SubParallelComputationGraph substitution_output_graph = + substitution_output_result.first; + OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping = + substitution_output_result.second; + + SubParallelComputationGraphData output_graph_data = + get_sub_pcg_data(substitution_output_graph); + SubParallelComputationGraphData pre_data = get_sub_pcg_data(spcg); + + std::unordered_set pre_nodes = + keys(pre_data.node_data); + std::unordered_set matched_nodes = + unordered_set_of(values(match.node_assignment)); + std::unordered_set post_nodes_from_original_graph = + set_minus(pre_nodes, matched_nodes); + + std::unordered_map post_node_data = + [&] { + std::unordered_map + post_node_data_from_orig = restrict_keys( + pre_data.node_data, post_nodes_from_original_graph); + std::unordered_map + post_node_data_from_sub = output_graph_data.node_data; + + return merge_maps(post_node_data_from_orig, post_node_data_from_sub); + }(); + + std::unordered_set post_edges = [&] { + std::unordered_set post_edges_from_orig = + filter(pre_data.edges, [&](SubParallelComputationGraphEdge const &e) { + if (e.raw_edge.has()) { + return true; + } else { + DataflowEdge dfe = e.raw_edge.get(); + parallel_layer_guid_t src = parallel_layer_guid_t{dfe.src.node}; + parallel_layer_guid_t dst = parallel_layer_guid_t{dfe.dst.node}; + return !(contains(matched_nodes, src) || + contains(matched_nodes, dst)); + } + }); + + std::unordered_set post_edges_from_sub = + filter(output_graph_data.edges, + [&](SubParallelComputationGraphEdge const &e) { + return !e.raw_edge.has(); + }); + + bidict + output_orig_pattern_mapping = get_output_mapping_for_pcg_pattern_match( + match, sub.pcg_pattern, spcg); + bidict + output_post_outexpr_mapping = get_output_graph_expr_output_mapping( + output_expr_to_result_sub_pcg_mapping, + sub.output_graph_expr, + substitution_output_graph); + + std::unordered_set incoming_to_sub_edges; + for (auto const &[pattern_input, base_graph_tensor] : + match.input_assignment) { + OutputGraphExprInput output_expr_input = + sub.inputs_mapping.at_l(pattern_input); + input_parallel_tensor_guid_t output_graph_input = + output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( + output_expr_input); + std::unordered_set uses = get_parallel_tensor_uses( + substitution_output_graph, + open_parallel_tensor_guid_from_input(output_graph_input)); + for (parallel_tensor_use_t const &use : uses) { + SubParallelComputationGraphEdge new_edge = + subpcg_edge_from_tensor_and_use(base_graph_tensor, use); + incoming_to_sub_edges.insert(new_edge); + } + } + + std::unordered_set outgoing_from_sub_edges; + for (ParallelComputationGraphEdge const &outgoing_edge : + get_subgraph_outgoing_edges(spcg, matched_nodes)) { + parallel_tensor_guid_t original_tensor = + get_parallel_tensor(outgoing_edge); + PatternNodeOutput pattern_tensor = + output_orig_pattern_mapping.at_r(original_tensor); + OutputGraphExprNodeOutput output_graph_tensor = + sub.outputs_mapping.at_l(pattern_tensor); + parallel_tensor_guid_t new_tensor = + output_post_outexpr_mapping.at_r(output_graph_tensor); + + SubParallelComputationGraphEdge new_edge = + subpcg_edge_from_tensor_and_dst( + new_tensor, + get_dst_layer(outgoing_edge), + get_dst_layer_input_idx(outgoing_edge)); + outgoing_from_sub_edges.insert(new_edge); + } + + return set_union(std::vector{ + post_edges_from_orig, + post_edges_from_sub, + incoming_to_sub_edges, + outgoing_from_sub_edges, + }); + }(); + + std::unordered_set post_inputs = + pre_data.inputs; + + std::unordered_map + post_value_data = [&] { + std::unordered_map + post_value_data_from_orig = filter_keys( + pre_data.value_data, [&](open_parallel_tensor_guid_t const &t) { + return visit_open_parallel_tensor_guid( + t, + overload{ + [&](parallel_tensor_guid_t const &t) { + return contains(post_nodes_from_original_graph, + get_source_layer(t)); + }, + [](input_parallel_tensor_guid_t const &) { + return true; + }, + }); + }); + + std::unordered_map + post_value_data_from_sub = output_graph_data.value_data; + return merge_maps(post_value_data_from_orig, post_value_data_from_sub); + }(); + + SubParallelComputationGraphData post_data = SubParallelComputationGraphData{ + post_node_data, + post_edges, + post_inputs, + post_value_data, + }; + + return sub_pcg_from_graph_data(post_data); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution_internal/evaluate_substitution_output.cc b/lib/substitutions/src/substitutions/substitution_internal/evaluate_substitution_output.cc new file mode 100644 index 0000000000..186e2fc03a --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution_internal/evaluate_substitution_output.cc @@ -0,0 +1,94 @@ +#include "substitutions/substitution_internal/evaluate_substitution_output.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution_internal/perform_shape_inference.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/map_values.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" +#include "utils/graph/node/algorithms/generate_new_node_id_permutation.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" + +namespace FlexFlow { + +std::pair + evaluate_substitution_output(SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match) { + std::unordered_map node_match = + map_values(match.node_assignment.as_unordered_map(), + [&](parallel_layer_guid_t const &n) { + return get_operator_attrs(spcg, n); + }); + + bidict new_node_id_permutation = + generate_new_node_id_permutation(sub.output_graph_expr.raw_graph); + bidict new_input_id_permutation = + generate_new_input_id_permutation(sub.output_graph_expr.raw_graph); + LabelledOpenDataflowGraphView + permuted = + permute_input_ids(permute_node_ids(sub.output_graph_expr.raw_graph, + new_node_id_permutation), + new_input_id_permutation); + + LabelledOpenDataflowGraphView + without_shapes = rewrite_node_labels( + permuted, + [&](Node const &n, OutputOperatorAttrsAssignment const &attrs) { + return ParallelLayerAttrs{ + materialize_output_operator_from_attrs_assignment(attrs, + node_match), + std::nullopt, + }; + }); + + bidict result_input_map = + map_keys(map_values(new_input_id_permutation, + [](DataflowGraphInput const &i) { + return OutputGraphExprInput{i}; + }), + [](NewDataflowGraphInput const &i) { + return input_parallel_tensor_guid_t{i.raw_input}; + }); + + bidict result_node_map = map_keys( + map_values(new_node_id_permutation, + [](Node const &n) { return OutputGraphExprNode{n}; }), + [](NewNode const &n) { return parallel_layer_guid_t{n.raw_node}; }); + + std::unordered_map input_shapes = + map_values(map_keys(match.input_assignment, + [&](PatternInput const &i) { + return result_input_map + .at_r(sub.inputs_mapping.at_l(i)) + .raw_dataflow_graph_input; + }), + [&](open_parallel_tensor_guid_t const &v) { + return spcg.raw_graph.at(v.raw_open_dataflow_value).shape; + }); + LabelledOpenDataflowGraphView + with_shapes = perform_shape_inference(without_shapes, input_shapes); + LabelledOpenDataflowGraphView + with_attrs = rewrite_value_labels( + with_shapes, + [](OpenDataflowValue const &, ParallelTensorShape const &s) { + return ParallelTensorAttrs{ + s, + std::nullopt, + std::nullopt, + CreateGrad::YES, + }; + }); + + return std::make_pair(SubParallelComputationGraph{with_attrs}, + OutputExprToResultSubPCGMapping{ + result_node_map, + result_input_map, + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc new file mode 100644 index 0000000000..083334f0db --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc @@ -0,0 +1,32 @@ +#include "substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.h" +#include "substitutions/output_graph/output_graph_expr.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/merge_bidicts.h" + +namespace FlexFlow { + +bidict + get_output_graph_expr_output_mapping( + OutputExprToResultSubPCGMapping const &m, + OutputGraphExpr const &output_graph_expr, + SubParallelComputationGraph const &spcg) { + bidict result; + + for (auto const &[parallel_layer, output_graph_expr_node] : m.node_mapping) { + std::vector layer_outputs = + get_layer_outputs(spcg, parallel_layer); + std::vector output_graph_expr_outputs = + get_node_outputs(output_graph_expr, output_graph_expr_node); + + bidict + mapping_for_layer = bidict_from_keys_and_values( + layer_outputs, output_graph_expr_outputs); + + result = merge_bidicts(result, mapping_for_layer); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc new file mode 100644 index 0000000000..0bde326bd1 --- /dev/null +++ b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -0,0 +1,45 @@ +#include "substitutions/substitution_internal/perform_shape_inference.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" + +namespace FlexFlow { + +LabelledOpenDataflowGraphView + perform_shape_inference( + LabelledOpenDataflowGraphView const + &g, + std::unordered_map const + &input_shapes) { + + std::unordered_map inferred = + map_keys(input_shapes, [](DataflowGraphInput const &i) { + return OpenDataflowValue{i}; + }); + + for (Node const &n : get_topological_ordering(g)) { + std::vector input_shapes = + transform(get_inputs(g, n), + [&](OpenDataflowValue const &v) { return inferred.at(v); }); + + std::vector output_shapes = + get_output_shapes(g.at(n).op_attrs, input_shapes); + + std::vector outputs = get_outputs(g, n); + + for (auto const &[output, shape] : zip(outputs, output_shapes)) { + inferred.insert({OpenDataflowValue{output}, shape}); + } + } + + return rewrite_value_labels( + g, [&](OpenDataflowValue const &v, std::monostate const &) { + return inferred.at(v); + }); +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 8a71d92e0e..05f21247c7 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -1,4 +1,5 @@ #include "substitutions/tensor_pattern/get_attribute.h" +#include "op-attrs/parallel_tensor_dims.h" #include "utils/containers/as_vector.h" #include "utils/containers/transform.h" #include "utils/integer_conversions.h" diff --git a/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc new file mode 100644 index 0000000000..794ab5abda --- /dev/null +++ b/lib/substitutions/src/substitutions/tensor_pattern/tensor_attribute_pattern.cc @@ -0,0 +1,9 @@ +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" + +namespace FlexFlow { + +TensorAttributePattern tensor_attribute_pattern_match_all() { + return TensorAttributePattern{{}}; +} + +} // namespace FlexFlow diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index fb01733bae..a7ebc0bff7 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -10,7 +10,7 @@ #include "utils/containers/zip.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" namespace FlexFlow { @@ -37,7 +37,8 @@ static std::optional std::vector pattern_node_inputs = get_inputs_to_pattern_node(pattern, pattern_node); - std::unordered_set pattern_graph_inputs = get_inputs(pattern); + std::unordered_set pattern_graph_inputs = + get_graph_inputs(pattern); assert(unordered_set_of(pattern_node_inputs) == transform(pattern_graph_inputs, diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 31c4a23e7e..304bb8cf46 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -5,11 +5,14 @@ #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/standard_pattern_edge.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" @@ -22,8 +25,7 @@ namespace FlexFlow { OpenDataflowSubgraphResult subgraph_matched(OpenDataflowGraphView const &g, UnlabelledDataflowGraphPatternMatch const &match) { - std::unordered_set matched_nodes = - keys(match.node_assignment.reversed()); + std::unordered_set matched_nodes = right_entries(match.node_assignment); return get_subgraph(g, matched_nodes); } @@ -149,8 +151,8 @@ bool unlabelled_pattern_does_match( OpenDataflowSubgraphResult subgraph_result = subgraph_matched(graph, match); OpenDataflowGraphView matched_subgraph = subgraph_result.graph; - assert(keys(match.node_assignment) == get_nodes(pattern)); - assert(keys(match.node_assignment.reversed()) == get_nodes(matched_subgraph)); + assert(left_entries(match.node_assignment) == get_nodes(pattern)); + assert(right_entries(match.node_assignment) == get_nodes(matched_subgraph)); MatchAdditionalCriterion through_subgraph_operation = MatchAdditionalCriterion{ diff --git a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index db49e01611..84e0d91fee 100644 --- a/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -5,7 +5,10 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" @@ -29,8 +32,9 @@ std::unordered_set get_values(UnlabelledGraphPattern const &p) { pattern_value_from_raw_open_dataflow_value); } -std::unordered_set get_inputs(UnlabelledGraphPattern const &p) { - return transform(get_inputs(p.raw_graph), +std::unordered_set + get_graph_inputs(UnlabelledGraphPattern const &p) { + return transform(get_open_dataflow_graph_inputs(p.raw_graph), [](DataflowGraphInput const &i) { return PatternInput{i}; }); } diff --git a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc new file mode 100644 index 0000000000..70e960bc73 --- /dev/null +++ b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc @@ -0,0 +1,34 @@ +#include "substitutions/operator_pattern/get_attribute.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_attribute(LinearAttrs, OperatorAttributeKey)") { + int out_channels = 16; + bool use_bias = true; + std::optional activation = Activation::GELU; + std::optional regularizer = RegularizerAttrs{ + L1RegularizerAttrs{ + 0.5, + }, + }; + + LinearAttrs attrs = LinearAttrs{ + out_channels, + use_bias, + DataType::FLOAT, + activation, + regularizer, + }; + + SUBCASE("USE_BIAS") { + std::optional result = + get_attribute(attrs, OperatorAttributeKey::USE_BIAS); + std::optional correct = + OperatorAttributeValue{use_bias}; + CHECK(result == correct); + CHECK(result.value().has()); + } + } +} diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 8631d574f8..6922798a97 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -1,10 +1,12 @@ -#include "utils/containers/get_only.h" -#define DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#include "substitutions/pcg_pattern.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -#include "substitutions/pcg_pattern.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.h" #include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" #include "test/utils/doctest.h" +#include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" using namespace ::FlexFlow; @@ -79,18 +81,20 @@ TEST_SUITE(FF_TEST_SUITE) { OperatorAttributePattern, TensorAttributePattern>>(); - TensorAttributePattern pattern_tensor_a = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_b = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_c = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_x = TensorAttributePattern{{}}; - TensorAttributePattern pattern_tensor_y = TensorAttributePattern{{}}; - - OperatorAttributePattern op_pattern_1 = - OperatorAttributePattern{{OperatorAttributeConstraint{ - ConstraintType::EQUAL, - OperatorAttributeExpr{OperatorAttributeKey::OP_TYPE}, - OperatorAttributeValue{OperatorType::LINEAR}, - }}}; + TensorAttributePattern pattern_tensor_a = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_b = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_c = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_x = + tensor_attribute_pattern_match_all(); + TensorAttributePattern pattern_tensor_y = + tensor_attribute_pattern_match_all(); + + OperatorAttributePattern op_pattern_1 = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + }}; OperatorAttributePattern op_pattern_2 = op_pattern_1; @@ -116,42 +120,38 @@ TEST_SUITE(FF_TEST_SUITE) { PCGPattern pattern = PCGPattern{g}; - std::unordered_set result = - unordered_set_of( - find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); - - UnlabelledDataflowGraphPatternMatch match1 = - UnlabelledDataflowGraphPatternMatch{ - bidict{ - {op_pattern_1_node, x_matmul.raw_graph_node}, - {op_pattern_2_node, y_matmul.raw_graph_node}, - }, - bidict{ - {PatternInput{pt_a}, - OpenDataflowValue{a_tensor.raw_graph_output}}, - {PatternInput{pt_b}, - OpenDataflowValue{x_weights.raw_graph_output}}, - {PatternInput{pt_c}, - OpenDataflowValue{y_weights.raw_graph_output}}, - }}; - - UnlabelledDataflowGraphPatternMatch match2 = - UnlabelledDataflowGraphPatternMatch{ - bidict{ - {op_pattern_1_node, y_matmul.raw_graph_node}, - {op_pattern_2_node, x_matmul.raw_graph_node}, - }, - bidict{ - {PatternInput{pt_a}, - OpenDataflowValue{a_tensor.raw_graph_output}}, - {PatternInput{pt_b}, - OpenDataflowValue{y_weights.raw_graph_output}}, - {PatternInput{pt_c}, - OpenDataflowValue{x_weights.raw_graph_output}}, - }}; - - std::unordered_set correct = {match1, - match2}; + std::unordered_set result = unordered_set_of( + find_pattern_matches(pattern, sub_pcg_from_full_pcg(pcg))); + + PCGPatternMatch match1 = + PCGPatternMatch{bidict{ + {op_pattern_1_node, x_matmul}, + {op_pattern_2_node, y_matmul}, + }, + bidict{ + {PatternInput{pt_a}, + open_parallel_tensor_guid_from_closed(a_tensor)}, + {PatternInput{pt_b}, + open_parallel_tensor_guid_from_closed(x_weights)}, + {PatternInput{pt_c}, + open_parallel_tensor_guid_from_closed(y_weights)}, + }}; + + PCGPatternMatch match2 = + PCGPatternMatch{bidict{ + {op_pattern_1_node, y_matmul}, + {op_pattern_2_node, x_matmul}, + }, + bidict{ + {PatternInput{pt_a}, + open_parallel_tensor_guid_from_closed(a_tensor)}, + {PatternInput{pt_b}, + open_parallel_tensor_guid_from_closed(y_weights)}, + {PatternInput{pt_c}, + open_parallel_tensor_guid_from_closed(x_weights)}, + }}; + + std::unordered_set correct = {match1, match2}; CHECK(result == correct); } diff --git a/lib/substitutions/test/src/substitutions/substitution.cc b/lib/substitutions/test/src/substitutions/substitution.cc new file mode 100644 index 0000000000..87ffc01f0b --- /dev/null +++ b/lib/substitutions/test/src/substitutions/substitution.cc @@ -0,0 +1,229 @@ +#include "substitutions/substitution.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.h" +#include "substitutions/output_graph/output_graph_expr_node.dtg.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/pcg_pattern_builder.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + // TEST_CASE("is_valid_substitution") { + // FAIL("TODO"); + // } + + TEST_CASE("evaluate_substitution_output(SubParallelComputationGraph, " + "Substituion, PCGPatternMatch)") { + // Currently Substitution creation is very verbose. + // This is being addressed in + // https://github.com/flexflow/FlexFlow/issues/1473. + auto pattern_g = LabelledOpenDataflowGraph:: + create>(); + + PatternInput pattern_i_activation = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + PatternInput pattern_i_weights = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + + OperatorAttributePattern mm_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + op_attr_key_equals( + OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{std::optional{std::nullopt}}), + }}; + NodeAddedResult mm_added = pattern_g.add_node( + mm_pattern, + {OpenDataflowValue{pattern_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{pattern_i_weights.raw_dataflow_graph_input}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_mm_node = PatternNode{mm_added.node}; + DataflowOutput mm_output = get_only(mm_added.outputs); + + OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::RELU), + }}; + NodeAddedResult relu_added = + pattern_g.add_node(relu_pattern, + {OpenDataflowValue{mm_output}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_relu_node = PatternNode{relu_added.node}; + DataflowOutput relu_output = get_only(relu_added.outputs); + + LabelledOpenDataflowGraph + output_g = LabelledOpenDataflowGraph:: + create>(); + + OutputGraphExprInput output_i_activation = + OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput output_i_weights = + OutputGraphExprInput{output_g.add_input({})}; + + OutputOperatorAttrsAssignment fused_mm_relu_attrs_assignment = + OutputOperatorAttrsAssignment{{ + set_attr_to_constant(OperatorAttributeKey::OP_TYPE, + OperatorAttributeValue{OperatorType::LINEAR}), + copy_attr_from_pattern_node(OperatorAttributeKey::OUT_CHANNELS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::USE_BIAS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::DATA_TYPE, + pattern_mm_node), + set_attr_to_constant(OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{Activation::RELU}), + copy_attr_from_pattern_node(OperatorAttributeKey::REGULARIZER, + pattern_mm_node), + }}; + NodeAddedResult fused_mm_relu_added = output_g.add_node( + fused_mm_relu_attrs_assignment, + {OpenDataflowValue{output_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{output_i_weights.raw_dataflow_graph_input}}, + {{}}); + OutputGraphExprNode fused_mm_relu_node = + OutputGraphExprNode{fused_mm_relu_added.node}; + DataflowOutput fused_mm_relu_output = get_only(fused_mm_relu_added.outputs); + + Substitution sub = Substitution{ + PCGPattern{pattern_g}, + OutputGraphExpr{output_g}, + bidict{ + { + pattern_i_activation, + output_i_activation, + }, + { + pattern_i_weights, + output_i_weights, + }, + }, + bidict{ + { + PatternNodeOutput{relu_output}, + OutputGraphExprNodeOutput{fused_mm_relu_output}, + }, + }, + }; + + int in_channels = 24; + int batch_size = 4; + int batch_degree = 2; + std::string mm_match = "mm_match"; + std::string relu_match = "relu_match"; + + SubParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }); + t = b.dense(t, + /*outDim=*/16, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/mm_match); + t = b.relu(t, + /*name=*/relu_match); + t = b.dense(t, + /*outDim=*/8, + /*activation=*/Activation::RELU); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + PCGPatternMatch match = [&] { + parallel_layer_guid_t mm_match_layer = + get_parallel_layer_by_name(pcg, mm_match); + parallel_layer_guid_t relu_match_layer = + get_parallel_layer_by_name(pcg, relu_match); + open_parallel_tensor_guid_t mm_match_layer_input_activations = + get_layer_inputs(pcg, mm_match_layer).at(0); + open_parallel_tensor_guid_t mm_match_layer_input_weights = + get_layer_inputs(pcg, mm_match_layer).at(1); + + return PCGPatternMatch{ + bidict{ + {pattern_mm_node, mm_match_layer}, + {pattern_relu_node, relu_match_layer}, + }, + std::unordered_map{ + { + PatternInput{pattern_i_activation}, + mm_match_layer_input_activations, + }, + { + PatternInput{pattern_i_weights}, + mm_match_layer_input_weights, + }}, + }; + }(); + + SubParallelComputationGraph result = apply_substitution(pcg, sub, match); + + SubParallelComputationGraph correct = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }); + t = b.dense(t, + /*outDim=*/16, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12, + /*activation=*/Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/std::nullopt); + t = b.dense(t, + /*outDim=*/8, + /*activation=*/Activation::RELU); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + // since the new nodes produced by the substitution have new ids, it's + // easier/more correct to check that the graphs are isomorphic rather than + // checking their exact graph data + CHECK(sub_pcgs_are_isomorphic(result, correct)); + } +} diff --git a/lib/substitutions/test/src/substitutions/substitution_internal/evaluate_substitution_output.cc b/lib/substitutions/test/src/substitutions/substitution_internal/evaluate_substitution_output.cc new file mode 100644 index 0000000000..52b54b32fb --- /dev/null +++ b/lib/substitutions/test/src/substitutions/substitution_internal/evaluate_substitution_output.cc @@ -0,0 +1,274 @@ +#include "substitutions/substitution_internal/evaluate_substitution_output.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/operator_pattern/operator_attribute_constraint.h" +#include "substitutions/output_graph/output_operator_attrs_assignment.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/tensor_pattern/tensor_attribute_pattern.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("evaluate_substitution_output") { + // Currently Substitution creation is very verbose. + // This is being addressed in + // https://github.com/flexflow/FlexFlow/issues/1473. + auto pattern_g = LabelledOpenDataflowGraph:: + create>(); + + PatternInput pattern_i_activation = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + PatternInput pattern_i_weights = + PatternInput{pattern_g.add_input(tensor_attribute_pattern_match_all())}; + + OperatorAttributePattern mm_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::LINEAR), + op_attr_key_equals( + OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{std::optional{std::nullopt}}), + }}; + NodeAddedResult mm_added = pattern_g.add_node( + mm_pattern, + {OpenDataflowValue{pattern_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{pattern_i_weights.raw_dataflow_graph_input}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_mm_node = PatternNode{mm_added.node}; + DataflowOutput mm_output = get_only(mm_added.outputs); + + OperatorAttributePattern relu_pattern = OperatorAttributePattern{{ + op_type_equals_constraint(OperatorType::RELU), + }}; + NodeAddedResult relu_added = + pattern_g.add_node(relu_pattern, + {OpenDataflowValue{mm_output}}, + {tensor_attribute_pattern_match_all()}); + PatternNode pattern_relu_node = PatternNode{relu_added.node}; + DataflowOutput relu_output = get_only(relu_added.outputs); + + LabelledOpenDataflowGraph + output_g = LabelledOpenDataflowGraph:: + create>(); + + OutputGraphExprInput output_i_activation = + OutputGraphExprInput{output_g.add_input({})}; + OutputGraphExprInput output_i_weights = + OutputGraphExprInput{output_g.add_input({})}; + + OutputOperatorAttrsAssignment fused_mm_relu_attrs_assignment = + OutputOperatorAttrsAssignment{{ + set_attr_to_constant(OperatorAttributeKey::OP_TYPE, + OperatorAttributeValue{OperatorType::LINEAR}), + copy_attr_from_pattern_node(OperatorAttributeKey::OUT_CHANNELS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::USE_BIAS, + pattern_mm_node), + copy_attr_from_pattern_node(OperatorAttributeKey::DATA_TYPE, + pattern_mm_node), + set_attr_to_constant(OperatorAttributeKey::ACTIVATION, + OperatorAttributeValue{Activation::RELU}), + copy_attr_from_pattern_node(OperatorAttributeKey::REGULARIZER, + pattern_mm_node), + }}; + NodeAddedResult fused_mm_relu_added = output_g.add_node( + fused_mm_relu_attrs_assignment, + {OpenDataflowValue{output_i_activation.raw_dataflow_graph_input}, + OpenDataflowValue{output_i_weights.raw_dataflow_graph_input}}, + {{}}); + OutputGraphExprNode fused_mm_relu_node = + OutputGraphExprNode{fused_mm_relu_added.node}; + DataflowOutput fused_mm_relu_output = get_only(fused_mm_relu_added.outputs); + + Substitution sub = Substitution{ + PCGPattern{pattern_g}, + OutputGraphExpr{output_g}, + bidict{ + { + pattern_i_activation, + output_i_activation, + }, + { + pattern_i_weights, + output_i_weights, + }, + }, + bidict{ + { + PatternNodeOutput{relu_output}, + OutputGraphExprNodeOutput{fused_mm_relu_output}, + }, + }, + }; + + int in_channels = 24; + int batch_size = 4; + int batch_degree = 2; + std::string mm_match = "mm_match"; + std::string relu_match = "relu_match"; + + SubParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + parallel_tensor_guid_t t = b.create_input_tensor(ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }); + t = b.dense(t, + /*outDim=*/16, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/mm_match); + t = b.relu(t, + /*name=*/relu_match); + t = b.dense(t, + /*outDim=*/8, + /*activation=*/Activation::RELU); + + return sub_pcg_from_full_pcg(b.pcg); + }(); + + parallel_layer_guid_t mm_match_layer = + get_parallel_layer_by_name(pcg, mm_match); + parallel_layer_guid_t relu_match_layer = + get_parallel_layer_by_name(pcg, relu_match); + open_parallel_tensor_guid_t mm_match_layer_input_activations = + get_layer_inputs(pcg, mm_match_layer).at(0); + open_parallel_tensor_guid_t mm_match_layer_input_weights = + get_layer_inputs(pcg, mm_match_layer).at(1); + + PCGPatternMatch match = PCGPatternMatch{ + bidict{ + {pattern_mm_node, mm_match_layer}, + {pattern_relu_node, relu_match_layer}, + }, + std::unordered_map{ + { + PatternInput{pattern_i_activation}, + mm_match_layer_input_activations, + }, + { + PatternInput{pattern_i_weights}, + mm_match_layer_input_weights, + }}, + }; + + SUBCASE("evaluate_substitution_output") { + std::pair + result = evaluate_substitution_output(pcg, sub, match); + + SubParallelComputationGraph result_graph = result.first; + bidict result_node_map = + result.second.node_mapping; + bidict + result_input_map = result.second.input_mapping; + + LinearAttrs correct_result_fused_mm_relu_attrs = LinearAttrs{ + 12, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + ParallelTensorAttrs correct_result_i_activation_attrs = + get_parallel_tensor_attrs(pcg, mm_match_layer_input_activations); + ParallelTensorAttrs correct_result_i_weights_attrs = + get_parallel_tensor_attrs(pcg, mm_match_layer_input_weights); + ParallelTensorAttrs correct_result_fused_mm_relu_output_attrs = + get_parallel_tensor_attrs( + pcg, + open_parallel_tensor_guid_from_closed( + get_only(get_layer_outputs(pcg, relu_match_layer)))); + + parallel_layer_guid_t result_fused_mm_relu_node = + result_node_map.at_r(fused_mm_relu_node); + parallel_tensor_guid_t result_fused_mm_relu_output = + get_only(get_layer_outputs(result_graph, result_fused_mm_relu_node)); + input_parallel_tensor_guid_t result_i_activation = + result_input_map.at_r(output_i_activation); + input_parallel_tensor_guid_t result_i_weights = + result_input_map.at_r(output_i_weights); + + SubParallelComputationGraphData correct_graph_data = + SubParallelComputationGraphData{ + std::unordered_map{{ + result_fused_mm_relu_node, + ParallelLayerAttrs{ + PCGOperatorAttrs{correct_result_fused_mm_relu_attrs}, + /*name=*/std::nullopt, + }, + }}, + std::unordered_set{ + SubParallelComputationGraphEdge{ + OpenDataflowEdge{ + DataflowInputEdge{ + result_i_activation.raw_dataflow_graph_input, + DataflowInput{ + result_fused_mm_relu_node.raw_graph_node, + 0, + }, + }, + }, + }, + SubParallelComputationGraphEdge{ + OpenDataflowEdge{ + DataflowInputEdge{ + result_i_weights.raw_dataflow_graph_input, + DataflowInput{ + result_fused_mm_relu_node.raw_graph_node, + 1, + }, + }, + }, + }, + }, + std::unordered_set{ + result_i_activation, + result_i_weights, + }, + std::unordered_map{ + { + open_parallel_tensor_guid_from_input(result_i_activation), + correct_result_i_activation_attrs, + }, + { + open_parallel_tensor_guid_from_input(result_i_weights), + correct_result_i_weights_attrs, + }, + { + open_parallel_tensor_guid_from_closed( + result_fused_mm_relu_output), + correct_result_fused_mm_relu_output_attrs, + }}}; + + SubParallelComputationGraphData result_graph_data = + get_sub_pcg_data(result_graph); + + CHECK(result_graph_data == correct_graph_data); + } + } +} diff --git a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc new file mode 100644 index 0000000000..0bf1c21e7f --- /dev/null +++ b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -0,0 +1,173 @@ +#include "substitutions/substitution_internal/perform_shape_inference.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("perform_shape_inference") { + auto g = + LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + int in_channels = 24; + int out_channels = 16; + int batch_size = 4; + int batch_degree = 2; + + DataflowGraphInput i0 = g.add_input({}); + ParallelTensorShape i0_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{size_t_from_int(batch_size), batch_degree}, + ShardParallelDim{size_t_from_int(in_channels), 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + bool use_bias = false; + LinearAttrs n1_op_attrs = LinearAttrs{ + out_channels, + use_bias, + DataType::FLOAT, + std::nullopt, + std::nullopt, + }; + ParallelLayerAttrs n1_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + n1_op_attrs, + }, + std::nullopt, + }; + + ElementUnaryAttrs n2_op_attrs = ElementUnaryAttrs{ + OperatorType::RELU, + std::nullopt, + }; + ParallelLayerAttrs n2_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + n2_op_attrs, + }, + std::nullopt, + }; + + ParallelTensorShape n1_output_shape = + throw_if_unexpected(get_output_shape(n1_op_attrs, i0_shape)); + ParallelTensorShape n1_weight_shape = + throw_if_unexpected(get_kernel_shape(n1_op_attrs, i0_shape)); + ParallelTensorShape n2_output_shape = + throw_if_unexpected(get_output_shape(n2_op_attrs, n1_output_shape)); + + ParallelLayerAttrs n1_weight_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + WeightAttrs{get_reduced_shape(n1_weight_shape)}, + }, + std::nullopt, + }; + + ParallelLayerAttrs n1_weight_replicate_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{batch_degree}, + }, + std::nullopt, + }; + + NodeAddedResult n1_weight_added_result = + g.add_node(n1_weight_attrs, {}, {{}}); + Node n1_weight_node = n1_weight_added_result.node; + DataflowOutput n1_weight = get_only(n1_weight_added_result.outputs); + + NodeAddedResult n1_weight_replicate_added_result = g.add_node( + n1_weight_replicate_attrs, {OpenDataflowValue{n1_weight}}, {{}}); + Node n1_weight_replicate_node = n1_weight_replicate_added_result.node; + DataflowOutput n1_weight_replicated = + get_only(n1_weight_replicate_added_result.outputs); + + NodeAddedResult n1_added_result = g.add_node( + n1_attrs, + {OpenDataflowValue{i0}, OpenDataflowValue{n1_weight_replicated}}, + {{}}); + Node n1 = n1_added_result.node; + DataflowOutput o1 = get_only(n1_added_result.outputs); + + NodeAddedResult n2_added_result = + g.add_node(n2_attrs, {OpenDataflowValue{o1}}, {{}}); + Node n2 = n2_added_result.node; + DataflowOutput o2 = get_only(n2_added_result.outputs); + + std::unordered_map input_shapes = { + {i0, i0_shape}, + }; + + LabelledOpenDataflowGraphView + result = perform_shape_inference(g, input_shapes); + + LabelledOpenDataflowGraphData + result_data = get_graph_data(result); + + LabelledOpenDataflowGraphData + correct_data = LabelledOpenDataflowGraphData{ + { + {n1, n1_attrs}, + {n2, n2_attrs}, + {n1_weight_node, n1_weight_attrs}, + {n1_weight_replicate_node, n1_weight_replicate_attrs}, + }, + { + OpenDataflowEdge{ + DataflowInputEdge{ + i0, + DataflowInput{n1, 0}, + }, + }, + OpenDataflowEdge{DataflowEdge{ + DataflowOutput{n1_weight_node, 0}, + DataflowInput{n1_weight_replicate_node, 0}, + }}, + OpenDataflowEdge{ + DataflowEdge{ + DataflowOutput{n1_weight_replicate_node, 0}, + DataflowInput{n1, 1}, + }, + }, + OpenDataflowEdge{DataflowEdge{ + DataflowOutput{n1, 0}, + DataflowInput{n2, 0}, + }}, + }, + {i0}, + {{ + OpenDataflowValue{i0}, + i0_shape, + }, + { + OpenDataflowValue{DataflowOutput{n1_weight_node, 0}}, + lift_to_parallel(get_reduced_shape(n1_weight_shape)), + }, + { + OpenDataflowValue{DataflowOutput{n1_weight_replicate_node, 0}}, + n1_weight_shape, + }, + { + OpenDataflowValue{DataflowOutput{n1, 0}}, + n1_output_shape, + }, + { + OpenDataflowValue{DataflowOutput{n2, 0}}, + n2_output_shape, + }}}; + + CHECK(result_data == correct_data); + } +} diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc index 341cb23c29..6621145d39 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc @@ -61,7 +61,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("full_pattern_values_to_subpattern_2_inputs") { bidict result = split_result.full_pattern_values_to_subpattern_2_inputs; - PatternInput i0 = get_only(get_inputs(split_result.subpattern_2)); + PatternInput i0 = get_only(get_graph_inputs(split_result.subpattern_2)); bidict correct = { {pv0, i0}, }; @@ -117,7 +117,7 @@ TEST_SUITE(FF_TEST_SUITE) { split_result.full_pattern_values_to_subpattern_1_inputs; bidict correct = { {PatternValue{pi0}, - get_only(get_inputs(split_result.subpattern_1))}, + get_only(get_graph_inputs(split_result.subpattern_1))}, }; CHECK(result == correct); } @@ -126,7 +126,7 @@ TEST_SUITE(FF_TEST_SUITE) { split_result.full_pattern_values_to_subpattern_2_inputs; bidict correct = { {PatternValue{pi1}, - get_only(get_inputs(split_result.subpattern_2))}, + get_only(get_graph_inputs(split_result.subpattern_2))}, }; CHECK(result == correct); } diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index b2f4103c6a..9478195523 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -7,7 +7,8 @@ #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h new file mode 100644 index 0000000000..86ef6c4b4d --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_enumerating.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_ENUMERATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_ENUMERATING_H + +#include "utils/bidict/bidict.h" +#include + +namespace FlexFlow { + +template +bidict bidict_from_enumerating(std::unordered_set const &s) { + bidict result; + int idx = 0; + for (T const &t : s) { + result.equate(idx, t); + idx++; + } + + return result; +} + +template +bidict bidict_from_enumerating(std::set const &s) { + bidict result; + int idx = 0; + for (T const &t : s) { + result.equate(idx, t); + idx++; + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h new file mode 100644 index 0000000000..47af03591a --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_keys_and_values.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_KEYS_AND_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_KEYS_AND_VALUES_H + +#include "utils/bidict/algorithms/bidict_from_pairs.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/zip.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict bidict_from_keys_and_values(std::vector const &ls, + std::vector const &rs) { + size_t l_size = ls.size(); + size_t r_size = rs.size(); + if (l_size != r_size) { + throw mk_runtime_error(fmt::format( + "recieved keys (of size {}) not matching values (of size {})", + l_size, + r_size)); + } + + return bidict_from_pairs(zip(ls, rs)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/bidict_from_pairs.h b/lib/utils/include/utils/bidict/algorithms/bidict_from_pairs.h new file mode 100644 index 0000000000..e33ab68f60 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/bidict_from_pairs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FROM_PAIRS_H + +#include "utils/bidict/bidict.h" + +namespace FlexFlow { + +template +bidict bidict_from_pairs(C const &c) { + return bidict{c.begin(), c.end()}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/left_entries.h b/lib/utils/include/utils/bidict/algorithms/left_entries.h new file mode 100644 index 0000000000..a3fab172b1 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/left_entries.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_LEFT_ENTRIES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_LEFT_ENTRIES_H + +#include "utils/bidict/bidict.h" +#include + +namespace FlexFlow { + +template +std::unordered_set left_entries(bidict const &b) { + std::unordered_set result; + for (auto const &[l, _] : b) { + result.insert(l); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h new file mode 100644 index 0000000000..d388e35d75 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H + +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" + +namespace FlexFlow { + +template +bidict merge_bidicts(bidict const &lhs, bidict const &rhs) { + assert(are_disjoint(left_entries(lhs), left_entries(rhs))); + assert(are_disjoint(right_entries(lhs), right_entries(rhs))); + + bidict result; + for (auto const &kv : lhs) { + result.equate(kv.first, kv.second); + } + for (auto const &kv : rhs) { + result.equate(kv.first, kv.second); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/right_entries.h b/lib/utils/include/utils/bidict/algorithms/right_entries.h new file mode 100644 index 0000000000..ec0e822c74 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/right_entries.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_RIGHT_ENTRIES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_RIGHT_ENTRIES_H + +#include "utils/bidict/bidict.h" +#include + +namespace FlexFlow { + +template +std::unordered_set right_entries(bidict const &b) { + std::unordered_set result; + for (auto const &[_, r] : b) { + result.insert(r); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index eaecb6e405..8b19313002 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H #include "utils/fmt/unordered_map.h" +#include "utils/hash/unordered_map.h" #include #include #include @@ -22,6 +23,10 @@ struct bidict { } } + bool contains(L const &l, R const &r) const { + return this->contains_l(l) && this->at_l(l) == r; + } + bool contains_l(L const &l) const { return fwd_map.find(l) != fwd_map.end(); } @@ -85,6 +90,10 @@ struct bidict { return fwd_map.size(); } + bool empty() const { + return this->size() == 0; + } + using const_iterator = typename std::unordered_map::const_iterator; using value_type = std::pair; using reference = value_type &; diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 81fdff8a40..937ed51af2 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -41,13 +41,6 @@ template std::unordered_map restrict_keys(std::unordered_map const &m, std::unordered_set const &mask); -template -std::unordered_map merge_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs); - -template -bidict merge_maps(bidict const &lhs, bidict const &rhs); - template std::optional at_idx(std::vector const &v, size_t idx); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 6164699f2e..7c0490fa2a 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -88,37 +88,6 @@ std::optional index_of(Container const &c, Element const &e) { } } -template -std::unordered_map merge_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); - - std::unordered_map result; - for (auto const &kv : lhs) { - result.insert(kv); - } - for (auto const &kv : rhs) { - result.insert(kv); - } - - return result; -} - -template -bidict merge_maps(bidict const &lhs, bidict const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); - - bidict result; - for (auto const &kv : lhs) { - result.equate(kv.first, kv.second); - } - for (auto const &kv : rhs) { - result.equate(kv.first, kv.second); - } - - return result; -} - template std::function lookup_in(std::unordered_map const &m) { return [&m](K const &k) -> V { return m.at(k); }; diff --git a/lib/utils/include/utils/containers/enumerate.h b/lib/utils/include/utils/containers/enumerate.h index c9c5f4e97b..e3722e52c6 100644 --- a/lib/utils/include/utils/containers/enumerate.h +++ b/lib/utils/include/utils/containers/enumerate.h @@ -1,25 +1,46 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H -#include "utils/bidict/bidict.h" #include "utils/containers/enumerate_vector.h" +#include #include +#include namespace FlexFlow { +/** + * @brief Generate a map from indices to elements of \p c. + * + * @note We return a std::map rather than a + * std::vector> for consistency + * with enumerate(FFOrdered const &). Note that std::map + * provides ordered iteration in increasing order, so iterating through + * the result of this function should still function as expected. + */ template -bidict enumerate(std::vector const &c) { +std::map enumerate(std::vector const &c) { return enumerate_vector(c); } +/** + * @brief Choose an arbitrary ordering of the elements of \p c and + * return a map from indices of this ordering to elements of \p c. + + * + * @note We return a std::map rather than a + * std::vector> for consistency + * with enumerate(FFOrdered const &). Note that std::map + * provides ordered iteration in increasing order, so iterating through + * the result of this function should still function as expected. + */ template -bidict enumerate(std::unordered_set const &c) { - bidict m; - size_t idx = 0; +std::map enumerate(std::unordered_set const &c) { + std::map result; + int idx = 0; for (auto const &v : c) { - m.equate(idx++, v); + result.insert({idx++, v}); } - return m; + return result; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h index 8d36a5fe3b..11ee8d2352 100644 --- a/lib/utils/include/utils/containers/enumerate_vector.h +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -1,16 +1,17 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H +#include #include #include namespace FlexFlow { template -std::vector> enumerate_vector(std::vector const &v) { - std::vector> result; +std::map enumerate_vector(std::vector const &v) { + std::map result; for (int i = 0; i < v.size(); i++) { - result.push_back({i, v.at(i)}); + result.insert({i, v.at(i)}); } return result; } diff --git a/lib/utils/include/utils/containers/filtrans.h b/lib/utils/include/utils/containers/filtrans.h new file mode 100644 index 0000000000..be1b5093c9 --- /dev/null +++ b/lib/utils/include/utils/containers/filtrans.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTRANS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FILTRANS_H + +#include "utils/type_traits_core.h" +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct unwrap_optional { + static_assert("T is not a std::optional!"); +}; + +template +struct unwrap_optional> : type_identity {}; + +template +using unwrap_optional_t = typename unwrap_optional::type; + +template >> +std::vector filtrans(std::vector const &v, F f) { + std::vector result; + + for (In const &i : v) { + std::optional o = f(i); + if (o.has_value()) { + result.push_back(o.value()); + } + } + + return result; +} + +template >> +std::unordered_set filtrans(std::unordered_set const &s, F f) { + std::unordered_set result; + + for (In const &i : s) { + std::optional o = f(i); + if (o.has_value()) { + result.insert(o.value()); + } + } + + return result; +} + +template >> +std::set filtrans(std::set const &s, F f) { + std::set result; + + for (In const &i : s) { + std::optional o = f(i); + if (o.has_value()) { + result.insert(o.value()); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_permutations.h b/lib/utils/include/utils/containers/get_all_permutations.h new file mode 100644 index 0000000000..b7e797dad2 --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_permutations.h @@ -0,0 +1,106 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_H + +#include "utils/containers/sorted.h" +#include +#include +#include + +namespace FlexFlow { + +template +struct permutations_container { +public: + template + permutations_container(It start, It end) : current(start, end) { + std::sort(this->current.begin(), this->current.end()); + } + + struct iterator { + public: + using difference_type = long; + using value_type = std::vector; + using pointer = std::vector const *; + using reference = std::vector const &; + using iterator_category = std::input_iterator_tag; + + public: + explicit iterator(permutations_container const &c, bool done) + : c(c), done(done) {} + + iterator &operator++() { + assert(!this->done); + + this->done = !std::next_permutation(this->c.current.begin(), + this->c.current.end()); + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + ++(*this); + return retval; + } + + bool operator==(iterator other) const { + return &this->c == &other.c && this->done == other.done; + } + + bool operator!=(iterator other) const { + return &this->c != &other.c || this->done != other.done; + } + + reference operator*() const { + return this->c.current; + } + + private: + permutations_container const &c; + bool done; + }; + + using const_iterator = iterator; + using value_type = typename iterator::value_type; + using difference_type = typename iterator::difference_type; + using pointer = typename iterator::pointer; + using reference = typename iterator::reference; + using const_reference = typename iterator::reference; + + iterator begin() const { + return iterator(*this, false); + } + + iterator end() const { + return iterator(*this, true); + } + + const_iterator cbegin() const { + return iterator(*this, false); + } + + const_iterator cend() const { + return iterator(*this, true); + } + +private: + mutable std::vector current; +}; + +/** + * @brief Lazily compute all permutations of the elements of in the input + * container. + * + * @note In cases where an element appears multiple times in the input + * (e.g., std::vector{1, 2, 2}), duplicate permutations are removed + * (i.e., {2, 1, 2} is only returned once, not twice), so it is + * possible for this function to return fewer than (but no more than) + * n! permutations. + */ +template +permutations_container get_all_permutations(C const &c) { + return permutations_container(c.cbegin(), c.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/keys.h b/lib/utils/include/utils/containers/keys.h index c1c8af54cc..e14612541e 100644 --- a/lib/utils/include/utils/containers/keys.h +++ b/lib/utils/include/utils/containers/keys.h @@ -1,13 +1,24 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_KEYS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_KEYS_H +#include +#include #include namespace FlexFlow { -template -std::unordered_set keys(C const &c) { - std::unordered_set result; +template +std::unordered_set keys(std::unordered_map const &c) { + std::unordered_set result; + for (auto const &kv : c) { + result.insert(kv.first); + } + return result; +} + +template +std::unordered_set keys(std::map const &c) { + std::unordered_set result; for (auto const &kv : c) { result.insert(kv.first); } diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h new file mode 100644 index 0000000000..653c9d24f1 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_MAPS_H + +#include "utils/containers/are_disjoint.h" +#include "utils/containers/keys.h" +#include + +namespace FlexFlow { + +template +std::unordered_map merge_maps(std::unordered_map const &lhs, + std::unordered_map const &rhs) { + assert(are_disjoint(keys(lhs), keys(rhs))); + + std::unordered_map result; + for (auto const &kv : lhs) { + result.insert(kv); + } + for (auto const &kv : rhs) { + result.insert(kv); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/set_union.h b/lib/utils/include/utils/containers/set_union.h index 0f5d6d5157..0f7b895f7a 100644 --- a/lib/utils/include/utils/containers/set_union.h +++ b/lib/utils/include/utils/containers/set_union.h @@ -16,7 +16,7 @@ std::unordered_set set_union(std::unordered_set const &l, template std::unordered_set set_union(C const &sets) { std::unordered_set result; - for (std::unordered_set const &s : sets) { + for (auto const &s : sets) { for (T const &element : s) { result.insert(element); } diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index c4e561f059..ec3d5f5612 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -5,6 +5,7 @@ #include "utils/required_core.h" #include #include +#include #include #include @@ -32,6 +33,17 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { return result; } +template ()(std::declval()))> +std::set transform(std::set const &v, F const &f) { + std::set result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + template std::string transform(std::string const &s, F const &f) { std::string result; diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot_file.h index 1fd9813646..214e6eeddc 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot_file.h @@ -94,13 +94,29 @@ class DotFile { this->get_ostream() << "}" << std::endl; } - void add_edge(T const &src, T const &dst) { + void add_edge(T const &src, + T const &dst, + std::optional const &src_field = std::nullopt, + std::optional const &dst_field = std::nullopt) { this->reserve_node(src); this->reserve_node(dst); - auto src_name = this->get_node_name(this->node_ids.at(src)); - auto dst_name = this->get_node_name(this->node_ids.at(dst)); - this->get_ostream() << " " << src_name << " -> " << dst_name << ";" - << std::endl; + + auto get_field_suffix = + [](std::optional const &field) -> std::string { + if (field.has_value()) { + return (":" + field.value()); + } else { + return ""; + } + }; + + std::string src_name = this->get_node_name(this->node_ids.at(src)); + + std::string dst_name = this->get_node_name(this->node_ids.at(dst)); + + this->get_ostream() << " " << src_name << get_field_suffix(src_field) + << " -> " << dst_name << get_field_suffix(dst_field) + << ";" << std::endl; } void close() { for (size_t subgraph = 0; subgraph < this->subgraph_id; subgraph++) { diff --git a/lib/utils/include/utils/exception.decl.h b/lib/utils/include/utils/exception.decl.h deleted file mode 100644 index 93c450294b..0000000000 --- a/lib/utils/include/utils/exception.decl.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_DECL_H - -#include "utils/fmt.decl.h" -#include -#include - -namespace FlexFlow { - -#ifdef FF_REQUIRE_IMPLEMENTED -#define NOT_IMPLEMENTED() \ - static_assert(false, \ - "Function " __FUNC__ " not yet implemented " __FILE__ \ - ":" __LINE__); -#else -#define NOT_IMPLEMENTED() \ - throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); -#endif - -class not_implemented : public std::logic_error { -public: - not_implemented(std::string const &function_name, - std::string const &file_name, - int line); -}; - -template -T throw_if_unexpected(tl::expected const &r); - -template -std::runtime_error mk_runtime_error(fmt::format_string fmt_str, - T &&...args); -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index a00d2dba2b..20a8098040 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -1,13 +1,30 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_H #define _FLEXFLOW_UTILS_INCLUDE_EXCEPTION_H -#include "utils/exception.decl.h" #include "utils/fmt.h" +#include #include #include namespace FlexFlow { +#ifdef FF_REQUIRE_IMPLEMENTED +#define NOT_IMPLEMENTED() \ + static_assert(false, \ + "Function " __FUNC__ " not yet implemented " __FILE__ \ + ":" __LINE__); +#else +#define NOT_IMPLEMENTED() \ + throw not_implemented(__PRETTY_FUNCTION__, __FILE__, __LINE__); +#endif + +class not_implemented : public std::logic_error { +public: + not_implemented(std::string const &function_name, + std::string const &file_name, + int line); +}; + template T throw_if_unexpected(tl::expected const &r) { if (r.has_value()) { diff --git a/lib/utils/include/utils/fmt.decl.h b/lib/utils/include/utils/fmt.decl.h deleted file mode 100644 index 26193ae416..0000000000 --- a/lib/utils/include/utils/fmt.decl.h +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_FMT_DECL_H - -#include "fmt/format.h" -#include "utils/check_fmtable.h" -#include -#include -#include -#include - -#define DELEGATE_OSTREAM(...) \ - template <> \ - struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} - -namespace FlexFlow { - -template -struct delegate_ostream_operator : std::false_type {}; - -template -typename std::enable_if>::value, - std::ostream &>::type - operator<<(std::ostream &s, T); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index f1d4a9f2d9..ee008f7bfe 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -1,16 +1,24 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_FMT_H #define _FLEXFLOW_UTILS_INCLUDE_FMT_H -#include "utils/fmt.decl.h" +#include "utils/check_fmtable.h" #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include #include #include #include #include +#define DELEGATE_OSTREAM(...) \ + template <> \ + struct delegate_ostream_operator<__VA_ARGS__> : std::true_type {} + namespace FlexFlow { +template +struct delegate_ostream_operator : std::false_type {}; + template typename std::enable_if>::value, std::ostream &>::type diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h index db868a59f4..d50facee57 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms.h @@ -7,9 +7,10 @@ namespace FlexFlow { std::unordered_set get_edges(DataflowGraphView const &); -std::vector get_incoming_edges(DataflowGraphView const &, +std::vector get_input_values(DataflowGraphView const &, Node const &); -std::vector get_inputs(DataflowGraphView const &, Node const &); +std::vector get_dataflow_inputs(DataflowGraphView const &, + Node const &); std::vector get_outputs(DataflowGraphView const &, Node const &); std::unordered_set diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h new file mode 100644 index 0000000000..6c9626ce00 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/as_dot.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H + +#include "utils/dot_file.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::string as_dot(DataflowGraphView const &); +void as_dot(DotFile &, + DataflowGraphView const &, + std::function const &get_node_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml new file mode 100644 index 0000000000..082c25f6ea --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DataflowGraphIsomorphism" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..914f8553dc --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_ARE_ISOMORPHIC_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +bool dataflow_graphs_are_isomorphic(DataflowGraphView const &, + DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..de78f9bec3 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" +#include + +namespace FlexFlow { + +/** + * @brief Find a valid isomorphism between \p src and \p dst, if one exists + * + * @note If multiple isomorphisms exist, an arbitrary one is returned + */ +std::optional + find_isomorphism(DataflowGraphView const &, DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphisms.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphisms.h new file mode 100644 index 0000000000..dda69ea69a --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/find_isomorphisms.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + find_isomorphisms(DataflowGraphView const &, DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_incoming_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_incoming_edges.h new file mode 100644 index 0000000000..a4cd27bf9d --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_incoming_edges.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::vector get_incoming_edges(DataflowGraphView const &, + Node const &); +std::unordered_set + get_incoming_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h new file mode 100644 index 0000000000..f26ea20473 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h index febec3d14d..b1bade4254 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_edge_query.h @@ -10,6 +10,9 @@ DataflowEdgeQuery dataflow_edge_query_all(); DataflowEdgeQuery dataflow_edge_query_none(); bool dataflow_edge_query_includes_dataflow_edge(DataflowEdgeQuery const &, DataflowEdge const &); +DataflowEdgeQuery dataflow_edge_query_for_edge(DataflowEdge const &); +DataflowEdgeQuery dataflow_edge_query_all_outgoing_from(DataflowOutput const &); +DataflowEdgeQuery dataflow_edge_query_all_incoming_to(DataflowInput const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h index 7974c033c3..6a1898dd13 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_graph.h @@ -29,8 +29,7 @@ struct DataflowGraph : virtual public DataflowGraphView { } template - static typename std::enable_if::value, - DataflowGraph>::type + static std::enable_if_t, DataflowGraph> create_copy_of(DataflowGraphView const &view) { cow_ptr_t impl = make_cow_ptr(); impl.get_mutable()->inplace_materialize_from(view); diff --git a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h index 7ed54a5c27..fc1a222f1e 100644 --- a/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h +++ b/lib/utils/include/utils/graph/dataflow_graph/dataflow_output_query.h @@ -10,6 +10,10 @@ DataflowOutputQuery dataflow_output_query_all(); DataflowOutputQuery dataflow_output_query_none(); bool dataflow_output_query_includes_dataflow_output(DataflowOutputQuery const &, DataflowOutput const &); +DataflowOutputQuery dataflow_output_query_for_output(DataflowOutput const &); +std::unordered_set + apply_dataflow_output_query(DataflowOutputQuery const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index ad1b5f3bf5..f1063c1f21 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -16,6 +16,10 @@ #include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_source.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" @@ -109,11 +113,11 @@ struct UnorderedSetLabelledOpenDataflowGraph final return this->inputs; } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->nodes.at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const override { + ValueLabel at(OpenDataflowValue const &v) const override { return this->values.at(v); } @@ -136,6 +140,26 @@ struct UnorderedSetLabelledOpenDataflowGraph final }); } + virtual void inplace_materialize_from( + LabelledOpenDataflowGraphView const &view) + override { + + std::unordered_map nodes = generate_map( + get_nodes(view), [&](Node const &n) { return view.at(n); }); + std::unordered_set edges = get_edges(view); + std::unordered_set inputs = + ::FlexFlow::get_open_dataflow_graph_inputs(view); + + std::unordered_map values = + generate_map(get_open_dataflow_values(view), + [&](OpenDataflowValue const &v) { return view.at(v); }); + + this->inputs = inputs; + this->nodes = nodes; + this->edges = edges; + this->values = values; + } + UnorderedSetLabelledOpenDataflowGraph *clone() const override { return new UnorderedSetLabelledOpenDataflowGraph{ this->node_source, diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..2d4e6b11e9 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +std::optional find_isomorphism( + LabelledDataflowGraphView const &src, + LabelledDataflowGraphView const &dst) { + std::optional open_isomorphism = + find_isomorphism(view_as_labelled_open_dataflow_graph(src), + view_as_labelled_open_dataflow_graph(dst)); + + return transform(open_isomorphism, + [](OpenDataflowGraphIsomorphism const &open) { + assert(open.input_mapping.empty()); + return DataflowGraphIsomorphism{open.node_mapping}; + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h new file mode 100644 index 0000000000..4f42653380 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H + +#include "utils/graph/dataflow_graph/algorithms/dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +bool is_isomorphic_under( + LabelledDataflowGraphView const &src, + LabelledDataflowGraphView const &dst, + DataflowGraphIsomorphism const &candidate_isomorphism) { + return is_isomorphic_under(view_as_labelled_open_dataflow_graph(src), + view_as_labelled_open_dataflow_graph(dst), + OpenDataflowGraphIsomorphism{ + candidate_isomorphism.node_mapping, + {}, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..d399c5fcdb --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H + +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" + +namespace FlexFlow { + +template +bool labelled_dataflow_graphs_are_isomorphic( + LabelledDataflowGraph const &src, + LabelledDataflowGraph const &dst) { + return find_isomorphism(src, dst).has_value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h index 13e75efdd6..f1cdfd9690 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h @@ -33,11 +33,11 @@ struct LabelledDataflowGraphAsOpenView final return {}; } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->g.at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const override { + ValueLabel at(OpenDataflowValue const &v) const override { return this->g.at(v.get()); } diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h index 9f0fc0f30d..f7bbbd9964 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph_view.h @@ -8,8 +8,8 @@ namespace FlexFlow { template struct ILabelledDataflowGraphView : virtual public IDataflowGraphView { public: - virtual NodeLabel const &at(Node const &) const = 0; - virtual OutputLabel const &at(DataflowOutput const &) const = 0; + virtual NodeLabel at(Node const &) const = 0; + virtual OutputLabel at(DataflowOutput const &) const = 0; virtual ~ILabelledDataflowGraphView() = default; }; diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h index a6a6b9d061..61e0677061 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/labelled_dataflow_graph_view.h @@ -16,10 +16,10 @@ struct LabelledDataflowGraphView : virtual public DataflowGraphView { LabelledDataflowGraphView & operator=(LabelledDataflowGraphView const &) = default; - NodeLabel const &at(Node const &n) const { + NodeLabel at(Node const &n) const { return this->get_interface().at(n); } - OutputLabel const &at(DataflowOutput const &o) const { + OutputLabel at(DataflowOutput const &o) const { return this->get_interface().at(o); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h new file mode 100644 index 0000000000..6faddcdfcb --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" + +namespace FlexFlow { + +template +std::string as_dot( + LabelledOpenDataflowGraphView const &g, + std::function const &get_node_label, + std::function const &get_input_label) { + std::function unlabelled_get_node_label = + [&](Node const &n) -> std::string { return get_node_label(g.at(n)); }; + + std::function + unlabelled_get_input_label = [&](DataflowGraphInput const &i) { + return get_input_label(g.at(OpenDataflowValue{i})); + }; + + return as_dot(static_cast(g), + unlabelled_get_node_label, + unlabelled_get_input_label); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..a1d6e9e37a --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/containers/as_vector.h" +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/zip.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +/** + * @brief Finds an isomorphism between \p src and \p dst, if one exists. + * + * @note If multiple isomorphisms exist, an arbitrary one is returned. + */ +template +std::optional find_isomorphism( + LabelledOpenDataflowGraphView const &src, + LabelledOpenDataflowGraphView const &dst) { + std::unordered_set unlabelled_isomorphisms = + find_isomorphisms(static_cast(src), + static_cast(dst)); + + for (OpenDataflowGraphIsomorphism const &candidate_isomorphism : + unlabelled_isomorphisms) { + if (is_isomorphic_under(src, dst, candidate_isomorphism)) { + return candidate_isomorphism; + } + } + + return std::nullopt; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h new file mode 100644 index 0000000000..106d500464 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_LABELLED_OPEN_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_LABELLED_OPEN_DATAFLOW_GRAPH_DATA_H + +#include "utils/containers/filtrans.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphView + from_labelled_open_dataflow_graph_data( + LabelledOpenDataflowGraphData const &data) { + std::unordered_set values = keys(data.value_data); + std::unordered_set outputs = + filtrans(values, try_get_dataflow_output); + + OpenDataflowGraphData unlabelled_data = OpenDataflowGraphData{ + keys(data.node_data), + data.edges, + data.inputs, + outputs, + }; + + return with_labelling(from_open_dataflow_graph_data(unlabelled_data), + data.node_data, + data.value_data); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h new file mode 100644 index 0000000000..ec8f025ac3 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.dtg.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphData get_graph_data( + LabelledOpenDataflowGraphView const &g) { + + std::unordered_map node_data = + generate_map(get_nodes(g), [&](Node const &n) { return g.at(n); }); + + std::unordered_set edges = get_edges(g); + + std::unordered_set inputs = g.get_inputs(); + + std::unordered_map value_data = + generate_map(get_open_dataflow_values(g), + [&](OpenDataflowValue const &v) { return g.at(v); }); + + return LabelledOpenDataflowGraphData{ + node_data, + edges, + inputs, + value_data, + }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h new file mode 100644 index 0000000000..ecf9c22143 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" + +namespace FlexFlow { + +template +bool is_isomorphic_under( + LabelledOpenDataflowGraphView const &src, + LabelledOpenDataflowGraphView const &dst, + OpenDataflowGraphIsomorphism const &candidate_isomorphism) { + + bidict node_permutation = + map_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { + return NewNode{dst_node}; + }).reversed(); + bidict input_permutation = + map_values(candidate_isomorphism.input_mapping, + [](DataflowGraphInput const &dst_input) { + return NewDataflowGraphInput{dst_input}; + }) + .reversed(); + return get_graph_data(permute_input_ids( + permute_node_ids(src, node_permutation), input_permutation)) == + get_graph_data(dst); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml new file mode 100644 index 0000000000..082b61e691 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graph_data.struct.toml @@ -0,0 +1,41 @@ +namespace = "FlexFlow" +name = "LabelledOpenDataflowGraphData" +features = [ + "eq", + "hash", + "fmt", +] + +template_params = ["NodeLabel", "ValueLabel"] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h", + "", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/hash/unordered_set.h", + "utils/fmt/unordered_map.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "node_data" +type = "std::unordered_map<::FlexFlow::Node, NodeLabel>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "value_data" +type = "std::unordered_map<::FlexFlow::OpenDataflowValue, ValueLabel>" diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..b3a71235cc --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_LABELLED_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" + +namespace FlexFlow { + +template +bool labelled_open_dataflow_graphs_are_isomorphic( + LabelledOpenDataflowGraphView const &lhs, + LabelledOpenDataflowGraphView const &rhs) { + return find_isomorphism(lhs, rhs).has_value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h new file mode 100644 index 0000000000..88132e0a79 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h @@ -0,0 +1,46 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H + +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphView permute_input_ids( + LabelledOpenDataflowGraphView const &g, + bidict const &input_mapping) { + + OpenDataflowGraphView permuted = + permute_input_ids(static_cast(g), input_mapping); + + auto old_value_from_new = [&](OpenDataflowValue const &new_value) { + return new_value.visit(overload{ + [](DataflowOutput const &o) { return OpenDataflowValue{o}; }, + [&](DataflowGraphInput const &new_i) { + return OpenDataflowValue{ + input_mapping.at_l(NewDataflowGraphInput{new_i}), + }; + }, + }); + }; + + std::unordered_map node_labels = + generate_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); + + std::unordered_map value_labels = + generate_map(get_open_dataflow_values(permuted), + [&](OpenDataflowValue const &new_value) { + return g.at(old_value_from_new(new_value)); + }); + + return with_labelling(permuted, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h new file mode 100644 index 0000000000..2d1dd03755 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h @@ -0,0 +1,54 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +LabelledOpenDataflowGraphView permute_node_ids( + LabelledOpenDataflowGraphView const &g, + bidict const &new_node_tofrom_old_node) { + OpenDataflowGraphView permuted = permute_node_ids( + static_cast(g), new_node_tofrom_old_node); + + auto old_node_from_new = [&](Node const &new_node) { + return new_node_tofrom_old_node.at_l(NewNode{new_node}); + }; + + auto old_value_from_new = [&](OpenDataflowValue const &new_value) { + return new_value.visit(overload{ + [&](DataflowOutput const &new_o) { + return OpenDataflowValue{ + DataflowOutput{ + old_node_from_new(new_o.node), + new_o.idx, + }, + }; + }, + [](DataflowGraphInput const &i) { return OpenDataflowValue{i}; }, + }); + }; + + std::unordered_map node_labels = + generate_map(get_nodes(permuted), [&](Node const &new_node) { + return g.at(old_node_from_new(new_node)); + }); + + std::unordered_map value_labels = + generate_map(get_open_dataflow_values(permuted), + [&](OpenDataflowValue const &new_value) { + return g.at(old_value_from_new(new_value)); + }); + + return with_labelling(permuted, node_labels, value_labels); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h index 2849bfa72f..92938d7142 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h @@ -4,7 +4,7 @@ #include "utils/containers/generate_map.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h new file mode 100644 index 0000000000..eb39c4fe6a --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template > +LabelledOpenDataflowGraphView rewrite_node_labels( + LabelledOpenDataflowGraphView const &g, F f) { + return rewrite_labels( + g, + overload{ + [&](Node const &n, NodeLabel const &l) { return f(n, l); }, + [](OpenDataflowValue const &v, ValueLabel const &l) { return l; }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h new file mode 100644 index 0000000000..c0582d8e3d --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_VALUE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_VALUE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template < + typename NodeLabel, + typename ValueLabel, + typename F, + typename NewValueLabel = + std::invoke_result_t> +LabelledOpenDataflowGraphView rewrite_value_labels( + LabelledOpenDataflowGraphView const &g, F f) { + return rewrite_labels(g, + overload{ + [](Node const &n, NodeLabel const &l) { return l; }, + [&](OpenDataflowValue const &v, + ValueLabel const &l) { return f(v, l); }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h index e95781af6e..3697ab0f93 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h @@ -36,11 +36,11 @@ struct OpenDataflowGraphLabellingWrapper final return this->unlabelled.get_inputs(); } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->node_labels.at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const override { + ValueLabel at(OpenDataflowValue const &v) const override { return this->value_labels.at(v); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h index a4a3fc0bea..01777909cd 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph.h @@ -4,6 +4,7 @@ #include "utils/graph/dataflow_graph/node_added_result.dtg.h" #include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" namespace FlexFlow { @@ -18,6 +19,9 @@ struct ILabelledOpenDataflowGraph virtual DataflowGraphInput add_input(ValueLabel const &value_label) = 0; + virtual void inplace_materialize_from( + LabelledOpenDataflowGraphView const &) = 0; + // NodeAddedResult add_node(NodeLabel const &node_label, // std::vector const &inputs, // std::vector const &output_labels) diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h index 58137704e6..a59ce72896 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/i_labelled_open_dataflow_graph_view.h @@ -12,10 +12,10 @@ struct ILabelledOpenDataflowGraphView : virtual public ILabelledDataflowGraphView, virtual public IOpenDataflowGraphView { public: - virtual NodeLabel const &at(Node const &) const override = 0; - virtual ValueLabel const &at(OpenDataflowValue const &) const = 0; + virtual NodeLabel at(Node const &) const override = 0; + virtual ValueLabel at(OpenDataflowValue const &) const = 0; - ValueLabel const &at(DataflowOutput const &o) const override final { + ValueLabel at(DataflowOutput const &o) const override final { return this->at(OpenDataflowValue{o}); } diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h index 76877e245a..375e40d5ea 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h @@ -34,6 +34,16 @@ struct LabelledOpenDataflowGraph return LabelledOpenDataflowGraph(make_cow_ptr()); } + template + static std::enable_if_t, + LabelledOpenDataflowGraph> + create_copy_of( + LabelledOpenDataflowGraphView const &view) { + cow_ptr_t impl = make_cow_ptr(); + impl.get_mutable()->inplace_materialize_from(view); + return LabelledOpenDataflowGraph(std::move(impl)); + } + protected: using LabelledOpenDataflowGraphView:: LabelledOpenDataflowGraphView; diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h index 6e08b10a29..935f615ec8 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h @@ -20,11 +20,11 @@ struct LabelledOpenDataflowGraphView LabelledOpenDataflowGraphView & operator=(LabelledOpenDataflowGraphView const &) = default; - NodeLabel const &at(Node const &n) const { + NodeLabel at(Node const &n) const { return this->get_interface().at(n); } - ValueLabel const &at(OpenDataflowValue const &v) const { + ValueLabel at(OpenDataflowValue const &v) const { return this->get_interface().at(v); } diff --git a/lib/utils/include/utils/graph/node/algorithms/generate_new_node_id_permutation.h b/lib/utils/include/utils/graph/node/algorithms/generate_new_node_id_permutation.h new file mode 100644 index 0000000000..42ae867883 --- /dev/null +++ b/lib/utils/include/utils/graph/node/algorithms/generate_new_node_id_permutation.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_ALGORITHMS_GENERATE_NEW_NODE_ID_PERMUTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_NODE_ALGORITHMS_GENERATE_NEW_NODE_ID_PERMUTATION_H + +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/node/graph_view.h" + +namespace FlexFlow { + +bidict generate_new_node_id_permutation(GraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml b/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml new file mode 100644 index 0000000000..f3b8244573 --- /dev/null +++ b/lib/utils/include/utils/graph/node/algorithms/new_node.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "NewNode" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", +] + +[[fields]] +name = "raw_node" +type = "::FlexFlow::Node" diff --git a/lib/utils/include/utils/graph/node/node_query.h b/lib/utils/include/utils/graph/node/node_query.h index b7d754ceac..2ec8958083 100644 --- a/lib/utils/include/utils/graph/node/node_query.h +++ b/lib/utils/include/utils/graph/node/node_query.h @@ -8,6 +8,8 @@ namespace FlexFlow { NodeQuery node_query_all(); NodeQuery query_intersection(NodeQuery const &, NodeQuery const &); NodeQuery query_union(NodeQuery const &, NodeQuery const &); +std::unordered_set apply_node_query(NodeQuery const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h new file mode 100644 index 0000000000..4c600637aa --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/as_dot.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_AS_DOT_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::string as_dot(OpenDataflowGraphView const &); +std::string + as_dot(OpenDataflowGraphView const &, + std::function const &get_node_label, + std::function const + &get_input_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h new file mode 100644 index 0000000000..4c1ec38b89 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +/** + * @brief Find a valid isomorphism between \p src and \p dst, if one exists + * + * @note If multiple isomorphisms exist, an arbitrary one is returned + */ +std::optional + find_isomorphism(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h new file mode 100644 index 0000000000..022fc5b9fd --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISMS_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + find_isomorphisms(OpenDataflowGraphView const &, + OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h new file mode 100644 index 0000000000..1fbbea21b0 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_OPEN_DATAFLOW_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FROM_OPEN_DATAFLOW_GRAPH_DATA_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +struct FromOpenDataflowGraphDataView final + : virtual public IOpenDataflowGraphView { + FromOpenDataflowGraphDataView(OpenDataflowGraphData const &); + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(OpenDataflowEdgeQuery const &) const override; + std::unordered_set + query_outputs(DataflowOutputQuery const &) const override; + std::unordered_set get_inputs() const override; + + FromOpenDataflowGraphDataView *clone() const override; + +private: + OpenDataflowGraphData data; +}; + +OpenDataflowGraphView + from_open_dataflow_graph_data(OpenDataflowGraphData const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h new file mode 100644 index 0000000000..803b5c849b --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GENERATE_NEW_INPUT_ID_PERMUTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GENERATE_NEW_INPUT_ID_PERMUTATION_H + +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +bidict + generate_new_input_id_permutation(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_edges.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_edges.h new file mode 100644 index 0000000000..0710b3d970 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_edges.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_EDGES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_graph_data.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_graph_data.h new file mode 100644 index 0000000000..6bb4f123df --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_graph_data.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_GRAPH_DATA_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowGraphData get_graph_data(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h new file mode 100644 index 0000000000..84e0f57e3d --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGE_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowEdge get_incoming_edge(OpenDataflowGraphView const &, + DataflowInput const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h similarity index 53% rename from lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h rename to lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h index 9ba22394b2..22d66a0c0f 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h @@ -1,23 +1,17 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_EDGES_H #include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" namespace FlexFlow { -std::unordered_set get_edges(OpenDataflowGraphView const &); -std::unordered_set - get_inputs(OpenDataflowGraphView const &); -std::vector get_inputs(OpenDataflowGraphView const &, - Node const &); +std::unordered_set + get_incoming_edges(OpenDataflowGraphView const &); std::vector get_incoming_edges(OpenDataflowGraphView const &, Node const &); std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &, std::unordered_set const &); -std::unordered_set - get_open_dataflow_values(OpenDataflowGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_inputs.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_inputs.h new file mode 100644 index 0000000000..ae596010f8 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_inputs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_INPUTS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +std::vector get_inputs(OpenDataflowGraphView const &, + Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h new file mode 100644 index 0000000000..98231c8f8c --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_GRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_GRAPH_INPUTS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_graph_inputs(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h new file mode 100644 index 0000000000..bd7749a172 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUE_USES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUE_USES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_value_uses(OpenDataflowGraphView const &view, + OpenDataflowValue const &value); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h new file mode 100644 index 0000000000..5d8f58540e --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_OPEN_DATAFLOW_VALUES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_values(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h new file mode 100644 index 0000000000..a89b4e1bc1 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SOURCE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SOURCE_NODES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_source_nodes(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h new file mode 100644 index 0000000000..0df5f8458c --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(OpenDataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h new file mode 100644 index 0000000000..2325dcfbda --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_UNUSED_OPEN_DATAFLOW_GRAPH_INPUTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_GET_UNUSED_OPEN_DATAFLOW_GRAPH_INPUTS_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_unused_open_dataflow_graph_inputs(OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h new file mode 100644 index 0000000000..9ee5ac0790 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_IS_ISOMORPHIC_UNDER_H + +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +bool is_isomorphic_under(OpenDataflowGraphView const &, + OpenDataflowGraphView const &, + OpenDataflowGraphIsomorphism const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml new file mode 100644 index 0000000000..76b062e211 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "NewDataflowGraphInput" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "raw_input" +type = "::FlexFlow::DataflowGraphInput" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml new file mode 100644 index 0000000000..467ca73b3f --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.struct.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "OpenDataflowGraphData" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", + "utils/graph/dataflow_graph/dataflow_output.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "nodes" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "edges" +type = "std::unordered_set<::FlexFlow::OpenDataflowEdge>" + +[[fields]] +name = "inputs" +type = "std::unordered_set<::FlexFlow::DataflowGraphInput>" + +[[fields]] +name = "outputs" +type = "std::unordered_set<::FlexFlow::DataflowOutput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml new file mode 100644 index 0000000000..bafe3c7117 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_isomorphism.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "OpenDataflowGraphIsomorphism" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/bidict/bidict.h", + "utils/graph/node/node.dtg.h", + "utils/graph/open_dataflow_graph/dataflow_graph_input.dtg.h", +] + +[[fields]] +name = "node_mapping" +type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::Node>" + +[[fields]] +name = "input_mapping" +type = "::FlexFlow::bidict<::FlexFlow::DataflowGraphInput, ::FlexFlow::DataflowGraphInput>" diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h new file mode 100644 index 0000000000..6e27e55802 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_OPEN_DATAFLOW_GRAPHS_ARE_ISOMORPHIC_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +bool open_dataflow_graphs_are_isomorphic(OpenDataflowGraphView const &, + OpenDataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h new file mode 100644 index 0000000000..36add91574 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_INPUT_IDS_H + +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowGraphView permute_input_ids( + OpenDataflowGraphView const &, + bidict const &input_mapping); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h new file mode 100644 index 0000000000..64293383c6 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H + +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +OpenDataflowGraphView + permute_node_ids(OpenDataflowGraphView const &, + bidict const &new_node_tofrom_old_node); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h index 1189757c0e..78099fec57 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/dataflow_input_edge_query.h @@ -10,6 +10,12 @@ DataflowInputEdgeQuery dataflow_input_edge_query_all(); DataflowInputEdgeQuery dataflow_input_edge_query_none(); bool dataflow_input_edge_query_includes(DataflowInputEdgeQuery const &, DataflowInputEdge const &); +DataflowInputEdgeQuery + dataflow_input_edge_query_for_edge(DataflowInputEdge const &); +DataflowInputEdgeQuery + dataflow_input_edge_query_all_outgoing_from(DataflowGraphInput const &); +DataflowInputEdgeQuery + dataflow_input_edge_query_all_incoming_to(DataflowInput const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h index 3289ea48ae..09499f8e5f 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge.h @@ -8,7 +8,8 @@ namespace FlexFlow { Node get_open_dataflow_edge_dst_node(OpenDataflowEdge const &); int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &); -OpenDataflowValue get_open_dataflow_edge_source(OpenDataflowEdge const &); +DataflowInput get_open_dataflow_edge_dst(OpenDataflowEdge const &); +OpenDataflowValue get_open_dataflow_edge_src(OpenDataflowEdge const &); OpenDataflowEdge open_dataflow_edge_from_src_and_dst(OpenDataflowValue const &src, DataflowInput const &dst); diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h index 46630a2625..ae6e30549b 100644 --- a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_edge_query.h @@ -3,6 +3,7 @@ #include "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" namespace FlexFlow { @@ -10,6 +11,13 @@ OpenDataflowEdgeQuery open_dataflow_edge_query_all(); OpenDataflowEdgeQuery open_dataflow_edge_query_none(); bool open_dataflow_edge_query_includes(OpenDataflowEdgeQuery const &q, OpenDataflowEdge const &); +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_outgoing_from(OpenDataflowValue const &); +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_incoming_to(DataflowInput const &); +std::unordered_set apply_open_dataflow_edge_query( + OpenDataflowEdgeQuery const &, + std::unordered_set const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.h b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.h new file mode 100644 index 0000000000..d106205a07 --- /dev/null +++ b/lib/utils/include/utils/graph/open_dataflow_graph/open_dataflow_value.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_VALUE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_OPEN_DATAFLOW_GRAPH_OPEN_DATAFLOW_VALUE_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_value.dtg.h" +#include + +namespace FlexFlow { + +std::optional + try_get_dataflow_output(OpenDataflowValue const &); +std::optional + try_get_dataflow_graph_input(OpenDataflowValue const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/required_core.h b/lib/utils/include/utils/required_core.h index 76f03549a4..7a7abcd2c4 100644 --- a/lib/utils/include/utils/required_core.h +++ b/lib/utils/include/utils/required_core.h @@ -1,10 +1,10 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_CORE_H -#include "fmt.decl.h" #include "hash-utils.h" #include "test_types.h" #include "type_traits_core.h" +#include #include namespace FlexFlow { @@ -191,7 +191,9 @@ template using req = required; template -struct delegate_ostream_operator> : std::true_type {}; +std::ostream &operator<<(std::ostream &s, required const &t) { + return (s << fmt::to_string(t)); +} template struct remove_req { diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index 71c369df6a..afc16d4c4b 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_TUPLE_H #define _FLEXFLOW_UTILS_TUPLE_H -#include "utils/exception.decl.h" +#include "utils/exception.h" #include "utils/type_traits_core.h" #include #include diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/src/utils/bidict/algorithms/bidict_from_enumerating.cc new file mode 100644 index 0000000000..350f08600c --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/bidict_from_enumerating.h" diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_from_keys_and_values.cc b/lib/utils/src/utils/bidict/algorithms/bidict_from_keys_and_values.cc new file mode 100644 index 0000000000..34562f40c1 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_from_keys_and_values.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_from_pairs.cc b/lib/utils/src/utils/bidict/algorithms/bidict_from_pairs.cc new file mode 100644 index 0000000000..c8a27b8143 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_from_pairs.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/bidict_from_pairs.h" diff --git a/lib/utils/src/utils/bidict/algorithms/left_entries.cc b/lib/utils/src/utils/bidict/algorithms/left_entries.cc new file mode 100644 index 0000000000..a2c19de124 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/left_entries.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/left_entries.h" diff --git a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc new file mode 100644 index 0000000000..f70be2355f --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/merge_bidicts.h" diff --git a/lib/utils/src/utils/bidict/algorithms/right_entries.cc b/lib/utils/src/utils/bidict/algorithms/right_entries.cc new file mode 100644 index 0000000000..2f517a0af6 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/right_entries.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/right_entries.h" diff --git a/lib/utils/src/utils/containers/filtrans.cc b/lib/utils/src/utils/containers/filtrans.cc new file mode 100644 index 0000000000..a57a743ef0 --- /dev/null +++ b/lib/utils/src/utils/containers/filtrans.cc @@ -0,0 +1 @@ +#include "utils/containers/filtrans.h" diff --git a/lib/utils/src/utils/containers/get_all_permutations.cc b/lib/utils/src/utils/containers/get_all_permutations.cc new file mode 100644 index 0000000000..0fa4e16f08 --- /dev/null +++ b/lib/utils/src/utils/containers/get_all_permutations.cc @@ -0,0 +1 @@ +#include "utils/containers/get_all_permutations.h" diff --git a/lib/utils/src/utils/containers/merge_maps.cc b/lib/utils/src/utils/containers/merge_maps.cc new file mode 100644 index 0000000000..a36217fbeb --- /dev/null +++ b/lib/utils/src/utils/containers/merge_maps.cc @@ -0,0 +1 @@ +#include "utils/containers/merge_maps.h" diff --git a/lib/utils/src/utils/exception.cc b/lib/utils/src/utils/exception.cc new file mode 100644 index 0000000000..9bbf780fd8 --- /dev/null +++ b/lib/utils/src/utils/exception.cc @@ -0,0 +1 @@ +#include "utils/exception.h" diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc index 64af07636a..f0e52d6fc2 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms.cc @@ -1,6 +1,7 @@ #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/containers/sorted_by.h" #include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_output_query.h" @@ -10,23 +11,16 @@ std::unordered_set get_edges(DataflowGraphView const &g) { return g.query_edges(dataflow_edge_query_all()); } -std::vector get_incoming_edges(DataflowGraphView const &g, +std::vector get_input_values(DataflowGraphView const &g, Node const &n) { - return sorted_by(g.query_edges(DataflowEdgeQuery{ - query_set::matchall(), - query_set::matchall(), - {n}, - query_set::matchall(), - }), - [](DataflowEdge const &l, DataflowEdge const &r) { - return l.dst.idx < r.dst.idx; - }); + return transform(get_incoming_edges(g, n), + [](DataflowEdge const &e) { return e.src; }); } -std::vector get_inputs(DataflowGraphView const &g, - Node const &n) { +std::vector get_dataflow_inputs(DataflowGraphView const &g, + Node const &n) { return transform(get_incoming_edges(g, n), - [](DataflowEdge const &e) { return e.src; }); + [](DataflowEdge const &e) { return e.dst; }); } std::vector get_outputs(DataflowGraphView const &g, diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc new file mode 100644 index 0000000000..47c30ce998 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/as_dot.cc @@ -0,0 +1,64 @@ +#include "utils/graph/dataflow_graph/algorithms/as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +// WARN(@lockshaw): doing this all with string ids is ugly and error prone, +// as it requires duplicating the stringification logic across functions. +// +// Fixing this is tracked in issue +std::string as_dot(DataflowGraphView const &g) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + std::function get_node_label = + [](Node const &n) -> std::string { + return fmt::format("n{}", n.raw_uid); + }; + as_dot(dot, g, get_node_label); + + dot.close(); + return oss.str(); +} + +void as_dot(DotFile &dot, + DataflowGraphView const &g, + std::function const &get_node_label) { + auto get_node_name = [](Node n) { return fmt::format("n{}", n.raw_uid); }; + + auto get_input_field = [](int idx) { return fmt::format("i{}", idx); }; + + auto get_output_field = [](int idx) { return fmt::format("o{}", idx); }; + + for (Node const &n : get_nodes(g)) { + std::vector n_inputs = get_dataflow_inputs(g, n); + std::vector n_outputs = get_outputs(g, n); + + RecordFormatter inputs_record; + for (DataflowInput const &i : n_inputs) { + inputs_record << fmt::format("<{}>{}", get_input_field(i.idx), i.idx); + } + + RecordFormatter outputs_record; + for (DataflowOutput const &o : n_outputs) { + outputs_record << fmt::format("<{}>{}", get_output_field(o.idx), o.idx); + } + + RecordFormatter rec; + rec << inputs_record << get_node_label(n) << outputs_record; + + dot.add_record_node(get_node_name(n), rec); + } + + for (DataflowEdge const &e : get_edges(g)) { + dot.add_edge(get_node_name(e.src.node), + get_node_name(e.dst.node), + get_output_field(e.src.idx), + get_input_field(e.dst.idx)); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..ac7f9967be --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,11 @@ +#include "utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h" +#include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" + +namespace FlexFlow { + +bool dataflow_graphs_are_isomorphic(DataflowGraphView const &src, + DataflowGraphView const &dst) { + return find_isomorphism(src, dst).has_value(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..d06a64597e --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,20 @@ +#include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_first.h" +#include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" + +namespace FlexFlow { + +std::optional + find_isomorphism(DataflowGraphView const &src, + DataflowGraphView const &dst) { + std::unordered_set all_isomorphisms = + find_isomorphisms(src, dst); + + if (all_isomorphisms.empty()) { + return std::nullopt; + } else { + return get_first(all_isomorphisms); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphisms.cc new file mode 100644 index 0000000000..0e0210e5a2 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphisms.cc @@ -0,0 +1,22 @@ +#include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" + +namespace FlexFlow { + +std::unordered_set + find_isomorphisms(DataflowGraphView const &src, + DataflowGraphView const &dst) { + std::unordered_set open_isomorphisms = + find_isomorphisms(view_as_open_dataflow_graph(src), + view_as_open_dataflow_graph(dst)); + + return transform(open_isomorphisms, + [](OpenDataflowGraphIsomorphism const &open) { + assert(open.input_mapping.empty()); + return DataflowGraphIsomorphism{open.node_mapping}; + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc new file mode 100644 index 0000000000..9500836db1 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_incoming_edges.cc @@ -0,0 +1,31 @@ +#include "utils/graph/dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/containers/sorted_by.h" + +namespace FlexFlow { + +std::vector get_incoming_edges(DataflowGraphView const &g, + Node const &n) { + return sorted_by(g.query_edges(DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + {n}, + query_set::matchall(), + }), + [](DataflowEdge const &l, DataflowEdge const &r) { + return l.dst.idx < r.dst.idx; + }); +} + +std::unordered_set + get_incoming_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + DataflowEdgeQuery query = DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }; + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..c442a26dab --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + + std::unordered_set all_nodes = get_nodes(g); + query_set dst_query = query_set{set_minus(all_nodes, ns)}; + + DataflowEdgeQuery query = DataflowEdgeQuery{ + query_set{ns}, + query_set::matchall(), + dst_query, + query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.cc new file mode 100644 index 0000000000..0fd0b85b71 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.cc @@ -0,0 +1,41 @@ +#include "utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +ViewDataflowGraphAsOpen::ViewDataflowGraphAsOpen(DataflowGraphView const &g) + : g(g) {} + +std::unordered_set + ViewDataflowGraphAsOpen::query_nodes(NodeQuery const &q) const { + return this->g.query_nodes(q); +} + +std::unordered_set + ViewDataflowGraphAsOpen::query_edges(OpenDataflowEdgeQuery const &q) const { + std::unordered_set closed_edges = + this->g.query_edges(q.standard_edge_query); + + return transform(closed_edges, + [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }); +} + +std::unordered_set + ViewDataflowGraphAsOpen::query_outputs(DataflowOutputQuery const &q) const { + return this->g.query_outputs(q); +} + +std::unordered_set + ViewDataflowGraphAsOpen::get_inputs() const { + return {}; +} + +ViewDataflowGraphAsOpen *ViewDataflowGraphAsOpen::clone() const { + return new ViewDataflowGraphAsOpen{this->g}; +} + +OpenDataflowGraphView view_as_open_dataflow_graph(DataflowGraphView const &g) { + return OpenDataflowGraphView::create(g); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h new file mode 100644 index 0000000000..bec9d0e019 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/view_as_open_dataflow_graph.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_OPEN_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_SRC_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_VIEW_AS_OPEN_DATAFLOW_GRAPH_H + +#include "utils/graph/open_dataflow_graph/open_dataflow_graph_view.h" + +namespace FlexFlow { + +struct ViewDataflowGraphAsOpen final : public IOpenDataflowGraphView { +public: + ViewDataflowGraphAsOpen() = delete; + ViewDataflowGraphAsOpen(DataflowGraphView const &); + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(OpenDataflowEdgeQuery const &) const override; + std::unordered_set + query_outputs(DataflowOutputQuery const &) const override; + std::unordered_set get_inputs() const override; + + ViewDataflowGraphAsOpen *clone() const override; + + ~ViewDataflowGraphAsOpen() = default; + +private: + DataflowGraphView g; +}; + +OpenDataflowGraphView view_as_open_dataflow_graph(DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc index e30dc41c1f..2196f7a028 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_edge_query.cc @@ -27,4 +27,33 @@ bool dataflow_edge_query_includes_dataflow_edge(DataflowEdgeQuery const &q, includes(q.dst_idxs, e.dst.idx); } +DataflowEdgeQuery dataflow_edge_query_for_edge(DataflowEdge const &e) { + return DataflowEdgeQuery{ + query_set{e.src.node}, + query_set{e.src.idx}, + query_set{e.dst.node}, + query_set{e.dst.idx}, + }; +} + +DataflowEdgeQuery + dataflow_edge_query_all_outgoing_from(DataflowOutput const &src) { + return DataflowEdgeQuery{ + query_set{src.node}, + query_set{src.idx}, + query_set::matchall(), + query_set::matchall(), + }; +} + +DataflowEdgeQuery + dataflow_edge_query_all_incoming_to(DataflowInput const &dst) { + return DataflowEdgeQuery{ + query_set::matchall(), + query_set::matchall(), + query_set{dst.node}, + query_set{dst.idx}, + }; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc index b8d89a250d..64df4c77f2 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/dataflow_output_query.cc @@ -21,4 +21,19 @@ bool dataflow_output_query_includes_dataflow_output( return includes(q.nodes, o.node) && includes(q.output_idxs, o.idx); } +DataflowOutputQuery dataflow_output_query_for_output(DataflowOutput const &o) { + return DataflowOutputQuery{ + query_set{o.node}, + query_set{o.idx}, + }; +} + +std::unordered_set + apply_dataflow_output_query(DataflowOutputQuery const &q, + std::unordered_set const &os) { + return filter(os, [&](DataflowOutput const &o) { + return dataflow_output_query_includes_dataflow_output(q, o); + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..88ec6d141a --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..9fa68e58b2 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc new file mode 100644 index 0000000000..78dbed5262 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..c53cd4cd15 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.cc new file mode 100644 index 0000000000..49d3a663d9 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/from_labelled_open_dataflow_graph_data.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.cc new file mode 100644 index 0000000000..854f55732f --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..32a6da0bb5 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..1dcbbdc1e6 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.cc new file mode 100644 index 0000000000..2a5fe55809 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.cc new file mode 100644 index 0000000000..d2252d91e9 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" diff --git a/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.cc b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.cc new file mode 100644 index 0000000000..655988fb28 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_value_labels.h" diff --git a/lib/utils/src/utils/graph/node/algorithms/generate_new_node_id_permutation.cc b/lib/utils/src/utils/graph/node/algorithms/generate_new_node_id_permutation.cc new file mode 100644 index 0000000000..256dccd185 --- /dev/null +++ b/lib/utils/src/utils/graph/node/algorithms/generate_new_node_id_permutation.cc @@ -0,0 +1,16 @@ +#include "utils/graph/node/algorithms/generate_new_node_id_permutation.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node_source.h" + +namespace FlexFlow { + +bidict generate_new_node_id_permutation(GraphView const &g) { + NodeSource node_source; + return generate_bidict( + get_nodes(g), + [&](Node const &) { return NewNode{node_source.new_node()}; }) + .reversed(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/node/node_query.cc b/lib/utils/src/utils/graph/node/node_query.cc index c74457465c..834086a733 100644 --- a/lib/utils/src/utils/graph/node/node_query.cc +++ b/lib/utils/src/utils/graph/node/node_query.cc @@ -28,4 +28,9 @@ NodeQuery query_union(NodeQuery const &lhs, NodeQuery const &rhs) { NOT_IMPLEMENTED(); } +std::unordered_set apply_node_query(NodeQuery const &query, + std::unordered_set const &ns) { + return apply_query(query.nodes, ns); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc new file mode 100644 index 0000000000..9077ea5f9a --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/as_dot.cc @@ -0,0 +1,63 @@ +#include "utils/graph/open_dataflow_graph/algorithms/as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/as_dot.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" + +namespace FlexFlow { + +std::string as_dot(OpenDataflowGraphView const &g) { + std::function get_node_label = [](Node const &n) { + return fmt::format("n{}", n.raw_uid); + }; + + std::function get_input_label = + [](DataflowGraphInput const &i) { return fmt::format("i{}", i.idx); }; + + return as_dot(g, get_node_label, get_input_label); +} + +// WARN(@lockshaw): doing this all with string ids is ugly and error prone, +// as it requires duplicating the stringification logic across functions. +// +// Fixing this is tracked in issue +// https://github.com/flexflow/FlexFlow/issues/1476 +std::string + as_dot(OpenDataflowGraphView const &g, + std::function const &get_node_label, + std::function const + &get_input_label) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + as_dot(dot, static_cast(g), get_node_label); + + auto get_node_name = [](Node n) { return fmt::format("n{}", n.raw_uid); }; + + auto get_input_field = [](int idx) { return fmt::format("i{}", idx); }; + + auto get_output_field = [](int idx) { return fmt::format("o{}", idx); }; + + auto get_graph_input_name = [](DataflowGraphInput i) { + return fmt::format("gi{}", i.idx); + }; + + for (DataflowGraphInput const &i : get_open_dataflow_graph_inputs(g)) { + dot.add_node(get_graph_input_name(i), + {{"style", "dashed"}, {"label", get_input_label(i)}}); + } + + for (DataflowInputEdge const &e : get_incoming_edges(g)) { + dot.add_edge(get_graph_input_name(e.src), + get_node_name(e.dst.node), + std::nullopt, + get_input_field(e.dst.idx)); + } + + dot.close(); + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..d622497629 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,20 @@ +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_first.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" + +namespace FlexFlow { + +std::optional + find_isomorphism(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst) { + std::unordered_set all_isomorphisms = + find_isomorphisms(src, dst); + + if (all_isomorphisms.empty()) { + return std::nullopt; + } else { + return get_first(all_isomorphisms); + } +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc new file mode 100644 index 0000000000..d95a9b9565 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -0,0 +1,248 @@ +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" +#include "utils/bidict/algorithms/bidict_from_keys_and_values.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/get_first.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/values.h" +#include "utils/containers/zip.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include + +namespace FlexFlow { + +static std::optional + find_isomorphism_under_sink_node_mapping( + OpenDataflowGraphView const &src_g, + OpenDataflowGraphView const &dst_g, + bidict const &sink_node_mapping, + bidict const + &unused_graph_inputs_mapping) { + { + std::unordered_set already_mapped_src_nodes = + left_entries(sink_node_mapping); + std::unordered_set src_g_sink_nodes = get_sinks(src_g); + assert(already_mapped_src_nodes == src_g_sink_nodes); + } + + { + std::unordered_set already_mapped_dst_nodes = + right_entries(sink_node_mapping); + std::unordered_set dst_g_sink_nodes = get_sinks(dst_g); + assert(already_mapped_dst_nodes == dst_g_sink_nodes); + } + + { + std::unordered_set already_mapped_src_inputs = + right_entries(unused_graph_inputs_mapping); + std::unordered_set src_g_unused_inputs = + get_unused_open_dataflow_graph_inputs(src_g); + assert(already_mapped_src_inputs == src_g_unused_inputs); + } + + { + std::unordered_set already_mapped_dst_inputs = + right_entries(unused_graph_inputs_mapping); + std::unordered_set dst_g_unused_inputs = + get_unused_open_dataflow_graph_inputs(dst_g); + assert(already_mapped_dst_inputs == dst_g_unused_inputs); + } + + std::optional result = + OpenDataflowGraphIsomorphism{ + {}, + unused_graph_inputs_mapping, + }; + + auto fail = [&]() -> void { result = std::nullopt; }; + + auto has_failed = [&]() -> bool { return result == std::nullopt; }; + + std::function unify_nodes; + std::function + unify_edges; + std::function + unify_graph_inputs; + std::function + unify_values; + std::function + unify_outputs; + + unify_outputs = [&](DataflowOutput const &src_output, + DataflowOutput const &dst_output) { + if (has_failed()) { + return; + } + + if (src_output.idx != dst_output.idx) { + result = std::nullopt; + return; + } + + unify_nodes(src_output.node, dst_output.node); + }; + + unify_values = [&](OpenDataflowValue const &src_val, + OpenDataflowValue const &dst_val) { + if (has_failed()) { + return; + } + + if (src_val.index() != dst_val.index()) { + fail(); + return; + } + + if (src_val.has()) { + unify_outputs(src_val.get(), + dst_val.get()); + } else { + unify_graph_inputs(src_val.get(), + dst_val.get()); + } + }; + + unify_graph_inputs = [&](DataflowGraphInput const &src, + DataflowGraphInput const &dst) { + if (has_failed()) { + return; + } + + if (result->input_mapping.contains_l(src) && + result->input_mapping.at_l(src) != dst) { + fail(); + return; + } + if (result->input_mapping.contains_r(dst) && + result->input_mapping.at_r(dst) != src) { + fail(); + return; + } + + result->input_mapping.equate(src, dst); + }; + + unify_edges = [&](OpenDataflowEdge const &src_edge, + OpenDataflowEdge const &dst_edge) { + if (has_failed()) { + return; + } + + assert(get_open_dataflow_edge_dst(src_edge).idx == + get_open_dataflow_edge_dst(dst_edge).idx); + assert( + get_open_dataflow_edge_dst(src_edge).node == + result->node_mapping.at_r(get_open_dataflow_edge_dst(dst_edge).node)); + + unify_values(get_open_dataflow_edge_src(src_edge), + get_open_dataflow_edge_src(dst_edge)); + }; + + unify_nodes = [&](Node const &src_node, Node const &dst_node) { + if (has_failed()) { + return; + } + + if (result->node_mapping.contains(src_node, dst_node)) { + return; + } + + if (result->node_mapping.contains_l(src_node) && + result->node_mapping.at_l(src_node) != dst_node) { + fail(); + return; + } + if (result->node_mapping.contains_r(dst_node) && + result->node_mapping.at_r(dst_node) != src_node) { + fail(); + return; + } + + result->node_mapping.equate(src_node, dst_node); + + std::vector src_incoming_edges = + get_incoming_edges(src_g, src_node); + std::vector dst_incoming_edges = + get_incoming_edges(dst_g, dst_node); + + if (src_incoming_edges.size() != dst_incoming_edges.size()) { + fail(); + return; + } + + for (auto const &[src_edge, dst_edge] : + zip(src_incoming_edges, dst_incoming_edges)) { + unify_edges(src_edge, dst_edge); + } + }; + + for (auto const &[src_node, dst_node] : sink_node_mapping) { + unify_nodes(src_node, dst_node); + } + + return result; +} + +std::unordered_set + find_isomorphisms(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst) { + std::unordered_set result; + + std::vector src_sink_nodes = as_vector(get_sinks(src)); + std::unordered_set dst_sink_nodes = get_sinks(dst); + + if (src_sink_nodes.size() != dst_sink_nodes.size()) { + return {}; + } + + std::vector src_unused_graph_inputs = + as_vector(get_unused_open_dataflow_graph_inputs(src)); + std::unordered_set dst_unused_graph_inputs = + get_unused_open_dataflow_graph_inputs(dst); + + if (src_unused_graph_inputs.size() != dst_unused_graph_inputs.size()) { + return {}; + } + + for (std::vector const &dst_sink_nodes : + get_all_permutations(dst_sink_nodes)) { + + bidict sink_node_mapping = + bidict_from_keys_and_values(src_sink_nodes, dst_sink_nodes); + + for (std::vector const &dst_unused_graph_inputs : + get_all_permutations(dst_unused_graph_inputs)) { + + bidict + unused_graph_inputs_mapping = bidict_from_keys_and_values( + src_unused_graph_inputs, dst_unused_graph_inputs); + + std::optional found = + find_isomorphism_under_sink_node_mapping( + src, dst, sink_node_mapping, unused_graph_inputs_mapping); + + if (found.has_value()) { + assert(is_isomorphic_under(src, dst, found.value())); + + result.insert(found.value()); + } + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.cc new file mode 100644 index 0000000000..c4b5befcbc --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.cc @@ -0,0 +1,41 @@ +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +FromOpenDataflowGraphDataView::FromOpenDataflowGraphDataView( + OpenDataflowGraphData const &data) + : data(data) {} + +std::unordered_set + FromOpenDataflowGraphDataView::query_nodes(NodeQuery const &q) const { + return apply_node_query(q, this->data.nodes); +} + +std::unordered_set FromOpenDataflowGraphDataView::query_edges( + OpenDataflowEdgeQuery const &q) const { + return apply_open_dataflow_edge_query(q, this->data.edges); +} + +std::unordered_set FromOpenDataflowGraphDataView::query_outputs( + DataflowOutputQuery const &q) const { + return apply_dataflow_output_query(q, this->data.outputs); +} + +std::unordered_set + FromOpenDataflowGraphDataView::get_inputs() const { + return this->data.inputs; +} + +FromOpenDataflowGraphDataView *FromOpenDataflowGraphDataView::clone() const { + return new FromOpenDataflowGraphDataView{this->data}; +} + +OpenDataflowGraphView + from_open_dataflow_graph_data(OpenDataflowGraphData const &data) { + return OpenDataflowGraphView::create(data); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.cc new file mode 100644 index 0000000000..7d9a3e3a0e --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.cc @@ -0,0 +1,19 @@ +#include "utils/graph/open_dataflow_graph/algorithms/generate_new_input_id_permutation.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" + +namespace FlexFlow { + +bidict + generate_new_input_id_permutation(OpenDataflowGraphView const &g) { + DataflowGraphInputSource input_source; + return generate_bidict(get_open_dataflow_graph_inputs(g), + [&](DataflowGraphInput const &) { + return NewDataflowGraphInput{ + input_source.new_dataflow_graph_input()}; + }) + .reversed(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_edges.cc new file mode 100644 index 0000000000..610239feff --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_edges.cc @@ -0,0 +1,10 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(OpenDataflowGraphView const &g) { + return g.query_edges(open_dataflow_edge_query_all()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_graph_data.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_graph_data.cc new file mode 100644 index 0000000000..3199be92f9 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_graph_data.cc @@ -0,0 +1,17 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_edges.h" + +namespace FlexFlow { + +OpenDataflowGraphData get_graph_data(OpenDataflowGraphView const &g) { + return OpenDataflowGraphData{ + get_nodes(g), + get_edges(g), + g.get_inputs(), + get_all_dataflow_outputs(g), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.cc new file mode 100644 index 0000000000..ac1aae1168 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.cc @@ -0,0 +1,15 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edge.h" +#include "utils/containers/get_only.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +OpenDataflowEdge get_incoming_edge(OpenDataflowGraphView const &g, + DataflowInput const &i) { + OpenDataflowEdgeQuery query = open_dataflow_edge_query_all_incoming_to(i); + std::unordered_set query_result = g.query_edges(query); + + return get_only(query_result); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc similarity index 55% rename from lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc rename to lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc index ff5451d239..cad00c71e1 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc @@ -1,28 +1,23 @@ -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/containers/generate_map.h" -#include "utils/containers/group_by.h" #include "utils/containers/sorted_by.h" #include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" -#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" namespace FlexFlow { -std::unordered_set get_edges(OpenDataflowGraphView const &g) { - return g.query_edges(open_dataflow_edge_query_all()); -} - -std::unordered_set - get_inputs(OpenDataflowGraphView const &g) { - return g.get_inputs(); -} +std::unordered_set + get_incoming_edges(OpenDataflowGraphView const &g) { + std::unordered_set raw_edges = + g.query_edges(OpenDataflowEdgeQuery{ + dataflow_input_edge_query_all(), + dataflow_edge_query_none(), + }); -std::vector get_inputs(OpenDataflowGraphView const &g, - Node const &n) { - return transform(get_incoming_edges(g, n), [](OpenDataflowEdge const &e) { - return get_open_dataflow_edge_source(e); + return transform(raw_edges, [](OpenDataflowEdge const &e) { + return e.get(); }); } @@ -54,14 +49,4 @@ std::unordered_map> [&](Node const &n) { return get_incoming_edges(g, n); }); } -std::unordered_set - get_open_dataflow_values(OpenDataflowGraphView const &g) { - return set_union( - transform( - unordered_set_of(g.get_inputs()), - [](DataflowGraphInput const &gi) { return OpenDataflowValue{gi}; }), - transform(get_all_dataflow_outputs(g), - [](DataflowOutput const &o) { return OpenDataflowValue{o}; })); -} - } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_inputs.cc new file mode 100644 index 0000000000..f4e23e04f4 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_inputs.cc @@ -0,0 +1,15 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" +#include "utils/containers/transform.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" + +namespace FlexFlow { + +std::vector get_inputs(OpenDataflowGraphView const &g, + Node const &n) { + return transform(get_incoming_edges(g, n), [](OpenDataflowEdge const &e) { + return get_open_dataflow_edge_src(e); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..78c7677de9 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc @@ -0,0 +1,10 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_graph_inputs(OpenDataflowGraphView const &g) { + return g.get_inputs(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc new file mode 100644 index 0000000000..12795b8f7e --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc @@ -0,0 +1,17 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" +#include "utils/containers/transform.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_value_uses(OpenDataflowGraphView const &view, + OpenDataflowValue const &value) { + std::unordered_set edges = + view.query_edges(open_dataflow_edge_query_all_outgoing_from(value)); + + return transform(edges, get_open_dataflow_edge_dst); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.cc new file mode 100644 index 0000000000..0aa1bdb054 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.cc @@ -0,0 +1,17 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_open_dataflow_values(OpenDataflowGraphView const &g) { + return set_union( + transform( + unordered_set_of(g.get_inputs()), + [](DataflowGraphInput const &gi) { return OpenDataflowValue{gi}; }), + transform(get_all_dataflow_outputs(g), + [](DataflowOutput const &o) { return OpenDataflowValue{o}; })); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.cc new file mode 100644 index 0000000000..14099e1c64 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_source_nodes.cc @@ -0,0 +1,16 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_source_nodes.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" + +namespace FlexFlow { + +std::unordered_set get_source_nodes(OpenDataflowGraphView const &g) { + auto is_source_node = [&](Node const &n) { + std::vector incoming_edges = get_incoming_edges(g, n); + return incoming_edges.empty(); + }; + + return filter(get_nodes(g), is_source_node); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..95a8e095fc --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,29 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(OpenDataflowGraphView const &g, + std::unordered_set const &ns) { + std::unordered_set nodes_not_in_ns = set_minus(get_nodes(g), ns); + + OpenDataflowEdgeQuery query = OpenDataflowEdgeQuery{ + DataflowInputEdgeQuery{ + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }, + DataflowEdgeQuery{ + query_set{nodes_not_in_ns}, + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }, + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index 8c07f4bfdb..4ade34941c 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -3,7 +3,7 @@ #include "utils/containers/extend.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" -#include "utils/graph/open_dataflow_graph/algorithms.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/overload.h" @@ -28,7 +28,7 @@ std::unordered_set extend(relevant_edges, filter(incoming, comes_from_outside_subgraph)); } - return transform(relevant_edges, get_open_dataflow_edge_source); + return transform(relevant_edges, get_open_dataflow_edge_src); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..8fbe7ae5bc --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc @@ -0,0 +1,15 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" + +namespace FlexFlow { + +std::unordered_set + get_unused_open_dataflow_graph_inputs(OpenDataflowGraphView const &g) { + return filter( + get_open_dataflow_graph_inputs(g), [&](DataflowGraphInput const &i) { + return get_open_dataflow_value_uses(g, OpenDataflowValue{i}).empty(); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..77e23d9c87 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1,30 @@ +#include "utils/graph/open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/graph/node/algorithms/new_node.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/new_dataflow_graph_input.dtg.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" + +namespace FlexFlow { + +bool is_isomorphic_under( + OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst, + OpenDataflowGraphIsomorphism const &candidate_isomorphism) { + + bidict node_permutation = + map_values(candidate_isomorphism.node_mapping, [](Node const &dst_node) { + return NewNode{dst_node}; + }).reversed(); + bidict input_permutation = + map_values(candidate_isomorphism.input_mapping, + [](DataflowGraphInput const &dst_input) { + return NewDataflowGraphInput{dst_input}; + }) + .reversed(); + return get_graph_data(permute_input_ids( + permute_node_ids(src, node_permutation), input_permutation)) == + get_graph_data(dst); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..af56db2de3 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,11 @@ +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h" +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" + +namespace FlexFlow { + +bool open_dataflow_graphs_are_isomorphic(OpenDataflowGraphView const &src, + OpenDataflowGraphView const &dst) { + return find_isomorphism(src, dst).has_value(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc new file mode 100644 index 0000000000..c9c60edae3 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc @@ -0,0 +1,44 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/containers/transform.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OpenDataflowGraphView permute_input_ids( + OpenDataflowGraphView const &g, + bidict const &input_mapping) { + auto new_input_from_old = + [&](DataflowGraphInput const &old_input) -> DataflowGraphInput { + return input_mapping.at_r(old_input).raw_input; + }; + + auto new_edge_from_old = [&](OpenDataflowEdge const &e) { + return e.visit(overload{ + [&](DataflowInputEdge const &input_e) { + return OpenDataflowEdge{ + DataflowInputEdge{ + new_input_from_old(input_e.src), + input_e.dst, + }, + }; + }, + [&](DataflowEdge const &standard_e) { + return OpenDataflowEdge{standard_e}; + }, + }); + }; + + OpenDataflowGraphData old_data = get_graph_data(g); + OpenDataflowGraphData permuted_data = OpenDataflowGraphData{ + old_data.nodes, + transform(old_data.edges, new_edge_from_old), + transform(old_data.inputs, new_input_from_old), + old_data.outputs, + }; + + return from_open_dataflow_graph_data(permuted_data); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc new file mode 100644 index 0000000000..ab05cbbdc3 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc @@ -0,0 +1,72 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/transform.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/node/node_source.h" +#include "utils/graph/open_dataflow_graph/algorithms/from_open_dataflow_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/query_set.h" +#include "utils/overload.h" + +namespace FlexFlow { + +OpenDataflowGraphView + permute_node_ids(OpenDataflowGraphView const &g, + bidict const &new_node_tofrom_old_node) { + auto new_node_from_old = [&](Node const &n) -> Node { + return new_node_tofrom_old_node.at_r(n).raw_node; + }; + + auto new_output_from_old = [&](DataflowOutput const &o) -> DataflowOutput { + return DataflowOutput{ + new_node_from_old(o.node), + o.idx, + }; + }; + + auto new_input_from_old = [&](DataflowInput const &i) -> DataflowInput { + return DataflowInput{ + new_node_from_old(i.node), + i.idx, + }; + }; + + auto new_edge_from_old = [&](OpenDataflowEdge const &e) { + return e.visit(overload{ + [&](DataflowInputEdge const &input_e) { + return OpenDataflowEdge{ + DataflowInputEdge{ + input_e.src, + new_input_from_old(input_e.dst), + }, + }; + }, + [&](DataflowEdge const &standard_e) { + return OpenDataflowEdge{ + DataflowEdge{ + new_output_from_old(standard_e.src), + new_input_from_old(standard_e.dst), + }, + }; + }, + }); + }; + + OpenDataflowGraphData old_data = get_graph_data(g); + + OpenDataflowGraphData permuted_data = OpenDataflowGraphData{ + transform(old_data.nodes, new_node_from_old), + transform(old_data.edges, new_edge_from_old), + old_data.inputs, + transform(old_data.outputs, new_output_from_old), + }; + + return from_open_dataflow_graph_data(permuted_data); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc index 19da98aabd..8736f2d157 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/dataflow_input_edge_query.cc @@ -23,4 +23,31 @@ bool dataflow_input_edge_query_includes(DataflowInputEdgeQuery const &q, includes(q.dst_idxs, e.dst.idx); } +DataflowInputEdgeQuery + dataflow_input_edge_query_for_edge(DataflowInputEdge const &e) { + return DataflowInputEdgeQuery{ + query_set{e.src}, + query_set{e.dst.node}, + query_set{e.dst.idx}, + }; +} + +DataflowInputEdgeQuery + dataflow_input_edge_query_all_outgoing_from(DataflowGraphInput const &src) { + return DataflowInputEdgeQuery{ + query_set{src}, + query_set::matchall(), + query_set::matchall(), + }; +} + +DataflowInputEdgeQuery + dataflow_input_edge_query_all_incoming_to(DataflowInput const &dst) { + return DataflowInputEdgeQuery{ + query_set::matchall(), + query_set{dst.node}, + query_set{dst.idx}, + }; +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc index e3311e4d18..d5e5b614af 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge.cc @@ -4,21 +4,21 @@ namespace FlexFlow { Node get_open_dataflow_edge_dst_node(OpenDataflowEdge const &e) { - return e.visit(overload{ - [](DataflowEdge const &e) { return e.dst.node; }, - [](DataflowInputEdge const &e) { return e.dst.node; }, - }); + return get_open_dataflow_edge_dst(e).node; } int get_open_dataflow_edge_dst_idx(OpenDataflowEdge const &e) { - return e.visit(overload{ - [](DataflowEdge const &e) { return e.dst.idx; }, - [](DataflowInputEdge const &e) { return e.dst.idx; }, + return get_open_dataflow_edge_dst(e).idx; +} + +DataflowInput get_open_dataflow_edge_dst(OpenDataflowEdge const &e) { + return e.visit(overload{ + [](DataflowEdge const &e) { return e.dst; }, + [](DataflowInputEdge const &e) { return e.dst; }, }); } -OpenDataflowValue - get_open_dataflow_edge_source(OpenDataflowEdge const &open_e) { +OpenDataflowValue get_open_dataflow_edge_src(OpenDataflowEdge const &open_e) { return open_e.visit(overload{ [](DataflowEdge const &e) { return OpenDataflowValue{e.src}; }, [](DataflowInputEdge const &e) { return OpenDataflowValue{e.src}; }, diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc index 4d12889a1e..4882c3e143 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_edge_query.cc @@ -32,4 +32,38 @@ bool open_dataflow_edge_query_includes(OpenDataflowEdgeQuery const &q, }); } +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_outgoing_from(OpenDataflowValue const &src) { + return src.visit(overload{ + [](DataflowOutput const &o) { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_none(), + dataflow_edge_query_all_outgoing_from(o), + }; + }, + [](DataflowGraphInput const &i) { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_all_outgoing_from(i), + dataflow_edge_query_none(), + }; + }, + }); +} + +OpenDataflowEdgeQuery + open_dataflow_edge_query_all_incoming_to(DataflowInput const &dst) { + return OpenDataflowEdgeQuery{ + dataflow_input_edge_query_all_incoming_to(dst), + dataflow_edge_query_all_incoming_to(dst), + }; +} + +std::unordered_set apply_open_dataflow_edge_query( + OpenDataflowEdgeQuery const &q, + std::unordered_set const &es) { + return filter(es, [&](OpenDataflowEdge const &e) { + return open_dataflow_edge_query_includes(q, e); + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.cc b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.cc new file mode 100644 index 0000000000..25f13fd298 --- /dev/null +++ b/lib/utils/src/utils/graph/open_dataflow_graph/open_dataflow_value.cc @@ -0,0 +1,22 @@ +#include "utils/graph/open_dataflow_graph/open_dataflow_value.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + try_get_dataflow_output(OpenDataflowValue const &v) { + return v.visit>(overload{ + [](DataflowOutput const &o) { return o; }, + [](DataflowGraphInput const &i) { return std::nullopt; }, + }); +} + +std::optional + try_get_dataflow_graph_input(OpenDataflowValue const &v) { + return v.visit>(overload{ + [](DataflowOutput const &o) { return std::nullopt; }, + [](DataflowGraphInput const &i) { return i; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc b/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc index 347c906bd7..4c9eb9d3ef 100644 --- a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc +++ b/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc @@ -7,7 +7,8 @@ namespace FlexFlow { void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { for (Node const &node : get_nodes(ext)) { - g.add_node_unsafe(node, get_inputs(ext, node), get_outputs(ext, node)); + g.add_node_unsafe( + node, get_input_values(ext, node), get_outputs(ext, node)); } } diff --git a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc new file mode 100644 index 0000000000..6e3ac8c155 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -0,0 +1,38 @@ +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("bidict_from_enumerating(std::unordered_set)") { + std::unordered_set input = {"zero", "one", "two"}; + + bidict result = bidict_from_enumerating(input); + + std::unordered_set result_left_entries = left_entries(result); + std::unordered_set correct_left_entries = {0, 1, 2}; + CHECK(result_left_entries == correct_left_entries); + + std::unordered_set result_right_entries = + right_entries(result); + std::unordered_set correct_right_entries = input; + CHECK(result_right_entries == correct_right_entries); + } + + TEST_CASE("bidict_from_enumerating(std::set)") { + std::set input = {"a", "c", "b"}; + + bidict correct = { + {0, "a"}, + {1, "b"}, + {2, "c"}, + }; + + bidict result = bidict_from_enumerating(input); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc new file mode 100644 index 0000000000..2be5f1ef93 --- /dev/null +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -0,0 +1,50 @@ +#include "utils/containers/enumerate.h" +#include "utils/containers/as_vector.h" +#include "utils/fmt/map.h" +#include "utils/fmt/pair.h" +#include "utils/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("enumerate(std::vector)") { + std::vector input = {"zero", "one", "two", "three"}; + + std::map correct = { + {0, "zero"}, + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + + std::map result = enumerate(input); + + CHECK(result == correct); + + SUBCASE("check iteration order") { + std::vector> iterated_result = + as_vector(result); + std::vector> correct_iteration_order = { + {0, "zero"}, + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + + CHECK(iterated_result == correct_iteration_order); + } + } + + TEST_CASE("enumerate(std::unordered_set)") { + std::unordered_set input = {"zero", "one", "two", "three"}; + + std::map correct = { + {0, "zero"}, + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + } +} diff --git a/lib/utils/test/src/utils/containers/filtrans.cc b/lib/utils/test/src/utils/containers/filtrans.cc new file mode 100644 index 0000000000..b8bb832b06 --- /dev/null +++ b/lib/utils/test/src/utils/containers/filtrans.cc @@ -0,0 +1,57 @@ +#include "utils/containers/filtrans.h" +#include "utils/fmt/set.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filtrans(std::vector, F)") { + std::vector input = {1, 2, 3, 2, 4}; + std::vector result = + filtrans(input, [](int x) -> std::optional { + if ((x % 2) == 0) { + return std::to_string(x); + } else { + return std::nullopt; + } + }); + + std::vector correct = {"2", "2", "4"}; + + CHECK(result == correct); + } + + TEST_CASE("filtrans(std::unordered_set, F)") { + std::unordered_set input = {1, 2, 3, 4}; + std::unordered_set result = + filtrans(input, [](int x) -> std::optional { + if ((x % 2) == 0) { + return std::to_string(x); + } else { + return std::nullopt; + } + }); + + std::unordered_set correct = {"2", "4"}; + + CHECK(result == correct); + } + + TEST_CASE("filtrans(std::set, F)") { + std::set input = {1, 2, 3, 4}; + std::set result = + filtrans(input, [](int x) -> std::optional { + if ((x % 2) == 0) { + return std::to_string(x); + } else { + return std::nullopt; + } + }); + + std::set correct = {"2", "4"}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_permutations.cc b/lib/utils/test/src/utils/containers/get_all_permutations.cc new file mode 100644 index 0000000000..5f22266809 --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_permutations.cc @@ -0,0 +1,54 @@ +#include "utils/containers/get_all_permutations.h" +#include "utils/containers/as_vector.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/fmt/vector.h" +#include "utils/hash/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_all_permutations") { + SUBCASE("input size 1") { + std::vector input = {1}; + + std::unordered_multiset> result = + unordered_multiset_of(get_all_permutations(input)); + std::unordered_multiset> correct = {{1}}; + + CHECK(result == correct); + } + + SUBCASE("input size 3") { + std::vector input = {2, 1, 3}; + + std::unordered_multiset> result = + unordered_multiset_of(get_all_permutations(input)); + std::unordered_multiset> correct = { + {1, 2, 3}, + {1, 3, 2}, + {2, 1, 3}, + {2, 3, 1}, + {3, 1, 2}, + {3, 2, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("elements repeated") { + std::vector input = {1, 2, 2}; + + std::unordered_multiset> result = + unordered_multiset_of(get_all_permutations(input)); + std::unordered_multiset> correct = { + {1, 2, 2}, + {2, 1, 2}, + {2, 2, 1}, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc index f716f73a03..25f990f80e 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms.cc @@ -28,8 +28,8 @@ TEST_SUITE(FF_TEST_SUITE) { Node n4 = n4_added.node; DataflowOutput o4 = get_only(n4_added.outputs); - SUBCASE("get_inputs") { - std::vector result = get_inputs(g, n4); + SUBCASE("get_input_values") { + std::vector result = get_input_values(g, n4); std::vector correct = {o1, o2, o3}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..f991b4a65e --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,93 @@ +#include "utils/graph/dataflow_graph/algorithms/dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "dataflow_graphs_are_isomorphic(DataflowGraphView, DataflowGraphView)") { + auto g1 = DataflowGraph::create(); + + NodeAddedResult g1_n1_added = g1.add_node({}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node({g1_n1_output}, 1); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = DataflowGraph::create(); + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = true; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node({g2_n1_output, g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = false; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of src and sink " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 1); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of internal " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node({g2_n2_output}, 1); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..160e4c4f73 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,101 @@ +#include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism(DataflowGraphView, DataflowGraphView)") { + auto g1 = DataflowGraph::create(); + + NodeAddedResult g1_n1_added = g1.add_node({}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node({g1_n1_output}, 1); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = DataflowGraph::create(); + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct_isomorphism = + DataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + }; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node({g2_n1_output, g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct_isomorphism = + std::nullopt; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + + SUBCASE("input graphs are not isomorphic (different number of src and sink " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 0); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct_isomorphism = + std::nullopt; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + + SUBCASE("input graphs are not isomorphic (different number of internal " + "nodes)") { + NodeAddedResult g2_n1_added = g2.add_node({}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node({g2_n1_output}, 1); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node({g2_n2_output}, 1); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct_isomorphism = + std::nullopt; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct_isomorphism); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..7e02686dde --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,41 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_outgoing_edges(DataflowGraphView, std::unordered_set") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o1, o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + std::unordered_set input_node_set = {n2, n3}; + + std::unordered_set result = + get_subgraph_outgoing_edges(g, input_node_set); + + std::unordered_set correct = { + DataflowEdge{o2, DataflowInput{n4, 1}}, + DataflowEdge{o3, DataflowInput{n4, 2}}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..1ac1b7ff01 --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,187 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism(LabelledDataflowGraphView, " + "LabelledDataflowGraphView)") { + auto g1 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + auto g2 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("duplicate labels") { + std::string node_label = "n"; + int value_label = 1; + + NodeAddedResult g1_n1_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = + g1.add_node(node_label, {g1_n1_output, g1_n2_output}, {value_label}); + Node g1_n3_node = g1_n3_added.node; + + NodeAddedResult g2_n1_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = + g2.add_node(node_label, {g2_n1_output, g2_n2_output}, {value_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = + DataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + {g1_n3_node, g2_n3_node}, + }, + }; + + std::optional result = find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("non-duplicate labels") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + std::string n3_label = "n3"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + int n3_output_label = 4; + + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, {}, {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = g1.add_node( + n3_label, {g1_n1_output, g1_n2_output}, {n3_output_label}); + Node g1_n3_node = g1_n3_added.node; + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = + DataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + {g1_n3_node, g2_n3_node}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(mismatched_node_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {mismatched_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs " + "not isomorphic)") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n2_output, g2_n1_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..12950b8ad2 --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,169 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/labelled_dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("labelled_dataflow_graphs_are_isomorphic(LabelledDataflowGraphView," + " LabelledDataflowGraphView)") { + auto g1 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + auto g2 = LabelledDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("duplicate labels") { + std::string node_label = "n"; + int value_label = 1; + + NodeAddedResult g1_n1_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node(node_label, {}, {value_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = + g1.add_node(node_label, {g1_n1_output, g1_n2_output}, {value_label}); + Node g1_n3_node = g1_n3_added.node; + + NodeAddedResult g2_n1_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = g2.add_node(node_label, {}, {value_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = + g2.add_node(node_label, {g2_n1_output, g2_n2_output}, {value_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = true; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("non-duplicate labels") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + std::string n3_label = "n3"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + int n3_output_label = 4; + + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, {}, {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + DataflowOutput g1_n2_output = get_only(g1_n2_added.outputs); + + NodeAddedResult g1_n3_added = g1.add_node( + n3_label, {g1_n1_output, g1_n2_output}, {n3_output_label}); + Node g1_n3_node = g1_n3_added.node; + + SUBCASE("input graphs are isomorphic") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = true; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(mismatched_node_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {mismatched_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n1_output, g2_n2_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs " + "not isomorphic)") { + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, {}, {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + DataflowOutput g2_n2_output = get_only(g2_n2_added.outputs); + + NodeAddedResult g2_n3_added = g2.add_node( + n3_label, {g2_n2_output, g2_n1_output}, {n3_output_label}); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + + bool result = labelled_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..c83366e78c --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,185 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism") { + auto g1 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + auto g2 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("duplicate labels") { + std::string node_label = "n"; + int value_label = 2; + + DataflowGraphInput g1_i1 = g1.add_input(value_label); + NodeAddedResult g1_n1_added = + g1.add_node(node_label, {OpenDataflowValue{g1_i1}}, {value_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = g1.add_node( + node_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {value_label}); + Node g1_n2_node = g1_n2_added.node; + + DataflowGraphInput g2_i1 = g2.add_input(value_label); + NodeAddedResult g2_n1_added = + g2.add_node(node_label, {OpenDataflowValue{g2_i1}}, {value_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + node_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {value_label}); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("differing labels") { + std::string n1_label = "n1"; + std::string n2_label = "n2"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + + DataflowGraphInput g1_i1 = g1.add_input(i1_label); + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {OpenDataflowValue{g1_i1}}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = g1.add_node( + n2_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + mismatched_node_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {mismatched_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched input label)") { + int mismatched_input_label = 10000; + + DataflowGraphInput g2_i1 = g2.add_input(mismatched_input_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs " + "not isomorphic)") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, {OpenDataflowValue{g2_n1_output}}, {n2_output_label}); + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc new file mode 100644 index 0000000000..0f59392fcc --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.cc @@ -0,0 +1,60 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_isomorphic_under") { + auto g1 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + std::string n1_label = "n1"; + std::string n2_label = "n2"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + + DataflowGraphInput g1_i1 = g1.add_input(i1_label); + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {OpenDataflowValue{g1_i1}}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = + g2.add_node(n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + + OpenDataflowGraphIsomorphism correct_isomorphism = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + bool result = is_isomorphic_under(g1, g2, correct_isomorphism); + + CHECK(result); + } +} diff --git a/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..9f8d5eb08a --- /dev/null +++ b/lib/utils/test/src/utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,121 @@ +#include "utils/graph/labelled_open_dataflow_graph/algorithms/labelled_open_dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("labelled_open_dataflow_graphs_are_isomorphic") { + auto g1 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + std::string n1_label = "n1"; + std::string n2_label = "n2"; + int i1_label = 1; + int n1_output_label = 2; + int n2_output_label = 3; + + DataflowGraphInput g1_i1 = g1.add_input(i1_label); + NodeAddedResult g1_n1_added = + g1.add_node(n1_label, {OpenDataflowValue{g1_i1}}, {n1_output_label}); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + NodeAddedResult g1_n2_added = + g1.add_node(n2_label, + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, + {n2_output_label}); + Node g1_n2_node = g1_n2_added.node; + + auto g2 = LabelledOpenDataflowGraph::create< + UnorderedSetLabelledOpenDataflowGraph>(); + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + Node g2_n2_node = g2_n2_added.node; + + bool correct = true; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched node labels)") { + std::string mismatched_node_label = "mismatched_node_label"; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + "mismatched_label", + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched output label)") { + int mismatched_output_label = 20000; + + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = g2.add_node( + n1_label, {OpenDataflowValue{g2_i1}}, {mismatched_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (mismatched input label)") { + int mismatched_input_label = 10000; + + DataflowGraphInput g2_i1 = g2.add_input(mismatched_input_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, + {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (underlying unlabelled graphs not " + "isomorphic)") { + DataflowGraphInput g2_i1 = g2.add_input(i1_label); + NodeAddedResult g2_n1_added = + g2.add_node(n1_label, {OpenDataflowValue{g2_i1}}, {n1_output_label}); + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + n2_label, {OpenDataflowValue{g2_n1_output}}, {n2_output_label}); + + bool correct = false; + bool result = labelled_open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc new file mode 100644 index 0000000000..78aaa8d9fc --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc @@ -0,0 +1,134 @@ +#include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_isomorphism(OpenDataflowGraphView, OpenDataflowGraphView)") { + auto g1 = OpenDataflowGraph::create(); + auto g2 = OpenDataflowGraph::create(); + + SUBCASE("input graphs are empty") { + std::optional correct = + OpenDataflowGraphIsomorphism{ + {}, + {}, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not empty") { + DataflowGraphInput g1_i1 = g1.add_input(); + NodeAddedResult g1_n1_added = g1.add_node({OpenDataflowValue{g1_i1}}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node( + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, 1); + Node g1_n2_node = g1_n2_added.node; + + SUBCASE("one graph is empty") { + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = + OpenDataflowGraphIsomorphism{ + bidict{ + {g1_n1_node, g2_n1_node}, + {g1_n2_node, g2_n2_node}, + }, + bidict{ + {g1_i1, g2_i1}, + }, + }; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of graph " + "inputs)") { + DataflowGraphInput g2_i1 = g2.add_input(); + DataflowGraphInput g2_i2 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_n1_output}, OpenDataflowValue{g2_n1_output}}, + 1); + Node g2_n2_node = g2_n2_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different numbers of nodes)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 0); + Node g2_n3_node = g2_n3_added.node; + + std::optional correct = std::nullopt; + + std::optional result = + find_isomorphism(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..ff75e8fe48 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.cc @@ -0,0 +1,24 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_open_dataflow_graph_inputs(OpenDataflowGraphView)") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + DataflowGraphInput i1 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({}, 1); + + std::unordered_set result = + get_open_dataflow_graph_inputs(g); + std::unordered_set correct = {i0, i1}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc new file mode 100644 index 0000000000..7496c3009d --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.cc @@ -0,0 +1,74 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_value_uses.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_open_dataflow_value_uses(OpenDataflowGraphView, " + "OpenDataflowValue)") { + SUBCASE("value is a DataflowGraphInput") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + DataflowGraphInput i1 = g.add_input(); + + NodeAddedResult n0_added = g.add_node( + {OpenDataflowValue{i0}, OpenDataflowValue{i1}, OpenDataflowValue{i0}}, + 1); + Node n0 = n0_added.node; + DataflowOutput o0 = get_only(n0_added.outputs); + + NodeAddedResult n1_added = g.add_node( + {OpenDataflowValue{i1}, OpenDataflowValue{o0}, OpenDataflowValue{i0}}, + 1); + Node n1 = n1_added.node; + + std::unordered_set correct = { + DataflowInput{n0, 0}, + DataflowInput{n0, 2}, + DataflowInput{n1, 2}, + }; + + std::unordered_set result = + get_open_dataflow_value_uses(g, OpenDataflowValue{i0}); + + CHECK(result == correct); + } + + SUBCASE("value is a DataflowOutput") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({OpenDataflowValue{i0}}, 2); + Node n0 = n0_added.node; + DataflowOutput o0_0 = n0_added.outputs.at(0); + DataflowOutput o0_1 = n0_added.outputs.at(1); + + NodeAddedResult n1_added = g.add_node({OpenDataflowValue{i0}, + OpenDataflowValue{o0_1}, + OpenDataflowValue{o0_0}}, + 1); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = + g.add_node({OpenDataflowValue{o0_1}, OpenDataflowValue{i0}}, 1); + Node n2 = n2_added.node; + + std::unordered_set correct = { + DataflowInput{n1, 1}, + DataflowInput{n2, 0}, + }; + + std::unordered_set result = + get_open_dataflow_value_uses(g, OpenDataflowValue{o0_1}); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc new file mode 100644 index 0000000000..ddd6d74119 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.cc @@ -0,0 +1,41 @@ +#include "utils/graph/open_dataflow_graph/algorithms/get_unused_open_dataflow_graph_inputs.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_unused_open_dataflow_graph_inputs(OpenDataflowGraphView)") { + auto g = OpenDataflowGraph::create(); + SUBCASE("unused inputs exist") { + DataflowGraphInput g_i1 = g.add_input(); + DataflowGraphInput g_i2 = g.add_input(); + DataflowGraphInput g_i3 = g.add_input(); + + NodeAddedResult g_n1_added = g.add_node({OpenDataflowValue{g_i2}}, 1); + + std::unordered_set result = + get_unused_open_dataflow_graph_inputs(g); + + std::unordered_set correct = {g_i1, g_i3}; + + CHECK(result == correct); + } + + SUBCASE("unused inputs don't exist") { + DataflowGraphInput g_i1 = g.add_input(); + DataflowGraphInput g_i2 = g.add_input(); + + NodeAddedResult g_n1_added = + g.add_node({OpenDataflowValue{g_i1}, OpenDataflowValue{g_i2}}, 1); + + std::unordered_set result = + get_unused_open_dataflow_graph_inputs(g); + + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc new file mode 100644 index 0000000000..bdb1bb4814 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.cc @@ -0,0 +1,110 @@ +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graphs_are_isomorphic.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("open_dataflow_graphs_are_isomorphic(OpenDataflowGraphView, " + "OpenDataflowGraphView)") { + auto g1 = OpenDataflowGraph::create(); + auto g2 = OpenDataflowGraph::create(); + + SUBCASE("input graphs are empty") { + bool correct = true; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not empty") { + DataflowGraphInput g1_i1 = g1.add_input(); + NodeAddedResult g1_n1_added = g1.add_node({OpenDataflowValue{g1_i1}}, 1); + Node g1_n1_node = g1_n1_added.node; + DataflowOutput g1_n1_output = get_only(g1_n1_added.outputs); + + NodeAddedResult g1_n2_added = g1.add_node( + {OpenDataflowValue{g1_i1}, OpenDataflowValue{g1_n1_output}}, 1); + Node g1_n2_node = g1_n2_added.node; + + SUBCASE("one input graph is empty") { + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are isomorphic") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = true; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different number of graph " + "inputs)") { + DataflowGraphInput g2_i1 = g2.add_input(); + DataflowGraphInput g2_i2 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different connectivity)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_n1_output}, OpenDataflowValue{g2_n1_output}}, + 1); + Node g2_n2_node = g2_n2_added.node; + + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + + SUBCASE("input graphs are not isomorphic (different numbers of nodes)") { + DataflowGraphInput g2_i1 = g2.add_input(); + NodeAddedResult g2_n1_added = + g2.add_node({OpenDataflowValue{g2_i1}}, 1); + Node g2_n1_node = g2_n1_added.node; + DataflowOutput g2_n1_output = get_only(g2_n1_added.outputs); + NodeAddedResult g2_n2_added = g2.add_node( + {OpenDataflowValue{g2_i1}, OpenDataflowValue{g2_n1_output}}, 1); + Node g2_n2_node = g2_n2_added.node; + + NodeAddedResult g2_n3_added = g2.add_node({}, 0); + Node g2_n3_node = g2_n3_added.node; + + bool correct = false; + bool result = open_dataflow_graphs_are_isomorphic(g1, g2); + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc new file mode 100644 index 0000000000..b565e46e67 --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_input_ids.cc @@ -0,0 +1,79 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_input_ids.h" +#include "utils/containers/get_only.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/open_dataflow_graph/algorithms/open_dataflow_graph_data.dtg.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("permute_input_ids(OpenDataflowGraphView, " + "bidict)") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + DataflowGraphInput i1 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({OpenDataflowValue{i0}}, 1); + Node n0 = n0_added.node; + DataflowOutput n0_output = get_only(n0_added.outputs); + + NodeAddedResult n1_added = g.add_node({OpenDataflowValue{n0_output}}, 1); + Node n1 = n1_added.node; + DataflowOutput n1_output = get_only(n1_added.outputs); + + DataflowGraphInput new_i0 = DataflowGraphInput{6}; + DataflowGraphInput new_i1 = DataflowGraphInput{7}; + + bidict input_mapping = { + {NewDataflowGraphInput{new_i0}, i0}, + {NewDataflowGraphInput{new_i1}, i1}, + }; + + OpenDataflowGraphView result = permute_input_ids(g, input_mapping); + OpenDataflowGraphData result_data = get_graph_data(result); + + OpenDataflowGraphData correct_data = OpenDataflowGraphData{ + {n0, n1}, + { + OpenDataflowEdge{ + DataflowInputEdge{ + new_i0, + DataflowInput{ + n0, + 0, + }, + }, + }, + OpenDataflowEdge{ + DataflowEdge{ + DataflowOutput{ + n0, + 0, + }, + DataflowInput{ + n1, + 0, + }, + }, + }, + }, + {new_i0, new_i1}, + { + DataflowOutput{ + n0, + 0, + }, + DataflowOutput{ + n1, + 0, + }, + }, + }; + + CHECK(result_data == correct_data); + } +} diff --git a/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc new file mode 100644 index 0000000000..36bcd16dad --- /dev/null +++ b/lib/utils/test/src/utils/graph/open_dataflow_graph/algorithms/permute_node_ids.cc @@ -0,0 +1,175 @@ +#include "utils/graph/open_dataflow_graph/algorithms/permute_node_ids.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_edge_query.h" +#include "utils/graph/dataflow_graph/dataflow_output_query.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/open_dataflow_graph/algorithms/get_graph_data.h" +#include "utils/graph/open_dataflow_graph/dataflow_input_edge_query.h" +#include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("permute_node_ids(OpenDataflowGraphView, bidict)") { + OpenDataflowGraph g = + OpenDataflowGraph::create(); + + DataflowGraphInput i0 = g.add_input(); + + NodeAddedResult n0_added = g.add_node({OpenDataflowValue{i0}}, 1); + Node n0 = n0_added.node; + DataflowOutput n0_output = get_only(n0_added.outputs); + + NodeAddedResult n1_added = + g.add_node({OpenDataflowValue{i0}, OpenDataflowValue{n0_output}}, 1); + Node n1 = n1_added.node; + DataflowOutput n1_output = get_only(n1_added.outputs); + + Node new_node0 = Node{5}; + Node new_node1 = Node{6}; + + bidict node_mapping = { + {NewNode{new_node0}, n0}, + {NewNode{new_node1}, n1}, + }; + + OpenDataflowGraphView result = permute_node_ids(g, node_mapping); + OpenDataflowGraphData result_data = get_graph_data(result); + + OpenDataflowGraphData correct_data = OpenDataflowGraphData{ + {new_node0, new_node1}, + { + OpenDataflowEdge{ + DataflowInputEdge{ + i0, + DataflowInput{ + new_node0, + 0, + }, + }, + }, + OpenDataflowEdge{ + DataflowInputEdge{ + i0, + DataflowInput{ + new_node1, + 0, + }, + }, + }, + OpenDataflowEdge{ + DataflowEdge{ + DataflowOutput{ + new_node0, + 0, + }, + DataflowInput{ + new_node1, + 1, + }, + }, + }, + }, + {i0}, + { + DataflowOutput{ + new_node0, + 0, + }, + DataflowOutput{ + new_node1, + 0, + }, + }, + }; + + CHECK(result_data == correct_data); + + // because get_graph_data only uses matchall nodes which don't require as + // much updating, we also add test cases for the query methods with concrete + // queries to check the through-node-permutation querying logic + SUBCASE("query_nodes(NodeQuery)") { + SUBCASE("check access to old nodes") { + std::unordered_set result_nodes = + result.query_nodes(NodeQuery{n0}); + std::unordered_set correct = {}; + CHECK(result_nodes == correct); + } + + SUBCASE("check access to new nodes") { + std::unordered_set result_nodes = + result.query_nodes(NodeQuery{new_node0}); + std::unordered_set correct = {new_node0}; + CHECK(result_nodes == correct); + } + } + + SUBCASE("query_edges(OpenDataflowEdgeQuery)") { + SUBCASE("check access to old edges") { + OpenDataflowEdgeQuery query = OpenDataflowEdgeQuery{ + dataflow_input_edge_query_for_edge( + DataflowInputEdge{i0, DataflowInput{n0, 0}}), + dataflow_edge_query_for_edge( + DataflowEdge{n0_output, DataflowInput{n1, 1}}), + }; + std::unordered_set result_nodes = + result.query_edges(query); + std::unordered_set correct = {}; + CHECK(result_nodes == correct); + } + + SUBCASE("check access to new edges") { + DataflowEdge new_standard_edge = DataflowEdge{ + DataflowOutput{new_node0, 0}, + DataflowInput{new_node1, 1}, + }; + DataflowInputEdge new_input_edge = DataflowInputEdge{ + i0, + DataflowInput{new_node0, 0}, + }; + OpenDataflowEdgeQuery query = OpenDataflowEdgeQuery{ + dataflow_input_edge_query_for_edge(new_input_edge), + dataflow_edge_query_for_edge(new_standard_edge), + }; + + std::unordered_set result_nodes = + result.query_edges(query); + std::unordered_set correct = { + OpenDataflowEdge{new_standard_edge}, + OpenDataflowEdge{new_input_edge}, + }; + + CHECK(result_nodes == correct); + } + } + + SUBCASE("query_outputs(DataflowOutputQuery)") { + SUBCASE("check access to old outputs") { + DataflowOutput old_output = n0_output; + + DataflowOutputQuery query = + dataflow_output_query_for_output(old_output); + std::unordered_set result_outputs = + result.query_outputs(query); + + std::unordered_set correct = {}; + + CHECK(result_outputs == correct); + } + + SUBCASE("check access to new outputs") { + DataflowOutput new_output = DataflowOutput{new_node0, 0}; + + DataflowOutputQuery query = + dataflow_output_query_for_output(new_output); + std::unordered_set result_outputs = + result.query_outputs(query); + + std::unordered_set correct = {new_output}; + + CHECK(result_outputs == correct); + } + } + } +} From ae8bbf915b3de5a376d741b516944b75bb8cdfb3 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 14 Sep 2024 15:10:02 -0700 Subject: [PATCH 04/12] Add tool for exporting and visualizing model architectures and SP decompositions (#1490) * Start on pcg builder * Add tests and some implementation for pcg builder * Add pcg tests, make dtgen constructors explicit to fix bug * Add remainder of PCG tests * Fix build issues in local-execution * Format * Address Reyna comments, add topological_order function for PCG * Pre multidigraph refactor * Removing visitable from sp code * Add open dataflow graph, start to replace pcg dataflow graph * Start refactoring substitutions * Add utility functions to support pattern matching * Pre-refactor inputs * Fix proj url * Get back to substitutions, now with unordered graph inputs * Get substitutions building * substitutions-tests now builds * Fix bug in filter, pass some initial substitution tests * Add tests for fmt::to_string, fix some substitutions bugs * Pass initial unit tests for find_pattern_matches * Start on unit tests for pcg pattern * Pass initial test for find_pattern_matches * Fix small build issue in tests * Format * Sync tests in CI with tests in proj * Fix minor build errors in kernels and local-execution * Format * Remove outdated code * More outdated code removal * More cleanup, add test for sp decomposition * Pull apart containers.h * More sp testing and fixes * Break up graph algorithms.h * Pre- full SP algo commit * Add initial implementation and tests for cbc decomposition and inverse line graph * Pass test for get_inverse_line_graph * Add new multidigraph * Fix get_inverse_line_graph to return a MultiDiGraph instead of a DiGraph * Add tests for parallel and series reduction finding * Add really rough implementation of valdez sp decomposition * Fix local-execution build * Add implementations and tests for applying series/parallel reductions * Add transformer pcg prototype * Format * Clean up sp decomposition interface and tests * Format * Add comments for top-level substitutions functions, add proj doxygen support * Start sketching out substitutions code * Fix build errors * Add ability to permute node ids * Cleanup and start to test new substitutions code * Add test case for evaluate_substitution_output * Add naive isomorphism detection code * Add graph inputs to open dataflow graph isomorphism * Add input permutation to evaluate_substitution_output * Update based on review comments * Fix permute_node_ids * Add test for permute_input_ids * Add models makefile and test * Update * Pass test * Enhance Transformer implementation * Reflect review comments * Migrate over to mutable implementation of apply_substitution * Add fast isomorphism checking and an initial implementation of full substitution logic * Pass initial full substitutions test * Cleanup old isomorphism checking code * Fix post-merge bugs * [WIP] Save initial refactor * Fix broken pcg builder test * Format * Reorganize code and remove some outdated code pre-code-review * Format * Implement actual encorder decoder architecture * Remove duplicated definition * Update based on review * Update argument order * Address review comments * Address missed comment * Remove latex dependency to avoid CI out-of-disk-space * Format * Implement most of the shape inference and ComputationGraphBuilder support * Fix bug in LayerNorm shape inference tests, disable {?} doctest default * Fix transformer test * Format * Get initial export-model-arch binary building * Actually dump valid json from export-model-arch * Some minor polishing of export-model-arch * Add binary sp tree logic and start on sp decomposition of computation graphs * Fix sp decomposition export of transformer (a lot of cleanup now needed) * Flesh out export-model-arch CLI and features * Format * Add split_test model * Add single_operator model for testing * Add substitution-to-dot and export-model-arch build to CI * Cleanup generic_binary_sp_decomposition_tree * Format * Fix substitution-to-dot name in CI * Add testing for cli_get_help_message * Add testing for cli_parse * Add missing include in export_model_arch * Add basic test for cli_parse on raw argv * Rename serial-parallel -> series-parallel * Add a bunch of testing for new code * Add tests for computation graph sp decomposition * Format * Fix build error in export-model-arch --------- Co-authored-by: hsdfzhsdfz --- .../{build_libs.sh => build_target.sh} | 0 .../helpers/{test_libs.sh => test_target.sh} | 0 .github/workflows/per-lib-check.yml | 42 +- .proj.toml | 2 + CMakeLists.txt | 1 + bin/CMakeLists.txt | 6 +- bin/export-model-arch/CMakeLists.txt | 12 + .../json_sp_model_export.struct.toml | 27 + .../src/export_model_arch.cc | 208 +++++++ .../CMakeLists.txt | 0 .../substitution_to_dot.cc | 21 +- cmake/flexflow-utils.cmake | 5 + lib/compiler/include/compiler/graph_utils.h | 6 +- .../include/compiler/machine_mapping.h | 3 +- .../compiler/optimal_cost_state.struct.toml | 6 +- ...omputation_graph_binary_sp_decomposition.h | 30 + ..._graph_binary_sp_decomposition.struct.toml | 22 + ...tion_graph_series_parallel_decomposition.h | 17 + ...mputation_graph_binary_sp_decomposition.cc | 90 +++ ...ion_graph_series_parallel_decomposition.cc | 98 ++++ lib/compiler/src/graph_utils.cc | 14 +- lib/compiler/src/machine_mapping.cc | 67 +-- lib/compiler/test/CMakeLists.txt | 1 + ...ion_graph_series_parallel_decomposition.cc | 340 ++++++++++++ lib/compiler/test/src/test_generator.h | 10 +- lib/kernels/include/kernels/accessor.h | 16 +- lib/kernels/include/kernels/array_shape.h | 2 +- .../include/kernels/attention_kernels.h | 1 - .../include/kernels/batch_matmul_kernels.h | 1 - .../include/kernels/initializer_kernels.h | 1 + lib/kernels/src/allocation.cc | 1 + lib/kernels/src/cpu/initializer_kernels.cc | 2 +- lib/kernels/src/cuda/embedding_kernels.cu | 8 +- lib/kernels/src/cuda/ops/combine_kernels.cu | 2 +- .../src/cuda/ops/element_unary_kernels.cu | 12 +- lib/kernels/src/cuda/ops/partition_kernels.cu | 12 +- lib/kernels/src/cuda/ops/reduction_kernels.cu | 2 +- lib/kernels/src/cuda/ops/replicate_kernels.cu | 2 +- lib/kernels/src/cuda/ops/reshape_kernels.cu | 4 +- lib/kernels/src/hip/ops/replicate_kernels.cpp | 19 +- lib/kernels/src/hip/ops/reshape_kernels.cpp | 4 +- .../include/local-execution/cost_estimate.h | 4 +- .../local-execution/legion_tensor_shape.h | 3 +- .../local-execution/local_slots_backing.h | 3 + .../local-execution/local_training_backing.h | 1 + .../include/local-execution/op_arg_ref.h | 2 +- .../local-execution/op_task_invocation.h | 4 - .../include/local-execution/sim_environment.h | 2 +- .../local-execution/task_registry.struct.toml | 1 + .../src/legion_tensor_shape.cc | 1 + .../src/local_cost_estimator.cc | 2 +- .../src/local_slots_backing.cc | 2 + .../src/local_training_backing.cc | 1 + lib/local-execution/src/op_task_signature.cc | 1 + lib/local-execution/src/ops/element_unary.cc | 1 + .../test/src/test_local_cost_estimator.cc | 6 +- .../test/src/test_local_slots_backing.cc | 22 +- .../test/src/test_local_task_arg_accessor.cc | 4 +- .../test/src/test_task_registry.cc | 4 +- lib/models/CMakeLists.txt | 3 +- .../include/models/split_test/split_test.h | 19 + .../models/{ => transformer}/transformer.h | 6 +- .../transformer_config.struct.toml | 0 .../src/models/split_test/split_test.cc | 39 ++ .../models/{ => transformer}/transformer.cc | 10 +- lib/models/test/src/models/transformer.cc | 4 +- lib/op-attrs/include/op-attrs/datatype.h | 9 +- .../op-attrs/datatype_value.variant.toml | 25 + lib/op-attrs/include/op-attrs/dim_ordered.h | 7 +- .../include/op-attrs/dim_ordered/slice.h | 4 +- .../include/op-attrs/dim_ordered/transform.h | 4 +- .../include/op-attrs/dim_ordered/zip.h | 4 +- lib/op-attrs/include/op-attrs/ops/broadcast.h | 3 + .../op-attrs/ops/conv_2d_attrs.struct.toml | 3 +- .../ops/element_unary_attrs.struct.toml | 6 +- .../op-attrs/ops/embedding_attrs.struct.toml | 2 + .../op-attrs/ops/linear_attrs.struct.toml | 4 +- lib/op-attrs/include/op-attrs/tensor_shape.h | 5 - .../op-attrs/computation_graph_op_attrs.cc | 15 + lib/op-attrs/src/op-attrs/ops/broadcast.cc | 18 + .../src/op-attrs/parallel_tensor_dims.cc | 4 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 8 +- lib/op-attrs/src/op-attrs/tensor_shape.cc | 31 -- lib/op-attrs/test/src/datatype.cc | 4 +- lib/op-attrs/test/src/dim_ordered/slice.cc | 4 +- .../src/op-attrs/dim_ordered/enumerate.cc | 2 +- .../test/src/op-attrs/dim_ordered/zip.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/dropout.cc | 1 + .../test/src/op-attrs/ops/layer_norm.cc | 2 +- lib/op-attrs/test/src/op-attrs/ops/softmax.cc | 1 + lib/op-attrs/test/src/op-attrs/tensor_dims.cc | 1 + .../test/src/op-attrs/tensor_shape.cc | 64 --- lib/op-attrs/test/src/ops/attention.cc | 5 +- lib/op-attrs/test/src/ops/batch_matmul.cc | 5 +- lib/op-attrs/test/src/ops/cast.cc | 5 +- lib/op-attrs/test/src/ops/combine.cc | 5 +- lib/op-attrs/test/src/ops/element_binary.cc | 5 +- lib/op-attrs/test/src/ops/element_unary.cc | 5 +- lib/op-attrs/test/src/ops/embedding.cc | 5 +- lib/op-attrs/test/src/ops/linear.cc | 5 +- lib/op-attrs/test/src/ops/reduction.cc | 5 +- lib/op-attrs/test/src/ops/repartition.cc | 5 +- lib/op-attrs/test/src/ops/replicate.cc | 4 +- lib/op-attrs/test/src/test_operator_attrs.cc | 14 +- .../test/src/test_regularizer_attrs.cc | 4 +- lib/pcg/CMakeLists.txt | 1 + lib/pcg/include/pcg/computation_graph.h | 14 + .../computation_graph_edge.h | 14 + .../computation_graph_edge.struct.toml | 16 + .../include/pcg/computation_graph_builder.h | 35 +- lib/pcg/include/pcg/file_format/file_format.h | 19 - .../include/pcg/file_format/keyed_variant.h | 29 +- .../pcg/file_format/v1/data_type_value.h | 2 +- lib/pcg/include/pcg/file_format/v1/graphs.h | 26 - .../v1/graphs/v1_dataflow_graph.struct.toml | 9 +- .../v1/graphs/v1_labelled_dataflow_graph.h | 18 +- .../v1_labelled_dataflow_graph.struct.toml | 12 +- lib/pcg/include/pcg/file_format/v1/v1.h | 9 - .../pcg/file_format/v1/v1_computation_graph.h | 17 + .../v1/v1_computation_graph.struct.toml | 18 + .../v1/v1_parallel_computation_graph.h | 13 + .../v1_parallel_computation_graph.struct.toml | 18 + .../constant_initializer_attrs.struct.toml | 7 +- lib/pcg/include/pcg/layer_attrs.struct.toml | 2 +- .../parallel_layer_attrs.struct.toml | 2 + .../parallel_tensor_attrs.struct.toml | 2 + lib/pcg/include/pcg/tensor_attrs.struct.toml | 1 + lib/pcg/src/file_format.cc | 14 - lib/pcg/src/file_format/v1/graphs.cc | 16 - lib/pcg/src/pcg/computation_graph.cc | 104 +++- .../computation_graph_edge.cc | 15 + lib/pcg/src/pcg/computation_graph_builder.cc | 165 ++++-- .../file_format/v1/v1_computation_graph.cc | 24 + .../v1/v1_parallel_computation_graph.cc | 12 + .../file_format/v1/v1_computation_graph.cc | 30 + .../v1/v1_parallel_computation_graph.cc | 36 ++ .../initializers/uniform_initializer_attrs.cc | 4 +- .../parallel_computation_graph_builder.cc | 6 +- .../src/test_computation_graph_builder.cc | 2 +- lib/pcg/test/src/test_machine_view.cc | 2 +- lib/pcg/test/src/test_strided_rectangle.cc | 2 +- lib/runtime/src/accessor.cc | 6 +- .../operator_attribute_value.variant.toml | 1 + .../sub_parallel_computation_graph.h | 4 - .../operator_pattern/get_attribute.cc | 4 +- .../tensor_pattern/get_attribute.cc | 6 +- .../operator_pattern/get_attribute.cc | 1 + .../test/src/substitutions/pcg_pattern.cc | 2 +- .../substitutions/unlabelled/pattern_split.cc | 2 +- .../unlabelled/unlabelled_graph_pattern.cc | 2 +- .../test/src/test_pattern_matches.cc | 4 +- lib/utils/CMakeLists.txt | 1 - .../utils/cli/cli_argument_key.variant.toml | 19 + .../utils/cli/cli_flag_key.struct.toml | 13 + .../utils/cli/cli_flag_spec.struct.toml | 28 + .../include/utils/cli/cli_get_help_message.h | 13 + lib/utils/include/utils/cli/cli_parse.h | 19 + .../include/utils/cli/cli_parse_result.h | 14 + .../utils/cli/cli_parse_result.struct.toml | 27 + .../cli_positional_argument_key.struct.toml | 13 + .../cli_positional_argument_spec.struct.toml | 31 ++ lib/utils/include/utils/cli/cli_spec.h | 20 + .../include/utils/cli/cli_spec.struct.toml | 29 + lib/utils/include/utils/containers.decl.h | 3 - lib/utils/include/utils/containers.h | 5 - .../utils/containers/enumerate_vector.h | 1 - lib/utils/include/utils/containers/foldl1.h | 29 + lib/utils/include/utils/containers/foldr1.h | 28 + .../include/utils/containers/generate_map.h | 4 +- .../include/utils/containers/get_first.h | 6 + lib/utils/include/utils/containers/maximum.h | 20 + .../include/utils/containers/multiset_union.h | 48 ++ .../utils/containers/require_no_duplicates.h | 40 ++ lib/utils/include/utils/containers/reversed.h | 11 +- .../include/utils/containers/set_minus.h | 10 + lib/utils/include/utils/containers/set_of.h | 19 + .../include/utils/containers/to_uppercase.h | 12 + .../containers/{as_vector.h => vector_of.h} | 6 +- lib/utils/include/utils/fmt/expected.h | 14 +- lib/utils/include/utils/fmt/map.h | 12 - lib/utils/include/utils/fmt/multiset.h | 12 - lib/utils/include/utils/fmt/optional.h | 12 - lib/utils/include/utils/fmt/pair.h | 12 - lib/utils/include/utils/fmt/set.h | 12 - lib/utils/include/utils/fmt/unordered_map.h | 12 - .../include/utils/fmt/unordered_multiset.h | 12 - lib/utils/include/utils/fmt/unordered_set.h | 12 - lib/utils/include/utils/fmt/variant.h | 12 - lib/utils/include/utils/fmt/vector.h | 12 - .../algorithms/get_subgraph_incoming_edges.h | 14 + .../get_cbc_decomposition.h | 3 + .../is_complete_bipartite_digraph.h | 14 + .../graph/digraph/algorithms/digraph_as_dot.h | 14 + .../digraph/algorithms/digraph_has_edge.h | 12 + .../algorithms/get_subgraph_outgoing_edges.h | 14 + .../algorithms/get_subgraph_successors.h | 14 + .../digraph/algorithms/transitive_closure.h | 12 + .../unordered_set_undirected_graph.h | 37 ++ .../algorithms/find_isomorphism.h | 1 - .../include/utils/graph/node/node.struct.toml | 1 + .../get_serial_parallel_decomposition.h | 17 - .../intermediate_sp_decomposition_tree.h | 13 - .../serial_parallel_decomposition.h | 22 - .../binary_sp_decomposition_tree.h | 23 + .../binary_sp_decomposition_tree.struct.toml | 22 + .../fmt.h | 63 +++ .../generic_binary_sp_decomposition_tree.h | 155 ++++++ .../get.h | 15 + .../get_leaves.h | 40 ++ .../get_left_child.h | 44 ++ .../get_node_type.h | 29 + .../get_num_tree_nodes.h | 40 ++ .../get_right_child.h | 44 ++ .../hash.h | 34 ++ .../generic_binary_sp_decomposition_tree/is.h | 25 + .../is_binary_sp_tree_left_associative.h | 34 ++ .../is_binary_sp_tree_right_associative.h | 34 ++ .../json.h | 103 ++++ .../make.h | 39 ++ .../require.h | 28 + .../transform.h | 43 ++ .../visit.h | 37 ++ ...eft_associative_binary_sp_tree_from_nary.h | 14 + .../nary_sp_tree_from_binary.h | 14 + ...ght_associative_binary_sp_tree_from_nary.h | 14 + .../get_series_parallel_decomposition.h | 17 + .../graph_generation.h | 12 +- .../intermediate_sp_decomposition_tree.h | 17 + ...rmediate_sp_decomposition_tree.struct.toml | 2 +- .../parallel_reduction.h | 6 +- .../parallel_reduction.struct.toml | 0 .../series_parallel_decomposition.h | 22 + ...eries_parallel_decomposition.variant.toml} | 6 +- .../series_parallel_splits.h} | 32 +- .../series_reduction.h | 6 +- .../series_reduction.struct.toml | 0 .../sink_settings.enum.toml | 0 .../source_settings.enum.toml | 0 .../sp_decomposition_tree_node_type.enum.toml | 17 + .../split_type.enum.toml | 2 +- .../graph/undirected/algorithms/get_edges.h | 12 + .../algorithms/get_neighboring_nodes.h | 13 + .../graph/undirected/i_undirected_graph.h | 2 +- .../graph/undirected/undirected_edge_query.h | 2 + lib/utils/include/utils/hash/multiset.h | 20 + .../include/utils/hash/unordered_multiset.h | 20 + .../include/utils/json/check_is_jsonable.h | 17 + .../utils/json/is_json_deserializable.h | 25 + .../include/utils/json/is_json_serializable.h | 24 + lib/utils/include/utils/json/is_jsonable.h | 18 + lib/utils/include/utils/json/optional.h | 33 ++ lib/utils/include/utils/json/variant.h | 89 +++ .../utils/{json.h => json/visitable.h} | 125 +---- lib/utils/include/utils/optional.h | 23 +- lib/utils/include/utils/rapidcheck/optional.h | 21 + lib/utils/include/utils/required.h | 10 +- lib/utils/include/utils/stack_string.h | 6 +- lib/utils/include/utils/stack_vector.h | 6 +- .../src/utils/cli/cli_get_help_message.cc | 101 ++++ lib/utils/src/utils/cli/cli_parse.cc | 96 ++++ lib/utils/src/utils/cli/cli_parse_result.cc | 14 + lib/utils/src/utils/cli/cli_spec.cc | 37 ++ lib/utils/src/utils/containers/as_vector.cc | 1 - .../src/utils/containers/enumerate_vector.cc | 1 + lib/utils/src/utils/containers/foldl1.cc | 1 + lib/utils/src/utils/containers/foldr1.cc | 1 + .../utils/containers/get_element_counts.cc | 4 +- lib/utils/src/utils/containers/maximum.cc | 1 + .../src/utils/containers/multiset_union.cc | 1 + .../utils/containers/require_no_duplicates.cc | 1 + lib/utils/src/utils/containers/set_of.cc | 1 + .../src/utils/containers/to_uppercase.cc | 10 + lib/utils/src/utils/containers/vector_of.cc | 1 + lib/utils/src/utils/graph/algorithms.cc | 4 - .../algorithms/get_subgraph_incoming_edges.cc | 24 + .../get_cbc_decomposition.cc | 37 +- .../is_complete_bipartite_digraph.cc | 29 + .../digraph/algorithms/digraph_as_dot.cc | 32 ++ .../digraph/algorithms/digraph_has_edge.cc | 13 + .../algorithms/get_imm_dominators_map.cc | 6 +- .../algorithms/get_subgraph_outgoing_edges.cc | 16 + .../algorithms/get_subgraph_successors.cc | 16 + .../digraph/algorithms/transitive_closure.cc | 51 ++ .../algorithms/transitive_reduction.cc | 64 ++- .../graph/instances/adjacency_digraph.cc | 6 +- .../unordered_set_undirected_graph.cc | 58 ++ .../algorithms/get_edge_counts.cc | 4 +- .../algorithms/find_isomorphisms.cc | 6 +- .../intermediate_sp_decomposition_tree.cc | 48 -- .../binary_sp_decomposition_tree.cc | 43 ++ .../fmt.cc | 1 + .../generic_binary_sp_decomposition_tree.cc | 1 + .../get.cc | 1 + .../get_leaves.cc | 1 + .../get_left_child.cc | 1 + .../get_node_type.cc | 1 + .../get_num_tree_nodes.cc | 1 + .../get_right_child.cc | 1 + .../hash.cc | 1 + .../is.cc | 1 + .../is_binary_sp_tree_left_associative.cc | 1 + .../is_binary_sp_tree_right_associative.cc | 1 + .../json.cc | 1 + .../make.cc | 1 + .../require.cc | 1 + .../transform.cc | 1 + .../visit.cc | 1 + ...ft_associative_binary_sp_tree_from_nary.cc | 75 +++ .../nary_sp_tree_from_binary.cc | 12 + ...ht_associative_binary_sp_tree_from_nary.cc | 72 +++ .../get_series_parallel_decomposition.cc} | 48 +- .../graph_generation.cc | 12 +- .../intermediate_sp_decomposition_tree.cc | 84 +++ .../parallel_reduction.cc | 2 +- .../series_parallel_decomposition.cc} | 45 +- .../series_parallel_splits.cc} | 30 +- .../series_reduction.cc | 2 +- .../graph/undirected/algorithms/get_edges.cc | 10 + .../algorithms/get_neighboring_nodes.cc | 19 + .../graph/undirected/undirected_edge_query.cc | 4 + lib/utils/src/utils/hash/multiset.cc | 1 + .../src/utils/hash/unordered_multiset.cc | 1 + lib/utils/src/utils/json/check_is_jsonable.cc | 1 + .../src/utils/json/is_json_deserializable.cc | 1 + .../src/utils/json/is_json_serializable.cc | 1 + lib/utils/src/utils/json/is_jsonable.cc | 1 + lib/utils/src/utils/json/optional.cc | 1 + lib/utils/src/utils/rapidcheck/optional.cc | 1 + .../test/common/include/test/utils/all.h | 2 - .../check_without_stringify.h} | 0 .../include/test/utils/doctest/fmt/expected.h | 18 + .../include/test/utils/doctest/fmt/map.h | 18 + .../include/test/utils/doctest/fmt/multiset.h | 18 + .../include/test/utils/doctest/fmt/optional.h | 18 + .../include/test/utils/doctest/fmt/pair.h | 18 + .../include/test/utils/doctest/fmt/set.h | 18 + .../test/utils/doctest/fmt/unordered_map.h | 18 + .../utils/doctest/fmt/unordered_multiset.h | 18 + .../test/utils/doctest/fmt/unordered_set.h | 18 + .../include/test/utils/doctest/fmt/variant.h | 18 + .../include/test/utils/doctest/fmt/vector.h | 18 + lib/utils/test/common/src/common.cc | 1 - .../src/test/utils/doctest/fmt/expected.cc | 1 + .../common/src/test/utils/doctest/fmt/map.cc | 1 + .../src/test/utils/doctest/fmt/multiset.cc | 1 + .../src/test/utils/doctest/fmt/optional.cc | 1 + .../common/src/test/utils/doctest/fmt/pair.cc | 1 + .../common/src/test/utils/doctest/fmt/set.cc | 1 + .../test/utils/doctest/fmt/unordered_map.cc | 1 + .../utils/doctest/fmt/unordered_multiset.cc | 1 + .../test/utils/doctest/fmt/unordered_set.cc | 1 + .../src/test/utils/doctest/fmt/variant.cc | 1 + .../src/test/utils/doctest/fmt/vector.cc | 1 + lib/utils/test/src/test_algorithms.cc | 2 +- lib/utils/test/src/test_containers.cc | 6 +- .../src/test_deduplicated_priority_queue.cc | 2 +- lib/utils/test/src/test_disjoint_set.cc | 2 +- lib/utils/test/src/test_dot_file.cc | 2 +- lib/utils/test/src/test_format.cc | 2 +- lib/utils/test/src/test_hash.cc | 2 +- lib/utils/test/src/test_multidigraph.cc | 2 +- lib/utils/test/src/test_random_utils.cc | 2 +- lib/utils/test/src/test_sequence.cc | 2 +- lib/utils/test/src/test_stack_map.cc | 2 +- lib/utils/test/src/test_stack_string.cc | 2 +- lib/utils/test/src/test_stack_vector.cc | 2 +- lib/utils/test/src/test_tuple.cc | 2 +- lib/utils/test/src/test_type_index.cc | 2 +- lib/utils/test/src/test_undirected_graph.cc | 3 +- lib/utils/test/src/test_variant.cc | 2 +- lib/utils/test/src/test_vector.cc | 2 +- .../algorithms/bidict_from_enumerating.cc | 2 +- lib/utils/test/src/utils/bidict/bidict.cc | 6 +- .../bidict/try_merge_nondisjoint_bidicts.cc | 4 +- .../src/utils/cli/cli_get_help_message.cc | 519 ++++++++++++++++++ lib/utils/test/src/utils/cli/cli_parse.cc | 477 ++++++++++++++++ .../test/src/utils/containers/contains_key.cc | 5 +- .../test/src/utils/containers/enumerate.cc | 26 +- lib/utils/test/src/utils/containers/extend.cc | 4 +- lib/utils/test/src/utils/containers/filter.cc | 12 +- .../src/utils/containers/filtermap_keys.cc | 6 +- .../src/utils/containers/filtermap_values.cc | 6 +- .../test/src/utils/containers/filtrans.cc | 6 +- lib/utils/test/src/utils/containers/foldl1.cc | 27 + lib/utils/test/src/utils/containers/foldr1.cc | 27 + .../utils/containers/get_all_permutations.cc | 5 +- .../utils/containers/get_element_counts.cc | 2 +- .../src/utils/containers/inplace_filter.cc | 13 +- .../test/src/utils/containers/intersection.cc | 4 +- .../test/src/utils/containers/maximum.cc | 60 ++ .../src/utils/containers/multiset_union.cc | 29 + lib/utils/test/src/utils/containers/repeat.cc | 2 +- .../utils/containers/require_no_duplicates.cc | 62 +++ .../test/src/utils/containers/reversed.cc | 27 + .../test/src/utils/containers/to_uppercase.cc | 15 + .../test/src/utils/containers/transform.cc | 6 +- .../try_merge_nondisjoint_unordered_maps.cc | 6 +- .../utils/containers/unordered_multiset_of.cc | 2 +- .../src/utils/containers/unordered_set_of.cc | 2 +- .../test/src/utils/containers/vector_of.cc | 17 + .../src/utils/containers/without_order.cc | 2 +- lib/utils/test/src/utils/expected.cc | 4 +- lib/utils/test/src/utils/fmt/expected.cc | 23 +- lib/utils/test/src/utils/fmt/map.cc | 2 +- lib/utils/test/src/utils/fmt/optional.cc | 2 +- lib/utils/test/src/utils/fmt/pair.cc | 2 +- lib/utils/test/src/utils/fmt/set.cc | 2 +- lib/utils/test/src/utils/fmt/unordered_map.cc | 3 +- lib/utils/test/src/utils/fmt/unordered_set.cc | 4 +- lib/utils/test/src/utils/fmt/variant.cc | 2 +- lib/utils/test/src/utils/fmt/vector.cc | 2 +- lib/utils/test/src/utils/graph/cow_ptr_t.cc | 2 +- .../algorithms/get_subgraph_incoming_edges.cc | 43 ++ .../algorithms/get_subgraph_outgoing_edges.cc | 3 +- .../unordered_open_dataflow_graph.cc | 4 +- .../get_cbc_decomposition.cc | 45 ++ .../is_complete_bipartite_graph.cc | 175 ++++++ .../get_inverse_line_graph.cc | 23 + .../graph/digraph/algorithms/is_acyclic.cc | 1 + .../digraph/algorithms/transitive_closure.cc | 50 ++ .../algorithms/transitive_reduction.cc | 62 +++ .../fmt.cc | 51 ++ .../get_leaves.cc | 86 +++ .../get_left_child.cc | 41 ++ .../get_num_tree_nodes.cc | 85 +++ .../get_right_child.cc | 41 ++ .../hash.cc | 117 ++++ .../is_binary_sp_tree_left_associative.cc | 102 ++++ .../is_binary_sp_tree_right_associative.cc | 102 ++++ .../json.cc | 131 +++++ .../transform.cc | 28 + ...ft_associative_binary_sp_tree_from_nary.cc | 95 ++++ .../nary_sp_tree_from_binary.cc | 132 +++++ ...ht_associative_binary_sp_tree_from_nary.cc | 93 ++++ .../get_series_parallel_decomposition.cc} | 108 ++-- .../intermediate_sp_decomposition_tree.cc | 8 +- .../parallel_reduction.cc | 2 +- .../series_parallel_decomposition.cc} | 68 +-- .../series_reduction.cc | 2 +- lib/utils/test/src/utils/hash/multiset.cc | 34 ++ .../test/src/utils/hash/unordered_multiset.cc | 34 ++ lib/utils/test/src/utils/json/optional.cc | 49 ++ .../src/utils/{ => rapidcheck}/optional.cc | 7 +- 443 files changed, 8032 insertions(+), 1245 deletions(-) rename .github/workflows/helpers/{build_libs.sh => build_target.sh} (100%) rename .github/workflows/helpers/{test_libs.sh => test_target.sh} (100%) create mode 100644 bin/export-model-arch/CMakeLists.txt create mode 100644 bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml create mode 100644 bin/export-model-arch/src/export_model_arch.cc rename bin/{substitutions-to-dot => substitution-to-dot}/CMakeLists.txt (100%) rename bin/{substitutions-to-dot => substitution-to-dot}/substitution_to_dot.cc (89%) create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h create mode 100644 lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc create mode 100644 lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc create mode 100644 lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc create mode 100644 lib/models/include/models/split_test/split_test.h rename lib/models/include/models/{ => transformer}/transformer.h (90%) rename lib/models/include/models/{ => transformer}/transformer_config.struct.toml (100%) create mode 100644 lib/models/src/models/split_test/split_test.cc rename lib/models/src/models/{ => transformer}/transformer.cc (95%) create mode 100644 lib/op-attrs/include/op-attrs/datatype_value.variant.toml delete mode 100644 lib/op-attrs/test/src/op-attrs/tensor_shape.cc create mode 100644 lib/pcg/include/pcg/computation_graph/computation_graph_edge.h create mode 100644 lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml delete mode 100644 lib/pcg/include/pcg/file_format/file_format.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/graphs.h delete mode 100644 lib/pcg/include/pcg/file_format/v1/v1.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml delete mode 100644 lib/pcg/src/file_format.cc delete mode 100644 lib/pcg/src/file_format/v1/graphs.cc create mode 100644 lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc create mode 100644 lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc create mode 100644 lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc create mode 100644 lib/utils/include/utils/cli/cli_argument_key.variant.toml create mode 100644 lib/utils/include/utils/cli/cli_flag_key.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_flag_spec.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_get_help_message.h create mode 100644 lib/utils/include/utils/cli/cli_parse.h create mode 100644 lib/utils/include/utils/cli/cli_parse_result.h create mode 100644 lib/utils/include/utils/cli/cli_parse_result.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml create mode 100644 lib/utils/include/utils/cli/cli_spec.h create mode 100644 lib/utils/include/utils/cli/cli_spec.struct.toml create mode 100644 lib/utils/include/utils/containers/foldl1.h create mode 100644 lib/utils/include/utils/containers/foldr1.h create mode 100644 lib/utils/include/utils/containers/maximum.h create mode 100644 lib/utils/include/utils/containers/multiset_union.h create mode 100644 lib/utils/include/utils/containers/require_no_duplicates.h create mode 100644 lib/utils/include/utils/containers/set_of.h create mode 100644 lib/utils/include/utils/containers/to_uppercase.h rename lib/utils/include/utils/containers/{as_vector.h => vector_of.h} (54%) create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h create mode 100644 lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h delete mode 100644 lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h create mode 100644 lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/graph_generation.h (56%) create mode 100644 lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/intermediate_sp_decomposition_tree.struct.toml (90%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.h (70%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.struct.toml (100%) create mode 100644 lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h rename lib/utils/include/utils/graph/{serial_parallel/serial_parallel_decomposition.variant.toml => series_parallel/series_parallel_decomposition.variant.toml} (62%) rename lib/utils/include/utils/graph/{serial_parallel/serial_parallel_splits.h => series_parallel/series_parallel_splits.h} (59%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/series_reduction.h (77%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/series_reduction.struct.toml (100%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/sink_settings.enum.toml (100%) rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/source_settings.enum.toml (100%) create mode 100644 lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml rename lib/utils/include/utils/graph/{serial_parallel => series_parallel}/split_type.enum.toml (90%) create mode 100644 lib/utils/include/utils/graph/undirected/algorithms/get_edges.h create mode 100644 lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h create mode 100644 lib/utils/include/utils/hash/multiset.h create mode 100644 lib/utils/include/utils/hash/unordered_multiset.h create mode 100644 lib/utils/include/utils/json/check_is_jsonable.h create mode 100644 lib/utils/include/utils/json/is_json_deserializable.h create mode 100644 lib/utils/include/utils/json/is_json_serializable.h create mode 100644 lib/utils/include/utils/json/is_jsonable.h create mode 100644 lib/utils/include/utils/json/optional.h create mode 100644 lib/utils/include/utils/json/variant.h rename lib/utils/include/utils/{json.h => json/visitable.h} (52%) create mode 100644 lib/utils/include/utils/rapidcheck/optional.h create mode 100644 lib/utils/src/utils/cli/cli_get_help_message.cc create mode 100644 lib/utils/src/utils/cli/cli_parse.cc create mode 100644 lib/utils/src/utils/cli/cli_parse_result.cc create mode 100644 lib/utils/src/utils/cli/cli_spec.cc delete mode 100644 lib/utils/src/utils/containers/as_vector.cc create mode 100644 lib/utils/src/utils/containers/enumerate_vector.cc create mode 100644 lib/utils/src/utils/containers/foldl1.cc create mode 100644 lib/utils/src/utils/containers/foldr1.cc create mode 100644 lib/utils/src/utils/containers/maximum.cc create mode 100644 lib/utils/src/utils/containers/multiset_union.cc create mode 100644 lib/utils/src/utils/containers/require_no_duplicates.cc create mode 100644 lib/utils/src/utils/containers/set_of.cc create mode 100644 lib/utils/src/utils/containers/to_uppercase.cc create mode 100644 lib/utils/src/utils/containers/vector_of.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc create mode 100644 lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc delete mode 100644 lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc rename lib/utils/src/utils/graph/{serial_parallel/get_serial_parallel_decomposition.cc => series_parallel/get_series_parallel_decomposition.cc} (62%) rename lib/utils/src/utils/graph/{serial_parallel => series_parallel}/graph_generation.cc (79%) create mode 100644 lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc rename lib/utils/src/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.cc (93%) rename lib/utils/src/utils/graph/{serial_parallel/serial_parallel_decomposition.cc => series_parallel/series_parallel_decomposition.cc} (52%) rename lib/utils/src/utils/graph/{serial_parallel/serial_parallel_splits.cc => series_parallel/series_parallel_splits.cc} (65%) rename lib/utils/src/utils/graph/{serial_parallel => series_parallel}/series_reduction.cc (97%) create mode 100644 lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc create mode 100644 lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc create mode 100644 lib/utils/src/utils/hash/multiset.cc create mode 100644 lib/utils/src/utils/hash/unordered_multiset.cc create mode 100644 lib/utils/src/utils/json/check_is_jsonable.cc create mode 100644 lib/utils/src/utils/json/is_json_deserializable.cc create mode 100644 lib/utils/src/utils/json/is_json_serializable.cc create mode 100644 lib/utils/src/utils/json/is_jsonable.cc create mode 100644 lib/utils/src/utils/json/optional.cc create mode 100644 lib/utils/src/utils/rapidcheck/optional.cc delete mode 100644 lib/utils/test/common/include/test/utils/all.h rename lib/utils/test/common/include/test/utils/{doctest.h => doctest/check_without_stringify.h} (100%) create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/expected.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/map.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/optional.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/pair.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/set.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/variant.h create mode 100644 lib/utils/test/common/include/test/utils/doctest/fmt/vector.h delete mode 100644 lib/utils/test/common/src/common.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/map.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/set.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc create mode 100644 lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc create mode 100644 lib/utils/test/src/utils/cli/cli_get_help_message.cc create mode 100644 lib/utils/test/src/utils/cli/cli_parse.cc create mode 100644 lib/utils/test/src/utils/containers/foldl1.cc create mode 100644 lib/utils/test/src/utils/containers/foldr1.cc create mode 100644 lib/utils/test/src/utils/containers/maximum.cc create mode 100644 lib/utils/test/src/utils/containers/multiset_union.cc create mode 100644 lib/utils/test/src/utils/containers/require_no_duplicates.cc create mode 100644 lib/utils/test/src/utils/containers/reversed.cc create mode 100644 lib/utils/test/src/utils/containers/to_uppercase.cc create mode 100644 lib/utils/test/src/utils/containers/vector_of.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc create mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc rename lib/utils/test/src/utils/graph/{serial_parallel/get_serial_parallel_decomposition.cc => series_parallel/get_series_parallel_decomposition.cc} (50%) rename lib/utils/test/src/utils/graph/{serial_parallel => series_parallel}/intermediate_sp_decomposition_tree.cc (83%) rename lib/utils/test/src/utils/graph/{serial_parallel => series_parallel}/parallel_reduction.cc (99%) rename lib/utils/test/src/utils/graph/{serial_parallel/serial_parallel_decomposition.cc => series_parallel/series_parallel_decomposition.cc} (66%) rename lib/utils/test/src/utils/graph/{serial_parallel => series_parallel}/series_reduction.cc (99%) create mode 100644 lib/utils/test/src/utils/hash/multiset.cc create mode 100644 lib/utils/test/src/utils/hash/unordered_multiset.cc create mode 100644 lib/utils/test/src/utils/json/optional.cc rename lib/utils/test/src/utils/{ => rapidcheck}/optional.cc (67%) diff --git a/.github/workflows/helpers/build_libs.sh b/.github/workflows/helpers/build_target.sh similarity index 100% rename from .github/workflows/helpers/build_libs.sh rename to .github/workflows/helpers/build_target.sh diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_target.sh similarity index 100% rename from .github/workflows/helpers/test_libs.sh rename to .github/workflows/helpers/test_target.sh diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index 639f4d82b5..a5ac6fd29f 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -62,71 +62,79 @@ jobs: - name: Build utils run: | - build_libs.sh utils + build_target.sh utils - name: Build op-attrs run: | - build_libs.sh op-attrs + build_target.sh op-attrs - name: Build pcg run: | - build_libs.sh pcg + build_target.sh pcg - name: Build kernels run: | - build_libs.sh kernels + build_target.sh kernels - name: Build substitutions run: | - build_libs.sh substitutions + build_target.sh substitutions - name: Build compiler run: | - build_libs.sh compiler + build_target.sh compiler - name: Build substitution-generator run: | - build_libs.sh substitution-generator + build_target.sh substitution-generator - name: Build local-execution run: | - build_libs.sh local-execution + build_target.sh local-execution - name: Build models run: | - build_libs.sh models + build_target.sh models + + - name: Build substitution-to-dot + run: | + build_target.sh substitution-to-dot + + - name: Build export-model-arch + run: | + build_target.sh export-model-arch - name: Test utils run: | - test_libs.sh utils + test_target.sh utils - name: Test op-attrs run: | - test_libs.sh op-attrs + test_target.sh op-attrs - name: Test pcg run: | - test_libs.sh pcg + test_target.sh pcg - name: Test substitutions run: | - test_libs.sh substitutions + test_target.sh substitutions # - name: Test compiler # run: | - # test_libs.sh compiler + # test_target.sh compiler - name: Test substitution-generator run: | - test_libs.sh substitution-generator + test_target.sh substitution-generator - name: Test local-execution run: | - test_libs.sh local-execution + test_target.sh local-execution - name: Test models run: | - test_libs.sh models + test_target.sh models - name: Generate code coverage run: | diff --git a/.proj.toml b/.proj.toml index 721d212e31..5592f184ad 100644 --- a/.proj.toml +++ b/.proj.toml @@ -13,6 +13,8 @@ build_targets = [ "substitution-generator", "local-execution", "models", + "export-model-arch", + "substitution-to-dot", ] test_targets = [ diff --git a/CMakeLists.txt b/CMakeLists.txt index a518931ac5..792126449b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,6 +74,7 @@ option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" ON) option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) +option(FF_BUILD_BIN_EXPORT_MODEL_ARCH "build export-model-arch utility" ON) set(FF_CUDA_ARCH "autodetect" CACHE STRING "Target CUDA Arch") if (FF_CUDA_ARCH STREQUAL "") diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index fcc19b33b9..1cd7068cfd 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -7,9 +7,13 @@ if(FF_BUILD_SUBSTITUTION_TOOL) endif() if(FF_BUILD_VISUALIZATION_TOOL) - add_subdirectory(substitutions-to-dot) + add_subdirectory(substitution-to-dot) endif() if(FF_BUILD_ARG_PARSER) add_subdirectory(arg_parser) endif() + +if(FF_BUILD_BIN_EXPORT_MODEL_ARCH) + add_subdirectory(export-model-arch) +endif() diff --git a/bin/export-model-arch/CMakeLists.txt b/bin/export-model-arch/CMakeLists.txt new file mode 100644 index 0000000000..b931668594 --- /dev/null +++ b/bin/export-model-arch/CMakeLists.txt @@ -0,0 +1,12 @@ +ff_add_executable( + NAME + export-model-arch + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + include/ + DEPS + utils + models + compiler +) diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml new file mode 100644 index 0000000000..efaf368bc8 --- /dev/null +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "JsonSPModelExport" +features = [ + "eq", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_computation_graph.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h", +] + +[[fields]] +name = "sp_decomposition" +type = "::FlexFlow::GenericBinarySPDecompositionTree" + +[[fields]] +name = "computation_graph" +type = "::FlexFlow::V1ComputationGraph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc new file mode 100644 index 0000000000..ccc720ed14 --- /dev/null +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -0,0 +1,208 @@ +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "utils/cli/cli_get_help_message.h" +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_parse_result.h" +#include "utils/cli/cli_spec.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" + +using namespace ::FlexFlow; + +ComputationGraph get_single_operator_computation_graph() { + ComputationGraphBuilder b; + + size_t batch_size = 8; + size_t in_channels = 16; + size_t out_channels = 12; + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + batch_size, + in_channels, + out_channels, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + InitializerAttrs kernel_initializer = + InitializerAttrs{GlorotUniformAttrs{/*seed=*/12}}; + InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}}; + tensor_guid_t output = b.dense(input, + in_channels, + Activation::RELU, + /*use_bias=*/true, + DataType::FLOAT, + kernel_initializer, + bias_initializer, + "my_example_operator"); + + return b.computation_graph; +} + +ComputationGraph get_default_transformer_computation_graph() { + TransformerConfig config = get_default_transformer_config(); + ComputationGraph cg = get_transformer_computation_graph(config); + + return cg; +} + +tl::expected + get_model_computation_graph(std::string const &model_name) { + if (model_name == "transformer") { + return get_default_transformer_computation_graph(); + } else if (model_name == "split_test") { + int batch_size = 8; + return get_split_test_computation_graph(batch_size); + } else if (model_name == "single_operator") { + return get_single_operator_computation_graph(); + } else { + return tl::unexpected(fmt::format("Unknown model name: {}", model_name)); + } +} + +tl::expected + get_sp_model_export(std::string const &model_name) { + ComputationGraph computation_graph = ({ + tl::expected result = + get_model_computation_graph(model_name); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); + }); + + ComputationGraphBinarySPDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_right_assoc_binary_sp_decomposition( + computation_graph); + if (!result.has_value()) { + return tl::unexpected("Failed to generate series-parallel decomposition " + "of computation graph."); + } + result.value(); + }); + + std::pair> v1_result = + to_v1_including_node_numbering(computation_graph); + V1ComputationGraph v1_cg = v1_result.first; + bidict layer_numbering = v1_result.second; + GenericBinarySPDecompositionTree v1_sp_decomposition = + transform(sp_decomposition.raw_tree, + [&](layer_guid_t const &l) { return layer_numbering.at_r(l); }); + + return JsonSPModelExport{ + v1_sp_decomposition, + v1_cg, + }; +} + +int main(int argc, char **argv) { + CLISpec cli = empty_cli_spec(); + + CLIArgumentKey arg_key_help = cli_add_help_flag(cli); + + CLIArgumentKey key_sp_decomposition = + cli_add_flag(cli, + CLIFlagSpec{"sp-decomposition", + std::nullopt, + "also output a series parallel decomposition of " + "the model's computation graph"}); + + CLIArgumentKey key_dot = cli_add_flag( + cli, + CLIFlagSpec{ + "dot", + std::nullopt, + "output a dot representation of the model's computation graph"}); + + CLIArgumentKey key_preprocessed_dot = cli_add_flag( + cli, + CLIFlagSpec{"preprocessed-dot", + std::nullopt, + "output a dot representation of model's computation graph " + "for preprocessed to help check series-parallel structure"}); + + std::vector model_options = { + "transformer", "split_test", "single_operator"}; + CLIArgumentKey key_model_name = cli_add_positional_argument( + cli, + CLIPositionalArgumentSpec{ + "model", model_options, "name of the model to export"}); + + assert(argc >= 1); + std::string prog_name = argv[0]; + + CLIParseResult parsed = ({ + tl::expected result = + cli_parse(cli, argc, argv); + if (!result.has_value()) { + std::string error_msg = result.error(); + std::cerr << cli_get_help_message(prog_name, cli); + std::cerr << std::endl; + std::cerr << "error: " << error_msg << std::endl; + return 1; + } + + result.value(); + }); + + bool help = cli_get_flag(parsed, arg_key_help); + if (help) { + std::cerr << cli_get_help_message(prog_name, cli); + return 1; + } + + std::string model_name = cli_get_argument(parsed, key_model_name); + bool sp_decompositition = cli_get_flag(parsed, key_sp_decomposition); + bool dot = cli_get_flag(parsed, key_dot); + bool preprocessed_dot = cli_get_flag(parsed, key_preprocessed_dot); + + auto handle_error = [](auto const &result) { + if (!result.has_value()) { + std::cerr << "error: " << result.error() << std::endl; + exit(1); + } + + return result.value(); + }; + + if (dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + std::cout << as_dot(cg) << std::endl; + return 0; + } + + if (preprocessed_dot) { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + std::string rendered = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + + std::cout << rendered << std::endl; + return 0; + } + + nlohmann::json json_output; + if (sp_decompositition) { + JsonSPModelExport model_export = + handle_error(get_sp_model_export(model_name)); + + json_output = model_export; + } else { + ComputationGraph cg = handle_error(get_model_computation_graph(model_name)); + + json_output = to_v1(cg); + } + std::cout << json_output.dump(2) << std::endl; + + return 0; +} diff --git a/bin/substitutions-to-dot/CMakeLists.txt b/bin/substitution-to-dot/CMakeLists.txt similarity index 100% rename from bin/substitutions-to-dot/CMakeLists.txt rename to bin/substitution-to-dot/CMakeLists.txt diff --git a/bin/substitutions-to-dot/substitution_to_dot.cc b/bin/substitution-to-dot/substitution_to_dot.cc similarity index 89% rename from bin/substitutions-to-dot/substitution_to_dot.cc rename to bin/substitution-to-dot/substitution_to_dot.cc index 49a199ddd3..1b5f715bcd 100644 --- a/bin/substitutions-to-dot/substitution_to_dot.cc +++ b/bin/substitution-to-dot/substitution_to_dot.cc @@ -1,4 +1,4 @@ -#include "substitution-generator/json.h" +#include "substitution-generator/legacy_rules.h" #include "utils/dot_file.h" #include #include @@ -24,10 +24,11 @@ int main(int argc, char **argv) { std::string json_path(argv[1]); std::string rule_name(argv[2]); - RuleCollection rule_collection = load_rule_collection_from_path(json_path); + LegacyRuleCollection rule_collection = + load_rule_collection_from_path(json_path); - std::optional found = std::nullopt; - for (Rule const &r : rule_collection.rules) { + std::optional found = std::nullopt; + for (LegacyRule const &r : rule_collection.rules) { if (r.name == rule_name) { found = r; break; @@ -39,7 +40,7 @@ int main(int argc, char **argv) { return 1; } - Rule r = found.value(); + LegacyRule r = found.value(); using Node = std::tuple; @@ -82,14 +83,14 @@ int main(int argc, char **argv) { }; for (int i = 0; i < r.srcOp.size(); i++) { - Operator const &o = r.srcOp[i]; + LegacyOperator const &o = r.srcOp[i]; Node srcOpNode = {NodeType::SRC, i, 0}; { dot.add_node(srcOpNode, label_map(fmt::to_string(o.op_type), srcOpNode)); dot.add_node_to_subgraph(srcOpNode, src_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::SRC_INPUT_TENSOR, t.opId, 0}; @@ -106,14 +107,14 @@ int main(int argc, char **argv) { } } for (int j = 0; j < r.dstOp.size(); j++) { - Operator const &o = r.dstOp[j]; + LegacyOperator const &o = r.dstOp[j]; Node dstOpNode = {NodeType::DST, j, 0}; { dot.add_node(dstOpNode, label_map(fmt::to_string(o.op_type), dstOpNode)); dot.add_node_to_subgraph(dstOpNode, dst_body_subgraph); } - for (Tensor const &t : o.input) { + for (LegacyTensor const &t : o.input) { if (t.opId < 0) { assert(t.tsId == 0); Node inputOpNode = {NodeType::DST_INPUT_TENSOR, t.opId, 0}; @@ -128,7 +129,7 @@ int main(int argc, char **argv) { } } } - for (MapOutput const &mo : r.mappedOutput) { + for (LegacyMapOutput const &mo : r.mappedOutput) { Node srcOutputNode = {NodeType::SRC_OUTPUT_TENSOR, mo.srcOpId, mo.srcTsId}; Node dstOutputNode = {NodeType::DST_OUTPUT_TENSOR, mo.dstOpId, mo.dstTsId}; { diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 1dbd16bdb1..90e100bb1b 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -149,6 +149,11 @@ function(ff_add_executable) ${FF_EXEC_NAME} ${SRC}) + target_include_directories( + ${FF_EXEC_NAME} + PRIVATE + ${FF_EXEC_PRIVATE_INCLUDE}) + target_link_libraries( ${FF_EXEC_NAME} ${FF_EXEC_DEPS}) diff --git a/lib/compiler/include/compiler/graph_utils.h b/lib/compiler/include/compiler/graph_utils.h index 1370357837..75fd369434 100644 --- a/lib/compiler/include/compiler/graph_utils.h +++ b/lib/compiler/include/compiler/graph_utils.h @@ -5,12 +5,12 @@ #include "pcg/computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); +SeriesParallelDecomposition + get_series_parallel_decomposition(ParallelComputationGraph const &pcg); ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 5d17cbb373..3774f2cd52 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -9,7 +9,8 @@ #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/visitable.h" namespace FlexFlow { diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml index 50496f661b..036647c0b1 100644 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ b/lib/compiler/include/compiler/optimal_cost_state.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h", + "utils/graph/series_parallel/series_parallel_decomposition.dtg.h", "pcg/machine_specification.dtg.h", "pcg/machine_view.dtg.h", "utils/graph/node/node.dtg.h", @@ -21,7 +21,7 @@ includes = [ [[fields]] name = "subgraph" -type = "::FlexFlow::SerialParallelDecomposition" +type = "::FlexFlow::SeriesParallelDecomposition" [[fields]] name = "resource" @@ -33,4 +33,4 @@ type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" [[fields]] name = "frontier_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h new file mode 100644 index 0000000000..3032e3efe9 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.dtg.h" +#include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &); +ComputationGraphBinarySPDecomposition + get_left_child(ComputationGraphBinarySPDecomposition const &); +ComputationGraphBinarySPDecomposition + get_right_child(ComputationGraphBinarySPDecomposition const &); +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &); +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &); +bool is_left_associative(ComputationGraphBinarySPDecomposition const &); +bool is_right_associative(ComputationGraphBinarySPDecomposition const &); +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml new file mode 100644 index 0000000000..147b1e3acf --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h new file mode 100644 index 0000000000..e85843ed26 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_COMPUTATION_GRAPH_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &); +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc new file mode 100644 index 0000000000..63054385ac --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc @@ -0,0 +1,90 @@ +#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" + +namespace FlexFlow { + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &d) { + return get_node_type(d.raw_tree); +} + +ComputationGraphBinarySPDecomposition + get_left_child(ComputationGraphBinarySPDecomposition const &d) { + return ComputationGraphBinarySPDecomposition{ + get_left_child(d.raw_tree), + }; +} + +ComputationGraphBinarySPDecomposition + get_right_child(ComputationGraphBinarySPDecomposition const &d) { + return ComputationGraphBinarySPDecomposition{ + get_right_child(d.raw_tree), + }; +} + +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { + return require_node(d.raw_tree); +} + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + left_associative_binary_sp_tree_from_nary(sp_decomposition); + + return ComputationGraphBinarySPDecomposition{transform( + raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; +} + +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + right_associative_binary_sp_tree_from_nary(sp_decomposition); + + return ComputationGraphBinarySPDecomposition{transform( + raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; +} + +bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { + return is_binary_sp_tree_left_associative(d.raw_tree); +} + +bool is_right_associative(ComputationGraphBinarySPDecomposition const &d) { + return is_binary_sp_tree_right_associative(d.raw_tree); +} + +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &d) { + return get_leaves(d.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..184ad93f4d --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,98 @@ +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "op-attrs/computation_graph_op_attrs.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph/computation_graph_edge.h" +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/record_formatter.h" + +namespace FlexFlow { + +std::string render_preprocessed_computation_graph_for_sp_decomposition( + ComputationGraph const &cg) { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + // dot has is incapable of rendering the number of edges in the all-to-all + // connection, so for visualization purposes we instead insert a "fake" node + // to reduce the n^2 edges to 2*n edges + DiGraph preprocessed_digraph = + materialize_digraph_view(cg.raw_graph); + Node fake_node = preprocessed_digraph.add_node(); + for (layer_guid_t const &src : weight_and_input_layers) { + preprocessed_digraph.add_edge(DirectedEdge{src.raw_node, fake_node}); + } + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + preprocessed_digraph.add_edge(DirectedEdge{fake_node, dst.raw_node}); + } + + std::function get_node_label = + [&](Node const &n) -> std::string { + if (n == fake_node) { + return "FAKE"; + } + LayerAttrs a = cg.raw_graph.at(n); + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + std::string preprocessed_dot = digraph_as_dot( + transitive_reduction(preprocessed_digraph), get_node_label); + + return preprocessed_dot; +} + +std::optional + get_computation_graph_series_parallel_decomposition( + ComputationGraph const &cg) { + + { + DiGraphView unpreprocessed_digraph = cg.raw_graph; + std::optional unpreprocessed_sp_decomposition = + get_series_parallel_decomposition(unpreprocessed_digraph); + if (unpreprocessed_sp_decomposition.has_value()) { + return unpreprocessed_sp_decomposition.value(); + } + } + + DiGraphView preprocessed_digraph = [&] { + std::unordered_set weight_and_input_layers = + filter(get_layers(cg), [&](layer_guid_t const &l) { + ComputationGraphOpAttrs op_attrs = get_layer_attrs(cg, l).attrs; + return op_attrs.has() || op_attrs.has(); + }); + + std::unordered_set weight_and_input_layer_successors = + get_subgraph_successors(cg, weight_and_input_layers); + + DiGraph digraph = materialize_digraph_view(cg.raw_graph); + for (layer_guid_t const &src : weight_and_input_layers) { + for (layer_guid_t const &dst : weight_and_input_layer_successors) { + digraph.add_edge(DirectedEdge{src.raw_node, dst.raw_node}); + } + } + + return digraph; + }(); + + return get_series_parallel_decomposition(preprocessed_digraph); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 08db219a21..a19c5e8597 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -4,13 +4,13 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "utils/containers/without_order.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { -SerialParallelDecomposition - get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { +SeriesParallelDecomposition + get_series_parallel_decomposition(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); - // return get_serial_parallel_decomposition(pcg.raw_graph); + // return get_series_parallel_decomposition(pcg.raw_graph); } ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { @@ -126,11 +126,11 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { // } // }; -// std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { +// std::unordered_set get_nodes(SeriesParallelDecomposition const &sp) { // return std::visit(GetNodes{}, sp.raw_variant); // } -// std::unordered_set get_nodes(SerialSplit const &serial) { +// std::unordered_set get_nodes(SeriesSplit const &serial) { // return set_union( // transform(serial.children, [](std::variant const // child) { @@ -140,7 +140,7 @@ SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { // std::unordered_set get_nodes(ParallelSplit const ¶llel) { // return set_union( -// transform(parallel.children, [](std::variant const +// transform(parallel.children, [](std::variant const // child) { // return std::visit(GetNodes{}, child); // })); diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index af7756c635..fddd825109 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -8,18 +8,19 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers.h" #include "utils/containers/are_disjoint.h" -#include "utils/containers/as_vector.h" #include "utils/containers/contains_key.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps.h" +#include "utils/containers/require_no_duplicates.h" +#include "utils/containers/vector_of.h" #include "utils/exception.h" #include "utils/graph/graph_split.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/serial_parallel_splits.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" namespace FlexFlow { @@ -83,39 +84,43 @@ std::vector> } // We may replace this by having unflattened AST -std::pair - decompose(SerialSplit const &serial) { +std::pair + decompose(SeriesSplit const &serial) { if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; + return {widen(serial.children[0]), + widen(serial.children[1])}; } - SerialSplit decompn1 = serial; + SeriesSplit decompn1 = serial; decompn1.children.pop_back(); - return {SerialParallelDecomposition(decompn1), - widen(serial.children.back())}; + return {SeriesParallelDecomposition(decompn1), + widen(serial.children.back())}; } -std::pair +std::pair decompose(ParallelSplit const ¶llel) { if (parallel.children.size() == 2) { - std::vector children = - transform(as_vector(parallel.children), [&](auto const &child) { - return widen(child); + std::vector children = + transform(vector_of(parallel.children), [&](auto const &child) { + return widen(child); }); return {children[0], children[1]}; } ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); + std::variant child = *parallel.children.begin(); decompn1.children.erase(child); - return {SerialParallelDecomposition(decompn1), - widen(child)}; + return {SeriesParallelDecomposition(decompn1), + widen(child)}; } GraphSplit - get_graph_split(SerialParallelDecomposition const &pre_decomposition, - SerialParallelDecomposition const &post_decomposition) { - return GraphSplit{get_nodes(pre_decomposition), - get_nodes(post_decomposition)}; + get_graph_split(SeriesParallelDecomposition const &pre_decomposition, + SeriesParallelDecomposition const &post_decomposition) { + std::unordered_set pre_nodes = + require_no_duplicates(get_nodes(pre_decomposition)); + std::unordered_set post_nodes = + require_no_duplicates(get_nodes(post_decomposition)); + assert(are_disjoint(pre_nodes, post_nodes)); + return GraphSplit{pre_nodes, post_nodes}; } float estimate_cost(SubParallelComputationGraph const &g, @@ -181,7 +186,7 @@ struct MachineMappingSearcher { template OptimalCostResult operator()(T const &t) { - OptimalCostState state{SerialParallelDecomposition{t}, + OptimalCostState state{SeriesParallelDecomposition{t}, resource, given_machine_views, frontier_machine_views}; @@ -202,13 +207,13 @@ struct MachineMappingSearcher { OptimalCostResult optimal_cost(SubParallelComputationGraph const &g, MachineSpecification resource, - SerialParallelDecomposition const &sp_decomposition) { + SeriesParallelDecomposition const &sp_decomposition) { return std::visit(OptimalCostFunctor(this, g, resource, {}, {}), sp_decomposition.raw_variant); } OptimalCostResult optimal_cost( - SerialSplit const &serial, + SeriesSplit const &serial, SubParallelComputationGraph const &g, MachineSpecification const &resource, std::unordered_map const &given_machine_views, @@ -218,8 +223,8 @@ struct MachineMappingSearcher { // OptimalCostResult optimal_result = OptimalCostResult::infinity(); // auto decomposed = decompose(serial); - // SerialParallelDecomposition pre_decompn = decomposed.first; - // SerialParallelDecomposition post_decompn = decomposed.second; + // SeriesParallelDecomposition pre_decompn = decomposed.first; + // SeriesParallelDecomposition post_decompn = decomposed.second; // GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); // SubParallelComputationGraph pre_graph = @@ -273,8 +278,8 @@ struct MachineMappingSearcher { NOT_IMPLEMENTED(); // auto decomposed = decompose(parallel); - // SerialParallelDecomposition decompn1 = decomposed.first; - // SerialParallelDecomposition decompn2 = decomposed.second; + // SeriesParallelDecomposition decompn1 = decomposed.first; + // SeriesParallelDecomposition decompn2 = decomposed.second; // GraphSplit graph_split = get_graph_split(decompn1, decompn2); // SubParallelComputationGraph g1 = get_subgraph(g, graph_split.first), @@ -350,8 +355,8 @@ OptimalCostResult optimal_cost( CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - SerialParallelDecomposition sp_decomposition = - get_serial_parallel_decomposition(g); + SeriesParallelDecomposition sp_decomposition = + get_series_parallel_decomposition(g); SubParallelComputationGraph subpcg = pcg_to_subpcg(g); MachineMappingSearcher searcher( cost_estimator, allowed_machine_views, cached_subgraph_costs); diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index 13b1fd3b83..3399a45f0f 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -10,4 +10,5 @@ ff_add_test_executable( compiler doctest utils-test-common + models ) diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc new file mode 100644 index 0000000000..ab537e73de --- /dev/null +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -0,0 +1,340 @@ +#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/split_test/split_test.h" +#include "models/transformer/transformer.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_computation_graph_series_parallel_decomposition(ComputationGraph)") { + SUBCASE("empty computation graph") { + ComputationGraph cg = make_empty_computation_graph(); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + // technically an empty graph is non-SP + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("just a single input") { + std::string input_layer_name = "my input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{input_layer.raw_node}; + + CHECK(result == correct); + } + + SUBCASE("single operator plus inputs and weights") { + std::string input_layer_name = "my input"; + std::string projection_weights_layer_name = "my projection weights"; + std::string bias_weights_layer_name = "my bias weights"; + std::string operator_name = "my operator"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_layer_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/operator_name, + /*projection_name=*/projection_weights_layer_name, + /*bias_name=*/bias_weights_layer_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_layer_name); + layer_guid_t projection_weights_layer = + get_layer_by_name(cg, projection_weights_layer_name); + layer_guid_t bias_weights_layer = + get_layer_by_name(cg, bias_weights_layer_name); + layer_guid_t operator_layer = get_layer_by_name(cg, operator_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ + ParallelSplit{ + input_layer.raw_node, + projection_weights_layer.raw_node, + bias_weights_layer.raw_node, + }, + operator_layer.raw_node, + }}; + + CHECK(result == correct); + } + + SUBCASE("SP without weight nodes but non-SP with weight nodes") { + // A minimal computation graph where without weights (w1 and w2) the + // computation graph is series-parallel, but with weight nodes it is not + // + // w1 input w2 + // \ / \ / + // op1 op2 + + std::string w1_name = "w1"; + std::string input_name = "input"; + std::string w2_name = "w2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op1_name, + /*projection_name=*/w1_name); + b.dense(input, + /*outDim=*/14, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/op2_name, + /*projection_name=*/w2_name); + + return b.computation_graph; + }(); + + layer_guid_t w1 = get_layer_by_name(cg, w1_name); + layer_guid_t input = get_layer_by_name(cg, input_name); + layer_guid_t w2 = get_layer_by_name(cg, w2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ + ParallelSplit{ + w1.raw_node, + input.raw_node, + w2.raw_node, + }, + ParallelSplit{ + op1.raw_node, + op2.raw_node, + }, + }}; + } + + SUBCASE("SP with or without preprocessing, but preprocessing would SP " + "decomposition") { + // computation graph: + // + // input1 input2 + // | | + // op1 op2 + + std::string input1_name = "input1"; + std::string input2_name = "input2"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + tensor_guid_t input2 = + b.create_input(input_shape, CreateGrad::YES, input2_name); + + b.relu(input1, op1_name); + b.relu(input2, op2_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t input2 = get_layer_by_name(cg, input2_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{ + SeriesSplit{ + input1.raw_node, + op1.raw_node, + }, + SeriesSplit{ + input2.raw_node, + op2.raw_node, + }, + }}; + } + + SUBCASE("not SP with or without weight nodes") { + // computation graph: + // + // input1 + // / \ + // op1 op2 + // | \ | + // | \ | + // op3 op4 + + std::string input1_name = "input1"; + std::string op1_name = "op1"; + std::string op2_name = "op2"; + std::string op3_name = "op3"; + std::string op4_name = "op4"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + tensor_guid_t input1 = + b.create_input(input_shape, CreateGrad::YES, input1_name); + + tensor_guid_t op1_output = b.relu(input1, op1_name); + tensor_guid_t op2_output = b.relu(input1, op2_name); + b.relu(op1_output, op3_name); + b.add(op1_output, op2_output, op4_name); + + return b.computation_graph; + }(); + + layer_guid_t input1 = get_layer_by_name(cg, input1_name); + layer_guid_t op1 = get_layer_by_name(cg, op1_name); + layer_guid_t op2 = get_layer_by_name(cg, op2_name); + layer_guid_t op3 = get_layer_by_name(cg, op3_name); + layer_guid_t op4 = get_layer_by_name(cg, op4_name); + + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + std::optional correct = std::nullopt; + } + + SUBCASE("real models") { + SUBCASE("split_test") { + ComputationGraph cg = + get_split_test_computation_graph(/*batch_size=*/8); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } + } + } + + TEST_CASE("render_preprocessed_computation_graph_for_sp_decomposition(" + "ComputationGraph)") { + // currently there's not really a good way to test this, and its arguable + // how much its output really should be validated as its primarily for + // visualization and so there's not really a strict definition of + // correctness, so for now we just run it on some models and make sure it + // doesn't crash. Don't use this as an example. + + SUBCASE("basic single-operator model") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT}; + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + + b.dense(input, /*outDim=*/14); + + return b.computation_graph; + }(); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("split_test") { + ComputationGraph cg = get_split_test_computation_graph(/*batch_size=*/8); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("transformer") { + ComputationGraph cg = + get_transformer_computation_graph(get_default_transformer_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + } +} diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h index d6b8222968..9f5a768b27 100644 --- a/lib/compiler/test/src/test_generator.h +++ b/lib/compiler/test/src/test_generator.h @@ -46,7 +46,7 @@ using namespace FlexFlow; // namespace rc { // Gen serialParallelMultiDiGraph() { -// return gen::map(gen::arbitrary(), +// return gen::map(gen::arbitrary(), // multidigraph_from_sp_decomposition); // } @@ -113,12 +113,12 @@ using namespace FlexFlow; // }; // template <> -// struct Arbitrary { -// static Gen arbitrary() { +// struct Arbitrary { +// static Gen arbitrary() { // return gen::mapcat(gen::arbitrary(), [](bool is_serial) { -// return is_serial ? gen::construct( +// return is_serial ? gen::construct( // gen::arbitrary()) -// : gen::construct( +// : gen::construct( // gen::arbitrary()); // }); // } diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index d3221474c0..f523520f9f 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -16,7 +16,7 @@ class GenericTensorAccessorW { template typename data_type_enum_to_class
::type *get() const { if (this->data_type == DT) { - return static_cast *>(this->ptr); + return static_cast *>(this->ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", this->data_type, DT); @@ -47,7 +47,7 @@ class GenericTensorAccessorR { template typename data_type_enum_to_class
::type const *get() const { if (this->data_type == DT) { - return static_cast const *>(this->ptr); + return static_cast const *>(this->ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", this->data_type, DT); @@ -94,7 +94,7 @@ template typename data_type_enum_to_class
::type * get(GenericTensorAccessorW const &a) { if (a.data_type == DT) { - return static_cast *>(a.ptr); + return static_cast *>(a.ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", a.data_type, DT); @@ -102,9 +102,9 @@ typename data_type_enum_to_class
::type * } template -std::vector *> +std::vector *> get(std::vector const &accs) { - std::vector *> out; + std::vector *> out; for (auto acc : accs) { out.push_back(get
(acc)); } @@ -115,7 +115,7 @@ template typename data_type_enum_to_class
::type const * get(GenericTensorAccessorR const &a) { if (a.data_type == DT) { - return static_cast const *>(a.ptr); + return static_cast const *>(a.ptr); } else { throw mk_runtime_error( "Invalid access data type ({} != {})", a.data_type, DT); @@ -139,9 +139,9 @@ std::vector get_half_ptrs(std::vector const &); template -std::vector const *> +std::vector const *> get(std::vector const &accs) { - std::vector const *> out; + std::vector const *> out; for (auto acc : accs) { out.push_back(get
(acc)); } diff --git a/lib/kernels/include/kernels/array_shape.h b/lib/kernels/include/kernels/array_shape.h index 5de9fae7ad..96a3b3b281 100644 --- a/lib/kernels/include/kernels/array_shape.h +++ b/lib/kernels/include/kernels/array_shape.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_KERNELS_ARRAY_SHAPE_H #include "legion_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" #include "utils/visitable.h" #include diff --git a/lib/kernels/include/kernels/attention_kernels.h b/lib/kernels/include/kernels/attention_kernels.h index 575de57f09..eb5a1b8198 100644 --- a/lib/kernels/include/kernels/attention_kernels.h +++ b/lib/kernels/include/kernels/attention_kernels.h @@ -5,7 +5,6 @@ #include "kernels/allocation.h" #include "kernels/device.h" #include "kernels/ff_handle.h" -#include "op-attrs/ops/attention.h" #include namespace FlexFlow { diff --git a/lib/kernels/include/kernels/batch_matmul_kernels.h b/lib/kernels/include/kernels/batch_matmul_kernels.h index 38be2118fa..bfd72647b0 100644 --- a/lib/kernels/include/kernels/batch_matmul_kernels.h +++ b/lib/kernels/include/kernels/batch_matmul_kernels.h @@ -4,7 +4,6 @@ #include "device.h" #include "kernels/allocation.h" #include "kernels/ff_handle.h" -#include "utils/visitable.h" namespace FlexFlow { namespace Kernels { diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 14bb9d2cd2..52609a303f 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -3,6 +3,7 @@ #include "accessor.h" #include "kernels/cpu.h" +#include "op-attrs/datatype_value.dtg.h" #include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/src/allocation.cc b/lib/kernels/src/allocation.cc index a892e14a54..ccd88580db 100644 --- a/lib/kernels/src/allocation.cc +++ b/lib/kernels/src/allocation.cc @@ -1,4 +1,5 @@ #include "kernels/allocation.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/kernels/src/cpu/initializer_kernels.cc b/lib/kernels/src/cpu/initializer_kernels.cc index f3b4c9b8fd..91f4f46ef8 100644 --- a/lib/kernels/src/cpu/initializer_kernels.cc +++ b/lib/kernels/src/cpu/initializer_kernels.cc @@ -24,7 +24,7 @@ struct ConstantInitKernel { void operator()(GenericTensorAccessorW const &tensor, DataTypeValue value) const { auto arr = get
(tensor); - auto unwrapped_value = get>(value); + auto unwrapped_value = value.get>(); for (size_t i = 0; i < get_volume(tensor.shape); i++) { arr[i] = unwrapped_value; } diff --git a/lib/kernels/src/cuda/embedding_kernels.cu b/lib/kernels/src/cuda/embedding_kernels.cu index 371b45f760..e6a614ba70 100644 --- a/lib/kernels/src/cuda/embedding_kernels.cu +++ b/lib/kernels/src/cuda/embedding_kernels.cu @@ -358,7 +358,7 @@ struct ForwardKernel { weight.data_type == DataType::DOUBLE); if (!aggr.has_value()) { - embed_forward_no_aggr, real_type> + embed_forward_no_aggr, real_type_t> <<, real_type> + embed_forward_with_aggr, real_type_t> <<, real_type> + embed_backward_no_aggr, real_type_t> <<, real_type> + embed_backward_with_aggr, real_type_t> <<> + add_kernel> <<>>( input_grad.get
(), output_grad.get
(), num_elements); } diff --git a/lib/kernels/src/cuda/ops/element_unary_kernels.cu b/lib/kernels/src/cuda/ops/element_unary_kernels.cu index 3eb9c486f2..a35d28fa8c 100644 --- a/lib/kernels/src/cuda/ops/element_unary_kernels.cu +++ b/lib/kernels/src/cuda/ops/element_unary_kernels.cu @@ -267,16 +267,16 @@ struct ForwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_forward_kernel> + elewise_scalar_unary_forward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, input.get(), output.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_forward_kernel> + elewise_unary_forward_kernel> <<>>( num_elements, op_type, input.get(), output.get()); } @@ -313,10 +313,10 @@ struct BackwardKernel { } else if (use_scalar(op_type)) { assert(scalar.has_value()); size_t num_elements = input.shape.num_elements(); - elewise_scalar_unary_backward_kernel> + elewise_scalar_unary_backward_kernel> <<>>( num_elements, - static_cast>(scalar.value()), + static_cast>(scalar.value()), op_type, output.get(), output_grad.get(), @@ -324,7 +324,7 @@ struct BackwardKernel { input_grad.get()); } else { size_t num_elements = input.shape.num_elements(); - elewise_unary_backward_kernel> + elewise_unary_backward_kernel> <<>>( num_elements, op_type, diff --git a/lib/kernels/src/cuda/ops/partition_kernels.cu b/lib/kernels/src/cuda/ops/partition_kernels.cu index e356f83d2a..1d07efb5fa 100644 --- a/lib/kernels/src/cuda/ops/partition_kernels.cu +++ b/lib/kernels/src/cuda/ops/partition_kernels.cu @@ -41,12 +41,12 @@ struct BackwardKernel { RepartitionPerDeviceState const &m, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { - add_kernel><<>>(input_grad.get(), - output_grad.get(), - input_grad.shape.num_elements()); + add_kernel><<>>(input_grad.get(), + output_grad.get(), + input_grad.shape.num_elements()); } }; diff --git a/lib/kernels/src/cuda/ops/reduction_kernels.cu b/lib/kernels/src/cuda/ops/reduction_kernels.cu index 992d27fe60..0c6ba7d8e3 100644 --- a/lib/kernels/src/cuda/ops/reduction_kernels.cu +++ b/lib/kernels/src/cuda/ops/reduction_kernels.cu @@ -42,7 +42,7 @@ struct ForwardKernel { size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - reduction_forward_kernel> + reduction_forward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/replicate_kernels.cu b/lib/kernels/src/cuda/ops/replicate_kernels.cu index 0c87418f58..76bfbe2658 100644 --- a/lib/kernels/src/cuda/ops/replicate_kernels.cu +++ b/lib/kernels/src/cuda/ops/replicate_kernels.cu @@ -54,7 +54,7 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - replicate_backward_kernel> + replicate_backward_kernel> <<>>( input.get(), output.get(), diff --git a/lib/kernels/src/cuda/ops/reshape_kernels.cu b/lib/kernels/src/cuda/ops/reshape_kernels.cu index c4da408952..5b7843a3a5 100644 --- a/lib/kernels/src/cuda/ops/reshape_kernels.cu +++ b/lib/kernels/src/cuda/ops/reshape_kernels.cu @@ -45,14 +45,14 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - apply_add_with_scale> + apply_add_with_scale> <<>>(input.get(), output.get(), input.shape.num_elements(), - static_cast>(alpha)); + static_cast>(alpha)); } }; diff --git a/lib/kernels/src/hip/ops/replicate_kernels.cpp b/lib/kernels/src/hip/ops/replicate_kernels.cpp index 9a5fc813c3..8d27bb1908 100644 --- a/lib/kernels/src/hip/ops/replicate_kernels.cpp +++ b/lib/kernels/src/hip/ops/replicate_kernels.cpp @@ -55,15 +55,16 @@ struct BackwardKernel { GenericTensorAccessorR const &output, size_t num_replicas) { size_t total_elements = input.shape.num_elements() * num_replicas; - hipLaunchKernelGGL(HIP_KERNEL_NAME(replicate_backward_kernel>), - GET_BLOCKS(total_elements), - CUDA_NUM_THREADS, - 0, - stream, - input.get(), - output.get(), - input.shape.num_elements(), - num_replicas); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(replicate_backward_kernel>), + GET_BLOCKS(total_elements), + CUDA_NUM_THREADS, + 0, + stream, + input.get(), + output.get(), + input.shape.num_elements(), + num_replicas); } } diff --git a/lib/kernels/src/hip/ops/reshape_kernels.cpp b/lib/kernels/src/hip/ops/reshape_kernels.cpp index 941495c0fd..47978a5f4a 100644 --- a/lib/kernels/src/hip/ops/reshape_kernels.cpp +++ b/lib/kernels/src/hip/ops/reshape_kernels.cpp @@ -47,7 +47,7 @@ struct BackwardKernel { GenericTensorAccessorW const &input, GenericTensorAccessorR const &output) { float alpha = 1.0f; - hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), + hipLaunchKernelGGL(HIP_KERNEL_NAME(apply_add_with_scale>), GET_BLOCKS(input.shape.num_elements()), CUDA_NUM_THREADS, 0, @@ -55,7 +55,7 @@ struct BackwardKernel { input.get(), output.get(), input.shape.num_elements(), - static_cast> alpha); + static_cast> alpha); } } diff --git a/lib/local-execution/include/local-execution/cost_estimate.h b/lib/local-execution/include/local-execution/cost_estimate.h index 33954827bd..31503e0da9 100644 --- a/lib/local-execution/include/local-execution/cost_estimate.h +++ b/lib/local-execution/include/local-execution/cost_estimate.h @@ -4,8 +4,8 @@ #include "local-execution/cost_details.dtg.h" #include "local-execution/local_training_backing.h" -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" diff --git a/lib/local-execution/include/local-execution/legion_tensor_shape.h b/lib/local-execution/include/local-execution/legion_tensor_shape.h index f1d2ad252a..2f2ed50d41 100644 --- a/lib/local-execution/include/local-execution/legion_tensor_shape.h +++ b/lib/local-execution/include/local-execution/legion_tensor_shape.h @@ -4,8 +4,9 @@ #include "kernels/legion_dim.h" #include "op-attrs/datatype.h" #include "op-attrs/ff_dim.h" -#include "op-attrs/tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "utils/stack_vector.h" +#include "utils/visitable.h" #include namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_slots_backing.h b/lib/local-execution/include/local-execution/local_slots_backing.h index 6a0c28e988..5b826c7022 100644 --- a/lib/local-execution/include/local-execution/local_slots_backing.h +++ b/lib/local-execution/include/local-execution/local_slots_backing.h @@ -7,6 +7,9 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/per_device_op_state.h" #include "local-execution/runtime_arg_config.h" +#include "pcg/computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" +#include "pcg/tensor_guid_t.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/local_training_backing.h b/lib/local-execution/include/local-execution/local_training_backing.h index b398bb8cc3..6789624076 100644 --- a/lib/local-execution/include/local-execution/local_training_backing.h +++ b/lib/local-execution/include/local-execution/local_training_backing.h @@ -3,6 +3,7 @@ #include "local-execution/local_slots_backing.h" #include "local-execution/task_registry.h" +#include "pcg/computation_graph.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_arg_ref.h b/lib/local-execution/include/local-execution/op_arg_ref.h index 20d6ccb1c5..102a8d4362 100644 --- a/lib/local-execution/include/local-execution/op_arg_ref.h +++ b/lib/local-execution/include/local-execution/op_arg_ref.h @@ -5,7 +5,7 @@ #include "local-execution/device_specific.h" #include "local-execution/op_arg_ref_type.dtg.h" #include "local-execution/per_device_op_state.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/op_task_invocation.h b/lib/local-execution/include/local-execution/op_task_invocation.h index 73a0460554..0f351c3a0e 100644 --- a/lib/local-execution/include/local-execution/op_task_invocation.h +++ b/lib/local-execution/include/local-execution/op_task_invocation.h @@ -13,10 +13,6 @@ #include "local-execution/slot_grad_id.dtg.h" #include "local-execution/task_id_t.dtg.h" #include "local-execution/variadic_tensor_ref.h" -#include "op-attrs/computation_graph_op_attrs.h" -#include "pcg/computation_graph.h" -#include "utils/bidict/bidict.h" -#include "utils/stack_map.h" #include #include #include diff --git a/lib/local-execution/include/local-execution/sim_environment.h b/lib/local-execution/include/local-execution/sim_environment.h index 3ba17ea3ff..7c81cba408 100644 --- a/lib/local-execution/include/local-execution/sim_environment.h +++ b/lib/local-execution/include/local-execution/sim_environment.h @@ -7,7 +7,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/task_argument_accessor.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/machine_view.h" #include diff --git a/lib/local-execution/include/local-execution/task_registry.struct.toml b/lib/local-execution/include/local-execution/task_registry.struct.toml index 308527efac..ada467a67d 100644 --- a/lib/local-execution/include/local-execution/task_registry.struct.toml +++ b/lib/local-execution/include/local-execution/task_registry.struct.toml @@ -15,6 +15,7 @@ includes = [ src_includes = [ "utils/hash/unordered_map.h", "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", ] [[fields]] diff --git a/lib/local-execution/src/legion_tensor_shape.cc b/lib/local-execution/src/legion_tensor_shape.cc index b3a045bab4..bce29fafeb 100644 --- a/lib/local-execution/src/legion_tensor_shape.cc +++ b/lib/local-execution/src/legion_tensor_shape.cc @@ -1,4 +1,5 @@ #include "local-execution/legion_tensor_shape.h" +#include "op-attrs/tensor_shape.h" namespace FlexFlow { diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index d4e0467cbf..5203991f25 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -51,7 +51,7 @@ CostDetails LocalCostEstimator::estimate_cost( for (ParallelTensorShape const &input : inputs) { TensorShape tensor_shape = get_piece_shape(input); tensor_guid_t tensor_id = - cg_builder.create_tensor(tensor_shape, CreateGrad::YES); + cg_builder.create_input(tensor_shape, CreateGrad::YES); GenericTensorAccessorW tensor_backing = allocator.allocate_tensor(tensor_shape); tensor_backing_map.insert({tensor_id, tensor_backing}); diff --git a/lib/local-execution/src/local_slots_backing.cc b/lib/local-execution/src/local_slots_backing.cc index 0ec9068c6a..ac35d63c0b 100644 --- a/lib/local-execution/src/local_slots_backing.cc +++ b/lib/local-execution/src/local_slots_backing.cc @@ -1,4 +1,6 @@ #include "local-execution/local_slots_backing.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "utils/containers/contains_key.h" #include "utils/overload.h" diff --git a/lib/local-execution/src/local_training_backing.cc b/lib/local-execution/src/local_training_backing.cc index a2ee06a95a..0fdf1761e3 100644 --- a/lib/local-execution/src/local_training_backing.cc +++ b/lib/local-execution/src/local_training_backing.cc @@ -1,5 +1,6 @@ #include "local-execution/local_training_backing.h" #include "local-execution/task_signature_impl.h" +#include "pcg/computation_graph.h" #include "utils/containers/reversed.h" #include "utils/exception.h" diff --git a/lib/local-execution/src/op_task_signature.cc b/lib/local-execution/src/op_task_signature.cc index 36a1dd708d..932b330453 100644 --- a/lib/local-execution/src/op_task_signature.cc +++ b/lib/local-execution/src/op_task_signature.cc @@ -1,4 +1,5 @@ #include "local-execution/op_task_signature.h" +#include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" #include "utils/fmt/unordered_set.h" diff --git a/lib/local-execution/src/ops/element_unary.cc b/lib/local-execution/src/ops/element_unary.cc index a52ebb8089..4ee609bd6c 100644 --- a/lib/local-execution/src/ops/element_unary.cc +++ b/lib/local-execution/src/ops/element_unary.cc @@ -1,6 +1,7 @@ #include "element_unary.h" #include "kernels/element_unary_kernels.h" #include "op-attrs/get_output_shapes.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/hash-utils.h" namespace FlexFlow { diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 2bd0acc222..4c01df53e9 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -2,10 +2,12 @@ #include "kernels/local_cuda_allocator.h" #include "kernels/managed_per_device_ff_handle.h" #include "local-execution/local_cost_estimator.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/computation_graph_builder.h" #include "test_utils.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_CUDA_TEST_SUITE) { TEST_CASE("Local Cost Estimator") { @@ -73,5 +75,3 @@ TEST_SUITE(FF_CUDA_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_slots_backing.cc b/lib/local-execution/test/src/test_local_slots_backing.cc index 542aa66087..1ec441fbca 100644 --- a/lib/local-execution/test/src/test_local_slots_backing.cc +++ b/lib/local-execution/test/src/test_local_slots_backing.cc @@ -1,15 +1,19 @@ -#include "doctest/doctest.h" #include "kernels/attention_kernels.h" #include "local-execution/local_cost_estimator.h" #include "local-execution/local_cpu_allocator.h" #include "local-execution/local_slots_backing.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "pcg/computation_graph.h" #include "pcg/computation_graph_builder.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/variant.h" +#include "test/utils/doctest/fmt/vector.h" #include "test_utils.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" +#include -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalSlotsBacking -- Attention Op") { @@ -37,11 +41,11 @@ TEST_SUITE(FF_TEST_SUITE) { // build graph ComputationGraphBuilder cg_builder; tensor_guid_t query_guid = - cg_builder.create_tensor(query_shape, CreateGrad::YES); + cg_builder.create_input(query_shape, CreateGrad::YES); tensor_guid_t key_guid = - cg_builder.create_tensor(key_shape, CreateGrad::YES); + cg_builder.create_input(key_shape, CreateGrad::YES); tensor_guid_t value_guid = - cg_builder.create_tensor(value_shape, CreateGrad::YES); + cg_builder.create_input(value_shape, CreateGrad::YES); std::string layer_name = "attn1"; tensor_guid_t output_guid = @@ -269,5 +273,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_local_task_arg_accessor.cc b/lib/local-execution/test/src/test_local_task_arg_accessor.cc index 0637faaf1c..f52fccb1ed 100644 --- a/lib/local-execution/test/src/test_local_task_arg_accessor.cc +++ b/lib/local-execution/test/src/test_local_task_arg_accessor.cc @@ -4,7 +4,7 @@ #include "local-execution/task_signature_impl.h" #include "utils/fmt/variant.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("LocalTaskArgumentAccessor") { @@ -140,5 +140,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/local-execution/test/src/test_task_registry.cc b/lib/local-execution/test/src/test_task_registry.cc index fa3b068425..e18b7ea2de 100644 --- a/lib/local-execution/test/src/test_task_registry.cc +++ b/lib/local-execution/test/src/test_task_registry.cc @@ -7,7 +7,7 @@ #include "utils/fmt/optional.h" #include "utils/fmt/unordered_map.h" -namespace FlexFlow { +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Task Registry") { @@ -127,5 +127,3 @@ TEST_SUITE(FF_TEST_SUITE) { } } } - -} // namespace FlexFlow diff --git a/lib/models/CMakeLists.txt b/lib/models/CMakeLists.txt index 7dd7f48700..4f4b22ed47 100644 --- a/lib/models/CMakeLists.txt +++ b/lib/models/CMakeLists.txt @@ -11,6 +11,7 @@ ff_add_library( op-attrs utils pcg + rapidcheck ) -add_subdirectory(test) \ No newline at end of file +add_subdirectory(test) diff --git a/lib/models/include/models/split_test/split_test.h b/lib/models/include/models/split_test/split_test.h new file mode 100644 index 0000000000..b03e45b2d2 --- /dev/null +++ b/lib/models/include/models/split_test/split_test.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_SPLIT_TEST_SPLIT_TEST_H + +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the computation graph of the old FlexFlow test model + * split_test + * + * @note This is a tiny model developed for testing the original Unity + * implementation. It is not a "real" model and has never been trained. + */ +ComputationGraph get_split_test_computation_graph(int batch_size); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/transformer.h b/lib/models/include/models/transformer/transformer.h similarity index 90% rename from lib/models/include/models/transformer.h rename to lib/models/include/models/transformer/transformer.h index e50fa37709..385100a4c9 100644 --- a/lib/models/include/models/transformer.h +++ b/lib/models/include/models/transformer/transformer.h @@ -1,7 +1,7 @@ -#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H -#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_H +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_TRANSFORMER_TRANSFORMER_H -#include "models/transformer_config.dtg.h" +#include "models/transformer/transformer_config.dtg.h" #include "pcg/computation_graph_builder.h" namespace FlexFlow { diff --git a/lib/models/include/models/transformer_config.struct.toml b/lib/models/include/models/transformer/transformer_config.struct.toml similarity index 100% rename from lib/models/include/models/transformer_config.struct.toml rename to lib/models/include/models/transformer/transformer_config.struct.toml diff --git a/lib/models/src/models/split_test/split_test.cc b/lib/models/src/models/split_test/split_test.cc new file mode 100644 index 0000000000..118f94ec06 --- /dev/null +++ b/lib/models/src/models/split_test/split_test.cc @@ -0,0 +1,39 @@ +#include "models/split_test/split_test.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +ComputationGraph get_split_test_computation_graph(int batch_size) { + ComputationGraphBuilder cgb; + + int layer_dim1 = 256; + int layer_dim2 = 128; + int layer_dim3 = 64; + int layer_dim4 = 32; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(batch_size), + size_t_from_int(layer_dim1), + }}, + DataType::FLOAT, + }; + + tensor_guid_t t = cgb.create_input(input_shape, CreateGrad::YES); + t = cgb.dense(t, layer_dim2); + t = cgb.relu(t); + tensor_guid_t t1 = cgb.dense(t, layer_dim3); + tensor_guid_t t2 = cgb.dense(t, layer_dim3); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t1 = cgb.dense(t, layer_dim4); + t2 = cgb.dense(t, layer_dim4); + t = cgb.add(t1, t2); + t = cgb.relu(t); + t = cgb.softmax(t); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/src/models/transformer.cc b/lib/models/src/models/transformer/transformer.cc similarity index 95% rename from lib/models/src/models/transformer.cc rename to lib/models/src/models/transformer/transformer.cc index 874cd85787..e179359940 100644 --- a/lib/models/src/models/transformer.cc +++ b/lib/models/src/models/transformer/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" namespace FlexFlow { @@ -100,7 +100,7 @@ tensor_guid_t assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention_normalized)); - tensor_guid_t mha = cgb.multihead_attention(input, + tensor_guid_t mha = cgb.multihead_attention(self_attention_normalized, encoder_output, encoder_output, config.num_features, @@ -149,11 +149,13 @@ ComputationGraph config.batch_size, config.sequence_length, config.num_features}}, DataType::FLOAT, }; - tensor_guid_t input = cgb.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES, "input"); + tensor_guid_t target = + cgb.create_input(input_shape, CreateGrad::YES, "target"); tensor_guid_t encoder_output = create_transformer_encoder(cgb, config, input); tensor_guid_t decoder_output = - create_transformer_decoder(cgb, config, input, encoder_output); + create_transformer_decoder(cgb, config, target, encoder_output); tensor_guid_t out_prob = cgb.softmax(cgb.dense(decoder_output, /*outDim=*/config.vocab_size, diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer.cc index 2133e9965b..20274c4151 100644 --- a/lib/models/test/src/models/transformer.cc +++ b/lib/models/test/src/models/transformer.cc @@ -1,4 +1,4 @@ -#include "models/transformer.h" +#include "models/transformer/transformer.h" #include "pcg/computation_graph.h" #include @@ -12,7 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("num layers") { int result_num_layers = get_layers(result).size(); - int correct_num_layers = 317; + int correct_num_layers = 258; CHECK(result_num_layers == correct_num_layers); } } diff --git a/lib/op-attrs/include/op-attrs/datatype.h b/lib/op-attrs/include/op-attrs/datatype.h index 6204b9ca49..5af00fb510 100644 --- a/lib/op-attrs/include/op-attrs/datatype.h +++ b/lib/op-attrs/include/op-attrs/datatype.h @@ -47,14 +47,7 @@ typename data_type_enum_to_class
::type cast_to(T t) { } template -using real_type = typename data_type_enum_to_class
::type; - -using DataTypeValue = std::variant, - real_type, - real_type, - real_type, - /* real_type, */ - real_type>; +using real_type_t = typename data_type_enum_to_class
::type; size_t size_of_datatype(DataType); diff --git a/lib/op-attrs/include/op-attrs/datatype_value.variant.toml b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml new file mode 100644 index 0000000000..3386e9d131 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/datatype_value.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "DataTypeValue" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +type = "float" + +[[values]] +type = "double" + +[[values]] +type = "int32_t" + +[[values]] +type = "int64_t" + +[[values]] +type = "bool" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered.h index 6868ba083f..34d186e74e 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered.h @@ -3,8 +3,8 @@ #include "op-attrs/ff_dim.dtg.h" #include "utils/fmt/vector.h" -#include "utils/json.h" #include "utils/stack_vector.h" +#include namespace FlexFlow { @@ -202,11 +202,12 @@ FFOrdered const &outer_to_inner(FFOrdered const &ff_ordered) { namespace nlohmann { template struct adl_serializer<::FlexFlow::DimOrdered> { - static ::FlexFlow::DimOrdered from_json(json const &j) { + static ::FlexFlow::DimOrdered from_json(nlohmann::json const &j) { return {j.template get>()}; } - static void to_json(json &j, ::FlexFlow::DimOrdered const &x) { + static void to_json(nlohmann::json &j, + ::FlexFlow::DimOrdered const &x) { j = std::vector{x.cbegin(), x.cend()}; } }; diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index f3dfe5d199..d39bac1bde 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" #include "utils/containers/subvec.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/optional.h" namespace FlexFlow { @@ -18,7 +18,7 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, }; return DimOrdered{ - subvec(as_vector(d), to_raw_idx(start), to_raw_idx(end))}; + subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } template diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h index 3a31ea511d..ae6e552243 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" namespace FlexFlow { @@ -12,7 +12,7 @@ DimOrdered> transform(DimOrdered const &d, F f) { using Out = std::invoke_result_t; - return DimOrdered{vector_transform(as_vector(d), f)}; + return DimOrdered{vector_transform(vector_of(d), f)}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h index 54554afb81..023dcfc586 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H #include "op-attrs/dim_ordered.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" namespace FlexFlow { @@ -11,7 +11,7 @@ template DimOrdered> zip(DimOrdered const &lhs, DimOrdered const &rhs) { return DimOrdered>{ - zip(as_vector(lhs), as_vector(rhs))}; + zip(vector_of(lhs), vector_of(rhs))}; } } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/broadcast.h b/lib/op-attrs/include/op-attrs/ops/broadcast.h index 0a5f057578..4fd7d49234 100644 --- a/lib/op-attrs/include/op-attrs/ops/broadcast.h +++ b/lib/op-attrs/include/op-attrs/ops/broadcast.h @@ -5,11 +5,14 @@ #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" +#include "utils/record_formatter.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(BroadcastAttrs); +RecordFormatter as_dot(BroadcastAttrs const &); + tl::expected get_output_shape(BroadcastAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(BroadcastAttrs const &, diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml index 2fb385b64d..5bef144cd9 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d_attrs.struct.toml @@ -12,11 +12,12 @@ features = [ includes = [ "", "op-attrs/activation.dtg.h", - "utils/json.h", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] fields = [ diff --git a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml index 4b9c8a9f45..403bb87592 100644 --- a/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/element_unary_attrs.struct.toml @@ -11,12 +11,14 @@ features = [ ] includes = [ - "utils/json.h", - "op-attrs/operator_type.h", + "op-attrs/operator_type.dtg.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml index 38d5a4371e..66d6f99253 100644 --- a/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/embedding_attrs.struct.toml @@ -17,6 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml index eaa34cc496..0a35a6c5ec 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/linear_attrs.struct.toml @@ -13,11 +13,13 @@ includes = [ "op-attrs/datatype.dtg.h", "op-attrs/activation.dtg.h", "op-attrs/regularizer_attrs.dtg.h", - "utils/json.h", + "", ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/tensor_shape.h b/lib/op-attrs/include/op-attrs/tensor_shape.h index 108df58dce..14ee637f92 100644 --- a/lib/op-attrs/include/op-attrs/tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/tensor_shape.h @@ -11,11 +11,6 @@ size_t &dim_at_idx(TensorShape &, ff_dim_t); size_t get_num_elements(TensorShape const &); size_t get_size_in_bytes(TensorShape const &); -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal); -std::optional - get_broadcast_target_shape(std::unordered_set const &); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index 166416cbad..054930cebd 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -1,5 +1,8 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_op_type.h" +#include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/linear.h" +#include "utils/overload.h" namespace FlexFlow { @@ -8,4 +11,16 @@ OperatorType get_op_type(ComputationGraphOpAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } +RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { + return attrs.visit(overload{ + [](LinearAttrs const &l) { return as_dot(l); }, + [](BroadcastAttrs const &a) { return as_dot(a); }, + [&](auto const &) { + RecordFormatter r; + r << fmt::to_string(get_op_type(attrs)); + return r; + }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/broadcast.cc b/lib/op-attrs/src/op-attrs/ops/broadcast.cc index bd69864aff..aa3c95f551 100644 --- a/lib/op-attrs/src/op-attrs/ops/broadcast.cc +++ b/lib/op-attrs/src/op-attrs/ops/broadcast.cc @@ -1,8 +1,26 @@ #include "op-attrs/ops/broadcast.h" #include "op-attrs/tensor_dims.h" +#include "utils/record_formatter.h" namespace FlexFlow { +RecordFormatter as_dot(BroadcastAttrs const &attrs) { + RecordFormatter r; + + auto kv = [](std::string const &label, auto const &val) { + RecordFormatter rr; + rr << label << fmt::to_string(val); + return rr; + }; + + for (int i = 0; i < num_dims(attrs.target_dims); i++) { + r << kv(fmt::format("target_dims[{}]", i), + dim_at_idx(attrs.target_dims, ff_dim_t{i})); + } + + return r; +} + tl::expected get_output_shape(BroadcastAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 73c0068826..4bce5449f4 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -4,9 +4,9 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -35,7 +35,7 @@ int total_replica_degree(ParallelTensorDims const &dims) { } int total_shard_degree(ParallelTensorDims const &dims) { - return product(transform(as_vector(dims.shard_dims), + return product(transform(vector_of(dims.shard_dims), [](ShardParallelDim const &d) { return d.degree; })); } diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index e716793a8f..ba7d6e8357 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -3,9 +3,9 @@ #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.dtg.h" #include "utils/containers/all_of.h" -#include "utils/containers/as_vector.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/integer_conversions.h" @@ -33,8 +33,8 @@ bool tensor_dims_is_broadcastable_to(TensorDims const &curr, return false; } - std::vector curr_dims = as_vector(curr.ff_ordered); - std::vector goal_dims = as_vector(goal.ff_ordered); + std::vector curr_dims = vector_of(curr.ff_ordered); + std::vector goal_dims = vector_of(goal.ff_ordered); for (auto const &[curr_dim, goal_dim] : zip(reversed(curr_dims), reversed(goal_dims))) { @@ -72,7 +72,7 @@ ParallelTensorDims DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees) { std::vector lifted = - transform(zip(as_vector(dims.ff_ordered), as_vector(shard_degrees)), + transform(zip(vector_of(dims.ff_ordered), vector_of(shard_degrees)), [](std::pair const &p) { size_t size = p.first; int degree = p.second; diff --git a/lib/op-attrs/src/op-attrs/tensor_shape.cc b/lib/op-attrs/src/op-attrs/tensor_shape.cc index b604d442cb..07508e3065 100644 --- a/lib/op-attrs/src/op-attrs/tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/tensor_shape.cc @@ -27,35 +27,4 @@ size_t get_size_in_bytes(TensorShape const &s) { return get_num_elements(s) * size_of_datatype(s.data_type); } -bool tensor_shape_is_broadcastable_to(TensorShape const &curr, - TensorShape const &goal) { - return tensor_dims_is_broadcastable_to(curr.dims, goal.dims) && - curr.data_type == goal.data_type; -} - -std::optional - get_broadcast_target_shape(std::unordered_set const &shapes) { - std::unordered_set datatypes = - transform(shapes, [](TensorShape const &s) { return s.data_type; }); - - if (datatypes.size() != 1) { - return std::nullopt; - } - - std::unordered_set shapes_dims = - transform(shapes, [](TensorShape const &s) { return s.dims; }); - - std::optional maybe_result_dims = - get_broadcast_target_dims(shapes_dims); - std::optional result = - transform(maybe_result_dims, [&](TensorDims const &result_dims) { - return TensorShape{ - result_dims, - get_only(datatypes), - }; - }); - - return result; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/datatype.cc index cc7e496c60..d45c156d59 100644 --- a/lib/op-attrs/test/src/datatype.cc +++ b/lib/op-attrs/test/src/datatype.cc @@ -1,6 +1,8 @@ #include "op-attrs/datatype.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("can_promote_datatype_from_to(DataType, DataType)") { diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/dim_ordered/slice.cc index 8640b077dc..8d5f247756 100644 --- a/lib/op-attrs/test/src/dim_ordered/slice.cc +++ b/lib/op-attrs/test/src/dim_ordered/slice.cc @@ -1,5 +1,7 @@ #include "op-attrs/dim_ordered/slice.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE( diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc index d2c758a05f..180bc2a01f 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/enumerate.cc @@ -1,5 +1,5 @@ #include "op-attrs/dim_ordered/enumerate.h" -#include "utils/fmt/map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc index 11e09dc43f..8e3d0f1b80 100644 --- a/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/zip.cc @@ -1,6 +1,6 @@ #include "op-attrs/dim_ordered/zip.h" #include "op-attrs/ff_dim.dtg.h" -#include "utils/fmt/pair.h" +#include "test/utils/doctest/fmt/pair.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc index 17a68ccbc8..7580de24e5 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/dropout.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/dropout.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/dropout.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index b9dd66df5d..cbcebdbce1 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -1,9 +1,9 @@ #include "op-attrs/ops/layer_norm.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include "utils/fmt/optional.h" +#include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc index f6a8da016f..65a74932cb 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/softmax.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/softmax.cc @@ -1,5 +1,6 @@ #include "op-attrs/ops/softmax.h" #include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/expected.h" #include "utils/fmt/expected.h" #include diff --git a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc index 25c7eb036f..60d87300c1 100644 --- a/lib/op-attrs/test/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/test/src/op-attrs/tensor_dims.cc @@ -1,4 +1,5 @@ #include "op-attrs/tensor_dims.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc b/lib/op-attrs/test/src/op-attrs/tensor_shape.cc deleted file mode 100644 index bc715c183a..0000000000 --- a/lib/op-attrs/test/src/op-attrs/tensor_shape.cc +++ /dev/null @@ -1,64 +0,0 @@ -#include "op-attrs/tensor_shape.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_broadcast_target_shape(std::unordered_set)") { - SUBCASE("target exists in inputs") { - DataType datatype = DataType::FLOAT; - - TensorShape s1 = TensorShape{ - TensorDims{FFOrdered{ - 1, - }}, - datatype, - }; - - TensorShape s2 = TensorShape{ - TensorDims{FFOrdered{10, 4, 3}}, - datatype, - }; - - TensorShape s3 = TensorShape{ - TensorDims{FFOrdered{ - 4, - 1, - }}, - datatype, - }; - - std::optional result = - get_broadcast_target_shape({s1, s2, s3}); - std::optional correct = s2; - - CHECK(result == correct); - } - - SUBCASE("datatypes don't match") { - TensorDims dims = TensorDims{FFOrdered{10, 4, 3}}; - - TensorShape s1 = TensorShape{ - dims, - DataType::FLOAT, - }; - - TensorShape s2 = TensorShape{ - dims, - DataType::DOUBLE, - }; - - std::optional result = get_broadcast_target_shape({s1, s2}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - - SUBCASE("inputs is empty") { - std::optional result = get_broadcast_target_shape({}); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - } -} diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/ops/attention.cc index ade219a6a9..2fb804ca8c 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/ops/attention.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " diff --git a/lib/op-attrs/test/src/ops/batch_matmul.cc b/lib/op-attrs/test/src/ops/batch_matmul.cc index 3ff02ccece..56a2e3fa52 100644 --- a/lib/op-attrs/test/src/ops/batch_matmul.cc +++ b/lib/op-attrs/test/src/ops/batch_matmul.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/batch_matmul.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_output_shape(BatchMatmulAttrs, TensorShape)") { diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/ops/cast.cc index 31030ca0f9..c7395316ad 100644 --- a/lib/op-attrs/test/src/ops/cast.cc +++ b/lib/op-attrs/test/src/ops/cast.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/cast.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Cast shape inference") { diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/ops/combine.cc index ac18bbc798..bf74a072e0 100644 --- a/lib/op-attrs/test/src/ops/combine.cc +++ b/lib/op-attrs/test/src/ops/combine.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/combine.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Combine shape inference") { diff --git a/lib/op-attrs/test/src/ops/element_binary.cc b/lib/op-attrs/test/src/ops/element_binary.cc index 0ed695eb89..b091833f10 100644 --- a/lib/op-attrs/test/src/ops/element_binary.cc +++ b/lib/op-attrs/test/src/ops/element_binary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_binary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("EWAdd shape inference") { diff --git a/lib/op-attrs/test/src/ops/element_unary.cc b/lib/op-attrs/test/src/ops/element_unary.cc index 4239782d55..94c382356e 100644 --- a/lib/op-attrs/test/src/ops/element_unary.cc +++ b/lib/op-attrs/test/src/ops/element_unary.cc @@ -1,6 +1,9 @@ #include "op-attrs/ops/element_unary.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ReLU shape inference") { diff --git a/lib/op-attrs/test/src/ops/embedding.cc b/lib/op-attrs/test/src/ops/embedding.cc index 9180f7055d..134737f6c0 100644 --- a/lib/op-attrs/test/src/ops/embedding.cc +++ b/lib/op-attrs/test/src/ops/embedding.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/embedding.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Sum embedding shape inference") { diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/ops/linear.cc index 0d23dc35df..f838ff4285 100644 --- a/lib/op-attrs/test/src/ops/linear.cc +++ b/lib/op-attrs/test/src/ops/linear.cc @@ -1,7 +1,10 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/parallel_tensor_shape.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" #include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Linear shape inference") { diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/ops/reduction.cc index 59ed5bb5ee..0d1c8bdf98 100644 --- a/lib/op-attrs/test/src/ops/reduction.cc +++ b/lib/op-attrs/test/src/ops/reduction.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/reduction.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Reduction shape inference") { diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/ops/repartition.cc index af28a6d471..8bc8205183 100644 --- a/lib/op-attrs/test/src/ops/repartition.cc +++ b/lib/op-attrs/test/src/ops/repartition.cc @@ -1,5 +1,8 @@ #include "op-attrs/ops/repartition.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Repartition shape inference") { diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/ops/replicate.cc index a0ec40cc14..60a1018479 100644 --- a/lib/op-attrs/test/src/ops/replicate.cc +++ b/lib/op-attrs/test/src/ops/replicate.cc @@ -1,5 +1,7 @@ #include "op-attrs/ops/replicate.h" -#include "test/utils/doctest.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Replicate shape inference") { diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc index f485b07b02..20825f5d73 100644 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ b/lib/op-attrs/test/src/test_operator_attrs.cc @@ -1,8 +1,8 @@ #include "op-attrs/computation_graph_op_attrs.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" -#include "utils/json.h" #include #include +#include #include using namespace ::FlexFlow; @@ -10,16 +10,16 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { BatchNormAttrs correct = BatchNormAttrs{true}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + BatchNormAttrs result = j.get(); CHECK(result == correct); } TEST_CASE("ComputationGraphAttrs to/from json") { ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{true}}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + ComputationGraphOpAttrs result = j.get(); CHECK(result == correct); } @@ -29,8 +29,8 @@ TEST_SUITE(FF_TEST_SUITE) { /*repartition_dim=*/ff_dim_t{1}, /*repartition_degree=*/4, }}; - json j = correct; - auto result = j.get(); + nlohmann::json j = correct; + PCGOperatorAttrs result = j.get(); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/test_regularizer_attrs.cc index 35851463bb..6e172d1e8e 100644 --- a/lib/op-attrs/test/src/test_regularizer_attrs.cc +++ b/lib/op-attrs/test/src/test_regularizer_attrs.cc @@ -1,6 +1,8 @@ #include "op-attrs/regularizer_attrs.dtg.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/pcg/CMakeLists.txt b/lib/pcg/CMakeLists.txt index e1875ca694..e6eb182740 100644 --- a/lib/pcg/CMakeLists.txt +++ b/lib/pcg/CMakeLists.txt @@ -10,6 +10,7 @@ ff_add_library( DEPS op-attrs utils + rapidcheck ) add_subdirectory(ffi) diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 088139a0f3..499b26af89 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #include "pcg/computation_graph.dtg.h" +#include "pcg/computation_graph/computation_graph_edge.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" #include "pcg/layer_guid_t.dtg.h" #include "pcg/tensor_attrs.dtg.h" @@ -30,11 +31,24 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n); +std::unordered_set + get_subgraph_incoming_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_outgoing_edges(ComputationGraph const &, + std::unordered_set const &); +std::unordered_set + get_subgraph_successors(ComputationGraph const &, + std::unordered_set const &); + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); layer_guid_t get_layer_by_name(ComputationGraph const &cg, std::string const &name); +std::string as_dot(ComputationGraph const &); +void debug_print_dot(ComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h new file mode 100644 index 0000000000..2a9a9ee04a --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_COMPUTATION_GRAPH_EDGE_H + +#include "pcg/computation_graph/computation_graph_edge.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +layer_guid_t get_computation_graph_edge_src_layer(ComputationGraphEdge const &); +layer_guid_t get_computation_graph_edge_dst_layer(ComputationGraphEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml new file mode 100644 index 0000000000..311c47d277 --- /dev/null +++ b/lib/pcg/include/pcg/computation_graph/computation_graph_edge.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "ComputationGraphEdge" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/dataflow_graph/dataflow_edge.dtg.h", +] + +[[fields]] +name = "raw_edge" +type = "::FlexFlow::DataflowEdge" diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index c641aed6a4..a35763cacc 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -159,9 +159,12 @@ struct ComputationGraphBuilder { std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &projection_initializer = + std::nullopt, std::optional const &bias_initializer = std::nullopt, - std::optional const &name = std::nullopt); + std::optional const &name = std::nullopt, + std::optional const &projection_name = std::nullopt, + std::optional const &bias_name = std::nullopt); // Add a cast layer tensor_guid_t cast(tensor_guid_t const &input, DataType dtype, @@ -225,12 +228,16 @@ struct ComputationGraphBuilder { bool add_zero_attn = false, std::optional initializer = std::nullopt, std::optional const &maybe_name = std::nullopt); - tensor_guid_t create_tensor(TensorShape const &, CreateGrad); + tensor_guid_t + create_input(TensorShape const &, + CreateGrad, + std::optional const &maybe_name = std::nullopt); tensor_guid_t create_weight( TensorShape const &, - bool create_grad = true, + CreateGrad create_grad = CreateGrad::YES, std::optional const &initializer = std::nullopt, - std::optional sync_type = std::nullopt); + std::optional sync_type = std::nullopt, + std::optional const &name = std::nullopt); std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; @@ -243,9 +250,8 @@ struct ComputationGraphBuilder { private: TensorShape get_shape(tensor_guid_t const &) const; - tensor_guid_t broadcast(tensor_guid_t const &, - TensorShape const &, - std::string const &); + tensor_guid_t + broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); @@ -259,13 +265,22 @@ struct ComputationGraphBuilder { std::vector const &weights, std::vector const &outputs); + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + TensorShape const &output); + tensor_guid_t add_layer(LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, TensorShape const &output); - TensorShape get_broadcast_target_shape(std::vector const &); - TensorShape get_broadcast_target_shape(std::vector const &); + tensor_guid_t add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output); + + TensorDims get_broadcast_target_dims(std::vector const &); + TensorDims get_broadcast_target_dims(std::vector const &); tensor_guid_t element_binary(OperatorType, diff --git a/lib/pcg/include/pcg/file_format/file_format.h b/lib/pcg/include/pcg/file_format/file_format.h deleted file mode 100644 index 823846754c..0000000000 --- a/lib/pcg/include/pcg/file_format/file_format.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_FILE_FORMAT_H - -#include "graphs.h" -#include "utils/json.h" - -namespace FlexFlow { - -enum class FileFormatVersion { - V1, - UNSTABLE, -}; - -json to_json(ComputationGraph const &, FileFormatVersion); -ComputationGraph from_json(json const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/keyed_variant.h b/lib/pcg/include/pcg/file_format/keyed_variant.h index 11044de12b..5e29d8c252 100644 --- a/lib/pcg/include/pcg/file_format/keyed_variant.h +++ b/lib/pcg/include/pcg/file_format/keyed_variant.h @@ -1,10 +1,11 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_KEYED_VARIANT_H -#include "utils/json.h" +#include "utils/json/is_jsonable.h" #include "utils/sequence.h" #include "utils/strong_typedef.h" #include "utils/variant.h" +#include namespace FlexFlow { @@ -29,9 +30,9 @@ struct KeyedVariant { }; struct ToJsonFunctor { - ToJsonFunctor(json &j) : j(j) {} + ToJsonFunctor(nlohmann::json &j) : j(j) {} - json &j; + nlohmann::json &j; template void operator()(T const &t) { @@ -42,20 +43,20 @@ struct ToJsonFunctor { }; template -void to_json(json &j, KeyedVariant const &v) { +void to_json(nlohmann::json &j, KeyedVariant const &v) { static_assert(is_jsonable::value, ""); K key = static_cast(v.value.index()); j["type"] = key; - json &jj = j["value"]; + nlohmann::json &jj = j["value"]; visit(ToJsonFunctor{j["value"]}, v.value); } template struct FromJsonFunctor { - FromJsonFunctor(json const &j, int idx) : j(j), idx(idx) {} + FromJsonFunctor(nlohmann::json const &j, int idx) : j(j), idx(idx) {} - json const &j; + nlohmann::json const &j; int idx; template @@ -68,31 +69,31 @@ struct FromJsonFunctor { template std::string get_json_name(T const &t) { - return json{t}.get(); + return nlohmann::json{t}.get(); } template struct FromJsonMoveOnlyFunctor { - FromJsonMoveOnlyFunctor(json const &j, Key const &key) : j(j) {} + FromJsonMoveOnlyFunctor(nlohmann::json const &j, Key const &key) : j(j) {} - json const &j; + nlohmann::json const &j; Key const &key; template Variant operator()(std::integral_constant const &) const { - return j.get::type>(); + return j.get::type>(); } }; template -Variant from_json_moveonly(json const &j, K const &key) { +Variant from_json_moveonly(nlohmann::json const &j, K const &key) { FromJsonMoveOnlyFunctor func(j); return seq_get(func, idx, seq_count_t::value>{}); } template typename std::enable_if::value>::type - from_json(json const &j, KeyedVariant &v) { + from_json(nlohmann::json const &j, KeyedVariant &v) { K key = j.at("type").get(); std::string key_string = j.at("type").get(); @@ -100,7 +101,7 @@ typename std::enable_if::value>::type } template -KeyedVariant keyed_variant_from_json(json const &j) { +KeyedVariant keyed_variant_from_json(nlohmann::json const &j) { K key = j.at("type").get(); return KeyedVariant{ diff --git a/lib/pcg/include/pcg/file_format/v1/data_type_value.h b/lib/pcg/include/pcg/file_format/v1/data_type_value.h index 6e4e5abc54..ec3910aab3 100644 --- a/lib/pcg/include/pcg/file_format/v1/data_type_value.h +++ b/lib/pcg/include/pcg/file_format/v1/data_type_value.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_DATA_TYPE_H #include "utils/fp16.h" -#include "utils/json.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/file_format/v1/graphs.h b/lib/pcg/include/pcg/file_format/v1/graphs.h deleted file mode 100644 index 702c79c2b6..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/graphs.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_GRAPHS_H - -#include "pcg/computation_graph.dtg.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h" -#include "pcg/layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" -#include "pcg/tensor_attrs.dtg.h" -#include "utils/json.h" - -namespace FlexFlow { - -using V1ComputationGraph = V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ComputationGraph); -V1ComputationGraph to_v1(ComputationGraph const &); - -using V1ParallelComputationGraph = - V1LabelledDataflowGraph; -CHECK_IS_JSONABLE(V1ParallelComputationGraph); -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml index d9aade739c..c332b6b41d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1DataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -13,8 +13,13 @@ includes = [ "", "", "pcg/file_format/v1/graphs/v1_graph_edge.dtg.h", +] + +src_includes = [ "utils/fmt/vector.h", + "utils/hash/vector.h", "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h index 48203d73ae..fc9dfcef9a 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h @@ -13,8 +13,9 @@ namespace FlexFlow { template -V1LabelledDataflowGraph - to_v1(LabelledDataflowGraphView const &g) { +std::pair, bidict> + to_v1_including_node_numbering( + LabelledDataflowGraphView const &g) { bidict nodes = bidict_from_enumerating(get_nodes(g)); @@ -29,8 +30,17 @@ V1LabelledDataflowGraph [&](DataflowOutput const &o) { return g.at(o); }); }); - return V1LabelledDataflowGraph{ - node_labels, output_labels, unlabelled}; + return { + V1LabelledDataflowGraph{ + node_labels, output_labels, unlabelled}, + nodes, + }; +} + +template +V1LabelledDataflowGraph + to_v1(LabelledDataflowGraphView const &g) { + return to_v1_including_node_numbering(g).first; } } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml index fd8d4c39c4..b440d0f03d 100644 --- a/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml +++ b/lib/pcg/include/pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.struct.toml @@ -1,9 +1,9 @@ namespace = "FlexFlow" name = "V1LabelledDataflowGraph" features = [ - # "eq", + "eq", # "ord", - # "hash", + "hash", "json", # "rapidcheck", "fmt", @@ -20,6 +20,13 @@ includes = [ "pcg/file_format/v1/graphs/v1_graph_output.dtg.h", ] +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + [[fields]] name = "node_labels" type = "std::unordered_map" @@ -31,4 +38,3 @@ type = "std::unordered_map>" [[fields]] name = "graph" type = "::FlexFlow::V1DataflowGraph" - diff --git a/lib/pcg/include/pcg/file_format/v1/v1.h b/lib/pcg/include/pcg/file_format/v1/v1.h deleted file mode 100644 index e2557af4f5..0000000000 --- a/lib/pcg/include/pcg/file_format/v1/v1.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_H - -#include "graphs.h" -#include "pcg/computation_graph.h" - -namespace FlexFlow {} - -#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h new file mode 100644 index 0000000000..5590d6999b --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_COMPUTATION_GRAPH_H + +#include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_computation_graph.dtg.h" +#include "pcg/layer_guid_t.dtg.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &); + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml new file mode 100644 index 0000000000..0d7135ec74 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/layer_attrs.dtg.h", + "pcg/tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::LayerAttrs, ::FlexFlow::TensorAttrs>" diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h new file mode 100644 index 0000000000..aceb59f5af --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_PARALLEL_COMPUTATION_GRAPH_H + +#include "pcg/file_format/v1/v1_parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml new file mode 100644 index 0000000000..16be4a9561 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_parallel_computation_graph.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "V1ParallelComputationGraph" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", + "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.dtg.h", +] + +[[fields]] +name = "raw_graph" +type = "::FlexFlow::V1LabelledDataflowGraph<::FlexFlow::ParallelLayerAttrs, ::FlexFlow::ParallelTensorAttrs>" diff --git a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml index 12917d0989..4e3c31bd36 100644 --- a/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml +++ b/lib/pcg/include/pcg/initializers/constant_initializer_attrs.struct.toml @@ -10,12 +10,7 @@ features = [ ] includes = [ - "op-attrs/datatype.h", - "utils/json.h", -] - -src_includes = [ - "utils/fmt/variant.h", + "op-attrs/datatype_value.dtg.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/layer_attrs.struct.toml b/lib/pcg/include/pcg/layer_attrs.struct.toml index d062f6cd78..8290795174 100644 --- a/lib/pcg/include/pcg/layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/layer_attrs.struct.toml @@ -13,11 +13,11 @@ includes = [ "op-attrs/computation_graph_op_attrs.dtg.h", "utils/stack_string.h", "", - "utils/json.h" ] src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 60cfc426cc..4d61f24d37 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -17,6 +17,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml index d9e6cf113b..323932fec6 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_attrs.struct.toml @@ -19,6 +19,8 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] diff --git a/lib/pcg/include/pcg/tensor_attrs.struct.toml b/lib/pcg/include/pcg/tensor_attrs.struct.toml index c0b89cfc99..7f16e60914 100644 --- a/lib/pcg/include/pcg/tensor_attrs.struct.toml +++ b/lib/pcg/include/pcg/tensor_attrs.struct.toml @@ -19,6 +19,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", ] [[fields]] diff --git a/lib/pcg/src/file_format.cc b/lib/pcg/src/file_format.cc deleted file mode 100644 index bb01ac2dbf..0000000000 --- a/lib/pcg/src/file_format.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "pcg/file_format/v1/v1.h" - -namespace FlexFlow { - -/* void thing() { */ -/* static_assert(is_visitable::value, ""); */ - -/* json j; */ -/* auto g = j.get(); */ - -/* /1* IllBehaved v = j.get(); *1/ */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/pcg/src/file_format/v1/graphs.cc b/lib/pcg/src/file_format/v1/graphs.cc deleted file mode 100644 index de8d5dddb4..0000000000 --- a/lib/pcg/src/file_format/v1/graphs.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "pcg/file_format/v1/graphs.h" -#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" -#include "utils/graph/algorithms.h" -#include "utils/integer_conversions.h" - -namespace FlexFlow { - -V1ComputationGraph to_v1(ComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { - return to_v1(g.raw_graph); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index deaa440ef8..cf4b1496cf 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,11 +1,18 @@ #include "pcg/computation_graph.h" +#include "op-attrs/computation_graph_op_attrs.h" #include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" +#include "utils/record_formatter.h" namespace FlexFlow { @@ -20,6 +27,23 @@ std::unordered_set get_layers(ComputationGraph const &cg) { [&](Node const &n) { return layer_guid_t{n}; }); } +LayerAddedResult add_layer(ComputationGraph &computation_graph, + LayerAttrs const &attrs, + std::vector const &inputs, + std::vector const &outputs) { + std::vector raw_inputs = transform( + inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); + + NodeAddedResult added = + computation_graph.raw_graph.add_node(attrs, raw_inputs, outputs); + + return LayerAddedResult{ + layer_guid_t{added.node}, + transform(added.outputs, + [](DataflowOutput const &o) { return tensor_guid_t{o}; }), + }; +} + TensorAttrs get_tensor_attrs(ComputationGraph const &cg, tensor_guid_t const &t) { return cg.raw_graph.at(t.raw_graph_output); @@ -39,8 +63,7 @@ std::vector topological_ordering(ComputationGraph const &cg) { std::vector reverse_topological_ordering(ComputationGraph const &cg) { - std::vector layers = - reversed>(get_topological_ordering(cg.raw_graph)); + std::vector layers = reversed(get_topological_ordering(cg.raw_graph)); return transform( layers, [&](Node const &e) -> layer_guid_t { return layer_guid_t{e}; }); } @@ -57,6 +80,47 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } +std::unordered_set get_subgraph_incoming_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_incoming_edges = + get_subgraph_incoming_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_incoming_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_outgoing_edges( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_outgoing_edges = + get_subgraph_outgoing_edges(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_outgoing_edges, [](DataflowEdge const &e) { + return ComputationGraphEdge{e}; + }); +} + +std::unordered_set get_subgraph_successors( + ComputationGraph const &cg, + std::unordered_set const &subgraph_nodes) { + + std::unordered_set raw_subgraph_nodes = transform( + subgraph_nodes, [](layer_guid_t const &l) { return l.raw_node; }); + std::unordered_set raw_successors = + get_subgraph_successors(cg.raw_graph, raw_subgraph_nodes); + + return transform(raw_successors, + [](Node const &n) { return layer_guid_t{n}; }); +} + LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n) { return cg.raw_graph.at(n.raw_node); } @@ -70,4 +134,40 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, return get_only(found); } +std::string as_dot(ComputationGraph const &cg) { + std::function get_node_label = + [](LayerAttrs const &a) -> std::string { + RecordFormatter r = as_dot(a.attrs); + + if (a.name.has_value()) { + RecordFormatter rr; + rr << "Name" << a.name.value(); + r << rr; + } + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + std::function get_input_label = + [](TensorAttrs const &a) -> std::string { + RecordFormatter r; + + r << fmt::to_string(a.shape); + + std::ostringstream oss; + oss << r; + return oss.str(); + }; + + return as_dot(view_as_labelled_open_dataflow_graph(cg.raw_graph), + get_node_label, + get_input_label); +} + +void debug_print_dot(ComputationGraph const &cg) { + std::cout << as_dot(cg) << std::endl; +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc new file mode 100644 index 0000000000..0efa0620c4 --- /dev/null +++ b/lib/pcg/src/pcg/computation_graph/computation_graph_edge.cc @@ -0,0 +1,15 @@ +#include "pcg/computation_graph/computation_graph_edge.h" + +namespace FlexFlow { + +layer_guid_t + get_computation_graph_edge_src_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.src.node}; +} + +layer_guid_t + get_computation_graph_edge_dst_layer(ComputationGraphEdge const &e) { + return layer_guid_t{e.raw_edge.dst.node}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 3f2feaf619..e0b6935a6d 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -15,6 +15,7 @@ #include "op-attrs/ops/linear.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/tensor_dims.h" #include "pcg/computation_graph.h" #include "utils/containers/any_of.h" #include "utils/containers/concat_vectors.h" @@ -26,6 +27,16 @@ namespace FlexFlow { +static TensorAttrs make_weight_attrs( + TensorShape const &shape, + std::optional const &initializer_attrs) { + return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; +} + +static TensorAttrs make_output_attrs(TensorShape const &shape) { + return TensorAttrs{shape, std::nullopt, std::nullopt, CreateGrad::YES}; +} + ComputationGraphBuilder::ComputationGraphBuilder() : computation_graph(make_empty_computation_graph()) {} @@ -33,13 +44,31 @@ TensorShape ComputationGraphBuilder::get_shape(tensor_guid_t const &t) const { return get_tensor_attrs(this->computation_graph, t).shape; } -tensor_guid_t ComputationGraphBuilder::create_tensor(TensorShape const &shape, - CreateGrad create_grad) { +tensor_guid_t ComputationGraphBuilder::create_input( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &maybe_name) { TensorAttrs tensor_attrs = TensorAttrs{shape, std::nullopt, std::nullopt, create_grad}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, - std::nullopt, + maybe_name, + }; + + return this->add_layer(layer_attrs, {}, {}, tensor_attrs); +} + +tensor_guid_t ComputationGraphBuilder::create_weight( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &initializer, + std::optional param_sync, + std::optional const &maybe_name) { + TensorAttrs tensor_attrs = + TensorAttrs{shape, initializer, param_sync, create_grad}; + LayerAttrs layer_attrs = LayerAttrs{ + ComputationGraphOpAttrs{InputAttrs{}}, + maybe_name, }; return this->add_layer(layer_attrs, {}, {}, tensor_attrs); @@ -98,9 +127,31 @@ std::vector ComputationGraphBuilder::add_layer( std::vector const &weights, std::vector const &outputs) { return this->add_layer( - layer, inputs, weights, transform(outputs, [](TensorShape const &s) { - return TensorAttrs{s, std::nullopt, std::nullopt, CreateGrad::YES}; - })); + layer, inputs, weights, transform(outputs, make_output_attrs)); +} + +tensor_guid_t ComputationGraphBuilder::add_layer( + LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + TensorShape const &output_shape) { + + TensorAttrs output_attrs = make_output_attrs(output_shape); + LayerAddedResult added = + ::FlexFlow::add_layer(this->computation_graph, + layer, + concat_vectors(inputs, weights), + {output_attrs}); + return get_only(added.outputs); +} + +tensor_guid_t + ComputationGraphBuilder::add_layer(LayerAttrs const &layer, + std::vector const &inputs, + TensorShape const &output_shape) { + + std::vector weights = {}; + return this->add_layer(layer, inputs, weights, output_shape); } tensor_guid_t @@ -129,25 +180,28 @@ tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, } } -tensor_guid_t - ComputationGraphBuilder::broadcast(tensor_guid_t const &input, - TensorShape const &target_shape, - std::string const &name) { +tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input, + TensorDims const &target_dims, + std::string const &name) { TensorShape input_shape = this->get_shape(input); - if (!tensor_shape_is_broadcastable_to(input_shape, target_shape)) { + if (input_shape.dims == target_dims) { + return input; + } + + if (!tensor_dims_is_broadcastable_to(input_shape.dims, target_dims)) { throw mk_runtime_error(fmt::format( - "Cannot broadcast input tensor of shape {} to target shape {}", - input_shape, - target_shape)); + "Cannot broadcast input tensor of dims {} to target dims {}", + input_shape.dims, + target_dims)); } - BroadcastAttrs attrs = BroadcastAttrs{target_shape.dims}; + BroadcastAttrs attrs = BroadcastAttrs{target_dims}; LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t @@ -184,7 +238,7 @@ tensor_guid_t ComputationGraphBuilder::element_unary( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::element_binary( @@ -194,18 +248,18 @@ tensor_guid_t ComputationGraphBuilder::element_binary( std::optional const &maybe_name) { std::string name = maybe_name.value_or(get_default_name(op_type)); - TensorShape compute_shape = this->get_broadcast_target_shape({lhs, rhs}); + TensorDims compute_dims = this->get_broadcast_target_dims({lhs, rhs}); DataType compute_type = std::max(this->get_shape(lhs).data_type, this->get_shape(rhs).data_type); tensor_guid_t lhs_input = this->as_type( this->broadcast( - lhs, compute_shape, fmt::format("{}_inputl_broadcast", name)), + lhs, compute_dims, fmt::format("{}_inputl_broadcast", name)), compute_type, name + "_inputl_cast"); tensor_guid_t rhs_input = this->as_type( this->broadcast( - rhs, compute_shape, fmt::format("{}_inputr_broadcast", name)), + rhs, compute_dims, fmt::format("{}_inputr_broadcast", name)), compute_type, name + "_inputr_cast"); @@ -217,7 +271,7 @@ tensor_guid_t ComputationGraphBuilder::element_binary( TensorShape output_shape = throw_if_unexpected(get_output_shape( attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); - return this->add_layer(layer, {lhs_input, rhs_input}, {}, output_shape); + return this->add_layer(layer, {lhs_input, rhs_input}, output_shape); } tensor_guid_t @@ -359,12 +413,6 @@ tensor_guid_t return this->element_unary(OperatorType::ELU, input, std::nullopt, name); } -static TensorAttrs make_weight_attrs( - TensorShape const &shape, - std::optional const &initializer_attrs) { - return TensorAttrs{shape, initializer_attrs, std::nullopt, CreateGrad::YES}; -} - tensor_guid_t ComputationGraphBuilder::conv2d( tensor_guid_t const &x, int outChannels, @@ -431,7 +479,7 @@ tensor_guid_t ComputationGraphBuilder::dropout( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::embedding( @@ -483,7 +531,7 @@ tensor_guid_t ComputationGraphBuilder::gather( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } /* std::vector @@ -531,7 +579,7 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -581,26 +629,26 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( output_shape); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( std::vector const &inputs) { - std::vector input_shapes = transform( - inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); + std::vector inputs_dims = transform( + inputs, [&](tensor_guid_t const &t) { return this->get_shape(t).dims; }); - return this->get_broadcast_target_shape(input_shapes); + return this->get_broadcast_target_dims(inputs_dims); } -TensorShape ComputationGraphBuilder::get_broadcast_target_shape( - std::vector const &input_shapes) { - std::optional maybe_result = - ::FlexFlow::get_broadcast_target_shape(unordered_set_of(input_shapes)); +TensorDims ComputationGraphBuilder::get_broadcast_target_dims( + std::vector const &inputs_dims) { + std::optional maybe_result = + ::FlexFlow::get_broadcast_target_dims(unordered_set_of(inputs_dims)); if (maybe_result.has_value()) { return maybe_result.value(); } else { throw mk_runtime_error(fmt::format( - "ComputationGraphBuilder::get_broadcast_target_shape failed to find " - "target tensor shape for input tensor shapes {}", - input_shapes)); + "ComputationGraphBuilder::get_broadcast_target_dims failed to find " + "target tensor dims for input tensor dims {}", + inputs_dims)); } } @@ -610,9 +658,11 @@ tensor_guid_t ComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, - std::optional const &maybe_name) { + std::optional const &maybe_name, + std::optional const &projection_name, + std::optional const &bias_name) { LinearAttrs attrs = LinearAttrs{outDim, use_bias, data_type, activation, std::nullopt}; @@ -623,15 +673,30 @@ tensor_guid_t ComputationGraphBuilder::dense( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - std::vector weights; - TensorShape kernel_shape = + std::vector weights; + + TensorShape projection_shape = throw_if_unexpected(get_kernel_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + + tensor_guid_t projection_weights = + this->create_weight(projection_shape, + CreateGrad::YES, + projection_initializer, + /*sync_type=*/std::nullopt, + projection_name); + + weights.push_back(projection_weights); if (use_bias) { TensorShape bias_shape = throw_if_unexpected(get_bias_shape(attrs, this->get_shape(input))); - weights.push_back(make_weight_attrs(bias_shape, bias_initializer)); + + tensor_guid_t bias_weights = this->create_weight(bias_shape, + CreateGrad::YES, + bias_initializer, + /*sync_type=*/std::nullopt, + bias_name); + weights.push_back(bias_weights); } return this->add_layer(layer, {input}, weights, output_shape); @@ -677,13 +742,13 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( TensorShape gamma_shape = throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); InitializerAttrs gamma_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{1}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); TensorShape beta_shape = throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); InitializerAttrs beta_initializer = - InitializerAttrs{ConstantInitializerAttrs{float{0}}}; + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } @@ -716,7 +781,7 @@ tensor_guid_t ComputationGraphBuilder::softmax( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, {}, output_shape); + return this->add_layer(layer, {input}, output_shape); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..975e92dfb7 --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,24 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ComputationGraph to_v1(ComputationGraph const &g) { + return V1ComputationGraph{ + to_v1(g.raw_graph), + }; +} + +std::pair> + to_v1_including_node_numbering(ComputationGraph const &cg) { + std::pair, bidict> + raw = + to_v1_including_node_numbering(cg.raw_graph); + V1ComputationGraph v1_cg = V1ComputationGraph{raw.first}; + bidict v1_node_ids = + map_values(raw.second, [](Node const &n) { return layer_guid_t{n}; }); + + return {v1_cg, v1_node_ids}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..9da58fcf6e --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,12 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/file_format/v1/graphs/v1_labelled_dataflow_graph.h" + +namespace FlexFlow { + +V1ParallelComputationGraph to_v1(ParallelComputationGraph const &g) { + return V1ParallelComputationGraph{ + to_v1(g.raw_graph), + }; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc new file mode 100644 index 0000000000..8336d81bb4 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_computation_graph.cc @@ -0,0 +1,30 @@ +#include "pcg/file_format/v1/v1_computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ComputationGraph") { + ComputationGraph cg = [] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 12, + 16, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + tensor_guid_t mm_output = b.dense(input, 8); + tensor_guid_t relu_output = b.relu(mm_output); + + return b.computation_graph; + }(); + + V1ComputationGraph v1_cg = to_v1(cg); + nlohmann::json j = v1_cg; + } +} diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc new file mode 100644 index 0000000000..8ce25c4bc5 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_parallel_computation_graph.cc @@ -0,0 +1,36 @@ +#include "pcg/file_format/v1/v1_parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("V1ParallelComputationGraph") { + ParallelComputationGraph pcg = [] { + ParallelComputationGraphBuilder b; + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + parallel_tensor_guid_t input = b.create_input_tensor(input_shape); + parallel_tensor_guid_t mm_output = b.dense(input, 8); + parallel_tensor_guid_t relu_output = b.relu(mm_output); + + return b.pcg; + }(); + + V1ParallelComputationGraph v1_pcg = to_v1(pcg); + nlohmann::json j = v1_pcg; + } +} diff --git a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc index 0b75e3ae1a..703c129da4 100644 --- a/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc +++ b/lib/pcg/test/src/pcg/initializers/uniform_initializer_attrs.cc @@ -1,6 +1,8 @@ #include "pcg/initializers/uniform_initializer_attrs.h" -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 440f735e80..f46f267859 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -3,7 +3,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" -#include "test/utils/doctest.h" #include "utils/containers/count.h" #include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" @@ -12,6 +11,9 @@ #include "utils/containers/values.h" #include "utils/containers/without_nullopts.h" #include "utils/hash/pair.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { @@ -227,7 +229,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(num_replicate_attrs == 2); parallel_layer_guid_t conv_guid = get_only(without_nullopts(transform( - as_vector(items(layers)), + vector_of(items(layers)), [](std::pair const &kv) -> std::optional { if (get_op_type(kv.second) == OperatorType::CONV2D) { diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/test_computation_graph_builder.cc index 936c2de00d..ff169d8312 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/test_computation_graph_builder.cc @@ -15,7 +15,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - tensor_guid_t input = b.create_tensor(input_shape, CreateGrad::YES); + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); tensor_guid_t output = b.conv2d(input, /*outChannels=*/5, /*kernelH=*/3, diff --git a/lib/pcg/test/src/test_machine_view.cc b/lib/pcg/test/src/test_machine_view.cc index 70fe958d8c..25c6e21b87 100644 --- a/lib/pcg/test/src/test_machine_view.cc +++ b/lib/pcg/test/src/test_machine_view.cc @@ -1,7 +1,7 @@ #include "pcg/machine_view.h" #include "pcg/strided_rectangle.h" #include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/pcg/test/src/test_strided_rectangle.cc b/lib/pcg/test/src/test_strided_rectangle.cc index 2fe3005b15..ac6af9fa19 100644 --- a/lib/pcg/test/src/test_strided_rectangle.cc +++ b/lib/pcg/test/src/test_strided_rectangle.cc @@ -1,6 +1,6 @@ #include "pcg/strided_rectangle.h" #include "pcg/strided_rectangle_side.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/runtime/src/accessor.cc b/lib/runtime/src/accessor.cc index 44ad8ab40d..84573fb4aa 100644 --- a/lib/runtime/src/accessor.cc +++ b/lib/runtime/src/accessor.cc @@ -129,7 +129,7 @@ struct GetTensorPointerWOFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerWO>( + return (void *)helperGetTensorPointerWO>( region, req, fid, ctx, runtime); } }; @@ -141,7 +141,7 @@ struct GetTensorPointerROFunctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void const *)helperGetTensorPointerRO>( + return (void const *)helperGetTensorPointerRO>( region, req, fid, ctx, runtime); } }; @@ -153,7 +153,7 @@ struct GetTensorPointerRWFUnctor { FieldID fid, Context ctx, Runtime *runtime) const { - return (void *)helperGetTensorPointerRW>( + return (void *)helperGetTensorPointerRW>( region, req, fid, ctx, runtime); } }; diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index 7df65ef361..ad36f1bc4b 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -25,6 +25,7 @@ includes = [ src_includes = [ "utils/fmt/optional.h", + "utils/json/optional.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index 00032045c0..2d76352ccf 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -45,10 +45,6 @@ std::unordered_set get_subgraph_outgoing_edges( SubParallelComputationGraph const &, std::unordered_set const &); -std::unordered_set get_subgraph_incoming_edges( - SubParallelComputationGraph const &, - std::unordered_set const &); - std::unordered_set get_parallel_tensor_uses(SubParallelComputationGraph const &, open_parallel_tensor_guid_t const &); diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index 26f8ff5062..a18737085a 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -1,6 +1,6 @@ #include "substitutions/operator_pattern/get_attribute.h" #include "op-attrs/get_op_type.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { @@ -364,7 +364,7 @@ std::optional get_attribute(TransposeAttrs const &p, case OperatorAttributeKey::OP_TYPE: return get_op_type(p); case OperatorAttributeKey::PERMUTATION: - return as_vector(p.perm); + return vector_of(p.perm); default: return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc index 05f21247c7..286bc69b84 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/get_attribute.cc @@ -1,7 +1,7 @@ #include "substitutions/tensor_pattern/get_attribute.h" #include "op-attrs/parallel_tensor_dims.h" -#include "utils/containers/as_vector.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -11,13 +11,13 @@ TensorAttributeValue get_attribute(ParallelTensorAttrs const &attrs, switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector sizes = - transform(as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + transform(vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return d.size; }); return TensorAttributeValue{sizes}; } case TensorAttributeKey::DIM_DEGREES: { std::vector degrees = transform( - as_vector(ff_ordered_shard_dims(attrs.shape.dims)), + vector_of(ff_ordered_shard_dims(attrs.shape.dims)), [](ShardParallelDim const &d) { return size_t_from_int(d.degree); }); return TensorAttributeValue{degrees}; } diff --git a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc index 70e960bc73..95b61e0ef4 100644 --- a/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/test/src/substitutions/operator_pattern/get_attribute.cc @@ -1,4 +1,5 @@ #include "substitutions/operator_pattern/get_attribute.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 6922798a97..4f56a76d0d 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -5,9 +5,9 @@ #include "substitutions/operator_pattern/operator_attribute_constraint.h" #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/tensor_pattern/tensor_attribute_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc index 6621145d39..e4d763d9c3 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/pattern_split.cc @@ -1,10 +1,10 @@ #include "substitutions/unlabelled/pattern_split.h" #include "substitutions/unlabelled/pattern_value.h" #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc index 3475c10235..e0805dbfd4 100644 --- a/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc +++ b/lib/substitutions/test/src/substitutions/unlabelled/unlabelled_graph_pattern.cc @@ -1,8 +1,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" -#include "test/utils/doctest.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" +#include using namespace ::FlexFlow; diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index 9478195523..aeedd65f82 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -1,9 +1,6 @@ -#include "doctest/doctest.h" -#include "rapidcheck.h" #include "substitutions/unlabelled/find_pattern_matches.h" #include "substitutions/unlabelled/match_additional_criterion.h" #include "substitutions/unlabelled/pattern_matching.h" -#include "test/utils/all.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -13,6 +10,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.h" #include "utils/graph/open_dataflow_graph/open_dataflow_graph.h" #include "utils/overload.h" +#include using namespace FlexFlow; diff --git a/lib/utils/CMakeLists.txt b/lib/utils/CMakeLists.txt index ae5e120fad..a0d77b9f76 100644 --- a/lib/utils/CMakeLists.txt +++ b/lib/utils/CMakeLists.txt @@ -13,7 +13,6 @@ ff_add_library( fmt json cuda - doctest ) add_subdirectory(ffi) diff --git a/lib/utils/include/utils/cli/cli_argument_key.variant.toml b/lib/utils/include/utils/cli/cli_argument_key.variant.toml new file mode 100644 index 0000000000..be118160ce --- /dev/null +++ b/lib/utils/include/utils/cli/cli_argument_key.variant.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "CLIArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_positional_argument_key.dtg.h", + "utils/cli/cli_flag_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::CLIPositionalArgumentKey" + +[[values]] +type = "::FlexFlow::CLIFlagKey" diff --git a/lib/utils/include/utils/cli/cli_flag_key.struct.toml b/lib/utils/include/utils/cli/cli_flag_key.struct.toml new file mode 100644 index 0000000000..790a752911 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIFlagKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_flag_spec.struct.toml b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml new file mode 100644 index 0000000000..66a47de067 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_flag_spec.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "CLIFlagSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "long_flag" +type = "std::string" + +[[fields]] +name = "short_flag" +type = "std::optional" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_get_help_message.h b/lib/utils/include/utils/cli/cli_get_help_message.h new file mode 100644 index 0000000000..d51579a8e2 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_get_help_message.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_GET_HELP_MESSAGE_H + +#include "utils/cli/cli_spec.dtg.h" + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse.h b/lib/utils/include/utils/cli/cli_parse.h new file mode 100644 index 0000000000..3c91a8423b --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_H + +#include "utils/cli/cli_parse_result.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg); +tl::expected + cli_parse(CLISpec const &, std::vector const &); +tl::expected + cli_parse(CLISpec const &, int argc, char const *const *argv); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.h b/lib/utils/include/utils/cli/cli_parse_result.h new file mode 100644 index 0000000000..155caac7ae --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_PARSE_RESULT_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_parse_result.dtg.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &, CLIArgumentKey const &); +std::string cli_get_argument(CLIParseResult const &, CLIArgumentKey const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_parse_result.struct.toml b/lib/utils/include/utils/cli/cli_parse_result.struct.toml new file mode 100644 index 0000000000..b63da7be14 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_parse_result.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "CLIParseResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "utils/cli/cli_flag_key.dtg.h", + "utils/cli/cli_positional_argument_key.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "flags" +type = "std::unordered_map<::FlexFlow::CLIFlagKey, bool>" + +[[fields]] +name = "positional_arguments" +type = "std::unordered_map<::FlexFlow::CLIPositionalArgumentKey, std::string>" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml new file mode 100644 index 0000000000..d571d0deb3 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_key.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentKey" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "raw_idx" +type = "int" diff --git a/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml new file mode 100644 index 0000000000..b1e74701ee --- /dev/null +++ b/lib/utils/include/utils/cli/cli_positional_argument_spec.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "CLIPositionalArgumentSpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "name" +type = "std::string" + +[[fields]] +name = "choices" +type = "std::optional>" + +[[fields]] +name = "description" +type = "std::optional" diff --git a/lib/utils/include/utils/cli/cli_spec.h b/lib/utils/include/utils/cli/cli_spec.h new file mode 100644 index 0000000000..2c0df08c55 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CLI_CLI_SPEC_H + +#include "utils/cli/cli_argument_key.dtg.h" +#include "utils/cli/cli_flag_spec.dtg.h" +#include "utils/cli/cli_spec.dtg.h" +#include + +namespace FlexFlow { + +CLISpec empty_cli_spec(); +std::vector cli_get_flag_keys(CLISpec const &); +CLIArgumentKey cli_add_help_flag(CLISpec &); +CLIArgumentKey cli_add_flag(CLISpec &, CLIFlagSpec const &); +CLIArgumentKey cli_add_positional_argument(CLISpec &, + CLIPositionalArgumentSpec const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/cli/cli_spec.struct.toml b/lib/utils/include/utils/cli/cli_spec.struct.toml new file mode 100644 index 0000000000..9f64f62c15 --- /dev/null +++ b/lib/utils/include/utils/cli/cli_spec.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "CLISpec" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/cli/cli_flag_spec.dtg.h", + "utils/cli/cli_positional_argument_spec.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "flags" +type = "std::vector<::FlexFlow::CLIFlagSpec>" + +[[fields]] +name = "positional_arguments" +type = "std::vector<::FlexFlow::CLIPositionalArgumentSpec>" diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 937ed51af2..20ab6ce440 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -74,9 +74,6 @@ bool are_all_same(C const &c); template std::function compare_by(F const &f); -template -typename C::value_type maximum(C const &v); - template T reversed(T const &t); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 7c0490fa2a..f60ef77cda 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -179,11 +179,6 @@ std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; } -template -typename C::value_type maximum(C const &v) { - return *std::max_element(v.begin(), v.end()); -} - template std::vector value_all(std::vector> const &v) { return transform(v, [](std::optional const &element) { diff --git a/lib/utils/include/utils/containers/enumerate_vector.h b/lib/utils/include/utils/containers/enumerate_vector.h index 11ee8d2352..700106ea3f 100644 --- a/lib/utils/include/utils/containers/enumerate_vector.h +++ b/lib/utils/include/utils/containers/enumerate_vector.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ENUMERATE_VECTOR_H #include -#include #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/foldl1.h b/lib/utils/include/utils/containers/foldl1.h new file mode 100644 index 0000000000..f542f8cf00 --- /dev/null +++ b/lib/utils/include/utils/containers/foldl1.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDL1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldl1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldl1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.cbegin(); + T result = *it; + it++; + + for (; it != vec.cend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/foldr1.h b/lib/utils/include/utils/containers/foldr1.h new file mode 100644 index 0000000000..4a7e8e098c --- /dev/null +++ b/lib/utils/include/utils/containers/foldr1.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FOLDR1_H + +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +T foldr1(std::vector const &vec, F f) { + if (vec.empty()) { + throw mk_runtime_error(fmt::format( + "foldr1 expected non-empty vector, but receieved empty vector")); + } + + auto it = vec.crbegin(); + T result = *it; + it++; + for (; it != vec.crend(); it++) { + result = f(result, *it); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/generate_map.h b/lib/utils/include/utils/containers/generate_map.h index 1afa534a19..53b2a590c5 100644 --- a/lib/utils/include/utils/containers/generate_map.h +++ b/lib/utils/include/utils/containers/generate_map.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_MAP_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" #include "utils/type_traits_core.h" #include @@ -17,7 +17,7 @@ std::unordered_map generate_map(C const &c, F const &f) { static_assert(is_hashable_v, "Key type should be hashable (but is not)"); auto transformed = - vector_transform(as_vector(c), [&](K const &k) -> std::pair { + vector_transform(vector_of(c), [&](K const &k) -> std::pair { return {k, f(k)}; }); return {transformed.cbegin(), transformed.cend()}; diff --git a/lib/utils/include/utils/containers/get_first.h b/lib/utils/include/utils/containers/get_first.h index ce2a483401..a616c44c20 100644 --- a/lib/utils/include/utils/containers/get_first.h +++ b/lib/utils/include/utils/containers/get_first.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H +#include #include namespace FlexFlow { @@ -10,6 +11,11 @@ T get_first(std::unordered_set const &s) { return *s.cbegin(); } +template +T get_first(std::set const &s) { + return *s.cbegin(); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h new file mode 100644 index 0000000000..634bb61bc1 --- /dev/null +++ b/lib/utils/include/utils/containers/maximum.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H + +#include +#include + +namespace FlexFlow { + +template +std::optional maximum(C const &v) { + if (v.empty()) { + return std::nullopt; + } + + return *std::max_element(std::cbegin(v), std::cend(v)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/multiset_union.h b/lib/utils/include/utils/containers/multiset_union.h new file mode 100644 index 0000000000..6f2b2a7889 --- /dev/null +++ b/lib/utils/include/utils/containers/multiset_union.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MULTISET_UNION_H + +#include +#include + +namespace FlexFlow { + +template +std::unordered_multiset + multiset_union(std::unordered_multiset const &lhs, + std::unordered_multiset const &rhs) { + std::unordered_multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::multiset multiset_union(std::multiset const &lhs, + std::multiset const &rhs) { + std::multiset result = lhs; + + for (T const &t : rhs) { + result.insert(t); + } + + return result; +} + +template +std::unordered_multiset multiset_union(C const &c) { + std::unordered_multiset result; + for (auto const &s : c) { + for (T const &element : s) { + result.insert(element); + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_no_duplicates.h b/lib/utils/include/utils/containers/require_no_duplicates.h new file mode 100644 index 0000000000..0cbe361bdd --- /dev/null +++ b/lib/utils/include/utils/containers/require_no_duplicates.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_NO_DUPLICATES_H + +#include "utils/exception.h" +#include "utils/fmt/multiset.h" +#include "utils/fmt/unordered_multiset.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set + require_no_duplicates(std::unordered_multiset const &s) { + std::unordered_set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +template +std::set require_no_duplicates(std::multiset const &s) { + std::set result{s.cbegin(), s.cend()}; + + if (result.size() != s.size()) { + throw mk_runtime_error(fmt::format( + "require_no_duplicates encountered duplicate in set {}", s)); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/reversed.h b/lib/utils/include/utils/containers/reversed.h index 621eee9519..902b247469 100644 --- a/lib/utils/include/utils/containers/reversed.h +++ b/lib/utils/include/utils/containers/reversed.h @@ -1,15 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_H +#include + namespace FlexFlow { template -T reversed(T const &t) { - T r; - for (auto i = t.cend() - 1; i >= t.begin(); i--) { - r.push_back(*i); - } - return r; +std::vector reversed(std::vector const &t) { + std::vector result(std::crbegin(t), std::crend(t)); + return result; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/set_minus.h b/lib/utils/include/utils/containers/set_minus.h index 6efa2f0a84..fdd1f11995 100644 --- a/lib/utils/include/utils/containers/set_minus.h +++ b/lib/utils/include/utils/containers/set_minus.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_MINUS_H +#include #include namespace FlexFlow { @@ -15,6 +16,15 @@ std::unordered_set set_minus(std::unordered_set const &l, return result; } +template +std::set set_minus(std::set const &l, std::set const &r) { + std::set result = l; + for (T const &t : r) { + result.erase(t); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/set_of.h b/lib/utils/include/utils/containers/set_of.h new file mode 100644 index 0000000000..14658209aa --- /dev/null +++ b/lib/utils/include/utils/containers/set_of.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SET_OF_H + +#include + +namespace FlexFlow { + +template +std::set set_of(C const &c) { + std::set result; + for (T const &t : c) { + result.insert(t); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/to_uppercase.h b/lib/utils/include/utils/containers/to_uppercase.h new file mode 100644 index 0000000000..a2dc7786f9 --- /dev/null +++ b/lib/utils/include/utils/containers/to_uppercase.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TO_UPPERCASE_H + +#include + +namespace FlexFlow { + +std::string to_uppercase(std::string const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/as_vector.h b/lib/utils/include/utils/containers/vector_of.h similarity index 54% rename from lib/utils/include/utils/containers/as_vector.h rename to lib/utils/include/utils/containers/vector_of.h index fafa1dc799..7fb903b4a8 100644 --- a/lib/utils/include/utils/containers/as_vector.h +++ b/lib/utils/include/utils/containers/vector_of.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_AS_VECTOR_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_OF_H #include namespace FlexFlow { template -std::vector as_vector(C const &c) { +std::vector vector_of(C const &c) { std::vector result(c.cbegin(), c.cend()); return result; } diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index 21a6d28ca2..4170882ae6 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -1,9 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_EXPECTED_H -#include "fmt/format.h" #include "utils/check_fmtable.h" -#include +#include #include #include @@ -44,15 +43,4 @@ std::ostream &operator<<(std::ostream &s, tl::expected const &t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(tl::expected const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 8e186928fd..46bf9ca8fa 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -5,7 +5,6 @@ #include "utils/containers/sorted.h" #include "utils/fmt/pair.h" #include "utils/join_strings.h" -#include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h index cff150dc29..616b784aac 100644 --- a/lib/utils/include/utils/fmt/multiset.h +++ b/lib/utils/include/utils/fmt/multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -42,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 45eebc2c58..2364e49568 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_OPTIONAL_H #include "utils/check_fmtable.h" -#include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::optional const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index 6f7e6f6b52..ab5ddd4e28 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #include "utils/check_fmtable.h" -#include #include #include @@ -40,15 +39,4 @@ std::ostream &operator<<(std::ostream &s, std::pair const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::pair const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index 1f8012f240..a183d37542 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/containers/sorted.h" #include "utils/join_strings.h" -#include #include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 75bbb4cb8a..876a032fe6 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -5,7 +5,6 @@ #include "utils/fmt/pair.h" #include "utils/join_strings.h" #include -#include #include #include #include @@ -48,15 +47,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_map const &m) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_map const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index 41abbc925e..deb03a04d4 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -42,15 +41,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_multiset const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_multiset const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 646ef0c7c5..257545af1b 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -4,7 +4,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" #include "utils/type_traits_core.h" -#include #include #include @@ -43,15 +42,4 @@ std::ostream &operator<<(std::ostream &s, std::unordered_set const &x) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::unordered_set const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 867577f72a..06a56417c3 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_VARIANT_H -#include #include #include @@ -33,15 +32,4 @@ std::ostream &operator<<(std::ostream &s, std::variant const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::variant const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 96526175a8..5d9ca0aeae 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -3,7 +3,6 @@ #include "utils/check_fmtable.h" #include "utils/join_strings.h" -#include #include #include @@ -41,15 +40,4 @@ std::ostream &operator<<(std::ostream &s, std::vector const &v) { } // namespace FlexFlow -namespace doctest { - -template -struct StringMaker> { - static String convert(std::vector const &m) { - return toString(fmt::to_string(m)); - } -}; - -} // namespace doctest - #endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h new file mode 100644 index 0000000000..2ed0bc02be --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_SUBGRAPH_INCOMING_EDGES_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h index fc372f68aa..afc9c47c1c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h @@ -6,6 +6,9 @@ namespace FlexFlow { +std::optional + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &, std::vector const &); std::optional get_cbc_decomposition(DiGraphView const &); diff --git a/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h new file mode 100644 index 0000000000..3066886e37 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_COMPLETE_BIPARTITE_COMPOSITE_IS_COMPLETE_BIPARTITE_DIGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &); +bool is_complete_bipartite_digraph(DiGraphView const &, + std::unordered_set const &srcs); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h new file mode 100644 index 0000000000..ee533a1180 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_as_dot.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_DIGRAPH_AS_DOT_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &, + std::function const &get_node_label); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h new file mode 100644 index 0000000000..87d0d3143a --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/digraph_has_edge.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_HAS_EDGE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &, DirectedEdge const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h new file mode 100644 index 0000000000..6d98c5c20d --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_OUTGOING_EDGES_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_outgoing_edges(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h new file mode 100644 index 0000000000..2c48d327c4 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_subgraph_successors.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_SUBGRAPH_SUCCESSORS_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h new file mode 100644 index 0000000000..c9751124c8 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/transitive_closure.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_CLOSURE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h new file mode 100644 index 0000000000..db2526f973 --- /dev/null +++ b/lib/utils/include/utils/graph/instances/unordered_set_undirected_graph.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_UNDIRECTED_GRAPH_H + +#include "utils/graph/node/node_source.h" +#include "utils/graph/undirected/i_undirected_graph.h" + +namespace FlexFlow { + +struct UnorderedSetUndirectedGraph final : public IUndirectedGraph { +public: + UnorderedSetUndirectedGraph(); + + Node add_node() override; + void add_node_unsafe(Node const &) override; + void remove_node_unsafe(Node const &) override; + void add_edge(UndirectedEdge const &) override; + void remove_edge(UndirectedEdge const &) override; + + std::unordered_set query_nodes(NodeQuery const &) const override; + std::unordered_set + query_edges(UndirectedEdgeQuery const &) const override; + + UnorderedSetUndirectedGraph *clone() const override; + +private: + UnorderedSetUndirectedGraph(NodeSource const &, + std::unordered_set const &, + std::unordered_set const &); + + NodeSource node_source; + std::unordered_set nodes; + std::unordered_set edges; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h index a1d6e9e37a..8306dad1ec 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_FIND_ISOMORPHISM_H -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/zip.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/is_isomorphic_under.h" diff --git a/lib/utils/include/utils/graph/node/node.struct.toml b/lib/utils/include/utils/graph/node/node.struct.toml index 0b6f348ddf..d5c22e5d3d 100644 --- a/lib/utils/include/utils/graph/node/node.struct.toml +++ b/lib/utils/include/utils/graph/node/node.struct.toml @@ -5,6 +5,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h deleted file mode 100644 index be6b9ce12c..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/get_serial_parallel_decomposition.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GET_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/digraph/digraph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include "utils/optional.h" -#include -#include - -namespace FlexFlow { - -std::optional - get_serial_parallel_decomposition(DiGraphView const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h deleted file mode 100644 index 6285d7ae1f..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" - -namespace FlexFlow { - -std::variant - flatten_ast(std::variant const &ast); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h b/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h deleted file mode 100644 index 7d8efc96f2..0000000000 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIAL_PARALLEL_DECOMPOSITION_H - -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.dtg.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" -#include - -namespace FlexFlow { - -std::variant internal_to_final_ast( - std::variant const &ast); -SerialParallelDecomposition - to_final_ast(std::variant const &); - -std::unordered_set get_nodes(SerialParallelDecomposition const &sp); -std::unordered_set get_nodes(SerialSplit const &); -std::unordered_set get_nodes(ParallelSplit const &); -std::unordered_set get_nodes(Node const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h new file mode 100644 index 0000000000..b1607e7a76 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include + +namespace FlexFlow { + +BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &, + BinarySPDecompositionTree const &); +BinarySPDecompositionTree + make_parallel_split(BinarySPDecompositionTree const &, + BinarySPDecompositionTree const &); +BinarySPDecompositionTree make_leaf_node(Node const &); + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); + +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml new file mode 100644 index 0000000000..1241311150 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", +] + +[[fields]] +name = "raw_tree" +type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h new file mode 100644 index 0000000000..42d71ce54e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h @@ -0,0 +1,63 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::string format_as(GenericBinarySeriesSplit const &s) { + return fmt::format("", + get_left_child(s), + get_right_child(s)); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinarySeriesSplit const &x) { + return (s << fmt::to_string(x)); +} + +template +std::string format_as(GenericBinaryParallelSplit const &s) { + return fmt::format("", + get_left_child(s), + get_right_child(s)); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinaryParallelSplit const &x) { + return (s << fmt::to_string(x)); +} + +template +std::string format_as(GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return fmt::format("", s); + }, + [](GenericBinaryParallelSplit const &s) { + return fmt::format("", s); + }, + [](T const &t) { + return fmt::format("", t); + }, + }); +} + +template +std::ostream &operator<<(std::ostream &s, + GenericBinarySPDecompositionTree const &t) { + return (s << fmt::to_string(t)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h new file mode 100644 index 0000000000..74f5ba5d8a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h @@ -0,0 +1,155 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H + +#include +#include +#include + +namespace FlexFlow { + +template +struct GenericBinarySPDecompositionTree; + +template +struct GenericBinarySeriesSplit { +public: + GenericBinarySeriesSplit() = delete; + explicit GenericBinarySeriesSplit( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) + : left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) {} + + GenericBinarySeriesSplit(GenericBinarySeriesSplit const &) = default; + + bool operator==(GenericBinarySeriesSplit const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinarySeriesSplit const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinarySeriesSplit const &other) const { + return this->tie() < other.tie(); + } + +public: + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; + +private: + std::tuple const &, + GenericBinarySPDecompositionTree const &> + tie() const { + return std::tie(*this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct GenericBinaryParallelSplit { +public: + GenericBinaryParallelSplit() = delete; + explicit GenericBinaryParallelSplit( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) + : left_child_ptr( + std::make_shared>(lhs)), + right_child_ptr( + std::make_shared>(rhs)) {} + + GenericBinaryParallelSplit(GenericBinaryParallelSplit const &) = default; + + bool operator==(GenericBinaryParallelSplit const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinaryParallelSplit const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinaryParallelSplit const &other) const { + return this->tie() < other.tie(); + } + +public: + std::shared_ptr> left_child_ptr; + std::shared_ptr> right_child_ptr; + +private: + std::tuple const &, + GenericBinarySPDecompositionTree const &> + tie() const { + return std::tie(*this->left_child_ptr, *this->right_child_ptr); + } + + friend std::hash; +}; + +template +struct GenericBinarySPDecompositionTree { +public: + GenericBinarySPDecompositionTree() = delete; + explicit GenericBinarySPDecompositionTree( + GenericBinarySeriesSplit const &s) + : root{s} {} + + explicit GenericBinarySPDecompositionTree( + GenericBinaryParallelSplit const &s) + : root{s} {} + + explicit GenericBinarySPDecompositionTree(T const &t) : root{t} {} + + GenericBinarySPDecompositionTree(GenericBinarySPDecompositionTree const &) = + default; + + bool operator==(GenericBinarySPDecompositionTree const &other) const { + return this->tie() == other.tie(); + } + + bool operator!=(GenericBinarySPDecompositionTree const &other) const { + return this->tie() != other.tie(); + } + + bool operator<(GenericBinarySPDecompositionTree const &other) const { + return this->tie() < other.tie(); + } + +public: + std::variant, GenericBinaryParallelSplit, T> + root; + +private: + std::tuple tie() const { + return std::tie(this->root); + } + + friend std::hash; +}; + +} // namespace FlexFlow + +// namespace rc { +// +// template <> +// struct Arbitrary<::FlexFlow::BinarySeriesSplit> { +// static Gen<::FlexFlow::BinarySeriesSplit> arbitrary(); +// }; +// +// template <> +// struct Arbitrary<::FlexFlow::GenericBinaryParallelSplit> { +// static Gen<::FlexFlow::GenericBinaryParallelSplit> arbitrary(); +// }; +// +// template <> +// struct Arbitrary<::FlexFlow::GenericBinarySPDecompositionTree> { +// static Gen<::FlexFlow::GenericBinarySPDecompositionTree> arbitrary(); +// }; +// +// } // namespace rc + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h new file mode 100644 index 0000000000..c6c1186d3d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +TT const &get(GenericBinarySPDecompositionTree const &t) { + return std::get(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h new file mode 100644 index 0000000000..51e1e20bac --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H + +#include "utils/containers/multiset_union.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](T const &t) { return std::unordered_multiset{t}; }, + [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, + [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, + }); +} + +template +std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { + return multiset_union(get_leaves(get_left_child(s)), + get_leaves(get_right_child(s))); +} + +template +std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { + return multiset_union(get_leaves(get_left_child(p)), + get_leaves(get_right_child(p))); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h new file mode 100644 index 0000000000..46a460b64e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySeriesSplit const &s) { + return *s.left_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinaryParallelSplit const &p) { + return *p.left_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_left_child(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return get_left_child(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_left_child(p); + }, + [](T const &t) -> GenericBinarySPDecompositionTree { + throw mk_runtime_error( + "get_left_child incorrectly called on leaf node"); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h new file mode 100644 index 0000000000..883acda480 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +SPDecompositionTreeNodeType + get_node_type(GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](GenericBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](GenericBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](T const &) { return SPDecompositionTreeNodeType::NODE; }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..7c6d28d7b4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -0,0 +1,40 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { + return visit(tt, + overload{ + [](T const &t) { return 1; }, + [](GenericBinarySeriesSplit const &s) { + return get_num_tree_nodes(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_num_tree_nodes(p); + }, + }); +} + +template +int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { + return 1 + get_num_tree_nodes(get_left_child(s)) + + get_num_tree_nodes(get_right_child(s)); +} + +template +int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { + return 1 + get_num_tree_nodes(get_left_child(p)) + + get_num_tree_nodes(get_right_child(p)); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h new file mode 100644 index 0000000000..f0bfba43a2 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h @@ -0,0 +1,44 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySeriesSplit const &s) { + return *s.right_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinaryParallelSplit const &p) { + return *p.right_child_ptr; +} + +template +GenericBinarySPDecompositionTree + get_right_child(GenericBinarySPDecompositionTree const &tt) { + return visit>( + tt, + overload{ + [](GenericBinarySeriesSplit const &s) { + return get_right_child(s); + }, + [](GenericBinaryParallelSplit const &p) { + return get_right_child(p); + }, + [](T const &t) -> GenericBinarySPDecompositionTree { + throw mk_runtime_error( + "get_right_child incorrectly called on leaf node"); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h new file mode 100644 index 0000000000..983dc4a572 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/hash-utils.h" +#include "utils/hash/tuple.h" + +namespace std { + +template +struct hash<::FlexFlow::GenericBinarySeriesSplit> { + size_t operator()(::FlexFlow::GenericBinarySeriesSplit const &s) const { + return get_std_hash(s.tie()); + } +}; + +template +struct hash<::FlexFlow::GenericBinaryParallelSplit> { + size_t operator()(::FlexFlow::GenericBinaryParallelSplit const &s) const { + return get_std_hash(s.tie()); + } +}; + +template +struct hash<::FlexFlow::GenericBinarySPDecompositionTree> { + size_t operator()( + ::FlexFlow::GenericBinarySPDecompositionTree const &s) const { + return get_std_hash(s.tie()); + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h new file mode 100644 index 0000000000..8086f38244 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +bool is_series_split(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative>(t.root); +} + +template +bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative>(t.root); +} + +template +bool is_leaf(GenericBinarySPDecompositionTree const &t) { + return std::holds_alternative(t.root); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h new file mode 100644 index 0000000000..3ffa63753a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_left_associative( + GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](T const &) { return true; }, + [](GenericBinarySeriesSplit const &s) { + return !is_series_split(get_right_child(s)) && + is_binary_sp_tree_left_associative(get_left_child(s)) && + is_binary_sp_tree_left_associative(get_right_child(s)); + }, + [](GenericBinaryParallelSplit const &p) { + return !is_parallel_split(get_right_child(p)) && + is_binary_sp_tree_left_associative(get_left_child(p)) && + is_binary_sp_tree_left_associative(get_right_child(p)); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h new file mode 100644 index 0000000000..d88459b432 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template +bool is_binary_sp_tree_right_associative( + GenericBinarySPDecompositionTree const &tt) { + return visit( + tt, + overload{ + [](T const &t) { return true; }, + [](GenericBinarySeriesSplit const &s) { + return !is_series_split(get_left_child(s)) && + is_binary_sp_tree_right_associative(get_left_child(s)) && + is_binary_sp_tree_right_associative(get_right_child(s)); + }, + [](GenericBinaryParallelSplit const &p) { + return !is_parallel_split(get_left_child(p)) && + is_binary_sp_tree_right_associative(get_left_child(p)) && + is_binary_sp_tree_right_associative(get_right_child(p)); + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h new file mode 100644 index 0000000000..4f1f8266e1 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h @@ -0,0 +1,103 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::GenericBinarySeriesSplit> { + static ::FlexFlow::GenericBinarySeriesSplit from_json(json const &j) { + return ::FlexFlow::GenericBinarySeriesSplit{ + j.at("left_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + j.at("right_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + }; + } + + static void to_json(json &j, + ::FlexFlow::GenericBinarySeriesSplit const &v) { + j["__type"] = "GenericBinarySeriesSplit"; + j["left_child"] = get_left_child(v); + j["right_child"] = get_right_child(v); + } +}; + +template +struct adl_serializer<::FlexFlow::GenericBinaryParallelSplit> { + static ::FlexFlow::GenericBinaryParallelSplit from_json(json const &j) { + return ::FlexFlow::GenericBinaryParallelSplit{ + j.at("left_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + j.at("right_child") + .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), + }; + } + + static void to_json(json &j, + ::FlexFlow::GenericBinaryParallelSplit const &v) { + j["__type"] = "GenericBinaryParallelSplit"; + j["left_child"] = get_left_child(v); + j["right_child"] = get_right_child(v); + } +}; + +template +struct adl_serializer<::FlexFlow::GenericBinarySPDecompositionTree> { + static ::FlexFlow::GenericBinarySPDecompositionTree + from_json(json const &j) { + std::string key = j.at("type").get(); + + if (key == "series") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get<::FlexFlow::GenericBinarySeriesSplit>(), + }; + } else if (key == "parallel") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get<::FlexFlow::GenericBinaryParallelSplit>(), + }; + } else if (key == "leaf") { + return ::FlexFlow::GenericBinarySPDecompositionTree{ + j.at("value").get(), + }; + } else { + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown json type key: {}", key)); + } + } + + static void + to_json(json &j, + ::FlexFlow::GenericBinarySPDecompositionTree const &v) { + j["__type"] = "GenericBinarySPDecompositionTree"; + ::FlexFlow::visit( + v, + ::FlexFlow::overload{ + [&](::FlexFlow::GenericBinarySeriesSplit const &s) { + j["type"] = "series"; + j["value"] = s; + return std::monostate{}; + }, + [&](::FlexFlow::GenericBinaryParallelSplit const &p) { + j["type"] = "parallel"; + j["value"] = p; + return std::monostate{}; + }, + [&](T const &t) { + j["type"] = "leaf"; + j["value"] = t; + return std::monostate{}; + }, + }); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h new file mode 100644 index 0000000000..f55b71146a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +GenericBinarySPDecompositionTree make_generic_binary_series_split( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{ + lhs, + rhs, + }, + }; +} + +template +GenericBinarySPDecompositionTree make_generic_binary_parallel_split( + GenericBinarySPDecompositionTree const &lhs, + GenericBinarySPDecompositionTree const &rhs) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{ + lhs, + rhs, + }, + }; +} + +template +GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(T const &t) { + return GenericBinarySPDecompositionTree{t}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h new file mode 100644 index 0000000000..a8de1ee8f8 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" + +namespace FlexFlow { + +template +GenericBinarySeriesSplit const & + require_series(GenericBinarySPDecompositionTree const &t) { + return get>(t); +} + +template +GenericBinaryParallelSplit const & + require_parallel(GenericBinarySPDecompositionTree const &t) { + return get>(t); +} + +template +T const &require_node(GenericBinarySPDecompositionTree const &t) { + return get(t); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h new file mode 100644 index 0000000000..4d7fa05960 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" + +namespace FlexFlow { + +template > +GenericBinarySPDecompositionTree + transform(GenericBinarySPDecompositionTree const &tt, F f) { + return visit>( + tt, + overload{ + [&](GenericBinarySeriesSplit const &s) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }, + }; + }, + [&](GenericBinaryParallelSplit const &s) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{ + transform(get_left_child(s), f), + transform(get_right_child(s), f), + }, + }; + }, + [&](T const &t) { + return GenericBinarySPDecompositionTree{ + f(t), + }; + }, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h new file mode 100644 index 0000000000..0d9503e59f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H + +#include "utils/exception.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +template +Result visit(GenericBinarySPDecompositionTree const &tt, F f) { + if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative>(tt.root)) { + return f(std::get>(tt.root)); + } else if (std::holds_alternative(tt.root)) { + return f(std::get(tt.root)); + } else { + throw mk_runtime_error( + "Unexpected case in visit(GenericBinarySPDecompositionTree)"); + } + + // return std::visit(tt.root, overload { + // [&](GenericBinarySeriesSplit const &s) -> Result { + // return f(s); + // }, + // [&](GenericBinaryParallelSplit const &p) -> Result { + // return f(p); + // }, + // [&](T const &t) -> Result { + // return f(t); + // }, + // }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..183ece3a89 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_LEFT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h new file mode 100644 index 0000000000..f5174aee56 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_NARY_SP_TREE_FROM_BINARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h new file mode 100644 index 0000000000..e01ec0bdde --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_RIGHT_ASSOCIATIVE_BINARY_SP_TREE_FROM_NARY_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h new file mode 100644 index 0000000000..f2a006d899 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/get_series_parallel_decomposition.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GET_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/optional.h" +#include +#include + +namespace FlexFlow { + +std::optional + get_series_parallel_decomposition(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h b/lib/utils/include/utils/graph/series_parallel/graph_generation.h similarity index 56% rename from lib/utils/include/utils/graph/serial_parallel/graph_generation.h rename to lib/utils/include/utils/graph/series_parallel/graph_generation.h index fac9c98db2..f18fd63d24 100644 --- a/lib/utils/include/utils/graph/serial_parallel/graph_generation.h +++ b/lib/utils/include/utils/graph/series_parallel/graph_generation.h @@ -1,23 +1,23 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_GRAPH_GENERATION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_GRAPH_GENERATION_H #include "utils/graph/dataflow_graph/dataflow_graph.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" namespace FlexFlow { void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext); -void serial_extend(DataflowGraph &g, DataflowGraphView const &ext); +void series_extend(DataflowGraph &g, DataflowGraphView const &ext); -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph parallel_composition(DataflowGraphView const &g1, DataflowGraphView const &g2); DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition); + SeriesParallelDecomposition const &sp_decomposition); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h new file mode 100644 index 0000000000..1283a6df3a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_INTERMEDIATE_SP_DECOMPOSITION_TREE_H + +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" + +namespace FlexFlow { + +std::variant + flatten_ast(std::variant const &ast); + +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml rename to lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml index 08f03ed12a..e7666fcd3f 100644 --- a/lib/utils/include/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.struct.toml +++ b/lib/utils/include/utils/graph/series_parallel/intermediate_sp_decomposition_tree.struct.toml @@ -8,7 +8,7 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/split_type.dtg.h", + "utils/graph/series_parallel/split_type.dtg.h", "", "", "utils/graph/node/node.dtg.h", diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h similarity index 70% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.h index 71cc5e3998..3fc1347ee5 100644 --- a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.h @@ -1,8 +1,8 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_PARALLEL_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_PARALLEL_REDUCTION_H #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/parallel_reduction.dtg.h" +#include "utils/graph/series_parallel/parallel_reduction.dtg.h" #include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/parallel_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/parallel_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h new file mode 100644 index 0000000000..52d2cb7236 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_PARALLEL_DECOMPOSITION_H + +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +std::variant internal_to_final_ast( + std::variant const &ast); +SeriesParallelDecomposition + to_final_ast(std::variant const &); + +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp); +std::unordered_multiset get_nodes(SeriesSplit const &); +std::unordered_multiset get_nodes(ParallelSplit const &); +std::unordered_multiset get_nodes(Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml similarity index 62% rename from lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml rename to lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml index f816abfbb4..921499ebd1 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_decomposition.variant.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.variant.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "SerialParallelDecomposition" +name = "SeriesParallelDecomposition" features = [ "eq", "hash", @@ -7,12 +7,12 @@ features = [ ] includes = [ - "utils/graph/serial_parallel/serial_parallel_splits.h", + "utils/graph/series_parallel/series_parallel_splits.h", "utils/graph/node/node.dtg.h", ] [[values]] -type = "::FlexFlow::SerialSplit" +type = "::FlexFlow::SeriesSplit" [[values]] type = "::FlexFlow::ParallelSplit" diff --git a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h similarity index 59% rename from lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h rename to lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 081137e513..18434d2b67 100644 --- a/lib/utils/include/utils/graph/serial_parallel/serial_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #include "utils/graph/node/node.dtg.h" #include @@ -7,18 +7,18 @@ namespace FlexFlow { -struct SerialSplit; +struct SeriesSplit; struct ParallelSplit; -struct SerialSplit { +struct SeriesSplit { public: - SerialSplit() = delete; - explicit SerialSplit(std::vector> const &); - explicit SerialSplit( + SeriesSplit() = delete; + explicit SeriesSplit(std::vector> const &); + explicit SeriesSplit( std::initializer_list> const &); - bool operator==(SerialSplit const &) const; - bool operator!=(SerialSplit const &) const; + bool operator==(SeriesSplit const &) const; + bool operator!=(SeriesSplit const &) const; public: std::vector> children; @@ -28,16 +28,16 @@ struct SerialSplit { Tie tie() const; }; -std::string format_as(SerialSplit const &); -std::ostream &operator<<(std::ostream &, SerialSplit const &); +std::string format_as(SeriesSplit const &); +std::ostream &operator<<(std::ostream &, SeriesSplit const &); } // namespace FlexFlow namespace std { template <> -struct hash<::FlexFlow::SerialSplit> { - size_t operator()(::FlexFlow::SerialSplit const &) const; +struct hash<::FlexFlow::SeriesSplit> { + size_t operator()(::FlexFlow::SeriesSplit const &) const; }; } // namespace std @@ -48,15 +48,15 @@ struct ParallelSplit { public: ParallelSplit() = delete; explicit ParallelSplit( - std::unordered_set> const &); + std::unordered_multiset> const &); explicit ParallelSplit( - std::initializer_list> const &); + std::initializer_list> const &); bool operator==(ParallelSplit const &) const; bool operator!=(ParallelSplit const &) const; public: - std::unordered_set> children; + std::unordered_multiset> children; private: using Tie = std::tuple; diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h b/lib/utils/include/utils/graph/series_parallel/series_reduction.h similarity index 77% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.h rename to lib/utils/include/utils/graph/series_parallel/series_reduction.h index c9bae58546..a7d53fecfc 100644 --- a/lib/utils/include/utils/graph/serial_parallel/series_reduction.h +++ b/lib/utils/include/utils/graph/series_parallel/series_reduction.h @@ -1,9 +1,9 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLEL_SERIES_REDUCTION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_SERIES_REDUCTION_H #include "utils/graph/multidigraph/multidiedge.dtg.h" #include "utils/graph/multidigraph/multidigraph.h" -#include "utils/graph/serial_parallel/series_reduction.dtg.h" +#include "utils/graph/series_parallel/series_reduction.dtg.h" namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/series_reduction.struct.toml rename to lib/utils/include/utils/graph/series_parallel/series_reduction.struct.toml diff --git a/lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/sink_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/sink_settings.enum.toml diff --git a/lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml b/lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml similarity index 100% rename from lib/utils/include/utils/graph/serial_parallel/source_settings.enum.toml rename to lib/utils/include/utils/graph/series_parallel/source_settings.enum.toml diff --git a/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml new file mode 100644 index 0000000000..2050800cbd --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_decomposition_tree_node_type.enum.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SPDecompositionTreeNodeType" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "SERIES" + +[[values]] +name = "PARALLEL" + +[[values]] +name = "NODE" diff --git a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml similarity index 90% rename from lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml rename to lib/utils/include/utils/graph/series_parallel/split_type.enum.toml index 96d85f0e12..c1a1cb5978 100644 --- a/lib/utils/include/utils/graph/serial_parallel/split_type.enum.toml +++ b/lib/utils/include/utils/graph/series_parallel/split_type.enum.toml @@ -8,7 +8,7 @@ features = [ ] [[values]] -name = "SERIAL" +name = "SERIES" [[values]] name = "PARALLEL" diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h new file mode 100644 index 0000000000..3e951b1db1 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_edges.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_EDGES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h new file mode 100644 index 0000000000..bc605360d2 --- /dev/null +++ b/lib/utils/include/utils/graph/undirected/algorithms/get_neighboring_nodes.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_ALGORITHMS_GET_NEIGHBORING_NODES_H + +#include "utils/graph/undirected/undirected_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &, + Node const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h index 1662ec6d8c..4761275031 100644 --- a/lib/utils/include/utils/graph/undirected/i_undirected_graph.h +++ b/lib/utils/include/utils/graph/undirected/i_undirected_graph.h @@ -15,7 +15,7 @@ struct IUndirectedGraph : public IUndirectedGraphView { virtual std::unordered_set query_nodes(NodeQuery const &query) const = 0; - virtual IUndirectedGraph *clone() const override = 0; + virtual IUndirectedGraph *clone() const = 0; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h index 9aa0f189ec..65939acc87 100644 --- a/lib/utils/include/utils/graph/undirected/undirected_edge_query.h +++ b/lib/utils/include/utils/graph/undirected/undirected_edge_query.h @@ -1,11 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_UNDIRECTED_UNDIRECTED_EDGE_QUERY_H +#include "utils/graph/undirected/undirected_edge.h" #include "utils/graph/undirected/undirected_edge_query.dtg.h" namespace FlexFlow { UndirectedEdgeQuery undirected_edge_query_all(); +bool matches_edge(UndirectedEdgeQuery const &, UndirectedEdge const &); UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &, UndirectedEdgeQuery const &); diff --git a/lib/utils/include/utils/hash/multiset.h b/lib/utils/include/utils/hash/multiset.h new file mode 100644 index 0000000000..4695b89165 --- /dev/null +++ b/lib/utils/include/utils/hash/multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/hash/unordered_multiset.h b/lib/utils/include/utils/hash/unordered_multiset.h new file mode 100644 index 0000000000..b19c76bfef --- /dev/null +++ b/lib/utils/include/utils/hash/unordered_multiset.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_HASH_UNORDERED_MULTISET_H + +#include "utils/hash-utils.h" +#include + +namespace std { + +template +struct hash> { + size_t operator()(std::unordered_multiset const &s) const { + size_t result = 0; + ::FlexFlow::unordered_container_hash(result, s); + return result; + } +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/json/check_is_jsonable.h b/lib/utils/include/utils/json/check_is_jsonable.h new file mode 100644 index 0000000000..41a64a1b83 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_jsonable.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSONABLE(TYPENAME) \ + static_assert(is_json_serializable::value, \ + #TYPENAME " should be json serializeable"); \ + static_assert(is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_deserializable.h b/lib/utils/include/utils/json/is_json_deserializable.h new file mode 100644 index 0000000000..9e6625428b --- /dev/null +++ b/lib/utils/include/utils/json/is_json_deserializable.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_DESERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_deserializable : std::false_type {}; + +template +struct is_json_deserializable< + T, + void_t().get())>> + : std::true_type {}; + +template +inline constexpr bool is_json_deserializable_v = + is_json_deserializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_json_serializable.h b/lib/utils/include/utils/json/is_json_serializable.h new file mode 100644 index 0000000000..926a8037d4 --- /dev/null +++ b/lib/utils/include/utils/json/is_json_serializable.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSON_SERIALIZABLE_H + +#include "utils/type_traits_core.h" +#include +#include + +namespace FlexFlow { + +template +struct is_json_serializable : std::false_type {}; + +template +struct is_json_serializable< + T, + void_t() = std::declval())>> + : std::true_type {}; + +template +inline constexpr bool is_json_serializable_v = is_json_serializable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/is_jsonable.h b/lib/utils/include/utils/json/is_jsonable.h new file mode 100644 index 0000000000..2c8c103650 --- /dev/null +++ b/lib/utils/include/utils/json/is_jsonable.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_IS_JSONABLE_H + +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +template +struct is_jsonable + : std::conjunction, is_json_deserializable> {}; + +template +inline constexpr bool is_jsonable_v = is_jsonable::value; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/optional.h b/lib/utils/include/utils/json/optional.h new file mode 100644 index 0000000000..c88dd24a15 --- /dev/null +++ b/lib/utils/include/utils/json/optional.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_OPTIONAL_H + +#include "utils/json/is_jsonable.h" +#include +#include + +namespace nlohmann { + +template +struct adl_serializer< + std::optional, + typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { + static void to_json(json &j, std::optional const &t) { + if (t.has_value()) { + j = t.value(); + } else { + j = nullptr; + } + } + + static void from_json(json const &j, std::optional &t) { + if (j == nullptr) { + t = std::nullopt; + } else { + t = j.get(); + } + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json/variant.h b/lib/utils/include/utils/json/variant.h new file mode 100644 index 0000000000..fe2c3f3b6c --- /dev/null +++ b/lib/utils/include/utils/json/variant.h @@ -0,0 +1,89 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_VARIANT_H + +#include "utils/json/is_jsonable.h" +#include + +namespace FlexFlow { + +struct VariantToJsonFunctor { + VariantToJsonFunctor(nlohmann::json &j) : j(j) {} + + nlohmann::json &j; + + template + void operator()(T const &t) { + static_assert(is_jsonable::value, ""); + + j = t; + } +}; + +template +void variant_to_json(json &j, std::variant const &v) { + json jval; + visit(::FlexFlow::VariantToJsonFunctor{jval}, v); + j["value"] = jval; + j["index"] = v.index(); +} + +template +std::optional variant_from_json_impl(json const &j) { + using Type = typename std::variant_alternative::type; + + if (j.at("index").get() == Idx) { + return j.at("value").get(); + } + return std::nullopt; +} + +template +std::optional variant_from_json_impl(json const &j, + std::index_sequence) { + // If there were no errors when parsing, all but one element of the array + // will be nullopt. This is because each call to variant_from_json_impl will + // have a unique index and exactly one of them will match the index in the + // json object. + std::array, sizeof...(Is)> results{ + variant_from_json_impl(j)...}; + for (std::optional &maybe : results) { + if (maybe) { + return maybe.value(); + } + } + return std::nullopt; +} + +template +std::variant variant_from_json(json const &j) { + using Variant = std::variant; + std::optional result = variant_from_json_impl( + j, std::make_index_sequence()); + if (!result.has_value()) { + throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", + j.at("index").get()); + } + return result.value(); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer, + typename std::enable_if<::FlexFlow::elements_satisfy< + ::FlexFlow::is_json_serializable, + std::variant>::value>::type> { + static void to_json(json &j, std::variant const &v) { + return ::FlexFlow::variant_to_json(j, v); + } + + static std::variant from_json(json const &j) { + return ::FlexFlow::variant_from_json(j); + } +}; + +} // namespace nlohmann + +#endif diff --git a/lib/utils/include/utils/json.h b/lib/utils/include/utils/json/visitable.h similarity index 52% rename from lib/utils/include/utils/json.h rename to lib/utils/include/utils/json/visitable.h index f56917e329..abc20065de 100644 --- a/lib/utils/include/utils/json.h +++ b/lib/utils/include/utils/json/visitable.h @@ -1,6 +1,9 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_JSON_H +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/json_core.h" #include "utils/optional.h" #include "utils/sequence.h" @@ -10,33 +13,6 @@ namespace FlexFlow { -template -struct is_json_serializable : std::false_type {}; - -template -struct is_json_serializable< - T, - void_t() = std::declval())>> - : std::true_type {}; - -template -struct is_json_deserializable : std::false_type {}; - -template -struct is_json_deserializable().get())>> - : std::true_type {}; - -template -struct is_jsonable - : conjunction, is_json_deserializable> {}; - -#define CHECK_IS_JSONABLE(TYPENAME) \ - static_assert(is_json_serializable::value, \ - #TYPENAME " should be json serializeable"); \ - static_assert(is_json_deserializable::value, \ - #TYPENAME " should be json deserializeable") - struct json_serialization_visitor { json_serialization_visitor() = delete; json_serialization_visitor(json &j) : j(j) {} @@ -134,66 +110,6 @@ T moveonly_visit_json_deserialize(json const &j) { return visitable_from_tuple(tuple_from_json(j)); } -struct VariantToJsonFunctor { - VariantToJsonFunctor(json &j) : j(j) {} - - json &j; - - template - void operator()(T const &t) { - static_assert(is_jsonable::value, ""); - - j = t; - } -}; - -template -void variant_to_json(json &j, std::variant const &v) { - json jval; - visit(::FlexFlow::VariantToJsonFunctor{jval}, v); - j["value"] = jval; - j["index"] = v.index(); -} - -template -std::optional variant_from_json_impl(json const &j) { - using Type = typename std::variant_alternative::type; - - if (j.at("index").get() == Idx) { - return j.at("value").get(); - } - return std::nullopt; -} - -template -std::optional variant_from_json_impl(json const &j, - std::index_sequence) { - // If there were no errors when parsing, all but one element of the array - // will be nullopt. This is because each call to variant_from_json_impl will - // have a unique index and exactly one of them will match the index in the - // json object. - std::array, sizeof...(Is)> results{ - variant_from_json_impl(j)...}; - for (std::optional &maybe : results) { - if (maybe) { - return maybe.value(); - } - } - return std::nullopt; -} - -template -std::variant variant_from_json(json const &j) { - using Variant = std::variant; - std::optional result = variant_from_json_impl( - j, std::make_index_sequence()); - if (!result.has_value()) { - throw ::FlexFlow::mk_runtime_error("Invalid type {} found in json", - j.at("index").get()); - } - return result.value(); -} - } // namespace FlexFlow namespace nlohmann { @@ -231,41 +147,6 @@ struct adl_serializer< } }; -template -struct adl_serializer< - std::optional, - typename std::enable_if<::FlexFlow::is_jsonable::value>::type> { - static void to_json(json &j, std::optional const &t) { - if (t.has_value()) { - to_json(j, t.value()); - } else { - j = nullptr; - } - } - - static void from_json(json const &j, std::optional &t) { - if (j == nullptr) { - t = std::nullopt; - } else { - t = j.get(); - } - } -}; - -template -struct adl_serializer, - typename std::enable_if<::FlexFlow::elements_satisfy< - ::FlexFlow::is_json_serializable, - std::variant>::value>::type> { - static void to_json(json &j, std::variant const &v) { - return ::FlexFlow::variant_to_json(j, v); - } - - static std::variant from_json(json const &j) { - return ::FlexFlow::variant_from_json(j); - } -}; - } // namespace nlohmann #endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3448ec4e0e..3ec165d595 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -7,6 +7,15 @@ namespace FlexFlow { +template +T or_else(std::optional const &o, F &&f) { + if (o.has_value()) { + return o.value(); + } else { + return f(); + } +} + template T const &unwrap(std::optional const &o, F const &f) { if (o.has_value()) { @@ -25,18 +34,4 @@ T const &assert_unwrap(std::optional const &o) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::map( - gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { - return m ? std::optional(std::move(*m)) : std::optional(); - }); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/include/utils/rapidcheck/optional.h b/lib/utils/include/utils/rapidcheck/optional.h new file mode 100644 index 0000000000..edb28fdb81 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/optional.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_OPTIONAL_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::map( + gen::maybe(std::move(gen::arbitrary())), [](Maybe &&m) { + return m ? std::optional(std::move(*m)) : std::optional(); + }); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/required.h b/lib/utils/include/utils/required.h index 9cdd7918dd..d16b67ba86 100644 --- a/lib/utils/include/utils/required.h +++ b/lib/utils/include/utils/required.h @@ -1,9 +1,13 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_REQUIRED_H -#include "utils/json.h" +#include "utils/fmt/vector.h" +#include "utils/json/is_json_deserializable.h" +#include "utils/json/is_json_serializable.h" +#include "utils/json/is_jsonable.h" #include "utils/required_core.h" #include "utils/type_traits.h" +#include namespace FlexFlow { @@ -14,11 +18,11 @@ static_assert(is_list_initializable, int>::value, ""); namespace nlohmann { template struct adl_serializer<::FlexFlow::req> { - static ::FlexFlow::req from_json(json const &j) { + static ::FlexFlow::req from_json(nlohmann::json const &j) { return {j.template get()}; } - static void to_json(json &j, ::FlexFlow::req const &t) { + static void to_json(nlohmann::json &j, ::FlexFlow::req const &t) { j = static_cast(t); } }; diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 19743b8301..7a936ebd7b 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -4,9 +4,9 @@ #include "fmt/core.h" #include "stack_vector.h" #include "utils/fmt.h" -#include "utils/json.h" #include "utils/type_traits.h" #include +#include #include #include @@ -70,13 +70,13 @@ template using stack_string = stack_basic_string; template -void to_json(json &j, stack_string const &v) { +void to_json(nlohmann::json &j, stack_string const &v) { std::string as_string = v; j = as_string; } template -void from_json(json const &j, stack_string &v) { +void from_json(nlohmann::json const &j, stack_string &v) { std::string as_string; j.get_to(as_string); v = stack_string{as_string}; diff --git a/lib/utils/include/utils/stack_vector.h b/lib/utils/include/utils/stack_vector.h index 1d654e3415..7a7bce7afc 100644 --- a/lib/utils/include/utils/stack_vector.h +++ b/lib/utils/include/utils/stack_vector.h @@ -3,12 +3,12 @@ #include "utils/hash-utils.h" #include "utils/join_strings.h" -#include "utils/json.h" #include "utils/test_types.h" #include "utils/type_traits.h" #include #include #include +#include #include #include #include @@ -326,13 +326,13 @@ std::ostream &operator<<(std::ostream &s, stack_vector const &v) { } template -void to_json(json &j, stack_vector const &v) { +void to_json(nlohmann::json &j, stack_vector const &v) { std::vector as_vec(v.begin(), v.end()); j = as_vec; } template -void from_json(json const &j, stack_vector &v) { +void from_json(nlohmann::json const &j, stack_vector &v) { std::vector as_vec; j.get_to(as_vec); v = stack_vector{as_vec.begin(), as_vec.end()}; diff --git a/lib/utils/src/utils/cli/cli_get_help_message.cc b/lib/utils/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..03c53c9356 --- /dev/null +++ b/lib/utils/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,101 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/maximum.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" +#include "utils/join_strings.h" +#include + +namespace FlexFlow { + +std::string cli_get_help_message(std::string const &program_name, + CLISpec const &cli) { + auto render_pos_arg = [](CLIPositionalArgumentSpec const &pos_arg_spec) { + if (pos_arg_spec.choices.has_value()) { + return "{" + join_strings(pos_arg_spec.choices.value(), ",") + "}"; + } else { + return pos_arg_spec.name; + } + }; + + auto render_flag_option_column_key = [](CLIFlagSpec const &flag_spec) { + std::ostringstream oss; + if (flag_spec.short_flag.has_value()) { + oss << "-" << flag_spec.short_flag.value() << ", "; + } + oss << "--" << flag_spec.long_flag; + return oss.str(); + }; + + std::ostringstream oss; + + oss << "usage: " << program_name; + for (CLIFlagSpec const &flag_spec : cli.flags) { + if (flag_spec.short_flag.has_value()) { + oss << " [-" << flag_spec.short_flag.value() << "]"; + } else { + oss << " [--" << flag_spec.long_flag << "]"; + } + } + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << " " << render_pos_arg(pos_arg_spec); + } + + oss << std::endl; + + std::vector all_arg_columns = concat_vectors(std::vector{ + transform(cli.positional_arguments, render_pos_arg), + transform(cli.flags, render_flag_option_column_key), + }); + std::vector all_arg_column_widths = + transform(all_arg_columns, [](std::string const &s) { return s.size(); }); + + if (!all_arg_columns.empty()) { + int max_column_width = + std::min(int_from_size_t(maximum(all_arg_column_widths).value()), 20); + + auto render_column = [&](std::string const &key, + std::optional const &description) { + if (description.has_value()) { + if (key.size() > max_column_width) { + return " " + key + "\n" + std::string(24, ' ') + description.value(); + } else { + } + return fmt::format( + " {:<{}} {}", key, max_column_width, description.value()); + } else { + return fmt::format(" {}", key); + } + }; + + if (!cli.positional_arguments.empty()) { + oss << std::endl; + oss << "positional arguments:" << std::endl; + + if (!cli.positional_arguments.empty()) { + for (CLIPositionalArgumentSpec const &pos_arg_spec : + cli.positional_arguments) { + oss << render_column(render_pos_arg(pos_arg_spec), + pos_arg_spec.description) + << std::endl; + } + } + } + + if (!cli.flags.empty()) { + oss << std::endl; + oss << "options:" << std::endl; + + for (CLIFlagSpec const &flag_spec : cli.flags) { + oss << render_column(render_flag_option_column_key(flag_spec), + flag_spec.description) + << std::endl; + } + } + } + + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse.cc b/lib/utils/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..07982c0c2d --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse.cc @@ -0,0 +1,96 @@ +#include "utils/cli/cli_parse.h" +#include "utils/cli/cli_spec.h" +#include "utils/containers/contains.h" +#include "utils/containers/enumerate.h" +#include "utils/containers/generate_map.h" + +namespace FlexFlow { + +tl::expected cli_parse_flag(CLISpec const &cli, + std::string const &arg) { + for (auto const &[idx, flag_spec] : enumerate(cli.flags)) { + CLIFlagKey key = CLIFlagKey{idx}; + if (("--" + flag_spec.long_flag) == arg) { + return key; + } + + if (flag_spec.short_flag.has_value()) { + if ((std::string{"-"} + flag_spec.short_flag.value()) == arg) { + return key; + } + } + } + + return tl::unexpected(fmt::format("Encountered unknown flag {}", arg)); +} + +tl::expected + cli_parse(CLISpec const &cli, std::vector const &args) { + CLIParseResult result = CLIParseResult{ + generate_map(cli_get_flag_keys(cli), + [](CLIFlagKey const &) { return false; }), + {}, + }; + + int consumed_positional_args = 0; + auto parse_positional_arg = + [&](std::string const &arg) -> std::optional { + if (consumed_positional_args >= cli.positional_arguments.size()) { + return fmt::format("Too many positional arguments: expected {}", + cli.positional_arguments.size()); + } + + CLIPositionalArgumentSpec arg_spec = + cli.positional_arguments.at(consumed_positional_args); + + if (arg_spec.choices.has_value() && + !contains(arg_spec.choices.value(), arg)) { + return fmt::format( + "Invalid option for positional argument \"{}\": \"{}\"", + arg_spec.name, + arg); + } + + result.positional_arguments.insert( + {CLIPositionalArgumentKey{consumed_positional_args}, arg}); + consumed_positional_args++; + + return std::nullopt; + }; + + for (int i = 1; i < args.size(); i++) { + std::string arg = args.at(i); + + if (!arg.empty() && arg.at(0) == '-') { + tl::expected parsed_flag = + cli_parse_flag(cli, arg); + + if (parsed_flag.has_value()) { + result.flags.at(parsed_flag.value()) = true; + } + } else { + std::optional maybe_err_msg = parse_positional_arg(arg); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + } + + if (consumed_positional_args != cli.positional_arguments.size()) { + return tl::unexpected( + fmt::format("Not enough positional arguments: found {}, expected {}", + consumed_positional_args, + cli.positional_arguments.size())); + } + + return result; +} + +tl::expected + cli_parse(CLISpec const &cli, int argc, char const *const *argv) { + std::vector args = {argv, argv + argc}; + + return cli_parse(cli, args); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse_result.cc b/lib/utils/src/utils/cli/cli_parse_result.cc new file mode 100644 index 0000000000..6682a7a6eb --- /dev/null +++ b/lib/utils/src/utils/cli/cli_parse_result.cc @@ -0,0 +1,14 @@ +#include "utils/cli/cli_parse_result.h" + +namespace FlexFlow { + +bool cli_get_flag(CLIParseResult const &result, CLIArgumentKey const &key) { + return result.flags.at(key.get()); +} + +std::string cli_get_argument(CLIParseResult const &result, + CLIArgumentKey const &key) { + return result.positional_arguments.at(key.get()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_spec.cc b/lib/utils/src/utils/cli/cli_spec.cc new file mode 100644 index 0000000000..ca51cfe57f --- /dev/null +++ b/lib/utils/src/utils/cli/cli_spec.cc @@ -0,0 +1,37 @@ +#include "utils/cli/cli_spec.h" +#include "utils/containers/count.h" +#include "utils/containers/transform.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +CLISpec empty_cli_spec() { + return CLISpec{{}, {}}; +} + +std::vector cli_get_flag_keys(CLISpec const &cli) { + return transform(count(cli.flags.size()), + [](int idx) { return CLIFlagKey{idx}; }); +} + +CLIArgumentKey cli_add_help_flag(CLISpec &cli) { + CLIFlagSpec help_flag = + CLIFlagSpec{"help", 'h', "show this help message and exit"}; + return cli_add_flag(cli, help_flag); +} + +CLIArgumentKey cli_add_flag(CLISpec &cli, CLIFlagSpec const &flag_spec) { + cli.flags.push_back(flag_spec); + + return CLIArgumentKey{CLIFlagKey{int_from_size_t(cli.flags.size()) - 1}}; +} + +CLIArgumentKey + cli_add_positional_argument(CLISpec &cli, + CLIPositionalArgumentSpec const &arg) { + cli.positional_arguments.push_back(arg); + return CLIArgumentKey{CLIPositionalArgumentKey{ + int_from_size_t(cli.positional_arguments.size()) - 1}}; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/as_vector.cc b/lib/utils/src/utils/containers/as_vector.cc deleted file mode 100644 index 9c7b63ca58..0000000000 --- a/lib/utils/src/utils/containers/as_vector.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/as_vector.h" diff --git a/lib/utils/src/utils/containers/enumerate_vector.cc b/lib/utils/src/utils/containers/enumerate_vector.cc new file mode 100644 index 0000000000..d4fd131af2 --- /dev/null +++ b/lib/utils/src/utils/containers/enumerate_vector.cc @@ -0,0 +1 @@ +#include "utils/containers/enumerate_vector.h" diff --git a/lib/utils/src/utils/containers/foldl1.cc b/lib/utils/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..c6cdd0eec9 --- /dev/null +++ b/lib/utils/src/utils/containers/foldl1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldl1.h" diff --git a/lib/utils/src/utils/containers/foldr1.cc b/lib/utils/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..9d00d81565 --- /dev/null +++ b/lib/utils/src/utils/containers/foldr1.cc @@ -0,0 +1 @@ +#include "utils/containers/foldr1.h" diff --git a/lib/utils/src/utils/containers/get_element_counts.cc b/lib/utils/src/utils/containers/get_element_counts.cc index 9840ed34d8..ac8e289523 100644 --- a/lib/utils/src/utils/containers/get_element_counts.cc +++ b/lib/utils/src/utils/containers/get_element_counts.cc @@ -1,10 +1,10 @@ #include "utils/containers/get_element_counts.h" -#include "utils/containers/as_vector.h" +#include "utils/containers/vector_of.h" namespace FlexFlow { std::unordered_map get_element_counts(std::string const &s) { - return get_element_counts(as_vector(s)); + return get_element_counts(vector_of(s)); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/maximum.cc b/lib/utils/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..51d92cf951 --- /dev/null +++ b/lib/utils/src/utils/containers/maximum.cc @@ -0,0 +1 @@ +#include "utils/containers/maximum.h" diff --git a/lib/utils/src/utils/containers/multiset_union.cc b/lib/utils/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..a053d05fa6 --- /dev/null +++ b/lib/utils/src/utils/containers/multiset_union.cc @@ -0,0 +1 @@ +#include "utils/containers/multiset_union.h" diff --git a/lib/utils/src/utils/containers/require_no_duplicates.cc b/lib/utils/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..b1d21ad832 --- /dev/null +++ b/lib/utils/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1 @@ +#include "utils/containers/require_no_duplicates.h" diff --git a/lib/utils/src/utils/containers/set_of.cc b/lib/utils/src/utils/containers/set_of.cc new file mode 100644 index 0000000000..3a12ee539d --- /dev/null +++ b/lib/utils/src/utils/containers/set_of.cc @@ -0,0 +1 @@ +#include "utils/containers/set_of.h" diff --git a/lib/utils/src/utils/containers/to_uppercase.cc b/lib/utils/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..6c02b5a109 --- /dev/null +++ b/lib/utils/src/utils/containers/to_uppercase.cc @@ -0,0 +1,10 @@ +#include "utils/containers/to_uppercase.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +std::string to_uppercase(std::string const &s) { + return transform(s, [](char c) -> char { return std::toupper(c); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/vector_of.cc b/lib/utils/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..b997076511 --- /dev/null +++ b/lib/utils/src/utils/containers/vector_of.cc @@ -0,0 +1 @@ +#include "utils/containers/vector_of.h" diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index 323f444a22..6ed41daf43 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -219,10 +219,6 @@ std::unordered_set get_endpoints(UndirectedEdge const &e) { // return g.query_edges(MultiDiEdgeQuery::all()); // } -std::unordered_set get_edges(UndirectedGraphView const &g) { - return g.query_edges(undirected_edge_query_all()); -} - // std::unordered_set get_edges(OpenMultiDiGraphView const &g) // { // return g.query_edges(OpenMultiDiEdgeQuery::all()); diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..d17a84dd12 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_incoming_edges(DataflowGraphView const &g, + std::unordered_set const &ns) { + + std::unordered_set all_nodes = get_nodes(g); + query_set src_query = query_set{set_minus(all_nodes, ns)}; + + DataflowEdgeQuery query = DataflowEdgeQuery{ + src_query, + query_set::matchall(), + query_set{ns}, + query_set::matchall(), + }; + + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 011d8b3ed9..8afe7da926 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -3,8 +3,13 @@ #include "utils/containers/extend.h" #include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" +#include "utils/containers/set_of.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/set.h" +#include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" #include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/graph/digraph/algorithms/get_outgoing_edges.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" @@ -12,23 +17,35 @@ #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/node/algorithms.h" #include "utils/hash/unordered_set.h" +#include namespace FlexFlow { std::optional - get_cbc_decomposition(DiGraphView const &g) { + get_cbc_decomposition_with_edge_order_internal( + DiGraphView const &g, std::vector const &edge_order) { // implementation of the algorithm from https://doi.org/10.1145/800135.804393 // top left of page 8, second paragraph + std::queue edges_to_process; + for (DirectedEdge const &e : edge_order) { + edges_to_process.push(e); + } + std::unordered_set already_in_a_head = {}; std::unordered_set already_in_a_tail = {}; - std::unordered_set edges_to_process = get_edges(g); + + std::unordered_set already_processed = {}; CompleteBipartiteCompositeDecomposition result = CompleteBipartiteCompositeDecomposition{{}}; while (!edges_to_process.empty()) { - DirectedEdge e = get_first(edges_to_process); + DirectedEdge e = edges_to_process.front(); + edges_to_process.pop(); + if (contains(already_processed, e)) { + continue; + } std::unordered_set head = get_predecessors(g, e.dst); std::unordered_set tail = get_successors(g, e.src); @@ -39,6 +56,12 @@ std::optional std::unordered_set from_head_to_tail = g.query_edges(DirectedEdgeQuery{head, tail}); + + DiGraphView subgraph = get_subgraph(g, set_union(head, tail)); + if (!is_complete_bipartite_digraph(subgraph, head)) { + return std::nullopt; + } + if (set_union(values(get_outgoing_edges(g, head))) != from_head_to_tail) { return std::nullopt; } @@ -47,7 +70,7 @@ std::optional } result.subgraphs.insert(BipartiteComponent{head, tail}); - edges_to_process = set_minus(edges_to_process, from_head_to_tail); + already_processed = set_union(already_processed, from_head_to_tail); extend(already_in_a_head, head); extend(already_in_a_tail, tail); } @@ -58,4 +81,10 @@ std::optional return result; } +std::optional + get_cbc_decomposition(DiGraphView const &g) { + std::vector edge_order = vector_of(get_edges(g)); + return get_cbc_decomposition_with_edge_order_internal(g, edge_order); +} + } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc new file mode 100644 index 0000000000..2eab8371b2 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -0,0 +1,29 @@ +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/containers/get_first.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +bool is_complete_bipartite_digraph(DiGraphView const &g) { + return is_complete_bipartite_digraph(g, get_sources(g)); +} + +bool is_complete_bipartite_digraph(DiGraphView const &g, + std::unordered_set const &srcs) { + std::unordered_set sinks = set_minus(get_nodes(g), srcs); + + std::unordered_set edges = get_edges(g); + + std::unordered_set expected_edges; + for (Node const &src : srcs) { + for (Node const &sink : sinks) { + expected_edges.insert(DirectedEdge{src, sink}); + } + } + + return edges == expected_edges; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc new file mode 100644 index 0000000000..ad7830cc76 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_as_dot.cc @@ -0,0 +1,32 @@ +#include "utils/graph/digraph/algorithms/digraph_as_dot.h" +#include "utils/dot_file.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::string digraph_as_dot( + DiGraphView const &g, + std::function const &get_node_label) { + std::ostringstream oss; + DotFile dot = DotFile{oss}; + + auto get_node_name = [](Node const &n) { + return fmt::format("n{}", n.raw_uid); + }; + + for (Node const &n : get_nodes(g)) { + RecordFormatter rec; + rec << get_node_label(n); + dot.add_record_node(get_node_name(n), rec); + } + + for (DirectedEdge const &e : get_edges(g)) { + dot.add_edge(get_node_name(e.src), get_node_name(e.dst)); + } + + dot.close(); + return oss.str(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc new file mode 100644 index 0000000000..5c790abb8c --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/digraph_has_edge.cc @@ -0,0 +1,13 @@ +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" + +namespace FlexFlow { + +bool digraph_has_edge(DiGraphView const &g, DirectedEdge const &e) { + return !g.query_edges(DirectedEdgeQuery{ + query_set{e.src}, + query_set{e.dst}, + }) + .empty(); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc index 2e570cbdf9..34cc7fcc6f 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/get_imm_dominators_map.h" -#include "utils/containers/as_vector.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" #include "utils/containers/generate_map.h" @@ -7,6 +6,7 @@ #include "utils/containers/get_only.h" #include "utils/containers/keys.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms/get_dominators_map.h" #include "utils/graph/node/algorithms.h" @@ -22,8 +22,8 @@ std::unordered_map> std::unordered_set n_dominators = node_to_its_dominators.at(n); n_dominators.erase(n); std::vector recursive_dominator_list = concat_vectors( - transform(as_vector(n_dominators), [&](Node const &dominator) { - return as_vector(node_to_its_dominators.at(dominator)); + transform(vector_of(n_dominators), [&](Node const &dominator) { + return vector_of(node_to_its_dominators.at(dominator)); })); std::unordered_map dominator_counts = get_element_counts(recursive_dominator_list); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc new file mode 100644 index 0000000000..f19deb3046 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" +#include "utils/containers/set_minus.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_set get_subgraph_outgoing_edges( + DiGraphView const &g, std::unordered_set const &subgraph_nodes) { + std::unordered_set external_nodes = + set_minus(get_nodes(g), subgraph_nodes); + DirectedEdgeQuery query = DirectedEdgeQuery{query_set{subgraph_nodes}, + query_set{external_nodes}}; + return g.query_edges(query); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc new file mode 100644 index 0000000000..e860fb11b1 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_subgraph_successors.cc @@ -0,0 +1,16 @@ +#include "utils/graph/digraph/algorithms/get_subgraph_successors.h" +#include "utils/graph/digraph/algorithms/get_subgraph_outgoing_edges.h" + +namespace FlexFlow { + +std::unordered_set + get_subgraph_successors(DiGraphView const &g, + std::unordered_set const &subgraph_nodes) { + std::unordered_set successors = + transform(get_subgraph_outgoing_edges(g, subgraph_nodes), + [](DirectedEdge const &e) { return e.dst; }); + + return successors; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..3efea1c138 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,51 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/digraph/algorithms.h" +#include "utils/graph/digraph/algorithms/digraph_has_edge.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +DiGraphView transitive_closure(DiGraphView const &g) { + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + std::unordered_set edges = get_edges(g); + + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } + + DiGraph result = materialize_digraph_view(g); + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; + result.add_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(j)}); + } + } + } + } + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index 10ffe4fc33..97a2439263 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,7 +1,12 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/vector_of.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_closure.h" #include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -24,29 +29,60 @@ DirectedEdgeMaskView *DirectedEdgeMaskView::clone() const { } DiGraphView transitive_reduction(DiGraphView const &g) { - std::unordered_set edge_mask = get_edges(g); + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + // + // transitive_closure inlined to avoid any drifts in node numbering + // between transitive_closure and transitive_reduction + + bidict nodes = bidict_from_enumerating(get_nodes(g)); + int num_nodes = nodes.size(); + + std::vector edge_matrix(num_nodes * num_nodes, false); + + auto has_edge = [&](int src_idx, + int dst_idx) -> std::vector::reference { + return edge_matrix[src_idx * num_nodes + dst_idx]; + }; + + for (DirectedEdge const &e : get_edges(g)) { + has_edge(nodes.at_r(e.src), nodes.at_r(e.dst)) = true; + } - while (true) { - std::unordered_set new_edge_mask = edge_mask; - for (DirectedEdge const &e1 : edge_mask) { - for (DirectedEdge const &e2 : edge_mask) { - if (e1.dst == e2.src && e1 != e2) { - DirectedEdge trans_edge = DirectedEdge{e1.src, e2.dst}; - if (contains(new_edge_mask, trans_edge)) { - new_edge_mask.erase(trans_edge); + // compute transitive closure + // see https://cs.winona.edu/lin/cs440/ch08-2.pdf slide 8-8 + for (int k = 0; k < num_nodes; k++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, k)) { + for (int j = 0; j < num_nodes; j++) { + if (has_edge(k, j)) { + has_edge(i, j) = true; } } } } + } - if (new_edge_mask == edge_mask) { - break; - } else { - edge_mask = new_edge_mask; + DiGraph result = materialize_digraph_view(g); + // compute transitive reduction + // see https://stackoverflow.com/a/6702198 + std::unordered_set edge_mask = get_edges(g); + for (int j = 0; j < num_nodes; j++) { + for (int i = 0; i < num_nodes; i++) { + if (has_edge(i, j)) { + for (int k = 0; k < num_nodes; k++) { + if (has_edge(j, k)) { + has_edge(i, k) = false; + result.remove_edge(DirectedEdge{nodes.at_l(i), nodes.at_l(k)}); + } + } + } } } - return DiGraphView::create(g, edge_mask); + return result; } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc index 34a8eff503..68ef12c49e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_digraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_digraph.cc @@ -38,11 +38,7 @@ void AdjacencyDiGraph::add_edge(DirectedEdge const &e) { } void AdjacencyDiGraph::remove_edge(DirectedEdge const &e) { - std::unordered_set &m = this->adjacency.at(e.src); - auto iter = m.find(e.dst); - if (iter != m.end()) { - m.erase(iter); - } + this->adjacency.at(e.src).erase(e.dst); } std::unordered_set diff --git a/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc new file mode 100644 index 0000000000..6f6722f635 --- /dev/null +++ b/lib/utils/src/utils/graph/instances/unordered_set_undirected_graph.cc @@ -0,0 +1,58 @@ +#include "utils/graph/instances/unordered_set_undirected_graph.h" +#include "utils/graph/node/node_query.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph() {} + +UnorderedSetUndirectedGraph::UnorderedSetUndirectedGraph( + NodeSource const &node_source, + std::unordered_set const &nodes, + std::unordered_set const &edges) + : node_source(node_source), nodes(nodes), edges(edges) {} + +Node UnorderedSetUndirectedGraph::add_node() { + Node new_node = this->node_source.new_node(); + this->nodes.insert(new_node); + return new_node; +} + +void UnorderedSetUndirectedGraph::add_node_unsafe(Node const &n) { + this->nodes.insert(n); +} + +void UnorderedSetUndirectedGraph::remove_node_unsafe(Node const &n) { + this->nodes.erase(n); +} + +void UnorderedSetUndirectedGraph::add_edge(UndirectedEdge const &e) { + assert(contains(this->nodes, e.bigger)); + assert(contains(this->nodes, e.smaller)); + this->edges.insert(e); +} + +void UnorderedSetUndirectedGraph::remove_edge(UndirectedEdge const &e) { + this->edges.erase(e); +} + +std::unordered_set + UnorderedSetUndirectedGraph::query_nodes(NodeQuery const &q) const { + return apply_node_query(q, this->nodes); +} + +std::unordered_set UnorderedSetUndirectedGraph::query_edges( + UndirectedEdgeQuery const &q) const { + return filter(this->edges, + [&](UndirectedEdge const &e) { return matches_edge(q, e); }); +} + +UnorderedSetUndirectedGraph *UnorderedSetUndirectedGraph::clone() const { + return new UnorderedSetUndirectedGraph{ + this->node_source, + this->nodes, + this->edges, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc index 47096d492c..53497a715d 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_edge_counts.cc @@ -1,7 +1,7 @@ #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_element_counts.h" #include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" @@ -10,7 +10,7 @@ namespace FlexFlow { std::unordered_map get_edge_counts(MultiDiGraphView const &g) { return get_element_counts( - transform(as_vector(get_edges(g)), + transform(vector_of(get_edges(g)), [&](MultiDiEdge const &e) { return get_directed_edge(g, e); })); } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index d95a9b9565..1dd5353301 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -2,12 +2,12 @@ #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/containers/as_vector.h" #include "utils/containers/get_all_permutations.h" #include "utils/containers/get_first.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/values.h" +#include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" @@ -201,7 +201,7 @@ std::unordered_set OpenDataflowGraphView const &dst) { std::unordered_set result; - std::vector src_sink_nodes = as_vector(get_sinks(src)); + std::vector src_sink_nodes = vector_of(get_sinks(src)); std::unordered_set dst_sink_nodes = get_sinks(dst); if (src_sink_nodes.size() != dst_sink_nodes.size()) { @@ -209,7 +209,7 @@ std::unordered_set } std::vector src_unused_graph_inputs = - as_vector(get_unused_open_dataflow_graph_inputs(src)); + vector_of(get_unused_open_dataflow_graph_inputs(src)); std::unordered_set dst_unused_graph_inputs = get_unused_open_dataflow_graph_inputs(dst); diff --git a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc deleted file mode 100644 index 6384bd9159..0000000000 --- a/lib/utils/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" -#include "utils/containers/extend.h" - -namespace FlexFlow { - -struct FlattenAST { - void add_flattened_child_to_parent( - IntermediateSpDecompositionTree &parent, - std::variant const &child) { - if (std::holds_alternative(child)) { - parent.children.push_back(child); - return; - } - - IntermediateSpDecompositionTree child_node = - std::get(child); - - if (parent.type == child_node.type) { - extend(parent.children, child_node.children); - } else { - parent.children.push_back(child); - } - } - - std::variant - operator()(IntermediateSpDecompositionTree const &ast_node) { - IntermediateSpDecompositionTree result(ast_node.type, {}); - for (std::variant const &child : - ast_node.children) { - std::variant flattened_child = - flatten_ast(child); - add_flattened_child_to_parent(result, flattened_child); - } - return result; - } - - std::variant - operator()(Node const &ast_node) { - return ast_node; - } -}; - -std::variant flatten_ast( - std::variant const &ast) { - return std::visit(FlattenAST{}, ast); -} - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc new file mode 100644 index 0000000000..18d1f922c6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -0,0 +1,43 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" + +namespace FlexFlow { + +BinarySPDecompositionTree + make_series_split(BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{ + make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), + }; +} + +BinarySPDecompositionTree + make_parallel_split(BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{ + make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + }; +} + +BinarySPDecompositionTree make_leaf_node(Node const &n) { + return BinarySPDecompositionTree{ + make_generic_binary_sp_leaf(n), + }; +} + +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tt) { + return is_binary_sp_tree_left_associative(tt.raw_tree); +} + +bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tt) { + return is_binary_sp_tree_right_associative(tt.raw_tree); +} + +std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { + return get_leaves(tt.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc new file mode 100644 index 0000000000..4cd7206408 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc new file mode 100644 index 0000000000..3a4dbad8ec --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc new file mode 100644 index 0000000000..4ee18af5be --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..71b67acc54 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc new file mode 100644 index 0000000000..227e5bd79c --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc new file mode 100644 index 0000000000..1618128226 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..05ec6b5925 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc new file mode 100644 index 0000000000..f168ba1e2f --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc new file mode 100644 index 0000000000..75c472c435 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc new file mode 100644 index 0000000000..3da024743c --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..8fe9397003 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..d202f55964 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc new file mode 100644 index 0000000000..b569ff9265 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc new file mode 100644 index 0000000000..fb1532b3ef --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc new file mode 100644 index 0000000000..3fee45fcf5 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..cabd66cff7 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc new file mode 100644 index 0000000000..25409333f2 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -0,0 +1 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..02e541b7e4 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,75 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldl1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function( + std::variant const &)> + from_series_child; + std::function( + std::variant const &)> + from_parallel_child; + + auto from_node = [](Node const &n) -> GenericBinarySPDecompositionTree { + return GenericBinarySPDecompositionTree{n}; + }; + + auto from_series = + [&](SeriesSplit const &s) -> GenericBinarySPDecompositionTree { + std::vector> children = + transform(s.children, from_series_child); + return foldl1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{accum, x}, + }; + }); + }; + + auto from_parallel = + [&](ParallelSplit const &s) -> GenericBinarySPDecompositionTree { + std::vector> children = + transform(vector_of(s.children), from_parallel_child); + return foldl1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{accum, x}}; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return BinarySPDecompositionTree{ + nary.visit>(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..3b8affd16d --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,12 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" + +namespace FlexFlow { + +SeriesParallelDecomposition + nary_sp_tree_from_binary(BinarySPDecompositionTree const &binary) { + return to_final_ast(from_binary_sp_tree(binary)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..673a4118a6 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,72 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/containers/foldr1.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/overload.h" + +namespace FlexFlow { + +BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( + SeriesParallelDecomposition const &nary) { + std::function( + std::variant const &)> + from_series_child; + std::function( + std::variant const &)> + from_parallel_child; + + auto from_node = [](Node const &n) { + return GenericBinarySPDecompositionTree{n}; + }; + + auto from_series = [&](SeriesSplit const &s) { + std::vector> children = + transform(s.children, from_series_child); + return foldr1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinarySeriesSplit{x, accum}}; + }); + }; + + auto from_parallel = [&](ParallelSplit const &s) { + std::vector> children = + transform(vector_of(s.children), from_parallel_child); + return foldr1(children, + [](GenericBinarySPDecompositionTree const &accum, + GenericBinarySPDecompositionTree const &x) { + return GenericBinarySPDecompositionTree{ + GenericBinaryParallelSplit{x, accum}}; + }); + }; + + from_parallel_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + }, + v); + }; + + from_series_child = [&](std::variant const &v) + -> GenericBinarySPDecompositionTree { + return std::visit( + overload{ + [&](Node const &n) { return from_node(n); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }, + v); + }; + + return BinarySPDecompositionTree{ + nary.visit>(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc similarity index 62% rename from lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 632f5245db..ab231f256c 100644 --- a/lib/utils/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -1,23 +1,28 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" #include "utils/containers/get_only.h" #include "utils/containers/map_values.h" #include "utils/containers/transform.h" #include "utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/serial_parallel/parallel_reduction.h" -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_reduction.h" namespace FlexFlow { -std::optional - get_serial_parallel_decomposition(DiGraphView const &g) { +std::optional + get_series_parallel_decomposition(DiGraphView const &g) { + + DiGraphView transitively_reduced = transitive_reduction(g); InverseLineGraphResult inverse_line_graph_result = ({ std::optional maybe_line_graph = - get_inverse_line_graph(g); + get_inverse_line_graph(transitively_reduced); if (!maybe_line_graph.has_value()) { return std::nullopt; } @@ -27,14 +32,11 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); - std::unordered_map> - ttsp_edge_to_sp_tree = map_values( - inverse_line_graph_result.inverse_edge_to_line_node_bidict - .as_unordered_map(), - [](Node const &n) { - return std::variant{n}; - }); + std::unordered_map + ttsp_edge_to_sp_tree = + map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return make_leaf_node(n); }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -44,11 +46,8 @@ std::optional ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = make_parallel_split( + ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -63,11 +62,8 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - std::variant new_tree = - IntermediateSpDecompositionTree{ - SplitType::SERIAL, - {ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)}, - }; + BinarySPDecompositionTree new_tree = make_series_split( + ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -83,7 +79,7 @@ std::optional MultiDiEdge e = get_only(get_edges(ttsp)); if (ttsp.get_multidiedge_src(e) != ttsp.get_multidiedge_dst(e)) { - return to_final_ast(ttsp_edge_to_sp_tree.at(e)); + return nary_sp_tree_from_binary(ttsp_edge_to_sp_tree.at(e)); } } } diff --git a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc similarity index 79% rename from lib/utils/src/utils/graph/serial_parallel/graph_generation.cc rename to lib/utils/src/utils/graph/series_parallel/graph_generation.cc index 4c9eb9d3ef..7070d04c4a 100644 --- a/lib/utils/src/utils/graph/serial_parallel/graph_generation.cc +++ b/lib/utils/src/utils/graph/series_parallel/graph_generation.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/graph_generation.h" +#include "utils/graph/series_parallel/graph_generation.h" #include "utils/graph/dataflow_graph/algorithms.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -12,7 +12,7 @@ void parallel_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { } } -void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { +void series_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { // TODO(@lockshaw): This function signature is impossible to implement in // general, as there is no guarantee that the graph view ext actually has // source nodes with inputs Either the signature should be changed, or an @@ -22,11 +22,11 @@ void serial_extend_unsafe(DataflowGraph &g, DataflowGraphView const &ext) { NOT_IMPLEMENTED(); } -DataflowGraph serial_composition(DataflowGraphView const &g1, +DataflowGraph series_composition(DataflowGraphView const &g1, DataflowGraphView const &g2) { DataflowGraph g = DataflowGraph::create_copy_of(g1); - serial_extend_unsafe(g, g2); + series_extend_unsafe(g, g2); return g; } @@ -39,8 +39,8 @@ DataflowGraph parallel_composition(DataflowGraphView const &g1, } DataflowGraph dataflow_graph_from_sp_decomposition( - SerialParallelDecomposition const &sp_decomposition) { - // TODO(@lockshaw): see existing concerns about serial_extend_unsafe + SeriesParallelDecomposition const &sp_decomposition) { + // TODO(@lockshaw): see existing concerns about series_extend_unsafe NOT_IMPLEMENTED(); } diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc new file mode 100644 index 0000000000..48c936ec39 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -0,0 +1,84 @@ +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/containers/extend.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/overload.h" + +namespace FlexFlow { + +struct FlattenAST { + void add_flattened_child_to_parent( + IntermediateSpDecompositionTree &parent, + std::variant const &child) { + if (std::holds_alternative(child)) { + parent.children.push_back(child); + return; + } + + IntermediateSpDecompositionTree child_node = + std::get(child); + + if (parent.type == child_node.type) { + extend(parent.children, child_node.children); + } else { + parent.children.push_back(child); + } + } + + std::variant + operator()(IntermediateSpDecompositionTree const &ast_node) { + IntermediateSpDecompositionTree result(ast_node.type, {}); + for (std::variant const &child : + ast_node.children) { + std::variant flattened_child = + flatten_ast(child); + add_flattened_child_to_parent(result, flattened_child); + } + return result; + } + + std::variant + operator()(Node const &ast_node) { + return ast_node; + } +}; + +std::variant flatten_ast( + std::variant const &ast) { + return std::visit(FlattenAST{}, ast); +} + +std::variant + from_binary_sp_tree(GenericBinarySPDecompositionTree const &binary) { + return visit>( + binary, + overload{ + [](Node const &n) { return n; }, + [](GenericBinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(get_left_child(s)), + from_binary_sp_tree(get_right_child(s)), + }, + }; + }, + [](GenericBinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(get_left_child(p)), + from_binary_sp_tree(get_right_child(p)), + }, + }; + }, + }); +} + +std::variant + from_binary_sp_tree(BinarySPDecompositionTree const &binary) { + return from_binary_sp_tree(binary.raw_tree); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 93% rename from lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc index 30aa10edd7..12a6630bf0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 52% rename from lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 666bf40f10..e697533054 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,18 +1,20 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" #include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/hash/unordered_set.h" #include "utils/variant.h" namespace FlexFlow { struct ToFinalAST { - std::variant + std::variant operator()(IntermediateSpDecompositionTree const &node) { - if (node.type == SplitType::SERIAL) { - return SerialSplit{transform( + if (node.type == SplitType::SERIES) { + return SeriesSplit{transform( node.children, [](std::variant const &s) { return narrow>( @@ -20,54 +22,55 @@ struct ToFinalAST { .value(); })}; } else { - return ParallelSplit{unordered_set_of(transform( + return ParallelSplit{unordered_multiset_of(transform( node.children, [](std::variant const &s) { - return narrow>( + return narrow>( internal_to_final_ast(s)) .value(); }))}; } } - std::variant operator()(Node const &node) { + std::variant operator()(Node const &node) { return node; } }; -std::variant internal_to_final_ast( +std::variant internal_to_final_ast( std::variant const &ast) { return std::visit(ToFinalAST{}, flatten_ast(ast)); } -SerialParallelDecomposition to_final_ast( +SeriesParallelDecomposition to_final_ast( std::variant const &ast) { - return std::visit([](auto &&x) { return SerialParallelDecomposition{x}; }, + return std::visit([](auto &&x) { return SeriesParallelDecomposition{x}; }, internal_to_final_ast(ast)); } -std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { - return sp.visit>( +std::unordered_multiset get_nodes(SeriesParallelDecomposition const &sp) { + return sp.visit>( [](auto &&t) { return get_nodes(t); }); } -std::unordered_set get_nodes(SerialSplit const &serial) { - return set_union(transform( +std::unordered_multiset get_nodes(SeriesSplit const &serial) { + return multiset_union(transform( serial.children, [](std::variant const &child) - -> std::unordered_set { + -> std::unordered_multiset { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(ParallelSplit const ¶llel) { - return set_union(transform( - parallel.children, [](std::variant const &child) { +std::unordered_multiset get_nodes(ParallelSplit const ¶llel) { + return multiset_union(transform( + vector_of(parallel.children), + [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); } -std::unordered_set get_nodes(Node const &node) { +std::unordered_multiset get_nodes(Node const &node) { return {node}; } diff --git a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc similarity index 65% rename from lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc rename to lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc index 8fa42d4b22..0e04a4f904 100644 --- a/lib/utils/src/utils/graph/serial_parallel/serial_parallel_splits.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc @@ -1,47 +1,47 @@ -#include "utils/graph/serial_parallel/serial_parallel_splits.h" -#include "utils/fmt/unordered_set.h" +#include "utils/graph/series_parallel/series_parallel_splits.h" +#include "utils/fmt/unordered_multiset.h" #include "utils/fmt/variant.h" #include "utils/fmt/vector.h" #include "utils/hash-utils.h" -#include "utils/hash/unordered_set.h" +#include "utils/hash/unordered_multiset.h" #include "utils/hash/vector.h" namespace FlexFlow { -SerialSplit::SerialSplit( +SeriesSplit::SeriesSplit( std::vector> const &children) : children(children) {} -SerialSplit::SerialSplit( +SeriesSplit::SeriesSplit( std::initializer_list> const &children) : children(children) {} -bool SerialSplit::operator==(SerialSplit const &other) const { +bool SeriesSplit::operator==(SeriesSplit const &other) const { return this->tie() == other.tie(); } -bool SerialSplit::operator!=(SerialSplit const &other) const { +bool SeriesSplit::operator!=(SeriesSplit const &other) const { return this->tie() != other.tie(); } -SerialSplit::Tie SerialSplit::tie() const { +SeriesSplit::Tie SeriesSplit::tie() const { return std::tie(this->children); } -std::string format_as(SerialSplit const &split) { - return fmt::format("", split.children); +std::string format_as(SeriesSplit const &split) { + return fmt::format("", split.children); } -std::ostream &operator<<(std::ostream &s, SerialSplit const &split) { +std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { return s << fmt::to_string(split); } ParallelSplit::ParallelSplit( - std::unordered_set> const &children) + std::unordered_multiset> const &children) : children(children) {} ParallelSplit::ParallelSplit( - std::initializer_list> const &children) + std::initializer_list> const &children) : children(children) {} bool ParallelSplit::operator==(ParallelSplit const &other) const { @@ -68,8 +68,8 @@ std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { namespace std { -size_t hash<::FlexFlow::SerialSplit>::operator()( - ::FlexFlow::SerialSplit const &s) const { +size_t hash<::FlexFlow::SeriesSplit>::operator()( + ::FlexFlow::SeriesSplit const &s) const { size_t result = 0; ::FlexFlow::hash_combine(result, s.children); return result; diff --git a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc similarity index 97% rename from lib/utils/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/src/utils/graph/series_parallel/series_reduction.cc index e26f460e0e..7300c93fb0 100644 --- a/lib/utils/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/require_same.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/multidigraph/algorithms/get_incoming_edges.h" diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc new file mode 100644 index 0000000000..8ae825c1ab --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_edges.cc @@ -0,0 +1,10 @@ +#include "utils/graph/undirected/algorithms/get_edges.h" +#include "utils/graph/undirected/undirected_edge_query.h" + +namespace FlexFlow { + +std::unordered_set get_edges(UndirectedGraphView const &g) { + return g.query_edges(undirected_edge_query_all()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc new file mode 100644 index 0000000000..3c05b9d5d5 --- /dev/null +++ b/lib/utils/src/utils/graph/undirected/algorithms/get_neighboring_nodes.cc @@ -0,0 +1,19 @@ +#include "utils/graph/undirected/algorithms/get_neighboring_nodes.h" +#include "utils/containers/vector_of.h" + +namespace FlexFlow { + +std::unordered_set get_neighboring_nodes(UndirectedGraphView const &g, + Node const &n) { + std::unordered_set edges = + g.query_edges(UndirectedEdgeQuery{query_set{n}}); + + std::unordered_set result = + set_union(transform(vector_of(edges), [](UndirectedEdge const &e) { + return std::unordered_set{e.bigger, e.smaller}; + })); + result.erase(n); + return result; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc index 5c41eef7da..3cccf1c6eb 100644 --- a/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc +++ b/lib/utils/src/utils/graph/undirected/undirected_edge_query.cc @@ -6,6 +6,10 @@ UndirectedEdgeQuery undirected_edge_query_all() { return UndirectedEdgeQuery{matchall()}; } +bool matches_edge(UndirectedEdgeQuery const &q, UndirectedEdge const &e) { + return includes(q.nodes, e.bigger) && includes(q.nodes, e.smaller); +} + UndirectedEdgeQuery query_intersection(UndirectedEdgeQuery const &lhs, UndirectedEdgeQuery const &rhs) { return UndirectedEdgeQuery{ diff --git a/lib/utils/src/utils/hash/multiset.cc b/lib/utils/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..d84ca7d614 --- /dev/null +++ b/lib/utils/src/utils/hash/multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/multiset.h" diff --git a/lib/utils/src/utils/hash/unordered_multiset.cc b/lib/utils/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..7f6f73f428 --- /dev/null +++ b/lib/utils/src/utils/hash/unordered_multiset.cc @@ -0,0 +1 @@ +#include "utils/hash/unordered_multiset.h" diff --git a/lib/utils/src/utils/json/check_is_jsonable.cc b/lib/utils/src/utils/json/check_is_jsonable.cc new file mode 100644 index 0000000000..1e78fdb21f --- /dev/null +++ b/lib/utils/src/utils/json/check_is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_jsonable.h" diff --git a/lib/utils/src/utils/json/is_json_deserializable.cc b/lib/utils/src/utils/json/is_json_deserializable.cc new file mode 100644 index 0000000000..17df41433d --- /dev/null +++ b/lib/utils/src/utils/json/is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/is_json_serializable.cc b/lib/utils/src/utils/json/is_json_serializable.cc new file mode 100644 index 0000000000..883ee9f51a --- /dev/null +++ b/lib/utils/src/utils/json/is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/is_json_serializable.h" diff --git a/lib/utils/src/utils/json/is_jsonable.cc b/lib/utils/src/utils/json/is_jsonable.cc new file mode 100644 index 0000000000..3f819f8556 --- /dev/null +++ b/lib/utils/src/utils/json/is_jsonable.cc @@ -0,0 +1 @@ +#include "utils/json/is_jsonable.h" diff --git a/lib/utils/src/utils/json/optional.cc b/lib/utils/src/utils/json/optional.cc new file mode 100644 index 0000000000..c8f0fd2e3c --- /dev/null +++ b/lib/utils/src/utils/json/optional.cc @@ -0,0 +1 @@ +#include "utils/json/optional.h" diff --git a/lib/utils/src/utils/rapidcheck/optional.cc b/lib/utils/src/utils/rapidcheck/optional.cc new file mode 100644 index 0000000000..6d62532e7e --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/optional.cc @@ -0,0 +1 @@ +#include "utils/rapidcheck/optional.h" diff --git a/lib/utils/test/common/include/test/utils/all.h b/lib/utils/test/common/include/test/utils/all.h deleted file mode 100644 index ced1c9ce38..0000000000 --- a/lib/utils/test/common/include/test/utils/all.h +++ /dev/null @@ -1,2 +0,0 @@ -#include "test/utils/doctest.h" -#include "test/utils/rapidcheck.h" diff --git a/lib/utils/test/common/include/test/utils/doctest.h b/lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h similarity index 100% rename from lib/utils/test/common/include/test/utils/doctest.h rename to lib/utils/test/common/include/test/utils/doctest/check_without_stringify.h diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h new file mode 100644 index 0000000000..8333ac4777 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/expected.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_EXPECTED_H + +#include "utils/fmt/expected.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(tl::expected const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h new file mode 100644 index 0000000000..d20dbe6943 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MAP_H + +#include "utils/fmt/map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h new file mode 100644 index 0000000000..b26eee28ba --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_MULTISET_H + +#include "utils/fmt/multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h new file mode 100644 index 0000000000..519cde7d74 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_OPTIONAL_H + +#include "utils/fmt/optional.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::optional const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h new file mode 100644 index 0000000000..db0ed24f13 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/pair.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_PAIR_H + +#include "utils/fmt/pair.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::pair const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h new file mode 100644 index 0000000000..3dd386645c --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_SET_H + +#include "utils/fmt/set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h new file mode 100644 index 0000000000..4fd5d15009 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MAP_H + +#include "utils/fmt/unordered_map.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_map const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h new file mode 100644 index 0000000000..94dae42239 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_multiset.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_MULTISET_H + +#include "utils/fmt/unordered_multiset.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_multiset const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h new file mode 100644 index 0000000000..441590365d --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/unordered_set.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_UNORDERED_SET_H + +#include "utils/fmt/unordered_set.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::unordered_set const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h new file mode 100644 index 0000000000..c30862274a --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/variant.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VARIANT_H + +#include "utils/fmt/variant.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::variant const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h new file mode 100644 index 0000000000..56198a7558 --- /dev/null +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/vector.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H +#define _FLEXFLOW_LIB_UTILS_TEST_COMMON_INCLUDE_TEST_UTILS_DOCTEST_FMT_VECTOR_H + +#include "utils/fmt/vector.h" +#include + +namespace doctest { + +template +struct StringMaker> { + static String convert(std::vector const &m) { + return toString(fmt::to_string(m)); + } +}; + +} // namespace doctest + +#endif diff --git a/lib/utils/test/common/src/common.cc b/lib/utils/test/common/src/common.cc deleted file mode 100644 index 51e981b1f5..0000000000 --- a/lib/utils/test/common/src/common.cc +++ /dev/null @@ -1 +0,0 @@ -#include "test/utils/all.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc new file mode 100644 index 0000000000..1cff2195db --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/expected.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/expected.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc new file mode 100644 index 0000000000..976e65cfca --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc new file mode 100644 index 0000000000..9c5b2f4d1e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc new file mode 100644 index 0000000000..8a3f7f158e --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/optional.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc new file mode 100644 index 0000000000..106fb1c900 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/pair.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/pair.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc new file mode 100644 index 0000000000..9ec70698bc --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc new file mode 100644 index 0000000000..b893e632ed --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_map.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_map.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc new file mode 100644 index 0000000000..55d2e69056 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_multiset.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_multiset.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc new file mode 100644 index 0000000000..13ad811e63 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/unordered_set.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/unordered_set.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc new file mode 100644 index 0000000000..b6cc4f54e4 --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/variant.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/variant.h" diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc new file mode 100644 index 0000000000..0102cd86da --- /dev/null +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/vector.cc @@ -0,0 +1 @@ +#include "test/utils/doctest/fmt/vector.h" diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index a1dd75504e..44f602f3bc 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -1,10 +1,10 @@ -#include "test/utils/doctest.h" #include "utils/graph/algorithms.h" #include "utils/graph/construction.h" #include "utils/graph/hashmap_undirected_graph.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/undirected.h" #include +#include #include #include #include diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index af7792dc6d..dca500ced5 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/containers.h" +#include #include #include #include @@ -275,9 +275,9 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == std::vector({2, 4, 6})); } - TEST_CASE("as_vector") { + TEST_CASE("vector_of") { std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); + std::vector result = vector_of(s); CHECK(result == std::vector({3, 2, 1})); } diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/test_deduplicated_priority_queue.cc index 66cfd395bc..048e95acb7 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/test_deduplicated_priority_queue.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/deduplicated_priority_queue.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index 80fcf87d6b..65037be3dd 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/disjoint_set.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/test_dot_file.cc index ed4c32bb1c..e409572511 100644 --- a/lib/utils/test/src/test_dot_file.cc +++ b/lib/utils/test/src/test_dot_file.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/dot_file.h" +#include #include TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/test_format.cc index eeed2eae81..f0d396a123 100644 --- a/lib/utils/test/src/test_format.cc +++ b/lib/utils/test/src/test_format.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/record_formatter.h" +#include std::string formatRecord(RecordFormatter const &formatter) { std::ostringstream oss; diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc index b38c43fe30..decf405e7a 100644 --- a/lib/utils/test/src/test_hash.cc +++ b/lib/utils/test/src/test_hash.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/hash-utils.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 90e1bb2187..cc7ac1de32 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -1,7 +1,7 @@ -#include "test/utils/doctest.h" #include "utils/graph/adjacency_multidigraph.h" #include "utils/graph/multidiedge.h" #include "utils/graph/multidigraph_interfaces.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc index 88a566a198..2b816eea4f 100644 --- a/lib/utils/test/src/test_random_utils.cc +++ b/lib/utils/test/src/test_random_utils.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "utils/random_utils.h" #include +#include void checkProbabilities(std::vector const &counts, int numIterations, diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/test_sequence.cc index ee72febe05..a758476fd9 100644 --- a/lib/utils/test/src/test_sequence.cc +++ b/lib/utils/test/src/test_sequence.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/sequence.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/test_stack_map.cc index 21c1b07d1b..f117820c5d 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/test_stack_map.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/stack_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/test_stack_string.cc index a044f85fe3..b89e3277cd 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/test_stack_string.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/stack_string.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 1af43b6993..577e61092c 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/stack_vector.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/test_tuple.cc index 31308dec2c..96171510a7 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/test_tuple.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/tuple.h" +#include #include #include diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc index b2d8aea848..e7ce12346a 100644 --- a/lib/utils/test/src/test_type_index.cc +++ b/lib/utils/test/src/test_type_index.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/type_index.h" +#include #include using namespace FlexFlow; diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index 33b102bd3b..ea519478d3 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -1,7 +1,8 @@ -#include "test/utils/all.h" +#include "test/utils/rapidcheck.h" #include "test/utils/rapidcheck/visitable.h" #include "utils/graph/hashmap_undirected_graph.h" #include "utils/graph/undirected.h" +#include /* namespace rc { */ diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 98b28a48e9..0bd01b8dfe 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,6 +1,6 @@ -#include "test/utils/doctest.h" #include "test/utils/rapidcheck.h" #include "utils/variant.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/test_vector.cc index 4bdc724dd8..c6eb0828b8 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/test_vector.cc @@ -1,5 +1,5 @@ -#include "test/utils/doctest.h" #include "utils/vector.h" +#include TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") { diff --git a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc index 6e3ac8c155..b5a373e5c9 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_from_enumerating.cc @@ -1,7 +1,7 @@ #include "utils/bidict/algorithms/bidict_from_enumerating.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" -#include "utils/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 5c2ffd5bba..fed655013f 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -1,6 +1,8 @@ #include "utils/bidict/bidict.h" -#include "test/utils/doctest.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/check_without_stringify.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/vector.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc index 2eb8f869f9..49fed81b29 100644 --- a/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/try_merge_nondisjoint_bidicts.cc @@ -1,6 +1,6 @@ #include "utils/bidict/try_merge_nondisjoint_bidicts.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/cli/cli_get_help_message.cc b/lib/utils/test/src/utils/cli/cli_get_help_message.cc new file mode 100644 index 0000000000..b3ee4d3318 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_get_help_message.cc @@ -0,0 +1,519 @@ +#include "utils/cli/cli_get_help_message.h" +#include "utils/join_strings.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_get_help_message(std::string, CLISpec)") { + std::string program_name = "prog_name"; + + SUBCASE("no flags or positional arguments") { + CLISpec cli = CLISpec{ + {}, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name\n"); + + CHECK(result == correct); + } + + SUBCASE("no flags") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "pos-arg-1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name pos-arg-1\n" + "\n" + "positional arguments:\n" + " pos-arg-1\n"); + + CHECK(result == correct); + } + + SUBCASE("no positional arguments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag-1", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag-1\n"); + + CHECK(result == correct); + } + + SUBCASE("flag formatting") { + SUBCASE("flag with shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f]\n" + "\n" + "options:\n" + " -f, --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flag without shortname") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag]\n" + "\n" + "options:\n" + " --flag\n"); + + CHECK(result == correct); + } + + SUBCASE("flags are displayed in provided order") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [--flag2] [--flag1]\n" + "\n" + "options:\n" + " --flag2\n" + " --flag1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("positional argument formatting") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg\n" + "\n" + "positional arguments:\n" + " posarg\n"); + + CHECK(result == correct); + } + + SUBCASE("with choices") { + SUBCASE("choices are not empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {red,blue,green}\n" + "\n" + "positional arguments:\n" + " {red,blue,green}\n"); + + CHECK(result == correct); + } + + SUBCASE("choices are empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {}\n" + "\n" + "positional arguments:\n" + " {}\n"); + + CHECK(result == correct); + } + } + + SUBCASE("are displayed in provided order") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name posarg2 posarg1\n" + "\n" + "positional arguments:\n" + " posarg2\n" + " posarg1\n"); + + CHECK(result == correct); + } + } + + SUBCASE("flag and positional argument alignment") { + SUBCASE("flags are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + CLIFlagSpec{ + "flag2-is-long", + std::nullopt, + "flag2-is-long description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "help text for posarg", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] [--flag2-is-long] posarg\n" + "\n" + "positional arguments:\n" + " posarg help text for posarg\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n" + " --flag2-is-long flag2-is-long description\n"); + + CHECK(result == correct); + } + + SUBCASE("pos args are longer") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + '1', + "flag1 description", + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1-is-very-long", + std::nullopt, + "help text for posarg1-is-very-long", + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + "help text for posarg2", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-1] posarg1-is-very-long posarg2\n" + "\n" + "positional arguments:\n" + " posarg1-is-very-long help text for posarg1-is-very-long\n" + " posarg2 help text for posarg2\n" + "\n" + "options:\n" + " -1, --flag1 flag1 description\n"); + + CHECK(result == correct); + } + + SUBCASE("line break behavior") { + SUBCASE("line breaks max out other argument alignments") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag", + 'f', + "flag help text", + }, + }, + { + CLIPositionalArgumentSpec{ + "abcdefghijklmnopqrstuvwxyz0123456789", + std::nullopt, + "long arg help text", + }, + CLIPositionalArgumentSpec{ + "posarg", + std::nullopt, + "posarg help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name [-f] " + "abcdefghijklmnopqrstuvwxyz0123456789 posarg\n" + "\n" + "positional arguments:\n" + " abcdefghijklmnopqrstuvwxyz0123456789\n" + " long arg help text\n" + " posarg posarg help text\n" + "\n" + "options:\n" + " -f, --flag flag help text\n"); + + CHECK(result == correct); + } + SUBCASE("positional argument line break behavior") { + SUBCASE("positional arguments cause a line break at or above " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 22); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE("positional arguments do not cause a line break below " + "formatted-length 22") { + std::string arg_name = "aaaaaaaaaaaaaaaaaaaaa"; + REQUIRE(arg_name.size() == 21); + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + arg_name, + std::nullopt, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name aaaaaaaaaaaaaaaaaaaaa\n" + "\n" + "positional arguments:\n" + " aaaaaaaaaaaaaaaaaaaaa\n" + " help text\n"); + } + } + + SUBCASE("flag line break behavior") { + SUBCASE("flags cause a line break at or above formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 21); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbbb\n" + " flag description\n"); + + CHECK(result == correct); + } + + SUBCASE("flags do not cause a line break below formatted-length 21") { + std::string arg_name = "bbbbbbbbbbbbbb"; + { + std::string formatted = "-b, --" + arg_name; + REQUIRE(formatted.size() == 20); + } + + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + arg_name, + 'b', + "flag description", + }, + }, + {}, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = + ("usage: prog_name [-b]\n" + "\n" + "options:\n" + " -b, --bbbbbbbbbbbbbb flag description\n"); + + CHECK(result == correct); + } + } + + SUBCASE("choice line breakpoint formatting") { + SUBCASE( + "choices cause a line break at or above formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "fffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 21); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,fffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,fffffffff}\n" + " help text\n"); + + CHECK(result == correct); + } + + SUBCASE( + "choices do not cause a line break below formatted-length 21") { + std::vector choices = { + "a", "b", "c", "d", "e", "ffffffff"}; + { + std::string formatted_choices = + "{" + join_strings(choices, ",") + "}"; + REQUIRE(formatted_choices.size() == 20); + } + + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + choices, + "help text", + }, + }, + }; + + std::string result = cli_get_help_message(program_name, cli); + std::string correct = ("usage: prog_name {a,b,c,d,e,ffffffff}\n" + "\n" + "positional arguments:\n" + " {a,b,c,d,e,ffffffff} help text\n"); + + CHECK(result == correct); + } + } + } + } + } +} diff --git a/lib/utils/test/src/utils/cli/cli_parse.cc b/lib/utils/test/src/utils/cli/cli_parse.cc new file mode 100644 index 0000000000..40dea86ae0 --- /dev/null +++ b/lib/utils/test/src/utils/cli/cli_parse.cc @@ -0,0 +1,477 @@ +#include "utils/cli/cli_parse.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cli_parse_flag(CLISpec, std::string)") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("correctly parses short flag") { + std::string input = "-2"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag2; + + CHECK(result == correct); + } + + SUBCASE("correctly parses long flag") { + std::string input = "--flag1"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = key_flag1; + + CHECK(result == correct); + } + + SUBCASE("fails on unknown flag") { + std::string input = "--not-real"; + + tl::expected result = cli_parse_flag(cli, input); + tl::expected correct = + tl::unexpected("Encountered unknown flag --not-real"); + + CHECK(result == correct); + } + + SUBCASE("fails on non-flag") { + std::string input = "-flag1"; + + std::optional result = + optional_from_expected(cli_parse_flag(cli, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + TEST_CASE("cli_parse(CLISpec, std::vector)") { + SUBCASE("works even if cli is empty") { + CLISpec cli = CLISpec{{}, {}}; + std::vector inputs = {"prog_name"}; + + tl::expected result = cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, {}}; + + CHECK(result == correct); + } + + SUBCASE("flag parsing") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + '2', + std::nullopt, + }, + }, + {}, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + + SUBCASE("parses flags in any order") { + std::vector inputs = {"prog_name", "-2", "--flag1"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if some are not present") { + std::vector inputs = {"prog_name", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine if none are present") { + std::vector inputs = {"prog_name"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, false}, + }, + {}, + }; + + CHECK(result == correct); + } + + SUBCASE("is fine even if the program name is a flag") { + std::vector inputs = {"--flag1", "-2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, false}, + {key_flag2, true}, + }, + {}, + }; + + CHECK(result == correct); + } + } + + SUBCASE("positional argument parsing") { + SUBCASE("without choices") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::nullopt, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("can parse multiple positional arguments") { + std::vector inputs = {"prog_name", "hello", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("requires all positional arguments to be present") { + std::vector inputs = {"prog_name", "hello"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + + SUBCASE("requires no extra positional arguments to be present") { + std::vector inputs = { + "prog_name", "hello", "there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + tl::unexpected("Too many positional arguments: expected 2"); + + CHECK(result == correct); + } + + SUBCASE("allows arguments to contain spaces") { + std::vector inputs = { + "prog_name", "hello there", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello there"}, + {key_posarg2, "world"}, + }}; + + CHECK(result == correct); + } + + SUBCASE("allows arguments to be empty") { + std::vector inputs = {"prog_name", "hello", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = + CLIParseResult{{}, + { + {key_posarg1, "hello"}, + {key_posarg2, ""}, + }}; + + CHECK(result == correct); + } + } + + SUBCASE("with choices") { + SUBCASE("choices is non-empty") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + }, + }; + + CLIPositionalArgumentKey key_posarg = CLIPositionalArgumentKey{0}; + + SUBCASE( + "succeeds if a positional argument is set to a valid choice") { + std::vector inputs = {"prog_name", "blue"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + {}, + { + {key_posarg, "red"}, + }, + }; + } + + SUBCASE( + "fails if a positional argument is set to an invalid choice") { + std::vector inputs = {"prog_name", " red"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \" red\""); + + CHECK(result == correct); + } + } + + SUBCASE("if choices is empty, rejects everything") { + CLISpec cli = CLISpec{ + {}, + { + CLIPositionalArgumentSpec{ + "posarg", + std::vector{}, + std::nullopt, + }, + }, + }; + + std::vector inputs = {"prog_name", ""}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Invalid option for positional argument \"posarg\": \"\""); + + CHECK(result == correct); + } + } + } + + SUBCASE("correctly differentiates mixed arguments/flags") { + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + SUBCASE("works if flags are before positional arguments") { + std::vector inputs = { + "prog_name", "-f", "--flag3", "red", "world"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("works if flags are interspersed") { + std::vector inputs = { + "prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("detects if posargs are missing instead of treating flags as " + "posarg values") { + std::vector inputs = {"prog_name", "-f", "red", "--flag2"}; + + tl::expected result = + cli_parse(cli, inputs); + tl::expected correct = tl::unexpected( + "Not enough positional arguments: found 1, expected 2"); + + CHECK(result == correct); + } + } + } + + TEST_CASE("cli_parse(CLISpec, int argc, char const * const *argv)") { + // most cases are checked in the other overload, + // i.e., cli_parse(CLISpec, std::vector), + // so here we just throw in a single check to make sure + // nothing has unexpectedly gone wrong + CLISpec cli = CLISpec{ + { + CLIFlagSpec{ + "flag1", + 'f', + std::nullopt, + }, + CLIFlagSpec{ + "flag2", + std::nullopt, + std::nullopt, + }, + CLIFlagSpec{ + "flag3", + 'a', + std::nullopt, + }, + }, + { + CLIPositionalArgumentSpec{ + "posarg1", + std::vector{"red", "blue", "green"}, + std::nullopt, + }, + CLIPositionalArgumentSpec{ + "posarg2", + std::nullopt, + std::nullopt, + }, + }, + }; + CLIFlagKey key_flag1 = CLIFlagKey{0}; + CLIFlagKey key_flag2 = CLIFlagKey{1}; + CLIFlagKey key_flag3 = CLIFlagKey{2}; + CLIPositionalArgumentKey key_posarg1 = CLIPositionalArgumentKey{0}; + CLIPositionalArgumentKey key_posarg2 = CLIPositionalArgumentKey{1}; + + int argc = 5; + char const *argv[] = {"prog_name", "red", "-f", "world", "--flag3"}; + + tl::expected result = + cli_parse(cli, argc, argv); + tl::expected correct = CLIParseResult{ + { + {key_flag1, true}, + {key_flag2, false}, + {key_flag3, true}, + }, + { + {key_posarg1, "red"}, + {key_posarg2, "world"}, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/contains_key.cc b/lib/utils/test/src/utils/containers/contains_key.cc index acc6551cd4..da099113a6 100644 --- a/lib/utils/test/src/utils/containers/contains_key.cc +++ b/lib/utils/test/src/utils/containers/contains_key.cc @@ -1,8 +1,11 @@ #include "utils/containers/contains_key.h" -#include "test/utils/doctest.h" +#include #include +#include #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("contains_key(std::unordered_map, K)") { std::unordered_map m = { diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 2be5f1ef93..c6ce9942e9 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,8 +1,12 @@ #include "utils/containers/enumerate.h" -#include "utils/containers/as_vector.h" -#include "utils/fmt/map.h" -#include "utils/fmt/pair.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/containers/keys.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" #include #include @@ -25,7 +29,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("check iteration order") { std::vector> iterated_result = - as_vector(result); + vector_of(result); std::vector> correct_iteration_order = { {0, "zero"}, {1, "one"}, @@ -46,5 +50,17 @@ TEST_SUITE(FF_TEST_SUITE) { {2, "two"}, {3, "three"}, }; + + std::map result = enumerate(input); + + std::unordered_set result_keys = keys(correct); + std::unordered_set result_values = + unordered_set_of(values(correct)); + + std::unordered_set correct_keys = {0, 1, 2, 3}; + std::unordered_set correct_values = input; + + CHECK(result_keys == correct_keys); + CHECK(result_values == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/extend.cc b/lib/utils/test/src/utils/containers/extend.cc index e0d156a3fc..ef2a67725c 100644 --- a/lib/utils/test/src/utils/containers/extend.cc +++ b/lib/utils/test/src/utils/containers/extend.cc @@ -1,6 +1,6 @@ #include "utils/containers/extend.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filter.cc b/lib/utils/test/src/utils/containers/filter.cc index da459094ef..770ad40375 100644 --- a/lib/utils/test/src/utils/containers/filter.cc +++ b/lib/utils/test/src/utils/containers/filter.cc @@ -1,10 +1,10 @@ #include "utils/containers/filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_keys.cc b/lib/utils/test/src/utils/containers/filtermap_keys.cc index 758264627b..582e94392b 100644 --- a/lib/utils/test/src/utils/containers/filtermap_keys.cc +++ b/lib/utils/test/src/utils/containers/filtermap_keys.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_keys.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtermap_values.cc b/lib/utils/test/src/utils/containers/filtermap_values.cc index d2b6ddd220..8db6d6a964 100644 --- a/lib/utils/test/src/utils/containers/filtermap_values.cc +++ b/lib/utils/test/src/utils/containers/filtermap_values.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtermap_values.h" -#include "test/utils/doctest.h" -#include "utils/fmt/map.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace FlexFlow; diff --git a/lib/utils/test/src/utils/containers/filtrans.cc b/lib/utils/test/src/utils/containers/filtrans.cc index b8bb832b06..cd1c2f896c 100644 --- a/lib/utils/test/src/utils/containers/filtrans.cc +++ b/lib/utils/test/src/utils/containers/filtrans.cc @@ -1,7 +1,7 @@ #include "utils/containers/filtrans.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/foldl1.cc b/lib/utils/test/src/utils/containers/foldl1.cc new file mode 100644 index 0000000000..597aa5e109 --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldl1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldl1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldl1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldl1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"a s", "tr", "ing"}; + + std::string result = foldl1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/foldr1.cc b/lib/utils/test/src/utils/containers/foldr1.cc new file mode 100644 index 0000000000..3c9d9b66ae --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldr1.cc @@ -0,0 +1,27 @@ +#include "utils/containers/foldr1.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldr1(std::vector, F)") { + auto concat = [](std::string const &accum, std::string const &s) { + return accum + s; + }; + + SUBCASE("empty input") { + std::vector input = {}; + CHECK_THROWS(foldr1(input, concat)); + } + + SUBCASE("non-empty input") { + std::vector input = {"ing", "tr", "a s"}; + + std::string result = foldr1(input, concat); + + std::string correct = "a string"; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_permutations.cc b/lib/utils/test/src/utils/containers/get_all_permutations.cc index 5f22266809..cc5edb4075 100644 --- a/lib/utils/test/src/utils/containers/get_all_permutations.cc +++ b/lib/utils/test/src/utils/containers/get_all_permutations.cc @@ -1,8 +1,7 @@ #include "utils/containers/get_all_permutations.h" -#include "utils/containers/as_vector.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/vector.h" #include "utils/hash/vector.h" #include diff --git a/lib/utils/test/src/utils/containers/get_element_counts.cc b/lib/utils/test/src/utils/containers/get_element_counts.cc index 11e2ef7e05..8fc87dba90 100644 --- a/lib/utils/test/src/utils/containers/get_element_counts.cc +++ b/lib/utils/test/src/utils/containers/get_element_counts.cc @@ -1,5 +1,5 @@ #include "utils/containers/get_element_counts.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/inplace_filter.cc b/lib/utils/test/src/utils/containers/inplace_filter.cc index 7ef9d73339..ac430279b0 100644 --- a/lib/utils/test/src/utils/containers/inplace_filter.cc +++ b/lib/utils/test/src/utils/containers/inplace_filter.cc @@ -1,10 +1,11 @@ #include "utils/containers/inplace_filter.h" -#include "test/utils/all.h" -#include "utils/fmt/map.h" -#include "utils/fmt/set.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/map.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/intersection.cc b/lib/utils/test/src/utils/containers/intersection.cc index ac9acf5e2b..52de6ee6d3 100644 --- a/lib/utils/test/src/utils/containers/intersection.cc +++ b/lib/utils/test/src/utils/containers/intersection.cc @@ -1,6 +1,6 @@ #include "utils/containers/intersection.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc new file mode 100644 index 0000000000..71e7395805 --- /dev/null +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -0,0 +1,60 @@ +#include "utils/containers/maximum.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("maximum(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + + SUBCASE("input is empty") { + T input = {}; + + std::optional result = maximum(input); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input does not have duplicates") { + T input = {1, 3, 2}; + + std::optional result = maximum(input); + std::optional correct = 3; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + T input = {1, 2, 2, 0}; + + std::optional result = maximum(input); + std::optional correct = 2; + + CHECK(result == correct); + } + } + + TEST_CASE("maximum(std::vector)") { + std::vector input = {"hello", "world"}; + + std::optional result = maximum(input); + std::optional correct = "world"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/multiset_union.cc b/lib/utils/test/src/utils/containers/multiset_union.cc new file mode 100644 index 0000000000..8c40bf55ab --- /dev/null +++ b/lib/utils/test/src/utils/containers/multiset_union.cc @@ -0,0 +1,29 @@ +#include "utils/containers/multiset_union.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("multiset_union(std::unordered_multiset, " + "std::unordered_multiset)") { + std::unordered_multiset input_lhs = {1, 2, 2, 3}; + std::unordered_multiset input_rhs = {1, 2, 5}; + + std::unordered_multiset result = multiset_union(input_lhs, input_rhs); + std::unordered_multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } + + TEST_CASE("multiset_union(std::multiset, std::multiset)") { + std::multiset input_lhs = {1, 2, 2, 3}; + std::multiset input_rhs = {1, 2, 5}; + + std::multiset result = multiset_union(input_lhs, input_rhs); + std::multiset correct = {1, 1, 2, 2, 2, 3, 5}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/repeat.cc b/lib/utils/test/src/utils/containers/repeat.cc index 50e4b3e7c5..d8ffe76a64 100644 --- a/lib/utils/test/src/utils/containers/repeat.cc +++ b/lib/utils/test/src/utils/containers/repeat.cc @@ -1,5 +1,5 @@ #include "utils/containers/repeat.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/require_no_duplicates.cc b/lib/utils/test/src/utils/containers/require_no_duplicates.cc new file mode 100644 index 0000000000..67733d791a --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_no_duplicates.cc @@ -0,0 +1,62 @@ +#include "utils/containers/require_no_duplicates.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("require_no_duplicates(std::unordered_multiset)") { + SUBCASE("empty") { + std::unordered_multiset input = {}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::unordered_multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::unordered_multiset input = {1, 2, 4}; + + std::unordered_set result = require_no_duplicates(input); + std::unordered_set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } + + TEST_CASE("require_no_duplicates(std::multiset)") { + SUBCASE("empty") { + std::multiset input = {}; + + std::set result = require_no_duplicates(input); + std::set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input has duplicates") { + std::multiset input = {1, 2, 2}; + + CHECK_THROWS(require_no_duplicates(input)); + } + + SUBCASE("input does not have duplicates") { + std::multiset input = {1, 2, 4}; + + std::set result = require_no_duplicates(input); + std::set correct = {1, 2, 4}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/reversed.cc b/lib/utils/test/src/utils/containers/reversed.cc new file mode 100644 index 0000000000..834a497152 --- /dev/null +++ b/lib/utils/test/src/utils/containers/reversed.cc @@ -0,0 +1,27 @@ +#include "utils/containers/reversed.h" +#include "test/utils/doctest/fmt/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("reversed(std::vector)") { + SUBCASE("non-empty input") { + std::vector input = {1, 2, 3, 2}; + + std::vector result = reversed(input); + std::vector correct = {2, 3, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + + std::vector result = reversed(input); + std::vector correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/to_uppercase.cc b/lib/utils/test/src/utils/containers/to_uppercase.cc new file mode 100644 index 0000000000..9729307304 --- /dev/null +++ b/lib/utils/test/src/utils/containers/to_uppercase.cc @@ -0,0 +1,15 @@ +#include "utils/containers/to_uppercase.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("to_uppercase(std::string)") { + std::string input = "Hello World"; + + std::string result = to_uppercase(input); + std::string correct = "HELLO WORLD"; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/transform.cc b/lib/utils/test/src/utils/containers/transform.cc index 916bc20928..3122c67117 100644 --- a/lib/utils/test/src/utils/containers/transform.cc +++ b/lib/utils/test/src/utils/containers/transform.cc @@ -1,7 +1,7 @@ #include "utils/containers/transform.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_set.h" -#include "utils/fmt/vector.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc index 6aeab4ae6e..b8a7a85f74 100644 --- a/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc +++ b/lib/utils/test/src/utils/containers/try_merge_nondisjoint_unordered_maps.cc @@ -1,7 +1,7 @@ #include "utils/containers/try_merge_nondisjoint_unordered_maps.h" -#include "test/utils/doctest.h" -#include "utils/fmt/optional.h" -#include "utils/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc index 0ab0ef1446..becb7fdce0 100644 --- a/lib/utils/test/src/utils/containers/unordered_multiset_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_multiset_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/containers/unordered_set_of.cc b/lib/utils/test/src/utils/containers/unordered_set_of.cc index d42b41dd50..b8ca1d1797 100644 --- a/lib/utils/test/src/utils/containers/unordered_set_of.cc +++ b/lib/utils/test/src/utils/containers/unordered_set_of.cc @@ -1,5 +1,5 @@ #include "utils/containers/unordered_set_of.h" -#include "utils/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include #include diff --git a/lib/utils/test/src/utils/containers/vector_of.cc b/lib/utils/test/src/utils/containers/vector_of.cc new file mode 100644 index 0000000000..8b9353e1b0 --- /dev/null +++ b/lib/utils/test/src/utils/containers/vector_of.cc @@ -0,0 +1,17 @@ +#include "utils/containers/vector_of.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("vector_of(std::set)") { + std::set input = {2, 3, 1, 4}; + + std::vector result = vector_of(input); + std::vector correct = {1, 2, 3, 4}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/without_order.cc b/lib/utils/test/src/utils/containers/without_order.cc index 939c6ff108..b4c8663b14 100644 --- a/lib/utils/test/src/utils/containers/without_order.cc +++ b/lib/utils/test/src/utils/containers/without_order.cc @@ -1,5 +1,5 @@ #include "utils/containers/without_order.h" -#include "utils/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include #include diff --git a/lib/utils/test/src/utils/expected.cc b/lib/utils/test/src/utils/expected.cc index 14679e0d13..3e5de13d49 100644 --- a/lib/utils/test/src/utils/expected.cc +++ b/lib/utils/test/src/utils/expected.cc @@ -1,6 +1,6 @@ #include "utils/expected.h" -#include "utils/fmt/expected.h" -#include "utils/fmt/optional.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/expected.cc b/lib/utils/test/src/utils/fmt/expected.cc index fb39732761..48df8634db 100644 --- a/lib/utils/test/src/utils/fmt/expected.cc +++ b/lib/utils/test/src/utils/fmt/expected.cc @@ -1,5 +1,6 @@ #include "utils/fmt/expected.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/expected.h" +#include using namespace ::FlexFlow; @@ -19,24 +20,4 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } - - TEST_CASE("doctest::toString(tl::expected)") { - SUBCASE("has expected") { - tl::expected input = 3; - - doctest::String result = doctest::toString(input); - doctest::String correct = "expected(3)"; - - CHECK(result == correct); - } - - SUBCASE("has unexpected") { - tl::expected input = tl::make_unexpected("error"); - - doctest::String result = doctest::toString(input); - doctest::String correct = "unexpected(error)"; - - CHECK(result == correct); - } - } } diff --git a/lib/utils/test/src/utils/fmt/map.cc b/lib/utils/test/src/utils/fmt/map.cc index b65b4791ea..19f3a7d5cf 100644 --- a/lib/utils/test/src/utils/fmt/map.cc +++ b/lib/utils/test/src/utils/fmt/map.cc @@ -1,5 +1,5 @@ #include "utils/fmt/map.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/optional.cc b/lib/utils/test/src/utils/fmt/optional.cc index e7815a26ac..1cd79da747 100644 --- a/lib/utils/test/src/utils/fmt/optional.cc +++ b/lib/utils/test/src/utils/fmt/optional.cc @@ -1,5 +1,5 @@ #include "utils/fmt/optional.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/pair.cc b/lib/utils/test/src/utils/fmt/pair.cc index 3d7cc78756..e848eb08c7 100644 --- a/lib/utils/test/src/utils/fmt/pair.cc +++ b/lib/utils/test/src/utils/fmt/pair.cc @@ -1,5 +1,5 @@ #include "utils/fmt/pair.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/set.cc b/lib/utils/test/src/utils/fmt/set.cc index 66824f2b2a..e317954b02 100644 --- a/lib/utils/test/src/utils/fmt/set.cc +++ b/lib/utils/test/src/utils/fmt/set.cc @@ -1,5 +1,5 @@ #include "utils/fmt/set.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_map.cc b/lib/utils/test/src/utils/fmt/unordered_map.cc index 99752d73f4..c980bc1e52 100644 --- a/lib/utils/test/src/utils/fmt/unordered_map.cc +++ b/lib/utils/test/src/utils/fmt/unordered_map.cc @@ -1,6 +1,7 @@ #include "utils/fmt/unordered_map.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_map.h" #include "utils/containers/get_element_counts.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/unordered_set.cc b/lib/utils/test/src/utils/fmt/unordered_set.cc index 9dc8d236f1..f492ea844d 100644 --- a/lib/utils/test/src/utils/fmt/unordered_set.cc +++ b/lib/utils/test/src/utils/fmt/unordered_set.cc @@ -1,7 +1,7 @@ #include "utils/fmt/unordered_set.h" -#include "test/utils/doctest.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "utils/containers/unordered_multiset_of.h" -#include "utils/fmt/unordered_multiset.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/variant.cc b/lib/utils/test/src/utils/fmt/variant.cc index 3ada166de9..0c8dca35d7 100644 --- a/lib/utils/test/src/utils/fmt/variant.cc +++ b/lib/utils/test/src/utils/fmt/variant.cc @@ -1,5 +1,5 @@ #include "utils/fmt/variant.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/fmt/vector.cc b/lib/utils/test/src/utils/fmt/vector.cc index fee3eb34a5..91ef6c9efc 100644 --- a/lib/utils/test/src/utils/fmt/vector.cc +++ b/lib/utils/test/src/utils/fmt/vector.cc @@ -1,5 +1,5 @@ #include "utils/fmt/vector.h" -#include "test/utils/doctest.h" +#include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/graph/cow_ptr_t.cc b/lib/utils/test/src/utils/graph/cow_ptr_t.cc index 65088c19de..e6a6f9661e 100644 --- a/lib/utils/test/src/utils/graph/cow_ptr_t.cc +++ b/lib/utils/test/src/utils/graph/cow_ptr_t.cc @@ -1,5 +1,5 @@ #include "utils/graph/cow_ptr_t.h" -#include "test/utils/doctest.h" +#include #include #include #include diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc new file mode 100644 index 0000000000..330628adfd --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.cc @@ -0,0 +1,43 @@ +#include "utils/graph/dataflow_graph/algorithms/get_subgraph_incoming_edges.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_subgraph_incoming_edges(DataflowGraphView, " + "std::unordered_set") { + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + std::unordered_set input_node_set = {n2, n3}; + + std::unordered_set result = + get_subgraph_incoming_edges(g, input_node_set); + + std::unordered_set correct = { + DataflowEdge{o1, DataflowInput{n2, 0}}, + DataflowEdge{o1, DataflowInput{n3, 0}}, + DataflowEdge{o1, DataflowInput{n3, 2}}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc index 7e02686dde..779d0a9560 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.cc @@ -7,7 +7,8 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_outgoing_edges(DataflowGraphView, std::unordered_set") { + TEST_CASE("get_subgraph_outgoing_edges(DataflowGraphView, " + "std::unordered_set") { DataflowGraph g = DataflowGraph::create(); NodeAddedResult n1_added = g.add_node({}, 1); diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc index cfc912af6b..7a3237d432 100644 --- a/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc +++ b/lib/utils/test/src/utils/graph/dataflow_graph/unordered_open_dataflow_graph.cc @@ -1,9 +1,11 @@ -#include "test/utils/doctest.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" #include "utils/graph/dataflow_graph/dataflow_graph.h" #include "utils/graph/dataflow_graph/dataflow_output_query.h" #include "utils/graph/instances/unordered_set_dataflow_graph.h" #include "utils/graph/node/node_query.h" +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("UnorderedSetDataflowGraph") { diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 2ebfe232b6..eca7aa6c79 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,5 +1,8 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" +#include "utils/containers/reversed.h" +#include "utils/containers/vector_of.h" #include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include @@ -9,6 +12,25 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_cbc_decomposition") { DiGraph g = DiGraph::create(); + // used to check that the cbc decomposition result is the same regardless + // of the order in which the graph edges are processed, as this is a + // property that should hold, and violations of this property have been a + // source of bugs in the past + auto check_cbc_decomposition_is_edge_order_invariant = + [](DiGraphView const &g) { + std::unordered_set edges = get_edges(g); + + std::vector edge_order1 = vector_of(edges); + std::vector edge_order2 = reversed(edge_order1); + + std::optional result1 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order1); + std::optional result2 = + get_cbc_decomposition_with_edge_order_internal(g, edge_order2); + + CHECK(result1 == result2); + }; + SUBCASE("six-node diamond graph") { std::vector n = add_nodes(g, 6); add_edges(g, @@ -32,6 +54,8 @@ TEST_SUITE(FF_TEST_SUITE) { }}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } SUBCASE("graph without any edges") { @@ -43,6 +67,27 @@ TEST_SUITE(FF_TEST_SUITE) { CompleteBipartiteCompositeDecomposition{{}}; CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); + } + + SUBCASE("irreducible n-graph (non-cbc graph)") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_cbc_decomposition(g); + std::optional correct = + std::nullopt; + + CHECK(result == correct); + + check_cbc_decomposition_is_edge_order_invariant(g); } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc new file mode 100644 index 0000000000..17c8b8da27 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_graph.cc @@ -0,0 +1,175 @@ +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView, " + "std::unordered_set)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + SUBCASE("source group") { + std::unordered_set group1 = {n.at(0), n.at(1), n.at(2)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("sink group") { + std::unordered_set group1 = {n.at(3), n.at(4)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + SUBCASE("missing an edge (i.e., not complete)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge (i.e., not bipartite)") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("flipped edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(2), n.at(1)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0), n.at(1)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("group too small") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + std::unordered_set group1 = {n.at(0)}; + + bool result = is_complete_bipartite_digraph(g, group1); + bool correct = false; + + CHECK(result == correct); + } + } + + TEST_CASE("is_complete_bipartite_digraph(UndirectedGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("simple bipartite graph") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("missing an edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("extra edge") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + bool result = is_complete_bipartite_digraph(g); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc index fd2f469f93..a635658755 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/inverse_line_graph/get_inverse_line_graph.cc @@ -4,6 +4,7 @@ #include "utils/containers/transform.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edge_counts.h" @@ -139,5 +140,27 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_bidict == correct_bidict); } } + + SUBCASE("sp n-graph (inverse line graph does not exist)") { + // Tests that the inverse line graph of the sp n-graph + // + // a-b + // \ + // c-d + // + // does not exist + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + std::optional result = + get_inverse_line_graph(transitive_reduction(g)); + + CHECK_FALSE(result.has_value()); + } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc index 3ad506f40a..e675e6903f 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc new file mode 100644 index 0000000000..5f72355ed0 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_closure.cc @@ -0,0 +1,50 @@ +#include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transitive_closure(DiGraphView)") { + DiGraph g = DiGraph::create(); + + SUBCASE("maximum number of new edges") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_closure(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc index b8a35346f4..1f9062a8ed 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -1,4 +1,5 @@ #include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "test/utils/doctest/fmt/optional.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" @@ -76,5 +77,66 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result_edges == correct_edges); } } + + SUBCASE("longer paths") { + std::vector n = add_nodes(g, 5); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + } + + SUBCASE("irreducible sp n-graph") { + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + + DiGraphView result = transitive_reduction(g); + + SUBCASE("nodes") { + std::unordered_set result_nodes = get_nodes(result); + std::unordered_set correct_nodes = unordered_set_of(n); + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("edges") { + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + CHECK(result_edges == correct_edges); + } + } } } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc new file mode 100644 index 0000000000..66b657eaaa --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc @@ -0,0 +1,51 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("fmt GenericBinarySPDecompositionTree") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + std::string result = fmt::to_string(input); + std::string correct = ""; + + CHECK(result == correct); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + + std::string result = fmt::to_string(input); + std::string correct = (" " + "" + ">" + ">"); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + + std::string result = fmt::to_string(input); + std::string correct = (" " + "" + ">" + ">"); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc new file mode 100644 index 0000000000..abae9286b6 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -0,0 +1,86 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_leaves(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5}; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 6}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 5}; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 6}; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {5, 5}; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5))), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(2))); + + std::unordered_multiset result = get_leaves(input); + std::unordered_multiset correct = {2, 2, 4, 4, 5}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc new file mode 100644 index 0000000000..92c556ad28 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc @@ -0,0 +1,41 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_left_child(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + CHECK_THROWS(get_left_child(input)); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(3)); + + GenericBinarySPDecompositionTree result = get_left_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(5); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(7)); + + GenericBinarySPDecompositionTree result = get_left_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(4); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..3de61d3313 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -0,0 +1,85 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + int result = get_num_tree_nodes(input); + int correct = 1; + + CHECK(result == correct); + } + + SUBCASE("series split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + SUBCASE("children are not the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + + SUBCASE("children are the same") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(5)); + + int result = get_num_tree_nodes(input); + int correct = 3; + + CHECK(result == correct); + } + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5))), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(2))); + + int result = get_num_tree_nodes(input); + int correct = 9; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc new file mode 100644 index 0000000000..33b5d37955 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc @@ -0,0 +1,41 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_right_child(GenericBinarySPDecompositionTree)") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(5); + + CHECK_THROWS(get_right_child(input)); + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(3)); + + GenericBinarySPDecompositionTree result = get_right_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(3); + + CHECK(result == correct); + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(7)); + + GenericBinarySPDecompositionTree result = get_right_child(input); + GenericBinarySPDecompositionTree correct = + make_generic_binary_sp_leaf(7); + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc new file mode 100644 index 0000000000..e7025dbfad --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc @@ -0,0 +1,117 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree leaf_5 = + make_generic_binary_sp_leaf(5); + size_t leaf_5_hash = get_std_hash(leaf_5); + + SUBCASE("leaves with same labels hash to the same value") { + GenericBinarySPDecompositionTree also_leaf_5 = + make_generic_binary_sp_leaf(5); + size_t also_leaf_5_hash = get_std_hash(also_leaf_5); + + CHECK(leaf_5_hash == also_leaf_5_hash); + } + + SUBCASE("leaves with different labels hash to different values") { + GenericBinarySPDecompositionTree leaf_6 = + make_generic_binary_sp_leaf(6); + size_t leaf_6_hash = get_std_hash(leaf_6); + + CHECK(leaf_5_hash != leaf_6_hash); + } + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree series_5_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t series_5_6_hash = get_std_hash(series_5_6); + + SUBCASE("same children lead to the same hash") { + GenericBinarySPDecompositionTree also_series_5_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t also_series_5_6_hash = get_std_hash(also_series_5_6); + + CHECK(series_5_6_hash == also_series_5_6_hash); + } + + SUBCASE("hash is order dependent") { + GenericBinarySPDecompositionTree series_6_5 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(6), + make_generic_binary_sp_leaf(5)); + size_t series_6_5_hash = get_std_hash(series_6_5); + + CHECK(series_5_6_hash != series_6_5_hash); + } + + SUBCASE("different left child leads to different hash") { + GenericBinarySPDecompositionTree series_4_6 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(6)); + size_t series_4_6_hash = get_std_hash(series_4_6); + + CHECK(series_5_6_hash != series_4_6_hash); + } + + SUBCASE("different right child leads to different hash") { + GenericBinarySPDecompositionTree series_5_7 = + make_generic_binary_series_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + size_t series_5_7_hash = get_std_hash(series_5_7); + + CHECK(series_5_6_hash != series_5_7_hash); + } + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree parallel_5_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t parallel_5_6_hash = get_std_hash(parallel_5_6); + + SUBCASE("same children lead to the same hash") { + GenericBinarySPDecompositionTree also_parallel_5_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(6)); + size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); + + CHECK(parallel_5_6_hash == also_parallel_5_6_hash); + } + + SUBCASE("hash is order dependent") { + GenericBinarySPDecompositionTree parallel_6_5 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), + make_generic_binary_sp_leaf(5)); + size_t parallel_6_5_hash = get_std_hash(parallel_6_5); + + CHECK(parallel_5_6_hash != parallel_6_5_hash); + } + + SUBCASE("different left child leads to different hash") { + GenericBinarySPDecompositionTree parallel_4_6 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), + make_generic_binary_sp_leaf(6)); + size_t parallel_4_6_hash = get_std_hash(parallel_4_6); + + CHECK(parallel_5_6_hash != parallel_4_6_hash); + } + + SUBCASE("different right child leads to different hash") { + GenericBinarySPDecompositionTree parallel_5_7 = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), + make_generic_binary_sp_leaf(7)); + size_t parallel_5_7_hash = get_std_hash(parallel_5_7); + + CHECK(parallel_5_6_hash != parallel_5_7_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc new file mode 100644 index 0000000000..7a8756c6cc --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -0,0 +1,102 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_left_associative(" + "GenericBinarySPDecompositionTree)") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + SUBCASE("input is actually left associative") { + SUBCASE("just node") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(n1); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n3), + make_generic_binary_sp_leaf(n4))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not left associative") { + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_left_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc new file mode 100644 index 0000000000..3cf87368ab --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -0,0 +1,102 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_binary_sp_tree_right_associative(" + "GenericBinarySPDecompositionTree)") { + int n1 = 1; + int n2 = 2; + int n3 = 3; + int n4 = 4; + + SUBCASE("input is actually right associative") { + SUBCASE("just node") { + GenericBinarySPDecompositionTree input = + make_generic_binary_sp_leaf(n1); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n2), + make_generic_binary_sp_leaf(n3))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("nested") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n3), + make_generic_binary_sp_leaf(n4))); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = true; + + CHECK(result == correct); + } + } + + SUBCASE("input is not right associative") { + SUBCASE("just series") { + GenericBinarySPDecompositionTree input = + make_generic_binary_series_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + + SUBCASE("just parallel") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_parallel_split( + make_generic_binary_sp_leaf(n1), + make_generic_binary_sp_leaf(n2)), + make_generic_binary_sp_leaf(n3)); + + bool result = is_binary_sp_tree_right_associative(input); + bool correct = false; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc new file mode 100644 index 0000000000..cc234bacf8 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc @@ -0,0 +1,131 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer>") { + SUBCASE("leaf") { + GenericBinarySPDecompositionTree tt = make_generic_binary_sp_leaf(5); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + + SUBCASE("series split") { + GenericBinarySPDecompositionTree tt = + make_generic_binary_series_split(make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5)); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "series"}, + { + "value", + { + {"__type", "GenericBinarySeriesSplit"}, + { + "left_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }, + }, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + + SUBCASE("parallel split") { + GenericBinarySPDecompositionTree tt = + make_generic_binary_parallel_split(make_generic_binary_sp_leaf(2), + make_generic_binary_sp_leaf(5)); + + nlohmann::json tt_json = { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "parallel"}, + { + "value", + { + {"__type", "GenericBinaryParallelSplit"}, + { + "left_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"__type", "GenericBinarySPDecompositionTree"}, + {"type", "leaf"}, + {"value", 5}, + }, + }, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = tt; + nlohmann::json correct = tt_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + GenericBinarySPDecompositionTree result = + tt_json.get>(); + GenericBinarySPDecompositionTree correct = tt; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc new file mode 100644 index 0000000000..4ede4e84b5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc @@ -0,0 +1,28 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("transform(GenericBinarySPDecompositionTree, F)") { + GenericBinarySPDecompositionTree input = + make_generic_binary_parallel_split( + make_generic_binary_series_split(make_generic_binary_sp_leaf(1), + make_generic_binary_sp_leaf(4)), + make_generic_binary_sp_leaf(2)); + + GenericBinarySPDecompositionTree result = + transform(input, [](int x) { return std::to_string(x); }); + + GenericBinarySPDecompositionTree correct = + make_generic_binary_parallel_split( + make_generic_binary_series_split( + make_generic_binary_sp_leaf(std::string{"1"}), + make_generic_binary_sp_leaf(std::string{"4"})), + make_generic_binary_sp_leaf(std::string{"2"})); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..1e3217a2de --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,95 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/rapidcheck.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("left_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf_node(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_series_split(make_leaf_node(n1), make_leaf_node(n2)), + make_leaf_node(n3)); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // left-associative binary SP trees + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + BinarySPDecompositionTree result = + left_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_left_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc new file mode 100644 index 0000000000..0befbde5cc --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -0,0 +1,132 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("nary_sp_tree_from_binary(BinarySPDecompositionTree)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("leaf") { + BinarySPDecompositionTree input = make_leaf_node(n1); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + + CHECK(result == correct); + } + + SUBCASE("left associative series") { + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf_node(n2), make_leaf_node(n1)), + make_leaf_node(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("right associative series") { + BinarySPDecompositionTree input = make_series_split( + make_leaf_node(n2), + make_series_split(make_leaf_node(n1), make_leaf_node(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("series with duplicate children") { + BinarySPDecompositionTree input = + make_series_split(make_leaf_node(n1), make_leaf_node(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{n1, n1}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("left associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf_node(n2), make_leaf_node(n1)), + make_leaf_node(n3)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("right associative parallel") { + BinarySPDecompositionTree input = make_parallel_split( + make_leaf_node(n2), + make_parallel_split(make_leaf_node(n1), make_leaf_node(n3))); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + + CHECK(result == correct); + } + + SUBCASE("parallel with duplicate children") { + BinarySPDecompositionTree input = + make_parallel_split(make_leaf_node(n1), make_leaf_node(n1)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{n1, n1}}; + + CHECK(get_nodes(result).size() == 2); + CHECK(result == correct); + } + + SUBCASE("nested") { + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split( + make_parallel_split( + make_leaf_node(n1), + make_series_split( + make_series_split(make_series_split(make_leaf_node(n2), + make_leaf_node(n3)), + make_leaf_node(n3)), + make_leaf_node(n5))), + make_series_split(make_leaf_node(n6), make_leaf_node(n4))), + make_leaf_node(n5)); + + SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc new file mode 100644 index 0000000000..db1b440481 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -0,0 +1,93 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("right_associative_binary_sp_tree_from_nary(" + "SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + Node n5 = Node{5}; + Node n6 = Node{6}; + + SUBCASE("only node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + BinarySPDecompositionTree correct = make_leaf_node(n1); + + CHECK(result == correct); + } + + SUBCASE("only serial") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + BinarySPDecompositionTree correct = make_series_split( + make_leaf_node(n1), + make_series_split(make_leaf_node(n2), make_leaf_node(n3))); + + CHECK(result == correct); + } + + SUBCASE("only parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{n1, n2, n3}, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + // we use multiple checks here because SerialParallelDecomposition's + // ParallelSplit is unordered, so there are multiple possible + // right-associative binary SP trees + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = {n1, n2, n3}; + + CHECK(result_nodes == correct_nodes); + } + + SUBCASE("nested") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{ + n1, + SeriesSplit{ + n2, + n3, + n3, + n5, + }, + SeriesSplit{ + n6, + n4, + }, + n5, + }, + }; + + BinarySPDecompositionTree result = + right_associative_binary_sp_tree_from_nary(input); + + CHECK(is_binary_sp_tree_right_associative(result)); + + std::unordered_multiset result_nodes = get_nodes(input); + std::unordered_multiset correct_nodes = { + n1, n2, n3, n3, n5, n6, n4, n5}; + + CHECK(result_nodes == correct_nodes); + } + } +} diff --git a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc similarity index 50% rename from lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc rename to lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 04d82bf1d8..45f796c824 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/get_serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/get_serial_parallel_decomposition.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include @@ -6,47 +6,47 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_serial_parallel_decomposition (base case)") { + TEST_CASE("get_series_parallel_decomposition (base case)") { DiGraph g = DiGraph::create(); Node n = g.add_node(); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{n}; + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{n}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (parallel)") { + TEST_CASE("get_series_parallel_decomposition (parallel)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 2); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ParallelSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ParallelSplit{ n.at(0), n.at(1), }}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (serial)") { + TEST_CASE("get_series_parallel_decomposition (serial)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 2); g.add_edge(DirectedEdge{n.at(0), n.at(1)}); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ n.at(0), n.at(1), }}; CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (composite)") { + TEST_CASE("get_series_parallel_decomposition (composite)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 3); add_edges(g, @@ -55,11 +55,11 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(0), n.at(2)}, }); - std::optional result = - get_serial_parallel_decomposition(g); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ + std::optional result = + get_series_parallel_decomposition(g); + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ n.at(0), ParallelSplit{ n.at(1), @@ -70,7 +70,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (diamond graph)") { + TEST_CASE("get_series_parallel_decomposition (diamond graph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 6); @@ -85,15 +85,15 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(4), n.at(5)}, }); - std::optional correct = - SerialParallelDecomposition{SerialSplit{ + std::optional correct = + SeriesParallelDecomposition{SeriesSplit{ n.at(0), ParallelSplit{ - SerialSplit{ + SeriesSplit{ n.at(1), n.at(3), }, - SerialSplit{ + SeriesSplit{ n.at(2), n.at(4), }, @@ -101,13 +101,13 @@ TEST_SUITE(FF_TEST_SUITE) { n.at(5), }}; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (all-to-all connection)") { + TEST_CASE("get_series_parallel_decomposition (all-to-all connection)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -120,9 +120,9 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(1), n.at(3)}, }); - std::optional correct = - SerialParallelDecomposition{ - SerialSplit{ + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ ParallelSplit{ n.at(0), n.at(1), @@ -134,13 +134,13 @@ TEST_SUITE(FF_TEST_SUITE) { }, }; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } - TEST_CASE("get_serial_parallel_decomposition (non-sp graph)") { + TEST_CASE("get_series_parallel_decomposition (non-sp graph)") { DiGraph g = DiGraph::create(); std::vector n = add_nodes(g, 4); @@ -153,9 +153,39 @@ TEST_SUITE(FF_TEST_SUITE) { DirectedEdge{n.at(1), n.at(3)}, }); - std::optional correct = std::nullopt; - std::optional result = - get_serial_parallel_decomposition(g); + std::optional correct = std::nullopt; + std::optional result = + get_series_parallel_decomposition(g); + + CHECK(result == correct); + } + + TEST_CASE( + "get_series_parallel_decomposition (requires transitive reduction)") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 4); + + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + std::optional correct = + SeriesParallelDecomposition{ + SeriesSplit{ + n.at(0), + n.at(1), + n.at(2), + n.at(3), + }, + }; + std::optional result = + get_series_parallel_decomposition(g); CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc similarity index 83% rename from lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc rename to lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 4560f95ff7..3a486c7094 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/fmt/variant.h" #include @@ -8,11 +8,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("flatten_ast") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{2}, Node{3}, @@ -25,7 +25,7 @@ TEST_SUITE(FF_TEST_SUITE) { flatten_ast(input); std::variant correct = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, Node{2}, diff --git a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc index 8259d256d3..a62f528bcf 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/parallel_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/parallel_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/parallel_reduction.h" +#include "utils/graph/series_parallel/parallel_reduction.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" #include "utils/graph/multidigraph/algorithms/add_nodes.h" diff --git a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc similarity index 66% rename from lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 7cf17c3fee..f5766c9fdd 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/serial_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,5 +1,5 @@ -#include "utils/graph/serial_parallel/serial_parallel_decomposition.h" -#include "utils/fmt/unordered_set.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include using namespace ::FlexFlow; @@ -7,20 +7,20 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (base case)") { std::variant input = Node{1}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{Node{1}}; + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{Node{1}}; CHECK(result == correct); } TEST_CASE("to_final_ast (serial)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, {Node{1}, Node{2}}, }; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = SerialParallelDecomposition{ - SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{ Node{1}, Node{2}, }}, @@ -30,11 +30,11 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("to_final_ast (composite)") { std::variant input = IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{0}, IntermediateSpDecompositionTree{ - SplitType::SERIAL, + SplitType::SERIES, { Node{1}, IntermediateSpDecompositionTree{ @@ -55,9 +55,9 @@ TEST_SUITE(FF_TEST_SUITE) { Node{5}, }}; - SerialParallelDecomposition result = to_final_ast(input); - SerialParallelDecomposition correct = - SerialParallelDecomposition{SerialSplit{{ + SeriesParallelDecomposition result = to_final_ast(input); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{ Node{0}, Node{1}, ParallelSplit{{ @@ -70,55 +70,55 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - TEST_CASE("get_nodes(SerialParallelDecomposition)") { - SerialParallelDecomposition input = - SerialParallelDecomposition{SerialSplit{{ + TEST_CASE("get_nodes(SeriesParallelDecomposition)") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{ ParallelSplit{{ Node{1}, Node{2}, }}, - Node{3}, + Node{2}, ParallelSplit{{ Node{4}, Node{5}, }}, }}}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{2}, Node{4}, Node{5}, }; CHECK(result == correct); } - TEST_CASE("get_nodes(SerialSplit)") { + TEST_CASE("get_nodes(SeriesSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, ParallelSplit{{ Node{3}, Node{4}, }}, }}, - SerialSplit{{ - Node{5}, + SeriesSplit{{ + Node{1}, Node{6}, }}, Node{7}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, Node{3}, Node{4}, - Node{5}, + Node{1}, Node{6}, Node{7}, }; @@ -129,9 +129,9 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(ParallelSplit)") { ParallelSplit input = ParallelSplit{{ Node{1}, - SerialSplit{{ + SeriesSplit{{ Node{2}, - Node{3}, + Node{4}, ParallelSplit{{ Node{4}, Node{5}, @@ -139,11 +139,11 @@ TEST_SUITE(FF_TEST_SUITE) { }}, }}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = { + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = { Node{1}, Node{2}, - Node{3}, + Node{4}, Node{4}, Node{5}, }; @@ -153,8 +153,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("get_nodes(Node)") { Node input = Node{5}; - std::unordered_set result = get_nodes(input); - std::unordered_set correct = {input}; + std::unordered_multiset result = get_nodes(input); + std::unordered_multiset correct = {input}; CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc similarity index 99% rename from lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc rename to lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc index e4d53b4136..c6b45ec6ce 100644 --- a/lib/utils/test/src/utils/graph/serial_parallel/series_reduction.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_reduction.cc @@ -1,4 +1,4 @@ -#include "utils/graph/serial_parallel/series_reduction.h" +#include "utils/graph/series_parallel/series_reduction.h" #include "utils/containers/set_minus.h" #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/graph/multidigraph/algorithms/add_edges.h" diff --git a/lib/utils/test/src/utils/hash/multiset.cc b/lib/utils/test/src/utils/hash/multiset.cc new file mode 100644 index 0000000000..5c2e01fda8 --- /dev/null +++ b/lib/utils/test/src/utils/hash/multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/hash/unordered_multiset.cc b/lib/utils/test/src/utils/hash/unordered_multiset.cc new file mode 100644 index 0000000000..6c730fad3c --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_multiset.cc @@ -0,0 +1,34 @@ +#include "utils/hash/unordered_multiset.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_multiset input = {1, 2, 2, 1, 5}; + size_t input_hash = get_std_hash(input); + + SUBCASE("same values have the same hash") { + std::unordered_multiset also_input = {2, 1, 2, 5, 1}; + size_t also_input_hash = get_std_hash(input); + + CHECK(input_hash == also_input_hash); + } + + SUBCASE("different values have different hashes") { + SUBCASE("different number of duplicates") { + std::unordered_multiset other = {1, 2, 2, 1, 5, 5}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + + SUBCASE("different elements") { + std::unordered_multiset other = {1, 2, 2, 1, 6}; + size_t other_hash = get_std_hash(other); + + CHECK(input_hash != other_hash); + } + } + } +} diff --git a/lib/utils/test/src/utils/json/optional.cc b/lib/utils/test/src/utils/json/optional.cc new file mode 100644 index 0000000000..61f5868c53 --- /dev/null +++ b/lib/utils/test/src/utils/json/optional.cc @@ -0,0 +1,49 @@ +#include "utils/json/optional.h" +#include "test/utils/doctest/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer>") { + SUBCASE("to_json") { + SUBCASE("has value") { + std::optional input = 5; + + nlohmann::json result = input; + nlohmann::json correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + std::optional input = std::nullopt; + + nlohmann::json result = input; + nlohmann::json correct = nullptr; + + CHECK(result == correct); + } + } + + SUBCASE("from_json") { + SUBCASE("has value") { + nlohmann::json input = 5; + + std::optional result = input; + std::optional correct = 5; + + CHECK(result == correct); + } + + SUBCASE("has nullopt") { + nlohmann::json input = nullptr; + + std::optional result = input.get>(); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/utils/test/src/utils/optional.cc b/lib/utils/test/src/utils/rapidcheck/optional.cc similarity index 67% rename from lib/utils/test/src/utils/optional.cc rename to lib/utils/test/src/utils/rapidcheck/optional.cc index 16c9e964cb..96b17a5400 100644 --- a/lib/utils/test/src/utils/optional.cc +++ b/lib/utils/test/src/utils/rapidcheck/optional.cc @@ -1,7 +1,8 @@ -#include "utils/optional.h" -#include "test/utils/doctest.h" +#include "utils/rapidcheck/optional.h" #include "test/utils/rapidcheck.h" -#include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE( From 64554159ab66da117ee243fdac0592b7aeabd613 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 16 Sep 2024 11:17:19 -0700 Subject: [PATCH 05/12] Add interface for differentiating inputs and weights in CG & PCG (#1493) * Add interface for differentiating inputs and weights in CG/PCG * Format * Address Reyna PR comments * fix bugs from merge * Format --- lib/kernels/include/kernels/legion_dim.h | 2 +- .../include/local-execution/serialization.h | 2 +- .../src/local_cost_estimator.cc | 5 +- .../src/models/transformer/transformer.cc | 9 +- .../op-attrs/computation_graph_op_attrs.h | 3 + .../op-attrs/{ => dim_ordered}/dim_ordered.h | 0 .../include/op-attrs/dim_ordered/enumerate.h | 2 +- .../op-attrs/dim_ordered/ff_ordered_of.h | 2 +- .../include/op-attrs/dim_ordered/get_idxs.h | 2 +- .../include/op-attrs/dim_ordered/slice.h | 2 +- .../include/op-attrs/dim_ordered/transform.h | 2 +- .../include/op-attrs/dim_ordered/zip.h | 2 +- .../op-attrs/get_incoming_tensor_roles.h | 17 ++ .../op-attrs/incoming_tensor_role.enum.toml | 14 + lib/op-attrs/include/op-attrs/ops/attention.h | 20 ++ .../include/op-attrs/ops/batch_matmul.h | 3 + lib/op-attrs/include/op-attrs/ops/conv_2d.h | 4 + .../include/op-attrs/ops/layer_norm.h | 4 + lib/op-attrs/include/op-attrs/ops/linear.h | 10 +- lib/op-attrs/include/op-attrs/ops/topk.h | 2 + .../op-attrs/ops/transpose_attrs.struct.toml | 2 +- .../op-attrs/parallel_tensor_dims.struct.toml | 2 +- .../include/op-attrs/pcg_operator_attrs.h | 4 +- .../op-attrs/pcg_operator_attrs.variant.toml | 5 + .../include/op-attrs/tensor_dims.struct.toml | 2 +- .../op-attrs/computation_graph_op_attrs.cc | 18 ++ .../src/op-attrs/get_incoming_tensor_roles.cc | 103 +++++++ lib/op-attrs/src/op-attrs/ops/attention.cc | 19 ++ lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 14 + lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 13 + lib/op-attrs/src/op-attrs/ops/linear.cc | 23 +- lib/op-attrs/src/op-attrs/ops/topk.cc | 4 + .../src/op-attrs/pcg_operator_attrs.cc | 57 +--- .../op-attrs/computation_graph_op_attrs.cc | 15 + .../test/src/{ => op-attrs}/datatype.cc | 0 .../dim_ordered/dim_ordered.cc} | 2 +- .../src/{ => op-attrs}/dim_ordered/slice.cc | 0 .../src/op-attrs/get_incoming_tensor_roles.cc | 26 ++ .../test/src/{ => op-attrs}/ops/attention.cc | 49 ++++ .../src/{ => op-attrs}/ops/batch_matmul.cc | 0 .../test/src/op-attrs/ops/batch_norm_attrs.cc | 15 + .../test/src/{ => op-attrs}/ops/cast.cc | 0 .../test/src/{ => op-attrs}/ops/combine.cc | 0 .../test/src/{ => op-attrs}/ops/conv_2d.cc | 42 +++ .../src/{ => op-attrs}/ops/element_binary.cc | 0 .../src/{ => op-attrs}/ops/element_unary.cc | 0 .../test/src/{ => op-attrs}/ops/embedding.cc | 0 .../test/src/op-attrs/ops/layer_norm.cc | 36 +++ .../test/src/{ => op-attrs}/ops/linear.cc | 86 ++++-- .../test/src/{ => op-attrs}/ops/reduction.cc | 0 .../src/{ => op-attrs}/ops/repartition.cc | 0 .../test/src/{ => op-attrs}/ops/replicate.cc | 0 .../test/src/op-attrs/pcg_operator_attrs.cc | 17 ++ .../regularizer_attrs.cc} | 0 lib/op-attrs/test/src/test_operator_attrs.cc | 37 --- lib/pcg/include/pcg/computation_graph.h | 6 + .../include/pcg/computation_graph_builder.h | 39 +-- .../parallel_computation_graph.h | 11 +- .../parallel_computation_graph_builder.h | 5 +- .../include/pcg/strided_rectangle.struct.toml | 2 +- lib/pcg/src/pcg/computation_graph.cc | 41 +++ lib/pcg/src/pcg/computation_graph_builder.cc | 270 ++++++++---------- .../parallel_computation_graph.cc | 48 +++- .../parallel_computation_graph_builder.cc | 29 +- lib/pcg/test/src/pcg/computation_graph.cc | 206 +++++++++++++ .../parallel_computation_graph.cc | 227 +++++++++++++++ .../parallel_computation_graph_builder.cc | 60 ++-- .../operator_pattern/get_attribute.h | 60 ++-- .../operator_attribute_value.variant.toml | 3 + .../operator_pattern/get_attribute.cc | 12 + .../test/src/substitutions/pcg_pattern.cc | 18 +- .../perform_shape_inference.cc | 2 +- 72 files changed, 1352 insertions(+), 385 deletions(-) rename lib/op-attrs/include/op-attrs/{ => dim_ordered}/dim_ordered.h (100%) create mode 100644 lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h create mode 100644 lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml create mode 100644 lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc create mode 100644 lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc rename lib/op-attrs/test/src/{ => op-attrs}/datatype.cc (100%) rename lib/op-attrs/test/src/{test_dim_ordered.cc => op-attrs/dim_ordered/dim_ordered.cc} (89%) rename lib/op-attrs/test/src/{ => op-attrs}/dim_ordered/slice.cc (100%) create mode 100644 lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc rename lib/op-attrs/test/src/{ => op-attrs}/ops/attention.cc (87%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/batch_matmul.cc (100%) create mode 100644 lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc rename lib/op-attrs/test/src/{ => op-attrs}/ops/cast.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/combine.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/conv_2d.cc (85%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/element_binary.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/element_unary.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/embedding.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/linear.cc (74%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/reduction.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/repartition.cc (100%) rename lib/op-attrs/test/src/{ => op-attrs}/ops/replicate.cc (100%) create mode 100644 lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc rename lib/op-attrs/test/src/{test_regularizer_attrs.cc => op-attrs/regularizer_attrs.cc} (100%) delete mode 100644 lib/op-attrs/test/src/test_operator_attrs.cc create mode 100644 lib/pcg/test/src/pcg/computation_graph.cc diff --git a/lib/kernels/include/kernels/legion_dim.h b/lib/kernels/include/kernels/legion_dim.h index aafbd2cdcb..e4dd9723b8 100644 --- a/lib/kernels/include/kernels/legion_dim.h +++ b/lib/kernels/include/kernels/legion_dim.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_KERNELS_INCLUDE_KERNELS_LEGION_DIM_H #include "kernels/legion_dim_t.dtg.h" -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/serialization.h b/lib/local-execution/include/local-execution/serialization.h index a260519a55..2fc4b4b706 100644 --- a/lib/local-execution/include/local-execution/serialization.h +++ b/lib/local-execution/include/local-execution/serialization.h @@ -3,7 +3,7 @@ #include "kernels/device.h" #include "kernels/nccl.h" -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/required.h" #include "utils/strong_typedef.h" #include "utils/type_traits.h" diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index 5203991f25..b42aec10bb 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -69,7 +69,10 @@ CostDetails LocalCostEstimator::estimate_cost( std::vector output_tensor_ids = cg_builder.add_layer(layer_attrs, input_tensor_ids, - get_vector_piece_attrs(weights), + transform(get_vector_piece_attrs(weights), + [&](TensorAttrs const &a) { + return cg_builder.create_weight(a); + }), get_vector_piece_attrs(outputs)); LocalTrainingBacking local_backing(allocator, diff --git a/lib/models/src/models/transformer/transformer.cc b/lib/models/src/models/transformer/transformer.cc index e179359940..173a1b291c 100644 --- a/lib/models/src/models/transformer/transformer.cc +++ b/lib/models/src/models/transformer/transformer.cc @@ -42,7 +42,8 @@ tensor_guid_t create_transformer_encoder_layer(ComputationGraphBuilder &cgb, config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention)); @@ -88,7 +89,8 @@ tensor_guid_t config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent( cgb.computation_graph, input, self_attention)); @@ -107,7 +109,8 @@ tensor_guid_t config.num_heads, kdim, vdim, - config.dropout); + config.dropout, + /*bias=*/false); assert(are_tensor_guid_shapes_equivalent(cgb.computation_graph, input, mha)); tensor_guid_t mha_normalized = diff --git a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h index 03f38bb8f9..52e6e12a8c 100644 --- a/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h +++ b/lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H #include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" #include "utils/record_formatter.h" namespace FlexFlow { OperatorType get_op_type(ComputationGraphOpAttrs const &); RecordFormatter as_dot(ComputationGraphOpAttrs const &); +ComputationGraphOpAttrs + compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h similarity index 100% rename from lib/op-attrs/include/op-attrs/dim_ordered.h rename to lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h index f9f6d00532..38e7da4bb2 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/bidict/bidict.h" #include "utils/containers/count.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h index c843ed3842..8cc1bf3a51 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_of.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_OF_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" namespace FlexFlow { diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h index 560862677e..7343dc0e69 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/get_idxs.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_GET_IDXS_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/count.h" #include "utils/containers/transform.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index d39bac1bde..23b971da6b 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_SLICE_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/subvec.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h index ae6e552243..4fd3df0abb 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/transform.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/transform.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_TRANSFORM_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h index 023dcfc586..cc8b050f50 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/zip.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/zip.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ZIP_H -#include "op-attrs/dim_ordered.h" +#include "op-attrs/dim_ordered/dim_ordered.h" #include "utils/containers/vector_of.h" #include "utils/containers/zip.h" diff --git a/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h new file mode 100644 index 0000000000..b395736773 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/get_incoming_tensor_roles.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_INCOMING_TENSOR_ROLES_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_GET_INCOMING_TENSOR_ROLES_H + +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include "op-attrs/incoming_tensor_role.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" + +namespace FlexFlow { + +std::vector + get_incoming_tensor_roles(ComputationGraphOpAttrs const &, int num_inputs); +std::vector + get_incoming_tensor_roles(PCGOperatorAttrs const &, int num_inputs); + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml b/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml new file mode 100644 index 0000000000..427701c801 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/incoming_tensor_role.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "IncomingTensorRole" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "INPUT" + +[[values]] +name = "WEIGHT" diff --git a/lib/op-attrs/include/op-attrs/ops/attention.h b/lib/op-attrs/include/op-attrs/ops/attention.h index 40f57d08af..e06d795c04 100644 --- a/lib/op-attrs/include/op-attrs/ops/attention.h +++ b/lib/op-attrs/include/op-attrs/ops/attention.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_ATTENTION_ATTRS_H #define _FLEXFLOW_ATTENTION_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/attention/multihead_attention_inputs.dtg.h" #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.dtg.h" #include "op-attrs/ops/attention_attrs.dtg.h" @@ -37,6 +38,9 @@ int get_kvSeqLength(MultiHeadAttentionInputs const &); int get_num_samples(MultiHeadAttentionParallelInputs const &); int get_num_samples(MultiHeadAttentionInputs const &); +std::vector + get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &); + tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, TensorShape const &input_q, @@ -58,6 +62,22 @@ tl::expected TensorShape const &input_k, TensorShape const &input_v); +tl::expected + get_weights_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_input_bias_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); +tl::expected + get_output_bias_parallel_dims(MultiHeadAttentionAttrs const &, + ParallelTensorShape const &input_q, + ParallelTensorShape const &input_k, + ParallelTensorShape const &input_v); + tl::expected get_weights_shape(MultiHeadAttentionAttrs const &, ParallelTensorShape const &input_q, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h index 57760d1110..574b4ef579 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_matmul.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_matmul.h @@ -2,12 +2,15 @@ #define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_MATMUL_H #include "op-attrs/ops/batch_matmul.dtg.h" +#include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" #include namespace FlexFlow { +CHECK_VALID_OP_ATTR(BatchMatmulAttrs); + bool is_valid(BatchMatmulAttrs const &, ParallelTensorShape const &, ParallelTensorShape const &); diff --git a/lib/op-attrs/include/op-attrs/ops/conv_2d.h b/lib/op-attrs/include/op-attrs/ops/conv_2d.h index 72d1123c39..ae9f9249c6 100644 --- a/lib/op-attrs/include/op-attrs/ops/conv_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/conv_2d.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_CONV_2D_ATTRS_H #define _FLEXFLOW_CONV_2D_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/conv_2d_attrs.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/parallel_tensor_shape.h" @@ -10,6 +11,9 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Conv2DAttrs); +std::vector + get_conv2d_incoming_tensor_roles(Conv2DAttrs const &); + TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &input); TensorShape get_bias_shape(Conv2DAttrs const &attrs, TensorShape const &input); diff --git a/lib/op-attrs/include/op-attrs/ops/layer_norm.h b/lib/op-attrs/include/op-attrs/ops/layer_norm.h index 29b0b2f514..0fbadae2a1 100644 --- a/lib/op-attrs/include/op-attrs/ops/layer_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/layer_norm.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H #define _FLEXFLOW_OP_META_OPS_LAYER_NORM_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/ops/layer_norm_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -8,6 +9,9 @@ namespace FlexFlow { +std::vector + get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &); + tl::expected get_output_shape(LayerNormAttrs const &, TensorShape const &); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/linear.h b/lib/op-attrs/include/op-attrs/ops/linear.h index 795ba19ae8..065cc7e38e 100644 --- a/lib/op-attrs/include/op-attrs/ops/linear.h +++ b/lib/op-attrs/include/op-attrs/ops/linear.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LINEAR_ATTRS_H #define _FLEXFLOW_LINEAR_ATTRS_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/core.h" #include "op-attrs/ops/linear_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" @@ -10,20 +11,23 @@ namespace FlexFlow { +std::vector + get_linear_incoming_tensor_roles(LinearAttrs const &); + CHECK_VALID_OP_ATTR(LinearAttrs); RecordFormatter as_dot(LinearAttrs const &); tl::expected - get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input); + get_projection_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected get_output_shape(LinearAttrs const &attrs, TensorShape const &input); tl::expected - get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input); + get_projection_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input); tl::expected get_bias_shape(LinearAttrs const &attrs, ParallelTensorShape const &input); tl::expected diff --git a/lib/op-attrs/include/op-attrs/ops/topk.h b/lib/op-attrs/include/op-attrs/ops/topk.h index bd11f0ae91..d6de90903a 100644 --- a/lib/op-attrs/include/op-attrs/ops/topk.h +++ b/lib/op-attrs/include/op-attrs/ops/topk.h @@ -4,11 +4,13 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/topk_attrs.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { CHECK_VALID_OP_ATTR(TopKAttrs); +TensorShape get_output_shape(TopKAttrs const &, TensorShape const &); ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape); diff --git a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml index 756091f653..0dc30d9a79 100644 --- a/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/transpose_attrs.struct.toml @@ -12,7 +12,7 @@ features = [ includes = [ "op-attrs/ff_dim.h", "op-attrs/ff_dim.dtg.h", - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml index ae6eab1e58..f24fa12309 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", "op-attrs/shard_parallel_dim.dtg.h", "op-attrs/replica_parallel_dim_set.dtg.h", "", diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h index 08167fe3d9..723c05298d 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.h @@ -8,8 +8,8 @@ namespace FlexFlow { bool is_parallel_op(PCGOperatorAttrs const &); OperatorType get_op_type(PCGOperatorAttrs const &); -ComputationGraphOpAttrs - compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &); +PCGOperatorAttrs + pcg_op_attrs_from_compgraph_op_attrs(ComputationGraphOpAttrs const &); RecordFormatter as_dot(PCGOperatorAttrs const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml index 8617c5fd64..a44d712dbf 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.variant.toml @@ -13,6 +13,7 @@ includes = [ "op-attrs/ops/attention_attrs.dtg.h", "op-attrs/ops/batch_matmul.dtg.h", "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", "op-attrs/ops/cast_attrs.dtg.h", "op-attrs/ops/combine_attrs.dtg.h", "op-attrs/ops/concat_attrs.dtg.h", @@ -49,6 +50,10 @@ key = "batch_matmul" type = "::FlexFlow::BatchNormAttrs" key = "batch_norm" +[[values]] +type = "::FlexFlow::BroadcastAttrs" +key = "broadcast" + [[values]] type = "::FlexFlow::CastAttrs" key = "cast" diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml index cff8e08b0f..b262dd32b6 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml +++ b/lib/op-attrs/include/op-attrs/tensor_dims.struct.toml @@ -9,7 +9,7 @@ features = [ "fmt", ] includes = [ - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc index 054930cebd..c4ae7b31e5 100644 --- a/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/src/op-attrs/computation_graph_op_attrs.cc @@ -23,4 +23,22 @@ RecordFormatter as_dot(ComputationGraphOpAttrs const &attrs) { }); } +ComputationGraphOpAttrs + compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &op) { + auto fail_on_parallel_op = [](auto const &attrs) -> ComputationGraphOpAttrs { + throw mk_runtime_error( + fmt::format("Encountered parallel operator in " + "compgraph_op_attrs_from_pcg_op_attrs: {}", + attrs)); + }; + + return op.visit(overload{ + [&](CombineAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](ReductionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](RepartitionAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [&](ReplicateAttrs const &attrs) { return fail_on_parallel_op(attrs); }, + [](auto const &attrs) { return ComputationGraphOpAttrs{attrs}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc new file mode 100644 index 0000000000..c7febde1d6 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -0,0 +1,103 @@ +#include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::vector get_incoming_tensor_roles( + ComputationGraphOpAttrs const &comp_graph_op_attrs, int num_incoming) { + return get_incoming_tensor_roles( + pcg_op_attrs_from_compgraph_op_attrs(comp_graph_op_attrs), num_incoming); +} + +std::vector + get_incoming_tensor_roles(PCGOperatorAttrs const &pcg_op_attrs, + int num_incoming) { + return pcg_op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT}; + }, + [](BatchNormAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](BroadcastAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](CastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](CombineAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [&](ConcatAttrs const &) { + return std::vector(num_incoming, IncomingTensorRole::INPUT); + }, + [](Conv2DAttrs const &attrs) { + return get_conv2d_incoming_tensor_roles(attrs); + }, + [](DropoutAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ElementBinaryAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT}; + }, + [](ElementUnaryAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](EmbeddingAttrs const &) { + return std::vector{IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT}; + }, + [](FlatAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](GatherAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](InputAttrs const &) { return std::vector{}; }, + [](LayerNormAttrs const &attrs) { + return get_layer_norm_incoming_tensor_roles(attrs); + }, + [](LinearAttrs const &attrs) { + return get_linear_incoming_tensor_roles(attrs); + }, + [](MultiHeadAttentionAttrs const &attrs) { + return get_attention_incoming_tensor_roles(attrs); + }, + [](NoopAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](Pool2DAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReduceAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReductionAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](RepartitionAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReplicateAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReverseAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](ReshapeAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](SplitAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](SoftmaxAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](TopKAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; }, + [](TransposeAttrs const &) { + return std::vector{IncomingTensorRole::INPUT}; + }, + [](WeightAttrs const &) { return std::vector{}; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/attention.cc b/lib/op-attrs/src/op-attrs/ops/attention.cc index 036daa6e67..483d832fee 100644 --- a/lib/op-attrs/src/op-attrs/ops/attention.cc +++ b/lib/op-attrs/src/op-attrs/ops/attention.cc @@ -3,6 +3,7 @@ #include "op-attrs/ops/attention/multihead_attention_parallel_inputs.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.h" +#include "utils/containers/extend.h" #include "utils/integer_conversions.h" namespace FlexFlow { @@ -91,6 +92,24 @@ int get_num_samples(MultiHeadAttentionInputs const &inputs) { return inputs.batch_size; } +std::vector + get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs const &attrs) { + + std::vector roles = std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.bias) { + extend(roles, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return roles; +} + tl::expected get_output_shape(MultiHeadAttentionAttrs const &attrs, TensorShape const &input_q, diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index 03ae18a1d9..f77daf451f 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -5,6 +5,20 @@ namespace FlexFlow { +std::vector + get_conv2d_incoming_tensor_roles(Conv2DAttrs const &attrs) { + std::vector result = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.use_bias) { + result.push_back(IncomingTensorRole::WEIGHT); + } + + return result; +} + TensorShape get_kernel_shape(Conv2DAttrs const &attrs, TensorShape const &raw_input_shape) { assert(attrs.groups == 1); // TODO(@lockshaw): currently not supported diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index b9603d7850..d3c00efbb9 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -6,10 +6,23 @@ #include "utils/containers/all_of.h" #include "utils/containers/any_of.h" #include "utils/containers/contains.h" +#include "utils/containers/extend.h" #include "utils/containers/filter.h" namespace FlexFlow { +std::vector + get_layer_norm_incoming_tensor_roles(LayerNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.elementwise_affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + static std::optional check_input_shape(LayerNormAttrs const &attrs, TensorShape const &input_shape) { diff --git a/lib/op-attrs/src/op-attrs/ops/linear.cc b/lib/op-attrs/src/op-attrs/ops/linear.cc index 24a8250690..feac647216 100644 --- a/lib/op-attrs/src/op-attrs/ops/linear.cc +++ b/lib/op-attrs/src/op-attrs/ops/linear.cc @@ -8,6 +8,20 @@ namespace FlexFlow { +std::vector + get_linear_incoming_tensor_roles(LinearAttrs const &attrs) { + std::vector result = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + if (attrs.use_bias) { + result.push_back(IncomingTensorRole::WEIGHT); + } + + return result; +} + RecordFormatter as_dot(LinearAttrs const &attrs) { RecordFormatter r; @@ -25,7 +39,8 @@ RecordFormatter as_dot(LinearAttrs const &attrs) { } tl::expected - get_kernel_shape(LinearAttrs const &attrs, TensorShape const &input_shape) { + get_projection_shape(LinearAttrs const &attrs, + TensorShape const &input_shape) { size_t in_channels = dim_at_idx(input_shape, ff_dim_t{-1}); return TensorShape{ @@ -56,11 +71,11 @@ tl::expected } tl::expected - get_kernel_shape(LinearAttrs const &attrs, - ParallelTensorShape const &input) { + get_projection_shape(LinearAttrs const &attrs, + ParallelTensorShape const &input) { TensorShape unpar = ({ tl::expected result_unpar = - get_kernel_shape(attrs, get_reduced_shape(input)); + get_projection_shape(attrs, get_reduced_shape(input)); if (!result_unpar.has_value()) { return tl::unexpected(result_unpar.error()); } diff --git a/lib/op-attrs/src/op-attrs/ops/topk.cc b/lib/op-attrs/src/op-attrs/ops/topk.cc index 9d2fd35a94..7a6868340b 100644 --- a/lib/op-attrs/src/op-attrs/ops/topk.cc +++ b/lib/op-attrs/src/op-attrs/ops/topk.cc @@ -2,6 +2,10 @@ namespace FlexFlow { +TensorShape get_output_shape(TopKAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + ParallelTensorShape get_output_shape(TopKAttrs const &attrs, ParallelTensorShape const &input_shape) { NOT_IMPLEMENTED(); diff --git a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc index 0bb134da6b..4fe01c2c1a 100644 --- a/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc +++ b/lib/op-attrs/src/op-attrs/pcg_operator_attrs.cc @@ -15,56 +15,6 @@ OperatorType get_op_type(PCGOperatorAttrs const &attrs) { [](auto const &x) { return get_op_type(x); }); } -ComputationGraphOpAttrs - compgraph_op_attrs_from_pcg_op_attrs(PCGOperatorAttrs const &op) { - return op.visit(overload{ - [](BatchMatmulAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](BatchNormAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](CastAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ConcatAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](Conv2DAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](DropoutAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ElementBinaryAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](ElementUnaryAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](EmbeddingAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](FlatAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](GatherAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](InputAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](LayerNormAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](LinearAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](MultiHeadAttentionAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](NoopAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](Pool2DAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReduceAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReverseAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](ReshapeAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](SplitAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](SoftmaxAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](TopKAttrs const &attrs) { return ComputationGraphOpAttrs{attrs}; }, - [](TransposeAttrs const &attrs) { - return ComputationGraphOpAttrs{attrs}; - }, - [](auto const &attrs) -> ComputationGraphOpAttrs { - throw mk_runtime_error(fmt::format( - "Cannot convert parallel op to non-parallel, received {}", attrs)); - }, - }); -} - RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { return attrs.visit(overload{ [](LinearAttrs const &l) { return as_dot(l); }, @@ -76,4 +26,11 @@ RecordFormatter as_dot(PCGOperatorAttrs const &attrs) { }); } +PCGOperatorAttrs pcg_op_attrs_from_compgraph_op_attrs( + ComputationGraphOpAttrs const &cg_attrs) { + return cg_attrs.visit(overload{ + [](auto const &attrs) { return PCGOperatorAttrs{attrs}; }, + }); +} + } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc new file mode 100644 index 0000000000..42ea07e6b5 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -0,0 +1,15 @@ +#include "op-attrs/computation_graph_op_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("ComputationGraphOpAttrs to/from json") { + ComputationGraphOpAttrs correct = + ComputationGraphOpAttrs{BatchNormAttrs{true}}; + nlohmann::json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/datatype.cc b/lib/op-attrs/test/src/op-attrs/datatype.cc similarity index 100% rename from lib/op-attrs/test/src/datatype.cc rename to lib/op-attrs/test/src/op-attrs/datatype.cc diff --git a/lib/op-attrs/test/src/test_dim_ordered.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc similarity index 89% rename from lib/op-attrs/test/src/test_dim_ordered.cc rename to lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc index ac05767800..d7901a0c53 100644 --- a/lib/op-attrs/test/src/test_dim_ordered.cc +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/dim_ordered.cc @@ -1,5 +1,5 @@ +#include "op-attrs/dim_ordered/dim_ordered.h" #include "doctest/doctest.h" -#include "op-attrs/dim_ordered.h" #include "test/utils/rapidcheck.h" using namespace FlexFlow; diff --git a/lib/op-attrs/test/src/dim_ordered/slice.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc similarity index 100% rename from lib/op-attrs/test/src/dim_ordered/slice.cc rename to lib/op-attrs/test/src/op-attrs/dim_ordered/slice.cc diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc new file mode 100644 index 0000000000..60dedfe70a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -0,0 +1,26 @@ +#include "op-attrs/get_incoming_tensor_roles.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE( + "get_incoming_tensor_roles(ComputationGraphOpAttrs, int num_incoming)") { + SUBCASE("Concat") { + int num_incoming = 4; + ComputationGraphOpAttrs attrs = + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}, num_incoming}}; + + std::vector result = + get_incoming_tensor_roles(attrs, num_incoming); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/ops/attention.cc b/lib/op-attrs/test/src/op-attrs/ops/attention.cc similarity index 87% rename from lib/op-attrs/test/src/ops/attention.cc rename to lib/op-attrs/test/src/op-attrs/ops/attention.cc index 2fb804ca8c..eca8559b21 100644 --- a/lib/op-attrs/test/src/ops/attention.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/attention.cc @@ -7,6 +7,55 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_attention_incoming_tensor_roles(MultiHeadAttentionAttrs)") { + auto make_attrs = [](bool bias) { + return MultiHeadAttentionAttrs{ + /*embed_dim=*/32, + /*num_heads=*/10, + /*kdim=*/32, + /*vdim=*/32, + /*dropout=*/0.0, + /*bias=*/bias, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + }; + }; + + SUBCASE("without bias") { + MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/false); + + tl::expected, std::string> result = + get_attention_incoming_tensor_roles(attrs); + tl::expected, std::string> correct = + std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("with bias") { + MultiHeadAttentionAttrs attrs = make_attrs(/*bias=*/true); + + tl::expected, std::string> result = + get_attention_incoming_tensor_roles(attrs); + tl::expected, std::string> correct = + std::vector{ + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("get_output_shape(MultiHeadAttentionAttrs, TensorShape, " "TensorShape, TensorShape)") { int embed_dim = 32; diff --git a/lib/op-attrs/test/src/ops/batch_matmul.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc similarity index 100% rename from lib/op-attrs/test/src/ops/batch_matmul.cc rename to lib/op-attrs/test/src/op-attrs/ops/batch_matmul.cc diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc new file mode 100644 index 0000000000..df436da66c --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -0,0 +1,15 @@ +#include "op-attrs/ops/batch_norm_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("BatchNormAttrs to/from json") { + BatchNormAttrs correct = BatchNormAttrs{true}; + + nlohmann::json j = correct; + BatchNormAttrs result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/ops/cast.cc b/lib/op-attrs/test/src/op-attrs/ops/cast.cc similarity index 100% rename from lib/op-attrs/test/src/ops/cast.cc rename to lib/op-attrs/test/src/op-attrs/ops/cast.cc diff --git a/lib/op-attrs/test/src/ops/combine.cc b/lib/op-attrs/test/src/op-attrs/ops/combine.cc similarity index 100% rename from lib/op-attrs/test/src/ops/combine.cc rename to lib/op-attrs/test/src/op-attrs/ops/combine.cc diff --git a/lib/op-attrs/test/src/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc similarity index 85% rename from lib/op-attrs/test/src/ops/conv_2d.cc rename to lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index c4462eb7ec..152df09eca 100644 --- a/lib/op-attrs/test/src/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -5,6 +5,48 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_conv2d_incoming_tensor_roles(Conv2DAttrs") { + auto make_attrs = [](bool use_bias) { + return Conv2DAttrs{/*out_channels=*/4, + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*groups=*/1, + /*activation=*/std::nullopt, + /*use_bias=*/use_bias}; + }; + + SUBCASE("with bias") { + Conv2DAttrs attrs = make_attrs(/*use_bias=*/true); + + std::vector result = + get_conv2d_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("without bias") { + Conv2DAttrs attrs = make_attrs(/*use_bias=*/false); + + std::vector result = + get_conv2d_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("Conv2D shape inference") { int out_channels = 4; int kernel_h = 3; diff --git a/lib/op-attrs/test/src/ops/element_binary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_binary.cc similarity index 100% rename from lib/op-attrs/test/src/ops/element_binary.cc rename to lib/op-attrs/test/src/op-attrs/ops/element_binary.cc diff --git a/lib/op-attrs/test/src/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc similarity index 100% rename from lib/op-attrs/test/src/ops/element_unary.cc rename to lib/op-attrs/test/src/op-attrs/ops/element_unary.cc diff --git a/lib/op-attrs/test/src/ops/embedding.cc b/lib/op-attrs/test/src/op-attrs/ops/embedding.cc similarity index 100% rename from lib/op-attrs/test/src/ops/embedding.cc rename to lib/op-attrs/test/src/op-attrs/ops/embedding.cc diff --git a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc index cbcebdbce1..f45ea91dac 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/layer_norm.cc @@ -8,6 +8,42 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_layer_norm_incoming_tensor_roles(LayerNormAttrs)") { + auto make_attrs = [](bool elementwise_affine) { + return LayerNormAttrs{ + /*axes=*/{ff_dim_t{0}, ff_dim_t{2}}, + elementwise_affine, + /*eps=*/1.0, + }; + }; + + SUBCASE("elementwise_affine = true") { + LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/true); + + std::vector result = + get_layer_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("elementwise_affine = false") { + LayerNormAttrs attrs = make_attrs(/*elementwise_affine=*/false); + + std::vector result = + get_layer_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("shape inference (LayerNorm)") { LayerNormAttrs attrs_affine_true = LayerNormAttrs{ /*axes=*/{ff_dim_t{1}, ff_dim_t{3}}, diff --git a/lib/op-attrs/test/src/ops/linear.cc b/lib/op-attrs/test/src/op-attrs/ops/linear.cc similarity index 74% rename from lib/op-attrs/test/src/ops/linear.cc rename to lib/op-attrs/test/src/op-attrs/ops/linear.cc index f838ff4285..191515b062 100644 --- a/lib/op-attrs/test/src/ops/linear.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/linear.cc @@ -7,6 +7,45 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_linear_incoming_tensor_roles(LinearAttrs)") { + auto make_attrs = [](bool use_bias) { + return LinearAttrs{ + /*out_channels=*/16, + /*use_bias=*/use_bias, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + }; + + SUBCASE("use_bias = true") { + LinearAttrs attrs = make_attrs(/*use_bias=*/true); + + std::vector result = + get_linear_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("use_bias = false") { + LinearAttrs attrs = make_attrs(/*use_bias=*/false); + + std::vector result = + get_linear_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + } + TEST_CASE("Linear shape inference") { int out_channels = 16; LinearAttrs attrs = LinearAttrs{ @@ -43,7 +82,7 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - TensorShape kernel = TensorShape{ + TensorShape projection = TensorShape{ TensorDims{ FFOrdered{ in_channels, @@ -72,10 +111,10 @@ TEST_SUITE(FF_TEST_SUITE) { // get_weight_shape { - tl::expected kernel_result = - get_kernel_shape(attrs, input); - tl::expected kernel_correct = kernel; - CHECK(kernel_result == kernel_correct); + tl::expected projection_result = + get_projection_shape(attrs, input); + tl::expected projection_correct = projection; + CHECK(projection_result == projection_correct); } // get_bias_shape @@ -104,12 +143,12 @@ TEST_SUITE(FF_TEST_SUITE) { output, o_sum, o_eq, FFOrdered{o_batch, o_extra_dim, o_channel}); }; - auto make_kernel = [&](SumDegree o_sum, - DiscardCopyDegree o_eq, - int o_inchannel, - int o_outchannel) { + auto make_projection = [&](SumDegree o_sum, + DiscardCopyDegree o_eq, + int o_inchannel, + int o_outchannel) { return lift_to_parallel_with_degrees( - kernel, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); + projection, o_sum, o_eq, FFOrdered{o_inchannel, o_outchannel}); }; auto make_bias = @@ -143,12 +182,13 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, - DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, - 1, - 1); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, + DiscardCopyDegree{input_sum_degree * degree * extra_dim_degree}, + 1, + 1); CHECK(result == correct); } @@ -184,9 +224,10 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, degree, 1); CHECK(result == correct); } @@ -216,9 +257,10 @@ TEST_SUITE(FF_TEST_SUITE) { { tl::expected result = - get_kernel_shape(attrs, par_input); - tl::expected correct = make_kernel( - SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); + get_projection_shape(attrs, par_input); + tl::expected correct = + make_projection( + SumDegree{1}, DiscardCopyDegree{input_sum_degree}, 1, degree); CHECK(result == correct); } diff --git a/lib/op-attrs/test/src/ops/reduction.cc b/lib/op-attrs/test/src/op-attrs/ops/reduction.cc similarity index 100% rename from lib/op-attrs/test/src/ops/reduction.cc rename to lib/op-attrs/test/src/op-attrs/ops/reduction.cc diff --git a/lib/op-attrs/test/src/ops/repartition.cc b/lib/op-attrs/test/src/op-attrs/ops/repartition.cc similarity index 100% rename from lib/op-attrs/test/src/ops/repartition.cc rename to lib/op-attrs/test/src/op-attrs/ops/repartition.cc diff --git a/lib/op-attrs/test/src/ops/replicate.cc b/lib/op-attrs/test/src/op-attrs/ops/replicate.cc similarity index 100% rename from lib/op-attrs/test/src/ops/replicate.cc rename to lib/op-attrs/test/src/op-attrs/ops/replicate.cc diff --git a/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc new file mode 100644 index 0000000000..ebeaec4d19 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/pcg_operator_attrs.cc @@ -0,0 +1,17 @@ +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("PCGOperatorAttrs to/from json") { + PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ + /*repartition_dim=*/ff_dim_t{1}, + /*repartition_degree=*/4, + }}; + nlohmann::json j = correct; + auto result = j.get(); + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/test_regularizer_attrs.cc b/lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc similarity index 100% rename from lib/op-attrs/test/src/test_regularizer_attrs.cc rename to lib/op-attrs/test/src/op-attrs/regularizer_attrs.cc diff --git a/lib/op-attrs/test/src/test_operator_attrs.cc b/lib/op-attrs/test/src/test_operator_attrs.cc deleted file mode 100644 index 20825f5d73..0000000000 --- a/lib/op-attrs/test/src/test_operator_attrs.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include -#include -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{true}; - nlohmann::json j = correct; - BatchNormAttrs result = j.get(); - CHECK(result == correct); - } - - TEST_CASE("ComputationGraphAttrs to/from json") { - ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{true}}; - nlohmann::json j = correct; - ComputationGraphOpAttrs result = j.get(); - - CHECK(result == correct); - } - - TEST_CASE("PCGOperatorAttrs to/from json") { - PCGOperatorAttrs correct = PCGOperatorAttrs{RepartitionAttrs{ - /*repartition_dim=*/ff_dim_t{1}, - /*repartition_degree=*/4, - }}; - nlohmann::json j = correct; - PCGOperatorAttrs result = j.get(); - - CHECK(result == correct); - } -} diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index 499b26af89..f70d9f7404 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H #define _FLEXFLOW_PCG_INCLUDE_PCG_COMPUTATION_GRAPH_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "pcg/computation_graph.dtg.h" #include "pcg/computation_graph/computation_graph_edge.dtg.h" #include "pcg/computation_graph/layer_added_result.dtg.h" @@ -31,6 +32,11 @@ std::vector get_outgoing_tensors(ComputationGraph const &cg, std::vector get_incoming_tensors(ComputationGraph const &cg, layer_guid_t n); +std::vector get_incoming_inputs(ComputationGraph const &, + layer_guid_t const &); +std::vector get_incoming_weights(ComputationGraph const &, + layer_guid_t const &); + std::unordered_set get_subgraph_incoming_edges(ComputationGraph const &, std::unordered_set const &); diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index a35763cacc..11e591545d 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -231,21 +231,26 @@ struct ComputationGraphBuilder { tensor_guid_t create_input(TensorShape const &, CreateGrad, - std::optional const &maybe_name = std::nullopt); + std::optional const &name = std::nullopt); + tensor_guid_t create_weight( TensorShape const &, CreateGrad create_grad = CreateGrad::YES, std::optional const &initializer = std::nullopt, std::optional sync_type = std::nullopt, std::optional const &name = std::nullopt); + tensor_guid_t + create_weight(TensorAttrs const &, + std::optional const &name = std::nullopt); std::vector get_outputs(LayerAttrs const &) const; tensor_guid_t get_output(LayerAttrs const &, int idx) const; - std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); + std::vector + add_layer(LayerAttrs const &layer, + std::vector const &inputs, + std::vector const &weights, + std::vector const &outputs); private: TensorShape get_shape(tensor_guid_t const &) const; @@ -255,30 +260,6 @@ struct ComputationGraphBuilder { tensor_guid_t as_type(tensor_guid_t const &, DataType, std::string const &); - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorAttrs const &output); - - std::vector add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs); - - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - TensorShape const &output); - - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output); - - tensor_guid_t add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output); - TensorDims get_broadcast_target_dims(std::vector const &); TensorDims get_broadcast_target_dims(std::vector const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 9150681070..d7248afde4 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -22,12 +22,19 @@ ParallelLayerAddedResult std::vector const &output_labels); std::vector - get_layer_inputs(ParallelComputationGraph const &, - parallel_layer_guid_t const &); + get_incoming_tensors(ParallelComputationGraph const &, + parallel_layer_guid_t const &); std::vector get_layer_outputs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::vector + get_incoming_inputs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); +std::vector + get_incoming_weights(ParallelComputationGraph const &, + parallel_layer_guid_t const &); + ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 20e947ad58..3a7f67dcf0 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -13,7 +13,7 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t create_input_tensor( ParallelTensorShape const &shape, - bool create_grad = true, + CreateGrad create_grad = CreateGrad::YES, std::optional const &name = std::nullopt); parallel_tensor_guid_t @@ -54,7 +54,8 @@ struct ParallelComputationGraphBuilder { std::optional activation = std::nullopt, bool use_bias = true, DataType data_type = DataType::FLOAT, - std::optional const &kernel_initializer = std::nullopt, + std::optional const &projection_initializer = + std::nullopt, std::optional const &bias_initializer = std::nullopt, std::optional const &name = std::nullopt); diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml index 3dfd90e296..577825238d 100644 --- a/lib/pcg/include/pcg/strided_rectangle.struct.toml +++ b/lib/pcg/include/pcg/strided_rectangle.struct.toml @@ -11,7 +11,7 @@ features = [ includes = [ "pcg/strided_rectangle_side.dtg.h", - "op-attrs/dim_ordered.h", + "op-attrs/dim_ordered/dim_ordered.h", ] [[fields]] diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index cf4b1496cf..a69e54fd93 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -1,5 +1,7 @@ #include "pcg/computation_graph.h" #include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/reversed.h" #include "utils/containers/transform.h" @@ -80,6 +82,45 @@ std::vector get_incoming_tensors(ComputationGraph const &cg, [](DataflowOutput const &o) { return tensor_guid_t{o}; }); } +static std::vector + get_incoming_tensors_with_role(ComputationGraph const &cg, + layer_guid_t const &l, + IncomingTensorRole desired_role) { + ComputationGraphOpAttrs attrs = get_layer_attrs(cg, l).attrs; + + std::vector incoming_tensors = get_incoming_tensors(cg, l); + + std::vector incoming_tensor_roles = + get_incoming_tensor_roles(attrs, incoming_tensors.size()); + + assert(incoming_tensors.size() == incoming_tensor_roles.size()); + + std::vector result = + filtrans(zip(incoming_tensors, incoming_tensor_roles), + [&](std::pair const &p) + -> std::optional { + tensor_guid_t tensor = p.first; + IncomingTensorRole role = p.second; + + if (role == desired_role) { + return tensor; + } else { + return std::nullopt; + } + }); + return result; +} + +std::vector get_incoming_inputs(ComputationGraph const &cg, + layer_guid_t const &l) { + return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::INPUT); +} + +std::vector get_incoming_weights(ComputationGraph const &cg, + layer_guid_t const &l) { + return get_incoming_tensors_with_role(cg, l, IncomingTensorRole::WEIGHT); +} + std::unordered_set get_subgraph_incoming_edges( ComputationGraph const &cg, std::unordered_set const &subgraph_nodes) { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index e0b6935a6d..a4f61cff98 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -1,5 +1,6 @@ #include "pcg/computation_graph_builder.h" #include "op-attrs/computation_graph_op_attrs.h" +#include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/get_op_type.h" #include "op-attrs/get_output_shapes.h" #include "op-attrs/ops/attention.h" @@ -55,112 +56,63 @@ tensor_guid_t ComputationGraphBuilder::create_input( maybe_name, }; - return this->add_layer(layer_attrs, {}, {}, tensor_attrs); + return get_only(this->add_layer(layer_attrs, {}, {}, {tensor_attrs})); } tensor_guid_t ComputationGraphBuilder::create_weight( - TensorShape const &shape, - CreateGrad create_grad, - std::optional const &initializer, - std::optional param_sync, + TensorAttrs const &tensor_attrs, std::optional const &maybe_name) { - TensorAttrs tensor_attrs = - TensorAttrs{shape, initializer, param_sync, create_grad}; LayerAttrs layer_attrs = LayerAttrs{ ComputationGraphOpAttrs{InputAttrs{}}, maybe_name, }; - return this->add_layer(layer_attrs, {}, {}, tensor_attrs); + return get_only(this->add_layer(layer_attrs, + std::vector{}, + std::vector{}, + {tensor_attrs})); } -std::vector ComputationGraphBuilder::add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - std::vector raw_weight_tensors; - for (auto const &kv : enumerate_vector(weights)) { - int weight_idx = kv.first; - TensorAttrs weight_tensor_attrs = kv.second; - - std::optional weight_name = - transform(layer.name, [&](std::string const &layer_name) { - return fmt::format("{}.weights[{}]", layer_name, weight_idx); - }); - LayerAttrs weight_layer_attrs = LayerAttrs{ - ComputationGraphOpAttrs{WeightAttrs{weight_tensor_attrs.shape}}, - weight_name, - }; - std::vector weight_layer_inputs = {}; - std::vector weight_output_attrs = {weight_tensor_attrs}; - raw_weight_tensors.push_back(get_only(this->computation_graph.raw_graph - .add_node(weight_layer_attrs, - weight_layer_inputs, - weight_output_attrs) - .outputs)); - } +tensor_guid_t ComputationGraphBuilder::create_weight( + TensorShape const &shape, + CreateGrad create_grad, + std::optional const &initializer, + std::optional param_sync, + std::optional const &maybe_name) { + TensorAttrs tensor_attrs = + TensorAttrs{shape, initializer, param_sync, create_grad}; - std::vector raw_inputs = transform( - inputs, [](tensor_guid_t const &t) { return t.raw_graph_output; }); - std::vector raw_outputs = - this->computation_graph.raw_graph - .add_node( - layer, concat_vectors(raw_inputs, raw_weight_tensors), outputs) - .outputs; - return transform(raw_outputs, - [](DataflowOutput const &o) { return tensor_guid_t{o}; }); + return this->create_weight(tensor_attrs, maybe_name); } -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorAttrs const &output) { - std::vector outputs = {output}; - return get_only(this->add_layer(layer, inputs, weights, outputs)); -} +static void check_incoming_tensor_roles(LayerAttrs const &layer, + int num_inputs, + int num_weights) { + std::vector correct = + get_incoming_tensor_roles(layer.attrs, num_inputs + num_weights); + std::vector current = concat_vectors( + std::vector(num_inputs, IncomingTensorRole::INPUT), + std::vector(num_weights, IncomingTensorRole::WEIGHT)); -std::vector ComputationGraphBuilder::add_layer( - LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs) { - return this->add_layer( - layer, inputs, weights, transform(outputs, make_output_attrs)); + if (correct != current) { + throw mk_runtime_error( + fmt::format("check_incoming_tensor_roles found deviation in incoming " + "tensors: expected {}, received {}", + correct, + current)); + } } -tensor_guid_t ComputationGraphBuilder::add_layer( +std::vector ComputationGraphBuilder::add_layer( LayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, - TensorShape const &output_shape) { - - TensorAttrs output_attrs = make_output_attrs(output_shape); - LayerAddedResult added = - ::FlexFlow::add_layer(this->computation_graph, - layer, - concat_vectors(inputs, weights), - {output_attrs}); - return get_only(added.outputs); -} - -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - TensorShape const &output_shape) { - - std::vector weights = {}; - return this->add_layer(layer, inputs, weights, output_shape); -} + std::vector const &outputs) { + check_incoming_tensor_roles(layer, inputs.size(), weights.size()); -tensor_guid_t - ComputationGraphBuilder::add_layer(LayerAttrs const &layer, - std::vector const &inputs, - std::vector const &weights, - TensorShape const &output) { - return get_only(this->add_layer( - layer, inputs, weights, std::vector{output})); + LayerAddedResult added = ::FlexFlow::add_layer( + this->computation_graph, layer, concat_vectors(inputs, weights), outputs); + return added.outputs; } tensor_guid_t ComputationGraphBuilder::as_type(tensor_guid_t const &x, @@ -201,7 +153,8 @@ tensor_guid_t ComputationGraphBuilder::broadcast(tensor_guid_t const &input, TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t @@ -238,7 +191,8 @@ tensor_guid_t ComputationGraphBuilder::element_unary( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::element_binary( @@ -271,7 +225,8 @@ tensor_guid_t ComputationGraphBuilder::element_binary( TensorShape output_shape = throw_if_unexpected(get_output_shape( attrs, this->get_shape(lhs_input), this->get_shape(rhs_input))); - return this->add_layer(layer, {lhs_input, rhs_input}, output_shape); + return get_only(this->add_layer( + layer, {lhs_input, rhs_input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t @@ -461,7 +416,12 @@ tensor_guid_t ComputationGraphBuilder::conv2d( bias_initializer)); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::dropout( @@ -479,7 +439,8 @@ tensor_guid_t ComputationGraphBuilder::dropout( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::embedding( @@ -507,7 +468,10 @@ tensor_guid_t ComputationGraphBuilder::embedding( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); - return this->add_layer(layer, {input}, {weight_attrs}, output_shape); + return get_only(this->add_layer(layer, + {input}, + {this->create_weight(weight_attrs)}, + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::gather( @@ -531,41 +495,9 @@ tensor_guid_t ComputationGraphBuilder::gather( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); - return this->add_layer(layer, {input}, output_shape); -} - -/* std::vector - * ComputationGraphBuilder::get_shapes(std::vector const &ts) - * const { */ -/* return transform(ts, [&](tensor_guid_t const &t) { return - * this->get_shape(t); }); */ -/* } */ - -// tensor_guid_t ComputationGraphBuilder::aggregate( -// tensor_guid_t const &gate_preds, -// tensor_guid_t const &gate_assign, -// tensor_guid_t const &true_gate_assign, -// tensor_guid_t const &full_gate_gradients, -// std::vector const &exp_preds, -// int n, -// float lambda_bal, -// std::optional const &maybe_name) { -// AggregateAttrs attrs = {n, lambda_bal}; -// std::string name = maybe_name.value_or(get_default_name(attrs)); - -// LayerAttrs layer = {attrs, name}; -// TensorShape output_shape = get_output_shape(attrs, -// this->get_shape(gate_preds), -// this->get_shape(gate_assign), -// this->get_shape(true_gate_assign), -// this->get_shape(full_gate_gradients), -// this->get_shape(exp_preds)); - -// std::vector inputs = { -// gate_preds, gate_assign, true_gate_assign, full_gate_gradients}; -// extend(inputs, exp_preds); -// return this->add_layer(layer, inputs, {}, output_shape); -// } + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, @@ -579,7 +511,8 @@ tensor_guid_t ComputationGraphBuilder::batch_norm( TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -597,6 +530,20 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( std::optional initializer, std::optional const &maybe_name) { + if (add_bias_kv) { + throw mk_runtime_error( + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_bias_kv=true. " + "If you need this functionality, please create an issue."); + } + + if (add_zero_attn) { + throw mk_runtime_error( + "ComputationGraphBuilder::multihead_attention received currently " + "unsupported argument add_zero_attn=true. " + "If you need this functionality, please create an issue."); + } + MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{embed_dim, num_heads, kdim, @@ -609,24 +556,48 @@ tensor_guid_t ComputationGraphBuilder::multihead_attention( std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + TensorShape query_shape = this->get_shape(query); + TensorShape key_shape = this->get_shape(key); + TensorShape value_shape = this->get_shape(value); + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = - throw_if_unexpected(get_output_shape(attrs, - this->get_shape(query), - this->get_shape(key), - this->get_shape(value))); + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, query_shape, key_shape, value_shape)); + + std::vector weights; + + TensorShape weights_shape = throw_if_unexpected( + get_weights_shape(attrs, query_shape, key_shape, value_shape)); + weights.push_back(make_weight_attrs(weights_shape, initializer)); + + if (bias) { + TensorShape input_bias_shape = throw_if_unexpected( + get_input_bias_shape(attrs, query_shape, key_shape, value_shape)); + // initializer chosen based on + // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 + InitializerAttrs input_bias_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; - TensorShape weights_shape = - throw_if_unexpected(get_weights_shape(attrs, - this->get_shape(query), - this->get_shape(key), - this->get_shape(value))); - TensorAttrs weight_attrs = make_weight_attrs(weights_shape, initializer); + weights.push_back( + make_weight_attrs(input_bias_shape, input_bias_initializer)); - return this->add_layer(layer, - std::vector{query, key, value}, - {weight_attrs}, - output_shape); + TensorShape output_bias_shape = throw_if_unexpected( + get_output_bias_shape(attrs, query_shape, key_shape, value_shape)); + // initializer chosen based on + // https://github.com/pytorch/pytorch/blob/31c4e0d37d8efc37a0697159e5b9121ec34d5141/torch/nn/modules/activation.py#L1120-L1121 + InitializerAttrs output_bias_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + + weights.push_back( + make_weight_attrs(output_bias_shape, output_bias_initializer)); + } + + return get_only(this->add_layer( + layer, + {query, key, value}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } TensorDims ComputationGraphBuilder::get_broadcast_target_dims( @@ -676,7 +647,7 @@ tensor_guid_t ComputationGraphBuilder::dense( std::vector weights; TensorShape projection_shape = - throw_if_unexpected(get_kernel_shape(attrs, this->get_shape(input))); + throw_if_unexpected(get_projection_shape(attrs, this->get_shape(input))); tensor_guid_t projection_weights = this->create_weight(projection_shape, @@ -699,7 +670,8 @@ tensor_guid_t ComputationGraphBuilder::dense( weights.push_back(bias_weights); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, {input}, weights, {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::layer_norm( @@ -752,7 +724,12 @@ tensor_guid_t ComputationGraphBuilder::layer_norm( weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); } - return this->add_layer(layer, {input}, weights, output_shape); + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::softmax( @@ -781,7 +758,8 @@ tensor_guid_t ComputationGraphBuilder::softmax( TensorShape output_shape = throw_if_unexpected(get_output_shape(attrs, input_shape)); - return this->add_layer(layer, {input}, output_shape); + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 5b178160cd..b04d9d37b3 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/get_incoming_tensor_roles.h" +#include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" @@ -41,8 +43,8 @@ ParallelLayerAddedResult } std::vector - get_layer_inputs(ParallelComputationGraph const &pcg, - parallel_layer_guid_t const &l) { + get_incoming_tensors(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { return transform( get_input_values(pcg.raw_graph, l.raw_graph_node), [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); @@ -56,6 +58,48 @@ std::vector [](DataflowOutput const &o) { return parallel_tensor_guid_t{o}; }); } +static std::vector + get_incoming_tensors_with_role(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l, + IncomingTensorRole desired_role) { + PCGOperatorAttrs attrs = get_parallel_layer_attrs(pcg, l).op_attrs; + + std::vector incoming_tensors = + get_incoming_tensors(pcg, l); + + std::vector incoming_tensor_roles = + get_incoming_tensor_roles(attrs, incoming_tensors.size()); + + assert(incoming_tensors.size() == incoming_tensor_roles.size()); + + std::vector result = filtrans( + zip(incoming_tensors, incoming_tensor_roles), + [&](std::pair const &p) + -> std::optional { + parallel_tensor_guid_t tensor = p.first; + IncomingTensorRole role = p.second; + + if (role == desired_role) { + return tensor; + } else { + return std::nullopt; + } + }); + return result; +} + +std::vector + get_incoming_inputs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::INPUT); +} + +std::vector + get_incoming_weights(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_incoming_tensors_with_role(pcg, l, IncomingTensorRole::WEIGHT); +} + parallel_layer_guid_t get_source_layer(ParallelComputationGraph const &g, parallel_tensor_guid_t const &t) { return parallel_layer_guid_t{t.raw_graph_output.node}; diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 8290a2ff94..620dc035fc 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,4 +1,5 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/parallel_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" @@ -35,13 +36,13 @@ ParallelComputationGraphBuilder::ParallelComputationGraphBuilder() parallel_tensor_guid_t ParallelComputationGraphBuilder::create_input_tensor( ParallelTensorShape const &shape, - bool create_grad, + CreateGrad create_grad, std::optional const &name) { ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ /*shape=*/shape, /*sync_type=*/std::nullopt, /*initializer=*/std::nullopt, - /*create_gradients=*/(create_grad ? CreateGrad::YES : CreateGrad::NO), + /*create_gradients=*/create_grad, }; ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ PCGOperatorAttrs{InputAttrs{}}, @@ -205,7 +206,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( { ParallelTensorShape kernel_shape = - throw_if_unexpected(get_kernel_shape(attrs, input_shape)); + throw_if_unexpected(get_projection_shape(attrs, input_shape)); weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); } @@ -580,11 +581,32 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::add_weight( return parallel_tensor_guid_t{current_raw_weight_tensor}; } +static void check_incoming_tensor_roles(ParallelLayerAttrs const &layer, + int num_inputs, + int num_weights) { + std::vector correct = + get_incoming_tensor_roles(layer.op_attrs, num_inputs + num_weights); + std::vector current = concat_vectors( + std::vector(num_inputs, IncomingTensorRole::INPUT), + std::vector(num_weights, IncomingTensorRole::WEIGHT)); + + if (correct != current) { + throw mk_runtime_error( + fmt::format("check_incoming_tensor_roles found deviation in incoming " + "tensors: expected {}, received {}", + correct, + current)); + } +} + std::vector ParallelComputationGraphBuilder::add_layer( ParallelLayerAttrs const &layer, std::vector const &inputs, std::vector const &weights, std::vector const &outputs) { + + check_incoming_tensor_roles(layer, inputs.size(), weights.size()); + std::vector raw_weight_tensors; for (auto const &kv : enumerate_vector(weights)) { int weight_idx = kv.first; @@ -603,6 +625,7 @@ std::vector ParallelComputationGraphBuilder::add_layer( transform(inputs, [](parallel_tensor_guid_t const &t) { return t.raw_graph_output; }); + std::vector raw_outputs = this->pcg.raw_graph .add_node( diff --git a/lib/pcg/test/src/pcg/computation_graph.cc b/lib/pcg/test/src/pcg/computation_graph.cc new file mode 100644 index 0000000000..e2ed51b2f1 --- /dev/null +++ b/lib/pcg/test/src/pcg/computation_graph.cc @@ -0,0 +1,206 @@ +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_incoming_inputs(ComputationGraph, layer_guid_t)") { + SUBCASE("layer has no inputs") { + std::string input_name = "input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_name); + + std::vector result = get_incoming_inputs(cg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string layer_name = "my op"; + + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.relu(input, layer_name); + + ComputationGraph cg = b.computation_graph; + + layer_guid_t layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_inputs(cg, layer); + std::vector correct = {input}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string layer_name = "my op"; + + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/layer_name); + + ComputationGraph cg = b.computation_graph; + + layer_guid_t dense_layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_inputs(cg, dense_layer); + std::vector correct = { + input, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_incoming_weights(ComputationGraph, layer_guid_t)") { + SUBCASE("layer has no inputs or weights") { + std::string input_name = "input"; + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = + b.create_input(input_shape, CreateGrad::YES, input_name); + + return b.computation_graph; + }(); + + layer_guid_t input_layer = get_layer_by_name(cg, input_name); + + std::vector result = get_incoming_weights(cg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string layer_name = "my op"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.relu(input, layer_name); + + return b.computation_graph; + }(); + + layer_guid_t layer = get_layer_by_name(cg, layer_name); + + std::vector result = get_incoming_weights(cg, layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string layer_name = "my op"; + std::string projection_name = "my projection weight"; + std::string bias_name = "my bias weight"; + + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = b.create_input(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/layer_name, + /*projection_name=*/projection_name, + /*bias_name=*/bias_name); + + return b.computation_graph; + }(); + + layer_guid_t dense_layer = get_layer_by_name(cg, layer_name); + + layer_guid_t projection_weight_layer = + get_layer_by_name(cg, projection_name); + tensor_guid_t projection_weight = + get_only(get_outgoing_tensors(cg, projection_weight_layer)); + + layer_guid_t bias_weight_layer = get_layer_by_name(cg, bias_name); + tensor_guid_t bias_weight = + get_only(get_outgoing_tensors(cg, bias_weight_layer)); + + std::vector result = get_incoming_weights(cg, dense_layer); + std::vector correct = { + projection_weight, + bias_weight, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 188447da92..77d938e08a 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,5 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test/utils/rapidcheck.h" #include "utils/containers/get_only.h" @@ -35,4 +36,230 @@ TEST_SUITE(FF_TEST_SUITE) { // std::vector correct = {layer1, layer2, layer3}; // CHECK(result == correct); } + + TEST_CASE( + "get_incoming_inputs(ParallelComputationGraph, parallel_layer_guid_t)") { + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("layer has no inputs") { + std::string input_name = "my input"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + b.create_input_tensor(input_shape, CreateGrad::YES, input_name); + + return b.pcg; + }(); + + parallel_layer_guid_t input_layer = + get_parallel_layer_by_name(pcg, input_name); + + std::vector result = + get_incoming_inputs(pcg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights") { + std::string my_op_name = "my op"; + + ParallelComputationGraphBuilder b; + + parallel_tensor_guid_t input = + b.create_input_tensor(input_shape, CreateGrad::YES); + b.dense(input, + /*outDim=*/14, + /*activation=*/Activation::RELU, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt, + /*name=*/my_op_name); + + ParallelComputationGraph pcg = b.pcg; + + parallel_layer_guid_t my_op_layer = + get_parallel_layer_by_name(pcg, my_op_name); + + std::vector result = + get_incoming_inputs(pcg, my_op_layer); + std::vector correct = {input}; + + CHECK(result == correct); + } + } + + TEST_CASE( + "get_incoming_weights(ParallelComputationGraph, parallel_layer_guid_t)") { + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("layer has no inputs or weights") { + std::string input_name = "my input"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + b.create_input_tensor(input_shape, CreateGrad::YES, input_name); + + return b.pcg; + }(); + + parallel_layer_guid_t input_layer = + get_parallel_layer_by_name(pcg, input_name); + + std::vector result = + get_incoming_weights(pcg, input_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs but no weights") { + std::string my_op_name = "my op"; + ParallelComputationGraph pcg = [&] { + ParallelComputationGraphBuilder b; + + parallel_tensor_guid_t input = + b.create_input_tensor(input_shape, CreateGrad::YES); + b.relu(input, my_op_name); + + return b.pcg; + }(); + + parallel_layer_guid_t my_op_layer = + get_parallel_layer_by_name(pcg, my_op_name); + + std::vector result = + get_incoming_weights(pcg, my_op_layer); + std::vector correct = {}; + + CHECK(result == correct); + } + + SUBCASE("layer has inputs and weights, and weights are separate by " + "parallel ops") { + std::string my_op_name = "my op"; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + LinearAttrs op_attrs = LinearAttrs{ + /*out_channels=*/14, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*activation=*/Activation::RELU, + /*regularizer=*/std::nullopt, + }; + + ParallelLayerAddedResult input_added = [&] { + ParallelLayerAttrs input_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{InputAttrs{}}, + std::nullopt, + }; + ParallelTensorAttrs input_tensor_attrs = + ParallelTensorAttrs{input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + return add_parallel_layer(pcg, input_attrs, {}, {input_tensor_attrs}); + }(); + parallel_tensor_guid_t input = get_only(input_added.outputs); + + ParallelLayerAddedResult projection_weight_added = [&] { + ParallelTensorShape projection_weight_shape = + throw_if_unexpected(get_projection_shape(op_attrs, input_shape)); + + TensorShape unpar_projection_shape = + get_reduced_shape(projection_weight_shape); + ParallelTensorShape raw_projection_weight_shape = + lift_to_parallel(unpar_projection_shape); + + ParallelLayerAttrs raw_projection_weight_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{unpar_projection_shape}}, + std::nullopt, + }; + ParallelTensorAttrs raw_projection_tensor_attrs = + ParallelTensorAttrs{raw_projection_weight_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + ParallelLayerAddedResult raw_weight_added = + add_parallel_layer(pcg, + raw_projection_weight_attrs, + {}, + {raw_projection_tensor_attrs}); + + ReplicateAttrs replicate_attrs = ReplicateAttrs{/*degree=*/2}; + ParallelLayerAttrs replicate_layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{replicate_attrs}, + std::nullopt, + }; + ParallelTensorAttrs replicated_projection_tensor_attrs = + ParallelTensorAttrs{ + get_output_shape(replicate_attrs, raw_projection_weight_shape), + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + return add_parallel_layer(pcg, + replicate_layer_attrs, + {}, + {replicated_projection_tensor_attrs}); + }(); + parallel_tensor_guid_t projection_weight = + get_only(projection_weight_added.outputs); + + ParallelLayerAddedResult my_op_added = [&] { + ParallelTensorShape output_shape = + throw_if_unexpected(get_output_shape(op_attrs, input_shape)); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + PCGOperatorAttrs{op_attrs}, + std::nullopt, + }; + ParallelTensorAttrs output_tensor_attrs = + ParallelTensorAttrs{output_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + CreateGrad::YES}; + + return add_parallel_layer(pcg, + layer_attrs, + {input, projection_weight}, + {output_tensor_attrs}); + }(); + + parallel_layer_guid_t my_op_layer = my_op_added.parallel_layer; + + std::vector result = + get_incoming_weights(pcg, my_op_layer); + std::vector correct = {projection_weight}; + + CHECK(result == correct); + } + } } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index f46f267859..c445085635 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -15,6 +15,12 @@ using namespace ::FlexFlow; +// Stylistically these tests are not great (they're rather complicated +// and hard to read) and should not be used as a model for other FlexFlow +// tests. +// +// Improving them is being tracked in +// https://github.com/flexflow/FlexFlow/issues/1474 TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ParallelComputationGraphBuilder::add") { ParallelComputationGraphBuilder b; @@ -44,9 +50,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t out = b.add(lhs, rhs); parallel_layer_guid_t layer = get_source_layer(out); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {lhs, rhs}; CHECK(result == correct); } @@ -107,9 +113,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t out = b.batch_matmul(a_tensor, b_tensor); parallel_layer_guid_t layer = get_source_layer(out); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {a_tensor, b_tensor}; CHECK(result == correct); } @@ -150,9 +156,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.cast(input, output_datatype); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -260,20 +266,20 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelTensorShape correct_bias_shape = get_bias_shape(correct_attrs, input_shape); - std::vector conv_inputs = - get_layer_inputs(b.pcg, conv_guid); + std::vector conv_incoming = + get_incoming_tensors(b.pcg, conv_guid); - parallel_tensor_guid_t conv_input = conv_inputs.at(0); + parallel_tensor_guid_t conv_input = conv_incoming.at(0); ParallelTensorShape conv_input_shape = get_parallel_tensor_attrs(b.pcg, conv_input).shape; CHECK(conv_input_shape == input_shape); - parallel_tensor_guid_t conv_kernel = conv_inputs.at(1); + parallel_tensor_guid_t conv_kernel = conv_incoming.at(1); ParallelTensorShape conv_kernel_shape = get_parallel_tensor_attrs(b.pcg, conv_kernel).shape; CHECK(conv_kernel_shape == correct_kernel_shape); - parallel_tensor_guid_t conv_bias = conv_inputs.at(2); + parallel_tensor_guid_t conv_bias = conv_incoming.at(2); ParallelTensorShape conv_bias_shape = get_parallel_tensor_attrs(b.pcg, conv_bias).shape; CHECK(conv_bias_shape == correct_bias_shape); @@ -315,9 +321,9 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 3); @@ -358,9 +364,9 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == input); CHECK(result.size() == 2); @@ -408,9 +414,9 @@ TEST_SUITE(FF_TEST_SUITE) { b.multihead_attention(query, key, value, embed_dim, num_heads); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); CHECK(result.at(0) == query); CHECK(result.at(1) == key); CHECK(result.at(2) == value); @@ -449,9 +455,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.relu(input); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -488,9 +494,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_partition(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -527,9 +533,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_combine(input, ff_dim_t{0}, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -566,9 +572,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_replicate(input, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } @@ -605,9 +611,9 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t output = b.parallel_reduce(input, 2); parallel_layer_guid_t layer = get_source_layer(output); - SUBCASE("inputs") { + SUBCASE("incoming") { std::vector result = - get_layer_inputs(b.pcg, layer); + get_incoming_tensors(b.pcg, layer); std::vector correct = {input}; CHECK(result == correct); } diff --git a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h index d4d38af228..a5f0cc6fdc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h +++ b/lib/substitutions/include/substitutions/operator_pattern/get_attribute.h @@ -10,64 +10,66 @@ namespace FlexFlow { std::optional get_attribute(PCGOperatorAttrs const &, OperatorAttributeKey); -std::optional get_attribute(BatchMatmulAttrs const &p, +std::optional get_attribute(BatchMatmulAttrs const &, OperatorAttributeKey); -std::optional get_attribute(BatchNormAttrs const &p, +std::optional get_attribute(BatchNormAttrs const &, OperatorAttributeKey); -std::optional get_attribute(CastAttrs const &p, +std::optional get_attribute(BroadcastAttrs const &, OperatorAttributeKey); -std::optional get_attribute(CombineAttrs const &p, +std::optional get_attribute(CastAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ConcatAttrs const &p, +std::optional get_attribute(CombineAttrs const &, OperatorAttributeKey); -std::optional get_attribute(Conv2DAttrs const &p, +std::optional get_attribute(ConcatAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ElementBinaryAttrs const &p, +std::optional get_attribute(Conv2DAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ElementUnaryAttrs const &p, +std::optional get_attribute(ElementBinaryAttrs const &, OperatorAttributeKey); -std::optional get_attribute(DropoutAttrs const &p, +std::optional get_attribute(ElementUnaryAttrs const &, OperatorAttributeKey); -std::optional get_attribute(EmbeddingAttrs const &p, +std::optional get_attribute(DropoutAttrs const &, OperatorAttributeKey); -std::optional get_attribute(FlatAttrs const &p, +std::optional get_attribute(EmbeddingAttrs const &, OperatorAttributeKey); -std::optional get_attribute(GatherAttrs const &p, +std::optional get_attribute(FlatAttrs const &, OperatorAttributeKey); -std::optional get_attribute(InputAttrs const &p, +std::optional get_attribute(GatherAttrs const &, OperatorAttributeKey); -std::optional get_attribute(LayerNormAttrs const &p, +std::optional get_attribute(InputAttrs const &, OperatorAttributeKey); -std::optional get_attribute(LinearAttrs const &p, +std::optional get_attribute(LayerNormAttrs const &, + OperatorAttributeKey); +std::optional get_attribute(LinearAttrs const &, OperatorAttributeKey); std::optional - get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); + get_attribute(MultiHeadAttentionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(NoopAttrs const &p, +std::optional get_attribute(NoopAttrs const &, OperatorAttributeKey); -std::optional get_attribute(Pool2DAttrs const &p, +std::optional get_attribute(Pool2DAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReduceAttrs const &p, +std::optional get_attribute(ReduceAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReductionAttrs const &p, +std::optional get_attribute(ReductionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(RepartitionAttrs const &p, +std::optional get_attribute(RepartitionAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReplicateAttrs const &p, +std::optional get_attribute(ReplicateAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReshapeAttrs const &p, +std::optional get_attribute(ReshapeAttrs const &, OperatorAttributeKey); -std::optional get_attribute(ReverseAttrs const &p, +std::optional get_attribute(ReverseAttrs const &, OperatorAttributeKey); -std::optional get_attribute(SplitAttrs const &p, +std::optional get_attribute(SplitAttrs const &, OperatorAttributeKey); -std::optional get_attribute(SoftmaxAttrs const &p, +std::optional get_attribute(SoftmaxAttrs const &, OperatorAttributeKey); -std::optional get_attribute(TopKAttrs const &p, +std::optional get_attribute(TopKAttrs const &, OperatorAttributeKey); -std::optional get_attribute(TransposeAttrs const &p, +std::optional get_attribute(TransposeAttrs const &, OperatorAttributeKey); -// optional get_attribute(FusedParallelOpAttrs const &p, +// optional get_attribute(FusedParallelOpAttrs const &, // OperatorAttributeKey); } // namespace FlexFlow diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml index ad36f1bc4b..02a856f59a 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_value.variant.toml @@ -69,5 +69,8 @@ type = "::FlexFlow::PoolOp" [[values]] type = "::FlexFlow::TensorShape" +[[values]] +type = "::FlexFlow::TensorDims" + [[values]] type = "::FlexFlow::DataType" diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index a18737085a..d5d735ef59 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -26,6 +26,18 @@ std::optional get_attribute(BatchNormAttrs const &p, } } +std::optional get_attribute(BroadcastAttrs const &p, + OperatorAttributeKey key) { + switch (key) { + case OperatorAttributeKey::OP_TYPE: + return get_op_type(p); + case OperatorAttributeKey::TARGET_DIMS: + return p.target_dims; + default: + return std::nullopt; + } +} + std::optional get_attribute(CastAttrs const &p, OperatorAttributeKey key) { switch (key) { diff --git a/lib/substitutions/test/src/substitutions/pcg_pattern.cc b/lib/substitutions/test/src/substitutions/pcg_pattern.cc index 4f56a76d0d..d9273b4bcf 100644 --- a/lib/substitutions/test/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/test/src/substitutions/pcg_pattern.cc @@ -35,7 +35,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::string a_name = "a"; parallel_tensor_guid_t a_tensor = - builder.create_input_tensor(a_shape, /*create_grad=*/true, a_name); + builder.create_input_tensor(a_shape, CreateGrad::YES, a_name); int outDim = 16; std::string x_matmul_name = "x_matmul"; @@ -65,14 +65,14 @@ TEST_SUITE(FF_TEST_SUITE) { get_parallel_layer_by_name(pcg, x_matmul_name); parallel_layer_guid_t y_matmul = get_parallel_layer_by_name(pcg, y_matmul_name); - std::vector x_inputs = - get_layer_inputs(pcg, x_matmul); - REQUIRE(x_inputs.size() == 2); - parallel_tensor_guid_t x_weights = x_inputs.at(1); - std::vector y_inputs = - get_layer_inputs(pcg, y_matmul); - REQUIRE(y_inputs.size() == 2); - parallel_tensor_guid_t y_weights = y_inputs.at(1); + std::vector x_incoming = + get_incoming_tensors(pcg, x_matmul); + REQUIRE(x_incoming.size() == 2); + parallel_tensor_guid_t x_weights = x_incoming.at(1); + std::vector y_incoming = + get_incoming_tensors(pcg, y_matmul); + REQUIRE(y_incoming.size() == 2); + parallel_tensor_guid_t y_weights = y_incoming.at(1); LabelledOpenDataflowGraph g = LabelledOpenDataflowGraph Date: Tue, 17 Sep 2024 11:50:37 -0700 Subject: [PATCH 06/12] Add Inception-v3 model (#1495) * inception v3 initial implementation * Add parallel shape inference for concat and pool2d * Format * Respond to PR comments * Fix model bugs * Update batch norm to match pytorch interface for inception v3 * Finishing touches for inception, re-add relu flag for batchnorm * Format * Document adaptive pool2d formula simplification --------- Co-authored-by: Pietro Max Marsella Co-authored-by: Colin Unger --- .../src/export_model_arch.cc | 6 +- ...ion_graph_series_parallel_decomposition.cc | 11 + lib/local-execution/src/ops/concat.cc | 4 +- .../models/inception_v3/inception_v3.h | 23 + .../inception_v3_config.struct.toml | 23 + .../inception_v3_output.struct.toml | 25 + .../src/models/inception_v3/inception_v3.cc | 750 ++++++++++++++++++ .../src/models/inception_v3/inception_v3.cc | 19 + .../models/{ => transformer}/transformer.cc | 0 .../include/op-attrs/dim_ordered/concat.h | 34 + .../op-attrs/dim_ordered/dim_ordered.h | 6 +- .../dim_ordered/ff_ordered_from_map.h | 29 + .../include/op-attrs/dim_ordered/slice.h | 7 + .../include/op-attrs/ops/batch_norm.h | 35 +- .../op-attrs/ops/batch_norm_attrs.struct.toml | 22 + lib/op-attrs/include/op-attrs/ops/concat.h | 9 +- .../op-attrs/ops/concat_attrs.struct.toml | 4 - lib/op-attrs/include/op-attrs/ops/flat.h | 8 +- .../op-attrs/ops/flat_attrs.struct.toml | 21 +- lib/op-attrs/include/op-attrs/ops/pool_2d.h | 20 +- .../op-attrs/ops/pool_2d_attrs.struct.toml | 9 +- .../parallel_tensor_dim_degrees.struct.toml | 28 + .../include/op-attrs/parallel_tensor_dims.h | 3 + .../include/op-attrs/parallel_tensor_shape.h | 6 + .../src/op-attrs/dim_ordered/concat.cc | 1 + .../dim_ordered/ff_ordered_from_map.cc | 1 + .../src/op-attrs/get_incoming_tensor_roles.cc | 5 +- .../src/op-attrs/get_output_shapes.cc | 10 +- lib/op-attrs/src/op-attrs/ops/batch_norm.cc | 255 +++++- lib/op-attrs/src/op-attrs/ops/concat.cc | 135 +++- lib/op-attrs/src/op-attrs/ops/conv_2d.cc | 8 +- lib/op-attrs/src/op-attrs/ops/flat.cc | 120 +-- lib/op-attrs/src/op-attrs/ops/layer_norm.cc | 2 +- lib/op-attrs/src/op-attrs/ops/pool_2d.cc | 202 ++++- .../src/op-attrs/parallel_tensor_dims.cc | 8 + .../src/op-attrs/parallel_tensor_shape.cc | 13 + .../op-attrs/computation_graph_op_attrs.cc | 8 +- .../test/src/op-attrs/dim_ordered/concat.cc | 66 ++ .../dim_ordered/ff_ordered_from_map.cc | 66 ++ .../src/op-attrs/get_incoming_tensor_roles.cc | 2 +- .../test/src/op-attrs/ops/batch_norm.cc | 404 ++++++++++ .../test/src/op-attrs/ops/batch_norm_attrs.cc | 7 +- lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc | 8 +- lib/op-attrs/test/src/op-attrs/ops/flat.cc | 244 ++++++ lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc | 400 ++++++++++ .../include/pcg/computation_graph_builder.h | 24 +- .../parallel_computation_graph_builder.h | 5 +- lib/pcg/src/pcg/computation_graph_builder.cc | 166 +++- .../parallel_computation_graph_builder.cc | 44 +- .../operator_attribute_key.enum.toml | 3 +- .../operator_pattern/get_attribute.cc | 12 +- .../test/src/substitutions/substitution.cc | 2 +- .../include/utils/containers/are_all_same.h | 23 + .../utils/containers/require_all_same1.h | 31 + lib/utils/include/utils/containers/subvec.h | 5 + lib/utils/include/utils/containers/sum.h | 17 + lib/utils/include/utils/optional.h | 4 +- .../src/utils/containers/are_all_same.cc | 1 + .../src/utils/containers/require_all_same1.cc | 1 + lib/utils/src/utils/containers/sum.cc | 1 + .../test/src/utils/containers/are_all_same.cc | 36 + .../src/utils/containers/require_all_same1.cc | 54 ++ lib/utils/test/src/utils/containers/sum.cc | 27 + 63 files changed, 3354 insertions(+), 169 deletions(-) create mode 100644 lib/models/include/models/inception_v3/inception_v3.h create mode 100644 lib/models/include/models/inception_v3/inception_v3_config.struct.toml create mode 100644 lib/models/include/models/inception_v3/inception_v3_output.struct.toml create mode 100644 lib/models/src/models/inception_v3/inception_v3.cc create mode 100644 lib/models/test/src/models/inception_v3/inception_v3.cc rename lib/models/test/src/models/{ => transformer}/transformer.cc (100%) create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/concat.h create mode 100644 lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/concat.cc create mode 100644 lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc create mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc create mode 100644 lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/flat.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc create mode 100644 lib/utils/include/utils/containers/are_all_same.h create mode 100644 lib/utils/include/utils/containers/require_all_same1.h create mode 100644 lib/utils/include/utils/containers/sum.h create mode 100644 lib/utils/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/src/utils/containers/require_all_same1.cc create mode 100644 lib/utils/src/utils/containers/sum.cc create mode 100644 lib/utils/test/src/utils/containers/are_all_same.cc create mode 100644 lib/utils/test/src/utils/containers/require_all_same1.cc create mode 100644 lib/utils/test/src/utils/containers/sum.cc diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index ccc720ed14..98b7a003ce 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,6 +1,7 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "op-attrs/computation_graph_op_attrs.h" @@ -59,6 +60,9 @@ tl::expected get_model_computation_graph(std::string const &model_name) { if (model_name == "transformer") { return get_default_transformer_computation_graph(); + } else if (model_name == "inception_v3") { + return get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -132,7 +136,7 @@ int main(int argc, char **argv) { "for preprocessed to help check series-parallel structure"}); std::vector model_options = { - "transformer", "split_test", "single_operator"}; + "transformer", "inception_v3", "split_test", "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( cli, CLIPositionalArgumentSpec{ diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index ab537e73de..c9d84a8948 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,5 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" #include "pcg/computation_graph.h" @@ -291,6 +292,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(sp_decomposition.has_value()); } + + SUBCASE("inception_v3") { + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } } } diff --git a/lib/local-execution/src/ops/concat.cc b/lib/local-execution/src/ops/concat.cc index 35f663b1cd..4c3462e694 100644 --- a/lib/local-execution/src/ops/concat.cc +++ b/lib/local-execution/src/ops/concat.cc @@ -50,7 +50,7 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { auto output = acc.get_tensor(OUTPUT); auto inputs = acc.get_variadic_tensor(INPUTS); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(inputs.size() <= MAX_NUM_INPUTS); return profile(forward_kernel, profiling, @@ -68,7 +68,7 @@ static std::optional auto input_grads = acc.get_variadic_tensor_grad(INPUTS); auto output_grad = acc.get_tensor_grad(OUTPUT); - assert(attrs.num_inputs <= MAX_NUM_INPUTS); + assert(input_grads.size() <= MAX_NUM_INPUTS); return profile(backward_kernel, profiling, diff --git a/lib/models/include/models/inception_v3/inception_v3.h b/lib/models/include/models/inception_v3/inception_v3.h new file mode 100644 index 0000000000..5c4754e441 --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_INCEPTION_V3 + +#include "models/inception_v3/inception_v3_config.dtg.h" +#include "pcg/computation_graph.dtg.h" + +namespace FlexFlow { + +/** + * @brief Get the default training config from https://arxiv.org/abs/1512.00567. + */ +InceptionV3Config get_default_inception_v3_training_config(); + +/** + * @brief Get a computation graph for Inception-v3 as described in + * https://arxiv.org/abs/1512.00567. + */ +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/inception_v3/inception_v3_config.struct.toml b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml new file mode 100644 index 0000000000..a2a75c83bb --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_config.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "InceptionV3Config" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "num_classes" +type = "int" + +[[fields]] +name = "batch_size" +type = "int" + +[[fields]] +name = "aux_logits" +type = "bool" diff --git a/lib/models/include/models/inception_v3/inception_v3_output.struct.toml b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml new file mode 100644 index 0000000000..066e6df02b --- /dev/null +++ b/lib/models/include/models/inception_v3/inception_v3_output.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "InceptionV3Output" +features = [ + "eq", + "ord", + "hash", + "fmt", +] + +includes = [ + "pcg/tensor_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "standard_logits" +type = "::FlexFlow::tensor_guid_t" + +[[fields]] +name = "aux_logits" +type = "std::optional<::FlexFlow::tensor_guid_t>" diff --git a/lib/models/src/models/inception_v3/inception_v3.cc b/lib/models/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..f540eae629 --- /dev/null +++ b/lib/models/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,750 @@ +#include "models/inception_v3/inception_v3.h" +#include "models/inception_v3/inception_v3_output.dtg.h" +#include "op-attrs/tensor_shape.h" +#include "pcg/computation_graph.h" +#include "pcg/computation_graph_builder.h" +#include "utils/integer_conversions.h" + +namespace FlexFlow { + +struct CheckShape { + CheckShape(ComputationGraphBuilder const &cgb, + InceptionV3Config const &config) + : cgb(cgb), config(config) {} + + ComputationGraphBuilder const &cgb; + InceptionV3Config const &config; + + void operator()(tensor_guid_t t, int c, int h, int w) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + size_t_from_int(h), + size_t_from_int(w), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } + + void operator()(tensor_guid_t t, int c) const { + TensorShape current_shape = cgb.get_shape(t); + TensorShape expected_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + size_t_from_int(c), + }}, + DataType::FLOAT, + }; + + if (current_shape != expected_shape) { + throw mk_runtime_error(fmt::format( + "Expected activation shape {}, but found activation shape {}", + expected_shape, + current_shape)); + } + } +}; + +InceptionV3Config get_default_inception_v3_training_config() { + return InceptionV3Config{ + /*num_classes=*/1000, + + // see section 8 of https://arxiv.org/abs/1512.00567 for the source of the + // batch size + /*batch_size=*/32, + + // see section 4 of https://arxiv.org/abs/1512.00567 for a discussion of + // auxiliary logits. they are used by default in training + /*aux_logits=*/true, + }; +} + +static tensor_guid_t create_conv_block(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int filters, + int kernel_size_h, + int kernel_size_w, + int stride_h = 1, + int stride_w = 1, + int padding_h = 0, + int padding_w = 0, + bool use_bias = false) { + tensor_guid_t conv = cgb.conv2d(input, + /*outChannels=*/filters, + /*kernelH=*/kernel_size_h, + /*kernelW=*/kernel_size_w, + /*strideH=*/stride_h, + /*strideW=*/stride_w, + /*paddingH=*/padding_h, + /*paddingW=*/padding_w, + /*activation=*/std::nullopt, + /*groups=*/1, + /*use_bias=*/use_bias); + return cgb.batch_norm(conv, + /*affine=*/true, + /*activation=*/Activation::RELU, + /*eps=*/1e-5, + /*momentum=*/0.1); +} + +static tensor_guid_t create_inception_module_a(ComputationGraphBuilder &cgb, + tensor_guid_t const &input, + int pool_features) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch5x5 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/48, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/2, + /*padding_w=*/2); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/pool_features, + /*kernel_stride_h=*/1, + /*kernel_stride_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch5x5, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_b(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = create_conv_block(cgb, + input, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/96, + /*kernel_stride_h=*/3, + /*kernel_stride_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch3x3dbl, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_c(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + int channels_7x7) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(branch1x1, 192, 17, 17); + + tensor_guid_t branch7x7 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + return t; + }(); + check_shape(branch7x7, 192, 17, 17); + + tensor_guid_t branch7x7dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/channels_7x7, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + return t; + }(); + check_shape(branch7x7dbl, 192, 17, 17); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + check_shape(branch_pool, 192, 17, 17); + + return cgb.concat({branch1x1, branch7x7, branch7x7dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_inception_module_d(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, t, 320, 3, 3, 2, 2); + return t; + }(); + + tensor_guid_t branch7x7x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/7, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/3); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/7, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/3, + /*padding_w=*/0); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + return t; + }(); + + tensor_guid_t branch_pool = cgb.pool2d(input, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + + return cgb.concat({branch3x3, branch7x7x3, branch_pool}, /*axis=*/1); +} + +static tensor_guid_t create_inception_module_e(ComputationGraphBuilder &cgb, + tensor_guid_t const &input) { + tensor_guid_t branch1x1 = create_conv_block(cgb, + input, + /*filters=*/320, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + + tensor_guid_t branch3x3 = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch3x3dbl = [&] { + tensor_guid_t t = input; + t = create_conv_block(cgb, + t, + /*filters=*/448, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + t = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + tensor_guid_t t_1 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/1, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/1); + tensor_guid_t t_2 = create_conv_block(cgb, + t, + /*filters=*/384, + /*kernel_size_h=*/3, + /*kernel_size_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/0); + t = cgb.concat({t_1, t_2}, /*axis=*/1); + return t; + }(); + + tensor_guid_t branch_pool = [&] { + tensor_guid_t t = input; + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/1, + /*paddingW=*/1, + /*type=*/PoolOp::AVG); + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + return t; + }(); + + return cgb.concat({branch1x1, branch3x3, branch3x3dbl, branch_pool}, + /*axis=*/1); +} + +static tensor_guid_t create_initial_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input) { + tensor_guid_t t = input; + + check_shape(t, 3, 299, 299); + + // Conv2d_1a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/2, + /*stride_w=*/2); + check_shape(t, 32, 149, 149); + + // Conv2d_2a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/32, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 32, 147, 147); + + // Conv2d_2b_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/64, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/1, + /*padding_w=*/1); + check_shape(t, 64, 147, 147); + + // maxpool1 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 64, 73, 73); + + // Conv2d_3b_1x1 + t = create_conv_block(cgb, + t, + /*filters=*/80, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(t, 80, 73, 73); + + // Conv2d_4a_3x3 + t = create_conv_block(cgb, + t, + /*filters=*/192, + /*kernel_size_h=*/3, + /*kernel_size_w=*/3); + check_shape(t, 192, 71, 71); + + // maxpool2 + t = cgb.pool2d(t, + /*kernelH=*/3, + /*kernelW=*/3, + /*strideH=*/2, + /*strideW=*/2, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::MAX); + check_shape(t, 192, 35, 35); + + return t; +} + +static tensor_guid_t create_final_layers(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + // avgpool + tensor_guid_t x = cgb.pool2d(input, + /*kernelH=*/8, + /*kernelW=*/8, + /*strideH=*/1, + /*strideW=*/1, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 2048, 1, 1); + + // dropout + x = cgb.dropout(x, + /*rate=*/0.5); + check_shape(x, 2048, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 2048); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + // softmax (not in pytorch model, but shown in Table 1 on p6 of + // https://arxiv.org/abs/1512.00567) + x = cgb.softmax(x); + check_shape(x, num_classes); + + return x; +} + +static tensor_guid_t create_inception_aux(ComputationGraphBuilder &cgb, + CheckShape const &check_shape, + tensor_guid_t const &input, + size_t num_classes) { + tensor_guid_t x = input; + check_shape(x, 768, 17, 17); + + x = cgb.pool2d(x, + /*kernelH=*/5, + /*kernelW=*/5, + /*strideH=*/3, + /*strideW=*/3, + /*paddingH=*/0, + /*paddingW=*/0, + /*type=*/PoolOp::AVG); + check_shape(x, 768, 5, 5); + + // conv0 + x = create_conv_block(cgb, + x, + /*filters=*/128, + /*kernel_size_h=*/1, + /*kernel_size_w=*/1); + check_shape(x, 128, 5, 5); + + // conv1 + x = create_conv_block(cgb, + x, + /*filters=*/768, + /*kernel_size_h=*/5, + /*kernel_size_w=*/5); + check_shape(x, 768, 1, 1); + + x = cgb.adaptive_pool2d(x, + /*output_h=*/1, + /*output_w=*/1); + check_shape(x, 768, 1, 1); + + x = cgb.flat(x, + /*start_dim=*/1); + check_shape(x, 768); + + // fc + x = cgb.dense(x, + /*outDim=*/num_classes); + check_shape(x, num_classes); + + return x; +} + +static InceptionV3Output create_inception_v3(ComputationGraphBuilder &cgb, + InceptionV3Config const &config, + tensor_guid_t const &input) { + // NOTE: the shapes for check_shape (as well as the layer names in comments) + // are pulled from + // https://github.com/pytorch/vision/blob/6d7851bd5e2bedc294e40e90532f0e375fcfee04/torchvision/models/inception.py#L103-L155 + CheckShape check_shape = CheckShape{ + /*cgb=*/cgb, + /*config=*/config, + }; + + tensor_guid_t x = create_initial_layers(cgb, check_shape, input); + check_shape(x, 192, 35, 35); + + // Mixed_5b + x = create_inception_module_a(cgb, x, 32); + check_shape(x, 256, 35, 35); + + // Mixed_5c + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_5d + x = create_inception_module_a(cgb, x, 64); + check_shape(x, 288, 35, 35); + + // Mixed_6a + x = create_inception_module_b(cgb, x); + check_shape(x, 768, 17, 17); + + // Mixed_6b + x = create_inception_module_c(cgb, check_shape, x, 128); + check_shape(x, 768, 17, 17); + + // Mixed_6c + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6d + x = create_inception_module_c(cgb, check_shape, x, 160); + check_shape(x, 768, 17, 17); + + // Mixed_6e + x = create_inception_module_c(cgb, check_shape, x, 192); + check_shape(x, 768, 17, 17); + + std::optional aux; + if (config.aux_logits) { + aux = create_inception_aux(cgb, check_shape, x, config.num_classes); + check_shape(aux.value(), config.num_classes); + } + + // Mixed_7a + x = create_inception_module_d(cgb, x); + check_shape(x, 1280, 8, 8); + + // Mixed_7b + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + // Mixed_7c + x = create_inception_module_e(cgb, x); + check_shape(x, 2048, 8, 8); + + x = create_final_layers(cgb, check_shape, x, config.num_classes); + check_shape(x, config.num_classes); + + return InceptionV3Output{ + x, + aux, + }; +} + +ComputationGraph + get_inception_v3_computation_graph(InceptionV3Config const &config) { + ComputationGraphBuilder cgb; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + size_t_from_int(config.batch_size), + 3, + 299, + 299, + }}, + DataType::FLOAT, + }; + + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); + InceptionV3Output output = create_inception_v3(cgb, config, input); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/src/models/inception_v3/inception_v3.cc b/lib/models/test/src/models/inception_v3/inception_v3.cc new file mode 100644 index 0000000000..2b0fe82fd6 --- /dev/null +++ b/lib/models/test/src/models/inception_v3/inception_v3.cc @@ -0,0 +1,19 @@ +#include "models/inception_v3/inception_v3.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_inception_v3_computation_graph") { + InceptionV3Config config = get_default_inception_v3_training_config(); + + ComputationGraph result = get_inception_v3_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 522; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/models/test/src/models/transformer.cc b/lib/models/test/src/models/transformer/transformer.cc similarity index 100% rename from lib/models/test/src/models/transformer.cc rename to lib/models/test/src/models/transformer/transformer.cc diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/concat.h b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h new file mode 100644 index 0000000000..9b9eaf9b93 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/concat.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_CONCAT_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "utils/containers/concat_vectors.h" +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template +FFOrdered concat(FFOrdered const &l, FFOrdered const &r) { + std::vector l_vec = std::vector(l.cbegin(), l.cend()); + std::vector r_vec = std::vector(r.cbegin(), r.cend()); + + std::vector raw_result = concat_vectors(l_vec, r_vec); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +template +FFOrdered concat(std::vector> const &inputs) { + std::vector> vec_inputs = + transform(inputs, [](FFOrdered const &input) { + return std::vector(input.cbegin(), input.cend()); + }); + + std::vector raw_result = concat_vectors(vec_inputs); + + return FFOrdered(raw_result.cbegin(), raw_result.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h index 34d186e74e..6aa23d40fc 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/dim_ordered.h @@ -10,7 +10,7 @@ namespace FlexFlow { template struct DimOrdered { - DimOrdered() = delete; + DimOrdered() {} DimOrdered(std::initializer_list const &l) : contents(l.begin(), l.end()) {} @@ -138,6 +138,10 @@ struct DimOrdered { return this->contents.size(); } + size_t empty() const { + return this->contents.empty(); + } + size_t num_dims() const { return this->size(); } diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h new file mode 100644 index 0000000000..79d4929797 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/dim_ordered/ff_ordered_from_map.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H +#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_FF_ORDERED_FROM_MAP_H + +#include "op-attrs/dim_ordered/dim_ordered.h" +#include "op-attrs/dim_ordered/ff_ordered_of.h" + +namespace FlexFlow { + +template +FFOrdered ff_ordered_from_map(std::map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +template +FFOrdered ff_ordered_from_map(std::unordered_map const &m) { + std::vector raw; + for (int i = 0; i < m.size(); i++) { + raw.push_back(m.at(ff_dim_t{i})); + } + return ff_ordered_of(raw); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h index 23b971da6b..e4c0e8e275 100644 --- a/lib/op-attrs/include/op-attrs/dim_ordered/slice.h +++ b/lib/op-attrs/include/op-attrs/dim_ordered/slice.h @@ -21,6 +21,13 @@ DimOrdered nonoverloaded_slice(DimOrdered const &d, subvec(vector_of(d), to_raw_idx(start), to_raw_idx(end))}; } +template +FFOrdered slice(FFOrdered const &d, + std::optional const &start, + std::optional const &end) { + return nonoverloaded_slice(d, start, end); +} + template DimOrdered slice(DimOrdered const &d, std::optional const &start, diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm.h b/lib/op-attrs/include/op-attrs/ops/batch_norm.h index 8afcbb06b1..f2e95690d1 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm.h +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm.h @@ -1,15 +1,42 @@ #ifndef _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_OPS_BATCH_NORM_H +#include "op-attrs/incoming_tensor_role.dtg.h" #include "op-attrs/ops/batch_norm_attrs.dtg.h" #include "op-attrs/ops/core.h" -#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/tensor_shape.dtg.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &); +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &); + +tl::expected get_output_shape(BatchNormAttrs const &, + TensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, TensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, TensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_gamma_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_beta_weights_parallel_dim_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &); + +tl::expected + get_output_shape(BatchNormAttrs const &, ParallelTensorShape const &); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &, + ParallelTensorShape const &); +tl::expected + get_beta_weights_shape(BatchNormAttrs const &, ParallelTensorShape const &); CHECK_VALID_OP_ATTR(BatchNormAttrs); diff --git a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml index bc82f3c743..fdc3bce1fe 100644 --- a/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/batch_norm_attrs.struct.toml @@ -10,6 +10,28 @@ features = [ "fmt", ] +includes = [ + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", +] + [[fields]] name = "relu" type = "bool" + +[[fields]] +name = "affine" +type = "bool" + +[[fields]] +name = "eps" +type = "float" + +[[fields]] +name = "momentum" +type = "std::optional" diff --git a/lib/op-attrs/include/op-attrs/ops/concat.h b/lib/op-attrs/include/op-attrs/ops/concat.h index f3ac8494c0..f07f06df85 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat.h +++ b/lib/op-attrs/include/op-attrs/ops/concat.h @@ -10,10 +10,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(ConcatAttrs); -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, std::vector const &); +tl::expected + get_output_shape(ConcatAttrs const &, + std::vector const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml index 4faa870bc4..fab8132993 100644 --- a/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/concat_attrs.struct.toml @@ -17,7 +17,3 @@ includes = [ [[fields]] name = "axis" type = "::FlexFlow::ff_dim_t" - -[[fields]] -name = "num_inputs" -type = "int" diff --git a/lib/op-attrs/include/op-attrs/ops/flat.h b/lib/op-attrs/include/op-attrs/ops/flat.h index 676d21c59b..710cbdb44b 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat.h +++ b/lib/op-attrs/include/op-attrs/ops/flat.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/flat_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -11,8 +12,11 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(FlatAttrs); TensorShape get_output_shape(FlatAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &); +tl::expected + get_output_parallel_dim_degrees(FlatAttrs const &, + ParallelTensorDimDegrees const &); +tl::expected + get_output_shape(FlatAttrs const &, ParallelTensorShape const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml index e445535e29..7349e2a8c4 100644 --- a/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/flat_attrs.struct.toml @@ -8,4 +8,23 @@ features = [ "rapidcheck", "fmt", ] -fields = [] + +includes = [ + "", + "op-attrs/ff_dim.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", + "op-attrs/ff_dim.h", +] + +[[fields]] +name = "start_dim" +type = "::FlexFlow::ff_dim_t" + +[[fields]] +name = "end_dim" +type = "::FlexFlow::ff_dim_t" diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d.h b/lib/op-attrs/include/op-attrs/ops/pool_2d.h index 505fdd9f8c..1af22ad022 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d.h +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d.h @@ -3,6 +3,7 @@ #include "op-attrs/ops/core.h" #include "op-attrs/ops/pool_2d_attrs.dtg.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/tensor_shape.dtg.h" @@ -10,9 +11,22 @@ namespace FlexFlow { CHECK_VALID_OP_ATTR(Pool2DAttrs); -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &); -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &); +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation); + +tl::expected get_output_shape(Pool2DAttrs const &, + TensorShape const &); + +tl::expected + get_output_shape(Pool2DAttrs const &, ParallelTensorShape const &); + +tl::expected + get_output_parallel_dim_degrees(Pool2DAttrs const &, + ParallelTensorDimDegrees const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml index 56bf682f50..20ca7deabc 100644 --- a/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml +++ b/lib/op-attrs/include/op-attrs/ops/pool_2d_attrs.struct.toml @@ -12,6 +12,13 @@ features = [ includes = [ "op-attrs/pool_op.dtg.h", "op-attrs/activation.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", + "utils/rapidcheck/optional.h", ] [[fields]] @@ -44,4 +51,4 @@ type = "::FlexFlow::PoolOp" [[fields]] name = "activation" -type = "::FlexFlow::Activation" +type = "std::optional<::FlexFlow::Activation>" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml new file mode 100644 index 0000000000..974b27d2a7 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_degrees.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "ParallelTensorDimDegrees" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", + "rapidcheck", +] + +includes = [ + "op-attrs/parallel_tensor_shape/sum_degree.dtg.h", + "op-attrs/parallel_tensor_shape/discard_copy_degree.dtg.h", + "op-attrs/dim_ordered/dim_ordered.h", +] + +[[fields]] +name = "sum_degree" +type = "::FlexFlow::SumDegree" + +[[fields]] +name = "discard_copy_degree" +type = "::FlexFlow::DiscardCopyDegree" + +[[fields]] +name = "shard_degrees" +type = "::FlexFlow::FFOrdered" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 8e02e3607b..7a89b4bd78 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_OP_ATTRS_INCLUDE_OP_ATTRS_PARALLEL_TENSOR_DIMS_H #include "op-attrs/parallel_dim.h" +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_dims.dtg.h" #include "op-attrs/tensor_dims.dtg.h" @@ -14,6 +15,8 @@ std::unordered_set replica_dims(ParallelTensorDims const &); /* size_t get_volume(ParallelTensorDims const &); */ size_t num_shard_dims(ParallelTensorDims const &); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); + int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 76356b39d4..806a5f0de7 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,6 +1,7 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H +#include "op-attrs/parallel_tensor_dim_degrees.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" @@ -17,12 +18,17 @@ FFOrdered ff_ordered_shard_degrees(ParallelTensorShape const &); std::optional try_get_shard_dim_at_idx(ParallelTensorShape const &, ff_dim_t); +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); + ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, SumDegree sum_degree, DiscardCopyDegree discard_copy_degree, FFOrdered const &shard_degrees); +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &, + ParallelTensorDimDegrees const &); std::unordered_set replica_dims(ParallelTensorShape const &); diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..cb29f708a3 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/concat.h" diff --git a/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..2de88f38c8 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index c7febde1d6..21efc26466 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -1,5 +1,6 @@ #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" @@ -22,8 +23,8 @@ std::vector return std::vector{IncomingTensorRole::INPUT, IncomingTensorRole::INPUT}; }, - [](BatchNormAttrs const &) { - return std::vector{IncomingTensorRole::INPUT}; + [](BatchNormAttrs const &attrs) { + return get_batch_norm_incoming_tensor_roles(attrs); }, [](BroadcastAttrs const &) { return std::vector{IncomingTensorRole::INPUT}; diff --git a/lib/op-attrs/src/op-attrs/get_output_shapes.cc b/lib/op-attrs/src/op-attrs/get_output_shapes.cc index d91d1a1eca..0058ee35a2 100644 --- a/lib/op-attrs/src/op-attrs/get_output_shapes.cc +++ b/lib/op-attrs/src/op-attrs/get_output_shapes.cc @@ -14,6 +14,7 @@ #include "op-attrs/ops/input.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight.h" #include "utils/overload.h" @@ -29,7 +30,7 @@ std::vector get_output_shape(attrs, inputs.at(0), inputs.at(1)))}; }, [&](BatchNormAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](CastAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; @@ -38,7 +39,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](ConcatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs)}; + return {throw_if_unexpected(get_output_shape(attrs, inputs))}; }, [&](Conv2DAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; @@ -57,7 +58,7 @@ std::vector return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](FlatAttrs const &attrs) -> std::vector { - return {get_output_shape(attrs, inputs.at(0))}; + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, [&](GatherAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0), inputs.at(1))}; @@ -71,6 +72,9 @@ std::vector [&](LinearAttrs const &attrs) -> std::vector { return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; }, + [&](Pool2DAttrs const &attrs) -> std::vector { + return {throw_if_unexpected(get_output_shape(attrs, inputs.at(0)))}; + }, [&](ReplicateAttrs const &attrs) -> std::vector { return {get_output_shape(attrs, inputs.at(0))}; }, diff --git a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc index b75c3521c6..f394bb8473 100644 --- a/lib/op-attrs/src/op-attrs/ops/batch_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/batch_norm.cc @@ -1,15 +1,260 @@ #include "op-attrs/ops/batch_norm.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/any_of.h" +#include "utils/containers/extend.h" namespace FlexFlow { -TensorShape get_output_shape(BatchNormAttrs const &, - TensorShape const &input_shape) { +std::vector + get_batch_norm_incoming_tensor_roles(BatchNormAttrs const &attrs) { + std::vector result = {IncomingTensorRole::INPUT}; + + if (attrs.affine) { + extend(result, + std::vector{IncomingTensorRole::WEIGHT, IncomingTensorRole::WEIGHT}); + } + + return result; +} + +static std::optional + check_input_shape(BatchNormAttrs const &, TensorShape const &input_shape) { + if (num_dims(input_shape) < 2) { + return fmt::format( + "BatchNormAttrs expected input dims >= 2, but received input shape {}", + input_shape); + } + + if (input_shape.data_type != DataType::FLOAT) { + return fmt::format("BatchNormAttrs currently only supports data_type = " + "FLOAT, but received input data_type {}. " + "If you need this feature, please create an issue.", + input_shape.data_type); + } + + return std::nullopt; +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + return input_shape; } -ParallelTensorShape get_output_shape(BatchNormAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + { + std::optional maybe_err_msg = + check_input_shape(attrs, input_shape); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + + return TensorShape{ + TensorDims{FFOrdered{ + num_channels, + }}, + DataType::FLOAT, + }; +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + TensorShape const &input_shape) { + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_shape(attrs, input_shape); +} + +static std::optional + check_input_degrees(BatchNormAttrs const &, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.shard_degrees.size() < 2) { + return fmt::format("BatchNormAttrs expected input dims >= 2, but received " + "input degrees {}", + input_degrees); + } + + if (input_degrees.sum_degree != SumDegree{1}) { + return fmt::format("Expected sum degree 1, but receieved sum degree {}", + input_degrees.sum_degree); + } + + if (input_degrees.discard_copy_degree != DiscardCopyDegree{1}) { + return fmt::format( + "Expected discard copy degree 1, but receieved discard copy degree {}", + input_degrees.discard_copy_degree); + } + + FFOrdered non_channel_degrees = + concat(slice(input_degrees.shard_degrees, ff_dim_t{0}, ff_dim_t{1}), + slice(input_degrees.shard_degrees, ff_dim_t{2}, std::nullopt)); + + if (any_of(non_channel_degrees, [](int degree) { return degree != 1; })) { + return fmt::format("Expected parallel degree of all non-channel dimensions " + "to be 1, but received input with degrees {}", + input_degrees); + } + + return std::nullopt; +} + +tl::expected + get_output_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + return input_degrees; +} + +tl::expected + get_gamma_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No gamma weights exist for attrs.affine = false"); + } + + ff_dim_t channel_dim = ff_dim_t{1}; + + return ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{input_degrees.shard_degrees.at(channel_dim)}, + }; +} + +tl::expected + get_beta_weights_parallel_dim_degrees( + BatchNormAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + { + std::optional maybe_err_msg = + check_input_degrees(attrs, input_degrees); + if (maybe_err_msg.has_value()) { + return tl::unexpected(maybe_err_msg.value()); + } + } + + if (!attrs.affine) { + return tl::unexpected("No beta weights exist for attrs.affine = false"); + } + + return get_gamma_weights_parallel_dim_degrees(attrs, input_degrees); +} + +tl::expected + get_output_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_gamma_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_gamma_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_gamma_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_beta_weights_shape(BatchNormAttrs const &attrs, + ParallelTensorShape const &input_shape) { + + TensorShape unpar = ({ + tl::expected returned = + get_beta_weights_shape(attrs, get_reduced_shape(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_beta_weights_parallel_dim_degrees( + attrs, get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/concat.cc b/lib/op-attrs/src/op-attrs/ops/concat.cc index 02fee70bea..74295f279e 100644 --- a/lib/op-attrs/src/op-attrs/ops/concat.cc +++ b/lib/op-attrs/src/op-attrs/ops/concat.cc @@ -1,24 +1,129 @@ #include "op-attrs/ops/concat.h" +#include "op-attrs/dim_ordered/enumerate.h" +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/containers/all_of.h" +#include "utils/containers/are_all_same.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/fmt/map.h" namespace FlexFlow { -/* bool ConcatAttrs::is_valid( */ -/* std::vector const &input) const { */ -/* bool valid = true; */ -/* for (auto p : input) { */ -/* valid &= p.is_valid(); */ -/* } */ -/* return valid; */ -/* } */ - -TensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + auto get_non_axis_dims = [&](TensorShape const &s) { + std::map dim_sizes = enumerate(ff_ordered(s.dims)); + dim_sizes.erase(attrs.axis); + return dim_sizes; + }; + + if (inputs.size() <= 1) { + return tl::unexpected(fmt::format("get_output_shape for Concat expected 2 " + "or more input, but receieved {}", + inputs)); + } + + if (attrs.axis.value < 0) { + return tl::unexpected(fmt::format("ConcatAttrs requires axis >= 0")); + } + + if (!are_all_same(transform( + inputs, [](TensorShape const &s) { return num_dims(s); }))) { + return tl::unexpected( + fmt::format("get_output_shape for Concat expected all inputs to have " + "the same number of dimensions, but receieved {}", + inputs)); + } + + std::map non_axis_dims = ({ + tl::expected, std::string> returned = + require_all_same1(transform(inputs, get_non_axis_dims)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + std::vector axis_dim_sizes = transform( + inputs, [&](TensorShape const &s) { return dim_at_idx(s, attrs.axis); }); + + size_t output_axis_dim_size = sum(axis_dim_sizes); + + non_axis_dims.insert({attrs.axis, output_axis_dim_size}); + + DataType datatype = ({ + tl::expected returned = require_all_same1( + transform(inputs, [](TensorShape const &s) { return s.data_type; })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return TensorShape{ + TensorDims{ + ff_ordered_from_map(non_axis_dims), + }, + datatype, + }; } -ParallelTensorShape get_output_shape(ConcatAttrs const &, - std::vector const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_shape(ConcatAttrs const &attrs, + std::vector const &inputs) { + TensorShape unpar = ({ + tl::expected returned = + get_output_shape(attrs, transform(inputs, get_reduced_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + SumDegree sum_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_sum_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + SumDegree{returned.value()}; + }); + + DiscardCopyDegree discard_copy_degree = ({ + tl::expected returned = + require_all_same1(transform(inputs, get_discard_copy_degree)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + DiscardCopyDegree{returned.value()}; + }); + + if (!all_of(inputs, [&](ParallelTensorShape const &s) { + return shard_dim_at_idx(s, attrs.axis).degree == 1; + })) { + return tl::unexpected(fmt::format( + "get_output_shape for Concat expected input tensors to have parallel " + "degree 1 in the concat axis dimension, but received {}", + inputs)); + } + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + require_all_same1(transform(inputs, [](ParallelTensorShape const &s) { + return get_parallel_degrees(s); + })); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); } } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc index f77daf451f..eac756cc15 100644 --- a/lib/op-attrs/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/conv_2d.cc @@ -54,11 +54,11 @@ TensorShape get_output_shape(Conv2DAttrs const &attrs, Conv2DInputShape input = parse_input_shape(raw_input_shape); size_t out_height = - (input.height - (2 * attrs.padding_h) - (attrs.kernel_h - 1)) / - attrs.stride_h; + (input.height + (2 * attrs.padding_h) - attrs.kernel_h) / attrs.stride_h + + 1; size_t out_width = - (input.width - (2 * attrs.padding_w) - (attrs.kernel_w - 1)) / - attrs.stride_w; + (input.width + (2 * attrs.padding_w) - attrs.kernel_w) / attrs.stride_w + + 1; assert(attrs.out_channels > 0); diff --git a/lib/op-attrs/src/op-attrs/ops/flat.cc b/lib/op-attrs/src/op-attrs/ops/flat.cc index 5d318207ee..e9833d5e3f 100644 --- a/lib/op-attrs/src/op-attrs/ops/flat.cc +++ b/lib/op-attrs/src/op-attrs/ops/flat.cc @@ -1,57 +1,85 @@ #include "op-attrs/ops/flat.h" +#include "op-attrs/dim_ordered/concat.h" +#include "op-attrs/dim_ordered/slice.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "utils/containers/any_of.h" +#include "utils/containers/product.h" #include namespace FlexFlow { -TensorShape get_output_shape(FlatAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); +TensorShape get_output_shape(FlatAttrs const &attrs, + TensorShape const &input_shape) { + FFOrdered leading_dims = + slice(ff_ordered(input_shape.dims), ff_dim_t{0}, attrs.start_dim); + FFOrdered flattened_dims = + slice(ff_ordered(input_shape.dims), attrs.start_dim, attrs.end_dim); + FFOrdered trailing_dims = + slice(ff_ordered(input_shape.dims), attrs.end_dim, std::nullopt); + + if (flattened_dims.empty()) { + return input_shape; + } + + return TensorShape{ + TensorDims{ + concat(std::vector{ + leading_dims, + {product(flattened_dims)}, + trailing_dims, + }), + }, + input_shape.data_type, + }; } -ParallelTensorShape get_output_shape(FlatAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); +tl::expected + get_output_parallel_dim_degrees( + FlatAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { + FFOrdered flattened_dim_degrees = + slice(input_degrees.shard_degrees, attrs.start_dim, attrs.end_dim); + + if (flattened_dim_degrees.empty()) { + return input_degrees; + } + + if (any_of(flattened_dim_degrees, [](int degree) { return degree != 1; })) { + return tl::unexpected( + fmt::format("get_output_parallel_dim_degrees for {} expected all shard " + "degrees of flattened dimensions to be 1, but received {}", + attrs, + input_degrees)); + } + + return ParallelTensorDimDegrees{ + /*sum_degree=*/input_degrees.sum_degree, + /*discard_copy_degree=*/input_degrees.discard_copy_degree, + /*shard_degrees=*/ + concat(std::vector{ + slice(input_degrees.shard_degrees, ff_dim_t{0}, attrs.start_dim), + {product(flattened_dim_degrees)}, + slice(input_degrees.shard_degrees, attrs.end_dim, std::nullopt), + }), + }; } -// namespace Input { -// constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, -// REPLICA = 4; -// } -// -// namespace Output { -// constexpr int NUMDIM = 3, CHANNEL = 0, SAMPLE = 1, REPLICA = 2; -// } -// -/* bool FlatAttrs::is_valid(ParallelTensorShape const &input) const { */ -/* ParallelTensorShape output_shape = this->calculate_output_shape(input); */ - -/* bool is_valid = true; */ -/* is_valid &= input.is_valid(); */ -/* is_valid &= output_shape.is_valid(); */ -/* is_valid &= (input.at(Input::WIDTH).degree == 1); */ - -/* return is_valid; */ -/* } */ - -/* ParallelTensorShape FlatAttrs::calculate_output_shape(ParallelTensorShape - * const &input) const { */ -/* assert (input.num_dims() == Input::NUMDIM); */ -/* ParallelTensorShape output_dims; */ -/* output_dims.data_type = input.data_type; */ - -/* output_dims.at(Output::REPLICA) = input.at(Input::REPLICA); */ -/* output_dims.at(Output::SAMPLE) = input.at(Input::SAMPLE); */ - -/* output_dims.at(Output::CHANNEL).degree = input.at(Input::CHANNEL).degree; - */ -/* assert (input.at(Input::HEIGHT).degree == 1); */ -/* assert (input.at(Input::WIDTH).degree == 1); */ - -/* output_dims.at(Output::CHANNEL).size = input.at(Input::CHANNEL).size * - * input.at(Input::HEIGHT).size * input.at(Input::WIDTH).size; */ -/* output_dims.at(Output::CHANNEL).parallel_idx = - * input.at(Input::CHANNEL).parallel_idx; */ - -/* return output_dims; */ -/* } */ +tl::expected + get_output_shape(FlatAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = get_output_shape(attrs, get_reduced_shape(input_shape)); + + ParallelTensorDimDegrees degrees = ({ + tl::expected returned = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!returned.has_value()) { + return tl::unexpected(returned.error()); + } + returned.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} } // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc index d3c00efbb9..0dd9ac7a17 100644 --- a/lib/op-attrs/src/op-attrs/ops/layer_norm.cc +++ b/lib/op-attrs/src/op-attrs/ops/layer_norm.cc @@ -112,7 +112,7 @@ static std::optional if (get_discard_copy_degree(input_shape) != 1) { return fmt::format( - "Expected discard copy degree 1, but received discartd copy degree {}", + "Expected discard copy degree 1, but received discard copy degree {}", get_discard_copy_degree(input_shape)); } diff --git a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc index e1917efd89..95bcd8b336 100644 --- a/lib/op-attrs/src/op-attrs/ops/pool_2d.cc +++ b/lib/op-attrs/src/op-attrs/ops/pool_2d.cc @@ -1,62 +1,184 @@ #include "op-attrs/ops/pool_2d.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_dims.h" +#include "op-attrs/tensor_shape.h" +#include "utils/integer_conversions.h" namespace FlexFlow { -TensorShape get_output_shape(Pool2DAttrs const &, TensorShape const &) { - NOT_IMPLEMENTED(); -} +tl::expected + make_adaptive_pool2d_attrs(TensorDims const &input_dims, + int output_h, + int output_w, + PoolOp pool_type, + std::optional const &activation) { + // AdaptivePool2D semantics pulled from + // https://stackoverflow.com/questions/53841509/how-does-adaptive-pooling-in-pytorch-work/63603993 -ParallelTensorShape get_output_shape(Pool2DAttrs const &, - ParallelTensorShape const &) { - NOT_IMPLEMENTED(); -} + if (num_dims(input_dims) != 4) { + return tl::unexpected( + fmt::format("make_adaptive_pool2d_attrs expected input tensor to " + "have 4 dims, but received dims {}", + input_dims)); + } -} // namespace FlexFlow + size_t num_samples = dim_at_idx(input_dims, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_dims, ff_dim_t{1}); + size_t input_h = dim_at_idx(input_dims, ff_dim_t{2}); + size_t input_w = dim_at_idx(input_dims, ff_dim_t{3}); -/* -#include "op-attrs/ops/pool_2d.h" -#include "parallel_dim_mapping_record.h" -#include "parallel_dim_mapping_record_solver.h" + if (input_h % output_h != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_h % output_h " + "== 0, but received input_h={} and output_h={} (input_dims={}). If you " + "need input_h % output_h != 0 supported, please create an issue.", + input_h, + output_h, + input_dims)); + } -namespace FlexFlow { + if (input_w % output_w != 0) { + return tl::unexpected(fmt::format( + "Currently make_adaptive_pool2d_attrs only supports input_w % output_w " + "== 0, but received input_w={} and output_w={} (input_dims={}). If you " + "need input_w % output_w != 0 supported, please create an issue.", + input_w, + output_w, + input_dims)); + } -namespace Input { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + // Note that for some reason the stack overflow post linked above states that + // `kernel_size = ind - (outd-1)*stride`, but some simplification yields + // `kernel_size` = `ind - (outd - 1)*stride` + // = `ind - (outd - 1) * (ind / outd)` + // = `ind - ind + (ind /outd)` + // = `ind / outd` + // = `stride` -namespace Output { -constexpr int NUMDIM = 5, WIDTH = 0, HEIGHT = 1, CHANNEL = 2, SAMPLE = 3, - REPLICA = 4; -}; + int kernel_h = input_h / output_h; + int kernel_w = input_w / output_w; -bool Pool2DAttrs::is_valid(ParallelTensorShape const &input) const { - ParallelTensorShape output_shape = this->calculate_output_shape(input); + int stride_h = kernel_h; + int stride_w = kernel_w; - return output_shape.is_valid() && (input.at(Input::REPLICA).degree == 1); -} + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernel_h, + /*kernel_w=*/kernel_w, + /*stride_h=*/stride_h, + /*stride_w=*/stride_w, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + + TensorShape expected_ouput_shape = TensorShape{ + TensorDims{FFOrdered{ + num_samples, + num_channels, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; -static std::vector - construct_mappings(ParallelTensorShape const &input_shape) { - auto const outputMappings = construct_output_parallel_dims({ - {Input::REPLICA, MappingOperation::PARTITION, Output::REPLICA}, - {Input::SAMPLE, MappingOperation::PARTITION, Output::SAMPLE}, - {Input::CHANNEL, MappingOperation::PARTITION, Output::CHANNEL}, - {Input::HEIGHT, MappingOperation::PARTITION, Output::HEIGHT}, - {Input::WIDTH, MappingOperation::PARTITION, Output::WIDTH}, + TensorShape output_shape = ({ + tl::expected result = + get_output_shape(attrs, TensorShape{input_dims, DataType::FLOAT}); + if (!result.has_value()) { + return tl::unexpected(result.error()); + } + result.value(); }); - return outputMappings; + if (output_shape != expected_ouput_shape) { + return tl::unexpected( + fmt::format("Result of make_adaptive_pool_2d (i.e., {}) should produce " + "expected output shape {}, but produced {}. This is a bug " + "in FlexFlow, Please create an issue.", + attrs, + expected_ouput_shape, + output_shape)); + } + + return attrs; } -static ParallelDimMappingSolution - solve_mappings(ParallelTensorShape const &input) { - return solve_parallel_dim_mappings(construct_mappings(input), {input}, 0, 1); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, TensorShape const &input_shape) { + if (num_dims(input_shape) != 4) { + return tl::unexpected( + fmt::format("get_output_shape for Pool2DAttrs expected input tensor to " + "have 4 dims, but received shape {}", + input_shape)); + } + + size_t num_samples = dim_at_idx(input_shape, ff_dim_t{0}); + size_t num_channels = dim_at_idx(input_shape, ff_dim_t{1}); + size_t input_height = dim_at_idx(input_shape, ff_dim_t{2}); + size_t input_width = dim_at_idx(input_shape, ff_dim_t{3}); + + size_t output_height = + (input_height + 2 * attrs.padding_h - attrs.kernel_h) / attrs.stride_h + + 1; + + size_t output_width = + (input_width + 2 * attrs.padding_w - attrs.kernel_w) / attrs.stride_w + 1; + + return TensorShape{TensorDims{FFOrdered{ + num_samples, + num_channels, + output_height, + output_width, + }}, + input_shape.data_type}; } -ParallelTensorShape Pool2DAttrs::calculate_output_shape(ParallelTensorShape -const &input) const { return solve_mappings(input).output_shapes.at(0); +tl::expected + get_output_shape(Pool2DAttrs const &attrs, + ParallelTensorShape const &input_shape) { + TensorShape unpar = ({ + tl::expected result_unpar = + get_output_shape(attrs, get_reduced_shape(input_shape)); + if (!result_unpar.has_value()) { + return tl::unexpected(result_unpar.error()); + } + result_unpar.value(); + }); + + ParallelTensorDimDegrees degrees = ({ + tl::expected result_degrees = + get_output_parallel_dim_degrees(attrs, + get_parallel_degrees(input_shape)); + if (!result_degrees.has_value()) { + return tl::unexpected(result_degrees.error()); + } + result_degrees.value(); + }); + + return lift_to_parallel_with_degrees(unpar, degrees); +} + +tl::expected + get_output_parallel_dim_degrees( + Pool2DAttrs const &attrs, + ParallelTensorDimDegrees const &input_degrees) { + if (input_degrees.sum_degree.value > 1) { + if (attrs.pool_type == PoolOp::MAX) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with PoolOp::MAX " + "expected input sum degree == 1, but received {}", + input_degrees)); + } else if (attrs.activation.has_value()) { + return tl::unexpected(fmt::format( + "get_output_parallel_dim_degrees for Pool2DAttrs with activation={} " + "expected input sum degree == 1, but received {}", + attrs.activation.value(), + input_degrees)); + } + } + + return input_degrees; } } // namespace FlexFlow -*/ diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 4bce5449f4..61062b84b0 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -29,6 +29,14 @@ size_t num_shard_dims(ParallelTensorDims const &dims) { return dims.shard_dims.size(); } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { + return ParallelTensorDimDegrees{ + d.replica_dims.sum_degree, + d.replica_dims.discard_copy_degree, + ff_ordered_shard_degrees(d), + }; +} + int total_replica_degree(ParallelTensorDims const &dims) { return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 10bf5027a4..3cd0f47a5d 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -59,6 +59,10 @@ std::optional } } +ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &s) { + return get_parallel_degrees(s.dims); +} + ParallelTensorShape lift_to_parallel(TensorShape const &s) { return ParallelTensorShape{lift_to_parallel(s.dims), s.data_type}; } @@ -75,6 +79,15 @@ ParallelTensorShape }; } +ParallelTensorShape + lift_to_parallel_with_degrees(TensorShape const &s, + ParallelTensorDimDegrees const °rees) { + return lift_to_parallel_with_degrees(s, + degrees.sum_degree, + degrees.discard_copy_degree, + degrees.shard_degrees); +} + TensorShape require_not_parallel(ParallelTensorShape const &s) { int total_degree = get_total_parallel_degree(s); if (total_degree != 1) { diff --git a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc index 42ea07e6b5..84f1861f0b 100644 --- a/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/computation_graph_op_attrs.cc @@ -5,8 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("ComputationGraphOpAttrs to/from json") { - ComputationGraphOpAttrs correct = - ComputationGraphOpAttrs{BatchNormAttrs{true}}; + ComputationGraphOpAttrs correct = ComputationGraphOpAttrs{BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }}; nlohmann::json j = correct; auto result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc new file mode 100644 index 0000000000..2ac641cfc2 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/concat.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/concat.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat(FFOrdered, FFOrdered)") { + SUBCASE("inputs have elements") { + FFOrdered l_input = FFOrdered{1, 3, 1}; + FFOrdered r_input = FFOrdered{2, 1}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {1, 3, 1, 2, 1}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + FFOrdered l_input = FFOrdered{}; + FFOrdered r_input = FFOrdered{}; + + FFOrdered result = concat(l_input, r_input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("concat(std::vector>)") { + SUBCASE("inputs have elements") { + std::vector> input = { + {1}, + {2, 1}, + {1}, + }; + + FFOrdered result = concat(input); + FFOrdered correct = { + 1, + 2, + 1, + 1, + }; + + CHECK(result == correct); + } + + SUBCASE("no inputs") { + std::vector> input = {}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("inputs are empty") { + std::vector> input = {{}, {}, {}}; + + FFOrdered result = concat(input); + FFOrdered correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc new file mode 100644 index 0000000000..7bc1695e5c --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/dim_ordered/ff_ordered_from_map.cc @@ -0,0 +1,66 @@ +#include "op-attrs/dim_ordered/ff_ordered_from_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("ff_ordered_from_map", + T, + std::map, + std::unordered_map) { + SUBCASE("input is empty") { + T m = {}; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is missing keys") { + SUBCASE("missing key is in middle") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 2}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("missing key is 0 idx") { + T m = { + {ff_dim_t{1}, 2}, + {ff_dim_t{2}, 7}, + {ff_dim_t{3}, 5}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + } + + SUBCASE("input has negative keys") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{-1}, 2}, + }; + + CHECK_THROWS(ff_ordered_from_map(m)); + } + + SUBCASE("input is valid") { + T m = { + {ff_dim_t{0}, 4}, + {ff_dim_t{1}, 5}, + {ff_dim_t{2}, 2}, + {ff_dim_t{3}, 7}, + }; + + FFOrdered result = ff_ordered_from_map(m); + FFOrdered correct = {4, 5, 2, 7}; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc index 60dedfe70a..33cc00c6a1 100644 --- a/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/test/src/op-attrs/get_incoming_tensor_roles.cc @@ -9,7 +9,7 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("Concat") { int num_incoming = 4; ComputationGraphOpAttrs attrs = - ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}, num_incoming}}; + ComputationGraphOpAttrs{ConcatAttrs{ff_dim_t{0}}}; std::vector result = get_incoming_tensor_roles(attrs, num_incoming); diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc new file mode 100644 index 0000000000..4196394d00 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm.cc @@ -0,0 +1,404 @@ +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_batch_norm_incoming_tensor_roles(BatchNormAttrs)") { + auto make_attrs = [](bool affine) { + return BatchNormAttrs{ + /*relu=*/false, + /*affine=*/affine, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + }; + + SUBCASE("affine = true") { + BatchNormAttrs attrs = make_attrs(/*affine=*/true); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + IncomingTensorRole::WEIGHT, + IncomingTensorRole::WEIGHT, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + BatchNormAttrs attrs = make_attrs(/*affine=*/false); + + std::vector result = + get_batch_norm_incoming_tensor_roles(attrs); + std::vector correct = { + IncomingTensorRole::INPUT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("shape inference (BatchNorm)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 12, + 14, + 16, + 18, + }}, + DataType::FLOAT, + }; + + TensorShape output = input; + + TensorShape gamma = TensorShape{ + TensorDims{FFOrdered{ + 14, + }}, + DataType::FLOAT, + }; + + TensorShape beta = gamma; + + SUBCASE("get_output_shape(BatchNormAttrs, TensorShape)") { + tl::expected result = + get_output_shape(attrs_affine_true, input); + tl::expected correct = output; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_shape(attrs_affine_true, input); + tl::expected correct = gamma; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_gamma_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, TensorShape)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_shape(attrs_affine_true, input); + tl::expected correct = beta; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = optional_from_expected( + get_beta_weights_shape(attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel dim degree inference (BatchNormAttrs)") { + BatchNormAttrs attrs_affine_true = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + BatchNormAttrs attrs_affine_false = [&] { + BatchNormAttrs attrs = attrs_affine_true; + attrs.affine = false; + return attrs; + }(); + + SUBCASE("partition parallelism (in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + degree, + 1, + 1, + }, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + tl::expected result = + get_output_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_gamma_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + SUBCASE("affine = true") { + tl::expected result = + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{degree}, + }; + + CHECK(result == correct); + } + + SUBCASE("affine = false") { + std::optional result = + optional_from_expected(get_beta_weights_parallel_dim_degrees( + attrs_affine_false, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + SUBCASE("partition parallelism (not in channel dim)") { + int degree = 2; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, degree, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("sum parallelism") { + SumDegree sum_degree = SumDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + sum_degree, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy parallelism") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + discard_copy_degree, + FFOrdered{1, 1, 1, 1}, + }; + + SUBCASE("get_output_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_gamma_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_parallel_dim_degrees(BatchNormAttrs, " + "ParallelTensorDimDegrees)") { + std::optional result = optional_from_expected( + get_beta_weights_parallel_dim_degrees(attrs_affine_true, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel shape inference (BatchNormAttrs)") { + // since most of the edge cases are already tested in the above test cases + // (i.e., shape inference and parallel degree inference) + // here we just do a basic check that they compose + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/true, + /*affine=*/true, + /*eps=*/1.0, + /*momentum=*/0.1, + }; + + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 1}, + ShardParallelDim{14, 2}, + ShardParallelDim{16, 1}, + ShardParallelDim{18, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + SUBCASE("get_output_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("get_gamma_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_gamma_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("get_beta_weights_shape(BatchNormAttrs, ParallelTensorShape)") { + tl::expected result = + get_beta_weights_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc index df436da66c..3d86576279 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/batch_norm_attrs.cc @@ -5,7 +5,12 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("BatchNormAttrs to/from json") { - BatchNormAttrs correct = BatchNormAttrs{true}; + BatchNormAttrs correct = BatchNormAttrs{ + /*relu=*/false, + /*affine=*/true, + /*eps=*/1e-5, + /*momentum=*/0.1, + }; nlohmann::json j = correct; BatchNormAttrs result = j.get(); diff --git a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc index 152df09eca..7abb98f3e3 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/conv_2d.cc @@ -73,8 +73,8 @@ TEST_SUITE(FF_TEST_SUITE) { }; size_t num_samples = 7; - size_t input_channels = 6; - size_t input_height = 10; + size_t input_channels = 4; + size_t input_height = 11; size_t input_width = 15; TensorShape input = TensorShape{ @@ -87,8 +87,8 @@ TEST_SUITE(FF_TEST_SUITE) { DataType::FLOAT, }; - size_t output_height = 3; - size_t output_width = 6; + size_t output_height = 6; + size_t output_width = 8; TensorShape output = TensorShape{ TensorDims{FFOrdered{ diff --git a/lib/op-attrs/test/src/op-attrs/ops/flat.cc b/lib/op-attrs/test/src/op-attrs/ops/flat.cc new file mode 100644 index 0000000000..d81ab95c35 --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/flat.cc @@ -0,0 +1,244 @@ +#include "op-attrs/ops/flat.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(FlatAttrs, TensorShape)") { + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + SUBCASE("flatten all dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4 * 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten trailing dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{4}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4, + 2 * 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten leading dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{0}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2 * 4, + 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten middle dims") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = TensorShape{ + TensorDims{FFOrdered{ + 2, + 4 * 2, + 3, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim == end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{2}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + + SUBCASE("flatten no dims (start_dim < end_dim)") { + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{2}, + /*end_dim=*/ff_dim_t{1}, + }; + + TensorShape result = get_output_shape(attrs, input_shape); + TensorShape correct = input_shape; + + CHECK(result == correct); + } + } + + TEST_CASE( + "get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees)") { + FlatAttrs attrs = FlatAttrs{/*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}}; + + SUBCASE("allows shard parallelism in non-flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 1, 3}, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{2, 1, 3}, + }; + + CHECK(result == correct); + } + + SUBCASE("does not allow shard parallelism in flattened dims") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 2, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("allows sum parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1, 1}, + }; + + std::optional result = + optional_from_expected(get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = + ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{2}, + FFOrdered{1, 1, 1}, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(FlatAttrs, ParallelTensorShape)") { + // since most of the edge cases are already tested in + // get_output_shape(FlatAttrs, TensorShape) and + // get_output_parallel_dim_degrees(FlatAttrs, ParallelTensorDimDegrees), + // here we just do a basic check that they compose + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8, 1}, + ShardParallelDim{6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{1}, + /*end_dim=*/ff_dim_t{3}, + }; + + tl::expected result = + get_output_shape(attrs, input_shape); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{4, 2}, + ShardParallelDim{8 * 6, 1}, + ShardParallelDim{9, 3}, + }, + ReplicaParallelDimSet{ + SumDegree{7}, + DiscardCopyDegree{5}, + }, + }, + DataType::FLOAT, + }; + + CHECK(result == correct); + } +} diff --git a/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc new file mode 100644 index 0000000000..0c14c0fc2a --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/pool_2d.cc @@ -0,0 +1,400 @@ +#include "op-attrs/ops/pool_2d.h" +#include "utils/expected.h" +#include "utils/fmt/expected.h" +#include "utils/fmt/optional.h" +#include "utils/integer_conversions.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("make_adaptive_pool2d") { + size_t input_n = 10; + size_t input_c = 11; + size_t input_h = 15; + size_t input_w = 20; + Activation activation = Activation::RELU; + PoolOp op = PoolOp::AVG; + + TensorDims input_dims = + TensorDims{FFOrdered{input_n, input_c, input_h, input_w}}; + + SUBCASE("input_h divisible by output_h && input_w divisible by output_w") { + int output_h = 5; + int output_w = 2; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/10, + /*stride_h=*/3, + /*stride_w=*/10, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + input_n, + input_c, + size_t_from_int(output_h), + size_t_from_int(output_w), + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + SUBCASE("input_h not divisible by output_h") { + int output_h = 6; + int output_w = 2; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_w not divisible by output_w") { + int output_h = 5; + int output_w = 3; + + std::optional result = + optional_from_expected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input_h == output_h and input_w == output_w") { + int output_h = input_h; + int output_w = input_w; + + Pool2DAttrs correct_attrs = Pool2DAttrs{ + /*kernel_h=*/1, + /*kernel_w=*/1, + /*stride_h=*/1, + /*stride_w=*/1, + /*padding_h=*/0, + /*padding_w=*/0, + /*pool_type=*/op, + /*activation=*/activation, + }; + + SUBCASE("returns correct attrs") { + tl::expected result = + make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, op, activation); + tl::expected correct = correct_attrs; + + CHECK(result == correct); + } + + SUBCASE( + "confirm that output shape is as expected for the expected attrs") { + TensorShape input_shape = TensorShape{input_dims, DataType::FLOAT}; + + tl::expected result = + get_output_shape(correct_attrs, input_shape); + tl::expected correct = input_shape; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, TensorShape)") { + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("fails on non-4d inputs") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{ + 10, + 12, + 14, + }}, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("4d input") { + TensorShape input = TensorShape{ + TensorDims{FFOrdered{11, 13, 12, 6}}, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{11, 13, 6, 4}}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_parallel_dim_degrees(Pool2DAttrs, " + "ParallelTensorDimDegrees)") { + auto make_attrs = [](PoolOp pool_type, + std::optional const &activation) { + return Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/pool_type, + /*activation=*/activation, + }; + }; + + SUBCASE("allows data parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows arbitrary input sharding parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{1}, + FFOrdered{ + 4, + 2, + 5, + 6, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("allows discard copy parallelism") { + Pool2DAttrs attrs = make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{1}, + DiscardCopyDegree{3}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + + SUBCASE("sum parallelism") { + SUBCASE("without activation") { + SUBCASE("PoolOp::MAX does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::MAX, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = + optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("PoolOp::AVG does allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/std::nullopt); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + tl::expected result = + get_output_parallel_dim_degrees(attrs, input); + tl::expected correct = input; + + CHECK(result == correct); + } + } + + SUBCASE("with activation does not allow sum parallelism") { + Pool2DAttrs attrs = + make_attrs(PoolOp::AVG, /*activation=*/Activation::RELU); + + ParallelTensorDimDegrees input = ParallelTensorDimDegrees{ + SumDegree{2}, + DiscardCopyDegree{1}, + FFOrdered{ + 1, + 1, + 1, + 1, + }, + }; + + std::optional result = optional_from_expected( + get_output_parallel_dim_degrees(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + } + + TEST_CASE("get_output_shape(Pool2DAttrs, ParallelTensorShape)") { + // this function is mostly covered by the tests above, so we + // just do a single test to make sure it works/exists + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/3, + /*kernel_w=*/2, + /*stride_h=*/2, + /*stride_w=*/2, + /*padding_h=*/1, + /*padding_w=*/1, + /*pool_type=*/PoolOp::MAX, + /*activation=*/std::nullopt, + }; + + SUBCASE("valid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{12, 3}, + ShardParallelDim{6, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + tl::expected result = + get_output_shape(attrs, input); + tl::expected correct = + ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 7}, + ShardParallelDim{16, 8}, + ShardParallelDim{6, 3}, + ShardParallelDim{4, 2}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + } + + SUBCASE("invalid parallelism") { + ParallelTensorShape input = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{14, 1}, + ShardParallelDim{16, 1}, + ShardParallelDim{12, 1}, + ShardParallelDim{6, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/computation_graph_builder.h b/lib/pcg/include/pcg/computation_graph_builder.h index 11e591545d..45cde0de57 100644 --- a/lib/pcg/include/pcg/computation_graph_builder.h +++ b/lib/pcg/include/pcg/computation_graph_builder.h @@ -137,6 +137,13 @@ struct ComputationGraphBuilder { PoolOp type = PoolOp::MAX, std::optional const &activation = std::nullopt, std::optional const &name = std::nullopt); + tensor_guid_t adaptive_pool2d( + tensor_guid_t const &input, + int output_h, + int output_w, + PoolOp type = PoolOp::MAX, + std::optional const &activation = std::nullopt, + std::optional const &name = std::nullopt); tensor_guid_t layer_norm(tensor_guid_t const &input, std::vector const &axes, @@ -145,7 +152,10 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); tensor_guid_t batch_norm(tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); tensor_guid_t batch_matmul(tensor_guid_t const &A, @@ -170,11 +180,9 @@ struct ComputationGraphBuilder { DataType dtype, std::optional const &name = std::nullopt); // Add a concat layer - tensor_guid_t - concat(int n, - std::vector const &tensors, - int axis, - std::optional const &maybe_name = std::nullopt); + tensor_guid_t concat(std::vector const &tensors, + int axis, + std::optional const &name = std::nullopt); // Add a mean layer tensor_guid_t mean(tensor_guid_t const &input, std::vector const &dims, @@ -188,6 +196,8 @@ struct ComputationGraphBuilder { std::optional const &name = std::nullopt); // Add a flat layer tensor_guid_t flat(tensor_guid_t const &input, + int start_dim = 0, + std::optional const &end_dim = std::nullopt, std::optional const &name = std::nullopt); // Add a softmax layer tensor_guid_t softmax(tensor_guid_t const &input, @@ -252,9 +262,9 @@ struct ComputationGraphBuilder { std::vector const &weights, std::vector const &outputs); -private: TensorShape get_shape(tensor_guid_t const &) const; +private: tensor_guid_t broadcast(tensor_guid_t const &, TensorDims const &, std::string const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h index 3a7f67dcf0..019b120936 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph_builder.h @@ -87,7 +87,10 @@ struct ParallelComputationGraphBuilder { parallel_tensor_guid_t batch_norm(parallel_tensor_guid_t const &input, - bool relu = true, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &name = std::nullopt); parallel_tensor_guid_t diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index a4f61cff98..4a565476bd 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -6,14 +6,17 @@ #include "op-attrs/ops/attention.h" #include "op-attrs/ops/batch_norm.h" #include "op-attrs/ops/broadcast.h" +#include "op-attrs/ops/concat.h" #include "op-attrs/ops/conv_2d.h" #include "op-attrs/ops/dropout.h" #include "op-attrs/ops/element_binary.h" #include "op-attrs/ops/element_unary.h" #include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/flat.h" #include "op-attrs/ops/gather.h" #include "op-attrs/ops/layer_norm.h" #include "op-attrs/ops/linear.h" +#include "op-attrs/ops/pool_2d.h" #include "op-attrs/ops/softmax.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/tensor_dims.h" @@ -498,21 +501,130 @@ tensor_guid_t ComputationGraphBuilder::gather( return get_only( this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); } +tensor_guid_t ComputationGraphBuilder::pool2d( + tensor_guid_t const &x, + int kernelH, + int kernelW, + int strideH, + int strideW, + int paddingH, + int paddingW, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + Pool2DAttrs attrs = Pool2DAttrs{ + /*kernel_h=*/kernelH, + /*kernel_w=*/kernelW, + /*stride_h=*/strideH, + /*stride_w=*/strideW, + /*padding_h=*/paddingH, + /*padding_w=*/paddingW, + /*pool_type=*/type, + /*activation=*/activation, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t input = + this->as_type(x, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, this->get_shape(input))); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::adaptive_pool2d( + tensor_guid_t const &uncasted_input, + int output_h, + int output_w, + PoolOp type, + std::optional const &activation, + std::optional const &maybe_name) { + + TensorDims input_dims = this->get_shape(uncasted_input).dims; + + Pool2DAttrs attrs = throw_if_unexpected(make_adaptive_pool2d_attrs( + input_dims, output_h, output_w, type, activation)); + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + tensor_guid_t casted_input = + this->as_type(uncasted_input, DataType::FLOAT, name + "input_pre_cast"); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = throw_if_unexpected( + get_output_shape(attrs, this->get_shape(casted_input))); + + return get_only(this->add_layer( + layer, {casted_input}, {}, {make_output_attrs(output_shape)})); +} tensor_guid_t ComputationGraphBuilder::batch_norm( tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; + std::string name = maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; - TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + TensorShape input_shape = this->get_shape(input); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shape)); - return get_only( - this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); + std::vector weights; + + if (affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + TensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + TensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } + + return get_only(this->add_layer( + layer, + {input}, + transform(weights, + [&](TensorAttrs const &a) { return this->create_weight(a); }), + {make_output_attrs(output_shape)})); } tensor_guid_t ComputationGraphBuilder::multihead_attention( @@ -674,6 +786,50 @@ tensor_guid_t ComputationGraphBuilder::dense( layer, {input}, weights, {make_output_attrs(output_shape)})); } +tensor_guid_t ComputationGraphBuilder::concat( + std::vector const &inputs, + int axis, + std::optional const &maybe_name) { + + ConcatAttrs attrs = ConcatAttrs{ff_dim_t{axis}}; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + std::vector input_shapes = transform( + inputs, [&](tensor_guid_t const &i) { return this->get_shape(i); }); + TensorShape output_shape = + throw_if_unexpected(get_output_shape(attrs, input_shapes)); + + return get_only( + this->add_layer(layer, inputs, {}, {make_output_attrs(output_shape)})); +} + +tensor_guid_t ComputationGraphBuilder::flat( + tensor_guid_t const &input, + int start_dim, + std::optional const &end_dim, + std::optional const &maybe_name) { + int input_num_dims = num_dims(this->get_shape(input)); + + FlatAttrs attrs = FlatAttrs{ + /*start_dim=*/ff_dim_t{start_dim}, + /*end_dim=*/ff_dim_t{end_dim.value_or(input_num_dims)}, + }; + + std::string name = + maybe_name.value_or(get_default_name(ComputationGraphOpAttrs{attrs})); + + LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; + + TensorShape output_shape = get_output_shape(attrs, this->get_shape(input)); + + return get_only( + this->add_layer(layer, {input}, {}, {make_output_attrs(output_shape)})); +} + tensor_guid_t ComputationGraphBuilder::layer_norm( tensor_guid_t const &input, std::vector const &axes, diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 620dc035fc..ce00ea62f4 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -331,18 +331,56 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::multihead_attention( parallel_tensor_guid_t ParallelComputationGraphBuilder::batch_norm( parallel_tensor_guid_t const &input, - bool relu, + bool affine, + std::optional const &activation, + float eps, + std::optional const &momentum, std::optional const &maybe_name) { - BatchNormAttrs attrs = BatchNormAttrs{relu}; + if (activation.has_value() && activation.value() != Activation::RELU) { + throw mk_runtime_error(fmt::format( + "batch_norm currently only supports (1) no activation function, or (2) " + "relu activation function, but received {}. " + "If you need support for additional activation functions, please " + "create an issue.", + activation)); + } + + BatchNormAttrs attrs = BatchNormAttrs{ + /*relu=*/activation.has_value(), + /*affine=*/affine, + /*eps=*/eps, + /*momentum=*/momentum, + }; std::string name = maybe_name.value_or(get_default_name(PCGOperatorAttrs{attrs})); ParallelLayerAttrs layer = ParallelLayerAttrs{PCGOperatorAttrs{attrs}, name}; + ParallelTensorShape input_shape = this->get_shape(input); + ParallelTensorShape output_shape = - get_output_shape(attrs, this->get_shape(input)); + throw_if_unexpected(get_output_shape(attrs, input_shape)); + + std::vector weights; + + if (attrs.affine) { + // initializers chosen to match those of + // https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html + + ParallelTensorShape gamma_shape = + throw_if_unexpected(get_gamma_weights_shape(attrs, input_shape)); + InitializerAttrs gamma_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{1}}}}; + weights.push_back(make_weight_attrs(gamma_shape, gamma_initializer)); + + ParallelTensorShape beta_shape = + throw_if_unexpected(get_beta_weights_shape(attrs, input_shape)); + InitializerAttrs beta_initializer = + InitializerAttrs{ConstantInitializerAttrs{DataTypeValue{float{0}}}}; + weights.push_back(make_weight_attrs(beta_shape, beta_initializer)); + } return this->add_layer(layer, {input}, {}, {output_shape}); } diff --git a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml index 59e913750e..eb758ea4fc 100644 --- a/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml +++ b/lib/substitutions/include/substitutions/operator_pattern/operator_attribute_key.enum.toml @@ -55,7 +55,8 @@ values = [ { name = "SHOULD_BROADCAST_LHS" }, { name = "SHOULD_BROADCAST_RHS" }, { name = "DIM" }, - { name = "ELEMENTWISE_AFFINE" }, + { name = "AFFINE" }, + { name = "MOMENTUM" }, { name = "REGULARIZER" }, { name = "SHAPE" }, { name = "SPLITS" }, diff --git a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc index d5d735ef59..442d3345a1 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/get_attribute.cc @@ -19,8 +19,12 @@ std::optional get_attribute(BatchNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); - case OperatorAttributeKey::RELU: - return p.relu; + case OperatorAttributeKey::EPSILON: + return p.eps; + case OperatorAttributeKey::AFFINE: + return p.affine; + case OperatorAttributeKey::MOMENTUM: + return p.momentum; default: return std::nullopt; } @@ -189,6 +193,10 @@ std::optional get_attribute(LayerNormAttrs const &p, switch (key) { case OperatorAttributeKey::OP_TYPE: return get_op_type(p); + case OperatorAttributeKey::AFFINE: + return p.elementwise_affine; + case OperatorAttributeKey::AXES: + return vector_of(p.axes); default: return std::nullopt; } diff --git a/lib/substitutions/test/src/substitutions/substitution.cc b/lib/substitutions/test/src/substitutions/substitution.cc index 87ffc01f0b..1718b03b5c 100644 --- a/lib/substitutions/test/src/substitutions/substitution.cc +++ b/lib/substitutions/test/src/substitutions/substitution.cc @@ -21,7 +21,7 @@ TEST_SUITE(FF_TEST_SUITE) { // } TEST_CASE("evaluate_substitution_output(SubParallelComputationGraph, " - "Substituion, PCGPatternMatch)") { + "Substitution, PCGPatternMatch)") { // Currently Substitution creation is very verbose. // This is being addressed in // https://github.com/flexflow/FlexFlow/issues/1473. diff --git a/lib/utils/include/utils/containers/are_all_same.h b/lib/utils/include/utils/containers/are_all_same.h new file mode 100644 index 0000000000..37b1838146 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_same.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_SAME_H + +namespace FlexFlow { + +template +bool are_all_same(C const &c) { + if (c.empty()) { + return true; + } + + auto const &first = *c.cbegin(); + for (auto const &v : c) { + if (v != first) { + return false; + } + } + return true; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_same1.h b/lib/utils/include/utils/containers/require_all_same1.h new file mode 100644 index 0000000000..2f42243857 --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_same1.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_SAME1_H + +#include +#include + +namespace FlexFlow { + +template +tl::expected require_all_same1(C const &c) { + if (c.empty()) { + return tl::unexpected(fmt::format( + "require_all_same1 expected non-empty container, but received {}", c)); + } + + T const &first = *c.cbegin(); + for (T const &v : c) { + if (v != first) { + return tl::unexpected(fmt::format("require_all_same1 found non-same " + "elements {} and {} in containers {}", + first, + v, + c)); + } + } + return first; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 52368f94ad..5ae90ec5ba 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -25,10 +25,15 @@ std::vector subvec(std::vector const &v, if (maybe_start.has_value()) { begin_iter += resolve_loc(maybe_start.value()); } + if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } + if (end_iter < begin_iter) { + end_iter = begin_iter; + } + std::vector output(begin_iter, end_iter); return output; } diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h new file mode 100644 index 0000000000..5dbd620781 --- /dev/null +++ b/lib/utils/include/utils/containers/sum.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_H + +namespace FlexFlow { + +template +T sum(C const &c) { + T result = 0; + for (T const &t : c) { + result += t; + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 3ec165d595..377561d70c 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_OPTIONAL_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_OPTIONAL_H #include "utils/exception.h" #include "utils/fmt/optional.h" diff --git a/lib/utils/src/utils/containers/are_all_same.cc b/lib/utils/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..c515bceee2 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_same.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_same.h" diff --git a/lib/utils/src/utils/containers/require_all_same1.cc b/lib/utils/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..295339a91d --- /dev/null +++ b/lib/utils/src/utils/containers/require_all_same1.cc @@ -0,0 +1 @@ +#include "utils/containers/require_all_same1.h" diff --git a/lib/utils/src/utils/containers/sum.cc b/lib/utils/src/utils/containers/sum.cc new file mode 100644 index 0000000000..088b5f1983 --- /dev/null +++ b/lib/utils/src/utils/containers/sum.cc @@ -0,0 +1 @@ +#include "utils/containers/sum.h" diff --git a/lib/utils/test/src/utils/containers/are_all_same.cc b/lib/utils/test/src/utils/containers/are_all_same.cc new file mode 100644 index 0000000000..fd8b321439 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_same.cc @@ -0,0 +1,36 @@ +#include "utils/containers/are_all_same.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_same(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are all same") { + std::vector input = {1, 1, 1}; + + bool result = are_all_same(input); + bool correct = true; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all same") { + std::vector input = {1, 1, 2, 1}; + + bool result = are_all_same(input); + bool correct = false; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/require_all_same1.cc b/lib/utils/test/src/utils/containers/require_all_same1.cc new file mode 100644 index 0000000000..48c1ab0b99 --- /dev/null +++ b/lib/utils/test/src/utils/containers/require_all_same1.cc @@ -0,0 +1,54 @@ +#include "utils/containers/require_all_same1.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/multiset.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/expected.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("require_all_same1(T)", + T, + std::vector, + std::unordered_set, + std::unordered_multiset, + std::set, + std::multiset) { + SUBCASE("input is empty") { + T input = {}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input elements are all the same") { + T input = {1, 1, 1}; + + tl::expected result = require_all_same1(input); + tl::expected correct = 1; + + CHECK(result == correct); + } + + SUBCASE("input elements are not all the same") { + T input = {1, 1, 2, 1}; + + std::optional result = + optional_from_expected(require_all_same1(input)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/sum.cc b/lib/utils/test/src/utils/containers/sum.cc new file mode 100644 index 0000000000..32d8cd32a3 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum.cc @@ -0,0 +1,27 @@ +#include "utils/containers/sum.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sum(std::vector)") { + SUBCASE("input is empty") { + std::vector input = {}; + + int result = sum(input); + int correct = 0; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::vector input = {1, 3, 2}; + + int result = sum(input); + int correct = 6; + + CHECK(result == correct); + } + } +} From dbb642a970ec74f73d0666b4ca5855a5e5b2a184 Mon Sep 17 00:00:00 2001 From: hsdfzhsdfz <157999659+hsdfzhsdfz@users.noreply.github.com> Date: Mon, 23 Sep 2024 15:39:50 -0400 Subject: [PATCH 07/12] Add Candle Uno Model PCG (#1479) * Add initial Candle Uno model PCG * Adapt candle uno model to existing code style * Change to use dtgen * Add shape inference and ComputationGraphBuilder support for Concat * Update based on review * Update following review * Add glorot normal initializer * Update * Update * Fix bugs created by merge, add candle_uno to export-model-arch * Format --------- Co-authored-by: Colin Unger --- .github/workflows/per-lib-check.yml | 3 + .../src/export_model_arch.cc | 10 +- ...ion_graph_series_parallel_decomposition.cc | 11 + .../include/models/candle_uno/candle_uno.h | 41 +++ .../candle_uno/candle_uno_config.struct.toml | 52 +++ .../src/models/candle_uno/candle_uno.cc | 123 +++++++ .../test/src/models/candle_uno/candle_uno.cc | 19 + .../include/op-attrs/parallel_tensor_dims.h | 10 + .../include/op-attrs/parallel_tensor_shape.h | 4 +- lib/op-attrs/include/op-attrs/tensor_dims.h | 7 - .../src/op-attrs/parallel_tensor_dims.cc | 38 ++ .../src/op-attrs/parallel_tensor_shape.cc | 20 +- lib/op-attrs/src/op-attrs/tensor_dims.cc | 27 -- lib/op-attrs/test/src/op-attrs/ops/concat.cc | 331 ++++++++++++++++++ .../pcg/initializer_attrs.variant.toml | 5 + .../glorot_normal_attrs.struct.toml | 14 + lib/utils/include/utils/containers.decl.h | 6 - lib/utils/include/utils/containers.h | 20 -- 18 files changed, 667 insertions(+), 74 deletions(-) create mode 100644 lib/models/include/models/candle_uno/candle_uno.h create mode 100644 lib/models/include/models/candle_uno/candle_uno_config.struct.toml create mode 100644 lib/models/src/models/candle_uno/candle_uno.cc create mode 100644 lib/models/test/src/models/candle_uno/candle_uno.cc create mode 100644 lib/op-attrs/test/src/op-attrs/ops/concat.cc create mode 100644 lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index a5ac6fd29f..f976d369d5 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -23,6 +23,9 @@ jobs: - name: Add helpers directory to path run: echo "${PWD}/.github/workflows/helpers" >> $GITHUB_PATH + - name: Free additional space on runner + run: free_space_on_runner.sh + - name: Install nix uses: cachix/install-nix-action@v25 with: diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 98b7a003ce..9833102bd1 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,6 +1,7 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" @@ -63,6 +64,8 @@ tl::expected } else if (model_name == "inception_v3") { return get_inception_v3_computation_graph( get_default_inception_v3_training_config()); + } else if (model_name == "candle_uno") { + return get_candle_uno_computation_graph(get_default_candle_uno_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -135,8 +138,11 @@ int main(int argc, char **argv) { "output a dot representation of model's computation graph " "for preprocessed to help check series-parallel structure"}); - std::vector model_options = { - "transformer", "inception_v3", "split_test", "single_operator"}; + std::vector model_options = {"transformer", + "inception_v3", + "candle_uno", + "split_test", + "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( cli, CLIPositionalArgumentSpec{ diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index c9d84a8948..08dbeba218 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,5 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" #include "models/transformer/transformer.h" @@ -302,6 +303,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(sp_decomposition.has_value()); } + + SUBCASE("candle_uno") { + ComputationGraph cg = + get_candle_uno_computation_graph(get_default_candle_uno_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } } } diff --git a/lib/models/include/models/candle_uno/candle_uno.h b/lib/models/include/models/candle_uno/candle_uno.h new file mode 100644 index 0000000000..a2d21f2830 --- /dev/null +++ b/lib/models/include/models/candle_uno/candle_uno.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_CANDLE_UNO_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_CANDLE_UNO_H + +#include "candle_uno_config.dtg.h" +#include "pcg/computation_graph_builder.h" +#include +#include +#include + +namespace FlexFlow { + +// Helper functions to construct the Candle Uno model +tensor_guid_t create_candle_uno_feature_model(ComputationGraphBuilder &, + CandleUnoConfig const &, + tensor_guid_t const &); + +/** + * @brief Get the default configs of Candle Uno model. + * + * @details The default configs come from the dataset used by the original + * model: + * https://github.com/ECP-CANDLE/Benchmarks/tree/f6a3da8818308c9edcd9fedbcf831dd5968efcdd/Pilot1/Uno + */ +CandleUnoConfig get_default_candle_uno_config(); + +/** + * @brief Get the Candle Uno computation graph. + * + * @details CandleUnoConfig.feature_shapes is a map from feature name to the + * number of channels for the feature, and CandleUnoConfig.input_features is a + * map from specific data identifier in the dataset to the feature name used in + * this model. + * + * @param CandleUnoConfig The config of the Candle Uno model. + * @return ComputationGraph The PCG of a Transformer model. + */ +ComputationGraph get_candle_uno_computation_graph(CandleUnoConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/candle_uno/candle_uno_config.struct.toml b/lib/models/include/models/candle_uno/candle_uno_config.struct.toml new file mode 100644 index 0000000000..667a6531c3 --- /dev/null +++ b/lib/models/include/models/candle_uno/candle_uno_config.struct.toml @@ -0,0 +1,52 @@ +namespace = "FlexFlow" +name = "CandleUnoConfig" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "", + "", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/fmt/map.h", + "utils/hash/vector.h", + "utils/hash/map.h", +] + +[[fields]] +name = "batch_size" +type = "size_t" + +[[fields]] +name = "dense_layers" +type = "std::vector" + +[[fields]] +name = "dense_feature_layers" +type = "std::vector" + +[[fields]] +name = "feature_shapes" +type = "std::map" + +[[fields]] +name = "input_features" +type = "std::map" + +[[fields]] +name = "dropout" +type = "float" + +[[fields]] +name = "residual" +type = "bool" diff --git a/lib/models/src/models/candle_uno/candle_uno.cc b/lib/models/src/models/candle_uno/candle_uno.cc new file mode 100644 index 0000000000..4d52d515fb --- /dev/null +++ b/lib/models/src/models/candle_uno/candle_uno.cc @@ -0,0 +1,123 @@ +#include "models/candle_uno/candle_uno.h" +#include "pcg/initializers/glorot_normal_attrs.dtg.h" + +namespace FlexFlow { + +CandleUnoConfig get_default_candle_uno_config() { + CandleUnoConfig config{ + /*batch_size=*/64, + /*dense_layers=*/std::vector(4, 4192), + /*dense_feature_layers=*/std::vector(8, 4192), + /*feature_shapes=*/std::map{}, + /*input_features=*/std::map{}, + /*dropout=*/0.1, + /*residual=*/false}; + + config.feature_shapes["dose"] = 1; + config.feature_shapes["cell.rnaseq"] = 942; + config.feature_shapes["drug.descriptors"] = 5270; + config.feature_shapes["drug.fingerprints"] = 2048; + + config.input_features["dose1"] = "dose"; + config.input_features["dose2"] = "dose"; + config.input_features["cell.rnaseq"] = "cell.rnaseq"; + config.input_features["drug1.descriptors"] = "drug.descriptors"; + config.input_features["drug1.fingerprints"] = "drug.fingerprints"; + config.input_features["drug2.descriptors"] = "drug.descriptors"; + config.input_features["drug2.fingerprints"] = "drug.fingerprints"; + + return config; +} + +tensor_guid_t create_candle_uno_feature_model( + ComputationGraphBuilder &cgb, + CandleUnoConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &kernel_initializer) { + tensor_guid_t t = input; + for (int const dense_dim : config.dense_feature_layers) { + t = cgb.dense(t, + dense_dim, + Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + if (config.dropout > 0) { + t = cgb.dropout(t, config.dropout); + } + } + return t; +} + +ComputationGraph + get_candle_uno_computation_graph(CandleUnoConfig const &config) { + ComputationGraphBuilder cgb; + InitializerAttrs kernel_initializer = + InitializerAttrs{GlorotNormalAttrs{/*seed=*/0}}; + + auto create_input_tensor = + [&](FFOrdered const &dims) -> tensor_guid_t { + TensorShape input_shape = TensorShape{ + TensorDims{dims}, + DataType::FLOAT, + }; + return cgb.create_input(input_shape, CreateGrad::YES); + }; + + std::set input_models; + for (auto const &shape : config.feature_shapes) { + auto const &type = shape.first; + if (type.find(".") != std::string::npos) { + std::string base_type = type.substr(0, type.find(".")); + // The string parsing here is to match with original implementation at + // https://github.com/ECP-CANDLE/Benchmarks/blob/f6a3da8818308c9edcd9fedbcf831dd5968efcdd/Pilot1/Uno/uno_baseline_keras2.py#L178. + if (base_type == "cell" || base_type == "drug") { + input_models.insert(type); + } + } + } + + std::vector all_inputs; + std::vector encoded_inputs; + + for (auto const &input_feature : config.input_features) { + std::string const &feature_name = input_feature.second; + size_t shape = config.feature_shapes.at(feature_name); + tensor_guid_t input = create_input_tensor({config.batch_size, shape}); + all_inputs.push_back(input); + + if (contains(input_models, feature_name)) { + encoded_inputs.emplace_back(create_candle_uno_feature_model( + cgb, config, input, kernel_initializer)); + } else { + encoded_inputs.emplace_back(input); + } + } + + tensor_guid_t output = cgb.concat(encoded_inputs, /*axis=*/1); + for (int const &dense_layer_dim : config.dense_layers) { + tensor_guid_t residual_input = output; + output = cgb.dense(output, + dense_layer_dim, + Activation::RELU, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + if (config.dropout > 0) { + output = cgb.dropout(output, config.dropout); + } + if (config.residual) { + output = cgb.add(output, residual_input); + } + } + output = cgb.dense(output, + /*outDim=*/1, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/kernel_initializer); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/src/models/candle_uno/candle_uno.cc b/lib/models/test/src/models/candle_uno/candle_uno.cc new file mode 100644 index 0000000000..e32c5b5486 --- /dev/null +++ b/lib/models/test/src/models/candle_uno/candle_uno.cc @@ -0,0 +1,19 @@ +#include "models/candle_uno/candle_uno.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_candle_uno_computation_graph") { + CandleUnoConfig config = get_default_candle_uno_config(); + + ComputationGraph result = get_candle_uno_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 142; + CHECK(result_num_layers == correct_num_layers); + } + } +} diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h index 7a89b4bd78..1b8361abf6 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dims.h @@ -17,6 +17,16 @@ size_t num_shard_dims(ParallelTensorDims const &); ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &); +ParallelTensorDims lift_to_parallel(TensorDims const &); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + SumDegree const &, + DiscardCopyDegree const &, + FFOrdered const &shard_degrees); +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &, + ParallelTensorDimDegrees const &); + int total_replica_degree(ParallelTensorDims const &); int total_shard_degree(ParallelTensorDims const &); int total_parallel_degree(ParallelTensorDims const &); diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index 806a5f0de7..a03151160b 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -23,8 +23,8 @@ ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorShape const &); ParallelTensorShape lift_to_parallel(TensorShape const &); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, + SumDegree const &, + DiscardCopyDegree const &, FFOrdered const &shard_degrees); ParallelTensorShape lift_to_parallel_with_degrees(TensorShape const &, diff --git a/lib/op-attrs/include/op-attrs/tensor_dims.h b/lib/op-attrs/include/op-attrs/tensor_dims.h index c8af3b02e7..ee44a39170 100644 --- a/lib/op-attrs/include/op-attrs/tensor_dims.h +++ b/lib/op-attrs/include/op-attrs/tensor_dims.h @@ -17,13 +17,6 @@ bool tensor_dims_is_broadcastable_to(TensorDims const &curr, std::optional get_broadcast_target_dims(std::unordered_set const &); -ParallelTensorDims lift_to_parallel(TensorDims const &); -ParallelTensorDims - lift_to_parallel_with_degrees(TensorDims const &, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, - FFOrdered const &shard_degrees); - } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc index 61062b84b0..2955545561 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dims.cc @@ -1,8 +1,10 @@ #include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/dim_ordered/transform.h" +#include "op-attrs/dim_ordered/zip.h" #include "op-attrs/replica_parallel_dim.h" #include "op-attrs/replica_parallel_dim_set.h" #include "op-attrs/shard_parallel_dim.h" +#include "op-attrs/tensor_dims.h" #include "utils/containers/all_of.h" #include "utils/containers/product.h" #include "utils/containers/transform.h" @@ -37,6 +39,42 @@ ParallelTensorDimDegrees get_parallel_degrees(ParallelTensorDims const &d) { }; } +ParallelTensorDims lift_to_parallel(TensorDims const &dims) { + std::vector shard_degrees(num_dims(dims), + 1); // 1 repeated num_dims(dims) times + return lift_to_parallel_with_degrees( + dims, SumDegree{1}, DiscardCopyDegree{1}, shard_degrees); +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &unpar, + SumDegree const &sum_degree, + DiscardCopyDegree const &discard_copy_degree, + FFOrdered const &shard_degrees) { + std::vector lifted = + transform(zip(vector_of(unpar.ff_ordered), vector_of(shard_degrees)), + [](std::pair const &p) { + size_t size = p.first; + int degree = p.second; + return ShardParallelDim{size, degree}; + }); + + return ParallelTensorDims{FFOrdered{lifted}, + ReplicaParallelDimSet{ + sum_degree, + discard_copy_degree, + }}; +} + +ParallelTensorDims + lift_to_parallel_with_degrees(TensorDims const &unpar, + ParallelTensorDimDegrees const °rees) { + return lift_to_parallel_with_degrees(unpar, + degrees.sum_degree, + degrees.discard_copy_degree, + degrees.shard_degrees); +} + int total_replica_degree(ParallelTensorDims const &dims) { return dims.replica_dims.discard_copy_degree.value * dims.replica_dims.sum_degree.value; diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 3cd0f47a5d..0663795db5 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -68,24 +68,24 @@ ParallelTensorShape lift_to_parallel(TensorShape const &s) { } ParallelTensorShape - lift_to_parallel_with_degrees(TensorShape const &s, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, + lift_to_parallel_with_degrees(TensorShape const &unpar, + SumDegree const &sum_degree, + DiscardCopyDegree const &discard_copy_degree, FFOrdered const &shard_degrees) { return ParallelTensorShape{ lift_to_parallel_with_degrees( - s.dims, sum_degree, discard_copy_degree, shard_degrees), - s.data_type, + unpar.dims, sum_degree, discard_copy_degree, shard_degrees), + unpar.data_type, }; } ParallelTensorShape - lift_to_parallel_with_degrees(TensorShape const &s, + lift_to_parallel_with_degrees(TensorShape const &unpar, ParallelTensorDimDegrees const °rees) { - return lift_to_parallel_with_degrees(s, - degrees.sum_degree, - degrees.discard_copy_degree, - degrees.shard_degrees); + return ParallelTensorShape{ + lift_to_parallel_with_degrees(unpar.dims, degrees), + unpar.data_type, + }; } TensorShape require_not_parallel(ParallelTensorShape const &s) { diff --git a/lib/op-attrs/src/op-attrs/tensor_dims.cc b/lib/op-attrs/src/op-attrs/tensor_dims.cc index ba7d6e8357..1bb050db52 100644 --- a/lib/op-attrs/src/op-attrs/tensor_dims.cc +++ b/lib/op-attrs/src/op-attrs/tensor_dims.cc @@ -59,31 +59,4 @@ std::optional return std::nullopt; } -ParallelTensorDims lift_to_parallel(TensorDims const &dims) { - std::vector shard_degrees(num_dims(dims), - 1); // 1 repeated num_dims(dims) times - return lift_to_parallel_with_degrees( - dims, SumDegree{1}, DiscardCopyDegree{1}, shard_degrees); -} - -ParallelTensorDims - lift_to_parallel_with_degrees(TensorDims const &dims, - SumDegree sum_degree, - DiscardCopyDegree discard_copy_degree, - FFOrdered const &shard_degrees) { - std::vector lifted = - transform(zip(vector_of(dims.ff_ordered), vector_of(shard_degrees)), - [](std::pair const &p) { - size_t size = p.first; - int degree = p.second; - return ShardParallelDim(size, degree); - }); - - return ParallelTensorDims{FFOrdered{lifted}, - ReplicaParallelDimSet{ - sum_degree, - discard_copy_degree, - }}; -} - } // namespace FlexFlow diff --git a/lib/op-attrs/test/src/op-attrs/ops/concat.cc b/lib/op-attrs/test/src/op-attrs/ops/concat.cc new file mode 100644 index 0000000000..9e842c3ebe --- /dev/null +++ b/lib/op-attrs/test/src/op-attrs/ops/concat.cc @@ -0,0 +1,331 @@ +#include "op-attrs/ops/concat.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "test/utils/doctest/fmt/expected.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/expected.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { + ConcatAttrs attrs = ConcatAttrs{ + ff_dim_t{1}, + }; + + SUBCASE("empty input shapes list passed") { + std::vector input_shapes = {}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + size_t dim0_size = 12; + size_t dim2_size = 20; + TensorShape input_shape1 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14, + dim2_size, + }}, + DataType::FLOAT, + }; + + SUBCASE("single element input shapes list passed") { + std::vector input_shapes = {input_shape1}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + TensorShape input_shape2 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 16, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape3 = TensorShape{ + TensorDims{FFOrdered{dim0_size, 18, dim2_size}}, + DataType::FLOAT, + }; + + SUBCASE("input shapes do not shared the same num_dims") { + TensorShape mismatched_num_dims = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 20, + dim2_size, + 1, + }}, + DataType::FLOAT, + }; + + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3, mismatched_num_dims}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("concat axis is out of bounds") { + attrs = ConcatAttrs{ + ff_dim_t{3}, + }; + + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3}; + + std::optional result = + optional_from_expected(get_output_shape(attrs, input_shapes)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("input shapes are valid") { + std::vector input_shapes = { + input_shape1, input_shape2, input_shape3}; + + tl::expected result = + get_output_shape(attrs, input_shapes); + tl::expected correct = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14 + 16 + 18, + dim2_size, + }}, + DataType::FLOAT, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("get_output_shape(ConcatAttrs, std::vector)") { + ConcatAttrs attrs = ConcatAttrs{ + ff_dim_t{1}, + }; + + size_t dim0_size = 12; + size_t dim2_size = 20; + + TensorShape input_shape1 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 14, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape2 = TensorShape{ + TensorDims{FFOrdered{ + dim0_size, + 16, + dim2_size, + }}, + DataType::FLOAT, + }; + + TensorShape input_shape3 = TensorShape{ + TensorDims{FFOrdered{dim0_size, 18, dim2_size}}, + DataType::FLOAT, + }; + + TensorShape output_shape = TensorShape{ + TensorDims{FFOrdered{dim0_size, 14 + 16 + 18, dim2_size}}, + DataType::FLOAT, + }; + + auto lift_input1 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape1, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_input2 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape2, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_input3 = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + input_shape3, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + auto lift_output = + [&](SumDegree o_sum, DiscardCopyDegree o_eq, int o0, int o1, int o2) { + return lift_to_parallel_with_degrees( + output_shape, o_sum, o_eq, FFOrdered{o0, o1, o2}); + }; + + SUBCASE("sum reduction parallelism") { + SUBCASE("matching") { + SumDegree sum_degree = SumDegree{2}; + + std::vector inputs = { + lift_input1(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(sum_degree, DiscardCopyDegree{1}, 1, 1, 1), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(sum_degree, DiscardCopyDegree{1}, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{2}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(SumDegree{4}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(SumDegree{4}, DiscardCopyDegree{1}, 1, 1, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("discard copy reduction parallelism") { + SUBCASE("matching") { + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{2}; + + std::vector inputs = { + lift_input1(SumDegree{1}, discard_copy_degree, 1, 1, 1), + lift_input2(SumDegree{1}, discard_copy_degree, 1, 1, 1), + lift_input3(SumDegree{1}, discard_copy_degree, 1, 1, 1), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(SumDegree{1}, discard_copy_degree, 1, 1, 1); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{2}, 1, 1, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{2}, 1, 1, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{4}, 1, 1, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism in axis dim") { + SUBCASE("matching") { + int degree = 2; + + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 1, degree, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 1, 1, 1), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 1, 2, 1), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism in non-axis shard dims") { + SUBCASE("matching") { + int degree0 = 2; + int degree2 = 4; + + std::vector inputs = { + lift_input1( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + lift_input2( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + lift_input3( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = lift_output( + SumDegree{1}, DiscardCopyDegree{1}, degree0, 1, degree2); + + CHECK(result == correct); + } + + SUBCASE("not matching") { + std::vector inputs = { + lift_input1(SumDegree{1}, DiscardCopyDegree{1}, 2, 1, 4), + lift_input2(SumDegree{1}, DiscardCopyDegree{1}, 4, 1, 2), + lift_input3(SumDegree{1}, DiscardCopyDegree{1}, 4, 1, 2), + }; + + std::optional result = + optional_from_expected(get_output_shape(attrs, inputs)); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } + + SUBCASE("parallelism degrees are not mutually exclusive") { + SumDegree sum_degree = SumDegree{3}; + DiscardCopyDegree discard_copy_degree = DiscardCopyDegree{5}; + int degree0 = 2; + int degree2 = 4; + + std::vector inputs = { + lift_input1(sum_degree, discard_copy_degree, degree0, 1, degree2), + lift_input2(sum_degree, discard_copy_degree, degree0, 1, degree2), + lift_input3(sum_degree, discard_copy_degree, degree0, 1, degree2), + }; + + tl::expected result = + get_output_shape(attrs, inputs); + tl::expected correct = + lift_output(sum_degree, discard_copy_degree, degree0, 1, degree2); + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 1ea9ce05a6..4b7d10ef5e 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -11,6 +11,7 @@ features = [ includes = [ "pcg/initializers/glorot_uniform_attrs.dtg.h", + "pcg/initializers/glorot_normal_attrs.dtg.h", "pcg/initializers/zero_initializer_attrs.dtg.h", "pcg/initializers/uniform_initializer_attrs.h", "pcg/initializers/norm_initializer_attrs.dtg.h", @@ -21,6 +22,10 @@ includes = [ type = "::FlexFlow::GlorotUniformAttrs" key = "glorot_uniform" +[[values]] +type = "::FlexFlow::GlorotNormalAttrs" +key = "glorot_normal" + [[values]] type = "::FlexFlow::ZeroInitializerAttrs" key = "zero" diff --git a/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml b/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml new file mode 100644 index 0000000000..fd0d8eb9be --- /dev/null +++ b/lib/pcg/include/pcg/initializers/glorot_normal_attrs.struct.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "GlorotNormalAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 20ab6ce440..cb652a9e69 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -11,9 +11,6 @@ namespace FlexFlow { -template -Element sum(Container const &container); - template @@ -68,9 +65,6 @@ std::optional maybe_get_only(C const &c); template std::optional optional_all_of(Container const &, Function const &); -template -bool are_all_same(C const &c); - template std::function compare_by(F const &f); diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index f60ef77cda..6ac9eb10b0 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -31,15 +31,6 @@ namespace FlexFlow { -template -Element sum(Container const &container) { - Element result = 0; - for (Element const &element : container) { - result += element; - } - return result; -} - template Element sum_where(Container const &container, ConditionF const &condition) { Element result = 0; @@ -135,17 +126,6 @@ std::optional optional_all_of(Container const &container, return true; } -template -bool are_all_same(C const &c) { - auto const &first = *c.cbegin(); - for (auto const &v : c) { - if (v != first) { - return false; - } - } - return true; -} - template std::vector flatmap(std::vector const &v, F const &f) { std::vector result; From cf96db6c069b812dac3537f145e68e254596fa7e Mon Sep 17 00:00:00 2001 From: hsdfzhsdfz <157999659+hsdfzhsdfz@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:12:23 -0400 Subject: [PATCH 08/12] Add BERT model computation graph (#1488) * Add initial bert model structure * Update following review * Rename config * Add additional bert configs * Update based on reviewing * Added assert checks * Add error message for unsupported BertConfig.position_embedding_type * Format * fix typo * Add bert to export-model-arch * Format --------- Co-authored-by: Colin Unger --- .../src/export_model_arch.cc | 4 + ...ion_graph_series_parallel_decomposition.cc | 35 ++++ lib/models/include/models/bert/bert.h | 41 +++++ .../models/bert/bert_config.struct.toml | 71 ++++++++ lib/models/src/models/bert/bert.cc | 160 ++++++++++++++++++ lib/models/test/src/models/bert/bert.cc | 33 ++++ .../pcg/initializer_attrs.variant.toml | 5 + ...cated_normal_initializer_attrs.struct.toml | 30 ++++ 8 files changed, 379 insertions(+) create mode 100644 lib/models/include/models/bert/bert.h create mode 100644 lib/models/include/models/bert/bert_config.struct.toml create mode 100644 lib/models/src/models/bert/bert.cc create mode 100644 lib/models/test/src/models/bert/bert.cc create mode 100644 lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 9833102bd1..1c2dfd6ea3 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,6 +1,7 @@ #include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" +#include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" @@ -66,6 +67,8 @@ tl::expected get_default_inception_v3_training_config()); } else if (model_name == "candle_uno") { return get_candle_uno_computation_graph(get_default_candle_uno_config()); + } else if (model_name == "bert") { + return get_bert_computation_graph(get_default_bert_config()); } else if (model_name == "split_test") { int batch_size = 8; return get_split_test_computation_graph(batch_size); @@ -141,6 +144,7 @@ int main(int argc, char **argv) { std::vector model_options = {"transformer", "inception_v3", "candle_uno", + "bert", "split_test", "single_operator"}; CLIArgumentKey key_model_name = cli_add_positional_argument( diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc index 08dbeba218..564cffaebe 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,5 @@ #include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" #include "models/split_test/split_test.h" @@ -313,6 +314,16 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(sp_decomposition.has_value()); } + + SUBCASE("bert") { + ComputationGraph cg = + get_bert_computation_graph(get_default_bert_config()); + + std::optional sp_decomposition = + get_computation_graph_series_parallel_decomposition(cg); + + CHECK(sp_decomposition.has_value()); + } } } @@ -358,5 +369,29 @@ TEST_SUITE(FF_TEST_SUITE) { std::string result = render_preprocessed_computation_graph_for_sp_decomposition(cg); } + + SUBCASE("inception_v3") { + ComputationGraph cg = get_inception_v3_computation_graph( + get_default_inception_v3_training_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("candle_uno") { + ComputationGraph cg = + get_candle_uno_computation_graph(get_default_candle_uno_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } + + SUBCASE("bert") { + ComputationGraph cg = + get_bert_computation_graph(get_default_bert_config()); + + std::string result = + render_preprocessed_computation_graph_for_sp_decomposition(cg); + } } } diff --git a/lib/models/include/models/bert/bert.h b/lib/models/include/models/bert/bert.h new file mode 100644 index 0000000000..0047996b78 --- /dev/null +++ b/lib/models/include/models/bert/bert.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H +#define _FLEXFLOW_LIB_MODELS_INCLUDE_MODELS_BERT_H + +#include "models/bert/bert_config.dtg.h" +#include "pcg/computation_graph_builder.h" + +namespace FlexFlow { + +// Helper functions to construct the BERT model +tensor_guid_t create_bert_feedforward_network(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); +tensor_guid_t create_bert_encoder_layer(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); +tensor_guid_t create_bert_encoder(ComputationGraphBuilder &, + BertConfig const &, + tensor_guid_t const &); + +/** + * @brief Get the base config of the BERT model. + * + * @details Refer to + * https://huggingface.co/docs/transformers/v4.18.0/en/model_doc/bert#transformers.BertConfig + * for default configs. + */ +BertConfig get_default_bert_config(); + +/** + * @brief Get the BERT computation graph. + * + * @note This is a plain encoder-only model for pre-training. + * + * @param BertConfig The config of BERT model. + * @return ComputationGraph The computation graph of a BERT model. + */ +ComputationGraph get_bert_computation_graph(BertConfig const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/models/include/models/bert/bert_config.struct.toml b/lib/models/include/models/bert/bert_config.struct.toml new file mode 100644 index 0000000000..398210cf48 --- /dev/null +++ b/lib/models/include/models/bert/bert_config.struct.toml @@ -0,0 +1,71 @@ +namespace = "FlexFlow" +name = "BertConfig" + +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +includes = [ + "op-attrs/activation.dtg.h", +] + +[[fields]] +name = "vocab_size" +type = "size_t" + +[[fields]] +name = "hidden_size" +type = "size_t" + +[[fields]] +name = "num_encoder_layers" +type = "size_t" + +[[fields]] +name = "num_heads" +type = "size_t" + +[[fields]] +name = "dim_feedforward" +type = "size_t" + +[[fields]] +name = "hidden_act" +type = "::FlexFlow::Activation" + +[[fields]] +name = "hidden_dropout_prob" +type = "float" + +[[fields]] +name = "attention_probs_dropout_prob" +type = "float" + +[[fields]] +name = "initializer_range" +type = "float" + +[[fields]] +name = "layer_norm_eps" +type = "float" + +[[fields]] +name = "position_embedding_type" +type = "std::string" + +[[fields]] +name = "classifier_dropout" +type = "float" + +[[fields]] +name = "sequence_length" +type = "size_t" + +[[fields]] +name = "batch_size" +type = "size_t" diff --git a/lib/models/src/models/bert/bert.cc b/lib/models/src/models/bert/bert.cc new file mode 100644 index 0000000000..cf48f2399b --- /dev/null +++ b/lib/models/src/models/bert/bert.cc @@ -0,0 +1,160 @@ +#include "models/bert/bert.h" +#include "op-attrs/tensor_shape.h" +#include "pcg/computation_graph.h" +#include "pcg/initializers/truncated_normal_initializer_attrs.dtg.h" + +namespace FlexFlow { + +BertConfig get_default_bert_config() { + return BertConfig{/*vocab_size=*/30522, + /*hidden_size=*/768, + /*num_encoder_layers=*/12, + /*num_heads=*/12, + /*dim_feedforward=*/3072, + /*hidden_act=*/Activation::GELU, + /*hidden_dropout_prob=*/0.1, + /*attention_probs_dropout_prob=*/0.1, + /*initializer_range=*/0.02, + /*layer_norm_eps=*/1e-12, + /*position_embedding_type=*/"absolute", + /*classifier_dropout=*/0.1, + /*sequence_length=*/512, + /*batch_size=*/64}; +} + +tensor_guid_t + create_feedforward_network(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + tensor_guid_t layer1_out = + cgb.dense(input, + config.dim_feedforward, + /*activation=*/config.hidden_act, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer); + tensor_guid_t dropout_out = + cgb.dropout(layer1_out, config.hidden_dropout_prob); + tensor_guid_t layer2_out = + cgb.dense(dropout_out, + config.hidden_size, + /*activation=*/std::nullopt, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer); + return cgb.dropout(layer2_out, config.hidden_dropout_prob); +}; + +tensor_guid_t + create_bert_encoder_layer(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + assert(num_dims(cgb.get_shape(input)) == 3); + std::vector layer_norm_axis = {2}; // Apply layernorm across the last dim + int kdim = config.dim_feedforward / config.num_heads; + int vdim = config.dim_feedforward / config.num_heads; + tensor_guid_t self_attention = + cgb.multihead_attention(input, + input, + input, + config.hidden_size, + config.num_heads, + kdim, + vdim, + /*dropout=*/config.attention_probs_dropout_prob, + /*bias=*/true, + /*add_bias_kv=*/false, + /*add_zero_attn=*/false, + /*initializer=*/projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, self_attention)); + + tensor_guid_t normalized = cgb.layer_norm(cgb.add(self_attention, input), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, normalized)); + + tensor_guid_t feedforward_output = create_feedforward_network( + cgb, config, normalized, bias_initializer, projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, feedforward_output)); + return cgb.layer_norm(cgb.add(normalized, feedforward_output), + layer_norm_axis, + /*elementwise_affine=*/true, + config.layer_norm_eps); +} + +tensor_guid_t + create_bert_encoder(ComputationGraphBuilder &cgb, + BertConfig const &config, + tensor_guid_t const &input, + InitializerAttrs const &bias_initializer, + InitializerAttrs const &projection_initializer) { + tensor_guid_t t = input; + for (int i = 0; i < config.num_encoder_layers; i++) { + t = create_bert_encoder_layer( + cgb, config, t, bias_initializer, projection_initializer); + } + return t; +}; + +ComputationGraph get_bert_computation_graph(BertConfig const &config) { + if (config.position_embedding_type != "absolute") { + throw mk_runtime_error( + fmt::format("Currently only position_embedding_type=absolute is " + "supported, but found position_embedding_type={}. " + "If you need support for additional " + "position_embedding_type values, please create an issue.", + config.position_embedding_type)); + } + + ComputationGraphBuilder cgb; + InitializerAttrs projection_initializer = + InitializerAttrs{TruncatedNormalInitializerAttrs{ + /*seed=*/0, + /*mean=*/0, + /*stddev=*/config.initializer_range, + /*min_cutoff=*/-2 * config.initializer_range, + /*max_cutoff=*/2 * config.initializer_range}}; + InitializerAttrs bias_initializer = InitializerAttrs{ZeroInitializerAttrs{}}; + + TensorShape input_shape = TensorShape{ + TensorDims{FFOrdered{ + config.batch_size, config.sequence_length, config.hidden_size}}, + DataType::FLOAT, + }; + tensor_guid_t input = cgb.create_input(input_shape, CreateGrad::YES); + + tensor_guid_t encoder_output = create_bert_encoder( + cgb, config, input, bias_initializer, projection_initializer); + assert(are_tensor_guid_shapes_equivalent( + cgb.computation_graph, input, encoder_output)); + + tensor_guid_t out_prob = + cgb.softmax(cgb.dense(encoder_output, + /*outDim=*/config.vocab_size, + /*activation=*/config.hidden_act, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*projection_initializer=*/projection_initializer, + /*bias_initializer=*/bias_initializer)); + assert( + (cgb.get_shape(out_prob) == + TensorShape{ + TensorDims{FFOrdered{ + config.batch_size, config.sequence_length, config.vocab_size}}, + DataType::FLOAT, + })); + + return cgb.computation_graph; +} + +} // namespace FlexFlow diff --git a/lib/models/test/src/models/bert/bert.cc b/lib/models/test/src/models/bert/bert.cc new file mode 100644 index 0000000000..1defc3a1a2 --- /dev/null +++ b/lib/models/test/src/models/bert/bert.cc @@ -0,0 +1,33 @@ +#include "models/bert/bert.h" +#include "pcg/computation_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_bert_computation_graph") { + + SUBCASE("default config") { + BertConfig config = get_default_bert_config(); + + ComputationGraph result = get_bert_computation_graph(config); + + SUBCASE("num layers") { + int result_num_layers = get_layers(result).size(); + int correct_num_layers = 245; + CHECK(result_num_layers == correct_num_layers); + } + } + + SUBCASE("throws on position_embedding_type != absolute as other values are " + "currently unsupported") { + BertConfig config = [] { + BertConfig c = get_default_bert_config(); + c.position_embedding_type = "relative_key"; + return c; + }(); + + CHECK_THROWS(get_bert_computation_graph(config)); + } + } +} diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 4b7d10ef5e..2e878c5c53 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -15,6 +15,7 @@ includes = [ "pcg/initializers/zero_initializer_attrs.dtg.h", "pcg/initializers/uniform_initializer_attrs.h", "pcg/initializers/norm_initializer_attrs.dtg.h", + "pcg/initializers/truncated_normal_initializer_attrs.dtg.h", "pcg/initializers/constant_initializer_attrs.dtg.h", ] @@ -38,6 +39,10 @@ key = "uniform" type = "::FlexFlow::NormInitializerAttrs" key = "normal" +[[values]] +type = "::FlexFlow::TruncatedNormalInitializerAttrs" +key = "normal" + [[values]] type = "::FlexFlow::ConstantInitializerAttrs" key = "constant" diff --git a/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml b/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml new file mode 100644 index 0000000000..9e4ec0272d --- /dev/null +++ b/lib/pcg/include/pcg/initializers/truncated_normal_initializer_attrs.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "TruncatedNormalInitializerAttrs" +features = [ + "eq", + "ord", + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[fields]] +name = "seed" +type = "int" + +[[fields]] +name = "mean" +type = "float" + +[[fields]] +name = "stddev" +type = "float" + +[[fields]] +name = "min_cutoff" +type = "float" + +[[fields]] +name = "max_cutoff" +type = "float" From b77598c321d3070df31803003e5b808b995fda75 Mon Sep 17 00:00:00 2001 From: Mengdi Wu <48128384+wmdi@users.noreply.github.com> Date: Tue, 8 Oct 2024 00:46:06 -0400 Subject: [PATCH 09/12] Unity device mapping algorithm (#1459) * pass existing tests * unity algorithm builds * fmt * fix * refactor machine mapping * add unit tests * fmt * add more tests * fmt * fix * refactor get_optimal_machine_mapping a bit and improve the tests * remove debug codes * A lot of simplifying and modularizing of unity dp code * Get tests building again * Get all the new testcases working * Move over to ProblemTree/ResultTree framework for machine mapping * Settle on ProblemTree/BinaryTreePath-indexed-MachineMappingResult for machine mapping * More code cleanup and PR prep * Get tests building again * Pass some basic tests of get_optimal_machine_mapping * Migrate over to use type-erased binary tree * Move back to templated FullBinaryTree * Get all existing tests passing again * Fix tests and format * Move graph_optimize_state.cc to correct location --------- Co-authored-by: Mengdi Wu Co-authored-by: Colin Unger Co-authored-by: Colin Unger --- .github/workflows/per-lib-check.yml | 6 +- .../json_sp_model_export.struct.toml | 10 +- .../src/export_model_arch.cc | 10 +- flake.lock | 6 +- lib/compiler/include/compiler/cost_estimate.h | 61 --- .../compiler/cost_estimator/cost_estimator.h | 45 +++ .../op_cost_estimate_key.struct.toml | 40 ++ .../single_tensor_movement.struct.toml | 30 ++ .../tensor_set_movement.struct.toml | 21 + .../graph_optimize_result.struct.toml | 16 + .../include/compiler/graph_optimize_state.h | 31 ++ lib/compiler/include/compiler/graph_utils.h | 37 -- .../include/compiler/machine_mapping.h | 70 ---- ...tracted_single_tensor_movement.struct.toml | 30 ++ .../abstracted_tensor_set_movement.h | 25 ++ ...abstracted_tensor_set_movement.struct.toml | 21 + ...tracted_tensor_set_movement_across_split.h | 16 + ...easible_machine_mapping_result.struct.toml | 19 + .../get_machine_resource_splits.h | 15 + .../get_optimal_machine_mapping.h | 47 +++ .../get_tensor_set_movement_across_split.h | 19 + .../include_unconstrained.struct.toml | 16 + .../machine_mapping/machine_mapping.h | 15 + .../machine_mapping.struct.toml | 11 +- .../machine_mapping/machine_mapping_cache.h | 18 + .../machine_mapping_cache.struct.toml | 22 ++ .../machine_mapping_constraints.h | 39 ++ .../machine_mapping_constraints.struct.toml | 23 ++ .../machine_mapping_context.struct.toml | 18 + .../get_machine_mapping_problem_tree.h | 18 + .../machine_mapping_problem_tree.h | 32 ++ .../machine_mapping_problem_tree.variant.toml | 25 ++ ...mm_problem_tree_parallel_split.struct.toml | 27 ++ .../mm_problem_tree_series_split.struct.toml | 33 ++ .../unmapped_op_cost_estimate_key.h | 20 + .../unmapped_op_cost_estimate_key.struct.toml | 36 ++ .../machine_mapping/machine_mapping_result.h | 36 ++ .../machine_mapping_result.struct.toml | 20 + .../machine_mapping_state.struct.toml | 25 ++ ...lel_layer_guid_oblivious_machine_mapping.h | 23 ++ ...guid_oblivious_machine_mapping.struct.toml | 21 + .../parallel_split_transformation.enum.toml | 14 + .../pcg_split_boundary_layers.struct.toml | 24 ++ .../machine_mapping/transitive_reduced_pcg.h | 34 ++ .../transitive_reduced_pcg.struct.toml | 16 + .../compiler/optimal_cost_state.struct.toml | 36 -- .../compiler/optimizer_config.struct.toml | 26 ++ ...on_graph_binary_parallel_split.struct.toml | 27 ++ ...tion_graph_binary_series_split.struct.toml | 27 ++ ...omputation_graph_binary_sp_decomposition.h | 24 +- ...graph_binary_sp_decomposition.variant.toml | 25 ++ ...tion_graph_series_parallel_decomposition.h | 0 ..._graph_binary_sp_decomposition.struct.toml | 22 -- ...get_pcg_balanced_binary_sp_decomposition.h | 11 + .../get_pcg_series_parallel_decomposition.h | 14 + .../pcg/pcg_binary_parallel_split.h | 14 + .../pcg/pcg_binary_parallel_split.struct.toml | 27 ++ .../pcg/pcg_binary_series_split.h | 14 + .../pcg/pcg_binary_series_split.struct.toml | 27 ++ .../pcg/pcg_binary_sp_decomposition.h | 38 ++ .../pcg_binary_sp_decomposition.variant.toml | 25 ++ .../include/compiler/unity_algorithm.h | 34 +- .../compiler/cost_estimator/cost_estimator.cc | 16 + .../src/compiler/graph_optimize_state.cc | 85 ++++ .../abstracted_tensor_set_movement.cc | 62 +++ ...racted_tensor_set_movement_across_split.cc | 63 +++ .../get_machine_resource_splits.cc | 32 ++ .../get_optimal_machine_mapping.cc | 254 ++++++++++++ .../get_tensor_set_movement_across_split.cc | 26 ++ .../machine_mapping/machine_mapping.cc | 18 + .../machine_mapping/machine_mapping_cache.cc | 30 ++ .../machine_mapping_constraints.cc | 112 ++++++ .../get_machine_mapping_problem_tree.cc | 53 +++ .../machine_mapping_problem_tree.cc | 91 +++++ .../unmapped_op_cost_estimate_key.cc | 36 ++ .../machine_mapping/machine_mapping_result.cc | 138 +++++++ ...el_layer_guid_oblivious_machine_mapping.cc | 24 ++ .../machine_mapping/transitive_reduced_pcg.cc | 93 +++++ ...mputation_graph_binary_sp_decomposition.cc | 192 +++++++++ ...ion_graph_series_parallel_decomposition.cc | 2 +- ...mputation_graph_binary_sp_decomposition.cc | 90 ----- .../get_pcg_series_parallel_decomposition.cc | 10 + .../pcg/pcg_binary_parallel_split.cc | 14 + .../pcg/pcg_binary_series_split.cc | 14 + .../pcg/pcg_binary_sp_decomposition.cc | 115 ++++++ lib/compiler/src/graph_utils.cc | 153 -------- lib/compiler/src/machine_mapping.cc | 366 ------------------ lib/compiler/src/unity_algorithm.cc | 97 +++-- ...racted_tensor_set_movement_across_split.cc | 300 ++++++++++++++ .../cost_estimator_for_test.cc | 41 ++ .../machine_mapping/cost_estimator_for_test.h | 38 ++ .../get_machine_resource_splits.cc | 235 +++++++++++ .../get_optimal_machine_mapping.cc | 236 +++++++++++ .../get_tensor_set_movement_across_split.cc | 239 ++++++++++++ .../machine_mapping/machine_mapping.cc | 55 +++ .../get_machine_mapping_problem_tree.cc | 289 ++++++++++++++ .../machine_mapping/machine_mapping_result.cc | 342 ++++++++++++++++ ...ion_graph_series_parallel_decomposition.cc | 34 +- lib/compiler/test/src/graph_optimize_state.cc | 80 ++++ lib/compiler/test/src/test_cost_estimator.h | 25 -- lib/compiler/test/src/test_generator.h | 174 --------- .../test/src/test_labelled_open_graph.cc | 132 ------- lib/compiler/test/src/test_machine_mapping.cc | 23 -- lib/compiler/test/src/test_open_graph.cc | 81 ---- lib/compiler/test/src/test_optimal_cost.cc | 72 ---- ..._unity_algorithm.cc => unity_algorithm.cc} | 2 - lib/kernels/include/kernels/accessor.h | 12 +- .../include/kernels/datatype_dispatch.h | 2 +- lib/kernels/include/kernels/linear_kernels.h | 2 +- .../include/local-execution/device_specific.h | 9 +- .../include/local-execution/permissions.h | 4 +- .../src/local_task_argument_accessor.cc | 4 +- lib/local-execution/src/ops/batch_matmul.h | 2 +- lib/local-execution/src/ops/batch_norm.h | 2 +- lib/local-execution/src/ops/cast.h | 2 +- lib/local-execution/src/ops/combine.h | 2 +- lib/local-execution/src/ops/concat.h | 2 +- lib/local-execution/src/ops/conv_2d.h | 2 +- lib/local-execution/src/ops/dropout.h | 2 +- lib/local-execution/src/ops/element_binary.h | 2 +- lib/local-execution/src/ops/element_unary.h | 2 +- lib/local-execution/src/ops/embedding.h | 2 +- lib/local-execution/src/ops/flat.h | 2 +- lib/local-execution/src/ops/gather.h | 2 +- lib/local-execution/src/ops/input.h | 2 +- lib/local-execution/src/ops/layer_norm.h | 2 +- lib/local-execution/src/ops/linear.h | 2 +- lib/local-execution/src/ops/noop.h | 4 +- lib/local-execution/src/ops/pool_2d.h | 2 +- lib/local-execution/src/ops/reduce.h | 2 +- lib/local-execution/src/ops/reduction.h | 2 +- lib/local-execution/src/ops/repartition.h | 2 +- lib/local-execution/src/ops/replicate.h | 2 +- lib/local-execution/src/ops/reshape.h | 2 +- lib/local-execution/src/ops/reverse.h | 2 +- lib/local-execution/src/ops/softmax.h | 2 +- lib/local-execution/src/ops/split.h | 2 +- lib/local-execution/src/ops/topk.h | 2 +- lib/local-execution/src/ops/transpose.h | 2 +- lib/local-execution/src/permissions.cc | 3 +- .../include/op-attrs/operator_attrs.h | 50 --- lib/op-attrs/src/operator_attrs.cc | 287 -------------- lib/pcg/include/pcg/computation_graph.h | 5 + .../v1/v1_binary_sp_decomposition/json.h | 29 ++ .../v1_binary_parallel_split.struct.toml | 25 ++ .../v1_binary_series_split.struct.toml | 25 ++ .../v1_binary_sp_decomposition.variant.toml | 24 ++ .../pcg/initializer_attrs.variant.toml | 2 +- .../parallel_computation_graph.h | 19 + .../parallel_layer_attrs.struct.toml | 2 +- lib/pcg/src/pcg/computation_graph.cc | 22 ++ lib/pcg/src/pcg/computation_graph_builder.cc | 11 +- .../v1/v1_binary_sp_decomposition/json.cc | 84 ++++ .../parallel_computation_graph.cc | 71 ++++ .../parallel_computation_graph_builder.cc | 20 +- .../v1/v1_binary_sp_decomposition/json.cc | 178 +++++++++ .../parallel_computation_graph.cc | 50 +++ .../parallel_computation_graph_builder.cc | 1 + lib/runtime/test/src/test_serialization.cc | 1 - .../perform_shape_inference.h | 1 + .../sub_parallel_computation_graph.cc | 8 +- .../perform_shape_inference.cc | 1 + .../perform_shape_inference.cc | 3 + .../utils/any_value_type/any_value_type.h | 66 ++++ .../include/utils/archetypes/value_type.h | 48 +++ lib/utils/include/utils/containers.h | 28 -- .../utils/containers/cartesian_product.h | 39 ++ lib/utils/include/utils/containers/flatmap.h | 19 + .../utils/containers/get_all_assignments.h | 43 ++ lib/utils/include/utils/containers/get_only.h | 4 +- .../include/utils/containers/merge_maps.h | 9 +- .../include/utils/containers/transform.h | 27 +- lib/utils/include/utils/containers/try_at.h | 31 ++ .../containers/unordered_map_from_pairs.h | 17 + lib/utils/include/utils/containers/values.h | 8 +- lib/utils/include/utils/exception.h | 7 +- lib/utils/include/utils/fmt/json.h | 21 + lib/utils/include/utils/fmt/monostate.h | 32 ++ .../utils/full_binary_tree/binary_tree_path.h | 17 + .../binary_tree_path.struct.toml | 24 ++ .../binary_tree_path_entry.enum.toml | 16 + .../full_binary_tree/find_paths_to_leaf.h | 49 +++ ...ull_binary_tree_implementation.struct.toml | 33 ++ .../full_binary_tree_node_type.enum.toml | 16 + .../full_binary_tree_visitor.struct.toml | 22 ++ .../full_binary_tree/get_all_leaf_paths.h | 43 ++ .../utils/full_binary_tree/get_child.h | 28 ++ .../utils/full_binary_tree/get_leaves.h | 33 ++ .../full_binary_tree/get_num_tree_nodes.h | 27 ++ .../full_binary_tree/get_subtree_at_path.h | 36 ++ .../include/utils/full_binary_tree/visit.h | 23 ++ .../get_dataflow_edges_from_node_to_node.h | 13 + ...nsitive_reduced_boundary_nodes_for_split.h | 15 + ...et_transitive_reduced_edges_across_split.h | 14 + ..._transitive_reduced_outputs_across_split.h | 14 + .../split_boundary_nodes.struct.toml | 25 ++ .../transitive_reduced_dataflow_graph.h | 13 + ...nsitive_reduced_dataflow_graph.struct.toml | 17 + .../get_edges_from_subgraph_to_subgraph.h | 14 + ...zy_copy_of_labelled_dataflow_graph_view.h} | 14 +- .../algorithms/rewrite_node_labels.h | 21 + .../binary_parallel_split.struct.toml | 25 ++ .../binary_series_split.struct.toml | 25 ++ .../binary_sp_decomposition_tree.h | 17 +- .../binary_sp_decomposition_tree.struct.toml | 22 -- .../binary_sp_decomposition_tree.variant.toml | 25 ++ .../find_paths_to_leaf.h | 25 ++ .../fmt.h | 63 --- .../generic_binary_sp_decomposition_tree.h | 155 -------- ...ary_sp_decomposition_tree_implementation.h | 73 ++++ ...omposition_tree_implementation.struct.toml | 47 +++ ..._sp_decomposition_tree_visitor.struct.toml | 27 ++ .../get.h | 15 - .../get_all_leaf_paths.h | 25 ++ .../get_leaves.h | 39 +- .../get_left_child.h | 44 --- .../get_node_type.h | 29 -- .../get_num_tree_nodes.h | 39 +- .../get_right_child.h | 44 --- .../get_subtree_at_path.h | 26 ++ .../hash.h | 34 -- .../generic_binary_sp_decomposition_tree/is.h | 25 -- .../is_binary_sp_tree_left_associative.h | 54 ++- .../is_binary_sp_tree_right_associative.h | 53 ++- .../json.h | 103 ----- .../make.h | 39 -- .../require.h | 28 -- .../transform.h | 43 -- .../visit.h | 59 +-- .../parallel_split.struct.toml | 32 ++ .../series_parallel/series_parallel_splits.h | 136 ++++--- .../series_parallel/series_split.struct.toml | 31 ++ .../utils/json/check_is_json_deserializable.h | 14 + .../utils/json/check_is_json_serializable.h | 14 + lib/utils/include/utils/sequence.h | 2 +- lib/utils/include/utils/tuple.h | 4 +- .../utils/any_value_type/any_value_type.cc | 34 ++ lib/utils/src/utils/archetypes/value_type.cc | 7 + .../src/utils/containers/cartesian_product.cc | 1 + .../utils/containers/get_all_assignments.cc | 1 + lib/utils/src/utils/containers/try_at.cc | 1 + .../containers/unordered_map_from_pairs.cc | 1 + lib/utils/src/utils/exception.cc | 8 + lib/utils/src/utils/fmt/json.cc | 7 + lib/utils/src/utils/fmt/monostate.cc | 9 + .../full_binary_tree/binary_tree_path.cc | 34 ++ .../full_binary_tree/find_paths_to_leaf.cc | 15 + .../full_binary_tree/get_all_leaf_paths.cc | 12 + .../src/utils/full_binary_tree/get_child.cc | 15 + .../src/utils/full_binary_tree/get_leaves.cc | 14 + .../full_binary_tree/get_num_tree_nodes.cc | 13 + .../full_binary_tree/get_subtree_at_path.cc | 15 + lib/utils/src/utils/full_binary_tree/visit.cc | 9 + .../get_dataflow_edges_from_node_to_node.cc | 15 + ...sitive_reduced_boundary_nodes_for_split.cc | 24 ++ ...t_transitive_reduced_edges_across_split.cc | 27 ++ ...transitive_reduced_outputs_across_split.cc | 14 + .../transitive_reduced_dataflow_graph.cc | 17 + .../get_edges_from_subgraph_to_subgraph.cc | 25 ++ .../graph/instances/adjacency_multidigraph.cc | 1 + .../instances/hashmap_undirected_graph.cc | 8 +- ...zy_copy_of_labelled_dataflow_graph_view.cc | 1 + .../algorithms/rewrite_node_labels.cc | 1 + .../algorithms/get_subgraph_inputs.cc | 1 + .../binary_sp_decomposition_tree.cc | 88 +++-- .../find_paths_to_leaf.cc | 19 + .../fmt.cc | 1 - .../generic_binary_sp_decomposition_tree.cc | 1 - ...ry_sp_decomposition_tree_implementation.cc | 18 + .../get.cc | 1 - .../get_all_leaf_paths.cc | 18 + .../get_leaves.cc | 17 + .../get_left_child.cc | 1 - .../get_node_type.cc | 1 - .../get_num_tree_nodes.cc | 17 + .../get_right_child.cc | 1 - .../get_subtree_at_path.cc | 19 + .../hash.cc | 1 - .../is.cc | 1 - .../is_binary_sp_tree_left_associative.cc | 17 + .../is_binary_sp_tree_right_associative.cc | 17 + .../json.cc | 1 - .../make.cc | 1 - .../require.cc | 1 - .../transform.cc | 1 - .../visit.cc | 23 ++ ...ft_associative_binary_sp_tree_from_nary.cc | 65 ++-- ...ht_associative_binary_sp_tree_from_nary.cc | 59 +-- .../get_series_parallel_decomposition.cc | 24 +- .../intermediate_sp_decomposition_tree.cc | 55 ++- .../series_parallel_decomposition.cc | 2 +- .../series_parallel/series_parallel_splits.cc | 168 ++++---- .../json/check_is_json_deserializable.cc | 1 + .../utils/json/check_is_json_serializable.cc | 1 + .../src/utils/containers/cartesian_product.cc | 62 +++ .../test/src/utils/containers/flatmap.cc | 105 +++++ .../utils/containers/get_all_assignments.cc | 53 +++ lib/utils/test/src/utils/containers/try_at.cc | 29 ++ .../containers/unordered_map_from_pairs.cc | 57 +++ .../get_dataflow_edges_from_node_to_node.cc | 104 +++++ ...sitive_reduced_boundary_nodes_for_split.cc | 55 +++ ...t_transitive_reduced_edges_across_split.cc | 146 +++++++ ...transitive_reduced_outputs_across_split.cc | 52 +++ .../get_edges_from_subgraph_to_subgraph.cc | 142 +++++++ .../fmt.cc | 51 --- .../get_leaves.cc | 110 ++++-- .../get_left_child.cc | 41 -- .../get_num_tree_nodes.cc | 83 ++-- .../get_right_child.cc | 41 -- .../hash.cc | 117 ------ .../is_binary_sp_tree_left_associative.cc | 79 ++-- .../is_binary_sp_tree_right_associative.cc | 79 ++-- .../json.cc | 131 ------- .../transform.cc | 28 -- ...ft_associative_binary_sp_tree_from_nary.cc | 33 +- .../nary_sp_tree_from_binary.cc | 68 ++-- ...ht_associative_binary_sp_tree_from_nary.cc | 33 +- .../get_series_parallel_decomposition.cc | 48 +-- 318 files changed, 8579 insertions(+), 3627 deletions(-) delete mode 100644 lib/compiler/include/compiler/cost_estimate.h create mode 100644 lib/compiler/include/compiler/cost_estimator/cost_estimator.h create mode 100644 lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml create mode 100644 lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml create mode 100644 lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml create mode 100644 lib/compiler/include/compiler/graph_optimize_result.struct.toml create mode 100644 lib/compiler/include/compiler/graph_optimize_state.h delete mode 100644 lib/compiler/include/compiler/graph_utils.h delete mode 100644 lib/compiler/include/compiler/machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h create mode 100644 lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h create mode 100644 lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping.h rename lib/compiler/include/compiler/{ => machine_mapping}/machine_mapping.struct.toml (50%) create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h create mode 100644 lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml create mode 100644 lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h create mode 100644 lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml delete mode 100644 lib/compiler/include/compiler/optimal_cost_state.struct.toml create mode 100644 lib/compiler/include/compiler/optimizer_config.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml rename lib/compiler/include/compiler/series_parallel/{ => computation_graph}/computation_graph_binary_sp_decomposition.h (52%) create mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml rename lib/compiler/include/compiler/series_parallel/{ => computation_graph}/get_computation_graph_series_parallel_decomposition.h (100%) delete mode 100644 lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h create mode 100644 lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml create mode 100644 lib/compiler/src/compiler/cost_estimator/cost_estimator.cc create mode 100644 lib/compiler/src/compiler/graph_optimize_state.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc create mode 100644 lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc create mode 100644 lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc rename lib/compiler/src/compiler/series_parallel/{ => computation_graph}/get_computation_graph_series_parallel_decomposition.cc (97%) delete mode 100644 lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc create mode 100644 lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc delete mode 100644 lib/compiler/src/graph_utils.cc delete mode 100644 lib/compiler/src/machine_mapping.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h create mode 100644 lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc create mode 100644 lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc rename lib/compiler/test/src/compiler/series_parallel/{ => computation_graph}/get_computation_graph_series_parallel_decomposition.cc (96%) create mode 100644 lib/compiler/test/src/graph_optimize_state.cc delete mode 100644 lib/compiler/test/src/test_cost_estimator.h delete mode 100644 lib/compiler/test/src/test_generator.h delete mode 100644 lib/compiler/test/src/test_labelled_open_graph.cc delete mode 100644 lib/compiler/test/src/test_machine_mapping.cc delete mode 100644 lib/compiler/test/src/test_open_graph.cc delete mode 100644 lib/compiler/test/src/test_optimal_cost.cc rename lib/compiler/test/src/{test_unity_algorithm.cc => unity_algorithm.cc} (93%) delete mode 100644 lib/op-attrs/include/op-attrs/operator_attrs.h delete mode 100644 lib/op-attrs/src/operator_attrs.cc create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml create mode 100644 lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml create mode 100644 lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc create mode 100644 lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc create mode 100644 lib/utils/include/utils/any_value_type/any_value_type.h create mode 100644 lib/utils/include/utils/archetypes/value_type.h create mode 100644 lib/utils/include/utils/containers/cartesian_product.h create mode 100644 lib/utils/include/utils/containers/get_all_assignments.h create mode 100644 lib/utils/include/utils/containers/try_at.h create mode 100644 lib/utils/include/utils/containers/unordered_map_from_pairs.h create mode 100644 lib/utils/include/utils/fmt/json.h create mode 100644 lib/utils/include/utils/fmt/monostate.h create mode 100644 lib/utils/include/utils/full_binary_tree/binary_tree_path.h create mode 100644 lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml create mode 100644 lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml create mode 100644 lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml create mode 100644 lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_child.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_leaves.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h create mode 100644 lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h create mode 100644 lib/utils/include/utils/full_binary_tree/visit.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h create mode 100644 lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml create mode 100644 lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h rename lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/{create_lazy_copy_of_labelled_open_dataflow_graph_view.h => create_lazy_copy_of_labelled_dataflow_graph_view.h} (88%) create mode 100644 lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h create mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h delete mode 100644 lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h create mode 100644 lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml create mode 100644 lib/utils/include/utils/graph/series_parallel/series_split.struct.toml create mode 100644 lib/utils/include/utils/json/check_is_json_deserializable.h create mode 100644 lib/utils/include/utils/json/check_is_json_serializable.h create mode 100644 lib/utils/src/utils/any_value_type/any_value_type.cc create mode 100644 lib/utils/src/utils/archetypes/value_type.cc create mode 100644 lib/utils/src/utils/containers/cartesian_product.cc create mode 100644 lib/utils/src/utils/containers/get_all_assignments.cc create mode 100644 lib/utils/src/utils/containers/try_at.cc create mode 100644 lib/utils/src/utils/containers/unordered_map_from_pairs.cc create mode 100644 lib/utils/src/utils/fmt/json.cc create mode 100644 lib/utils/src/utils/fmt/monostate.cc create mode 100644 lib/utils/src/utils/full_binary_tree/binary_tree_path.cc create mode 100644 lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_child.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_leaves.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc create mode 100644 lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc create mode 100644 lib/utils/src/utils/full_binary_tree/visit.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc create mode 100644 lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc create mode 100644 lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc create mode 100644 lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc create mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc delete mode 100644 lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc create mode 100644 lib/utils/src/utils/json/check_is_json_deserializable.cc create mode 100644 lib/utils/src/utils/json/check_is_json_serializable.cc create mode 100644 lib/utils/test/src/utils/containers/cartesian_product.cc create mode 100644 lib/utils/test/src/utils/containers/flatmap.cc create mode 100644 lib/utils/test/src/utils/containers/get_all_assignments.cc create mode 100644 lib/utils/test/src/utils/containers/try_at.cc create mode 100644 lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc create mode 100644 lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc create mode 100644 lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc delete mode 100644 lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index f976d369d5..b54ef25819 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -123,9 +123,9 @@ jobs: run: | test_target.sh substitutions - # - name: Test compiler - # run: | - # test_target.sh compiler + - name: Test compiler + run: | + test_target.sh compiler - name: Test substitution-generator run: | diff --git a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml index efaf368bc8..efaf10c255 100644 --- a/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml +++ b/bin/export-model-arch/include/export_model_arch/json_sp_model_export.struct.toml @@ -3,24 +3,22 @@ name = "JsonSPModelExport" features = [ "eq", "hash", - "json", "fmt", + "json", ] includes = [ "pcg/file_format/v1/v1_computation_graph.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", ] src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/json.h", ] [[fields]] name = "sp_decomposition" -type = "::FlexFlow::GenericBinarySPDecompositionTree" +type = "::FlexFlow::V1BinarySPDecomposition" [[fields]] name = "computation_graph" diff --git a/bin/export-model-arch/src/export_model_arch.cc b/bin/export-model-arch/src/export_model_arch.cc index 1c2dfd6ea3..64419acce4 100644 --- a/bin/export-model-arch/src/export_model_arch.cc +++ b/bin/export-model-arch/src/export_model_arch.cc @@ -1,5 +1,5 @@ -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "export_model_arch/json_sp_model_export.dtg.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" @@ -13,7 +13,6 @@ #include "utils/cli/cli_parse.h" #include "utils/cli/cli_parse_result.h" #include "utils/cli/cli_spec.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" #include "utils/graph/series_parallel/get_series_parallel_decomposition.h" @@ -105,9 +104,8 @@ tl::expected to_v1_including_node_numbering(computation_graph); V1ComputationGraph v1_cg = v1_result.first; bidict layer_numbering = v1_result.second; - GenericBinarySPDecompositionTree v1_sp_decomposition = - transform(sp_decomposition.raw_tree, - [&](layer_guid_t const &l) { return layer_numbering.at_r(l); }); + V1BinarySPDecomposition v1_sp_decomposition = + to_v1(sp_decomposition, layer_numbering); return JsonSPModelExport{ v1_sp_decomposition, diff --git a/flake.lock b/flake.lock index 1aad68ae29..87fae7f446 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1722923482, - "narHash": "sha256-myUec+oBcnKNCqLQqSiPCyXFsIsvlrsGoj/mQFlHVrY=", + "lastModified": 1728341842, + "narHash": "sha256-XMS52KBSS6z3k2VaiVcHyZQD6b2QUm1wIvTClel4xwg=", "owner": "lockshaw", "repo": "proj", - "rev": "c650b0e52337652ea7190131988c0370e0ee7f25", + "rev": "830fb5b1a0c7087752693990e90bbbf021168dfe", "type": "github" }, "original": { diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h deleted file mode 100644 index 2e4ff8448b..0000000000 --- a/lib/compiler/include/compiler/cost_estimate.h +++ /dev/null @@ -1,61 +0,0 @@ - -#ifndef _FLEXFLOW_COMPILER_COST_ESTIMATE_H -#define _FLEXFLOW_COMPILER_COST_ESTIMATE_H - -#include "op-attrs/operator_attrs.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" - -namespace FlexFlow { - -struct ICostEstimator { - virtual float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const = 0; - virtual float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const = 0; - - ICostEstimator() = default; - ICostEstimator(ICostEstimator const &) = delete; - ICostEstimator &operator=(ICostEstimator const &) = delete; - - virtual ~ICostEstimator() = default; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); - -struct CostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const { - return this->implementation_ptr->estimate_cost( - op, inputs, weights, outputs, mv); - } - - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const { - return this->implementation_ptr->estimate_cost(tensor_shape, src, dst); - } - - template - static typename std::enable_if::value, - CostEstimator>::type - create(Args &&...args) { - return CostEstimator(std::make_shared(std::forward(args)...)); - } - -private: - CostEstimator(std::shared_ptr implementation_ptr) - : implementation_ptr(implementation_ptr) {} - std::shared_ptr implementation_ptr; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h new file mode 100644 index 0000000000..65bae0c76a --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h @@ -0,0 +1,45 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H + +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "pcg/machine_view.dtg.h" +#include + +namespace FlexFlow { + +struct ICostEstimator { + virtual float estimate_cost(OpCostEstimateKey const &) const = 0; + virtual float estimate_cost(TensorSetMovement const &) const = 0; + + ICostEstimator() = default; + ICostEstimator(ICostEstimator const &) = delete; + ICostEstimator &operator=(ICostEstimator const &) = delete; + + virtual ~ICostEstimator() = default; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); + +struct CostEstimator { + float estimate_cost(OpCostEstimateKey const &k) const; + float estimate_cost(TensorSetMovement const &m) const; + + template + static typename std::enable_if::value, + CostEstimator>::type + create(Args &&...args) { + return CostEstimator(std::make_shared(std::forward(args)...)); + } + +private: + CostEstimator(std::shared_ptr implementation_ptr); + +private: + std::shared_ptr implementation_ptr; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..8fd860d00d --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_estimate_key.struct.toml @@ -0,0 +1,40 @@ +namespace = "FlexFlow" +name = "OpCostEstimateKey" +features = [ + "eq", + "ord", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "machine_view" +type = "::FlexFlow::MachineView" diff --git a/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml new file mode 100644 index 0000000000..70f73ebe51 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "SingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "pcg/machine_view.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::MachineView>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml new file mode 100644 index 0000000000..3625605239 --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "TensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/cost_estimator/single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::SingleTensorMovement>" diff --git a/lib/compiler/include/compiler/graph_optimize_result.struct.toml b/lib/compiler/include/compiler/graph_optimize_result.struct.toml new file mode 100644 index 0000000000..22f29cbd59 --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_result.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "GraphOptimizeResult" +features = [ ] + +includes = [ + "compiler/machine_mapping/machine_mapping.dtg.h", + "pcg/parallel_computation_graph/parallel_computation_graph.h" +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" diff --git a/lib/compiler/include/compiler/graph_optimize_state.h b/lib/compiler/include/compiler/graph_optimize_state.h new file mode 100644 index 0000000000..2de2321ba6 --- /dev/null +++ b/lib/compiler/include/compiler/graph_optimize_state.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_STATE_H +#define _FLEXFLOW_COMPILER_MCMC_STATE_H + +#include "compiler/graph_optimize_result.dtg.h" + +namespace FlexFlow { + +struct GraphOptimizeState { + GraphOptimizeState(GraphOptimizeResult const &graph_optimize_result, + float runtime); + + GraphOptimizeResult graph_optimize_result; + float runtime; + + bool operator==(GraphOptimizeState const &other) const; + bool operator!=(GraphOptimizeState const &other) const; + bool operator<(GraphOptimizeState const &other) const; +}; + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::GraphOptimizeState> { + size_t operator()(::FlexFlow::GraphOptimizeState const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/compiler/include/compiler/graph_utils.h b/lib/compiler/include/compiler/graph_utils.h deleted file mode 100644 index 75fd369434..0000000000 --- a/lib/compiler/include/compiler/graph_utils.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_GRAPH_UTILS_H -#define _FLEXFLOW_COMPILER_GRAPH_UTILS_H - -#include "compiler/unity_algorithm.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" - -namespace FlexFlow { - -SeriesParallelDecomposition - get_series_parallel_decomposition(ParallelComputationGraph const &pcg); - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); - -// NOTE(@wmdi): I think we should have the following interfaces in the graph -// library eventually. - -template -void minimize(T &t, T const &v) { - if (v < t) { - t = v; - } -} - -template -void minimize(T &t, T const &v, Compare comp) { - if (comp(v, t)) { - t = v; - } -} - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h deleted file mode 100644 index 3774f2cd52..0000000000 --- a/lib/compiler/include/compiler/machine_mapping.h +++ /dev/null @@ -1,70 +0,0 @@ -#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H -#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H - -#include "compiler/machine_mapping.dtg.h" -#include "compiler/optimal_cost_state.dtg.h" -#include "cost_estimate.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" -#include "utils/visitable.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &, MachineMapping const &); - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); - -struct OptimalCostResult { - static OptimalCostResult sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2); - static OptimalCostResult infinity(); - - float runtime; - req machine_mapping; -}; -FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); - -struct OptimalCostRuntimeCmp { - bool operator()(OptimalCostResult const &, OptimalCostResult const &); -}; - -class OptimalCostCache { -public: - OptimalCostCache() = default; - - std::optional load(OptimalCostState const &) const; - void save(OptimalCostState const &, OptimalCostResult const &); - -private: - std::unordered_map cache; -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs); - -} // namespace FlexFlow - -// namespace std { -// -// template <> -// struct hash> { -// size_t operator()( -// std::unordered_map const &g) -// const; -// }; - -// }; // namespace std - -#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml new file mode 100644 index 0000000000..449a448706 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.struct.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "AbstractedSingleTensorMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "op-attrs/parallel_tensor_shape.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "::FlexFlow::ParallelTensorShape" + +[[fields]] +name = "src_machine_views" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" + +[[fields]] +name = "dst_machine_views" +type = "std::unordered_set<::FlexFlow::BinaryTreePath>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h new file mode 100644 index 0000000000..5b7e2f3613 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ABSTRACTED_TENSOR_SET_MOVEMENT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement(); + +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &); +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &); + +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &, + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml new file mode 100644 index 0000000000..4cf184706b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "AbstractedTensorSetMovement" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "single_tensor_movements" +type = "std::unordered_multiset<::FlexFlow::AbstractedSingleTensorMovement>" diff --git a/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..8567a7a3e6 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_ABSTRACTED_TENSOR_SET_MOVEMENT_ACROSS_SPLIT_H + +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml new file mode 100644 index 0000000000..e71cfc540f --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/feasible_machine_mapping_result.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "FeasibleMachineMappingResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", +] + +[[fields]] +name = "runtime" +type = "float" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h new file mode 100644 index 0000000000..990c1c8205 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_machine_resource_splits.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_RESOURCE_SPLITS_H + +#include "pcg/machine_specification.dtg.h" +#include +#include + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h new file mode 100644 index 0000000000..62da90bfcb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_optimal_machine_mapping.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_GET_OPTIMAL_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation); + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h new file mode 100644 index 0000000000..2aed9a20e4 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/get_tensor_set_movement_across_split.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_ESTIMATE_COST_ACROSS_SPLIT_H + +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &transitive_reduced_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml new file mode 100644 index 0000000000..b9a7f9ac59 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/include_unconstrained.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "IncludeUnconstrained" +features = [ + "eq", + "ord", + "hash", + "fmt", + "rapidcheck", + "json", +] + +includes = [] + +[[fields]] +name = "raw_bool" +type = "bool" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h new file mode 100644 index 0000000000..06cbbf942d --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/machine_mapping.dtg.h" + +namespace FlexFlow { + +MachineMapping combine_disjoint_mappings(MachineMapping const &, + MachineMapping const &); + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml similarity index 50% rename from lib/compiler/include/compiler/machine_mapping.struct.toml rename to lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml index 4c4912a3fd..92517c1110 100644 --- a/lib/compiler/include/compiler/machine_mapping.struct.toml +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping.struct.toml @@ -9,13 +9,16 @@ features = [ "fmt", ] -includes = [ - "utils/graph/node/node.dtg.h", +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", "pcg/machine_view.dtg.h", +] + +src_includes = [ "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/fmt/unordered_map.h", ] [[fields]] name = "machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" \ No newline at end of file +type = "std::unordered_map<::FlexFlow::parallel_layer_guid_t, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h new file mode 100644 index 0000000000..3a0fcf0e15 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CACHE_H + +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" + +namespace FlexFlow { + +MachineMappingCache empty_machine_mapping_cache(); +std::optional + machine_mapping_cache_load(MachineMappingCache const &, + MachineMappingState const &); +void machine_mapping_cache_save(MachineMappingCache &, + MachineMappingState const &, + MachineMappingResult const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml new file mode 100644 index 0000000000..a76ff26eb9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_cache.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "MachineMappingCache" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/machine_mapping_state.dtg.h", + "compiler/machine_mapping/machine_mapping_result.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_map" +type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h new file mode 100644 index 0000000000..d314ab493b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_CONSTRAINTS_H + +#include "compiler/machine_mapping/include_unconstrained.dtg.h" +#include "compiler/machine_mapping/machine_mapping.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &); + +std::unordered_set + get_all_layers(MachineMappingConstraints const &, + IncludeUnconstrained const &); + +std::optional + get_machine_view_for_layer(MachineMappingConstraints const &, + BinaryTreePath const &); + +MachineMappingConstraints restrict_to_child(MachineMappingConstraints const &, + BinaryTreePathEntry const &); +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &); +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &); + +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &, + ParallelLayerGuidObliviousMachineMapping const &); + +std::optional require_only_root(MachineMappingConstraints const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml new file mode 100644 index 0000000000..8e13abedb9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_constraints.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "MachineMappingConstraints" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", + "utils/fmt/optional.h", +] + +[[fields]] +name = "machine_views" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, std::optional<::FlexFlow::MachineView>>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml new file mode 100644 index 0000000000..81e26f491d --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_context.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "MachineMappingContext" +features = [] + +includes = [ + "compiler/cost_estimator/cost_estimator.h", + "pcg/machine_view.dtg.h", + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[fields]] +name = "cost_estimator" +type = "::FlexFlow::CostEstimator" + +[[fields]] +name = "allowed_machine_views" +type = "std::function(::FlexFlow::UnmappedOpCostEstimateKey const &, ::FlexFlow::MachineSpecification const &)>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h new file mode 100644 index 0000000000..68d02aaa54 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_GET_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" + +namespace FlexFlow { + +MachineMappingProblemTree + get_machine_mapping_problem_tree(ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h new file mode 100644 index 0000000000..29e9e7c90b --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_H + +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree(); + +SPDecompositionTreeNodeType get_node_type(MachineMappingProblemTree const &); + +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &); +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &); + +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &, + BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml new file mode 100644 index 0000000000..1949f143cb --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineMappingProblemTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h", +] + +[[values]] +type = "::FlexFlow::MMProblemTreeSeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::MMProblemTreeParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::UnmappedOpCostEstimateKey" +key = "leaf" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml new file mode 100644 index 0000000000..5247b2006a --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "MMProblemTreeParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml new file mode 100644 index 0000000000..d4f61bb3f5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "MMProblemTreeSeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct MachineMappingProblemTree", +] + +post_includes = [ + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +includes = [ + "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h", +] + +[[fields]] +name = "tensor_set_movement" +type = "::FlexFlow::AbstractedTensorSetMovement" + +[[fields]] +name = "left_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::MachineMappingProblemTree" +indirect = true diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h new file mode 100644 index 0000000000..9fbad4a1d0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_PROBLEM_TREE_UNMAPPED_OP_COST_ESTIMATE_KEY_H + +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &, parallel_layer_guid_t const &); + +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml new file mode 100644 index 0000000000..fe76683eb7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.struct.toml @@ -0,0 +1,36 @@ +namespace = "FlexFlow" +name = "UnmappedOpCostEstimateKey" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ + "op-attrs/pcg_operator_attrs.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "", + "pcg/machine_view.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "op_attrs" +type = "::FlexFlow::PCGOperatorAttrs" + +[[fields]] +name = "input_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "weight_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "output_shapes" +type = "std::vector<::FlexFlow::ParallelTensorShape>" + diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h new file mode 100644 index 0000000000..b21fea5f24 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MACHINE_MAPPING_RESULT_H + +#include "compiler/machine_mapping/machine_mapping_result.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" + +namespace FlexFlow { + +[[nodiscard]] MachineMappingResult infeasible_machine_mapping_result(); +[[nodiscard]] bool is_infeasible(MachineMappingResult const &); +FeasibleMachineMappingResult require_feasible(MachineMappingResult const &); + +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &); + +[[nodiscard]] MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &pre_result, + MachineMappingResult const &post_result, + std::optional const + ¶llel_split_transformation); +[[nodiscard]] MachineMappingResult + parallel_combine(MachineMappingResult const &lhs_result, + MachineMappingResult const &rhs_result); + +[[nodiscard]] MachineMappingResult + minimize_runtime(MachineMappingResult const &m1, + MachineMappingResult const &m2); + +[[nodiscard]] MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml new file mode 100644 index 0000000000..92a2873af5 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_result.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "MachineMappingResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/feasible_machine_mapping_result.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/optional.h", +] + +[[fields]] +name = "raw_result" +type = "std::optional<::FlexFlow::FeasibleMachineMappingResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml new file mode 100644 index 0000000000..1346f6ebe7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_state.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MachineMappingState" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_specification.dtg.h", + "compiler/machine_mapping/machine_mapping_constraints.dtg.h", + "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h", +] + +[[fields]] +name = "problem_tree" +type = "::FlexFlow::MachineMappingProblemTree" + +[[fields]] +name = "resources" +type = "::FlexFlow::MachineSpecification" + +[[fields]] +name = "constraints" +type = "::FlexFlow::MachineMappingConstraints" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h new file mode 100644 index 0000000000..accd96af4c --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H + +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &pre, + ParallelLayerGuidObliviousMachineMapping const &post); + +ParallelLayerGuidObliviousMachineMapping + restrict_to_left_child(ParallelLayerGuidObliviousMachineMapping const &); +ParallelLayerGuidObliviousMachineMapping + restrict_to_right_child(ParallelLayerGuidObliviousMachineMapping const &); + +std::optional + get_machine_view_for_path(ParallelLayerGuidObliviousMachineMapping const &, + BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml new file mode 100644 index 0000000000..f00fcc8490 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.struct.toml @@ -0,0 +1,21 @@ +namespace = "FlexFlow" +name = "ParallelLayerGuidObliviousMachineMapping" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/machine_view.dtg.h", + "utils/full_binary_tree/binary_tree_path.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_mapping" +type = "std::unordered_map<::FlexFlow::BinaryTreePath, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml new file mode 100644 index 0000000000..8247c0cbdc --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/parallel_split_transformation.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "ParallelSplitTransformation" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LthenR" + +[[values]] +name = "RthenL" diff --git a/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml new file mode 100644 index 0000000000..155e526672 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/pcg_split_boundary_layers.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "PCGSplitBoundaryLayers" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::parallel_layer_guid_t>" diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h new file mode 100644 index 0000000000..2b2bc9bf84 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_TRANSITIVE_REDUCED_PCG_H + +#include "compiler/machine_mapping/pcg_split_boundary_layers.dtg.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &); + +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &); + +std::unordered_set + pcg_get_transitive_reduced_edges_across_split(TransitiveReducedPCG const &, + PCGBinarySeriesSplit const &); + +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); + +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &, PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml new file mode 100644 index 0000000000..bb76ec2ff7 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/transitive_reduced_pcg.struct.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "TransitiveReducedPCG" +features = [] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h", +] + +[[fields]] +name = "full_pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/compiler/include/compiler/optimal_cost_state.struct.toml b/lib/compiler/include/compiler/optimal_cost_state.struct.toml deleted file mode 100644 index 036647c0b1..0000000000 --- a/lib/compiler/include/compiler/optimal_cost_state.struct.toml +++ /dev/null @@ -1,36 +0,0 @@ -namespace = "FlexFlow" -name = "OptimalCostState" -features = [ - "eq", - # "ord", - "hash", - # "json", - # "rapidcheck", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/series_parallel_decomposition.dtg.h", - "pcg/machine_specification.dtg.h", - "pcg/machine_view.dtg.h", - "utils/graph/node/node.dtg.h", - "utils/graph/open_dataflow_graph/open_dataflow_edge.dtg.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h", -] - -[[fields]] -name = "subgraph" -type = "::FlexFlow::SeriesParallelDecomposition" - -[[fields]] -name = "resource" -type = "::FlexFlow::MachineSpecification" - -[[fields]] -name = "given_machine_views" -type = "std::unordered_map<::FlexFlow::Node, ::FlexFlow::MachineView>" - -[[fields]] -name = "frontier_machine_views" -type = "std::unordered_map<::FlexFlow::OpenDataflowEdge, ::FlexFlow::MachineView>" diff --git a/lib/compiler/include/compiler/optimizer_config.struct.toml b/lib/compiler/include/compiler/optimizer_config.struct.toml new file mode 100644 index 0000000000..b7f4f71e9c --- /dev/null +++ b/lib/compiler/include/compiler/optimizer_config.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "OptimizerConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ +] + +[[fields]] +name = "alpha" +type = "float" + +[[fields]] +name = "budget" +type = "int" + +[[fields]] +name = "threshold" +type = "float" + +[[fields]] +name = "max_num_ops" +type = "int" diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..9654a2546e --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml new file mode 100644 index 0000000000..aa66c80b43 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ComputationGraphBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::ComputationGraphBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h similarity index 52% rename from lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h rename to lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h index 3032e3efe9..fdc80a1e37 100644 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.h +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h @@ -1,19 +1,30 @@ #ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_COMPUTATION_GRAPH_BINARY_SP_DECOMPOSITION_H -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.dtg.h" #include "pcg/computation_graph.dtg.h" +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include "utils/overload.h" namespace FlexFlow { +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree(); + SPDecompositionTreeNodeType get_node_type(ComputationGraphBinarySPDecomposition const &); + ComputationGraphBinarySPDecomposition - get_left_child(ComputationGraphBinarySPDecomposition const &); -ComputationGraphBinarySPDecomposition - get_right_child(ComputationGraphBinarySPDecomposition const &); -layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &); + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &); + std::optional get_computation_graph_left_assoc_binary_sp_decomposition( ComputationGraph const &); @@ -25,6 +36,9 @@ bool is_right_associative(ComputationGraphBinarySPDecomposition const &); std::unordered_multiset get_layers(ComputationGraphBinarySPDecomposition const &); +V1BinarySPDecomposition to_v1(ComputationGraphBinarySPDecomposition const &, + bidict const &layer_numbering); + } // namespace FlexFlow #endif diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..452470620b --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "ComputationGraphBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/layer_guid_t.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_series_split.dtg.h", + "compiler/series_parallel/computation_graph/computation_graph_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::ComputationGraphBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::ComputationGraphBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h similarity index 100% rename from lib/compiler/include/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h rename to lib/compiler/include/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h diff --git a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml b/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml deleted file mode 100644 index 147b1e3acf..0000000000 --- a/lib/compiler/include/compiler/series_parallel/computation_graph_binary_sp_decomposition.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "ComputationGraphBinarySPDecomposition" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "pcg/layer_guid_t.dtg.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::layer_guid_t>" diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h new file mode 100644 index 0000000000..d43edaa79d --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_balanced_binary_sp_decomposition.h @@ -0,0 +1,11 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_BALANCED_BINARY_SP_DECOMPOSITION_H + +namespace FlexFlow { + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h new file mode 100644 index 0000000000..d4ae77541a --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_GET_PCG_SERIES_PARALLEL_DECOMPOSITION_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h new file mode 100644 index 0000000000..f348b1a851 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_PARALLEL_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..f7f7026716 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_parallel_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h new file mode 100644 index 0000000000..0842ffb48f --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_PCG_BINARY_SERIES_SPLIT_H + +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +BinarySeriesSplit + binary_series_split_from_pcg_series_split(PCGBinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml new file mode 100644 index 0000000000..af2c8c4dae --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_series_split.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "PCGBinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct PCGBinarySPDecomposition", +] + +post_includes = [ + "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h", +] + +includes = [] + +[[fields]] +name = "left_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::PCGBinarySPDecomposition" +indirect = true diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h new file mode 100644 index 0000000000..86fa1a59aa --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_SERIES_PARALLEL_PCG_BINARY_SP_DECOMPOSITION_H + +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" +#include + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree(); + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &); + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &); +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &); + +SPDecompositionTreeNodeType get_node_type(PCGBinarySPDecomposition const &); + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &, + parallel_layer_guid_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..52372fb270 --- /dev/null +++ b/lib/compiler/include/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "PCGBinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_series_split.dtg.h", + "compiler/series_parallel/pcg/pcg_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::PCGBinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::PCGBinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::parallel_layer_guid_t" +key = "leaf" diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index abddef37ed..232f2b9563 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -1,39 +1,17 @@ #ifndef _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H #define _FLEXFLOW_COMPILER_UNITY_ALGORITHM_H -#include "compiler/machine_mapping.h" -#include "cost_estimate.h" -#include "machine_mapping.h" +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/graph_optimize_result.dtg.h" +#include "optimizer_config.dtg.h" #include "pcg/computation_graph.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/sub_parallel_computation_graph.h" -namespace FlexFlow { - -struct Strategy { - ParallelComputationGraph pcg; - MachineMapping machine_mapping; - req runtime; - friend bool operator!=(Strategy const &lhs, Strategy const &rhs) { - return (lhs.machine_mapping != rhs.machine_mapping) || - (lhs.runtime != rhs.runtime); - } -}; - -FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); -struct StrategyRuntimeCmp { - bool operator()(Strategy const &, Strategy const &); -}; - -struct OptimizerConfig { - float alpha; - int budget; - float threshold; - int max_num_ops; -}; +namespace FlexFlow { -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( diff --git a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc new file mode 100644 index 0000000000..051ffcd190 --- /dev/null +++ b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc @@ -0,0 +1,16 @@ +#include "compiler/cost_estimator/cost_estimator.h" + +namespace FlexFlow { + +CostEstimator::CostEstimator(std::shared_ptr implementation_ptr) + : implementation_ptr(implementation_ptr) {} + +float CostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->implementation_ptr->estimate_cost(k); +} + +float CostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->implementation_ptr->estimate_cost(m); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/graph_optimize_state.cc b/lib/compiler/src/compiler/graph_optimize_state.cc new file mode 100644 index 0000000000..4b4f323ea4 --- /dev/null +++ b/lib/compiler/src/compiler/graph_optimize_state.cc @@ -0,0 +1,85 @@ +#include "compiler/graph_optimize_state.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" + +namespace FlexFlow { + +GraphOptimizeState::GraphOptimizeState( + GraphOptimizeResult const &graph_optimize_result, float runtime) + : graph_optimize_result(graph_optimize_result), runtime(runtime) {} + +bool GraphOptimizeState::operator==(GraphOptimizeState const &other) const { + // Note(@wmdi): This is a hack to implement a partially correct homomorphism + // check. Switch to the homomorphism check used in substitutions right after + // https://github.com/flexflow/FlexFlow/pull/1471 is merged. + auto layers1 = topological_ordering(graph_optimize_result.pcg); + auto layers2 = topological_ordering(other.graph_optimize_result.pcg); + if (layers1.size() != layers2.size()) { + return false; + } + std::unordered_map mapping; + for (size_t i = 0; i < layers1.size(); ++i) { + if (get_parallel_layer_attrs(graph_optimize_result.pcg, layers1[i]) != + get_parallel_layer_attrs(other.graph_optimize_result.pcg, layers2[i])) { + return false; + } + auto inputs1 = get_incoming_tensors(graph_optimize_result.pcg, layers1[i]); + auto inputs2 = + get_incoming_tensors(other.graph_optimize_result.pcg, layers2[i]); + if (inputs1.size() != inputs2.size()) { + return false; + } + for (size_t j = 0; j < inputs1.size(); ++j) { + if (inputs1[j] != mapping.at(inputs2[j])) { + return false; + } + } + auto outputs1 = get_layer_outputs(graph_optimize_result.pcg, layers1[i]); + auto outputs2 = + get_layer_outputs(other.graph_optimize_result.pcg, layers2[i]); + if (outputs1.size() != outputs2.size()) { + return false; + } + for (size_t j = 0; j < outputs1.size(); ++j) { + mapping.emplace(outputs2[j], outputs1[j]); + } + } + return true; +} + +bool GraphOptimizeState::operator!=(GraphOptimizeState const &other) const { + return !(*this == other); +} + +bool GraphOptimizeState::operator<(GraphOptimizeState const &other) const { + return runtime < other.runtime; +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::GraphOptimizeState>::operator()( + ::FlexFlow::GraphOptimizeState const &state) const { + // TODO(@wmdi): Eventually it might be good to use a proper graph hash like + // https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash.html#networkx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash + size_t seed = 0; + auto layers = topological_ordering(state.graph_optimize_result.pcg); + ::FlexFlow::hash_combine(seed, layers.size()); + for (auto layer : layers) { + ::FlexFlow::hash_combine( + seed, get_parallel_layer_attrs(state.graph_optimize_result.pcg, layer)); + auto inputs = get_incoming_tensors(state.graph_optimize_result.pcg, layer); + ::FlexFlow::hash_combine(seed, inputs.size()); + for (auto input : inputs) { + for (size_t i = 0; i < layers.size(); ++i) { + if (get_source_layer(input) == layers[i]) { + ::FlexFlow::hash_combine(seed, i); + break; + } + } + } + } + return seed; +} + +} // namespace std diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc new file mode 100644 index 0000000000..6f3deca138 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -0,0 +1,62 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement empty_abstracted_tensor_set_movement() { + return AbstractedTensorSetMovement{{}}; +} + +std::unordered_set + get_src_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.src_machine_views; + }); +} + +std::unordered_set + get_dst_layers(AbstractedTensorSetMovement const &m) { + return flatmap(unordered_set_of(m.single_tensor_movements), + [](AbstractedSingleTensorMovement const &s) { + return s.dst_machine_views; + }); +} + +TensorSetMovement concretize_abstracted_tensor_set_movement( + AbstractedTensorSetMovement const &abstracted, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + ParallelLayerGuidObliviousMachineMapping mapping = + binary_combine_mappings(/*lhs=*/pre_mapping, + /*rhs=*/post_mapping); + + auto concretize_tensor_movement = + [&](AbstractedSingleTensorMovement const &a) { + return SingleTensorMovement{ + /*parallel_tensor_shape=*/a.parallel_tensor_shape, + /*src_machine_views=*/ + transform( + a.src_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(pre_mapping, path).value(); + }), + /*dst_machine_views=*/ + transform( + a.dst_machine_views, + [&](BinaryTreePath const &path) { + return get_machine_view_for_path(post_mapping, path).value(); + }), + }; + }; + + return TensorSetMovement{ + /*single_tensor_movements=*/transform(abstracted.single_tensor_movements, + concretize_tensor_movement), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..0e0f60c891 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,63 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + + std::unordered_set edges_across_split = + pcg_get_transitive_reduced_edges_across_split(tr_pcg, split); + + auto get_movement_for_tensor = + [&](parallel_tensor_guid_t const &t) -> AbstractedSingleTensorMovement { + std::unordered_set tensor_edges = + filter(edges_across_split, [&](ParallelComputationGraphEdge const &e) { + return get_parallel_tensor(e) == t; + }); + + std::unordered_set src_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_src_layer(e); + }); + + std::unordered_set dst_layers = + transform(tensor_edges, [&](ParallelComputationGraphEdge const &e) { + return get_dst_layer(e); + }); + + return AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/get_parallel_tensor_shape(tr_pcg.full_pcg, t), + /*src_machine_views=*/ + transform(src_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(split.get_left_child(), l)); + }), + /*dst_machine_views=*/ + transform(dst_layers, + [&](parallel_layer_guid_t const &l) { + return get_only( + find_paths_to_leaf(split.get_right_child(), l)); + }), + }; + }; + + std::unordered_map + single_tensor_movements = generate_map( + pcg_get_transitive_reduced_tensors_across_split(tr_pcg, split), + get_movement_for_tensor); + + return AbstractedTensorSetMovement{ + values(single_tensor_movements), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..5126d9687e --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,32 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "utils/hash/pair.h" + +namespace FlexFlow { + +std::unordered_set> + get_machine_resource_splits(MachineSpecification const &resource) { + std::unordered_set> + result; + + for (int i = 1; i < resource.num_nodes; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; + sub_resource1.num_nodes = i; + sub_resource2.num_nodes = resource.num_nodes - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); + } + + for (int i = 1; i < resource.num_gpus_per_node; i *= 2) { + MachineSpecification sub_resource1 = resource; + MachineSpecification sub_resource2 = resource; + sub_resource1.num_gpus_per_node = i; + sub_resource2.num_gpus_per_node = resource.num_gpus_per_node - i; + result.insert(std::make_pair(sub_resource1, sub_resource2)); + result.insert(std::make_pair(sub_resource2, sub_resource1)); + } + + return result; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..10abd7ff90 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -0,0 +1,254 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/contains.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/exception.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingState state = MachineMappingState{ + problem_tree, + resources, + constraints, + }; + + { + std::optional cached_result = + machine_mapping_cache_load(result_cache, state); + if (cached_result) { + return cached_result.value(); + } + } + + MachineMappingResult result = + problem_tree.visit(overload{ + [&](MMProblemTreeSeriesSplit const &series_split) { + return get_optimal_machine_mapping( + result_cache, + context, + series_split, + resources, + constraints, + /*parallel_split_transformation=*/std::nullopt); + }, + [&](auto const &decomp_tree_node) { + return get_optimal_machine_mapping(result_cache, + context, + decomp_tree_node, + resources, + constraints); + }, + }); + + machine_mapping_cache_save(result_cache, state, result); + return result; +} + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation) { + + auto get_boundary_machine_view_assignments = + [&](std::unordered_set const &boundary_layers) + -> std::unordered_set { + std::unordered_map> + allowed = generate_map( + boundary_layers, + [&](BinaryTreePath const &l) -> std::unordered_set { + UnmappedOpCostEstimateKey leaf = + mm_problem_tree_get_subtree_at_path( + MachineMappingProblemTree{series_split}, l) + .value() + .get(); + return context.allowed_machine_views(leaf, resources); + }); + return transform( + get_all_assignments(allowed), + [](std::unordered_map const &m) { + return ParallelLayerGuidObliviousMachineMapping{m}; + }); + }; + + auto eval_pre_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views) { + MachineMappingConstraints pre_candidate = with_additional_constraints( + restrict_to_left_child(constraints), assigned_pre_machine_views); + + MachineMappingResult pre_result = + get_optimal_machine_mapping(result_cache, + context, + series_split.get_left_child(), + resources, + pre_candidate); + + return pre_result; + }; + + auto eval_post_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views) { + MachineMappingConstraints post_candidate = with_additional_constraints( + restrict_to_right_child(constraints), assigned_post_machine_views); + + MachineMappingResult post_result = + get_optimal_machine_mapping(result_cache, + context, + series_split.get_right_child(), + resources, + post_candidate); + + return post_result; + }; + + MachineMappingResult result = infeasible_machine_mapping_result(); + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views : + get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + + MachineMappingResult pre_result = + eval_pre_boundary_mapping(assigned_pre_machine_views); + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views : + get_boundary_machine_view_assignments( + get_dst_layers(tensor_movement))) { + + MachineMappingResult post_result = + eval_post_boundary_mapping(assigned_post_machine_views); + + TensorSetMovement comm_across_split = + concretize_abstracted_tensor_set_movement( + tensor_movement, + /*pre_mapping=*/assigned_pre_machine_views, + /*post_mapping=*/assigned_post_machine_views); + float cost_across_split = + context.cost_estimator.estimate_cost(comm_across_split); + + result = minimize_runtime(result, + series_combine(cost_across_split, + pre_result, + post_result, + parallel_split_transformation)); + } + } + + return result; +} + +MachineMappingResult get_optimal_machine_mapping( + MachineMappingCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingProblemTree lhs = parallel_split.get_left_child(); + MachineMappingProblemTree rhs = parallel_split.get_right_child(); + + MachineMappingResult series_result = [&] { + MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*left_child=*/lhs, + /*right_child=*/rhs, + }; + + return get_optimal_machine_mapping(result_cache, + context, + series_split, + resources, + constraints, + ParallelSplitTransformation::LthenR); + }(); + + MachineMappingConstraints left_constraints = + restrict_to_left_child(constraints); + MachineMappingConstraints right_constraints = + restrict_to_right_child(constraints); + + auto evaluate_resource_split = + [&](std::pair const + &resource_split) { + MachineMappingResult left_result = get_optimal_machine_mapping( + result_cache, context, lhs, resource_split.first, left_constraints); + MachineMappingResult right_result = + get_optimal_machine_mapping(result_cache, + context, + rhs, + resource_split.second, + right_constraints); + + return parallel_combine(left_result, right_result); + }; + + std::unordered_set parallel_results = transform( + get_machine_resource_splits(resources), evaluate_resource_split); + + return minimize_runtime(series_result, + get_mapping_with_minimal_runtime(parallel_results)); +} + +MachineMappingResult + get_optimal_machine_mapping(MachineMappingCache &result_cache, + MachineMappingContext const &context, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resource, + MachineMappingConstraints const &constraints) { + + std::unordered_set candidates = [&] { + std::optional machine_view = require_only_root(constraints); + if (machine_view.has_value()) { + return std::unordered_set{machine_view.value()}; + } else { + return context.allowed_machine_views(leaf, resource); + } + }(); + + auto get_mapping_result = [&](MachineView const &machine_view) { + OpCostEstimateKey mapped = + map_unmapped_op_cost_estimate_key(leaf, machine_view); + float cost = context.cost_estimator.estimate_cost(mapped); + + return make_singleton_machine_mapping_result(cost, machine_view); + }; + + std::unordered_set candidate_results = + transform(candidates, get_mapping_result); + + return get_mapping_with_minimal_runtime(candidate_results); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..6cc3f4329c --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,26 @@ +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/sum.h" +#include "utils/containers/values.h" + +namespace FlexFlow { + +TensorSetMovement get_tensor_set_movement_across_split( + TransitiveReducedPCG const &tr_pcg, + PCGBinarySeriesSplit const &split, + ParallelLayerGuidObliviousMachineMapping const &pre_mapping, + ParallelLayerGuidObliviousMachineMapping const &post_mapping) { + AbstractedTensorSetMovement abstracted = + get_abstracted_tensor_set_movement_across_split(tr_pcg, split); + return concretize_abstracted_tensor_set_movement( + abstracted, pre_mapping, post_mapping); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc new file mode 100644 index 0000000000..6f350d8773 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -0,0 +1,18 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "utils/containers.h" +#include "utils/containers/are_disjoint.h" +#include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" + +namespace FlexFlow { + +MachineMapping combine_disjoint_mappings(MachineMapping const &s1, + MachineMapping const &s2) { + return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; +} + +bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { + return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc new file mode 100644 index 0000000000..fbfccf737f --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_cache.cc @@ -0,0 +1,30 @@ +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/try_at.h" + +namespace FlexFlow { + +MachineMappingCache empty_machine_mapping_cache() { + return MachineMappingCache{{}}; +} + +std::optional + machine_mapping_cache_load(MachineMappingCache const &cache, + MachineMappingState const &k) { + return try_at(cache.raw_map, k); +} + +void machine_mapping_cache_save(MachineMappingCache &cache, + MachineMappingState const &k, + MachineMappingResult const &v) { + if (contains_key(cache.raw_map, k)) { + throw mk_runtime_error( + fmt::format("machine_mapping_cache_save expected key to not already " + "exist, but received existing key {}", + k)); + } + + cache.raw_map.emplace(k, v); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc new file mode 100644 index 0000000000..2cee866a01 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -0,0 +1,112 @@ +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "utils/containers/filter.h" +#include "utils/containers/filtermap_keys.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/keys.h" +#include "utils/containers/map_values.h" +#include "utils/containers/restrict_keys.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +MachineMappingConstraints get_unconstrained_solution_for_layers( + std::unordered_set const &layers) { + return MachineMappingConstraints{ + generate_map(layers, + [](BinaryTreePath const &) -> std::optional { + return std::nullopt; + }), + }; +} + +std::unordered_set + get_all_layers(MachineMappingConstraints const &partial_solution, + IncludeUnconstrained const &include_unconstrained) { + std::unordered_set with_unconstrained = + keys(partial_solution.machine_views); + + if (include_unconstrained.raw_bool) { + return with_unconstrained; + } else { + return filter(with_unconstrained, [&](BinaryTreePath const &l) { + return partial_solution.machine_views.at(l).has_value(); + }); + } +} + +std::optional get_machine_view_for_layer( + MachineMappingConstraints const &partial_solution, + BinaryTreePath const &layer) { + return partial_solution.machine_views.at(layer); +} + +MachineMappingConstraints + restrict_to_child(MachineMappingConstraints const &constraints, + BinaryTreePathEntry const &prefix) { + return MachineMappingConstraints{filtermap_keys( + constraints.machine_views, + [&](BinaryTreePath const &path) -> std::optional { + BinaryTreePathEntry head = binary_tree_path_get_top_level(path); + + if (head == prefix) { + BinaryTreePath rest = binary_tree_path_get_non_top_level(path); + return rest; + } else { + return std::nullopt; + } + })}; +} + +MachineMappingConstraints + restrict_to_left_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::LEFT_CHILD); +} + +MachineMappingConstraints + restrict_to_right_child(MachineMappingConstraints const &c) { + return restrict_to_child(c, BinaryTreePathEntry::RIGHT_CHILD); +} + +MachineMappingConstraints with_additional_constraints( + MachineMappingConstraints const &constraints, + ParallelLayerGuidObliviousMachineMapping const &additional) { + MachineMappingConstraints result = constraints; + + for (auto const &[layer, machine_view] : additional.raw_mapping) { + std::optional current_machine_view = + result.machine_views.at(layer); + + if (!current_machine_view.has_value()) { + result.machine_views.at(layer) = machine_view; + } else { + if (current_machine_view.value() != machine_view) { + throw mk_runtime_error( + fmt::format("with_additional_layer_machine_views received machine " + "view assignment for layer {} " + "to machine view {}, but that layer is already " + "assigned to machine view {}.", + layer, + machine_view, + current_machine_view.value())); + } + } + } + + return result; +} + +std::optional + require_only_root(MachineMappingConstraints const &constraints) { + if (keys(constraints.machine_views) != + std::unordered_set{binary_tree_root_path()}) { + throw mk_runtime_error( + fmt::format("require_only_root expected constraints to have only a " + "single key (the root path), but received {}", + constraints)); + } + + return constraints.machine_views.at(binary_tree_root_path()); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..367af3701e --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -0,0 +1,53 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingProblemTree get_machine_mapping_problem_tree( + ParallelComputationGraph const &pcg, + PCGBinarySPDecomposition const &sp_decomposition_tree) { + TransitiveReducedPCG tr_pcg = pcg_get_transitive_reduction(pcg); + + std::function + to_problem_tree; + + to_problem_tree = + [&](PCGBinarySPDecomposition const &sp) -> MachineMappingProblemTree { + return sp.visit(overload{ + [&](PCGBinarySeriesSplit const &series) { + AbstractedTensorSetMovement tensor_movement = + get_abstracted_tensor_set_movement_across_split(tr_pcg, series); + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_movement, + /*lhs=*/to_problem_tree(series.get_left_child()), + /*rhs=*/to_problem_tree(series.get_right_child()), + }, + }; + }, + [&](PCGBinaryParallelSplit const ¶llel) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + to_problem_tree(parallel.get_left_child()), + to_problem_tree(parallel.get_right_child()), + }, + }; + }, + [&](parallel_layer_guid_t const &leaf) { + return MachineMappingProblemTree{ + get_unmapped_op_cost_estimate_key_for_layer(pcg, leaf), + }; + }, + }); + }; + + return to_problem_tree(sp_decomposition_tree); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..1e39a7be19 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.cc @@ -0,0 +1,91 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_binary_sp_impl_for_mm_problem_tree() { + return GenericBinarySPDecompositionTreeImplementation< + MachineMappingProblemTree, + MMProblemTreeSeriesSplit, + MMProblemTreeParallelSplit, + UnmappedOpCostEstimateKey>{ + /*series_get_left_child=*/[](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](MMProblemTreeSeriesSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](MMProblemTreeParallelSplit const &split) + -> MachineMappingProblemTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](MachineMappingProblemTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeSeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](MachineMappingProblemTree const &tree) + -> MMProblemTreeParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](MachineMappingProblemTree const &tree) + -> UnmappedOpCostEstimateKey const & { + return tree.get(); + }, + }; +} + +SPDecompositionTreeNodeType + get_node_type(MachineMappingProblemTree const &tree) { + return tree.visit(overload{ + [](MMProblemTreeSeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](MMProblemTreeParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](UnmappedOpCostEstimateKey const &) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +std::unordered_multiset + get_leaves(MachineMappingProblemTree const &tree) { + return get_leaves(tree, generic_binary_sp_impl_for_mm_problem_tree()); +} + +std::unordered_set + get_all_leaf_paths(MachineMappingProblemTree const &tree) { + return get_all_leaf_paths(tree, generic_binary_sp_impl_for_mm_problem_tree()); +} + +std::optional + mm_problem_tree_get_subtree_at_path(MachineMappingProblemTree const &tree, + BinaryTreePath const &path) { + return get_subtree_at_path( + tree, generic_binary_sp_impl_for_mm_problem_tree(), path); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc new file mode 100644 index 0000000000..990b287f8b --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.cc @@ -0,0 +1,36 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" + +namespace FlexFlow { + +UnmappedOpCostEstimateKey get_unmapped_op_cost_estimate_key_for_layer( + ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { + auto get_tensor_shape = [&](parallel_tensor_guid_t const &t) { + return get_parallel_tensor_shape(pcg, t); + }; + + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/pcg_get_op_attrs(pcg, layer), + /*input_shapes=*/ + transform(get_incoming_inputs(pcg, layer), get_tensor_shape), + /*weight_shapes=*/ + transform(get_incoming_weights(pcg, layer), get_tensor_shape), + /*output_shapes=*/ + transform(get_layer_outputs(pcg, layer), get_tensor_shape), + }; +} + +OpCostEstimateKey + map_unmapped_op_cost_estimate_key(UnmappedOpCostEstimateKey const &unmapped, + MachineView const &machine_view) { + return OpCostEstimateKey{ + /*op_attrs=*/unmapped.op_attrs, + /*input_shapes=*/unmapped.input_shapes, + /*weight_shapes=*/unmapped.weight_shapes, + /*output_shapes=*/unmapped.output_shapes, + /*machine_view=*/machine_view, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc new file mode 100644 index 0000000000..3409f7f871 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_result.cc @@ -0,0 +1,138 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +MachineMappingResult infeasible_machine_mapping_result() { + return MachineMappingResult{std::nullopt}; +} + +bool is_infeasible(MachineMappingResult const &result) { + return !result.raw_result.has_value(); +} + +FeasibleMachineMappingResult + require_feasible(MachineMappingResult const &result) { + return result.raw_result.value(); +} + +[[nodiscard]] MachineMappingResult get_mapping_with_minimal_runtime( + std::unordered_set const &candidates) { + MachineMappingResult result = infeasible_machine_mapping_result(); + + for (MachineMappingResult const &candidate : candidates) { + result = minimize_runtime(result, candidate); + } + + return result; +} + +MachineMappingResult + series_combine(float comm_cost, + MachineMappingResult const &maybe_pre_result, + MachineMappingResult const &maybe_post_result, + std::optional const + ¶llel_split_transformation) { + FeasibleMachineMappingResult pre_result = ({ + if (is_infeasible(maybe_pre_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_pre_result); + }); + + FeasibleMachineMappingResult post_result = ({ + if (is_infeasible(maybe_post_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_post_result); + }); + + ParallelLayerGuidObliviousMachineMapping mapping = [&] { + if (parallel_split_transformation.has_value() && + parallel_split_transformation.value() == + ParallelSplitTransformation::RthenL) { + return binary_combine_mappings(/*lhs=*/post_result.machine_mapping, + /*rhs=*/pre_result.machine_mapping); + } else { + return binary_combine_mappings(/*lhs=*/pre_result.machine_mapping, + /*rhs=*/post_result.machine_mapping); + } + }(); + + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_result.runtime + comm_cost + post_result.runtime, + /*machine_mapping=*/mapping, + }, + }; +} + +MachineMappingResult + parallel_combine(MachineMappingResult const &maybe_lhs_result, + MachineMappingResult const &maybe_rhs_result) { + FeasibleMachineMappingResult lhs_result = ({ + if (is_infeasible(maybe_lhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_lhs_result); + }); + + FeasibleMachineMappingResult rhs_result = ({ + if (is_infeasible(maybe_rhs_result)) { + return infeasible_machine_mapping_result(); + } + require_feasible(maybe_rhs_result); + }); + + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/std::max(lhs_result.runtime, rhs_result.runtime), + /*machine_mapping=*/ + binary_combine_mappings(/*lhs=*/lhs_result.machine_mapping, + /*rhs=*/rhs_result.machine_mapping), + }, + }; +} + +MachineMappingResult minimize_runtime(MachineMappingResult const &maybe_m1, + MachineMappingResult const &maybe_m2) { + FeasibleMachineMappingResult m1 = ({ + if (is_infeasible(maybe_m1)) { + return maybe_m2; + } + require_feasible(maybe_m1); + }); + + FeasibleMachineMappingResult m2 = ({ + if (is_infeasible(maybe_m2)) { + return maybe_m1; + } + require_feasible(maybe_m2); + }); + + if (m2.runtime < m1.runtime) { + return maybe_m2; + } else { + return maybe_m1; + } +} + +MachineMappingResult + make_singleton_machine_mapping_result(float runtime, + MachineView const &machine_view) { + return MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/runtime, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), machine_view}, + }}, + }, + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc new file mode 100644 index 0000000000..715a4c2e3d --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -0,0 +1,24 @@ +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/try_at.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( + ParallelLayerGuidObliviousMachineMapping const &lhs, + ParallelLayerGuidObliviousMachineMapping const &rhs) { + return ParallelLayerGuidObliviousMachineMapping{ + merge_maps(map_keys(lhs.raw_mapping, nest_inside_left_child), + map_keys(rhs.raw_mapping, nest_inside_right_child)), + }; +} + +std::optional get_machine_view_for_path( + ParallelLayerGuidObliviousMachineMapping const &mapping, + BinaryTreePath const &path) { + return try_at(mapping.raw_mapping, path); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc new file mode 100644 index 0000000000..96c8106cad --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/transitive_reduced_pcg.cc @@ -0,0 +1,93 @@ +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_underlying_transitive_reduced_dataflow_graph( + TransitiveReducedPCG const &tr_pcg) { + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/tr_pcg.full_pcg.raw_graph, + /*transitive_reduction=*/tr_pcg.transitive_reduction, + }; +} + +TransitiveReducedPCG + pcg_get_transitive_reduction(ParallelComputationGraph const &pcg) { + DiGraphView raw_digraph = pcg.raw_graph; + DiGraphView transitive_reduced = transitive_reduction(raw_digraph); + + return TransitiveReducedPCG{ + /*pcg=*/pcg, + /*transitive_reduction=*/transitive_reduced, + }; +} + +std::unordered_set + pcg_get_transitive_reduced_edges_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + std::unordered_set raw_edges = + get_transitive_reduced_edges_across_split(raw_tr_g, raw_split); + + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + +std::unordered_set + pcg_get_transitive_reduced_tensors_across_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + std::unordered_set raw_outputs = + get_transitive_reduced_outputs_across_split(raw_tr_g, raw_split); + + return transform(raw_outputs, [](DataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); +} + +PCGSplitBoundaryLayers pcg_get_transitive_reduced_boundary_layers_for_split( + TransitiveReducedPCG const &tr_pcg, PCGBinarySeriesSplit const &split) { + TransitiveReducedDataflowGraphView raw_tr_g = + get_underlying_transitive_reduced_dataflow_graph(tr_pcg); + + BinarySeriesSplit raw_split = + binary_series_split_from_pcg_series_split(split); + + SplitBoundaryNodes raw_boundary = + get_transitive_reduced_boundary_nodes_for_split(raw_tr_g, raw_split); + + return PCGSplitBoundaryLayers{ + /*pre_split_boundary=*/transform( + raw_boundary.pre_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + /*post_split_boundary=*/ + transform(raw_boundary.post_split_boundary, + [](Node const &n) { return parallel_layer_guid_t{n}; }), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc new file mode 100644 index 0000000000..32fb53b58a --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.cc @@ -0,0 +1,192 @@ +#include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t> + generic_impl_for_computation_graph_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + ComputationGraphBinarySPDecomposition, + ComputationGraphBinarySeriesSplit, + ComputationGraphBinaryParallelSplit, + layer_guid_t>{ + /*series_get_left_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](ComputationGraphBinarySeriesSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](ComputationGraphBinaryParallelSplit const &split) + -> ComputationGraphBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> SPDecompositionTreeNodeType { return get_node_type(tree); }, + /*require_series=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> ComputationGraphBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](ComputationGraphBinarySPDecomposition const &tree) + -> layer_guid_t const & { return tree.get(); }, + }; +} + +SPDecompositionTreeNodeType + get_node_type(ComputationGraphBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](ComputationGraphBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](ComputationGraphBinaryParallelSplit const ¶llel) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](layer_guid_t const &leaf) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &tree) { + return tree.get(); +} + +ComputationGraphBinarySPDecomposition + computation_graph_sp_decomp_from_binary_sp_decomp( + BinarySPDecompositionTree const &bin) { + return bin.visit(overload{ + [](BinarySeriesSplit const &series) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinarySeriesSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + series.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const ¶llel) { + return ComputationGraphBinarySPDecomposition{ + ComputationGraphBinaryParallelSplit{ + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_left_child()), + computation_graph_sp_decomp_from_binary_sp_decomp( + parallel.get_right_child()), + }, + }; + }, + [](Node const &node) { + return ComputationGraphBinarySPDecomposition{ + layer_guid_t{node}, + }; + }, + }); +} + +std::optional + get_computation_graph_left_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + left_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +std::optional + get_computation_graph_right_assoc_binary_sp_decomposition( + ComputationGraph const &cg) { + SeriesParallelDecomposition sp_decomposition = ({ + std::optional result = + get_computation_graph_series_parallel_decomposition(cg); + if (!result.has_value()) { + return std::nullopt; + } + result.value(); + }); + + BinarySPDecompositionTree raw_binary_tree = + right_associative_binary_sp_tree_from_nary(sp_decomposition); + + return computation_graph_sp_decomp_from_binary_sp_decomp(raw_binary_tree); +} + +bool is_left_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_left_associative( + tree, generic_impl_for_computation_graph_sp_tree()); +} + +bool is_right_associative(ComputationGraphBinarySPDecomposition const &tree) { + return is_binary_sp_tree_right_associative( + tree, generic_impl_for_computation_graph_sp_tree()); +} + +std::unordered_multiset + get_layers(ComputationGraphBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_computation_graph_sp_tree()); +} + +V1BinarySPDecomposition + to_v1(ComputationGraphBinarySPDecomposition const &tree, + bidict const &layer_numbering) { + return tree.visit( + overload{[&](ComputationGraphBinarySeriesSplit const &series) { + return V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + to_v1(series.get_left_child(), layer_numbering), + to_v1(series.get_right_child(), layer_numbering), + }, + }; + }, + [&](ComputationGraphBinaryParallelSplit const ¶llel) { + return V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + to_v1(parallel.get_left_child(), layer_numbering), + to_v1(parallel.get_right_child(), layer_numbering), + }, + }; + }, + [&](layer_guid_t const &layer) { + return V1BinarySPDecomposition{ + layer_numbering.at_r(layer), + }; + }}); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc similarity index 97% rename from lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc rename to lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 184ad93f4d..8f78d423b3 100644 --- a/lib/compiler/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "op-attrs/computation_graph_op_attrs.h" #include "pcg/computation_graph.h" #include "pcg/computation_graph/computation_graph_edge.h" diff --git a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc deleted file mode 100644 index 63054385ac..0000000000 --- a/lib/compiler/src/compiler/series_parallel/computation_graph_binary_sp_decomposition.cc +++ /dev/null @@ -1,90 +0,0 @@ -#include "compiler/series_parallel/computation_graph_binary_sp_decomposition.h" -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.h" - -namespace FlexFlow { - -SPDecompositionTreeNodeType - get_node_type(ComputationGraphBinarySPDecomposition const &d) { - return get_node_type(d.raw_tree); -} - -ComputationGraphBinarySPDecomposition - get_left_child(ComputationGraphBinarySPDecomposition const &d) { - return ComputationGraphBinarySPDecomposition{ - get_left_child(d.raw_tree), - }; -} - -ComputationGraphBinarySPDecomposition - get_right_child(ComputationGraphBinarySPDecomposition const &d) { - return ComputationGraphBinarySPDecomposition{ - get_right_child(d.raw_tree), - }; -} - -layer_guid_t require_node(ComputationGraphBinarySPDecomposition const &d) { - return require_node(d.raw_tree); -} - -std::optional - get_computation_graph_left_assoc_binary_sp_decomposition( - ComputationGraph const &cg) { - SeriesParallelDecomposition sp_decomposition = ({ - std::optional result = - get_computation_graph_series_parallel_decomposition(cg); - if (!result.has_value()) { - return std::nullopt; - } - result.value(); - }); - - BinarySPDecompositionTree raw_binary_tree = - left_associative_binary_sp_tree_from_nary(sp_decomposition); - - return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; -} - -std::optional - get_computation_graph_right_assoc_binary_sp_decomposition( - ComputationGraph const &cg) { - SeriesParallelDecomposition sp_decomposition = ({ - std::optional result = - get_computation_graph_series_parallel_decomposition(cg); - if (!result.has_value()) { - return std::nullopt; - } - result.value(); - }); - - BinarySPDecompositionTree raw_binary_tree = - right_associative_binary_sp_tree_from_nary(sp_decomposition); - - return ComputationGraphBinarySPDecomposition{transform( - raw_binary_tree.raw_tree, [](Node const &n) { return layer_guid_t{n}; })}; -} - -bool is_left_associative(ComputationGraphBinarySPDecomposition const &d) { - return is_binary_sp_tree_left_associative(d.raw_tree); -} - -bool is_right_associative(ComputationGraphBinarySPDecomposition const &d) { - return is_binary_sp_tree_right_associative(d.raw_tree); -} - -std::unordered_multiset - get_layers(ComputationGraphBinarySPDecomposition const &d) { - return get_leaves(d.raw_tree); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc new file mode 100644 index 0000000000..220614bb8b --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.cc @@ -0,0 +1,10 @@ +#include "compiler/series_parallel/pcg/get_pcg_series_parallel_decomposition.h" + +namespace FlexFlow { + +std::optional + get_pcg_series_parallel_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc new file mode 100644 index 0000000000..657a3c3166 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_parallel_split.cc @@ -0,0 +1,14 @@ +#include "compiler/series_parallel/pcg/pcg_binary_parallel_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinaryParallelSplit binary_parallel_split_from_pcg_parallel_split( + PCGBinaryParallelSplit const &pcg_split) { + return BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc new file mode 100644 index 0000000000..304ad224b1 --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_series_split.cc @@ -0,0 +1,14 @@ +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" + +namespace FlexFlow { + +BinarySeriesSplit binary_series_split_from_pcg_series_split( + PCGBinarySeriesSplit const &pcg_split) { + return BinarySeriesSplit{ + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(pcg_split.get_right_child()), + }; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc new file mode 100644 index 0000000000..5eb993c6ef --- /dev/null +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -0,0 +1,115 @@ +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "compiler/series_parallel/pcg/pcg_binary_series_split.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/overload.h" + +namespace FlexFlow { + +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_pcg_sp_tree() { + + return GenericBinarySPDecompositionTreeImplementation< + PCGBinarySPDecomposition, + PCGBinarySeriesSplit, + PCGBinaryParallelSplit, + parallel_layer_guid_t>{ + /*series_get_left_child=*/[](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](PCGBinarySeriesSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](PCGBinaryParallelSplit const &split) + -> PCGBinarySPDecomposition const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](PCGBinarySPDecomposition const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](PCGBinarySPDecomposition const &tree) -> PCGBinarySeriesSplit const & { + return tree.get(); + }, + /*require_parallel=*/ + [](PCGBinarySPDecomposition const &tree) + -> PCGBinaryParallelSplit const & { + return tree.get(); + }, + /*require_leaf=*/ + [](PCGBinarySPDecomposition const &tree) + -> parallel_layer_guid_t const & { + return tree.get(); + }, + }; +} + +BinarySPDecompositionTree + binary_sp_tree_from_pcg_sp_tree(PCGBinarySPDecomposition const &pcg_tree) { + return pcg_tree.visit(overload{ + [](PCGBinarySeriesSplit const &series) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + binary_series_split_from_pcg_series_split(series), + }; + }, + [](PCGBinaryParallelSplit const ¶llel) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{ + binary_sp_tree_from_pcg_sp_tree(parallel.get_left_child()), + binary_sp_tree_from_pcg_sp_tree(parallel.get_right_child()), + }, + }; + }, + [](parallel_layer_guid_t const &layer) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + layer.raw_graph_node, + }; + }, + }); +} + +std::optional + get_pcg_balanced_binary_sp_decomposition(ParallelComputationGraph const &) { + NOT_IMPLEMENTED(); +} + +std::unordered_multiset + get_parallel_layers(PCGBinarySPDecomposition const &tree) { + return get_leaves(tree, generic_impl_for_pcg_sp_tree()); +} + +SPDecompositionTreeNodeType + get_node_type(PCGBinarySPDecomposition const &tree) { + return tree.visit(overload{ + [](PCGBinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](PCGBinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](parallel_layer_guid_t const &) { + return SPDecompositionTreeNodeType::NODE; + }, + }); +} + +std::unordered_set + find_paths_to_leaf(PCGBinarySPDecomposition const &tree, + parallel_layer_guid_t const &leaf) { + return find_paths_to_leaf(tree, generic_impl_for_pcg_sp_tree(), leaf); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc deleted file mode 100644 index a19c5e8597..0000000000 --- a/lib/compiler/src/graph_utils.cc +++ /dev/null @@ -1,153 +0,0 @@ -#include "compiler/graph_utils.h" -#include "pcg/computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "substitutions/sub_parallel_computation_graph.dtg.h" -#include "utils/containers/without_order.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" -namespace FlexFlow { - -SeriesParallelDecomposition - get_series_parallel_decomposition(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return get_series_parallel_decomposition(pcg.raw_graph); -} - -ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { - NOT_IMPLEMENTED(); -} - -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &pcg) { - NOT_IMPLEMENTED(); - // return view_output_labelled_as_output_labelled_open(pcg.raw_graph); -} - -// std::vector -// get_sorted_node_input_edges(ParallelComputationGraph const &pcg, -// Node const &n) { -// std::unordered_map> -// incoming_edges = -// get_incoming_edges_by_idx(pcg, n); - -// std::vector result; -// for (auto const &p_id_edge_set : incoming_edges) { -// result.push_back(get_only(p_id_edge_set.second)); -// } - -// return result; -// } - -// std::unordered_map -// infer_tensor_shapes(ParallelComputationGraph const &pcg) { -// std::unordered_map result; -// for (Node const &n : get_topological_ordering(pcg)) { -// PCGOperatorAttrs op = pcg.raw_graph.at(n); - -// std::vector input_tensor_shapes = -// vector_transform([&](MultiDiEdge const &e) { return result.at(e); }, -// get_sorted_node_input_edges(pcg, n)); - -// std::vector output_tensor_shapes = -// get_output_shapes(op, input_tensor_shapes); - -// auto outgoing_edges = get_outgoing_edges_by_idx(pcg, n); - -// int i = 0; - -// for (auto const &[node_port, edges] : outgoing_edges) { -// for (MultiDiEdge const &e : edges) { -// result.insert({e, output_tensor_shapes[i++]}); -// } -// } -// } - -// assert(result.size() == get_edges(pcg.raw_graph).size()); - -// return result; -// } - -/* template */ -/* LabelledOpenMultiDiGraph */ -/* get_subgraph(LabelledOpenMultiDiGraph const &g, */ -/* std::unordered_set const &nodes, */ -/* InputSettings input_settings, */ -/* OutputSettings output_settings) { */ - -/* auto iview = LabelledOpenMultiDiGraphView(g) */ -/* .unsafe(); */ - -/* if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledOpenMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view(subgraph_view); */ -/* } else if (input_settings == InputSettings::INCLUDE_INPUTS && */ -/* output_settings == OutputSettings::EXCLUDE_OUTPUTS) { */ -/* LabelledUpwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else if (input_settings == InputSettings::EXCLUDE_INPUTS && */ -/* output_settings == OutputSettings::INCLUDE_OUTPUTS) { */ -/* LabelledDownwardMultiDiSubgraphView */ -/* subgraph_view(*iview, nodes); */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); */ -/* } else { */ -/* LabelledMultiDiSubgraphView subgraph_view(*iview, - */ -/* nodes); - */ -/* return materialize_labelled_openmultidigraph_view( */ -/* view_as_labelled_open_multidisubgraph(subgraph_view)); - */ -/* } */ -/* } */ - -// struct GetNodes { -// template -// std::unordered_set operator()(T const &t) { -// return get_nodes(t); -// } -// }; - -// std::unordered_set get_nodes(SeriesParallelDecomposition const &sp) { -// return std::visit(GetNodes{}, sp.raw_variant); -// } - -// std::unordered_set get_nodes(SeriesSplit const &serial) { -// return set_union( -// transform(serial.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(ParallelSplit const ¶llel) { -// return set_union( -// transform(parallel.children, [](std::variant const -// child) { -// return std::visit(GetNodes{}, child); -// })); -// } - -// std::unordered_set get_nodes(Node const &node) { -// return {node}; -// } - -} // namespace FlexFlow diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc deleted file mode 100644 index fddd825109..0000000000 --- a/lib/compiler/src/machine_mapping.cc +++ /dev/null @@ -1,366 +0,0 @@ -#include "compiler/machine_mapping.h" -#include "compiler/cost_estimate.h" -#include "compiler/graph_utils.h" -#include "pcg/machine_specification.dtg.h" -#include "pcg/machine_specification.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/machine_view.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers.h" -#include "utils/containers/are_disjoint.h" -#include "utils/containers/contains_key.h" -#include "utils/containers/get_only.h" -#include "utils/containers/keys.h" -#include "utils/containers/merge_maps.h" -#include "utils/containers/require_no_duplicates.h" -#include "utils/containers/vector_of.h" -#include "utils/exception.h" -#include "utils/graph/graph_split.dtg.h" -#include "utils/graph/node/algorithms.h" -#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" -#include "utils/graph/series_parallel/series_parallel_decomposition.h" -#include "utils/graph/series_parallel/series_parallel_splits.h" - -namespace FlexFlow { - -MachineMapping combine(MachineMapping const &s1, MachineMapping const &s2) { - return MachineMapping{merge_maps(s1.machine_views, s2.machine_views)}; -} - -bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { - return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); -} - -OptimalCostResult - OptimalCostResult::sequential_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{s1.runtime + s2.runtime, - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult - OptimalCostResult::parallel_combine(OptimalCostResult const &s1, - OptimalCostResult const &s2) { - return OptimalCostResult{std::max(s1.runtime, s2.runtime), - combine(s1.machine_mapping, s2.machine_mapping)}; -} - -OptimalCostResult OptimalCostResult::infinity() { - return {std::numeric_limits::infinity(), - MachineMapping{std::unordered_map{}}}; -} - -bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, - OptimalCostResult const &rhs) { - return lhs.runtime < rhs.runtime; -} - -std::optional - OptimalCostCache::load(OptimalCostState const &state) const { - if (contains_key(cache, state)) { - OptimalCostResult result = cache.at(state); - return std::make_optional(result); - } - return std::nullopt; -} - -void OptimalCostCache::save(OptimalCostState const &state, - OptimalCostResult const &result) { - assert(!contains_key(cache, state)); - cache.emplace(state, result); -} - -std::vector> - get_resource_split(MachineSpecification const &resource) { - std::vector> result; - for (int i = 1; i < resource.num_nodes; ++i) { - MachineSpecification sub_resource1 = resource, sub_resource2 = resource; - sub_resource1.num_nodes = i; - sub_resource2.num_nodes = resource.num_nodes - i; - result.push_back(std::make_pair(sub_resource1, sub_resource2)); - } - return result; -} - -// We may replace this by having unflattened AST -std::pair - decompose(SeriesSplit const &serial) { - if (serial.children.size() == 2) { - return {widen(serial.children[0]), - widen(serial.children[1])}; - } - SeriesSplit decompn1 = serial; - decompn1.children.pop_back(); - return {SeriesParallelDecomposition(decompn1), - widen(serial.children.back())}; -} - -std::pair - decompose(ParallelSplit const ¶llel) { - if (parallel.children.size() == 2) { - std::vector children = - transform(vector_of(parallel.children), [&](auto const &child) { - return widen(child); - }); - return {children[0], children[1]}; - } - ParallelSplit decompn1 = parallel; - std::variant child = *parallel.children.begin(); - decompn1.children.erase(child); - return {SeriesParallelDecomposition(decompn1), - widen(child)}; -} - -GraphSplit - get_graph_split(SeriesParallelDecomposition const &pre_decomposition, - SeriesParallelDecomposition const &post_decomposition) { - std::unordered_set pre_nodes = - require_no_duplicates(get_nodes(pre_decomposition)); - std::unordered_set post_nodes = - require_no_duplicates(get_nodes(post_decomposition)); - assert(are_disjoint(pre_nodes, post_nodes)); - return GraphSplit{pre_nodes, post_nodes}; -} - -float estimate_cost(SubParallelComputationGraph const &g, - CostEstimator const &estimator, - MachineMapping const &device_mapping, - std::unordered_map const - &frontier_machine_views) { - // TODO: Consider parallelism - float cost = 0; - // for (Node const &node : get_nodes(g.raw_graph)) { - // std::vector incoming_edges = - // get_incoming_edges(g.raw_graph, node); - // std::vector inputs = - // transform(incoming_edges, - // [&](OpenDataflowEdge const &input_edge) { - // return g.raw_graph.at(input_edge).get_shape(); - // }); - // cost += estimator.estimate_cost( - // g.raw_graph.at(node).op_attrs, inputs, - // device_mapping.machine_views.at(node)); - // } - return cost; -} - -void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { - minimize(m1, m2, OptimalCostRuntimeCmp{}); -} - -struct MachineMappingSearcher { - MachineMappingSearcher( - CostEstimator cost_estimator, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - OptimalCostCache &cached_subgraph_costs) - : cost_estimator(cost_estimator), - allowed_machine_views(allowed_machine_views), - cached_subgraph_costs(cached_subgraph_costs) {} - - CostEstimator cost_estimator; - std::function(ParallelLayerAttrs const &, - MachineSpecification const &)> - allowed_machine_views; - OptimalCostCache &cached_subgraph_costs; - - struct OptimalCostFunctor { - OptimalCostFunctor( - MachineMappingSearcher *searcher, - SubParallelComputationGraph const &g, - MachineSpecification resource, - std::unordered_map given_machine_views, - std::unordered_map - frontier_machine_views) - : searcher(searcher), g(g), resource(resource), - given_machine_views(given_machine_views), - frontier_machine_views(frontier_machine_views) {} - - MachineMappingSearcher *searcher; - SubParallelComputationGraph const &g; - MachineSpecification resource; - std::unordered_map given_machine_views; - std::unordered_map frontier_machine_views; - - template - OptimalCostResult operator()(T const &t) { - OptimalCostState state{SeriesParallelDecomposition{t}, - resource, - given_machine_views, - frontier_machine_views}; - std::optional cached_result = - searcher->cached_subgraph_costs.load(state); - - if (cached_result) { - return cached_result.value(); - } - OptimalCostResult result = searcher->optimal_cost( - t, g, resource, given_machine_views, frontier_machine_views); - - searcher->cached_subgraph_costs.save(state, result); - return result; - } - }; - - OptimalCostResult - optimal_cost(SubParallelComputationGraph const &g, - MachineSpecification resource, - SeriesParallelDecomposition const &sp_decomposition) { - return std::visit(OptimalCostFunctor(this, g, resource, {}, {}), - sp_decomposition.raw_variant); - } - - OptimalCostResult optimal_cost( - SeriesSplit const &serial, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - NOT_IMPLEMENTED(); - // OptimalCostResult optimal_result = OptimalCostResult::infinity(); - - // auto decomposed = decompose(serial); - // SeriesParallelDecomposition pre_decompn = decomposed.first; - // SeriesParallelDecomposition post_decompn = decomposed.second; - - // GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); - // SubParallelComputationGraph pre_graph = - // get_subgraph(g, graph_split.first); - // SubParallelComputationGraph post_graph = - // get_subgraph(g, graph_split.second); - - // std::unordered_set post_graph_sources = - // get_closed_sources(post_graph); - - // assert(post_graph_sources.size() == 1); // assume perfect SP - - // Node split_point = get_only(post_graph_sources); - // OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); - - // for (MachineView const &mv : - // allowed_machine_views(g.raw_graph.at(split_point), resource)) { - // std::unordered_map new_given_machine_views = - // given_machine_views; - // new_given_machine_views.emplace(split_point, mv); - // std::unordered_map - // new_frontier_machine_views = frontier_machine_views; - // new_frontier_machine_views.emplace(split_edge, mv); - // minimize_runtime( - // optimal_result, - // OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // pre_graph, - // resource, - // given_machine_views, - // new_frontier_machine_views), - // pre_decompn.raw_variant), - // std::visit(OptimalCostFunctor(this, - // post_graph, - // resource, - // new_given_machine_views, - // frontier_machine_views), - // post_decompn.raw_variant))); - // } - - // return optimal_result; - } - - OptimalCostResult optimal_cost( - ParallelSplit const ¶llel, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - - NOT_IMPLEMENTED(); - // auto decomposed = decompose(parallel); - // SeriesParallelDecomposition decompn1 = decomposed.first; - // SeriesParallelDecomposition decompn2 = decomposed.second; - - // GraphSplit graph_split = get_graph_split(decompn1, decompn2); - // SubParallelComputationGraph g1 = get_subgraph(g, graph_split.first), - // g2 = get_subgraph(g, graph_split.second); - - // OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant)); - - // for (auto const &resource_split : get_resource_split(resource)) { - // minimize_runtime( - // optimal_result, - // OptimalCostResult::parallel_combine( - // std::visit(OptimalCostFunctor(this, - // g1, - // resource_split.first, - // given_machine_views, - // frontier_machine_views), - // decompn1.raw_variant), - // std::visit(OptimalCostFunctor(this, - // g2, - // resource_split.second, - // given_machine_views, - // frontier_machine_views), - // decompn2.raw_variant))); - // } - - // return optimal_result; - } - - OptimalCostResult optimal_cost( - Node const &node, - SubParallelComputationGraph const &g, - MachineSpecification const &resource, - std::unordered_map const &given_machine_views, - std::unordered_map const - &frontier_machine_views) { - if (contains_key(given_machine_views, node)) { - assert(contains(allowed_machine_views(g.raw_graph.at(node), resource), - given_machine_views.at(node))); - MachineMapping mv_map{given_machine_views}; - return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), - mv_map}; - } else { - OptimalCostResult optimal_result = OptimalCostResult::infinity(); - for (auto mv : allowed_machine_views(g.raw_graph.at(node), resource)) { - MachineMapping mv_map{{{node, mv}}}; - minimize_runtime( - optimal_result, - {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), - mv_map}); - } - return optimal_result; - } - } -}; - -OptimalCostResult optimal_cost( - ParallelComputationGraph const &g, - std::function( - ParallelLayerAttrs const &, MachineSpecification const &)> const - &allowed_machine_views, - CostEstimator const &cost_estimator, - MachineSpecification const &resources, - OptimalCostCache &cached_subgraph_costs) { - SeriesParallelDecomposition sp_decomposition = - get_series_parallel_decomposition(g); - SubParallelComputationGraph subpcg = pcg_to_subpcg(g); - MachineMappingSearcher searcher( - cost_estimator, allowed_machine_views, cached_subgraph_costs); - return searcher.optimal_cost(subpcg, resources, sp_decomposition); -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index ba6ef28daa..86a211c535 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -1,20 +1,16 @@ #include "compiler/unity_algorithm.h" -#include "compiler/graph_utils.h" -#include "compiler/machine_mapping.h" +#include "compiler/graph_optimize_state.h" +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" #include "pcg/machine_specification.dtg.h" #include "substitutions/substitution.h" #include "utils/deduplicated_priority_queue.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { -bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { - return lhs.runtime < rhs.runtime; -} - /* * Gets all substitutions applicable to a PCG */ -std::unordered_set +std::vector get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { NOT_IMPLEMENTED(); } @@ -22,14 +18,14 @@ std::unordered_set /* * Applies a substitution to all possible positions in PCG */ -std::unordered_set +std::vector apply_substitution(ParallelComputationGraph const &pcg, Substitution const &) { NOT_IMPLEMENTED(); } -Strategy graph_optimize( - ComputationGraph &cg, +GraphOptimizeResult graph_optimize( + ParallelComputationGraph &pcg, CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( @@ -37,58 +33,61 @@ Strategy graph_optimize( &allowed_machine_views, OptimizerConfig const &opt_config) { NOT_IMPLEMENTED(); - // ParallelComputationGraph pcg = cg_to_pcg(cg); - - // std::unordered_set subs = - // get_all_applicable_substitutions(pcg); - - // OptimalCostCache cached_subgraph_costs; - // DeduplicatedPriorityQueue, - // StrategyRuntimeCmp> - // candidates; - - // OptimalCostResult initial_pcg_result = optimal_cost(pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy initial_result{ - // pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; - - // Strategy best_result = initial_result; - // candidates.push(initial_result); + // std::vector substitutions = + // get_all_applicable_substitutions(pcg); + // + // MachineMappingCache cached_subgraph_costs; + // DeduplicatedPriorityQueue candidates; + // + // MachineMappingResult original_pcg_cost = + // get_optimal_machine_mapping(pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // + // GraphOptimizeState initial_state = { + // GraphOptimizeResult(pcg, original_pcg_cost.machine_mapping), + // original_pcg_cost.runtime}; + // + // GraphOptimizeState best_state = initial_state; + // candidates.push(initial_state); + // // for (int iteration = 0; !candidates.empty() && iteration < // opt_config.budget; // ++iteration) { - // Strategy const ¤t_result = candidates.top(); + // GraphOptimizeState current_state = candidates.top(); // candidates.pop(); - - // if (current_result.runtime < best_result.runtime) { - // best_result = current_result; - // } else if (current_result.runtime > - // best_result.runtime * opt_config.alpha) { + // + // if (current_state.runtime < best_state.runtime) { + // best_state = current_state; + // } else if (current_state.runtime > best_state.runtime * opt_config.alpha) + // { // continue; // } - - // for (auto const &sub : subs) { - // for (auto const &new_pcg : apply_substitution(current_result.pcg, sub)) - // { - // OptimalCostResult c = optimal_cost(new_pcg, - // allowed_machine_views, - // cost_estimator, - // resources, - // cached_subgraph_costs); - // Strategy new_result{new_pcg, c.machine_mapping, c.runtime}; - // if (new_result.runtime <= opt_config.threshold && + // + // for (Substitution const &substitution : substitutions) { + // for (ParallelComputationGraph const &new_pcg : apply_substitution( + // current_state.graph_optimize_result.pcg, substitution)) { + // MachineMappingResult new_pcg_cost = + // get_optimal_machine_mapping(new_pcg, + // allowed_machine_views, + // cost_estimator, + // resources, + // cached_subgraph_costs); + // GraphOptimizeState new_state{ + // GraphOptimizeResult(new_pcg, new_pcg_cost.machine_mapping), + // new_pcg_cost.runtime}; + // if (new_pcg_cost.runtime <= opt_config.threshold && // get_nodes(new_pcg.raw_graph).size() <= opt_config.max_num_ops) { - // candidates.push(new_result); + // candidates.push(new_state); // } // } // } // } - // return best_result; + // return best_state.graph_optimize_result; } } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..5c8ea1c0f1 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -0,0 +1,300 @@ +#include "compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_abstracted_tensor_set_movement_across_split") { + auto make_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelLayerAttrs ew_add_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + SUBCASE("no edges across split") { + ParallelLayerAddedResult input1 = pcg_add_input_layer(pcg, input_shape); + ParallelLayerAddedResult input2 = pcg_add_input_layer(pcg, input_shape); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_leaf(input1.parallel_layer), + make_leaf(input2.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{}, + }; + + CHECK(result == correct); + } + + SUBCASE("single edge across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_leaf(layer_2.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not include edges removed by transitive reduction") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_series_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_leaf(layer_3.parallel_layer), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split(make_leaf(input.parallel_layer), + make_leaf(layer_1.parallel_layer)), + make_parallel_split(make_leaf(layer_2.parallel_layer), + make_leaf(layer_3.parallel_layer)), + }; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("multiple tensors, multiple consumers across split") { + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAddedResult layer_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(layer_1.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult layer_4 = add_parallel_layer( + pcg, + ew_add_attrs, + {get_only(layer_1.outputs), get_only(layer_2.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_series_split( + make_leaf(input.parallel_layer), + make_parallel_split(make_leaf(layer_1.parallel_layer), + make_leaf(layer_2.parallel_layer))), + make_parallel_split(make_leaf(layer_3.parallel_layer), + make_leaf(layer_4.parallel_layer))}; + + AbstractedTensorSetMovement result = + get_abstracted_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split); + + AbstractedTensorSetMovement correct = AbstractedTensorSetMovement{ + /*single_tensor_movements=*/{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + }, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc new file mode 100644 index 0000000000..9ee596af3e --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -0,0 +1,41 @@ +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" + +namespace FlexFlow { + +TestCostEstimator::TestCostEstimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost) + : get_operator_cost(get_operator_cost), + get_communication_cost(get_communication_cost) {} + +float TestCostEstimator::estimate_cost(OpCostEstimateKey const &k) const { + return this->get_operator_cost(k); +} + +float TestCostEstimator::estimate_cost(TensorSetMovement const &m) const { + return this->get_communication_cost(m); +} + +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost) { + + return CostEstimator::create(get_operator_cost, + get_communication_cost); +} + +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map) { + return make_fake_cost_estimator( + [op_cost_map](OpCostEstimateKey const &k) { return op_cost_map.at(k); }, + [comm_cost_map](TensorSetMovement const &m) { + return comm_cost_map.at(m); + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h new file mode 100644 index 0000000000..7c1d06207a --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H +#define _FLEXFLOW_TEST_COST_ESTIMATOR_H + +#include "compiler/cost_estimator/cost_estimator.h" +#include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.dtg.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" + +namespace FlexFlow { + +struct TestCostEstimator : public ICostEstimator { + std::function get_operator_cost; + std::function get_communication_cost; + + TestCostEstimator() = delete; + TestCostEstimator(decltype(get_operator_cost) const &get_operator_cost, + decltype(get_communication_cost) + const &get_communication_cost); + + float estimate_cost(OpCostEstimateKey const &) const override; + + float estimate_cost(TensorSetMovement const &) const override; +}; + +CostEstimator make_fake_cost_estimator( + std::function const &get_operator_cost, + std::function const + &get_communication_cost); + +CostEstimator make_fake_cost_estimator( + std::unordered_map const &op_cost_map, + std::unordered_map const &comm_cost_map); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc new file mode 100644 index 0000000000..499b111f8f --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_machine_resource_splits.cc @@ -0,0 +1,235 @@ +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/hash/pair.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_resource_splits") { + auto make_machine_spec = [](int num_nodes, int num_gpus_per_node) { + return MachineSpecification{ + /*num_nodes=*/num_nodes, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/num_gpus_per_node, + /*inter_node_bandwidth=*/1.0, + /*intra_node_bandwidth=*/1.0, + }; + }; + + SUBCASE("returns no splits if no splits are possible") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1); + + std::unordered_set> + result = get_machine_resource_splits(input); + std::unordered_set> + correct = {}; + + CHECK(result == correct); + } + + SUBCASE( + "returns splits in gpu and node dimensions, but not at the same time") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/2); + + std::unordered_set> + result = get_machine_resource_splits(input); + + std::unordered_set> + correct = { + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + + }; + + CHECK(result == correct); + } + + SUBCASE("returns splits in node dimension in powers of two") { + SUBCASE("num_nodes is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/8, + /*num_gpus_per_node=*/1); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/7, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_nodes is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/6, + /*num_gpus_per_node=*/1); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/4, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/2, + /*num_gpus_per_node=*/1), + }, + { + make_machine_spec(/*num_nodes=*/5, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + } + + SUBCASE("returns splits in gpu dimension in powers of two") { + SUBCASE("num_gpus_per_node is a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/8); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/7), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + + CHECK(result == correct); + } + + SUBCASE("num_gpus_per_node is not a power of 2") { + MachineSpecification input = make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/6); + + std::unordered_set< + std::pair> + result = get_machine_resource_splits(input); + + std::unordered_set< + std::pair> + correct = { + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/4), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/2), + }, + { + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/5), + make_machine_spec(/*num_nodes=*/1, + /*num_gpus_per_node=*/1), + }, + }; + } + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc new file mode 100644 index 0000000000..0a874948e4 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -0,0 +1,236 @@ +#include "compiler/machine_mapping/get_optimal_machine_mapping.h" +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_cache.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_optimal_machine_mapping") { + auto make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto make_series_split = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_set_movement, + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + auto make_parallel_split = [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + MachineView mv1 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(2)); + MachineView mv2 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(3)); + + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/2, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + MachineSpecification split_machine_spec = MachineSpecification{ + /*num_nodes=*/1, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, + MachineSpecification const &resources) { + if (resources == full_machine_spec) { + return std::unordered_set{mv1, mv2}; + } else { + return std::unordered_set{mv2}; + } + }; + + UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + ParallelTensorShape tensor_shape1 = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/tensor_shape1, + /*src_machine_views=*/{}, + /*dst_machine_views=*/{}, + }, + }}; + + ParallelLayerGuidObliviousMachineMapping mm1 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}; + ParallelLayerGuidObliviousMachineMapping mm2 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv2}, + }}; + + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{{ + {map_unmapped_op_cost_estimate_key(k1, mv1), 1.0}, + {map_unmapped_op_cost_estimate_key(k2, mv1), 2.0}, + {map_unmapped_op_cost_estimate_key(k1, mv2), 1.5}, + {map_unmapped_op_cost_estimate_key(k2, mv2), 2.5}, + }}, + std::unordered_map{{ + {TensorSetMovement{{}}, 0.0}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), + 0.1}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), + 0.2}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), + 0.3}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), + 0.4}, + }}); + + MachineMappingContext context = MachineMappingContext{ + cost_estimator, + allowed_machine_views1, + }; + + MachineMappingCache cache = empty_machine_mapping_cache(); + + SUBCASE("single layer") { + MachineMappingProblemTree problem_tree = make_leaf(k1); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in sequence") { + MachineMappingProblemTree problem_tree = + make_series_split(movement1, make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/1.0 + 2.0 + 0.1, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in parallel") { + MachineMappingProblemTree problem_tree = + make_parallel_split(make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingResult result = get_optimal_machine_mapping( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.5, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv2, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc new file mode 100644 index 0000000000..82210a138b --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -0,0 +1,239 @@ +// #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +// #include "compiler/machine_mapping/transitive_reduced_pcg.h" +// #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" +// #include "pcg/machine_view.h" +// #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +// #include +// "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +// #include "utils/containers/get_only.h" +// #include +// #include "./cost_estimator_for_test.h" +// +// using namespace ::FlexFlow; +// +// TEST_SUITE(FF_TEST_SUITE) { +// TEST_CASE("get_tensor_set_movement_across_split") { +// ParallelComputationGraph pcg = empty_parallel_computation_graph(); +// +// ParallelTensorShape input_shape = +// ParallelTensorShape{ +// ParallelTensorDims{ +// FFOrdered{ +// ShardParallelDim{10, 2}, +// ShardParallelDim{12, 1}, +// }, +// ReplicaParallelDimSet{ +// SumDegree{1}, +// DiscardCopyDegree{1}, +// }, +// }, +// DataType::FLOAT, +// }; +// ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); +// +// ParallelLayerAttrs relu_attrs +// = ParallelLayerAttrs{ +// /*op_attrs=*/PCGOperatorAttrs{ +// ElementUnaryAttrs{ +// /*op_type=*/OperatorType::RELU, +// /*scalar=*/std::nullopt, +// }, +// }, +// /*name=*/std::nullopt, +// }; +// +// ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ +// /*shape=*/input_shape, +// /*sync_type=*/std::nullopt, +// /*initializer=*/std::nullopt, +// /*create_gradients=*/CreateGrad::YES, +// }; +// +// ParallelLayerAddedResult relu_1 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(input.outputs)}, +// {relu_output_attrs}); +// ParallelLayerAddedResult relu_2 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(relu_1.outputs)}, +// {relu_output_attrs}); +// +// MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); +// MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); +// MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); +// MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); +// +// SUBCASE("single edge across split") { +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_leaf_node(relu_1.parallel_layer)), +// make_pcg_leaf_node(relu_2.parallel_layer))); +// +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// }}; +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("does not include edges removed by transitive reduction") { +// +// } +// +// SUBCASE("single tensor, multiple consumers across split") { +// ParallelLayerAddedResult relu_3 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(relu_1.outputs)}, +// {relu_output_attrs}); +// +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_leaf_node(relu_1.parallel_layer)), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_2.parallel_layer), +// make_pcg_leaf_node(relu_3.parallel_layer)))); +// +// SUBCASE("consumers have same view") { +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_3.parallel_layer, post_mv1}, +// }}; +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// +// SUBCASE("consumers have different views") { +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_3.parallel_layer, post_mv2}, +// }}; +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1, post_mv2}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// } +// +// SUBCASE("multiple tensors, multiple consumers across split") { +// ParallelLayerAddedResult relu_3 +// = add_parallel_layer(pcg, +// relu_attrs, +// {get_only(input.outputs)}, +// {relu_output_attrs}); +// +// ParallelLayerAddedResult relu_4 +// = add_parallel_layer(pcg, +// relu_attrs, +// // relu's don't have two inputs, but for the +// purposes of this test it's fine. +// {get_only(relu_1.outputs), +// get_only(relu_3.outputs)}, {relu_output_attrs}); +// +// PartialMachineMapping pre_mapping = PartialMachineMapping{{ +// {relu_1.parallel_layer, pre_mv1}, +// {relu_3.parallel_layer, pre_mv2}, +// }}; +// +// PartialMachineMapping post_mapping = PartialMachineMapping{{ +// {relu_2.parallel_layer, post_mv1}, +// {relu_4.parallel_layer, post_mv2}, +// }}; +// +// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( +// make_pcg_series_split( +// make_pcg_leaf_node(input.parallel_layer), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_1.parallel_layer), +// make_pcg_leaf_node(relu_3.parallel_layer))), +// make_pcg_parallel_split( +// make_pcg_leaf_node(relu_2.parallel_layer), +// make_pcg_leaf_node(relu_4.parallel_layer)))); +// +// TensorSetMovement result = +// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), +// split, +// pre_mapping, +// post_mapping); +// +// +// TensorSetMovement correct = TensorSetMovement{ +// /*single_tensor_movements=*/{ +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv1}, +// /*dst_machine_views=*/{post_mv1, post_mv2}, +// }, +// SingleTensorMovement{ +// /*parallel_tensor_shape=*/input_shape, +// /*src_machine_views=*/{pre_mv2}, +// /*dst_machine_views=*/{post_mv2}, +// }, +// }, +// }; +// +// CHECK(result == correct); +// } +// } +// } diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc new file mode 100644 index 0000000000..6b16a54c1f --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -0,0 +1,55 @@ +#include "compiler/machine_mapping/machine_mapping.h" +#include "cost_estimator_for_test.h" +#include "doctest/doctest.h" +#include "pcg/machine_view.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("combine_disjoint_mappings(MachineMapping, MachineMappping)") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_0 = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + }); + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + MachineMapping correct = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + MachineMapping result = + combine_disjoint_mappings(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + + TEST_CASE("nodes_are_disjoint(MachineMapping, MachineMappping)") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineMapping machine_mapping_0 = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + }); + + SUBCASE("nodes are disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + + bool correct = true; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + + SUBCASE("nodes are not disjoint") { + MachineMapping machine_mapping_1 = MachineMapping({ + {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t(Node(1)), machine_view_1}, + }); + bool correct = false; + bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc new file mode 100644 index 0000000000..06ab1e5b8c --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.cc @@ -0,0 +1,289 @@ +#include "compiler/machine_mapping/machine_mapping_problem_tree/get_machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_machine_mapping_problem_tree") { + auto pcg_make_leaf = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + auto pcg_make_series = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + PCGBinarySeriesSplit{ + lhs, + rhs, + }, + }; + }; + + auto pcg_make_parallel = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{ + PCGBinaryParallelSplit{ + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto mm_problem_tree_make_series = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + tensor_set_movement, + lhs, + rhs, + }, + }; + }; + + auto mm_problem_tree_make_parallel = + [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + lhs, + rhs, + }, + }; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + auto make_output_attrs = [](ParallelTensorShape const &shape) { + return ParallelTensorAttrs{ + /*shape=*/shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + }; + + auto make_layer_attrs = [](PCGOperatorAttrs const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }; + }; + + PCGOperatorAttrs input_attrs = PCGOperatorAttrs{InputAttrs{}}; + + auto make_input_key = + [&](ParallelTensorShape const ¶llel_tensor_shape) { + return UnmappedOpCostEstimateKey{ + /*op_attrs=*/input_attrs, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{parallel_tensor_shape}, + }; + }; + + SUBCASE("single layer") { + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + + PCGBinarySPDecomposition sp_decomposition = + PCGBinarySPDecomposition{input_layer}; + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + MachineMappingProblemTree correct = MachineMappingProblemTree{input_key}; + + CHECK(result == correct); + } + + SUBCASE("two layers in series") { + ParallelLayerAddedResult input_added = add_parallel_layer( + pcg, + /*layer_attrs=*/make_layer_attrs(input_attrs), + /*inputs=*/{}, + /*output_labels=*/{make_output_attrs(input_shape)}); + parallel_layer_guid_t input_layer = input_added.parallel_layer; + parallel_tensor_guid_t input = get_only(input_added.outputs); + + UnmappedOpCostEstimateKey input_key = make_input_key(input_shape); + + PCGOperatorAttrs relu_attrs = PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }; + ParallelTensorShape relu_output_shape = input_shape; + ParallelLayerAddedResult relu_added = + add_parallel_layer(pcg, + make_layer_attrs(relu_attrs), + {input}, + {make_output_attrs(relu_output_shape)}); + parallel_layer_guid_t relu_layer = relu_added.parallel_layer; + parallel_tensor_guid_t relu_output = get_only(relu_added.outputs); + + UnmappedOpCostEstimateKey relu_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/relu_attrs, + /*input_shapes=*/{input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{relu_output_shape}, + }; + + PCGBinarySPDecomposition sp_decomposition = pcg_make_series( + pcg_make_leaf(input_layer), pcg_make_leaf(relu_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = mm_problem_tree_make_series( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{}}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }}, + mm_problem_tree_make_leaf(input_key), + mm_problem_tree_make_leaf(relu_key)); + + CHECK(result == correct); + } + + SUBCASE("two layers in parallel") { + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); + + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); + + PCGBinarySPDecomposition sp_decomposition = pcg_make_parallel( + pcg_make_leaf(input1_layer), pcg_make_leaf(input2_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)); + + CHECK(result == correct); + } + + SUBCASE("multiple tensors across split") { + ParallelLayerAddedResult input1_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input1_layer = input1_added.parallel_layer; + parallel_tensor_guid_t input1_tensor = get_only(input1_added.outputs); + UnmappedOpCostEstimateKey input1_key = make_input_key(input_shape); + + ParallelLayerAddedResult input2_added = + pcg_add_input_layer(pcg, input_shape); + parallel_layer_guid_t input2_layer = input2_added.parallel_layer; + parallel_tensor_guid_t input2_tensor = get_only(input2_added.outputs); + UnmappedOpCostEstimateKey input2_key = make_input_key(input_shape); + + PCGOperatorAttrs ew_op_attrs = PCGOperatorAttrs{ + ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }, + }; + ParallelTensorShape ew_op_output_shape = input_shape; + ParallelLayerAddedResult ew_op_added = + add_parallel_layer(pcg, + make_layer_attrs(ew_op_attrs), + {input1_tensor, input2_tensor}, + {make_output_attrs(ew_op_output_shape)}); + parallel_layer_guid_t ew_op_layer = ew_op_added.parallel_layer; + UnmappedOpCostEstimateKey ew_op_key = UnmappedOpCostEstimateKey{ + /*op_attrs=*/ew_op_attrs, + /*input_shapes=*/{input_shape, input_shape}, + /*weight_shapes=*/{}, + /*output_shapes=*/{ew_op_output_shape}, + }; + + PCGBinarySPDecomposition sp_decomposition = + pcg_make_series(pcg_make_parallel(pcg_make_leaf(input1_layer), + pcg_make_leaf(input2_layer)), + pcg_make_leaf(ew_op_layer)); + + MachineMappingProblemTree result = + get_machine_mapping_problem_tree(pcg, sp_decomposition); + + MachineMappingProblemTree correct = mm_problem_tree_make_series( + AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + }, + /*dst_machine_views=*/ + { + BinaryTreePath{{}}, + }, + }, + }}, + /*pre=*/ + mm_problem_tree_make_parallel(mm_problem_tree_make_leaf(input1_key), + mm_problem_tree_make_leaf(input2_key)), + /*post=*/mm_problem_tree_make_leaf(ew_op_key)); + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc new file mode 100644 index 0000000000..254d6b2784 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -0,0 +1,342 @@ +#include "compiler/machine_mapping/machine_mapping_result.h" +#include "pcg/machine_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("series_combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + float pre_cost = 2.0; + MachineMappingResult pre = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + float post_cost = 4.0; + MachineMappingResult post = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + float comm_cost = 3.0; + + SUBCASE("pre is infeasbile") { + MachineMappingResult result = series_combine( + comm_cost, infeasible, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("post is infeasbile") { + MachineMappingResult result = series_combine( + comm_cost, pre, infeasible, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = + series_combine(comm_cost, + infeasible, + infeasible, + ParallelSplitTransformation::LthenR); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult no_parallel_split_transform = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + SUBCASE("parallel_split_transformation = std::nullopt") { + MachineMappingResult result = + series_combine(comm_cost, pre, post, std::nullopt); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = LthenR") { + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::LthenR); + MachineMappingResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = RthenL") { + MachineMappingResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::RthenL); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/pre_cost + comm_cost + post_cost, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel_combine") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + MachineMappingResult lhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult rhs = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasbile") { + MachineMappingResult result = parallel_combine(infeasible, rhs); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasbile") { + MachineMappingResult result = parallel_combine(lhs, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = parallel_combine(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + MachineMappingResult result = parallel_combine(lhs, rhs); + MachineMappingResult correct = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("minimize_runtime") { + MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); + MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + + MachineMappingResult faster = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/2.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult slower = MachineMappingResult{ + FeasibleMachineMappingResult{ + /*runtime=*/4.0, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + machine_view_1, + }, + }}, + }, + }; + + MachineMappingResult infeasible = infeasible_machine_mapping_result(); + + SUBCASE("lhs is infeasbile") { + MachineMappingResult result = minimize_runtime(infeasible, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("rhs is infeasible") { + MachineMappingResult result = minimize_runtime(slower, infeasible); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + + SUBCASE("both are infeasible") { + MachineMappingResult result = minimize_runtime(infeasible, infeasible); + MachineMappingResult correct = infeasible; + + CHECK(result == correct); + } + + SUBCASE("both are feasible") { + SUBCASE("lhs is faster") { + MachineMappingResult result = minimize_runtime(faster, slower); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("rhs is faster") { + MachineMappingResult result = minimize_runtime(slower, faster); + MachineMappingResult correct = faster; + + CHECK(result == correct); + } + + SUBCASE("lhs and rhs have the same speed") { + MachineMappingResult result = minimize_runtime(slower, slower); + MachineMappingResult correct = slower; + + CHECK(result == correct); + } + } + } +} diff --git a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc similarity index 96% rename from lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc rename to lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc index 564cffaebe..2b59669aad 100644 --- a/lib/compiler/test/src/compiler/series_parallel/get_computation_graph_series_parallel_decomposition.cc +++ b/lib/compiler/test/src/compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.cc @@ -1,4 +1,4 @@ -#include "compiler/series_parallel/get_computation_graph_series_parallel_decomposition.h" +#include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" #include "models/inception_v3/inception_v3.h" @@ -89,14 +89,14 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ - ParallelSplit{ + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ input_layer.raw_node, projection_weights_layer.raw_node, bias_weights_layer.raw_node, - }, + }}, operator_layer.raw_node, - }}; + }}}; CHECK(result == correct); } @@ -159,17 +159,17 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ - ParallelSplit{ + SeriesParallelDecomposition{SeriesSplit{{ + ParallelSplit{{ w1.raw_node, input.raw_node, w2.raw_node, - }, - ParallelSplit{ + }}, + ParallelSplit{{ op1.raw_node, op2.raw_node, - }, - }}; + }}, + }}}; } SUBCASE("SP with or without preprocessing, but preprocessing would SP " @@ -214,16 +214,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_computation_graph_series_parallel_decomposition(cg); std::optional correct = - SeriesParallelDecomposition{ParallelSplit{ - SeriesSplit{ + SeriesParallelDecomposition{ParallelSplit{{ + SeriesSplit{{ input1.raw_node, op1.raw_node, - }, - SeriesSplit{ + }}, + SeriesSplit{{ input2.raw_node, op2.raw_node, - }, - }}; + }}, + }}}; } SUBCASE("not SP with or without weight nodes") { diff --git a/lib/compiler/test/src/graph_optimize_state.cc b/lib/compiler/test/src/graph_optimize_state.cc new file mode 100644 index 0000000000..46177ad420 --- /dev/null +++ b/lib/compiler/test/src/graph_optimize_state.cc @@ -0,0 +1,80 @@ +#include "compiler/graph_optimize_state.h" +#include "doctest/doctest.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("GraphOptimizeState::operator==") { + ParallelComputationGraphBuilder builder; + + ParallelTensorShape input_shape = + ParallelTensorShape{ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{32, 2}, + ShardParallelDim{16, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT}; + + parallel_tensor_guid_t input0 = + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); + parallel_tensor_guid_t dense0 = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + parallel_tensor_guid_t dense1 = builder.dense(dense0, + 4, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense1"); + + ParallelComputationGraph pcg = builder.pcg; + + // `machine_mapping` is determined by the PCG and the device mapping + // algorithm, and `runtime` is determined by the PCG and the device mapping, + // so their values here do not matter. + std::unordered_map empty_machine_views; + MachineMapping empty_machine_mapping(empty_machine_views); + bool result1 = + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), 0); + bool correct1 = true; + CHECK(result1 == correct1); + + ParallelComputationGraphBuilder builder_; + + parallel_tensor_guid_t input0_ = + builder.create_input_tensor(input_shape, CreateGrad::YES, "input0"); + parallel_tensor_guid_t dense0_ = builder.dense(input0, + 8, + Activation::RELU, + true, + DataType::FLOAT, + std::nullopt, + std::nullopt, + "dense0"); + + ParallelComputationGraph pcg_ = builder.pcg; + + bool result2 = + GraphOptimizeState(GraphOptimizeResult(pcg, empty_machine_mapping), + 0) == + GraphOptimizeState(GraphOptimizeResult(pcg_, empty_machine_mapping), 0); + bool correct2 = false; + CHECK(result2 == correct2); + } +} diff --git a/lib/compiler/test/src/test_cost_estimator.h b/lib/compiler/test/src/test_cost_estimator.h deleted file mode 100644 index 9417b863e4..0000000000 --- a/lib/compiler/test/src/test_cost_estimator.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_TEST_COST_ESTIMATOR_H -#define _FLEXFLOW_TEST_COST_ESTIMATOR_H - -#include "compiler/cost_estimate.h" - -namespace FlexFlow { - -struct TestCostEstimator : public ICostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - std::vector const &weights, - std::vector const &outputs, - MachineView const &mv) const override { - return 0.1; - } - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override { - return 0.1; - } -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h deleted file mode 100644 index 9f5a768b27..0000000000 --- a/lib/compiler/test/src/test_generator.h +++ /dev/null @@ -1,174 +0,0 @@ -#ifndef _FLEXFLOW_TEST_GENERATOR_H -#define _FLEXFLOW_TEST_GENERATOR_H - -#include "compiler/machine_mapping.h" -#include "pcg/computation_graph.h" -#include "rapidcheck.h" -#include "substitutions/sub_parallel_computation_graph.h" - -using namespace FlexFlow; - -// Rapidcheck does not work for now -// /* -// Generates computation graphs with trivial layers and tensors, which are -// used for tests focusing on graph structures. -// */ -// ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, -// [](Tensor(MultiDiOutput const &)) { -// return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; -// })); -// } - -// /* -// Generates parallel computation graphs with trivial layers and tensors, -// which are used for tests focusing on graph structures. -// */ -// ParallelComputationGraph -// test_parallel_computation_graph(MultiDiGraphView const &g) { -// return materialize_output_labelled_multidigraph_view( -// ViewMultiDiGraphAsOutputLabelled( -// g, -// [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, -// [](Operator(MultiDiOutput const &)) { -// return ParallelTensor(ParallelTensorDims(TensorDims({})), -// DataType::FLOAT); -// })); -// } - -// rc::Gen small_integer_generator() { -// return rc::gen::inRange(1, 4); -// } - -// namespace rc { - -// Gen serialParallelMultiDiGraph() { -// return gen::map(gen::arbitrary(), -// multidigraph_from_sp_decomposition); -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_computataion_graph); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return -// gen::map(gen::cast(serialParallelMultiDiGraph()), -// test_parallel_computation_graph); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>(gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary> { -// static Gen> arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_node) { -// return is_node -// ? gen::cast>(gen::arbitrary()) -// : gen::cast>( -// gen::arbitrary()); -// }); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Serial::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&Parallel::children, -// gen::container>>( -// gen::arbitrary>()))); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::mapcat(gen::arbitrary(), [](bool is_serial) { -// return is_serial ? gen::construct( -// gen::arbitrary()) -// : gen::construct( -// gen::arbitrary()); -// }); -// } -// }; - -// template -// struct Arbitrary { -// static Gen< -// std::enable_if, -// Tag>::value>::type> arbitrary() { -// return gen::construct(gen::arbitrary()); -// } -// }; - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::apply(make_1d_machine_view, -// gen::arbitrary, -// gen::arbitrary, -// small_integer_generator()); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineMapping::machine_views, -// gen::container>( -// gen::arbitrary(), -// gen::arbitrary()))); -// } -// } - -// template <> -// struct Arbitrary { -// static Gen arbitrary() { -// return gen::build( -// gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), -// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, -// 64)), gen::set(&MachineSpecification::num_gpus_per_node, -// gen::inRange(1, 16)), -// gen::set(&MachineSpecification::inter_node_bandwidth, -// gen::nonZero()), -// gen::set(&MachineSpecification::intra_node_bandwidth, -// gen::nonZero())); -// } -// } - -// } // namespace rc - -#endif diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc deleted file mode 100644 index 59fa0f1e5e..0000000000 --- a/lib/compiler/test/src/test_labelled_open_graph.cc +++ /dev/null @@ -1,132 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// // #include "rapidcheck.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { -// auto g = OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); -// Node n2 = g.add_node(); -// Node n3 = g.add_node(); -// Node n4 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); -// NodePort p2 = g.add_node_port(); -// NodePort p3 = g.add_node_port(); -// NodePort p4 = g.add_node_port(); -// NodePort p5 = g.add_node_port(); -// NodePort p6 = g.add_node_port(); -// NodePort p7 = g.add_node_port(); -// NodePort p8 = g.add_node_port(); -// NodePort p9 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; -// MultiDiEdge e1{n2, p2, n0, p0}; -// MultiDiEdge e2{n3, p5, n1, p3}; -// MultiDiEdge e3{n3, p6, n2, p4}; -// MultiDiEdge e4{n4, p8, n3, p7}; -// OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// std::unordered_set node_set0{n3, n4}; - -// auto subgraph0 = get_subgraph(g, node_set0); -// auto subgraph1 = get_subgraph(g, -// node_set0); auto subgraph2 = -// get_subgraph(g, node_set0); -// auto subgraph3 = get_subgraph(g, node_set0); - -// CHECK(bool(get_nodes(subgraph0) == node_set0)); -// CHECK(bool(get_nodes(subgraph1) == node_set0)); -// CHECK(bool(get_nodes(subgraph2) == node_set0)); -// CHECK(bool(get_nodes(subgraph3) == node_set0)); - -// std::unordered_set input_set{split_edge(e2).second, -// split_edge(e3).second}; -// std::unordered_set output_set{e5}; - -// CHECK(bool(get_open_inputs(subgraph0) == input_set)); -// CHECK(bool(get_open_inputs(subgraph1) == input_set)); -// CHECK(bool(get_open_inputs(subgraph2).empty())); -// CHECK(bool(get_open_inputs(subgraph3).empty())); - -// CHECK(bool(get_open_outputs(subgraph0) == output_set)); -// CHECK(bool(get_open_outputs(subgraph1).empty())); -// CHECK(bool(get_open_outputs(subgraph2) == output_set)); -// CHECK(bool(get_open_outputs(subgraph3).empty())); - -// CHECK(bool(get_edges(subgraph0) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4, e5})); -// CHECK(bool(get_edges(subgraph1) == -// std::unordered_set{ -// split_edge(e2).second, split_edge(e3).second, e4})); -// CHECK(bool(get_edges(subgraph2) == -// std::unordered_set{e4, e5})); -// CHECK( -// bool(get_edges(subgraph3) == -// std::unordered_set{e4})); - -// CHECK(bool(get_closed_sources(subgraph2) == -// std::unordered_set{n3})); -// } - -// TEST_CASE("view OutputLabelledMultiDiGraph as open") { -// OutputLabelledMultiDiGraph g = -// OutputLabelledMultiDiGraph::create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_output(e0, 2); - -// CHECK(bool(get_edges(g).size() == 1)); - -// OutputLabelledOpenMultiDiGraphView open_graph = -// view_output_labelled_as_output_labelled_open(g); - -// CHECK(bool(open_graph.at(n0) == 0)); -// CHECK(bool(open_graph.at(n1) == 1)); -// CHECK(bool(open_graph.at(e0) == 2)); - -// CHECK(get_edges(open_graph).size() == 1); -// } - -// TEST_CASE("OutputLabelledOpenMultiDiGraph") { -// OutputLabelledOpenMultiDiGraph g = -// OutputLabelledOpenMultiDiGraph::create< -// UnorderedOutputLabelledOpenMultiDiGraph>(); - -// Node n0 = g.add_node(0); -// Node n1 = g.add_node(1); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// MultiDiEdge e0{n1, p1, n0, p0}; - -// g.add_edge(e0); -// g.add_label(e0, 2); - -// CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); -// CHECK(bool(get_edges(g).size() == 1)); -// } -// } diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc deleted file mode 100644 index 4f9b879574..0000000000 --- a/lib/compiler/test/src/test_machine_mapping.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "doctest/doctest.h" -#include "test_generator.h" - -TEST_SUITE(FF_TEST_SUITE) { - // TEST_CASE("MachineMapping::combine") { - // RC_SUBCASE([](MachineMapping const &m0, MachineMapping const &m1) { - // RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); - - // MachineMapping comb = MachineMapping::combine(m0, m1); - - // RC_ASSERT(comb.machine_views.size() == - // m0.machine_views.size() + m1.machine_views.size()); - // RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); - // RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); - // }); - // } - - // TEST_CASE("OptimalCostResult::infinity") { - // RC_SUBCASE([](OptimalCostResult const &c) { - // RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); - // }); - // } -} diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc deleted file mode 100644 index e3426aa293..0000000000 --- a/lib/compiler/test/src/test_open_graph.cc +++ /dev/null @@ -1,81 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "utils/graph/algorithms.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_source_sink_open_graph") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// NodePort p0 = g.add_node_port(); -// InputMultiDiEdge e0{ -// n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; -// g.add_edge(e0); - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{})); -// } - -// TEST_CASE("get_source_sink_open_graph:unconnected") { -// OpenMultiDiGraph g = -// OpenMultiDiGraph::create(); - -// Node n0 = g.add_node(); -// Node n1 = g.add_node(); - -// NodePort p0 = g.add_node_port(); -// NodePort p1 = g.add_node_port(); - -// InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; -// OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; -// g.add_edge(e0); -// g.add_edge(e1); - -// /* -// g: ->n0 -// n1-> -// */ - -// CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); -// CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); - -// CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); -// CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); -// } - -// TEST_CASE("get_cut") { -// auto g = OpenMultiDiGraph::create(); - -// std::vector ns = add_nodes(g, 5); - -// MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; -// MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; -// MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; -// MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; -// OutputMultiDiEdge e5{ -// ns[4], g.add_node_port(), std::make_pair(ns[4].value(), -// ns[4].value())}; - -// g.add_edge(e0); -// g.add_edge(e1); -// g.add_edge(e2); -// g.add_edge(e3); -// g.add_edge(e4); -// g.add_edge(e5); - -// GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; -// CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, -// e2})); - -// GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; -// CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, -// e4})); -// } -// } diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc deleted file mode 100644 index 133558f83a..0000000000 --- a/lib/compiler/test/src/test_optimal_cost.cc +++ /dev/null @@ -1,72 +0,0 @@ -// #include "compiler/unity_algorithm.h" -// #include "doctest/doctest.h" -// #include "test_cost_estimator.h" - -// using namespace FlexFlow; - -// TEST_SUITE(FF_TEST_SUITE) { -// // Rapidcheck infrastructures for graphs does not work for now -// /* -// Tests whether optimal_cost can give a valid result given random PCG, -// trivial allowed machine views, trivial cost estimator and random machine -// specification. -// */ -// // TEST_CASE("optimal_cost") { -// // auto test_allowed_machine_views = [](Operator const &, -// // MachineSpecification const &) { -// // return std::unordered_set{make_1d_machine_view(0, 1, -// 1)}; -// // }; -// // RC_SUBCASE([](ParallelComputationGraph const &g, -// // MachineSpecification const &machine_spec) { -// // OptimalCostCache cached_subgraph_costs; -// // OptimalCostResult result = optimal_cost(g, -// // test_allowed_machine_views, -// // TestCostEstimator{}, -// // machine_spec, -// // cached_subgraph_costs); -// // RC_ASSERT(result.runtime > 0); -// // RC_ASSERT(keys(result.machine_mapping.machine_views) == -// get_nodes(g)); -// // }); -// // } - -// TEST_CASE("optimal_cost_0") { -// auto pcg = -// OutputLabelledMultiDiGraph::template -// create< -// UnorderedOutputLabelledMultiDiGraph>(); - -// Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); -// Node n1 = pcg.add_node(Operator{ -// LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, -// std::nullopt}, "linear"}); - -// MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; -// pcg.add_edge(e); -// ParallelDim dim = {2, 1, false}; -// ParallelTensorDims dims = {FFOrdered{dim}}; -// pcg.add_output(e, ParallelTensor(dims, DataType::FLOAT, -// CreateGrad::YES)); - -// auto test_allowed_machine_views = [](Operator const &, -// MachineSpecification const &) { -// return std::unordered_set{ -// make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; -// }; - -// CostEstimator estimator = CostEstimator::create(); - -// MachineSpecification machine_spec{1, 1, 1, 1, 1}; - -// OptimalCostCache cached_results; - -// OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), -// test_allowed_machine_views, -// estimator, -// machine_spec, -// cached_results); - -// CHECK(bool(result.runtime > 0)); -// } -// } diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/unity_algorithm.cc similarity index 93% rename from lib/compiler/test/src/test_unity_algorithm.cc rename to lib/compiler/test/src/unity_algorithm.cc index ed5e895a75..8ff0978ea5 100644 --- a/lib/compiler/test/src/test_unity_algorithm.cc +++ b/lib/compiler/test/src/unity_algorithm.cc @@ -1,7 +1,5 @@ #include "compiler/unity_algorithm.h" #include "doctest/doctest.h" -#include "test_cost_estimator.h" -#include "test_generator.h" TEST_SUITE(FF_TEST_SUITE) { // Rapidcheck does not work for now diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index f523520f9f..5fbcd91a06 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -18,8 +18,8 @@ class GenericTensorAccessorW { if (this->data_type == DT) { return static_cast *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -49,8 +49,8 @@ class GenericTensorAccessorR { if (this->data_type == DT) { return static_cast const *>(this->ptr); } else { - throw mk_runtime_error( - "Invalid access data type ({} != {})", this->data_type, DT); + throw mk_runtime_error(fmt::format( + "Invalid access data type ({} != {})", this->data_type, DT)); } } @@ -97,7 +97,7 @@ typename data_type_enum_to_class
::type * return static_cast *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } @@ -118,7 +118,7 @@ typename data_type_enum_to_class
::type const * return static_cast const *>(a.ptr); } else { throw mk_runtime_error( - "Invalid access data type ({} != {})", a.data_type, DT); + fmt::format("Invalid access data type ({} != {})", a.data_type, DT)); } } diff --git a/lib/kernels/include/kernels/datatype_dispatch.h b/lib/kernels/include/kernels/datatype_dispatch.h index e6ab9fa8cc..e83fc3325d 100644 --- a/lib/kernels/include/kernels/datatype_dispatch.h +++ b/lib/kernels/include/kernels/datatype_dispatch.h @@ -22,7 +22,7 @@ Out dispatch(DataType dt, Args &&...args) { case DataType::BOOL: return F{}(std::forward(args)...); default: - throw mk_runtime_error("Unknown datatype {}", dt); + throw mk_runtime_error(fmt::format("Unknown datatype {}", dt)); } } diff --git a/lib/kernels/include/kernels/linear_kernels.h b/lib/kernels/include/kernels/linear_kernels.h index c761eaf1d9..3128e39fd0 100644 --- a/lib/kernels/include/kernels/linear_kernels.h +++ b/lib/kernels/include/kernels/linear_kernels.h @@ -4,7 +4,7 @@ #include "device.h" #include "ff_handle.h" #include "op-attrs/datatype.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/include/local-execution/device_specific.h b/lib/local-execution/include/local-execution/device_specific.h index 3a36e02327..4035aaf7cf 100644 --- a/lib/local-execution/include/local-execution/device_specific.h +++ b/lib/local-execution/include/local-execution/device_specific.h @@ -28,10 +28,11 @@ struct DeviceSpecific { T const *get(size_t curr_device_idx) const { if (curr_device_idx != this->device_idx) { - throw mk_runtime_error("Invalid access to DeviceSpecific: attempted " - "device_idx {} != correct device_idx {})", - curr_device_idx, - this->device_idx); + throw mk_runtime_error( + fmt::format("Invalid access to DeviceSpecific: attempted " + "device_idx {} != correct device_idx {})", + curr_device_idx, + this->device_idx)); } return (T const *)this->ptr.get(); } diff --git a/lib/local-execution/include/local-execution/permissions.h b/lib/local-execution/include/local-execution/permissions.h index ce19e38e7e..f34969f233 100644 --- a/lib/local-execution/include/local-execution/permissions.h +++ b/lib/local-execution/include/local-execution/permissions.h @@ -42,8 +42,8 @@ struct formatter<::FlexFlow::Permissions> : formatter { name = "READ_WRITE"; break; default: - throw ::FlexFlow::mk_runtime_error("Unknown permission {}", - static_cast(p)); + throw ::FlexFlow::mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } return formatter::format(name, ctx); } diff --git a/lib/local-execution/src/local_task_argument_accessor.cc b/lib/local-execution/src/local_task_argument_accessor.cc index 5d0156201e..54eca7e514 100644 --- a/lib/local-execution/src/local_task_argument_accessor.cc +++ b/lib/local-execution/src/local_task_argument_accessor.cc @@ -30,7 +30,7 @@ GenericTensorAccessor LocalTaskArgumentAccessor::get_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( @@ -49,7 +49,7 @@ VariadicGenericTensorAccessor LocalTaskArgumentAccessor::get_variadic_tensor( } else if (priv == Permissions::RW || priv == Permissions::WO) { return variadic_tensor_backing; } else { - throw mk_runtime_error("Unhandled privilege mode {}", priv); + throw mk_runtime_error(fmt::format("Unhandled privilege mode {}", priv)); } } diff --git a/lib/local-execution/src/ops/batch_matmul.h b/lib/local-execution/src/ops/batch_matmul.h index c082dec020..a7e29b1931 100644 --- a/lib/local-execution/src/ops/batch_matmul.h +++ b/lib/local-execution/src/ops/batch_matmul.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/op_task_signature.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_matmul.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/batch_norm.h b/lib/local-execution/src/ops/batch_norm.h index 1f6cceec19..36aa8ffa4e 100644 --- a/lib/local-execution/src/ops/batch_norm.h +++ b/lib/local-execution/src/ops/batch_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/batch_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/cast.h b/lib/local-execution/src/ops/cast.h index b4a1e91c91..e7af6aca6b 100644 --- a/lib/local-execution/src/ops/cast.h +++ b/lib/local-execution/src/ops/cast.h @@ -17,7 +17,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/cast_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/combine.h b/lib/local-execution/src/ops/combine.h index c6157a2955..e85e8fba39 100644 --- a/lib/local-execution/src/ops/combine.h +++ b/lib/local-execution/src/ops/combine.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/combine_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/concat.h b/lib/local-execution/src/ops/concat.h index 1f1443f25d..eab70d621c 100644 --- a/lib/local-execution/src/ops/concat.h +++ b/lib/local-execution/src/ops/concat.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/concat.h" +#include "op-attrs/ops/concat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/conv_2d.h b/lib/local-execution/src/ops/conv_2d.h index f70d36d514..0358d71eea 100644 --- a/lib/local-execution/src/ops/conv_2d.h +++ b/lib/local-execution/src/ops/conv_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/conv_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/dropout.h b/lib/local-execution/src/ops/dropout.h index 84b67a29c2..a3dc5ff8af 100644 --- a/lib/local-execution/src/ops/dropout.h +++ b/lib/local-execution/src/ops/dropout.h @@ -4,7 +4,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" #include "local-execution/task_id_t.dtg.h" -#include "op-attrs/ops/dropout.h" +#include "op-attrs/ops/dropout_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_binary.h b/lib/local-execution/src/ops/element_binary.h index 05273e34b4..72c0976df8 100644 --- a/lib/local-execution/src/ops/element_binary.h +++ b/lib/local-execution/src/ops/element_binary.h @@ -3,7 +3,7 @@ #include "local-execution/sim_environment.h" #include "local-execution/task_signature_impl.h" -#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_binary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/element_unary.h b/lib/local-execution/src/ops/element_unary.h index 4d1783f1f6..04a72e2e12 100644 --- a/lib/local-execution/src/ops/element_unary.h +++ b/lib/local-execution/src/ops/element_unary.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/element_unary_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/embedding.h b/lib/local-execution/src/ops/embedding.h index 0463984122..995d2296e1 100644 --- a/lib/local-execution/src/ops/embedding.h +++ b/lib/local-execution/src/ops/embedding.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/embedding_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/flat.h b/lib/local-execution/src/ops/flat.h index d1501f85ca..e019bfc654 100644 --- a/lib/local-execution/src/ops/flat.h +++ b/lib/local-execution/src/ops/flat.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_FLAT_H #include "local-execution/sim_environment.h" -#include "op-attrs/ops/flat.h" +#include "op-attrs/ops/flat_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/gather.h b/lib/local-execution/src/ops/gather.h index 74db276e35..e339683381 100644 --- a/lib/local-execution/src/ops/gather.h +++ b/lib/local-execution/src/ops/gather.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/gather.h" +#include "op-attrs/ops/gather_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/input.h b/lib/local-execution/src/ops/input.h index 97985585e1..baad25b798 100644 --- a/lib/local-execution/src/ops/input.h +++ b/lib/local-execution/src/ops/input.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_INPUT_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" +#include "op-attrs/ops/input_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/layer_norm.h b/lib/local-execution/src/ops/layer_norm.h index 4f8d87153b..8e034ac519 100644 --- a/lib/local-execution/src/ops/layer_norm.h +++ b/lib/local-execution/src/ops/layer_norm.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/layer_norm.h" +#include "op-attrs/ops/layer_norm_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/linear.h b/lib/local-execution/src/ops/linear.h index 2c76483df4..2aaf13a95a 100644 --- a/lib/local-execution/src/ops/linear.h +++ b/lib/local-execution/src/ops/linear.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/linear_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/noop.h b/lib/local-execution/src/ops/noop.h index 959f7dc054..1097adeb5e 100644 --- a/lib/local-execution/src/ops/noop.h +++ b/lib/local-execution/src/ops/noop.h @@ -2,9 +2,7 @@ #define _FLEXFLOW_NOOP_H #include "local-execution/op_task_invocation.h" -#include "op-attrs/ops/input.h" -#include "op-attrs/ops/noop.h" -#include "op-attrs/ops/weight_attrs.dtg.h" +#include "op-attrs/ops/noop_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/pool_2d.h b/lib/local-execution/src/ops/pool_2d.h index e8624185ac..908fd5462f 100644 --- a/lib/local-execution/src/ops/pool_2d.h +++ b/lib/local-execution/src/ops/pool_2d.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/pool_2d.h" +#include "op-attrs/ops/pool_2d_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduce.h b/lib/local-execution/src/ops/reduce.h index 92f0578757..7900c28159 100644 --- a/lib/local-execution/src/ops/reduce.h +++ b/lib/local-execution/src/ops/reduce.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduce.h" +#include "op-attrs/ops/reduce_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reduction.h b/lib/local-execution/src/ops/reduction.h index a0af4f3aea..56833602e6 100644 --- a/lib/local-execution/src/ops/reduction.h +++ b/lib/local-execution/src/ops/reduction.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/reduction_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/repartition.h b/lib/local-execution/src/ops/repartition.h index b38a93f8b1..5187d04ca0 100644 --- a/lib/local-execution/src/ops/repartition.h +++ b/lib/local-execution/src/ops/repartition.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/repartition_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/replicate.h b/lib/local-execution/src/ops/replicate.h index 77bda411c1..85d1dff41a 100644 --- a/lib/local-execution/src/ops/replicate.h +++ b/lib/local-execution/src/ops/replicate.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/replicate.h" +#include "op-attrs/ops/replicate_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reshape.h b/lib/local-execution/src/ops/reshape.h index 06a6b32597..37f07534ee 100644 --- a/lib/local-execution/src/ops/reshape.h +++ b/lib/local-execution/src/ops/reshape.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reshape.h" +#include "op-attrs/ops/reshape_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/reverse.h b/lib/local-execution/src/ops/reverse.h index 10072860b0..7c16073be7 100644 --- a/lib/local-execution/src/ops/reverse.h +++ b/lib/local-execution/src/ops/reverse.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/reverse.h" +#include "op-attrs/ops/reverse_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/softmax.h b/lib/local-execution/src/ops/softmax.h index b5756d92ff..d440fe7239 100644 --- a/lib/local-execution/src/ops/softmax.h +++ b/lib/local-execution/src/ops/softmax.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/softmax.h" +#include "op-attrs/ops/softmax_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/split.h b/lib/local-execution/src/ops/split.h index c82152b06a..dde46c20bf 100644 --- a/lib/local-execution/src/ops/split.h +++ b/lib/local-execution/src/ops/split.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/split.h" +#include "op-attrs/ops/split_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/topk.h b/lib/local-execution/src/ops/topk.h index b04d807400..c8f3175ebd 100644 --- a/lib/local-execution/src/ops/topk.h +++ b/lib/local-execution/src/ops/topk.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/topk.h" +#include "op-attrs/ops/topk_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/ops/transpose.h b/lib/local-execution/src/ops/transpose.h index 3feffc7d86..0f3a2e80a0 100644 --- a/lib/local-execution/src/ops/transpose.h +++ b/lib/local-execution/src/ops/transpose.h @@ -3,7 +3,7 @@ #include "local-execution/op_task_invocation.h" #include "local-execution/sim_environment.h" -#include "op-attrs/ops/transpose.h" +#include "op-attrs/ops/transpose_attrs.dtg.h" namespace FlexFlow { diff --git a/lib/local-execution/src/permissions.cc b/lib/local-execution/src/permissions.cc index e5c46b42f8..2286215987 100644 --- a/lib/local-execution/src/permissions.cc +++ b/lib/local-execution/src/permissions.cc @@ -33,7 +33,8 @@ static int as_int(Permissions p) { case Permissions::RW: return 2; default: - throw mk_runtime_error("Unknown permission {}", static_cast(p)); + throw mk_runtime_error( + fmt::format("Unknown permission {}", static_cast(p))); } } diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h deleted file mode 100644 index 268554b5be..0000000000 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ /dev/null @@ -1,50 +0,0 @@ -#ifndef _OPERATOR_PARAMS_H -#define _OPERATOR_PARAMS_H - -#include "op-attrs/ops/core.h" -#include "op-attrs/pcg_operator_attrs.dtg.h" -#include "ops/attention.h" -#include "ops/batch_matmul.h" -#include "ops/batch_norm.h" -#include "ops/broadcast.h" -#include "ops/cast.h" -#include "ops/combine.h" -#include "ops/concat.h" -#include "ops/conv_2d.h" -#include "ops/dropout.h" -#include "ops/element_binary.h" -#include "ops/element_unary.h" -#include "ops/embedding.h" -#include "ops/flat.h" -#include "ops/gather.h" -#include "ops/input.h" -#include "ops/layer_norm.h" -#include "ops/linear.h" -#include "ops/noop.h" -#include "ops/pool_2d.h" -#include "ops/reduce.h" -#include "ops/reduction.h" -#include "ops/repartition.h" -#include "ops/replicate.h" -#include "ops/reshape.h" -#include "ops/reverse.h" -#include "ops/softmax.h" -#include "ops/split.h" -#include "ops/topk.h" -#include "ops/transpose.h" -#include "utils/record_formatter.h" -#include "utils/variant.h" -#include - -namespace FlexFlow { - -std::vector get_output_shapes( - PCGOperatorAttrs const &op_params, - std::vector const &input_tensor_shapes); - -bool is_valid(PCGOperatorAttrs const &, - std::vector const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/op-attrs/src/operator_attrs.cc b/lib/op-attrs/src/operator_attrs.cc deleted file mode 100644 index e6459c6819..0000000000 --- a/lib/op-attrs/src/operator_attrs.cc +++ /dev/null @@ -1,287 +0,0 @@ -#include "op-attrs/operator_attrs.h" -#include "utils/fmt.h" -#include "utils/record_formatter.h" -#include "utils/type_traits.h" - -namespace FlexFlow { - -/* OperatorType GetOpType::operator()(BatchMatmulAttrs const &p) const { return - * OP_BATCHMATMUL; } */ -/* OperatorType GetOpType::operator()(Conv2DAttrs const &p) const { return - * OP_CONV2D; } */ -/* OperatorType GetOpType::operator()(ConcatAttrs const &p) const { return - * OP_CONCAT; } */ -/* OperatorType GetOpType::operator()(CastAttrs const &p) const { return - * OP_CAST; } */ -/* OperatorType GetOpType::operator()(ElementBinaryAttrs const &p) const { - * return p.type; } */ -/* OperatorType GetOpType::operator()(ElementUnaryAttrs const &p) const { return - * p.op_type; } */ -/* OperatorType GetOpType::operator()(DropoutAttrs const &p) const { return - * OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(EmbeddingAttrs const &p) const { return - * OP_EMBEDDING; } */ -/* OperatorType GetOpType::operator()(FlatAttrs const &p) const { return - * OP_FLAT; } */ -/* OperatorType GetOpType::operator()(LayerNormAttrs const &p) const { return - * OP_LAYERNORM; } */ -/* OperatorType GetOpType::operator()(LinearAttrs const &p) const { return - * OP_LINEAR; } */ -/* OperatorType GetOpType::operator()(MultiHeadAttentionAttrs const &p) const { - * return OP_DROPOUT; } */ -/* OperatorType GetOpType::operator()(Pool2DAttrs const &p) const { return - * OP_POOL2D; } */ -/* OperatorType GetOpType::operator()(ReshapeAttrs const &p) const { return - * OP_RESHAPE; } */ -/* OperatorType GetOpType::operator()(SplitAttrs const &p) const { return - * OP_SPLIT; } */ -/* OperatorType GetOpType::operator()(SoftmaxAttrs const &p) const { return - * OP_SOFTMAX; } */ -/* OperatorType GetOpType::operator()(TransposeAttrs const &p) const { return - * OP_TRANSPOSE; } */ -/* OperatorType GetOpType::operator()(RepartitionAttrs const &p) const { return - * OP_REPARTITION; } */ -/* OperatorType GetOpType::operator()(ReplicateAttrs const &p) const { return - * OP_REPLICATE; } */ -/* OperatorType GetOpType::operator()(ReductionAttrs const &p) const { return - * OP_REDUCTION; } */ -/* OperatorType GetOpType::operator()(CombineAttrs const &p) const { return - * OP_COMBINE; } */ -/* OperatorType GetOpType::operator()(FusedParallelOpAttrs const &p) const { - * return OP_FUSED_PARALLEL; } */ - -/* struct AsOpAttrs { */ -/* template */ -/* OpAttrsInterface const &operator()(T const &p) { */ -/* return p; */ -/* } */ -/* }; */ - -/* OperatorType get_op_type(OpAttrsInterface const &o) { */ -/* return o.op_type(); */ -/* } */ -/* // */ -/* OperatorType get_op_type(CompGraphOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* OperatorType get_op_type(PCGOperatorAttrs const &o) { */ -/* return get_op_type(visit(AsOpAttrs{}, o)); */ -/* } */ - -/* std::vector get_output_shapes(PCGOperatorAttrs const - * &op_params, std::vector const &input_tensor_shapes) { */ -/* return mpark::visit(AsOpAttrs{}, - * op_params).output_shapes(input_tensor_shapes); */ -/* } */ - -/* bool is_parallel_op(PCGOperatorAttrs const &o) { */ -/* return is_parallel_op(get_op_type(o)); */ -/* } */ -template -typename std::enable_if<(is_streamable::value && - !is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - std::ostringstream oss; - oss << t; - r << oss; -} - -template -typename std::enable_if<(is_fmtable::value)>::type - as_dot(T const &t, RecordFormatter &r) { - r << fmt::to_string(t); -} -void as_dot(int x, RecordFormatter &r) { - r << std::to_string(x); -} - -void as_dot(std::string const &s, RecordFormatter &r) { - r << s; -} - -template -void as_dot(std::vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -template -void as_dot(stack_vector const &x, RecordFormatter &r) { - RecordFormatter rr; - for (T const &t : x) { - as_dot(t, r); - } - r << rr; -} - -struct as_dot_visitor { - as_dot_visitor() = delete; - as_dot_visitor(RecordFormatter &result) : result(result) {} - - RecordFormatter &result; - - template - void operator()(char const *name, T const &t) { - RecordFormatter kv; - kv << name; - as_dot(t, result); - result << kv; - } - - template - void operator()(T const &t) { - as_dot(t, result); - } - - /* template */ - /* void operator()(const char *name, std::vector const &t) { */ - /* RecordFormatter kv; */ - /* kv << name; */ - /* RecordFormatter v; */ - /* for (V const &vv : t) { */ - /* v << as_dot_str(vv); */ - /* } */ - /* kv << v; */ - /* } */ -}; - -template -typename std::enable_if::value>::type - as_dot(T const &t, RecordFormatter &r) { - as_dot_visitor vis(r); - visit_struct::for_each(t, vis); -} - -struct AsDot { - template - RecordFormatter operator()(T const &t) { - return as_dot(t); - } -}; - -template -RecordFormatter as_dot(std::variant const &o) { - return std::visit(AsDot{}, o); -} - -struct IsValidFunctor { - IsValidFunctor(std::vector const &_input_shapes) - : input_shapes(_input_shapes) {} - - std::vector const &input_shapes; - - // bool operator()(AggregateAttrs const &attrs) { - // return is_valid(attrs, - // input_shapes.at(0), - // input_shapes.at(1), - // input_shapes.at(2), - // input_shapes.at(3), - // subvec(input_shapes, 4, nullopt)); - // } - - template - bool operator()(T const &) { - return true; // TODO FIXME @lockshaw - } -}; - -bool is_valid(PCGOperatorAttrs const &attrs, - std::vector const &input_shapes) { - NOT_IMPLEMENTED(); -} - -/* int num_outputs(OperatorParameters const &o) { */ -/* switch (get_op_type(o)) { */ -/* case OP_SPLIT: */ -/* } */ -/* } */ - -// tl::optional get_op_parameters(Op const *op) { -// switch (op->op_type) { -// case OP_LINEAR: -// return ((Linear *)op)->get_params(); -// case OP_CONV2D: -// return ((Conv2D *)op)->get_params(); -// case OP_EW_ADD: -// case OP_EW_SUB: -// case OP_EW_MUL: -// case OP_EW_DIV: -// return ((ElementBinary *)op)->get_params(); -// case OP_EXP: -// case OP_SIN: -// case OP_COS: -// case OP_SCALAR_MULTIPLY: -// case OP_SCALAR_ADD: -// case OP_SCALAR_SUB: -// case OP_SCALAR_TRUE_DIV: -// case OP_RELU: -// case OP_SIGMOID: -// case OP_TANH: -// case OP_IDENTITY: -// case OP_GELU: -// case OP_ELU: -// return ((ElementUnary *)op)->get_params(); -// case OP_CONCAT: -// return ((Concat *)op)->get_params(); -// case OP_POOL2D: -// return ((Pool2D *)op)->get_params(); -// case OP_CAST: -// return ((Cast *)op)->get_params(); -// case OP_DROPOUT: -// return ((Dropout *)op)->get_params(); -// case OP_EMBEDDING: -// return ((Embedding *)op)->get_params(); -// case OP_FLAT: -// return ((Flat *)op)->get_params(); -// case OP_MULTIHEAD_ATTENTION: -// return ((MultiHeadAttention *)op)->get_params(); -// case OP_LAYERNORM: -// return ((LayerNorm *)op)->get_params(); -// case OP_RESHAPE: -// return ((Reshape *)op)->get_params(); -// case OP_SOFTMAX: -// return ((Softmax *)op)->get_params(); -// case OP_REPARTITION: -// return ((Repartition *)op)->get_params(); -// case OP_REPLICATE: -// return ((Replicate *)op)->get_params(); -// case OP_REDUCTION: -// return ((Reduction *)op)->get_params(); -// case OP_COMBINE: -// return ((Combine *)op)->get_params(); -// case OP_FUSED_PARALLEL: -// return ((FusedParallelOp *)op)->get_params(); -// case OP_TRANSPOSE: -// return ((Transpose *)op)->get_params(); -// case OP_BATCHMATMUL: -// return ((BatchMatmul *)op)->get_params(); -// case OP_SPLIT: -// return ((Split *)op)->get_params(); -// -// // TODO: implement the get_params() function for the operators below -// and -// // uncomment the lines below -// -// // case OP_NOOP: -// // return ((NoOp *)op)->get_params(); -// // case OP_TOPK: -// // return ((TopK *)op)->get_params(); -// // case OP_MEAN: -// // return ((Mean *)op)->get_params(); -// // case OP_CACHE: -// // return ((Cache *)op)->get_params(); -// // case OP_REVERSE: -// // return ((Reverse *)op)->get_params(); -// // case OP_BATCHNORM: -// // return ((BatchNorm *)op)->get_params(); -// -// default: -// return tl::nullopt; -// } -// } - -} // namespace FlexFlow diff --git a/lib/pcg/include/pcg/computation_graph.h b/lib/pcg/include/pcg/computation_graph.h index f70d9f7404..b29d683edb 100644 --- a/lib/pcg/include/pcg/computation_graph.h +++ b/lib/pcg/include/pcg/computation_graph.h @@ -52,6 +52,11 @@ LayerAttrs get_layer_attrs(ComputationGraph const &cg, layer_guid_t const &n); layer_guid_t get_layer_by_name(ComputationGraph const &cg, std::string const &name); +ComputationGraph without_layer_names(ComputationGraph const &); + +bool computation_graphs_are_isomorphic(ComputationGraph const &, + ComputationGraph const &); + std::string as_dot(ComputationGraph const &); void debug_print_dot(ComputationGraph const &); diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h new file mode 100644 index 0000000000..a1ca0aceed --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/json.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_FILE_FORMAT_V1_V1_BINARY_SP_DECOMPOSITION_JSON_H + +#include "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h" +#include + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::V1BinarySPDecomposition> { + static ::FlexFlow::V1BinarySPDecomposition from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySPDecomposition const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinarySeriesSplit> { + static ::FlexFlow::V1BinarySeriesSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinarySeriesSplit const &); +}; + +template <> +struct adl_serializer<::FlexFlow::V1BinaryParallelSplit> { + static ::FlexFlow::V1BinaryParallelSplit from_json(json const &); + static void to_json(json &, ::FlexFlow::V1BinaryParallelSplit const &); +}; + +} // namespace nlohmann + +#endif diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml new file mode 100644 index 0000000000..d2d0c3bc77 --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml new file mode 100644 index 0000000000..317fa8b6ce --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "V1BinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct V1BinarySPDecomposition" +] + +post_includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::V1BinarySPDecomposition" +indirect = true diff --git a/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml new file mode 100644 index 0000000000..0fe0b1761f --- /dev/null +++ b/lib/pcg/include/pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_sp_decomposition.variant.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "V1BinarySPDecomposition" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_series_split.dtg.h", + "pcg/file_format/v1/v1_binary_sp_decomposition/v1_binary_parallel_split.dtg.h", +] + +[[values]] +type = "::FlexFlow::V1BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::V1BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "int" +key = "leaf" diff --git a/lib/pcg/include/pcg/initializer_attrs.variant.toml b/lib/pcg/include/pcg/initializer_attrs.variant.toml index 2e878c5c53..c67999ff32 100644 --- a/lib/pcg/include/pcg/initializer_attrs.variant.toml +++ b/lib/pcg/include/pcg/initializer_attrs.variant.toml @@ -41,7 +41,7 @@ key = "normal" [[values]] type = "::FlexFlow::TruncatedNormalInitializerAttrs" -key = "normal" +key = "truncated_normal" [[values]] type = "::FlexFlow::ConstantInitializerAttrs" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index d7248afde4..c740e1ffd2 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_H #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" @@ -21,6 +22,15 @@ ParallelLayerAddedResult std::vector const &inputs, std::vector const &output_labels); +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape); + +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &, + parallel_layer_guid_t const &, + parallel_layer_guid_t const &); + std::vector get_incoming_tensors(ParallelComputationGraph const &, parallel_layer_guid_t const &); @@ -37,8 +47,12 @@ std::vector ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &, parallel_layer_guid_t const &); +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &, + parallel_layer_guid_t const &); ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &, parallel_tensor_guid_t const &); +ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); std::vector topological_ordering(ParallelComputationGraph const &); @@ -47,6 +61,11 @@ parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name); +ParallelComputationGraph without_layer_names(ParallelComputationGraph const &); + +bool pcgs_are_isomorphic(ParallelComputationGraph const &, + ParallelComputationGraph const &); + } // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml index 4d61f24d37..027b9f6c80 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_attrs.struct.toml @@ -10,7 +10,7 @@ features = [ ] includes = [ - "op-attrs/operator_attrs.h", + "op-attrs/pcg_operator_attrs.dtg.h", "utils/stack_string.h", "", ] diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index a69e54fd93..3d1bc629e4 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -11,6 +11,8 @@ #include "utils/graph/digraph/algorithms/get_subgraph_successors.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/node/algorithms.h" @@ -175,6 +177,26 @@ layer_guid_t get_layer_by_name(ComputationGraph const &cg, return get_only(found); } +ComputationGraph without_layer_names(ComputationGraph const &cg) { + return ComputationGraph{ + LabelledDataflowGraph::create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels(cg.raw_graph, + [](Node const &n, LayerAttrs const &old_attrs) { + LayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool computation_graphs_are_isomorphic(ComputationGraph const &lhs, + ComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + std::string as_dot(ComputationGraph const &cg) { std::function get_node_label = [](LayerAttrs const &a) -> std::string { diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index 4a565476bd..dff647f5a1 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -489,11 +489,12 @@ tensor_guid_t ComputationGraphBuilder::gather( LayerAttrs layer = LayerAttrs{ComputationGraphOpAttrs{attrs}, name}; if (this->get_shape(index).data_type != DataType::INT32 && this->get_shape(index).data_type != DataType::INT64) { - throw mk_runtime_error("Invalid data type for input tensor 2 for Gather: " - "{} (should be {} or {})", - this->get_shape(input).data_type, - DataType::INT32, - DataType::INT64); + throw mk_runtime_error( + fmt::format("Invalid data type for input tensor 2 for Gather: " + "{} (should be {} or {})", + this->get_shape(input).data_type, + DataType::INT32, + DataType::INT64)); } TensorShape output_shape = get_output_shape(attrs, this->get_shape(input), this->get_shape(index)); diff --git a/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..5341e03c0a --- /dev/null +++ b/lib/pcg/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,84 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include "utils/exception.h" +#include "utils/fmt/json.h" +#include "utils/overload.h" + +using namespace ::FlexFlow; + +namespace nlohmann { + +V1BinarySPDecomposition + adl_serializer::from_json(json const &j) { + std::string type = j.at("type").get(); + + if (type == "series") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "parallel") { + return V1BinarySPDecomposition{ + j.get(), + }; + } else if (type == "leaf") { + return V1BinarySPDecomposition{ + j.at("value").get(), + }; + } else { + throw mk_runtime_error(fmt::format( + "Unknown json type value for LeafOnlyBinarySPDecompositionTree \"{}\" " + "in json object: {}", + type, + j)); + } +} + +void adl_serializer::to_json( + json &j, V1BinarySPDecomposition const &tree) { + tree.visit(overload{ + [&](V1BinarySeriesSplit const &split) { + j = split; + j["type"] = "series"; + return std::monostate{}; + }, + [&](V1BinaryParallelSplit const &split) { + j = split; + j["type"] = "parallel"; + return std::monostate{}; + }, + [&](int leaf) { + j["value"] = leaf; + j["type"] = "leaf"; + return std::monostate{}; + }, + }); +} + +V1BinarySeriesSplit + adl_serializer::from_json(json const &j) { + return V1BinarySeriesSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json( + json &j, V1BinarySeriesSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + +V1BinaryParallelSplit + adl_serializer::from_json(json const &j) { + return V1BinaryParallelSplit{ + /*lhs=*/j.at("left_child").get(), + /*rhs=*/j.at("right_child").get(), + }; +} + +void adl_serializer::to_json( + json &j, V1BinaryParallelSplit const &series) { + j["left_child"] = series.get_left_child(); + j["right_child"] = series.get_right_child(); +} + +} // namespace nlohmann diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index b04d9d37b3..781c44640c 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -4,8 +4,11 @@ #include "utils/containers/get_only.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/algorithms.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/find_isomorphism.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" #include "utils/graph/node/algorithms.h" namespace FlexFlow { @@ -42,6 +45,39 @@ ParallelLayerAddedResult }; } +ParallelLayerAddedResult + pcg_add_input_layer(ParallelComputationGraph &pcg, + ParallelTensorShape const &tensor_shape) { + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + return add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); +} + +std::unordered_set + get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &src, + parallel_layer_guid_t const &dst) { + std::unordered_set raw_edges = + get_dataflow_edges_from_node_to_node( + pcg.raw_graph, src.raw_graph_node, dst.raw_graph_node); + return transform(raw_edges, [](DataflowEdge const &e) { + return ParallelComputationGraphEdge{e}; + }); +} + std::vector get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -110,12 +146,23 @@ ParallelLayerAttrs get_parallel_layer_attrs(ParallelComputationGraph const &pcg, return pcg.raw_graph.at(l.raw_graph_node); } +PCGOperatorAttrs pcg_get_op_attrs(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return get_parallel_layer_attrs(pcg, l).op_attrs; +} + ParallelTensorAttrs get_parallel_tensor_attrs(ParallelComputationGraph const &pcg, parallel_tensor_guid_t const &t) { return pcg.raw_graph.at(t.raw_graph_output); } +ParallelTensorShape + get_parallel_tensor_shape(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + return get_parallel_tensor_attrs(pcg, t).shape; +} + std::vector topological_ordering(ParallelComputationGraph const &pcg) { return transform(get_topological_ordering(pcg.raw_graph), @@ -132,4 +179,28 @@ parallel_layer_guid_t return get_only(found); } +ParallelComputationGraph + without_layer_names(ParallelComputationGraph const &pcg) { + return ParallelComputationGraph{ + LabelledDataflowGraph:: + create_copy_of< + UnorderedSetLabelledOpenDataflowGraph>( + rewrite_node_labels( + pcg.raw_graph, + [](Node const &n, ParallelLayerAttrs const &old_attrs) { + ParallelLayerAttrs new_attrs = old_attrs; + new_attrs.name = std::nullopt; + return new_attrs; + })), + }; +} + +bool pcgs_are_isomorphic(ParallelComputationGraph const &lhs, + ParallelComputationGraph const &rhs) { + return find_isomorphism(without_layer_names(lhs).raw_graph, + without_layer_names(rhs).raw_graph) + .has_value(); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index ce00ea62f4..f33b4dcd17 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,5 +1,18 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "op-attrs/get_incoming_tensor_roles.h" +#include "op-attrs/ops/attention.h" +#include "op-attrs/ops/batch_matmul.h" +#include "op-attrs/ops/batch_norm.h" +#include "op-attrs/ops/cast.h" +#include "op-attrs/ops/combine.h" +#include "op-attrs/ops/conv_2d.h" +#include "op-attrs/ops/element_binary.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/embedding.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/reduction.h" +#include "op-attrs/ops/repartition.h" +#include "op-attrs/ops/replicate.h" #include "op-attrs/ops/weight_attrs.dtg.h" #include "op-attrs/parallel_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" @@ -182,7 +195,7 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::optional activation, bool use_bias, DataType data_type, - std::optional const &kernel_initializer, + std::optional const &projection_initializer, std::optional const &bias_initializer, std::optional const &maybe_name) { LinearAttrs attrs = LinearAttrs{ @@ -205,9 +218,10 @@ parallel_tensor_guid_t ParallelComputationGraphBuilder::dense( std::vector weights; { - ParallelTensorShape kernel_shape = + ParallelTensorShape projection_shape = throw_if_unexpected(get_projection_shape(attrs, input_shape)); - weights.push_back(make_weight_attrs(kernel_shape, kernel_initializer)); + weights.push_back( + make_weight_attrs(projection_shape, projection_initializer)); } if (use_bias) { diff --git a/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc new file mode 100644 index 0000000000..9068e14517 --- /dev/null +++ b/lib/pcg/test/src/pcg/file_format/v1/v1_binary_sp_decomposition/json.cc @@ -0,0 +1,178 @@ +#include "pcg/file_format/v1/v1_binary_sp_decomposition/json.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer") { + V1BinarySPDecomposition example_tree = V1BinarySPDecomposition{ + V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }, + }; + + nlohmann::json example_json = { + {"type", "series"}, + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_tree; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySPDecomposition result = + example_json.get(); + V1BinarySPDecomposition correct = example_tree; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinarySeriesSplit example_split = V1BinarySeriesSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinarySeriesSplit result = example_json.get(); + V1BinarySeriesSplit correct = example_split; + + CHECK(result == correct); + } + } + + TEST_CASE("adl_serializer") { + V1BinaryParallelSplit example_split = V1BinaryParallelSplit{ + V1BinarySPDecomposition{ + V1BinaryParallelSplit{ + V1BinarySPDecomposition{2}, + V1BinarySPDecomposition{2}, + }, + }, + V1BinarySPDecomposition{3}, + }; + + nlohmann::json example_json = { + { + "left_child", + { + {"type", "parallel"}, + { + "left_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 2}, + }, + }, + }, + }, + { + "right_child", + { + {"type", "leaf"}, + {"value", 3}, + }, + }, + }; + + SUBCASE("to_json") { + nlohmann::json result = example_split; + nlohmann::json correct = example_json; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + V1BinaryParallelSplit result = example_json.get(); + V1BinaryParallelSplit correct = example_split; + + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 77d938e08a..fc07edf5b3 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -1,4 +1,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" #include "test/utils/rapidcheck.h" #include "utils/containers/get_only.h" @@ -262,4 +265,51 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } } + + TEST_CASE("pcg_add_input_layer") { + ParallelTensorShape tensor_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{12, 2}, + ShardParallelDim{10, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{2}, + DiscardCopyDegree{2}, + }, + }, + DataType::FLOAT, + }; + + ParallelComputationGraph result = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + pcg_add_input_layer(pcg, tensor_shape); + return pcg; + }(); + + ParallelComputationGraph correct = [&] { + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs tensor_attrs = ParallelTensorAttrs{ + /*shape=*/tensor_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::NO, + }; + + add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*output_labels=*/{tensor_attrs}); + + return pcg; + }(); + + CHECK(pcgs_are_isomorphic(result, correct)); + } } diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index c445085635..20bd0ac92d 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -1,4 +1,5 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "op-attrs/ops/conv_2d.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" diff --git a/lib/runtime/test/src/test_serialization.cc b/lib/runtime/test/src/test_serialization.cc index e46a481a1a..471f2a2709 100644 --- a/lib/runtime/test/src/test_serialization.cc +++ b/lib/runtime/test/src/test_serialization.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "legion/legion_utilities.h" #include "op-attrs/ffconst.h" -#include "op-attrs/operator_attrs.h" #include "serialization.h" #include diff --git a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h index de9d1cd78a..b7ce13db0e 100644 --- a/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h +++ b/lib/substitutions/include/substitutions/substitution_internal/perform_shape_inference.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_SUBSTITUTION_INTERNAL_PERFORM_SHAPE_INFERENCE_H +#include "op-attrs/parallel_tensor_shape.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 0bbe0e97a7..0c673f0a8a 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -4,7 +4,7 @@ #include "utils/containers/values.h" #include "utils/graph/dataflow_graph/algorithms/get_subgraph_outgoing_edges.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" -#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h" +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" #include "utils/graph/labelled_dataflow_graph/algorithms/view_as_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/as_dot.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/find_isomorphism.h" @@ -54,12 +54,6 @@ ParallelComputationGraph pcg_from_sub_pcg_by_dropping_inputs( UnorderedSetLabelledOpenDataflowGraph>( sub_pcg.raw_graph)}; - // return ParallelComputationGraph{ - // make_lazy_copy_of< - // UnorderedSetLabelledOpenDataflowGraph - // >(sub_pcg.raw_graph) - // }; } parallel_layer_guid_t diff --git a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc index 0bde326bd1..9fa91d75b7 100644 --- a/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -1,4 +1,5 @@ #include "substitutions/substitution_internal/perform_shape_inference.h" +#include "op-attrs/get_output_shapes.h" #include "utils/containers/map_keys.h" #include "utils/containers/transform.h" #include "utils/containers/zip.h" diff --git a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc index 6c3d53d3b9..4d4e557fb8 100644 --- a/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc +++ b/lib/substitutions/test/src/substitutions/substitution_internal/perform_shape_inference.cc @@ -1,4 +1,7 @@ #include "substitutions/substitution_internal/perform_shape_inference.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/parallel_tensor_shape.h" #include "utils/containers/get_only.h" #include "utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h" diff --git a/lib/utils/include/utils/any_value_type/any_value_type.h b/lib/utils/include/utils/any_value_type/any_value_type.h new file mode 100644 index 0000000000..a99ce5c8f0 --- /dev/null +++ b/lib/utils/include/utils/any_value_type/any_value_type.h @@ -0,0 +1,66 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ANY_VALUE_TYPE_ANY_VALUE_TYPE_H + +#include +#include +#include +#include + +namespace FlexFlow { + +struct any_value_type { +public: + any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string); + + bool operator==(any_value_type const &other) const; + bool operator!=(any_value_type const &other) const; + + template + T get() const { + return std::any_cast(value); + } + + friend std::string format_as(any_value_type const &); + +private: + std::any value; + std::function eq; + std::function neq; + std::function hash; + std::function to_string; + + friend std::hash; +}; + +template +any_value_type make_any_value_type(T const &t) { + return any_value_type{ + std::make_any(t), + [](std::any const &l, std::any const &r) { + return std::any_cast(l) == std::any_cast(r); + }, + [](std::any const &l, std::any const &r) { + return std::any_cast(l) != std::any_cast(r); + }, + [](std::any const &v) { return std::hash{}(std::any_cast(v)); }, + [](std::any const &v) { return fmt::to_string(std::any_cast(v)); }, + }; +} + +} // namespace FlexFlow + +namespace std { + +template <> +struct hash<::FlexFlow::any_value_type> { + size_t operator()(::FlexFlow::any_value_type const &) const; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/archetypes/value_type.h b/lib/utils/include/utils/archetypes/value_type.h new file mode 100644 index 0000000000..1635747612 --- /dev/null +++ b/lib/utils/include/utils/archetypes/value_type.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_VALUE_TYPE_H + +#include +#include + +namespace FlexFlow { + +template +struct value_type { + value_type() = delete; + + value_type(value_type const &) { + assert(false); + } + value_type &operator=(value_type const &) { + assert(false); + } + + value_type(value_type &&) { + assert(false); + } + value_type &operator=(value_type &&) { + assert(false); + } + + bool operator==(value_type const &) const { + assert(false); + } + bool operator!=(value_type const &) const { + assert(false); + } +}; + +} // namespace FlexFlow + +namespace std { + +template +struct hash<::FlexFlow::value_type> { + size_t operator()(::FlexFlow::value_type const &) const { + assert(false); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index 6ac9eb10b0..0e3b1fc0bd 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -126,34 +126,6 @@ std::optional optional_all_of(Container const &container, return true; } -template -std::vector flatmap(std::vector const &v, F const &f) { - std::vector result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap(std::unordered_set const &v, F const &f) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - -template -std::unordered_set flatmap_v2(std::unordered_set const &v, - std::unordered_set (*f)(In const &)) { - std::unordered_set result; - for (auto const &elem : v) { - extend(result, f(elem)); - } - return result; -} - template std::function compare_by(F const &f) { return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; diff --git a/lib/utils/include/utils/containers/cartesian_product.h b/lib/utils/include/utils/containers/cartesian_product.h new file mode 100644 index 0000000000..bcba52113e --- /dev/null +++ b/lib/utils/include/utils/containers/cartesian_product.h @@ -0,0 +1,39 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H + +#include "utils/containers/vector_of.h" +#include "utils/hash/vector.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set> + cartesian_product(std::vector> const &containers) { + std::unordered_set> result; + + std::function &, size_t)> recurse = + [&](std::vector ¤t, size_t depth) { + if (depth == containers.size()) { + result.insert(current); + return; + } + + for (E const &item : containers.at(depth)) { + current.push_back(item); + recurse(current, depth + 1); + current.pop_back(); + } + }; + + std::vector current; + recurse(current, 0); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index 0f8906f34a..b016a1e03d 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -3,7 +3,9 @@ #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" +#include "utils/containers/merge_maps.h" #include +#include namespace FlexFlow { @@ -39,6 +41,23 @@ std::unordered_set flatmap_v2(std::unordered_set const &v, return result; } +template < + typename InK, + typename InV, + typename F, + typename OutK = typename std::invoke_result_t::key_type, + typename OutV = typename std::invoke_result_t::mapped_type> +std::unordered_map flatmap(std::unordered_map const &m, + F &&f) { + std::unordered_map result; + + for (auto const &[k, v] : m) { + result = merge_maps(result, f(k, v)); + } + + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h new file mode 100644 index 0000000000..b7b30cbae4 --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H + +#include "utils/containers/cartesian_product.h" +#include "utils/containers/keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_map_from_pairs.h" +#include "utils/containers/vector_of.h" +#include "utils/containers/zip.h" +#include "utils/hash/unordered_map.h" +#include +#include +#include + +namespace FlexFlow { + +/** + * @note If \p options_per_key is empty, an set containing a single empty + * assignment is returned + */ +template +std::unordered_set> get_all_assignments( + std::unordered_map> const &options_per_key) { + if (options_per_key.empty()) { + return {{}}; + } + + std::vector ordered_keys = vector_of(keys(options_per_key)); + std::vector> ordered_value_option_sets = transform( + ordered_keys, [&](K const &k) { return options_per_key.at(k); }); + + std::unordered_set> result = transform( + cartesian_product(ordered_value_option_sets), + [&](std::vector const &chosen_values) { + return unordered_map_from_pairs(zip(ordered_keys, chosen_values)); + }); + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index fedb87413d..201095c47d 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -10,8 +10,8 @@ namespace FlexFlow { template typename C::value_type get_only(C const &c) { return unwrap(maybe_get_only(c), [&] { - throw mk_runtime_error("Encountered container with size {} in get_only", - c.size()); + throw mk_runtime_error(fmt::format( + "Encountered container with size {} in get_only", c.size())); }); } diff --git a/lib/utils/include/utils/containers/merge_maps.h b/lib/utils/include/utils/containers/merge_maps.h index 653c9d24f1..dd886ab8aa 100644 --- a/lib/utils/include/utils/containers/merge_maps.h +++ b/lib/utils/include/utils/containers/merge_maps.h @@ -3,6 +3,8 @@ #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" +#include "utils/exception.h" +#include "utils/fmt/unordered_map.h" #include namespace FlexFlow { @@ -10,7 +12,12 @@ namespace FlexFlow { template std::unordered_map merge_maps(std::unordered_map const &lhs, std::unordered_map const &rhs) { - assert(are_disjoint(keys(lhs), keys(rhs))); + if (!are_disjoint(keys(lhs), keys(rhs))) { + throw mk_runtime_error(fmt::format("Key sets of merge_maps parameters are " + "non-disjoint: lhs = {}, rhs = {}", + lhs, + rhs)); + } std::unordered_map result; for (auto const &kv : lhs) { diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index ec3d5f5612..ef6a26c79a 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -22,9 +22,7 @@ auto transform(req const &c, F const &f) return transform(static_cast(c), f); } -template ()(std::declval()))> +template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; for (auto const &e : v) { @@ -33,9 +31,17 @@ std::unordered_set transform(std::unordered_set const &v, F const &f) { return result; } -template ()(std::declval()))> +template > +std::unordered_multiset transform(std::unordered_multiset const &v, + F const &f) { + std::unordered_multiset result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + +template > std::set transform(std::set const &v, F const &f) { std::set result; for (auto const &e : v) { @@ -44,6 +50,15 @@ std::set transform(std::set const &v, F const &f) { return result; } +template > +std::multiset transform(std::multiset const &v, F const &f) { + std::multiset result; + for (auto const &e : v) { + result.insert(f(e)); + } + return result; +} + template std::string transform(std::string const &s, F const &f) { std::string result; diff --git a/lib/utils/include/utils/containers/try_at.h b/lib/utils/include/utils/containers/try_at.h new file mode 100644 index 0000000000..45e50fca27 --- /dev/null +++ b/lib/utils/include/utils/containers/try_at.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRY_AT_H + +#include "utils/containers/contains_key.h" +#include +#include +#include + +namespace FlexFlow { + +template +std::optional try_at(std::unordered_map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +template +std::optional try_at(std::map const &m, K const &k) { + if (contains_key(m, k)) { + return m.at(k); + } else { + return std::nullopt; + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_map_from_pairs.h b/lib/utils/include/utils/containers/unordered_map_from_pairs.h new file mode 100644 index 0000000000..660c57c5e7 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_map_from_pairs.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_PAIRS_H + +#include + +namespace FlexFlow { + +template +std::unordered_map unordered_map_from_pairs(C const &c) { + return std::unordered_map(c.cbegin(), c.cend()); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/values.h b/lib/utils/include/utils/containers/values.h index 7c487d1d43..2a730ccc42 100644 --- a/lib/utils/include/utils/containers/values.h +++ b/lib/utils/include/utils/containers/values.h @@ -1,15 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUES_H -#include +#include namespace FlexFlow { template -std::vector values(C const &c) { - std::vector result; +std::unordered_multiset values(C const &c) { + std::unordered_multiset result; for (auto const &kv : c) { - result.push_back(kv.second); + result.insert(kv.second); } return result; } diff --git a/lib/utils/include/utils/exception.h b/lib/utils/include/utils/exception.h index 20a8098040..080cbb3611 100644 --- a/lib/utils/include/utils/exception.h +++ b/lib/utils/include/utils/exception.h @@ -34,12 +34,7 @@ T throw_if_unexpected(tl::expected const &r) { } } -template -std::runtime_error mk_runtime_error(fmt::format_string fmt_str, - T &&...args) { - return std::runtime_error( - fmt::vformat(fmt_str, fmt::make_format_args(args...))); -} +std::runtime_error mk_runtime_error(std::string const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/fmt/json.h b/lib/utils/include/utils/fmt/json.h new file mode 100644 index 0000000000..c7aa87e3eb --- /dev/null +++ b/lib/utils/include/utils/fmt/json.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_JSON_H + +#include +#include + +namespace fmt { + +template +struct formatter<::nlohmann::json, Char> : formatter { + template + auto format(::nlohmann::json const &j, FormatContext &ctx) { + std::ostringstream oss; + oss << j; + return formatter::format(oss.str(), ctx); + } +}; + +} // namespace fmt + +#endif diff --git a/lib/utils/include/utils/fmt/monostate.h b/lib/utils/include/utils/fmt/monostate.h new file mode 100644 index 0000000000..884f4d389e --- /dev/null +++ b/lib/utils/include/utils/fmt/monostate.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_MONOSTATE_H + +#include +#include + +namespace fmt { + +template +struct formatter< + ::std::monostate, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::std::monostate const &, FormatContext &ctx) + -> decltype(ctx.out()) { + std::string result = ""; + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &, std::monostate const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.h b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h new file mode 100644 index 0000000000..e3ed967a23 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_BINARY_TREE_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path(); +BinaryTreePath nest_inside_left_child(BinaryTreePath const &); +BinaryTreePath nest_inside_right_child(BinaryTreePath const &); + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &); +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml new file mode 100644 index 0000000000..08955c2d75 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "BinaryTreePath" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", + "rapidcheck", +] + +includes = [ + "utils/full_binary_tree/binary_tree_path_entry.dtg.h", + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "entries" +type = "std::vector<::FlexFlow::BinaryTreePathEntry>" diff --git a/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml new file mode 100644 index 0000000000..6c81123dcf --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/binary_tree_path_entry.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "BinaryTreePathEntry" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] + +[[values]] +name = "LEFT_CHILD" +key = "left" + +[[values]] +name = "RIGHT_CHILD" +key = "right" diff --git a/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..9cf5d63210 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/find_paths_to_leaf.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + Leaf const &needle) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform( + find_paths_to_leaf(impl.get_left_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform( + find_paths_to_leaf(impl.get_right_child(parent), impl, needle), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + if (leaf == needle) { + return {binary_tree_root_path()}; + } else { + return {}; + } + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml new file mode 100644 index 0000000000..bf08701840 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_implementation.struct.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "get_left_child" +type = "std::function" + +[[fields]] +name = "get_right_child" +type = "std::function" + +[[fields]] +name = "is_leaf" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" + +[[fields]] +name = "require_parent" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml new file mode 100644 index 0000000000..1f8af17cf3 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_node_type.enum.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeNodeType" +features = [ + "hash", + "fmt", + "json", + "rapidcheck", +] + +[[values]] +name = "PARENT" +key = "parent" + +[[values]] +name = "LEAF" +key = "leaf" diff --git a/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml new file mode 100644 index 0000000000..7418d7a016 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/full_binary_tree_visitor.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "FullBinaryTreeVisitor" +features = [] + +template_params = [ + "Result", + "Tree", + "Parent", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "parent_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..822acfe9ee --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_all_leaf_paths.h @@ -0,0 +1,43 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/visit.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + auto visitor = FullBinaryTreeVisitor, + Tree, + Parent, + Leaf>{ + [&](Parent const &parent) -> std::unordered_set { + return set_union( + transform(get_all_leaf_paths(impl.get_left_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_left_child(path); + }), + transform(get_all_leaf_paths(impl.get_right_child(parent), impl), + [](BinaryTreePath const &path) { + return nest_inside_right_child(path); + })); + }, + [&](Leaf const &leaf) -> std::unordered_set { + return {binary_tree_root_path()}; + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_child.h b/lib/utils/include/utils/full_binary_tree/get_child.h new file mode 100644 index 0000000000..7517028ec0 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_child.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_CHILD_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/binary_tree_path_entry.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include + +namespace FlexFlow { + +template +Tree get_child(Parent const &parent, + FullBinaryTreeImplementation const &impl, + BinaryTreePathEntry const &e) { + switch (e) { + case BinaryTreePathEntry::LEFT_CHILD: + return impl.get_left_child(parent); + case BinaryTreePathEntry::RIGHT_CHILD: + return impl.get_right_child(parent); + default: + throw mk_runtime_error( + fmt::format("Unhandled BinaryTreePathEntry value: {}", e)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_leaves.h b/lib/utils/include/utils/full_binary_tree/get_leaves.h new file mode 100644 index 0000000000..8f9d8e919f --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_leaves.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_LEAVES_H + +#include "utils/containers/multiset_union.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::unordered_multiset + get_leaves(Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = + FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::unordered_multiset { + return multiset_union( + get_leaves(impl.get_left_child(parent), impl), + get_leaves(impl.get_right_child(parent), impl)); + }, + [](Leaf const &leaf) -> std::unordered_multiset { + return {leaf}; + }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h new file mode 100644 index 0000000000..922a42242c --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_num_tree_nodes.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_NUM_TREE_NODES_H + +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template +int get_num_tree_nodes( + Tree const &tree, + FullBinaryTreeImplementation const &impl) { + + auto visitor = FullBinaryTreeVisitor{ + [&](Parent const &parent) -> int { + return 1 + get_num_tree_nodes(impl.get_left_child(parent), impl) + + get_num_tree_nodes(impl.get_right_child(parent), impl); + }, + [](Leaf const &) -> int { return 1; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..83ce1367b9 --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/get_subtree_at_path.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/binary_tree_path.dtg.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/full_binary_tree/get_child.h" +#include "utils/full_binary_tree/visit.h" +#include + +namespace FlexFlow { + +template +std::optional get_subtree_at_path( + Tree const &tree, + FullBinaryTreeImplementation const &impl, + BinaryTreePath const &p) { + if (p == binary_tree_root_path()) { + return tree; + } + + auto visitor = FullBinaryTreeVisitor, Tree, Parent, Leaf>{ + [&](Parent const &parent) -> std::optional { + BinaryTreePathEntry curr = binary_tree_path_get_top_level(p); + BinaryTreePath rest = binary_tree_path_get_non_top_level(p); + + return get_subtree_at_path(get_child(parent, impl, curr), impl, rest); + }, + [](Leaf const &leaf) -> std::optional { return std::nullopt; }, + }; + + return visit(tree, impl, visitor); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/full_binary_tree/visit.h b/lib/utils/include/utils/full_binary_tree/visit.h new file mode 100644 index 0000000000..832d39bdff --- /dev/null +++ b/lib/utils/include/utils/full_binary_tree/visit.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_VISIT_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/full_binary_tree/full_binary_tree_visitor.dtg.h" + +namespace FlexFlow { + +template +Result visit(Tree const &tree, + FullBinaryTreeImplementation const &impl, + FullBinaryTreeVisitor const &visitor) { + if (impl.is_leaf(tree)) { + return visitor.leaf_func(impl.require_leaf(tree)); + } else { + return visitor.parent_func(impl.require_parent(tree)); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h new file mode 100644 index 0000000000..de7ead8fb6 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_GET_DATAFLOW_EDGES_FROM_NODE_TO_NODE_H + +#include "utils/graph/dataflow_graph/dataflow_graph_view.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h new file mode 100644 index 0000000000..be0e57435a --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_BOUNDARY_NODES_FOR_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.dtg.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h new file mode 100644 index 0000000000..e53bb876a1 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_EDGES_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h new file mode 100644 index 0000000000..ad8eadda0e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_GET_TRANSITIVE_REDUCED_OUTPUTS_ACROSS_SPLIT_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &, BinarySeriesSplit const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml new file mode 100644 index 0000000000..32582a6b74 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/split_boundary_nodes.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "SplitBoundaryNodes" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/node/node.dtg.h", + "", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "pre_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "post_split_boundary" +type = "std::unordered_set<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h new file mode 100644 index 0000000000..916e8f7896 --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DATAFLOW_GRAPH_ALGORITHMS_TRANSITIVE_REDUCED_DATAFLOW_GRAPH_H + +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.dtg.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml new file mode 100644 index 0000000000..54c710b26e --- /dev/null +++ b/lib/utils/include/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "TransitiveReducedDataflowGraphView" +features = [] + +includes = [ + "utils/graph/dataflow_graph/dataflow_graph_view.h", + "utils/graph/digraph/digraph_view.h", +] + +[[fields]] +name = "full_dataflow_graph" +type = "::FlexFlow::DataflowGraphView" + +[[fields]] +name = "transitive_reduction" +type = "::FlexFlow::DiGraphView" + diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h new file mode 100644 index 0000000000..240fc66426 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_EDGES_FROM_SUBGRAPH_TO_SUBGRAPH_H + +#include "utils/graph/digraph/digraph_view.h" +namespace FlexFlow { + +std::unordered_set + get_edges_from_subgraph_to_subgraph(DiGraphView const &, + std::unordered_set const &, + std::unordered_set const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h similarity index 88% rename from lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h rename to lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h index a8e08cb995..b9894fbac3 100644 --- a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_open_dataflow_graph_view.h +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_OPEN_DATAFLOW_GRAPH_VIEW_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_CREATE_LAZY_COPY_OF_LABELLED_DATAFLOW_GRAPH_VIEW_H #include "utils/graph/labelled_dataflow_graph/i_labelled_dataflow_graph.h" #include "utils/graph/labelled_dataflow_graph/labelled_dataflow_graph.h" @@ -8,6 +8,10 @@ namespace FlexFlow { +// NOTE(@lockshaw) This code is not tested and I don't necessarily trust it. +// Figuring out what to do with it is tracked in +// https://github.com/flexflow/FlexFlow/issues/1513 + template struct LazyLabelledDataflowGraph final : public ILabelledDataflowGraph { @@ -42,11 +46,11 @@ struct LazyLabelledDataflowGraph final return this->get_view().query_outputs(q); } - NodeLabel const &at(Node const &n) const override { + NodeLabel at(Node const &n) const override { return this->get_view().at(n); } - ValueLabel const &at(DataflowOutput const &v) const override { + ValueLabel at(DataflowOutput const &v) const override { return this->get_view().at(v); } @@ -95,7 +99,7 @@ template static typename std::enable_if< std::is_base_of, T>::value, LabelledDataflowGraph>::type - make_lazy_copy_of( + create_lazy_copy_of_labelled_dataflow_graph_view( LabelledDataflowGraphView const &view) { std::function( LabelledDataflowGraphView const &)> diff --git a/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h new file mode 100644 index 0000000000..07aa64aa62 --- /dev/null +++ b/lib/utils/include/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_NODE_LABELS_H + +#include "utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_node_labels.h" + +namespace FlexFlow { + +template > +LabelledDataflowGraphView rewrite_node_labels( + LabelledDataflowGraphView const &g, F f) { + return rewrite_node_labels( + view_as_labelled_open_dataflow_graph(g), f); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml new file mode 100644 index 0000000000..37e3bbee09 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinaryParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml new file mode 100644 index 0000000000..7e6e86ba76 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinarySeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct BinarySPDecompositionTree", +] + +post_includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h", +] + +[[fields]] +name = "left_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true + +[[fields]] +name = "right_child" +type = "::FlexFlow::BinarySPDecompositionTree" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h index b1607e7a76..de48cd17e9 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h @@ -1,23 +1,28 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_BINARY_SP_DECOMPOSITION_TREE_H +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" #include namespace FlexFlow { -BinarySPDecompositionTree make_series_split(BinarySPDecompositionTree const &, - BinarySPDecompositionTree const &); -BinarySPDecompositionTree - make_parallel_split(BinarySPDecompositionTree const &, - BinarySPDecompositionTree const &); -BinarySPDecompositionTree make_leaf_node(Node const &); +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree(); bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &); bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &); std::unordered_multiset get_leaves(BinarySPDecompositionTree const &); +SPDecompositionTreeNodeType get_node_type(BinarySPDecompositionTree const &); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml deleted file mode 100644 index 1241311150..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "BinarySPDecompositionTree" -features = [ - "eq", - "ord", - "hash", - "fmt", -] - -includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h", - "utils/graph/node/node.dtg.h", -] - -src_includes = [ - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h", - "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h", -] - -[[fields]] -name = "raw_tree" -type = "::FlexFlow::GenericBinarySPDecompositionTree<::FlexFlow::Node>" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml new file mode 100644 index 0000000000..c586b49d9d --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.variant.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "BinarySPDecompositionTree" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_series_split.dtg.h", + "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_parallel_split.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::BinarySeriesSplit" +key = "series" + +[[values]] +type = "::FlexFlow::BinaryParallelSplit" +key = "parallel" + +[[values]] +type = "::FlexFlow::Node" +key = "node" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h new file mode 100644 index 0000000000..105f5490a4 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FIND_PATHS_TO_LEAF_H + +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_set find_paths_to_leaf( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + Leaf const &needle) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return find_paths_to_leaf(tree, full_binary_impl, needle); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h deleted file mode 100644 index 42d71ce54e..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_FMT_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include - -namespace FlexFlow { - -template -std::string format_as(GenericBinarySeriesSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySeriesSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinaryParallelSplit const &s) { - return fmt::format("", - get_left_child(s), - get_right_child(s)); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinaryParallelSplit const &x) { - return (s << fmt::to_string(x)); -} - -template -std::string format_as(GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return fmt::format("", s); - }, - [](GenericBinaryParallelSplit const &s) { - return fmt::format("", s); - }, - [](T const &t) { - return fmt::format("", t); - }, - }); -} - -template -std::ostream &operator<<(std::ostream &s, - GenericBinarySPDecompositionTree const &t) { - return (s << fmt::to_string(t)); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h deleted file mode 100644 index 74f5ba5d8a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h +++ /dev/null @@ -1,155 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_H - -#include -#include -#include - -namespace FlexFlow { - -template -struct GenericBinarySPDecompositionTree; - -template -struct GenericBinarySeriesSplit { -public: - GenericBinarySeriesSplit() = delete; - explicit GenericBinarySeriesSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinarySeriesSplit(GenericBinarySeriesSplit const &) = default; - - bool operator==(GenericBinarySeriesSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySeriesSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySeriesSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinaryParallelSplit { -public: - GenericBinaryParallelSplit() = delete; - explicit GenericBinaryParallelSplit( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) - : left_child_ptr( - std::make_shared>(lhs)), - right_child_ptr( - std::make_shared>(rhs)) {} - - GenericBinaryParallelSplit(GenericBinaryParallelSplit const &) = default; - - bool operator==(GenericBinaryParallelSplit const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinaryParallelSplit const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinaryParallelSplit const &other) const { - return this->tie() < other.tie(); - } - -public: - std::shared_ptr> left_child_ptr; - std::shared_ptr> right_child_ptr; - -private: - std::tuple const &, - GenericBinarySPDecompositionTree const &> - tie() const { - return std::tie(*this->left_child_ptr, *this->right_child_ptr); - } - - friend std::hash; -}; - -template -struct GenericBinarySPDecompositionTree { -public: - GenericBinarySPDecompositionTree() = delete; - explicit GenericBinarySPDecompositionTree( - GenericBinarySeriesSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree( - GenericBinaryParallelSplit const &s) - : root{s} {} - - explicit GenericBinarySPDecompositionTree(T const &t) : root{t} {} - - GenericBinarySPDecompositionTree(GenericBinarySPDecompositionTree const &) = - default; - - bool operator==(GenericBinarySPDecompositionTree const &other) const { - return this->tie() == other.tie(); - } - - bool operator!=(GenericBinarySPDecompositionTree const &other) const { - return this->tie() != other.tie(); - } - - bool operator<(GenericBinarySPDecompositionTree const &other) const { - return this->tie() < other.tie(); - } - -public: - std::variant, GenericBinaryParallelSplit, T> - root; - -private: - std::tuple tie() const { - return std::tie(this->root); - } - - friend std::hash; -}; - -} // namespace FlexFlow - -// namespace rc { -// -// template <> -// struct Arbitrary<::FlexFlow::BinarySeriesSplit> { -// static Gen<::FlexFlow::BinarySeriesSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinaryParallelSplit> { -// static Gen<::FlexFlow::GenericBinaryParallelSplit> arbitrary(); -// }; -// -// template <> -// struct Arbitrary<::FlexFlow::GenericBinarySPDecompositionTree> { -// static Gen<::FlexFlow::GenericBinarySPDecompositionTree> arbitrary(); -// }; -// -// } // namespace rc - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h new file mode 100644 index 0000000000..0bddbee81c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IMPLEMENTATION_H + +#include "utils/exception.h" +#include "utils/full_binary_tree/full_binary_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/overload.h" +#include + +namespace FlexFlow { + +template +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &impl) { + + using Parent = std::variant; + + auto full_binary_impl = FullBinaryTreeImplementation{ + /*get_left_child=*/[impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_left_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_left_child(parallel); + }, + }, + parent); + }, + /*get_right_child=*/ + [impl](Parent const &parent) -> Tree const & { + return std::visit(overload{ + [&](Series const &series) -> Tree const & { + return impl.series_get_right_child(series); + }, + [&](Parallel const ¶llel) -> Tree const & { + return impl.parallel_get_right_child(parallel); + }, + }, + parent); + }, + /*is_leaf=*/ + [impl](Tree const &tree) -> bool { + return impl.get_node_type(tree) == SPDecompositionTreeNodeType::NODE; + }, + /*require_leaf=*/ + [impl](Tree const &tree) -> Leaf const & { + return impl.require_leaf(tree); + }, + /*require_parent=*/ + [impl](Tree const &tree) -> Parent { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: + return Parent{impl.require_series(tree)}; + case SPDecompositionTreeNodeType::PARALLEL: + return Parent{impl.require_parallel(tree)}; + default: + throw mk_runtime_error(fmt::format( + "Unexpected SPDecompositionTreeNodeType: {}", node_type)); + } + }}; + + return full_binary_impl; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml new file mode 100644 index 0000000000..3ccbfd959b --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.struct.toml @@ -0,0 +1,47 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeImplementation" +features = [] + +template_params = [ + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", + "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h", +] + +[[fields]] +name = "series_get_left_child" +type = "std::function" + +[[fields]] +name = "parallel_get_left_child" +type = "std::function" + +[[fields]] +name = "series_get_right_child" +type = "std::function" + +[[fields]] +name = "parallel_get_right_child" +type = "std::function" + +[[fields]] +name = "get_node_type" +type = "std::function<::FlexFlow::SPDecompositionTreeNodeType(Tree const &)>" + +[[fields]] +name = "require_series" +type = "std::function" + +[[fields]] +name = "require_parallel" +type = "std::function" + +[[fields]] +name = "require_leaf" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml new file mode 100644 index 0000000000..6275c82a0c --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.struct.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "GenericBinarySPDecompositionTreeVisitor" +features = [] + +template_params = [ + "ReturnType", + "Tree", + "Series", + "Parallel", + "Leaf", +] + +includes = [ + "", +] + +[[fields]] +name = "series_func" +type = "std::function" + +[[fields]] +name = "parallel_func" +type = "std::function" + +[[fields]] +name = "leaf_func" +type = "std::function" diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h deleted file mode 100644 index c6c1186d3d..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -TT const &get(GenericBinarySPDecompositionTree const &t) { - return std::get(t.root); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h new file mode 100644 index 0000000000..b0bb8355db --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_ALL_LEAF_PATHS_H + +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" + +namespace FlexFlow { + +template +std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_all_leaf_paths(tree, full_binary_impl); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h index 51e1e20bac..c543375148 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h @@ -1,38 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEAVES_H -#include "utils/containers/multiset_union.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" -#include +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template -std::unordered_multiset - get_leaves(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](T const &t) { return std::unordered_multiset{t}; }, - [](GenericBinarySeriesSplit const &s) { return get_leaves(s); }, - [](GenericBinaryParallelSplit const &p) { return get_leaves(p); }, - }); -} +template +std::unordered_multiset get_leaves( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { -template -std::unordered_multiset get_leaves(GenericBinarySeriesSplit const &s) { - return multiset_union(get_leaves(get_left_child(s)), - get_leaves(get_right_child(s))); -} + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); -template -std::unordered_multiset get_leaves(GenericBinaryParallelSplit const &p) { - return multiset_union(get_leaves(get_left_child(p)), - get_leaves(get_right_child(p))); + return get_leaves(tree, full_binary_impl); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h deleted file mode 100644 index 46a460b64e..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_LEFT_CHILD_H - -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySeriesSplit const &s) { - return *s.left_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinaryParallelSplit const &p) { - return *p.left_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_left_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_left_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_left_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_left_child incorrectly called on leaf node"); - }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h deleted file mode 100644 index 883acda480..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_NODE_TYPE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/graph/series_parallel/sp_decomposition_tree_node_type.dtg.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -SPDecompositionTreeNodeType - get_node_type(GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](GenericBinarySeriesSplit const &) { - return SPDecompositionTreeNodeType::SERIES; - }, - [](GenericBinaryParallelSplit const &) { - return SPDecompositionTreeNodeType::PARALLEL; - }, - [](T const &) { return SPDecompositionTreeNodeType::NODE; }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h index 7c6d28d7b4..4678e0c0f7 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h @@ -1,38 +1,23 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GET_NUM_TREE_NODES_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" namespace FlexFlow { -template -int get_num_tree_nodes(GenericBinarySPDecompositionTree const &tt) { - return visit(tt, - overload{ - [](T const &t) { return 1; }, - [](GenericBinarySeriesSplit const &s) { - return get_num_tree_nodes(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_num_tree_nodes(p); - }, - }); -} +template +int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { -template -int get_num_tree_nodes(GenericBinarySeriesSplit const &s) { - return 1 + get_num_tree_nodes(get_left_child(s)) + - get_num_tree_nodes(get_right_child(s)); -} + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); -template -int get_num_tree_nodes(GenericBinaryParallelSplit const &p) { - return 1 + get_num_tree_nodes(get_left_child(p)) + - get_num_tree_nodes(get_right_child(p)); + return get_num_tree_nodes(tree, full_binary_impl); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h deleted file mode 100644 index f0bfba43a2..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_RIGHT_CHILD_H - -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include "utils/overload.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySeriesSplit const &s) { - return *s.right_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinaryParallelSplit const &p) { - return *p.right_child_ptr; -} - -template -GenericBinarySPDecompositionTree - get_right_child(GenericBinarySPDecompositionTree const &tt) { - return visit>( - tt, - overload{ - [](GenericBinarySeriesSplit const &s) { - return get_right_child(s); - }, - [](GenericBinaryParallelSplit const &p) { - return get_right_child(p); - }, - [](T const &t) -> GenericBinarySPDecompositionTree { - throw mk_runtime_error( - "get_right_child incorrectly called on leaf node"); - }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h new file mode 100644 index 0000000000..c48185fb7f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_GET_SUBTREE_AT_PATH_H + +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include + +namespace FlexFlow { + +template +std::optional get_subtree_at_path( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + BinaryTreePath const &path) { + FullBinaryTreeImplementation, Leaf> + full_binary_impl = get_full_binary_impl_from_generic_sp_impl(impl); + + return get_subtree_at_path(tree, full_binary_impl, path); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h deleted file mode 100644 index 983dc4a572..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_HASH_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/hash-utils.h" -#include "utils/hash/tuple.h" - -namespace std { - -template -struct hash<::FlexFlow::GenericBinarySeriesSplit> { - size_t operator()(::FlexFlow::GenericBinarySeriesSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinaryParallelSplit> { - size_t operator()(::FlexFlow::GenericBinaryParallelSplit const &s) const { - return get_std_hash(s.tie()); - } -}; - -template -struct hash<::FlexFlow::GenericBinarySPDecompositionTree> { - size_t operator()( - ::FlexFlow::GenericBinarySPDecompositionTree const &s) const { - return get_std_hash(s.tie()); - } -}; - -} // namespace std - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h deleted file mode 100644 index 8086f38244..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -bool is_series_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); -} - -template -bool is_parallel_split(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative>(t.root); -} - -template -bool is_leaf(GenericBinarySPDecompositionTree const &t) { - return std::holds_alternative(t.root); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h index 3ffa63753a..68e0a3af32 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h @@ -1,32 +1,44 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_LEFT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" namespace FlexFlow { -template +template bool is_binary_sp_tree_left_associative( - GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](T const &) { return true; }, - [](GenericBinarySeriesSplit const &s) { - return !is_series_split(get_right_child(s)) && - is_binary_sp_tree_left_associative(get_left_child(s)) && - is_binary_sp_tree_left_associative(get_right_child(s)); - }, - [](GenericBinaryParallelSplit const &p) { - return !is_parallel_split(get_right_child(p)) && - is_binary_sp_tree_left_associative(get_left_child(p)) && - is_binary_sp_tree_left_associative(get_right_child(p)); - }, - }); + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_right_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_left_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_right_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_left_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_left_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h index d88459b432..7042765203 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h @@ -1,32 +1,43 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_IS_BINARY_SP_TREE_RIGHT_ASSOCIATIVE_H -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" namespace FlexFlow { -template +template bool is_binary_sp_tree_right_associative( - GenericBinarySPDecompositionTree const &tt) { - return visit( - tt, - overload{ - [](T const &t) { return true; }, - [](GenericBinarySeriesSplit const &s) { - return !is_series_split(get_left_child(s)) && - is_binary_sp_tree_right_associative(get_left_child(s)) && - is_binary_sp_tree_right_associative(get_right_child(s)); - }, - [](GenericBinaryParallelSplit const &p) { - return !is_parallel_split(get_left_child(p)) && - is_binary_sp_tree_right_associative(get_left_child(p)) && - is_binary_sp_tree_right_associative(get_right_child(p)); - }, - }); + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl) { + auto visitor = GenericBinarySPDecompositionTreeVisitor{ + [&](Series const &split) { + return impl.get_node_type(impl.series_get_left_child(split)) != + SPDecompositionTreeNodeType::SERIES && + is_binary_sp_tree_right_associative( + impl.series_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.series_get_right_child(split), impl); + }, + [&](Parallel const &split) { + return impl.get_node_type(impl.parallel_get_left_child(split)) != + SPDecompositionTreeNodeType::PARALLEL && + is_binary_sp_tree_right_associative( + impl.parallel_get_left_child(split), impl) && + is_binary_sp_tree_right_associative( + impl.parallel_get_right_child(split), impl); + }, + [&](Leaf const &leaf) { return true; }, + }; + + return visit(tree, impl, visitor); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h deleted file mode 100644 index 4f1f8266e1..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h +++ /dev/null @@ -1,103 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_JSON_H - -#include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" -#include - -namespace nlohmann { - -template -struct adl_serializer<::FlexFlow::GenericBinarySeriesSplit> { - static ::FlexFlow::GenericBinarySeriesSplit from_json(json const &j) { - return ::FlexFlow::GenericBinarySeriesSplit{ - j.at("left_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - j.at("right_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - }; - } - - static void to_json(json &j, - ::FlexFlow::GenericBinarySeriesSplit const &v) { - j["__type"] = "GenericBinarySeriesSplit"; - j["left_child"] = get_left_child(v); - j["right_child"] = get_right_child(v); - } -}; - -template -struct adl_serializer<::FlexFlow::GenericBinaryParallelSplit> { - static ::FlexFlow::GenericBinaryParallelSplit from_json(json const &j) { - return ::FlexFlow::GenericBinaryParallelSplit{ - j.at("left_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - j.at("right_child") - .template get<::FlexFlow::GenericBinarySPDecompositionTree>(), - }; - } - - static void to_json(json &j, - ::FlexFlow::GenericBinaryParallelSplit const &v) { - j["__type"] = "GenericBinaryParallelSplit"; - j["left_child"] = get_left_child(v); - j["right_child"] = get_right_child(v); - } -}; - -template -struct adl_serializer<::FlexFlow::GenericBinarySPDecompositionTree> { - static ::FlexFlow::GenericBinarySPDecompositionTree - from_json(json const &j) { - std::string key = j.at("type").get(); - - if (key == "series") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get<::FlexFlow::GenericBinarySeriesSplit>(), - }; - } else if (key == "parallel") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get<::FlexFlow::GenericBinaryParallelSplit>(), - }; - } else if (key == "leaf") { - return ::FlexFlow::GenericBinarySPDecompositionTree{ - j.at("value").get(), - }; - } else { - throw ::FlexFlow::mk_runtime_error( - fmt::format("Unknown json type key: {}", key)); - } - } - - static void - to_json(json &j, - ::FlexFlow::GenericBinarySPDecompositionTree const &v) { - j["__type"] = "GenericBinarySPDecompositionTree"; - ::FlexFlow::visit( - v, - ::FlexFlow::overload{ - [&](::FlexFlow::GenericBinarySeriesSplit const &s) { - j["type"] = "series"; - j["value"] = s; - return std::monostate{}; - }, - [&](::FlexFlow::GenericBinaryParallelSplit const &p) { - j["type"] = "parallel"; - j["value"] = p; - return std::monostate{}; - }, - [&](T const &t) { - j["type"] = "leaf"; - j["value"] = t; - return std::monostate{}; - }, - }); - } -}; - -} // namespace nlohmann - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h deleted file mode 100644 index f55b71146a..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_MAKE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" - -namespace FlexFlow { - -template -GenericBinarySPDecompositionTree make_generic_binary_series_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{ - lhs, - rhs, - }, - }; -} - -template -GenericBinarySPDecompositionTree make_generic_binary_parallel_split( - GenericBinarySPDecompositionTree const &lhs, - GenericBinarySPDecompositionTree const &rhs) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{ - lhs, - rhs, - }, - }; -} - -template -GenericBinarySPDecompositionTree make_generic_binary_sp_leaf(T const &t) { - return GenericBinarySPDecompositionTree{t}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h deleted file mode 100644 index a8de1ee8f8..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_REQUIRE_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" - -namespace FlexFlow { - -template -GenericBinarySeriesSplit const & - require_series(GenericBinarySPDecompositionTree const &t) { - return get>(t); -} - -template -GenericBinaryParallelSplit const & - require_parallel(GenericBinarySPDecompositionTree const &t) { - return get>(t); -} - -template -T const &require_node(GenericBinarySPDecompositionTree const &t) { - return get(t); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h deleted file mode 100644 index 4d7fa05960..0000000000 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_TRANSFORM_H - -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" - -namespace FlexFlow { - -template > -GenericBinarySPDecompositionTree - transform(GenericBinarySPDecompositionTree const &tt, F f) { - return visit>( - tt, - overload{ - [&](GenericBinarySeriesSplit const &s) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{ - transform(get_left_child(s), f), - transform(get_right_child(s), f), - }, - }; - }, - [&](GenericBinaryParallelSplit const &s) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{ - transform(get_left_child(s), f), - transform(get_right_child(s), f), - }, - }; - }, - [&](T const &t) { - return GenericBinarySPDecompositionTree{ - f(t), - }; - }, - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h index 0d9503e59f..c06db135b2 100644 --- a/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h +++ b/lib/utils/include/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h @@ -2,34 +2,45 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_BINARY_SP_DECOMPOSITION_TREE_GENERIC_BINARY_SP_DECOMPOSITION_TREE_VISIT_H #include "utils/exception.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_visitor.dtg.h" namespace FlexFlow { -template -Result visit(GenericBinarySPDecompositionTree const &tt, F f) { - if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative>(tt.root)) { - return f(std::get>(tt.root)); - } else if (std::holds_alternative(tt.root)) { - return f(std::get(tt.root)); - } else { - throw mk_runtime_error( - "Unexpected case in visit(GenericBinarySPDecompositionTree)"); +template +ReturnType + visit(Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &impl, + GenericBinarySPDecompositionTreeVisitor const &visitor) { + SPDecompositionTreeNodeType node_type = impl.get_node_type(tree); + switch (node_type) { + case SPDecompositionTreeNodeType::SERIES: { + ReturnType result = visitor.series_func(impl.require_series(tree)); + return result; + } + case SPDecompositionTreeNodeType::PARALLEL: { + ReturnType result = visitor.parallel_func(impl.require_parallel(tree)); + return result; + } + case SPDecompositionTreeNodeType::NODE: { + ReturnType result = visitor.leaf_func(impl.require_leaf(tree)); + return result; + } + default: + throw mk_runtime_error(fmt::format( + "Unknown SPDecompositionTreeNodeType value: {}", node_type)); } - - // return std::visit(tt.root, overload { - // [&](GenericBinarySeriesSplit const &s) -> Result { - // return f(s); - // }, - // [&](GenericBinaryParallelSplit const &p) -> Result { - // return f(p); - // }, - // [&](T const &t) -> Result { - // return f(t); - // }, - // }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml new file mode 100644 index 0000000000..dd68adf3f6 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/parallel_split.struct.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "ParallelSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct SeriesSplit" +] + +post_includes = [ + "utils/graph/series_parallel/series_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "children" +type = "std::unordered_multiset>" +indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h index 18434d2b67..7374b45a60 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_splits.h @@ -1,80 +1,76 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_FLATTENED_DECOMPOSITION_TREE_H -#include "utils/graph/node/node.dtg.h" -#include -#include +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" namespace FlexFlow { -struct SeriesSplit; -struct ParallelSplit; - -struct SeriesSplit { -public: - SeriesSplit() = delete; - explicit SeriesSplit(std::vector> const &); - explicit SeriesSplit( - std::initializer_list> const &); - - bool operator==(SeriesSplit const &) const; - bool operator!=(SeriesSplit const &) const; - -public: - std::vector> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(SeriesSplit const &); -std::ostream &operator<<(std::ostream &, SeriesSplit const &); - -} // namespace FlexFlow - -namespace std { - -template <> -struct hash<::FlexFlow::SeriesSplit> { - size_t operator()(::FlexFlow::SeriesSplit const &) const; -}; - -} // namespace std - -namespace FlexFlow { - -struct ParallelSplit { -public: - ParallelSplit() = delete; - explicit ParallelSplit( - std::unordered_multiset> const &); - explicit ParallelSplit( - std::initializer_list> const &); - - bool operator==(ParallelSplit const &) const; - bool operator!=(ParallelSplit const &) const; - -public: - std::unordered_multiset> children; - -private: - using Tie = std::tuple; - Tie tie() const; -}; - -std::string format_as(ParallelSplit const &); -std::ostream &operator<<(std::ostream &, ParallelSplit const &); +// struct SeriesSplit { +// public: +// SeriesSplit() = delete; +// explicit SeriesSplit(std::vector> const +// &); explicit SeriesSplit( +// std::initializer_list> const &); +// +// bool operator==(SeriesSplit const &) const; +// bool operator!=(SeriesSplit const &) const; +// +// public: +// std::vector> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(SeriesSplit const &); +// std::ostream &operator<<(std::ostream &, SeriesSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::SeriesSplit> { +// size_t operator()(::FlexFlow::SeriesSplit const &) const; +// }; +// +// } // namespace std +// +// namespace FlexFlow { +// +// struct ParallelSplit { +// public: +// ParallelSplit() = delete; +// explicit ParallelSplit( +// std::unordered_multiset> const &); +// explicit ParallelSplit( +// std::initializer_list> const &); +// +// bool operator==(ParallelSplit const &) const; +// bool operator!=(ParallelSplit const &) const; +// +// public: +// std::unordered_multiset> children; +// +// private: +// using Tie = std::tuple; +// Tie tie() const; +// }; +// +// std::string format_as(ParallelSplit const &); +// std::ostream &operator<<(std::ostream &, ParallelSplit const &); +// +// } // namespace FlexFlow +// +// namespace std { +// +// template <> +// struct hash<::FlexFlow::ParallelSplit> { +// size_t operator()(::FlexFlow::ParallelSplit const &) const; +// }; } // namespace FlexFlow -namespace std { - -template <> -struct hash<::FlexFlow::ParallelSplit> { - size_t operator()(::FlexFlow::ParallelSplit const &) const; -}; - -} // namespace std - #endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml new file mode 100644 index 0000000000..fdb0a29972 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_split.struct.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "SeriesSplit" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct ParallelSplit" +] + +post_includes = [ + "utils/graph/series_parallel/parallel_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "children" +type = "std::vector>" diff --git a/lib/utils/include/utils/json/check_is_json_deserializable.h b/lib/utils/include/utils/json/check_is_json_deserializable.h new file mode 100644 index 0000000000..dd5f397c19 --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_deserializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_DESERIALIZABLE_H + +#include "utils/json/is_json_deserializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_DESERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_deserializable::value, \ + #TYPENAME " should be json deserializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/json/check_is_json_serializable.h b/lib/utils/include/utils/json/check_is_json_serializable.h new file mode 100644 index 0000000000..dfcb26081d --- /dev/null +++ b/lib/utils/include/utils/json/check_is_json_serializable.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_JSON_CHECK_IS_JSON_SERIALIZABLE_H + +#include "utils/json/is_json_serializable.h" + +namespace FlexFlow { + +#define CHECK_IS_JSON_SERIALIZABLE(TYPENAME) \ + static_assert(::FlexFlow::is_json_serializable::value, \ + #TYPENAME " should be json serializeable") + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/sequence.h b/lib/utils/include/utils/sequence.h index 6c66949fd8..07e4554299 100644 --- a/lib/utils/include/utils/sequence.h +++ b/lib/utils/include/utils/sequence.h @@ -135,7 +135,7 @@ auto seq_get(F const &f, int i, seq const &s) template auto seq_get(F const &f, int i, seq<> const &) -> decltype(f(std::declval>())) { - throw mk_runtime_error("Failed seq_get for index {}", i); + throw mk_runtime_error(fmt::format("Failed seq_get for index {}", i)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/tuple.h b/lib/utils/include/utils/tuple.h index afc16d4c4b..0296e365a3 100644 --- a/lib/utils/include/utils/tuple.h +++ b/lib/utils/include/utils/tuple.h @@ -67,8 +67,8 @@ template std::any get(std::tuple const &t, int idx) { size_t tuple_size = std::tuple_size::value; if (idx < 0 || idx >= tuple_size) { - throw mk_runtime_error( - "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size); + throw mk_runtime_error(fmt::format( + "Error: idx {} out of bounds for tuple of size {}", idx, tuple_size)); } std::any result; visit_tuple(t, tuple_get_visitor{idx, result}); diff --git a/lib/utils/src/utils/any_value_type/any_value_type.cc b/lib/utils/src/utils/any_value_type/any_value_type.cc new file mode 100644 index 0000000000..d4c605c441 --- /dev/null +++ b/lib/utils/src/utils/any_value_type/any_value_type.cc @@ -0,0 +1,34 @@ +#include "utils/any_value_type/any_value_type.h" + +namespace FlexFlow { + +any_value_type::any_value_type( + std::any const &value, + std::function const &eq, + std::function const &neq, + std::function const &hash, + std::function const &to_string) + : value(value), eq(eq), neq(neq), hash(hash), to_string(to_string) {} + +bool any_value_type::operator==(any_value_type const &other) const { + return this->eq(this->value, other.value); +} + +bool any_value_type::operator!=(any_value_type const &other) const { + return this->neq(this->value, other.value); +} + +std::string format_as(any_value_type const &v) { + return v.to_string(v.value); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash<::FlexFlow::any_value_type>::operator()( + ::FlexFlow::any_value_type const &v) const { + return v.hash(v); +} + +} // namespace std diff --git a/lib/utils/src/utils/archetypes/value_type.cc b/lib/utils/src/utils/archetypes/value_type.cc new file mode 100644 index 0000000000..f7da47d8f9 --- /dev/null +++ b/lib/utils/src/utils/archetypes/value_type.cc @@ -0,0 +1,7 @@ +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template struct value_type<0>; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/cartesian_product.cc b/lib/utils/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..b716a49ad5 --- /dev/null +++ b/lib/utils/src/utils/containers/cartesian_product.cc @@ -0,0 +1 @@ +#include "utils/containers/cartesian_product.h" diff --git a/lib/utils/src/utils/containers/get_all_assignments.cc b/lib/utils/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..3a7cf6377a --- /dev/null +++ b/lib/utils/src/utils/containers/get_all_assignments.cc @@ -0,0 +1 @@ +#include "utils/containers/get_all_assignments.h" diff --git a/lib/utils/src/utils/containers/try_at.cc b/lib/utils/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..0d1ed3b04a --- /dev/null +++ b/lib/utils/src/utils/containers/try_at.cc @@ -0,0 +1 @@ +#include "utils/containers/try_at.h" diff --git a/lib/utils/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..60cc978be7 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1 @@ +#include "utils/containers/unordered_map_from_pairs.h" diff --git a/lib/utils/src/utils/exception.cc b/lib/utils/src/utils/exception.cc index 9bbf780fd8..c645f241aa 100644 --- a/lib/utils/src/utils/exception.cc +++ b/lib/utils/src/utils/exception.cc @@ -1 +1,9 @@ #include "utils/exception.h" + +namespace FlexFlow { + +std::runtime_error mk_runtime_error(std::string const &s) { + return std::runtime_error(s); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/fmt/json.cc b/lib/utils/src/utils/fmt/json.cc new file mode 100644 index 0000000000..49ad57fba7 --- /dev/null +++ b/lib/utils/src/utils/fmt/json.cc @@ -0,0 +1,7 @@ +#include "utils/fmt/json.h" + +namespace fmt { + +template struct formatter<::nlohmann::json, char>; + +} diff --git a/lib/utils/src/utils/fmt/monostate.cc b/lib/utils/src/utils/fmt/monostate.cc new file mode 100644 index 0000000000..55988cdce0 --- /dev/null +++ b/lib/utils/src/utils/fmt/monostate.cc @@ -0,0 +1,9 @@ +#include "utils/fmt/monostate.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::monostate const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc new file mode 100644 index 0000000000..8445a2721a --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/binary_tree_path.cc @@ -0,0 +1,34 @@ +#include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/containers/subvec.h" + +namespace FlexFlow { + +BinaryTreePath binary_tree_root_path() { + return BinaryTreePath{{}}; +} + +BinaryTreePath nest_inside_left_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::LEFT_CHILD); + return result; +} + +BinaryTreePath nest_inside_right_child(BinaryTreePath const &p) { + BinaryTreePath result = p; + result.entries.insert(result.entries.begin(), + BinaryTreePathEntry::RIGHT_CHILD); + return result; +} + +BinaryTreePathEntry binary_tree_path_get_top_level(BinaryTreePath const &p) { + return p.entries.at(0); +} + +BinaryTreePath binary_tree_path_get_non_top_level(BinaryTreePath const &p) { + return BinaryTreePath{ + subvec(p.entries, 1, std::nullopt), + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..47845720ed --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/find_paths_to_leaf.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::unordered_set + find_paths_to_leaf(Tree const &, + FullBinaryTreeImplementation const &, + Leaf const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..b4d8aa1011 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_all_leaf_paths.cc @@ -0,0 +1,12 @@ +#include "utils/full_binary_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +template std::unordered_set + get_all_leaf_paths(value_type<0> const &, + FullBinaryTreeImplementation, + value_type<1>, + value_type<2>> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_child.cc b/lib/utils/src/utils/full_binary_tree/get_child.cc new file mode 100644 index 0000000000..19362ae510 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_child.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_child.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template Tree + get_child(Parent const &, + FullBinaryTreeImplementation const &, + BinaryTreePathEntry const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_leaves.cc b/lib/utils/src/utils/full_binary_tree/get_leaves.cc new file mode 100644 index 0000000000..0d7e9106f6 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_leaves.cc @@ -0,0 +1,14 @@ +#include "utils/full_binary_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::unordered_multiset + get_leaves(Tree const &, + FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc new file mode 100644 index 0000000000..7a99dd60fa --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_num_tree_nodes.cc @@ -0,0 +1,13 @@ +#include "utils/full_binary_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template int get_num_tree_nodes( + Tree const &, FullBinaryTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..1eea13fedd --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/get_subtree_at_path.cc @@ -0,0 +1,15 @@ +#include "utils/full_binary_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Parent = value_type<1>; +using Leaf = value_type<2>; + +template std::optional get_subtree_at_path( + Tree const &, + FullBinaryTreeImplementation const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/full_binary_tree/visit.cc b/lib/utils/src/utils/full_binary_tree/visit.cc new file mode 100644 index 0000000000..4a4f7c9302 --- /dev/null +++ b/lib/utils/src/utils/full_binary_tree/visit.cc @@ -0,0 +1,9 @@ +#include "utils/full_binary_tree/visit.h" + +namespace FlexFlow { + +template int visit(std::string const &, + FullBinaryTreeImplementation const &, + FullBinaryTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..c07d344d05 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,15 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" + +namespace FlexFlow { + +std::unordered_set get_dataflow_edges_from_node_to_node( + DataflowGraphView const &g, Node const &src, Node const &dst) { + return g.query_edges(DataflowEdgeQuery{ + /*src_nodes=*/query_set{src}, + /*src_idxs=*/query_set::matchall(), + /*dst_nodes=*/query_set{dst}, + /*dst_idxs=*/query_set::matchall(), + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..70a66c9a21 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,24 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +SplitBoundaryNodes get_transitive_reduced_boundary_nodes_for_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set edges = + get_transitive_reduced_edges_across_split(tr_g, split); + + std::unordered_set src_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.src.node; }); + + std::unordered_set dst_boundary_nodes = + transform(edges, [](DataflowEdge const &e) { return e.dst.node; }); + + return SplitBoundaryNodes{ + /*pre_split_boundary=*/src_boundary_nodes, + /*post_split_boundary=*/dst_boundary_nodes, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..8a4adf0b3a --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,27 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/containers/flatmap.h" +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_edges_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + std::unordered_set src_subgraph = + unordered_set_of(get_leaves(split.get_left_child())); + std::unordered_set dst_subgraph = + unordered_set_of(get_leaves(split.get_right_child())); + + std::unordered_set raw_edges = + get_edges_from_subgraph_to_subgraph( + tr_g.transitive_reduction, src_subgraph, dst_subgraph); + + return flatmap(raw_edges, [&](DirectedEdge const &e) { + return get_dataflow_edges_from_node_to_node( + tr_g.full_dataflow_graph, e.src, e.dst); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..0bb94c87f4 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,14 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/containers/transform.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" + +namespace FlexFlow { + +std::unordered_set get_transitive_reduced_outputs_across_split( + TransitiveReducedDataflowGraphView const &tr_g, + BinarySeriesSplit const &split) { + return transform(get_transitive_reduced_edges_across_split(tr_g, split), + [](DataflowEdge const &e) { return e.src; }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc new file mode 100644 index 0000000000..81751702a2 --- /dev/null +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.cc @@ -0,0 +1,17 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" + +namespace FlexFlow { + +TransitiveReducedDataflowGraphView + get_dataflow_graph_transitive_reduction(DataflowGraphView const &g) { + DiGraphView as_digraph = g; + DiGraphView transitive_reduced = transitive_reduction(as_digraph); + + return TransitiveReducedDataflowGraphView{ + /*full_dataflow_graph=*/g, + /*transitive_reduction=*/transitive_reduced, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..2c6606a06b --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,25 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/containers/are_disjoint.h" + +namespace FlexFlow { + +std::unordered_set get_edges_from_subgraph_to_subgraph( + DiGraphView const &g, + std::unordered_set const &src_subgraph, + std::unordered_set const &dst_subgraph) { + if (!are_disjoint(src_subgraph, dst_subgraph)) { + throw mk_runtime_error( + fmt::format("get_edges_from_subgraph_to_subgraph(DiGraphView, ...) " + "expected src_subgraph and dst_subgraph to be disjoint, " + "but found src_subgraph={}, dst_subgraph={}", + src_subgraph, + dst_subgraph)); + } + + return g.query_edges(DirectedEdgeQuery{ + /*srcs=*/query_set{src_subgraph}, + /*dsts=*/query_set{dst_subgraph}, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc index 4f4c846433..941c8e8e3e 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc @@ -6,6 +6,7 @@ #include "utils/containers/values.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc index 92a6d0b9eb..61c4f80763 100644 --- a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc +++ b/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.cc @@ -23,12 +23,12 @@ void HashmapUndirectedGraph::remove_node_unsafe(Node const &n) { void HashmapUndirectedGraph::add_edge(UndirectedEdge const &e) { if (!contains_key(this->adjacency, e.bigger)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.bigger); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.bigger)); } if (!contains_key(this->adjacency, e.smaller)) { - throw mk_runtime_error( - "Could not add edge connected to non-existent node {}", e.smaller); + throw mk_runtime_error(fmt::format( + "Could not add edge connected to non-existent node {}", e.smaller)); } this->adjacency.at(e.bigger).insert(e.smaller); diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc new file mode 100644 index 0000000000..28d63f9ee1 --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/create_lazy_copy_of_labelled_dataflow_graph_view.h" diff --git a/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc new file mode 100644 index 0000000000..dc5ce4fbda --- /dev/null +++ b/lib/utils/src/utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.cc @@ -0,0 +1 @@ +#include "utils/graph/labelled_dataflow_graph/algorithms/rewrite_node_labels.h" diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc index 4ade34941c..08dda09698 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_subgraph_inputs.cc @@ -5,6 +5,7 @@ #include "utils/containers/values.h" #include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" +#include "utils/hash/vector.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 18d1f922c6..62489ff75f 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -2,42 +2,84 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" namespace FlexFlow { -BinarySPDecompositionTree - make_series_split(BinarySPDecompositionTree const &lhs, - BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{ - make_generic_binary_series_split(lhs.raw_tree, rhs.raw_tree), - }; -} +GenericBinarySPDecompositionTreeImplementation + generic_impl_for_binary_sp_tree() { -BinarySPDecompositionTree - make_parallel_split(BinarySPDecompositionTree const &lhs, - BinarySPDecompositionTree const &rhs) { - return BinarySPDecompositionTree{ - make_generic_binary_parallel_split(lhs.raw_tree, rhs.raw_tree), + return GenericBinarySPDecompositionTreeImplementation< + BinarySPDecompositionTree, + BinarySeriesSplit, + BinaryParallelSplit, + Node>{ + /*series_get_left_child=*/[](BinarySeriesSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*parallel_get_left_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_left_child(); + }, + /*series_get_right_child=*/ + [](BinarySeriesSplit const &split) -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*parallel_get_right_child=*/ + [](BinaryParallelSplit const &split) + -> BinarySPDecompositionTree const & { + return split.get_right_child(); + }, + /*get_node_type=*/ + [](BinarySPDecompositionTree const &tree) -> SPDecompositionTreeNodeType { + return get_node_type(tree); + }, + /*require_series=*/ + [](BinarySPDecompositionTree const &tree) -> BinarySeriesSplit const & { + return tree.require_series(); + }, + /*require_parallel=*/ + [](BinarySPDecompositionTree const &tree) -> BinaryParallelSplit const & { + return tree.require_parallel(); + }, + /*require_leaf=*/ + [](BinarySPDecompositionTree const &tree) -> Node const & { + return tree.require_node(); + }, }; } -BinarySPDecompositionTree make_leaf_node(Node const &n) { - return BinarySPDecompositionTree{ - make_generic_binary_sp_leaf(n), - }; +bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_left_associative(tree, + generic_impl_for_binary_sp_tree()); } -bool is_binary_sp_tree_left_associative(BinarySPDecompositionTree const &tt) { - return is_binary_sp_tree_left_associative(tt.raw_tree); +bool is_binary_sp_tree_right_associative( + BinarySPDecompositionTree const &tree) { + return is_binary_sp_tree_right_associative(tree, + generic_impl_for_binary_sp_tree()); } -bool is_binary_sp_tree_right_associative(BinarySPDecompositionTree const &tt) { - return is_binary_sp_tree_right_associative(tt.raw_tree); +std::unordered_multiset + get_leaves(BinarySPDecompositionTree const &tree) { + return get_leaves(tree, generic_impl_for_binary_sp_tree()); } -std::unordered_multiset get_leaves(BinarySPDecompositionTree const &tt) { - return get_leaves(tt.raw_tree); +SPDecompositionTreeNodeType + get_node_type(BinarySPDecompositionTree const &tree) { + return tree.visit(overload{ + [](BinarySeriesSplit const &) { + return SPDecompositionTreeNodeType::SERIES; + }, + [](BinaryParallelSplit const &) { + return SPDecompositionTreeNodeType::PARALLEL; + }, + [](Node const &) { return SPDecompositionTreeNodeType::NODE; }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc new file mode 100644 index 0000000000..07e2c3e3e3 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/find_paths_to_leaf.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_set find_paths_to_leaf( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + Leaf const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc deleted file mode 100644 index 4cd7206408..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc deleted file mode 100644 index 3a4dbad8ec..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc new file mode 100644 index 0000000000..56a6d0cc85 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree_implementation.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +FullBinaryTreeImplementation, Leaf> + get_full_binary_impl_from_generic_sp_impl( + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc deleted file mode 100644 index 4ee18af5be..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc new file mode 100644 index 0000000000..71d3f6ac31 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.cc @@ -0,0 +1,18 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_all_leaf_paths.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_set get_all_leaf_paths( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index 71b67acc54..3bb90bfa32 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::unordered_multiset + get_leaves(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc deleted file mode 100644 index 227e5bd79c..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc deleted file mode 100644 index 1618128226..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_node_type.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 05ec6b5925..3d166145c1 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template int get_num_tree_nodes( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc deleted file mode 100644 index f168ba1e2f..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc new file mode 100644 index 0000000000..d1d8079c0b --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.cc @@ -0,0 +1,19 @@ +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_subtree_at_path.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template std::optional get_subtree_at_path( + Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + BinaryTreePath const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc deleted file mode 100644 index 75c472c435..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc deleted file mode 100644 index 3da024743c..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 8fe9397003..69cbb28582 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<0>; +using Series = value_type<1>; +using Parallel = value_type<2>; +using Leaf = value_type<3>; + +template bool is_binary_sp_tree_left_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index d202f55964..584099e33e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1 +1,18 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template bool is_binary_sp_tree_right_associative( + Tree const &tree, + GenericBinarySPDecompositionTreeImplementation const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc deleted file mode 100644 index b569ff9265..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc deleted file mode 100644 index fb1532b3ef..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc deleted file mode 100644 index 3fee45fcf5..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/require.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc deleted file mode 100644 index cabd66cff7..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc index 25409333f2..056ae2a8d4 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.cc @@ -1 +1,24 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using ReturnType = value_type<0>; +using Tree = value_type<1>; +using Series = value_type<2>; +using Parallel = value_type<3>; +using Leaf = value_type<4>; + +template ReturnType + visit(Tree const &, + GenericBinarySPDecompositionTreeImplementation const &, + GenericBinarySPDecompositionTreeVisitor const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 02e541b7e4..69b2ebea8e 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -2,50 +2,53 @@ #include "utils/containers/foldl1.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( SeriesParallelDecomposition const &nary) { - std::function( + std::function const &)> from_series_child; - std::function( + std::function const &)> from_parallel_child; - auto from_node = [](Node const &n) -> GenericBinarySPDecompositionTree { - return GenericBinarySPDecompositionTree{n}; + auto from_node = [](Node const &n) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{n}; }; - auto from_series = - [&](SeriesSplit const &s) -> GenericBinarySPDecompositionTree { - std::vector> children = + auto from_series = [&](SeriesSplit const &s) -> BinarySPDecompositionTree { + std::vector children = transform(s.children, from_series_child); - return foldl1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{accum, x}, - }; - }); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{accum, x}, + }; + }); }; auto from_parallel = - [&](ParallelSplit const &s) -> GenericBinarySPDecompositionTree { - std::vector> children = - transform(vector_of(s.children), from_parallel_child); - return foldl1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{accum, x}}; - }); + [&](ParallelSplit const &s) -> BinarySPDecompositionTree { + std::vector children = + transform(vector_of(s.get_children()), from_parallel_child); + return foldl1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{accum, x}, + }; + }); }; from_parallel_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit(overload{ [&](Node const &n) { return from_node(n); }, [&](SeriesSplit const &s) { return from_series(s); }, @@ -54,7 +57,7 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( }; from_series_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit( overload{ [&](Node const &n) { return from_node(n); }, @@ -63,13 +66,11 @@ BinarySPDecompositionTree left_associative_binary_sp_tree_from_nary( v); }; - return BinarySPDecompositionTree{ - nary.visit>(overload{ - [&](Node const &n) { return from_node(n); }, - [&](SeriesSplit const &s) { return from_series(s); }, - [&](ParallelSplit const &p) { return from_parallel(p); }, - }), - }; + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 673a4118a6..478d90e0c3 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -2,47 +2,50 @@ #include "utils/containers/foldr1.h" #include "utils/containers/transform.h" #include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( SeriesParallelDecomposition const &nary) { - std::function( + std::function const &)> from_series_child; - std::function( + std::function const &)> from_parallel_child; - auto from_node = [](Node const &n) { - return GenericBinarySPDecompositionTree{n}; - }; + auto from_node = [](Node const &n) { return BinarySPDecompositionTree{n}; }; auto from_series = [&](SeriesSplit const &s) { - std::vector> children = + std::vector children = transform(s.children, from_series_child); - return foldr1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinarySeriesSplit{x, accum}}; - }); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinarySeriesSplit{x, accum}, + }; + }); }; auto from_parallel = [&](ParallelSplit const &s) { - std::vector> children = - transform(vector_of(s.children), from_parallel_child); - return foldr1(children, - [](GenericBinarySPDecompositionTree const &accum, - GenericBinarySPDecompositionTree const &x) { - return GenericBinarySPDecompositionTree{ - GenericBinaryParallelSplit{x, accum}}; - }); + std::vector children = + transform(vector_of(s.get_children()), from_parallel_child); + return foldr1( + children, + [](BinarySPDecompositionTree const &accum, + BinarySPDecompositionTree const &x) -> BinarySPDecompositionTree { + return BinarySPDecompositionTree{ + BinaryParallelSplit{x, accum}, + }; + }); }; from_parallel_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit(overload{ [&](Node const &n) { return from_node(n); }, [&](SeriesSplit const &s) { return from_series(s); }, @@ -51,7 +54,7 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( }; from_series_child = [&](std::variant const &v) - -> GenericBinarySPDecompositionTree { + -> BinarySPDecompositionTree { return std::visit( overload{ [&](Node const &n) { return from_node(n); }, @@ -60,13 +63,11 @@ BinarySPDecompositionTree right_associative_binary_sp_tree_from_nary( v); }; - return BinarySPDecompositionTree{ - nary.visit>(overload{ - [&](Node const &n) { return from_node(n); }, - [&](SeriesSplit const &s) { return from_series(s); }, - [&](ParallelSplit const &p) { return from_parallel(p); }, - }), - }; + return nary.visit(overload{ + [&](Node const &n) { return from_node(n); }, + [&](SeriesSplit const &s) { return from_series(s); }, + [&](ParallelSplit const &p) { return from_parallel(p); }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index ab231f256c..cd29af59a0 100644 --- a/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -33,10 +33,10 @@ std::optional MultiDiGraph ttsp = MultiDiGraph::materialize_copy_of( inverse_line_graph_result.graph); std::unordered_map - ttsp_edge_to_sp_tree = - map_values(inverse_line_graph_result.inverse_edge_to_line_node_bidict - .as_unordered_map(), - [](Node const &n) { return make_leaf_node(n); }); + ttsp_edge_to_sp_tree = map_values( + inverse_line_graph_result.inverse_edge_to_line_node_bidict + .as_unordered_map(), + [](Node const &n) { return BinarySPDecompositionTree{n}; }); while (true) { assert(ttsp_edge_to_sp_tree.size() == get_edges(ttsp).size()); @@ -46,8 +46,12 @@ std::optional ParallelReduction parallel_reduction = maybe_parallel_reduction.value(); auto [e1, e2] = parallel_reduction.edges.ordered(); MultiDiEdge merged = apply_parallel_reduction(ttsp, parallel_reduction); - BinarySPDecompositionTree new_tree = make_parallel_split( - ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinaryParallelSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); @@ -62,8 +66,12 @@ std::optional MultiDiEdge e1 = series_reduction.first; MultiDiEdge e2 = series_reduction.second; MultiDiEdge merged = apply_series_reduction(ttsp, series_reduction); - BinarySPDecompositionTree new_tree = make_series_split( - ttsp_edge_to_sp_tree.at(e1), ttsp_edge_to_sp_tree.at(e2)); + BinarySPDecompositionTree new_tree = BinarySPDecompositionTree{ + BinarySeriesSplit{ + ttsp_edge_to_sp_tree.at(e1), + ttsp_edge_to_sp_tree.at(e2), + }, + }; ttsp_edge_to_sp_tree.erase(e1); ttsp_edge_to_sp_tree.erase(e2); ttsp_edge_to_sp_tree.insert({merged, new_tree}); diff --git a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc index 48c936ec39..410a40236d 100644 --- a/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc +++ b/lib/utils/src/utils/graph/series_parallel/intermediate_sp_decomposition_tree.cc @@ -1,8 +1,6 @@ #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" #include "utils/containers/extend.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/visit.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include "utils/overload.h" namespace FlexFlow { @@ -49,36 +47,31 @@ std::variant flatten_ast( return std::visit(FlattenAST{}, ast); } -std::variant - from_binary_sp_tree(GenericBinarySPDecompositionTree const &binary) { - return visit>( - binary, - overload{ - [](Node const &n) { return n; }, - [](GenericBinarySeriesSplit const &s) { - return IntermediateSpDecompositionTree{ - SplitType::SERIES, - { - from_binary_sp_tree(get_left_child(s)), - from_binary_sp_tree(get_right_child(s)), - }, - }; - }, - [](GenericBinaryParallelSplit const &p) { - return IntermediateSpDecompositionTree{ - SplitType::PARALLEL, - { - from_binary_sp_tree(get_left_child(p)), - from_binary_sp_tree(get_right_child(p)), - }, - }; - }, - }); -} - std::variant from_binary_sp_tree(BinarySPDecompositionTree const &binary) { - return from_binary_sp_tree(binary.raw_tree); + return binary + .template visit>( + overload{ + [](Node const &n) { return n; }, + [](BinarySeriesSplit const &s) { + return IntermediateSpDecompositionTree{ + SplitType::SERIES, + { + from_binary_sp_tree(s.get_left_child()), + from_binary_sp_tree(s.get_right_child()), + }, + }; + }, + [](BinaryParallelSplit const &p) { + return IntermediateSpDecompositionTree{ + SplitType::PARALLEL, + { + from_binary_sp_tree(p.get_left_child()), + from_binary_sp_tree(p.get_right_child()), + }, + }; + }, + }); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index e697533054..b7a84b871a 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -64,7 +64,7 @@ std::unordered_multiset get_nodes(SeriesSplit const &serial) { std::unordered_multiset get_nodes(ParallelSplit const ¶llel) { return multiset_union(transform( - vector_of(parallel.children), + vector_of(parallel.get_children()), [](std::variant const &child) { return std::visit([](auto &&t) { return get_nodes(t); }, child); })); diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc index 0e04a4f904..7d36371e49 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc @@ -1,85 +1,85 @@ #include "utils/graph/series_parallel/series_parallel_splits.h" -#include "utils/fmt/unordered_multiset.h" -#include "utils/fmt/variant.h" -#include "utils/fmt/vector.h" -#include "utils/hash-utils.h" -#include "utils/hash/unordered_multiset.h" -#include "utils/hash/vector.h" - -namespace FlexFlow { - -SeriesSplit::SeriesSplit( - std::vector> const &children) - : children(children) {} - -SeriesSplit::SeriesSplit( - std::initializer_list> const &children) - : children(children) {} - -bool SeriesSplit::operator==(SeriesSplit const &other) const { - return this->tie() == other.tie(); -} - -bool SeriesSplit::operator!=(SeriesSplit const &other) const { - return this->tie() != other.tie(); -} - -SeriesSplit::Tie SeriesSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(SeriesSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { - return s << fmt::to_string(split); -} - -ParallelSplit::ParallelSplit( - std::unordered_multiset> const &children) - : children(children) {} - -ParallelSplit::ParallelSplit( - std::initializer_list> const &children) - : children(children) {} - -bool ParallelSplit::operator==(ParallelSplit const &other) const { - return this->tie() == other.tie(); -} - -bool ParallelSplit::operator!=(ParallelSplit const &other) const { - return this->tie() != other.tie(); -} - -ParallelSplit::Tie ParallelSplit::tie() const { - return std::tie(this->children); -} - -std::string format_as(ParallelSplit const &split) { - return fmt::format("", split.children); -} - -std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { - return s << fmt::to_string(split); -} - -} // namespace FlexFlow - -namespace std { - -size_t hash<::FlexFlow::SeriesSplit>::operator()( - ::FlexFlow::SeriesSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -size_t hash<::FlexFlow::ParallelSplit>::operator()( - ::FlexFlow::ParallelSplit const &s) const { - size_t result = 0; - ::FlexFlow::hash_combine(result, s.children); - return result; -} - -} // namespace std +// #include "utils/fmt/unordered_multiset.h" +// #include "utils/fmt/variant.h" +// #include "utils/fmt/vector.h" +// #include "utils/hash-utils.h" +// #include "utils/hash/unordered_multiset.h" +// #include "utils/hash/vector.h" +// +// namespace FlexFlow { +// +// SeriesSplit::SeriesSplit( +// std::vector> const &children) +// : children(children) {} +// +// SeriesSplit::SeriesSplit( +// std::initializer_list> const &children) +// : children(children) {} +// +// bool SeriesSplit::operator==(SeriesSplit const &other) const { +// return this->tie() == other.tie(); +// } +// +// bool SeriesSplit::operator!=(SeriesSplit const &other) const { +// return this->tie() != other.tie(); +// } +// +// SeriesSplit::Tie SeriesSplit::tie() const { +// return std::tie(this->children); +// } +// +// std::string format_as(SeriesSplit const &split) { +// return fmt::format("", split.children); +// } +// +// std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { +// return s << fmt::to_string(split); +// } +// +// ParallelSplit::ParallelSplit( +// std::unordered_multiset> const &children) +// : children(children) {} +// +// ParallelSplit::ParallelSplit( +// std::initializer_list> const &children) +// : children(children) {} +// +// bool ParallelSplit::operator==(ParallelSplit const &other) const { +// return this->tie() == other.tie(); +// } +// +// bool ParallelSplit::operator!=(ParallelSplit const &other) const { +// return this->tie() != other.tie(); +// } +// +// ParallelSplit::Tie ParallelSplit::tie() const { +// return std::tie(this->children); +// } +// +// std::string format_as(ParallelSplit const &split) { +// return fmt::format("", split.children); +// } +// +// std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { +// return s << fmt::to_string(split); +// } +// +// } // namespace FlexFlow +// +// namespace std { +// +// size_t hash<::FlexFlow::SeriesSplit>::operator()( +// ::FlexFlow::SeriesSplit const &s) const { +// size_t result = 0; +// ::FlexFlow::hash_combine(result, s.children); +// return result; +// } +// +// size_t hash<::FlexFlow::ParallelSplit>::operator()( +// ::FlexFlow::ParallelSplit const &s) const { +// size_t result = 0; +// ::FlexFlow::hash_combine(result, s.children); +// return result; +// } +// +// } // namespace std diff --git a/lib/utils/src/utils/json/check_is_json_deserializable.cc b/lib/utils/src/utils/json/check_is_json_deserializable.cc new file mode 100644 index 0000000000..7e17ced7e5 --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_deserializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_deserializable.h" diff --git a/lib/utils/src/utils/json/check_is_json_serializable.cc b/lib/utils/src/utils/json/check_is_json_serializable.cc new file mode 100644 index 0000000000..1c9af4d3cb --- /dev/null +++ b/lib/utils/src/utils/json/check_is_json_serializable.cc @@ -0,0 +1 @@ +#include "utils/json/check_is_json_serializable.h" diff --git a/lib/utils/test/src/utils/containers/cartesian_product.cc b/lib/utils/test/src/utils/containers/cartesian_product.cc new file mode 100644 index 0000000000..42b8a10439 --- /dev/null +++ b/lib/utils/test/src/utils/containers/cartesian_product.cc @@ -0,0 +1,62 @@ +#include "utils/containers/cartesian_product.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cartesian_product") { + + SUBCASE("empty") { + std::vector> containers = {}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{}}; + CHECK(result == correct); + } + + SUBCASE("single container, one element") { + std::vector> containers = {{1}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1}}; + CHECK(result == correct); + } + + SUBCASE("single container, multiple elements") { + std::vector> containers = {{1, 2, 3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1}, {2}, {3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, one element each") { + std::vector> containers = {{1}, {2}, {3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {{1, 2, 3}}; + CHECK(result == correct); + } + + SUBCASE("multiple containers, multiple elements") { + std::vector> containers = {{1, 2}, {3, 4}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = { + {1, 3}, {1, 4}, {2, 3}, {2, 4}}; + CHECK(result == correct); + } + + SUBCASE("1 empty container, 1 non-empty container") { + std::vector> containers = {{}, {2, 3}}; + std::unordered_set> result = + cartesian_product(containers); + std::unordered_set> correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc new file mode 100644 index 0000000000..c10cc5ae75 --- /dev/null +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -0,0 +1,105 @@ +#include "utils/containers/flatmap.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "utils/containers/map_keys.h" +#include "utils/hash/pair.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flatmap(std::unordered_set, F)") { + auto get_chars = [](std::string const &s) { + std::unordered_set result; + for (char c : s) { + result.insert(c); + } + return result; + }; + + SUBCASE("type changing") { + std::unordered_set input = {"hello", " ", "", "world", "!"}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = { + 'h', 'e', 'l', 'o', ' ', 'w', 'r', 'd', '!'}; + + CHECK(result == correct); + } + + SUBCASE("input is empty") { + std::unordered_set input = {}; + + std::unordered_set result = flatmap(input, get_chars); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + TEST_CASE("flatmap(std::unordered_map, F)") { + auto de_nest_keys = [](int k1, + std::unordered_map const &v) { + return map_keys(v, [&](int k2) { return std::pair{k1, k2}; }); + }; + + SUBCASE("input is empty") { + std::unordered_map> input = {}; + + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input is not empty") { + std::unordered_map> input = { + { + 1, + { + {2, "a"}, + {3, "b"}, + }, + }, + { + 2, + {}, + }, + { + 3, + { + {3, "a"}, + }, + }, + }; + + std::unordered_map, std::string> result = + flatmap(input, de_nest_keys); + std::unordered_map, std::string> correct = { + {{1, 2}, "a"}, + {{1, 3}, "b"}, + {{3, 3}, "a"}, + }; + + CHECK(result == correct); + } + + SUBCASE("duplicate result keys") { + auto always_return_same_map = [](int, std::string const &) { + return std::unordered_map{ + {"mykey", 10000}, + }; + }; + + std::unordered_map input = { + {1, "a"}, + {2, "b"}, + }; + + CHECK_THROWS(flatmap(input, always_return_same_map)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_assignments.cc b/lib/utils/test/src/utils/containers/get_all_assignments.cc new file mode 100644 index 0000000000..d5f989318f --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_assignments.cc @@ -0,0 +1,53 @@ +#include "utils/containers/get_all_assignments.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_all_assignments") { + SUBCASE("empty input") { + std::unordered_map> input = {}; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = {{}}; + + CHECK(result == correct); + } + + SUBCASE("non-empty input") { + std::unordered_map> input = { + {"a", {1, 2, 3}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = { + {{"a", 1}, {"b", 2}}, + {{"a", 1}, {"b", 3}}, + {{"a", 2}, {"b", 2}}, + {{"a", 2}, {"b", 3}}, + {{"a", 3}, {"b", 2}}, + {{"a", 3}, {"b", 3}}, + }; + + CHECK(result == correct); + } + + SUBCASE("one possible-values set is empty") { + std::unordered_map> input = { + {"a", {}}, + {"b", {2, 3}}, + }; + + std::unordered_set> result = + get_all_assignments(input); + std::unordered_set> correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/try_at.cc b/lib/utils/test/src/utils/containers/try_at.cc new file mode 100644 index 0000000000..548c9b0c79 --- /dev/null +++ b/lib/utils/test/src/utils/containers/try_at.cc @@ -0,0 +1,29 @@ +#include "utils/containers/try_at.h" +#include "test/utils/doctest/fmt/optional.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("try_at(T, K)", + T, + std::unordered_map, + std::map) { + T m = {{1, "one"}, {2, "two"}}; + + SUBCASE("map contains key") { + std::optional result = try_at(m, 1); + std::optional correct = "one"; + + CHECK(result == correct); + } + + SUBCASE("map does not contain key") { + std::optional result = try_at(m, 3); + std::optional correct = std::nullopt; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc new file mode 100644 index 0000000000..f0cdb19611 --- /dev/null +++ b/lib/utils/test/src/utils/containers/unordered_map_from_pairs.cc @@ -0,0 +1,57 @@ +#include "utils/containers/unordered_map_from_pairs.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "utils/containers/contains.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("unordered_map_from_pairs") { + SUBCASE("nonempty input") { + std::vector> input = { + {1, "hello"}, + {3, "world"}, + }; + + std::unordered_map result = + unordered_map_from_pairs(input); + std::unordered_map correct = { + {1, "hello"}, + {3, "world"}, + }; + + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector> input = {}; + + std::unordered_map result = + unordered_map_from_pairs(input); + std::unordered_map correct = {}; + + CHECK(result == correct); + } + + SUBCASE("input with duplicate keys") { + std::vector> input = { + {1, "a"}, + {2, "c"}, + {1, "b"}, + }; + + std::unordered_map result = + unordered_map_from_pairs(input); + + std::vector> + possible_correct_values = { + {{1, "a"}, {2, "c"}}, + {{1, "b"}, {2, "c"}}, + }; + + CHECK(contains(possible_correct_values, result)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc new file mode 100644 index 0000000000..fec5d3401e --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.cc @@ -0,0 +1,104 @@ +#include "utils/graph/dataflow_graph/algorithms/get_dataflow_edges_from_node_to_node.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_dataflow_edges_from_node_to_node") { + DataflowGraph g = DataflowGraph::create(); + + SUBCASE("gets edges if there are multiple") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o0 = n1_added.outputs.at(0); + DataflowOutput n1_o1 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o0, n1_o0, n1_o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = { + DataflowEdge{ + n1_o0, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o0, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not get edges to/from other nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n3); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE( + "does not get flipped edges (i.e., respects from vs to direction)") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 0); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n2, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if no edges exist between the given nodes") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n2); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns empty set if src node == dst node (as cycles cannot exist " + "in DataflowGraph") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + + std::unordered_set result = + get_dataflow_edges_from_node_to_node(g, n1, n1); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc new file mode 100644 index 0000000000..c35789044d --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.cc @@ -0,0 +1,55 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_boundary_nodes_for_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_boundary_nodes_for_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + SplitBoundaryNodes result = + get_transitive_reduced_boundary_nodes_for_split(tr_g, split); + SplitBoundaryNodes correct = SplitBoundaryNodes{ + /*pre_split_boundary=*/{n2}, + /*post_split_boundary=*/{n3}, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc new file mode 100644 index 0000000000..1f8f66b932 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.cc @@ -0,0 +1,146 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_edges_across_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_edges_across_split") { + DataflowGraph g = DataflowGraph::create(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + SUBCASE("multiple nodes with edges across") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o2, o1}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o1}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o1, + DataflowInput{n3, 1}, + }, + DataflowEdge{ + o2, + DataflowInput{n3, 0}, + }, + DataflowEdge{ + o1, + DataflowInput{n4, 0}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("nodes each have multiple edges across") { + NodeAddedResult n1_added = g.add_node({}, 2); + Node n1 = n1_added.node; + DataflowOutput n1_o1 = n1_added.outputs.at(0); + DataflowOutput n1_o2 = n1_added.outputs.at(1); + + NodeAddedResult n2_added = g.add_node({n1_o1, n1_o2, n1_o1}, 1); + Node n2 = n2_added.node; + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_leaf(n1), + make_leaf(n2), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + n1_o1, + DataflowInput{n2, 0}, + }, + DataflowEdge{ + n1_o2, + DataflowInput{n2, 1}, + }, + DataflowEdge{ + n1_o1, + DataflowInput{n2, 2}, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not return edges eliminated by transitive reduction") { + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_edges_across_split(tr_g, split); + std::unordered_set correct = { + DataflowEdge{ + o2, + DataflowInput{n3, 1}, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc new file mode 100644 index 0000000000..0e77739434 --- /dev/null +++ b/lib/utils/test/src/utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.cc @@ -0,0 +1,52 @@ +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/get_transitive_reduced_outputs_across_split.h" +#include "utils/containers/get_only.h" +#include "utils/graph/dataflow_graph/algorithms/transitive_reduced_dataflow_graph/transitive_reduced_dataflow_graph.h" +#include "utils/graph/dataflow_graph/dataflow_graph.h" +#include "utils/graph/instances/unordered_set_dataflow_graph.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_transitive_reduced_outputs_across_split") { + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + DataflowGraph g = DataflowGraph::create(); + + NodeAddedResult n1_added = g.add_node({}, 1); + Node n1 = n1_added.node; + DataflowOutput o1 = get_only(n1_added.outputs); + + NodeAddedResult n2_added = g.add_node({o1}, 1); + Node n2 = n2_added.node; + DataflowOutput o2 = get_only(n2_added.outputs); + + NodeAddedResult n3_added = g.add_node({o1, o2}, 1); + Node n3 = n3_added.node; + DataflowOutput o3 = get_only(n3_added.outputs); + + NodeAddedResult n4_added = g.add_node({o2, o3}, 1); + Node n4 = n4_added.node; + DataflowOutput o4 = get_only(n4_added.outputs); + + TransitiveReducedDataflowGraphView tr_g = + get_dataflow_graph_transitive_reduction(g); + + BinarySeriesSplit split = BinarySeriesSplit{ + make_series_split(make_leaf(n1), make_leaf(n2)), + make_series_split(make_leaf(n3), make_leaf(n4)), + }; + + std::unordered_set result = + get_transitive_reduced_outputs_across_split(tr_g, split); + std::unordered_set correct = {o2}; + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc new file mode 100644 index 0000000000..5a1ea99671 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.cc @@ -0,0 +1,142 @@ +#include "utils/graph/digraph/algorithms/get_edges_from_subgraph_to_subgraph.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_edges_from_subgraph_to_subgraph") { + DiGraph g = DiGraph::create(); + + std::vector n = add_nodes(g, 5); + SUBCASE("basic tests") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(4)}; + std::unordered_set dst_subgraph = {n.at(2), n.at(3)}; + + SUBCASE("returns all edges between subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(4), n.at(2)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = unordered_set_of(e); + + CHECK(result == correct); + } + + SUBCASE("does not return reverse edges") { + std::vector e = { + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(0)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(0)}; + + CHECK(result == correct); + } + + SUBCASE("does not return edges within subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if there are no edges from src_subgraph to " + "dst_subgraph") { + std::vector e = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(2), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("empty subgraphs") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + SUBCASE("returns no edges if no nodes in src_subgraph") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, unordered_set_of(n)); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if no nodes in dst_subgraph") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, unordered_set_of(n), {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + + SUBCASE("returns no edges if both subgraphs are empty") { + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, {}, {}); + std::unordered_set correct = {}; + + CHECK(result == correct); + } + } + + SUBCASE("if subgraphs do not cover graph, then does not return external " + "edges") { + std::vector e = { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }; + + add_edges(g, e); + + std::unordered_set src_subgraph = {n.at(0)}; + std::unordered_set dst_subgraph = {n.at(3)}; + + std::unordered_set result = + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph); + std::unordered_set correct = {e.at(1)}; + + CHECK(result == correct); + } + + SUBCASE("throws an error if subgraphs are not disjoint") { + std::unordered_set src_subgraph = {n.at(0), n.at(1), n.at(2)}; + std::unordered_set dst_subgraph = {n.at(1), n.at(3)}; + CHECK_THROWS( + get_edges_from_subgraph_to_subgraph(g, src_subgraph, dst_subgraph)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc deleted file mode 100644 index 66b657eaaa..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.cc +++ /dev/null @@ -1,51 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("fmt GenericBinarySPDecompositionTree") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - std::string result = fmt::to_string(input); - std::string correct = ""; - - CHECK(result == correct); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - - std::string result = fmt::to_string(input); - std::string correct = (" " - "" - ">" - ">"); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - - std::string result = fmt::to_string(input); - std::string correct = (" " - "" - ">" - ">"); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc index abae9286b6..9ca869b2b0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.cc @@ -1,41 +1,61 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_leaves.h" #include "test/utils/doctest/fmt/unordered_multiset.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_leaves(GenericBinarySPDecompositionTree)") { + TEST_CASE("get_leaves") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto generic_get_leaves = [&](BinarySPDecompositionTree const &tree) { + return get_leaves(tree, impl); + }; + SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); + BinarySPDecompositionTree input = BinarySPDecompositionTree{n1}; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1}; CHECK(result == correct); } SUBCASE("series split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 6}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinarySeriesSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 5}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; CHECK(result == correct); } @@ -43,42 +63,54 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("parallel split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n2}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 6}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n2}; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = BinarySPDecompositionTree{ + BinaryParallelSplit{ + BinarySPDecompositionTree{n1}, + BinarySPDecompositionTree{n1}, + }, + }; - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {5, 5}; + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1}; CHECK(result == correct); } } + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5))), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(2))); - - std::unordered_multiset result = get_leaves(input); - std::unordered_multiset correct = {2, 2, 4, 4, 5}; + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); + + std::unordered_multiset result = generic_get_leaves(input); + std::unordered_multiset correct = {n1, n1, n2, n2, n3}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc deleted file mode 100644 index 92c556ad28..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_left_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_left_child(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - CHECK_THROWS(get_left_child(input)); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(3)); - - GenericBinarySPDecompositionTree result = get_left_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(5); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(7)); - - GenericBinarySPDecompositionTree result = get_left_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(4); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc index 3de61d3313..ad7e1c2609 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.cc @@ -1,16 +1,43 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_num_tree_nodes.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_num_tree_nodes(GenericBinarySPDecompositionTree)") { + TEST_CASE("get_num_tree_nodes") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + + auto generic_get_num_tree_nodes = + [&](BinarySPDecompositionTree const &tree) { + return get_num_tree_nodes(tree, impl); + }; + SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); + BinarySPDecompositionTree input = make_leaf(n1); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 1; CHECK(result == correct); @@ -18,22 +45,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("series split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n2)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = + make_series_split(make_leaf(n1), make_leaf(n1)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); @@ -42,22 +67,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("parallel split") { SUBCASE("children are not the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n2)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); } SUBCASE("children are the same") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(5)); + BinarySPDecompositionTree input = + make_parallel_split(make_leaf(n1), make_leaf(n1)); - int result = get_num_tree_nodes(input); + int result = generic_get_num_tree_nodes(input); int correct = 3; CHECK(result == correct); @@ -65,18 +88,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5))), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(2))); - - int result = get_num_tree_nodes(input); + BinarySPDecompositionTree input = make_parallel_split( + make_series_split(make_leaf(n1), + make_series_split(make_leaf(n2), make_leaf(n3))), + make_parallel_split(make_leaf(n2), make_leaf(n1))); + + int result = generic_get_num_tree_nodes(input); int correct = 9; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc deleted file mode 100644 index 33b5d37955..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/get_right_child.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_right_child(GenericBinarySPDecompositionTree)") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(5); - - CHECK_THROWS(get_right_child(input)); - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(3)); - - GenericBinarySPDecompositionTree result = get_right_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(3); - - CHECK(result == correct); - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(7)); - - GenericBinarySPDecompositionTree result = get_right_child(input); - GenericBinarySPDecompositionTree correct = - make_generic_binary_sp_leaf(7); - - CHECK(result == correct); - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc deleted file mode 100644 index e7025dbfad..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.cc +++ /dev/null @@ -1,117 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/hash.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("std::hash>") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree leaf_5 = - make_generic_binary_sp_leaf(5); - size_t leaf_5_hash = get_std_hash(leaf_5); - - SUBCASE("leaves with same labels hash to the same value") { - GenericBinarySPDecompositionTree also_leaf_5 = - make_generic_binary_sp_leaf(5); - size_t also_leaf_5_hash = get_std_hash(also_leaf_5); - - CHECK(leaf_5_hash == also_leaf_5_hash); - } - - SUBCASE("leaves with different labels hash to different values") { - GenericBinarySPDecompositionTree leaf_6 = - make_generic_binary_sp_leaf(6); - size_t leaf_6_hash = get_std_hash(leaf_6); - - CHECK(leaf_5_hash != leaf_6_hash); - } - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree series_5_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t series_5_6_hash = get_std_hash(series_5_6); - - SUBCASE("same children lead to the same hash") { - GenericBinarySPDecompositionTree also_series_5_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t also_series_5_6_hash = get_std_hash(also_series_5_6); - - CHECK(series_5_6_hash == also_series_5_6_hash); - } - - SUBCASE("hash is order dependent") { - GenericBinarySPDecompositionTree series_6_5 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(6), - make_generic_binary_sp_leaf(5)); - size_t series_6_5_hash = get_std_hash(series_6_5); - - CHECK(series_5_6_hash != series_6_5_hash); - } - - SUBCASE("different left child leads to different hash") { - GenericBinarySPDecompositionTree series_4_6 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(6)); - size_t series_4_6_hash = get_std_hash(series_4_6); - - CHECK(series_5_6_hash != series_4_6_hash); - } - - SUBCASE("different right child leads to different hash") { - GenericBinarySPDecompositionTree series_5_7 = - make_generic_binary_series_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - size_t series_5_7_hash = get_std_hash(series_5_7); - - CHECK(series_5_6_hash != series_5_7_hash); - } - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree parallel_5_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t parallel_5_6_hash = get_std_hash(parallel_5_6); - - SUBCASE("same children lead to the same hash") { - GenericBinarySPDecompositionTree also_parallel_5_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(6)); - size_t also_parallel_5_6_hash = get_std_hash(also_parallel_5_6); - - CHECK(parallel_5_6_hash == also_parallel_5_6_hash); - } - - SUBCASE("hash is order dependent") { - GenericBinarySPDecompositionTree parallel_6_5 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(6), - make_generic_binary_sp_leaf(5)); - size_t parallel_6_5_hash = get_std_hash(parallel_6_5); - - CHECK(parallel_5_6_hash != parallel_6_5_hash); - } - - SUBCASE("different left child leads to different hash") { - GenericBinarySPDecompositionTree parallel_4_6 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(4), - make_generic_binary_sp_leaf(6)); - size_t parallel_4_6_hash = get_std_hash(parallel_4_6); - - CHECK(parallel_5_6_hash != parallel_4_6_hash); - } - - SUBCASE("different right child leads to different hash") { - GenericBinarySPDecompositionTree parallel_5_7 = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(5), - make_generic_binary_sp_leaf(7)); - size_t parallel_5_7_hash = get_std_hash(parallel_5_7); - - CHECK(parallel_5_6_hash != parallel_5_7_hash); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc index 7a8756c6cc..3fae155280 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.cc @@ -1,22 +1,38 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_left_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("is_binary_sp_tree_left_associative(" - "GenericBinarySPDecompositionTree)") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; + TEST_CASE("is_binary_sp_tree_left_associative") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually left associative") { SUBCASE("just node") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -25,12 +41,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -39,12 +51,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -53,14 +61,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n3), - make_generic_binary_sp_leaf(n4))); + BinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); bool result = is_binary_sp_tree_left_associative(input); bool correct = true; @@ -71,12 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not left associative") { SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); bool correct = false; @@ -85,12 +84,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_left_associative(input); bool correct = false; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc index 3cf87368ab..5b4e26107e 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.cc @@ -1,22 +1,38 @@ #include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/is_binary_sp_tree_right_associative.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.dtg.h" +#include "utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("is_binary_sp_tree_right_associative(" - "GenericBinarySPDecompositionTree)") { - int n1 = 1; - int n2 = 2; - int n3 = 3; - int n4 = 4; + TEST_CASE("is_binary_sp_tree_right_associative") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + GenericBinarySPDecompositionTreeImplementation + impl = generic_impl_for_binary_sp_tree(); + + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; SUBCASE("input is actually right associative") { SUBCASE("just node") { - GenericBinarySPDecompositionTree input = - make_generic_binary_sp_leaf(n1); + BinarySPDecompositionTree input = make_leaf(n1); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -25,12 +41,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_series_split( + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -39,12 +51,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n2), - make_generic_binary_sp_leaf(n3))); + BinarySPDecompositionTree input = make_parallel_split( + make_leaf(n1), make_parallel_split(make_leaf(n2), make_leaf(n3))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -53,14 +61,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("nested") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n3), - make_generic_binary_sp_leaf(n4))); + BinarySPDecompositionTree input = make_series_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), + make_parallel_split(make_leaf(n3), make_leaf(n4))); bool result = is_binary_sp_tree_right_associative(input); bool correct = true; @@ -71,12 +74,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("input is not right associative") { SUBCASE("just series") { - GenericBinarySPDecompositionTree input = - make_generic_binary_series_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_series_split( + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); bool correct = false; @@ -85,12 +84,8 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("just parallel") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_parallel_split( - make_generic_binary_sp_leaf(n1), - make_generic_binary_sp_leaf(n2)), - make_generic_binary_sp_leaf(n3)); + BinarySPDecompositionTree input = make_parallel_split( + make_parallel_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); bool result = is_binary_sp_tree_right_associative(input); bool correct = false; diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc deleted file mode 100644 index cc234bacf8..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.cc +++ /dev/null @@ -1,131 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/json.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("adl_serializer>") { - SUBCASE("leaf") { - GenericBinarySPDecompositionTree tt = make_generic_binary_sp_leaf(5); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - - SUBCASE("series split") { - GenericBinarySPDecompositionTree tt = - make_generic_binary_series_split(make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5)); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "series"}, - { - "value", - { - {"__type", "GenericBinarySeriesSplit"}, - { - "left_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 2}, - }, - }, - { - "right_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }, - }, - }, - }, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - - SUBCASE("parallel split") { - GenericBinarySPDecompositionTree tt = - make_generic_binary_parallel_split(make_generic_binary_sp_leaf(2), - make_generic_binary_sp_leaf(5)); - - nlohmann::json tt_json = { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "parallel"}, - { - "value", - { - {"__type", "GenericBinaryParallelSplit"}, - { - "left_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 2}, - }, - }, - { - "right_child", - { - {"__type", "GenericBinarySPDecompositionTree"}, - {"type", "leaf"}, - {"value", 5}, - }, - }, - }, - }, - }; - - SUBCASE("to_json") { - nlohmann::json result = tt; - nlohmann::json correct = tt_json; - - CHECK(result == correct); - } - - SUBCASE("from_json") { - GenericBinarySPDecompositionTree result = - tt_json.get>(); - GenericBinarySPDecompositionTree correct = tt; - - CHECK(result == correct); - } - } - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc deleted file mode 100644 index 4ede4e84b5..0000000000 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.cc +++ /dev/null @@ -1,28 +0,0 @@ -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/transform.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/fmt.h" -#include "utils/graph/series_parallel/binary_sp_decomposition_tree/generic_binary_sp_decomposition_tree/make.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("transform(GenericBinarySPDecompositionTree, F)") { - GenericBinarySPDecompositionTree input = - make_generic_binary_parallel_split( - make_generic_binary_series_split(make_generic_binary_sp_leaf(1), - make_generic_binary_sp_leaf(4)), - make_generic_binary_sp_leaf(2)); - - GenericBinarySPDecompositionTree result = - transform(input, [](int x) { return std::to_string(x); }); - - GenericBinarySPDecompositionTree correct = - make_generic_binary_parallel_split( - make_generic_binary_series_split( - make_generic_binary_sp_leaf(std::string{"1"}), - make_generic_binary_sp_leaf(std::string{"4"})), - make_generic_binary_sp_leaf(std::string{"2"})); - - CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 1e3217a2de..fee971e5e0 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -18,34 +18,45 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); - BinarySPDecompositionTree correct = make_leaf_node(n1); + BinarySPDecompositionTree correct = make_leaf(n1); CHECK(result == correct); } SUBCASE("only serial") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - SeriesSplit{n1, n2, n3}, + SeriesSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_series_split(make_leaf_node(n1), make_leaf_node(n2)), - make_leaf_node(n3)); + make_series_split(make_leaf(n1), make_leaf(n2)), make_leaf(n3)); CHECK(result == correct); } SUBCASE("only parallel") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{n1, n2, n3}, + ParallelSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = @@ -64,20 +75,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nested") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; BinarySPDecompositionTree result = diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc index 0befbde5cc..fd540f853f 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/nary_sp_tree_from_binary.cc @@ -14,8 +14,20 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("leaf") { - BinarySPDecompositionTree input = make_leaf_node(n1); + BinarySPDecompositionTree input = make_leaf(n1); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; @@ -25,35 +37,33 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative series") { BinarySPDecompositionTree input = make_series_split( - make_series_split(make_leaf_node(n2), make_leaf_node(n1)), - make_leaf_node(n3)); + make_series_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("right associative series") { BinarySPDecompositionTree input = make_series_split( - make_leaf_node(n2), - make_series_split(make_leaf_node(n1), make_leaf_node(n3))); + make_leaf(n2), make_series_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n2, n1, n3}}; + SeriesParallelDecomposition{SeriesSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("series with duplicate children") { BinarySPDecompositionTree input = - make_series_split(make_leaf_node(n1), make_leaf_node(n1)); + make_series_split(make_leaf(n1), make_leaf(n1)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{SeriesSplit{n1, n1}}; + SeriesParallelDecomposition{SeriesSplit{{n1, n1}}}; CHECK(get_nodes(result).size() == 2); CHECK(result == correct); @@ -61,35 +71,33 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("left associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_parallel_split(make_leaf_node(n2), make_leaf_node(n1)), - make_leaf_node(n3)); + make_parallel_split(make_leaf(n2), make_leaf(n1)), make_leaf(n3)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("right associative parallel") { BinarySPDecompositionTree input = make_parallel_split( - make_leaf_node(n2), - make_parallel_split(make_leaf_node(n1), make_leaf_node(n3))); + make_leaf(n2), make_parallel_split(make_leaf(n1), make_leaf(n3))); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n2, n1, n3}}; + SeriesParallelDecomposition{ParallelSplit{{n2, n1, n3}}}; CHECK(result == correct); } SUBCASE("parallel with duplicate children") { BinarySPDecompositionTree input = - make_parallel_split(make_leaf_node(n1), make_leaf_node(n1)); + make_parallel_split(make_leaf(n1), make_leaf(n1)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = - SeriesParallelDecomposition{ParallelSplit{n1, n1}}; + SeriesParallelDecomposition{ParallelSplit{{n1, n1}}}; CHECK(get_nodes(result).size() == 2); CHECK(result == correct); @@ -99,31 +107,31 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree input = make_parallel_split( make_parallel_split( make_parallel_split( - make_leaf_node(n1), + make_leaf(n1), make_series_split( - make_series_split(make_series_split(make_leaf_node(n2), - make_leaf_node(n3)), - make_leaf_node(n3)), - make_leaf_node(n5))), - make_series_split(make_leaf_node(n6), make_leaf_node(n4))), - make_leaf_node(n5)); + make_series_split( + make_series_split(make_leaf(n2), make_leaf(n3)), + make_leaf(n3)), + make_leaf(n5))), + make_series_split(make_leaf(n6), make_leaf(n4))), + make_leaf(n5)); SeriesParallelDecomposition result = nary_sp_tree_from_binary(input); SeriesParallelDecomposition correct = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; CHECK(result == correct); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index db1b440481..532ff86c90 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -16,34 +16,45 @@ TEST_SUITE(FF_TEST_SUITE) { Node n5 = Node{5}; Node n6 = Node{6}; + auto make_series_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinarySeriesSplit{lhs, rhs}}; + }; + + auto make_parallel_split = [](BinarySPDecompositionTree const &lhs, + BinarySPDecompositionTree const &rhs) { + return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; + }; + + auto make_leaf = [](Node const &n) { return BinarySPDecompositionTree{n}; }; + SUBCASE("only node") { SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); - BinarySPDecompositionTree correct = make_leaf_node(n1); + BinarySPDecompositionTree correct = make_leaf(n1); CHECK(result == correct); } SUBCASE("only serial") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - SeriesSplit{n1, n2, n3}, + SeriesSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); BinarySPDecompositionTree correct = make_series_split( - make_leaf_node(n1), - make_series_split(make_leaf_node(n2), make_leaf_node(n3))); + make_leaf(n1), make_series_split(make_leaf(n2), make_leaf(n3))); CHECK(result == correct); } SUBCASE("only parallel") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{n1, n2, n3}, + ParallelSplit{{n1, n2, n3}}, }; BinarySPDecompositionTree result = @@ -62,20 +73,20 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nested") { SeriesParallelDecomposition input = SeriesParallelDecomposition{ - ParallelSplit{ + ParallelSplit{{ n1, - SeriesSplit{ + SeriesSplit{{ n2, n3, n3, n5, - }, - SeriesSplit{ + }}, + SeriesSplit{{ n6, n4, - }, + }}, n5, - }, + }}, }; BinarySPDecompositionTree result = diff --git a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc index 45f796c824..e5b9045739 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/get_series_parallel_decomposition.cc @@ -24,10 +24,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_series_parallel_decomposition(g); std::optional correct = - SeriesParallelDecomposition{ParallelSplit{ + SeriesParallelDecomposition{ParallelSplit{{ n.at(0), n.at(1), - }}; + }}}; CHECK(result == correct); } @@ -39,10 +39,10 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional result = get_series_parallel_decomposition(g); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ + SeriesParallelDecomposition{SeriesSplit{{ n.at(0), n.at(1), - }}; + }}}; CHECK(result == correct); } @@ -59,13 +59,13 @@ TEST_SUITE(FF_TEST_SUITE) { get_series_parallel_decomposition(g); std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ + SeriesSplit{{ n.at(0), - ParallelSplit{ + ParallelSplit{{ n.at(1), n.at(2), - }, - }, + }}, + }}, }; CHECK(result == correct); } @@ -86,20 +86,20 @@ TEST_SUITE(FF_TEST_SUITE) { }); std::optional correct = - SeriesParallelDecomposition{SeriesSplit{ + SeriesParallelDecomposition{SeriesSplit{{ n.at(0), - ParallelSplit{ - SeriesSplit{ + ParallelSplit{{ + SeriesSplit{{ n.at(1), n.at(3), - }, - SeriesSplit{ + }}, + SeriesSplit{{ n.at(2), n.at(4), - }, - }, + }}, + }}, n.at(5), - }}; + }}}; std::optional result = get_series_parallel_decomposition(g); @@ -122,16 +122,16 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ - ParallelSplit{ + SeriesSplit{{ + ParallelSplit{{ n.at(0), n.at(1), - }, - ParallelSplit{ + }}, + ParallelSplit{{ n.at(2), n.at(3), - }, - }, + }}, + }}, }; std::optional result = @@ -177,12 +177,12 @@ TEST_SUITE(FF_TEST_SUITE) { std::optional correct = SeriesParallelDecomposition{ - SeriesSplit{ + SeriesSplit{{ n.at(0), n.at(1), n.at(2), n.at(3), - }, + }}, }; std::optional result = get_series_parallel_decomposition(g); From a9d10d7c9c21bc79afa032f5139bbad507370705 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Mon, 7 Oct 2024 22:52:47 -0700 Subject: [PATCH 10/12] Fix `concretize_abstract_tensor_set_movement` (#1519) * Fix concretize_tensor_set_movement * Format --- .../get_tensor_set_movement_across_split.cc | 523 ++++++++++-------- .../series_parallel/series_parallel_splits.cc | 85 --- 2 files changed, 284 insertions(+), 324 deletions(-) delete mode 100644 lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 82210a138b..c66d533d0f 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -1,239 +1,284 @@ -// #include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" -// #include "compiler/machine_mapping/transitive_reduced_pcg.h" -// #include "compiler/series_parallel/pcg_binary_sp_decomposition.h" -// #include "pcg/machine_view.h" -// #include "pcg/parallel_computation_graph/parallel_computation_graph.h" -// #include -// "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" -// #include "utils/containers/get_only.h" -// #include -// #include "./cost_estimator_for_test.h" -// -// using namespace ::FlexFlow; -// -// TEST_SUITE(FF_TEST_SUITE) { -// TEST_CASE("get_tensor_set_movement_across_split") { -// ParallelComputationGraph pcg = empty_parallel_computation_graph(); -// -// ParallelTensorShape input_shape = -// ParallelTensorShape{ -// ParallelTensorDims{ -// FFOrdered{ -// ShardParallelDim{10, 2}, -// ShardParallelDim{12, 1}, -// }, -// ReplicaParallelDimSet{ -// SumDegree{1}, -// DiscardCopyDegree{1}, -// }, -// }, -// DataType::FLOAT, -// }; -// ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); -// -// ParallelLayerAttrs relu_attrs -// = ParallelLayerAttrs{ -// /*op_attrs=*/PCGOperatorAttrs{ -// ElementUnaryAttrs{ -// /*op_type=*/OperatorType::RELU, -// /*scalar=*/std::nullopt, -// }, -// }, -// /*name=*/std::nullopt, -// }; -// -// ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ -// /*shape=*/input_shape, -// /*sync_type=*/std::nullopt, -// /*initializer=*/std::nullopt, -// /*create_gradients=*/CreateGrad::YES, -// }; -// -// ParallelLayerAddedResult relu_1 -// = add_parallel_layer(pcg, -// relu_attrs, -// {get_only(input.outputs)}, -// {relu_output_attrs}); -// ParallelLayerAddedResult relu_2 -// = add_parallel_layer(pcg, -// relu_attrs, -// {get_only(relu_1.outputs)}, -// {relu_output_attrs}); -// -// MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); -// MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); -// MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); -// MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); -// -// SUBCASE("single edge across split") { -// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( -// make_pcg_series_split( -// make_pcg_leaf_node(input.parallel_layer), -// make_pcg_leaf_node(relu_1.parallel_layer)), -// make_pcg_leaf_node(relu_2.parallel_layer))); -// -// PartialMachineMapping pre_mapping = PartialMachineMapping{{ -// {relu_1.parallel_layer, pre_mv1}, -// }}; -// -// PartialMachineMapping post_mapping = PartialMachineMapping{{ -// {relu_2.parallel_layer, post_mv1}, -// }}; -// -// TensorSetMovement result = -// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), -// split, -// pre_mapping, -// post_mapping); -// TensorSetMovement correct = TensorSetMovement{ -// /*single_tensor_movements=*/{ -// SingleTensorMovement{ -// /*parallel_tensor_shape=*/input_shape, -// /*src_machine_views=*/{pre_mv1}, -// /*dst_machine_views=*/{post_mv1}, -// }, -// }, -// }; -// -// CHECK(result == correct); -// } -// -// SUBCASE("does not include edges removed by transitive reduction") { -// -// } -// -// SUBCASE("single tensor, multiple consumers across split") { -// ParallelLayerAddedResult relu_3 -// = add_parallel_layer(pcg, -// relu_attrs, -// {get_only(relu_1.outputs)}, -// {relu_output_attrs}); -// -// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( -// make_pcg_series_split( -// make_pcg_leaf_node(input.parallel_layer), -// make_pcg_leaf_node(relu_1.parallel_layer)), -// make_pcg_parallel_split( -// make_pcg_leaf_node(relu_2.parallel_layer), -// make_pcg_leaf_node(relu_3.parallel_layer)))); -// -// SUBCASE("consumers have same view") { -// PartialMachineMapping pre_mapping = PartialMachineMapping{{ -// {relu_1.parallel_layer, pre_mv1}, -// }}; -// -// PartialMachineMapping post_mapping = PartialMachineMapping{{ -// {relu_2.parallel_layer, post_mv1}, -// {relu_3.parallel_layer, post_mv1}, -// }}; -// -// TensorSetMovement result = -// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), -// split, -// pre_mapping, -// post_mapping); -// -// TensorSetMovement correct = TensorSetMovement{ -// /*single_tensor_movements=*/{ -// SingleTensorMovement{ -// /*parallel_tensor_shape=*/input_shape, -// /*src_machine_views=*/{pre_mv1}, -// /*dst_machine_views=*/{post_mv1}, -// }, -// }, -// }; -// -// CHECK(result == correct); -// } -// -// SUBCASE("consumers have different views") { -// PartialMachineMapping pre_mapping = PartialMachineMapping{{ -// {relu_1.parallel_layer, pre_mv1}, -// }}; -// -// PartialMachineMapping post_mapping = PartialMachineMapping{{ -// {relu_2.parallel_layer, post_mv1}, -// {relu_3.parallel_layer, post_mv2}, -// }}; -// -// TensorSetMovement result = -// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), -// split, -// pre_mapping, -// post_mapping); -// -// TensorSetMovement correct = TensorSetMovement{ -// /*single_tensor_movements=*/{ -// SingleTensorMovement{ -// /*parallel_tensor_shape=*/input_shape, -// /*src_machine_views=*/{pre_mv1}, -// /*dst_machine_views=*/{post_mv1, post_mv2}, -// }, -// }, -// }; -// -// CHECK(result == correct); -// } -// } -// -// SUBCASE("multiple tensors, multiple consumers across split") { -// ParallelLayerAddedResult relu_3 -// = add_parallel_layer(pcg, -// relu_attrs, -// {get_only(input.outputs)}, -// {relu_output_attrs}); -// -// ParallelLayerAddedResult relu_4 -// = add_parallel_layer(pcg, -// relu_attrs, -// // relu's don't have two inputs, but for the -// purposes of this test it's fine. -// {get_only(relu_1.outputs), -// get_only(relu_3.outputs)}, {relu_output_attrs}); -// -// PartialMachineMapping pre_mapping = PartialMachineMapping{{ -// {relu_1.parallel_layer, pre_mv1}, -// {relu_3.parallel_layer, pre_mv2}, -// }}; -// -// PartialMachineMapping post_mapping = PartialMachineMapping{{ -// {relu_2.parallel_layer, post_mv1}, -// {relu_4.parallel_layer, post_mv2}, -// }}; -// -// PCGBinarySeriesSplit split = require_series(make_pcg_series_split( -// make_pcg_series_split( -// make_pcg_leaf_node(input.parallel_layer), -// make_pcg_parallel_split( -// make_pcg_leaf_node(relu_1.parallel_layer), -// make_pcg_leaf_node(relu_3.parallel_layer))), -// make_pcg_parallel_split( -// make_pcg_leaf_node(relu_2.parallel_layer), -// make_pcg_leaf_node(relu_4.parallel_layer)))); -// -// TensorSetMovement result = -// get_tensor_set_movement_across_split(pcg_get_transitive_reduction(pcg), -// split, -// pre_mapping, -// post_mapping); -// -// -// TensorSetMovement correct = TensorSetMovement{ -// /*single_tensor_movements=*/{ -// SingleTensorMovement{ -// /*parallel_tensor_shape=*/input_shape, -// /*src_machine_views=*/{pre_mv1}, -// /*dst_machine_views=*/{post_mv1, post_mv2}, -// }, -// SingleTensorMovement{ -// /*parallel_tensor_shape=*/input_shape, -// /*src_machine_views=*/{pre_mv2}, -// /*dst_machine_views=*/{post_mv2}, -// }, -// }, -// }; -// -// CHECK(result == correct); -// } -// } -// } +#include "compiler/machine_mapping/get_tensor_set_movement_across_split.h" +#include "./cost_estimator_for_test.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_tensor_set_movement_across_split") { + auto make_pcg_series_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinarySeriesSplit{lhs, rhs}}; + }; + + auto make_pcg_parallel_split = [](PCGBinarySPDecomposition const &lhs, + PCGBinarySPDecomposition const &rhs) { + return PCGBinarySPDecomposition{PCGBinaryParallelSplit{lhs, rhs}}; + }; + + auto make_pcg_leaf_node = [](parallel_layer_guid_t const &l) { + return PCGBinarySPDecomposition{l}; + }; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + ParallelTensorShape input_shape = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{ + ShardParallelDim{10, 2}, + ShardParallelDim{12, 1}, + }, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + ParallelLayerAddedResult input = pcg_add_input_layer(pcg, input_shape); + + ParallelLayerAttrs relu_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{ + ElementUnaryAttrs{ + /*op_type=*/OperatorType::RELU, + /*scalar=*/std::nullopt, + }, + }, + /*name=*/std::nullopt, + }; + + ParallelTensorAttrs relu_output_attrs = ParallelTensorAttrs{ + /*shape=*/input_shape, + /*sync_type=*/std::nullopt, + /*initializer=*/std::nullopt, + /*create_gradients=*/CreateGrad::YES, + }; + + ParallelLayerAddedResult relu_1 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + ParallelLayerAddedResult relu_2 = add_parallel_layer( + pcg, relu_attrs, {get_only(relu_1.outputs)}, {relu_output_attrs}); + + MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); + MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); + MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); + MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); + + SUBCASE("single edge across split") { + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_leaf_node(relu_2.parallel_layer), + }; + + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + {BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv1}, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{}}, + post_mv1, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1}, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("does not include edges removed by transitive reduction") {} + + SUBCASE("single tensor, multiple consumers across split") { + ParallelLayerAddedResult relu_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(relu_1.outputs)}, {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_leaf_node(relu_1.parallel_layer)), + make_pcg_parallel_split(make_pcg_leaf_node(relu_2.parallel_layer), + make_pcg_leaf_node(relu_3.parallel_layer)), + }; + + SUBCASE("consumers have same view") { + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv1, + }, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + post_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + post_mv1, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1}, + }, + }, + }; + + CHECK(result == correct); + } + + SUBCASE("consumers have different views") { + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv1, + }, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + post_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + post_mv2, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), + split, + pre_mapping, + post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1, post_mv2}, + }, + }, + }; + + CHECK(result == correct); + } + } + + SUBCASE("multiple tensors, multiple consumers across split") { + ParallelLayerAddedResult relu_3 = add_parallel_layer( + pcg, relu_attrs, {get_only(input.outputs)}, {relu_output_attrs}); + + ParallelLayerAddedResult relu_4 = add_parallel_layer( + pcg, + relu_attrs, + // relu's don't have two inputs, but for the + // purposes of this test it's fine. + {get_only(relu_1.outputs), get_only(relu_3.outputs)}, + {relu_output_attrs}); + + PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ + make_pcg_series_split(make_pcg_leaf_node(input.parallel_layer), + make_pcg_parallel_split( + make_pcg_leaf_node(relu_1.parallel_layer), + make_pcg_leaf_node(relu_3.parallel_layer))), + make_pcg_parallel_split(make_pcg_leaf_node(relu_2.parallel_layer), + make_pcg_leaf_node(relu_4.parallel_layer)), + }; + + auto pre_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + pre_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + pre_mv2, + }, + }}; + + auto post_mapping = ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + post_mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + post_mv2, + }, + }}; + + TensorSetMovement result = get_tensor_set_movement_across_split( + pcg_get_transitive_reduction(pcg), split, pre_mapping, post_mapping); + + TensorSetMovement correct = TensorSetMovement{ + /*single_tensor_movements=*/{ + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv1}, + /*dst_machine_views=*/{post_mv1, post_mv2}, + }, + SingleTensorMovement{ + /*parallel_tensor_shape=*/input_shape, + /*src_machine_views=*/{pre_mv2}, + /*dst_machine_views=*/{post_mv2}, + }, + }, + }; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc deleted file mode 100644 index 7d36371e49..0000000000 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_splits.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "utils/graph/series_parallel/series_parallel_splits.h" -// #include "utils/fmt/unordered_multiset.h" -// #include "utils/fmt/variant.h" -// #include "utils/fmt/vector.h" -// #include "utils/hash-utils.h" -// #include "utils/hash/unordered_multiset.h" -// #include "utils/hash/vector.h" -// -// namespace FlexFlow { -// -// SeriesSplit::SeriesSplit( -// std::vector> const &children) -// : children(children) {} -// -// SeriesSplit::SeriesSplit( -// std::initializer_list> const &children) -// : children(children) {} -// -// bool SeriesSplit::operator==(SeriesSplit const &other) const { -// return this->tie() == other.tie(); -// } -// -// bool SeriesSplit::operator!=(SeriesSplit const &other) const { -// return this->tie() != other.tie(); -// } -// -// SeriesSplit::Tie SeriesSplit::tie() const { -// return std::tie(this->children); -// } -// -// std::string format_as(SeriesSplit const &split) { -// return fmt::format("", split.children); -// } -// -// std::ostream &operator<<(std::ostream &s, SeriesSplit const &split) { -// return s << fmt::to_string(split); -// } -// -// ParallelSplit::ParallelSplit( -// std::unordered_multiset> const &children) -// : children(children) {} -// -// ParallelSplit::ParallelSplit( -// std::initializer_list> const &children) -// : children(children) {} -// -// bool ParallelSplit::operator==(ParallelSplit const &other) const { -// return this->tie() == other.tie(); -// } -// -// bool ParallelSplit::operator!=(ParallelSplit const &other) const { -// return this->tie() != other.tie(); -// } -// -// ParallelSplit::Tie ParallelSplit::tie() const { -// return std::tie(this->children); -// } -// -// std::string format_as(ParallelSplit const &split) { -// return fmt::format("", split.children); -// } -// -// std::ostream &operator<<(std::ostream &s, ParallelSplit const &split) { -// return s << fmt::to_string(split); -// } -// -// } // namespace FlexFlow -// -// namespace std { -// -// size_t hash<::FlexFlow::SeriesSplit>::operator()( -// ::FlexFlow::SeriesSplit const &s) const { -// size_t result = 0; -// ::FlexFlow::hash_combine(result, s.children); -// return result; -// } -// -// size_t hash<::FlexFlow::ParallelSplit>::operator()( -// ::FlexFlow::ParallelSplit const &s) const { -// size_t result = 0; -// ::FlexFlow::hash_combine(result, s.children); -// return result; -// } -// -// } // namespace std From 65c3911d0a3b9c225086f9d18ebc32720a89e6c2 Mon Sep 17 00:00:00 2001 From: Marsella8 <45826022+Marsella8@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:53:24 -0700 Subject: [PATCH 11/12] Unordered StridedRectangle, `get_allowed_machine_views` (#1458) * containers helper functions * Additional support for unordered_multiset * format fix * Unordered Machine Mapping and adjacent changes * formatting * Minor fixes * Update to StridedRectangle interface * Minor updates * added get_allowed_machine_views * formatting * minor fix * Added StartInvariantMachineView * formatting * Containers fix * Implemented tensor to machine view injection * small refactor * formatting * Cleaning Up * Formatting fix * new machine-view interface * update to allowed machine views * PR review fixes * update to machine view and getting allowed machine view to match new interface * formatting * minor fix * PR fixes * PR fixes * machineview interface change * Minor PR fixes * .cc machine view fixes + added StartInvariantMachineView * minor PR fixes * minor fixes * Post-merge fixes * Format --------- Co-authored-by: Pietro Max Marsella Co-authored-by: Colin Unger Co-authored-by: Colin Unger --- .../include/compiler/allowed_machine_views.h | 21 ++ ...lel_layer_guid_oblivious_machine_mapping.h | 1 + .../src/compiler/allowed_machine_views.cc | 122 +++++++ .../test/src/allowed_machine_views.cc | 104 ++++++ .../get_optimal_machine_mapping.cc | 31 +- .../get_tensor_set_movement_across_split.cc | 63 +++- .../machine_mapping/machine_mapping.cc | 84 ++++- .../machine_mapping/machine_mapping_result.cc | 103 +++++- .../include/local-execution/cost_estimate.h | 1 - .../src/local_cost_estimator.cc | 1 + .../test/src/test_local_cost_estimator.cc | 134 ++++---- lib/op-attrs/include/op-attrs/parallel_dim.h | 1 + .../parallel_tensor_dim_idx_t.variant.toml | 20 ++ .../include/op-attrs/parallel_tensor_shape.h | 8 + lib/op-attrs/src/op-attrs/parallel_dim.cc | 14 + .../src/op-attrs/parallel_tensor_shape.cc | 29 ++ lib/pcg/include/pcg/device_id.h | 1 + .../pcg/machine_space_coordinate.struct.toml | 26 ++ lib/pcg/include/pcg/machine_space_offset.h | 14 + .../pcg/machine_space_offset.struct.toml | 26 ++ lib/pcg/include/pcg/machine_specification.h | 21 +- .../machine_specification_dimension.enum.toml | 14 + lib/pcg/include/pcg/machine_view.h | 69 ++-- lib/pcg/include/pcg/machine_view.struct.toml | 18 +- .../pcg/machine_view_dimension.struct.toml | 24 ++ .../pcg/multi_dimensional_stride.struct.toml | 25 ++ lib/pcg/include/pcg/operator_task_space.h | 22 ++ .../pcg/operator_task_space.struct.toml | 23 ++ .../pcg/start_invariant_machine_view.h | 47 +++ .../start_invariant_machine_view.struct.toml | 29 ++ ...ize_t.struct.toml => stride_t.struct.toml} | 2 +- lib/pcg/include/pcg/strided_rectangle.h | 17 - .../include/pcg/strided_rectangle.struct.toml | 19 -- lib/pcg/include/pcg/strided_rectangle_side.h | 15 - .../pcg/strided_rectangle_side.struct.toml | 22 -- .../pcg/task_space_coordinate.struct.toml | 23 ++ lib/pcg/src/pcg/device_id.cc | 23 +- lib/pcg/src/pcg/machine_space_offset.cc | 25 ++ lib/pcg/src/pcg/machine_specification.cc | 53 +++ lib/pcg/src/pcg/machine_view.cc | 210 ++++++------ lib/pcg/src/pcg/operator_task_space.cc | 38 +++ .../src/pcg/start_invariant_machine_view.cc | 86 +++++ lib/pcg/src/pcg/strided_rectangle.cc | 35 -- lib/pcg/src/pcg/strided_rectangle_side.cc | 17 - .../computation_graph_builder.cc} | 2 +- lib/pcg/test/src/pcg/machine_specification.cc | 54 ++++ lib/pcg/test/src/pcg/machine_view.cc | 301 ++++++++++++++++++ lib/pcg/test/src/pcg/operator_task_space.cc | 66 ++++ .../src/pcg/start_invariant_machine_view.cc | 229 +++++++++++++ lib/pcg/test/src/test_machine_view.cc | 76 ----- lib/pcg/test/src/test_strided_rectangle.cc | 39 --- .../utils/containers/cartesian_product.h | 9 +- lib/utils/include/utils/containers/filter.h | 8 + lib/utils/include/utils/containers/foldl.h | 72 +++++ .../utils/containers/get_all_assignments.h | 3 +- .../get_all_permutations_with_repetition.h | 50 +++ .../containers/map_from_keys_and_values.h | 29 ++ lib/utils/include/utils/containers/range.h | 13 + .../include/utils/containers/replicate.h | 15 + lib/utils/include/utils/containers/scanl.h | 77 +++++ .../include/utils/containers/transform.h | 6 +- .../include/utils/containers/without_order.h | 15 - lib/utils/include/utils/containers/zip.h | 12 + .../include/utils/fmt/unordered_multiset.h | 1 - lib/utils/include/utils/fmt/unordered_set.h | 1 - lib/utils/src/utils/containers/foldl.cc | 1 + .../utils/containers/get_all_assignments.cc | 11 + lib/utils/src/utils/containers/range.cc | 26 ++ lib/utils/src/utils/containers/replicate.cc | 1 + lib/utils/src/utils/containers/scanl.cc | 1 + .../src/utils/containers/without_order.cc | 1 - lib/utils/test/src/test_containers.cc | 2 +- .../src/utils/containers/cartesian_product.cc | 47 +-- lib/utils/test/src/utils/containers/filter.cc | 10 + lib/utils/test/src/utils/containers/foldl.cc | 47 +++ .../get_all_permutations_with_repetition.cc | 75 +++++ lib/utils/test/src/utils/containers/range.cc | 54 ++++ .../test/src/utils/containers/replicate.cc | 25 ++ lib/utils/test/src/utils/containers/scanl.cc | 71 +++++ .../src/utils/containers/without_order.cc | 15 - 80 files changed, 2588 insertions(+), 558 deletions(-) create mode 100644 lib/compiler/include/compiler/allowed_machine_views.h create mode 100644 lib/compiler/src/compiler/allowed_machine_views.cc create mode 100644 lib/compiler/test/src/allowed_machine_views.cc create mode 100644 lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml create mode 100644 lib/op-attrs/src/op-attrs/parallel_dim.cc create mode 100644 lib/pcg/include/pcg/machine_space_coordinate.struct.toml create mode 100644 lib/pcg/include/pcg/machine_space_offset.h create mode 100644 lib/pcg/include/pcg/machine_space_offset.struct.toml create mode 100644 lib/pcg/include/pcg/machine_specification_dimension.enum.toml create mode 100644 lib/pcg/include/pcg/machine_view_dimension.struct.toml create mode 100644 lib/pcg/include/pcg/multi_dimensional_stride.struct.toml create mode 100644 lib/pcg/include/pcg/operator_task_space.h create mode 100644 lib/pcg/include/pcg/operator_task_space.struct.toml create mode 100644 lib/pcg/include/pcg/start_invariant_machine_view.h create mode 100644 lib/pcg/include/pcg/start_invariant_machine_view.struct.toml rename lib/pcg/include/pcg/{side_size_t.struct.toml => stride_t.struct.toml} (87%) delete mode 100644 lib/pcg/include/pcg/strided_rectangle.h delete mode 100644 lib/pcg/include/pcg/strided_rectangle.struct.toml delete mode 100644 lib/pcg/include/pcg/strided_rectangle_side.h delete mode 100644 lib/pcg/include/pcg/strided_rectangle_side.struct.toml create mode 100644 lib/pcg/include/pcg/task_space_coordinate.struct.toml create mode 100644 lib/pcg/src/pcg/machine_space_offset.cc create mode 100644 lib/pcg/src/pcg/machine_specification.cc create mode 100644 lib/pcg/src/pcg/operator_task_space.cc create mode 100644 lib/pcg/src/pcg/start_invariant_machine_view.cc delete mode 100644 lib/pcg/src/pcg/strided_rectangle.cc delete mode 100644 lib/pcg/src/pcg/strided_rectangle_side.cc rename lib/pcg/test/src/{test_computation_graph_builder.cc => pcg/computation_graph_builder.cc} (100%) create mode 100644 lib/pcg/test/src/pcg/machine_specification.cc create mode 100644 lib/pcg/test/src/pcg/machine_view.cc create mode 100644 lib/pcg/test/src/pcg/operator_task_space.cc create mode 100644 lib/pcg/test/src/pcg/start_invariant_machine_view.cc delete mode 100644 lib/pcg/test/src/test_machine_view.cc delete mode 100644 lib/pcg/test/src/test_strided_rectangle.cc create mode 100644 lib/utils/include/utils/containers/foldl.h create mode 100644 lib/utils/include/utils/containers/get_all_permutations_with_repetition.h create mode 100644 lib/utils/include/utils/containers/map_from_keys_and_values.h create mode 100644 lib/utils/include/utils/containers/range.h create mode 100644 lib/utils/include/utils/containers/replicate.h create mode 100644 lib/utils/include/utils/containers/scanl.h delete mode 100644 lib/utils/include/utils/containers/without_order.h create mode 100644 lib/utils/src/utils/containers/foldl.cc create mode 100644 lib/utils/src/utils/containers/range.cc create mode 100644 lib/utils/src/utils/containers/replicate.cc create mode 100644 lib/utils/src/utils/containers/scanl.cc delete mode 100644 lib/utils/src/utils/containers/without_order.cc create mode 100644 lib/utils/test/src/utils/containers/foldl.cc create mode 100644 lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc create mode 100644 lib/utils/test/src/utils/containers/range.cc create mode 100644 lib/utils/test/src/utils/containers/replicate.cc create mode 100644 lib/utils/test/src/utils/containers/scanl.cc delete mode 100644 lib/utils/test/src/utils/containers/without_order.cc diff --git a/lib/compiler/include/compiler/allowed_machine_views.h b/lib/compiler/include/compiler/allowed_machine_views.h new file mode 100644 index 0000000000..9bb73fd1a9 --- /dev/null +++ b/lib/compiler/include/compiler/allowed_machine_views.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H +#define _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H + +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/operator_task_space.dtg.h" + +namespace FlexFlow { + +bool is_valid_machine_view(MachineView const &mv, + OperatorTaskSpace const &task, + MachineSpecification const &ms); + +std::unordered_set + get_allowed_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType device_type); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h index accd96af4c..cb3af9c689 100644 --- a/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_PARALLEL_LAYER_GUID_OBLIVIOUS_MACHINE_MAPPING_H #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h" +#include namespace FlexFlow { diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/allowed_machine_views.cc new file mode 100644 index 0000000000..1c226f79b0 --- /dev/null +++ b/lib/compiler/src/compiler/allowed_machine_views.cc @@ -0,0 +1,122 @@ +#include "compiler/allowed_machine_views.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.h" +#include "pcg/multi_dimensional_stride.dtg.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/all_of.h" +#include "utils/containers/cartesian_product.h" +#include "utils/containers/extend.h" +#include "utils/containers/filter.h" +#include "utils/containers/get_all_permutations_with_repetition.h" +#include "utils/containers/map_from_keys_and_values.h" +#include "utils/containers/product.h" +#include "utils/containers/range.h" +#include "utils/containers/replicate.h" +#include "utils/containers/sorted.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip.h" +#include "utils/overload.h" + +namespace FlexFlow { + +bool is_valid_machine_view(MachineView const &mv, + OperatorTaskSpace const &task, + MachineSpecification const &ms) { + std::optional maximum_device_coord = + get_machine_space_coordinate( + task, mv, get_task_space_maximum_coordinate(task), ms); + return maximum_device_coord.has_value(); +} + +/* + * Generates a set of candidate `MachineView`s. + * The returned set includes all valid machine views, and might contain invalid + * ones. This function should not be used externally (see + * `get_allowed_machine_views` instead). There is no guarantee that a non-empty + * returned set contains a valid machine view (i.e. it's possible for all + * the returned `MachineView`s to be invalid) + */ +static std::unordered_set + get_candidate_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType const &device_type) { + + auto get_max_stride_upper_bound = [](std::vector const &tensor_dims, + int total_devices) -> int { + int min_num_devices_with_full_stride_volume = product(transform( + tensor_dims, [](int const &num_devices) { return num_devices - 1; })); + return std::ceil(total_devices / min_num_devices_with_full_stride_volume); + }; + + auto candidate_strides = [&](std::vector const &tensor_dims, + int total_devices) + -> std::unordered_multiset { + int max_stride_upper_bound = + get_max_stride_upper_bound(tensor_dims, total_devices); + + std::vector single_stride_range = + transform(range(1, max_stride_upper_bound + 1), + [](int stride) { return stride_t{stride}; }); + std::unordered_multiset> raw_stride_vectors = + cartesian_product(replicate(tensor_dims.size(), single_stride_range)); + std::unordered_multiset strides = + transform(raw_stride_vectors, [](auto const &stride_vec) { + return MultiDimensionalStride{stride_vec}; + }); + return strides; + }; + + auto candidate_starts = [](MachineSpecification const &ms, + DeviceType const &device_type) { + std::unordered_set result; + for (int node_idx : range(ms.num_nodes)) { + for (int device_idx : range(get_num_devices_per_node(ms, device_type))) { + result.insert( + MachineSpaceCoordinate{node_idx, device_idx, device_type}); + } + } + return result; + }; + + auto candidate_dimensions = [](OperatorTaskSpace const &task) { + std::unordered_set options = { + MachineSpecificationDimension::INTER_NODE, + MachineSpecificationDimension::INTRA_NODE}; + return get_all_permutations_with_repetition(options, num_dims(task)); + }; + + std::vector tensor_dims = task.degrees; + int total_devices = get_num_devices(machine_spec, device_type); + + std::unordered_set machine_views; + + for (MultiDimensionalStride const &strides : + candidate_strides(tensor_dims, total_devices)) { + for (MachineSpaceCoordinate start : + candidate_starts(machine_spec, device_type)) { + for (std::vector const &dims : + candidate_dimensions(task)) { + machine_views.insert( + machine_view_from_strides_and_machine_spec_dimensions( + start, strides.raw_strides, dims)); + } + } + } + return machine_views; +} + +std::unordered_set + get_allowed_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType device_type) { + + std::unordered_set views = + get_candidate_machine_views(machine_spec, task, device_type); + return filter(views, [&](MachineView const &mv) { + return is_valid_machine_view(mv, task, machine_spec); + }); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/allowed_machine_views.cc b/lib/compiler/test/src/allowed_machine_views.cc new file mode 100644 index 0000000000..936894ad2d --- /dev/null +++ b/lib/compiler/test/src/allowed_machine_views.cc @@ -0,0 +1,104 @@ +#include "compiler/allowed_machine_views.h" +#include "doctest/doctest.h" +#include "utils/containers/extend.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/zip.h" +#include "utils/fmt/unordered_set.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_allowed_machine_views") { + + SUBCASE("1 degree of parallelism") { + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/1, + /*num_cpus_per_node=*/5, + /*num_gpus_per_node=*/5, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + + OperatorTaskSpace task = OperatorTaskSpace{{3}}; + + std::unordered_set correct = { + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/2, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}, + }, + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + + SUBCASE("2 degrees of parallelism") { + + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/3, + /*num_cpus_per_node=*/3, + /*num_gpus_per_node=*/3, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + OperatorTaskSpace task = OperatorTaskSpace{{2, 3}}; + + auto make_2d_view = [&](int start_node_idx, + int start_device_idx, + int stride1, + int stride2, + MachineSpecificationDimension m1, + MachineSpecificationDimension m2) { + return MachineView{ + MachineSpaceCoordinate{ + start_node_idx, start_device_idx, DeviceType::GPU}, + {MachineViewDimension{stride_t{stride1}, m1}, + MachineViewDimension{stride_t{stride2}, m2}}, + }; + }; + + auto intra = MachineSpecificationDimension::INTRA_NODE; + auto inter = MachineSpecificationDimension::INTER_NODE; + std::unordered_set correct = { + make_2d_view(0, 0, /*stride1=*/1, /*stride2=*/1, inter, intra), + make_2d_view(1, 0, /*stride1=*/1, /*stride2=*/1, inter, intra), + make_2d_view(0, 0, /*stride1=*/2, /*stride2=*/1, inter, intra), + + make_2d_view(0, 0, /*stride1=*/1, /*stride2=*/1, intra, inter), + make_2d_view(0, 1, /*stride1=*/1, /*stride2=*/1, intra, inter), + make_2d_view(0, 0, /*stride1=*/2, /*stride2=*/1, intra, inter), + }; + + std::unordered_set result = + get_allowed_machine_views(ms, task, DeviceType::GPU); + + CHECK(correct == result); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 0a874948e4..a0d06fe930 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -42,8 +42,35 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; - MachineView mv1 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(2)); - MachineView mv2 = make_1d_machine_view(gpu_id_t(1), gpu_id_t(3)); + MachineView mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; MachineSpecification full_machine_spec = MachineSpecification{ /*num_nodes=*/2, diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index c66d533d0f..e22f715d82 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -64,10 +64,65 @@ TEST_SUITE(FF_TEST_SUITE) { ParallelLayerAddedResult relu_2 = add_parallel_layer( pcg, relu_attrs, {get_only(relu_1.outputs)}, {relu_output_attrs}); - MachineView pre_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1}); - MachineView pre_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{2}); - MachineView post_mv1 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{3}); - MachineView post_mv2 = make_1d_machine_view(gpu_id_t{0}, gpu_id_t{4}); + MachineView pre_mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView pre_mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView post_mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{3}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView post_mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{4}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; SUBCASE("single edge across split") { PCGBinarySeriesSplit split = PCGBinarySeriesSplit{ diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc index 6b16a54c1f..221cca3ae1 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping.cc @@ -8,33 +8,89 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("combine_disjoint_mappings(MachineMapping, MachineMappping)") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + MachineMapping machine_mapping_0 = MachineMapping({ - {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t{Node{0}}, machine_view_0}, }); MachineMapping machine_mapping_1 = MachineMapping({ - {parallel_layer_guid_t(Node(1)), machine_view_1}, - }); - MachineMapping correct = MachineMapping({ - {parallel_layer_guid_t(Node(0)), machine_view_0}, - {parallel_layer_guid_t(Node(1)), machine_view_1}, + {parallel_layer_guid_t{Node{1}}, machine_view_1}, }); + MachineMapping correct = MachineMapping{{ + {parallel_layer_guid_t{Node{0}}, machine_view_0}, + {parallel_layer_guid_t{Node{1}}, machine_view_1}, + }}; MachineMapping result = combine_disjoint_mappings(machine_mapping_0, machine_mapping_1); CHECK(result == correct); } TEST_CASE("nodes_are_disjoint(MachineMapping, MachineMappping)") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + MachineMapping machine_mapping_0 = MachineMapping({ - {parallel_layer_guid_t(Node(0)), machine_view_0}, + {parallel_layer_guid_t{Node{0}}, machine_view_0}, }); SUBCASE("nodes are disjoint") { MachineMapping machine_mapping_1 = MachineMapping({ - {parallel_layer_guid_t(Node(1)), machine_view_1}, + {parallel_layer_guid_t{Node{1}}, machine_view_1}, }); bool correct = true; @@ -44,8 +100,8 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("nodes are not disjoint") { MachineMapping machine_mapping_1 = MachineMapping({ - {parallel_layer_guid_t(Node(0)), machine_view_0}, - {parallel_layer_guid_t(Node(1)), machine_view_1}, + {parallel_layer_guid_t{Node{0}}, machine_view_0}, + {parallel_layer_guid_t{Node{1}}, machine_view_1}, }); bool correct = false; bool result = nodes_are_disjoint(machine_mapping_0, machine_mapping_1); diff --git a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc index 254d6b2784..73b921fc98 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/machine_mapping_result.cc @@ -6,8 +6,35 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("series_combine") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; float pre_cost = 2.0; MachineMappingResult pre = MachineMappingResult{ @@ -49,7 +76,7 @@ TEST_SUITE(FF_TEST_SUITE) { float comm_cost = 3.0; - SUBCASE("pre is infeasbile") { + SUBCASE("pre is infeasible") { MachineMappingResult result = series_combine( comm_cost, infeasible, post, ParallelSplitTransformation::LthenR); MachineMappingResult correct = infeasible; @@ -57,7 +84,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("post is infeasbile") { + SUBCASE("post is infeasible") { MachineMappingResult result = series_combine( comm_cost, pre, infeasible, ParallelSplitTransformation::LthenR); MachineMappingResult correct = infeasible; @@ -160,8 +187,35 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("parallel_combine") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; MachineMappingResult lhs = MachineMappingResult{ FeasibleMachineMappingResult{ @@ -199,14 +253,14 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingResult infeasible = infeasible_machine_mapping_result(); - SUBCASE("lhs is infeasbile") { + SUBCASE("lhs is infeasible") { MachineMappingResult result = parallel_combine(infeasible, rhs); MachineMappingResult correct = infeasible; CHECK(result == correct); } - SUBCASE("rhs is infeasbile") { + SUBCASE("rhs is infeasible") { MachineMappingResult result = parallel_combine(lhs, infeasible); MachineMappingResult correct = infeasible; @@ -256,8 +310,35 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("minimize_runtime") { - MachineView machine_view_0 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(1)); - MachineView machine_view_1 = make_1d_machine_view(gpu_id_t(0), gpu_id_t(2)); + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; MachineMappingResult faster = MachineMappingResult{ FeasibleMachineMappingResult{ @@ -295,7 +376,7 @@ TEST_SUITE(FF_TEST_SUITE) { MachineMappingResult infeasible = infeasible_machine_mapping_result(); - SUBCASE("lhs is infeasbile") { + SUBCASE("lhs is infeasible") { MachineMappingResult result = minimize_runtime(infeasible, slower); MachineMappingResult correct = slower; diff --git a/lib/local-execution/include/local-execution/cost_estimate.h b/lib/local-execution/include/local-execution/cost_estimate.h index 31503e0da9..7020089ccf 100644 --- a/lib/local-execution/include/local-execution/cost_estimate.h +++ b/lib/local-execution/include/local-execution/cost_estimate.h @@ -8,7 +8,6 @@ #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" - namespace FlexFlow { struct ICostEstimator { diff --git a/lib/local-execution/src/local_cost_estimator.cc b/lib/local-execution/src/local_cost_estimator.cc index b42aec10bb..6d82e26511 100644 --- a/lib/local-execution/src/local_cost_estimator.cc +++ b/lib/local-execution/src/local_cost_estimator.cc @@ -5,6 +5,7 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/pcg_operator_attrs.h" #include "pcg/computation_graph_builder.h" +#include "pcg/machine_view.dtg.h" #include "pcg/parallel_tensor_attrs.h" #include "utils/containers/transform.h" diff --git a/lib/local-execution/test/src/test_local_cost_estimator.cc b/lib/local-execution/test/src/test_local_cost_estimator.cc index 4c01df53e9..da3af6e3ad 100644 --- a/lib/local-execution/test/src/test_local_cost_estimator.cc +++ b/lib/local-execution/test/src/test_local_cost_estimator.cc @@ -1,77 +1,79 @@ -#include "doctest/doctest.h" -#include "kernels/local_cuda_allocator.h" -#include "kernels/managed_per_device_ff_handle.h" -#include "local-execution/local_cost_estimator.h" -#include "op-attrs/ops/attention.h" -#include "op-attrs/parallel_tensor_shape.h" -#include "pcg/computation_graph_builder.h" -#include "test_utils.h" +// #include "doctest/doctest.h" +// #include "kernels/local_cuda_allocator.h" +// #include "kernels/managed_per_device_ff_handle.h" +// #include "local-execution/local_cost_estimator.h" +// #include "op-attrs/ops/attention.h" +// #include "op-attrs/parallel_tensor_shape.h" +// #include "pcg/computation_graph_builder.h" +// #include "test_utils.h" -using namespace ::FlexFlow; +// using namespace ::FlexFlow; -TEST_SUITE(FF_CUDA_TEST_SUITE) { - TEST_CASE("Local Cost Estimator") { - // local backing initialization - ManagedPerDeviceFFHandle managed_handle{}; +// TEST_SUITE(FF_CUDA_TEST_SUITE) { +// TEST_CASE("Local Cost Estimator") { +// // local backing initialization +// ManagedPerDeviceFFHandle managed_handle{}; - RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ - DeviceSpecific::create(managed_handle.raw_handle()), - EnableProfiling::YES, - ProfilingSettings{/*warmup_iters=*/0, - /*measure_iters=*/1}}; +// RuntimeArgConfig runtime_arg_config = RuntimeArgConfig{ +// DeviceSpecific::create(managed_handle.raw_handle()), +// EnableProfiling::YES, +// ProfilingSettings{/*warmup_iters=*/0, +// /*measure_iters=*/1}}; - LocalCostEstimator cost_estimator = LocalCostEstimator{runtime_arg_config}; +// LocalCostEstimator cost_estimator = +// LocalCostEstimator{runtime_arg_config}; - SUBCASE("Estimate cost -- Attention Op") { - int embed_dim = 32; - int num_heads = 10; - MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ - /*embed_dim=*/embed_dim, - /*num_heads=*/num_heads, - /*kdim=*/embed_dim, - /*vdim=*/embed_dim, - /*dropout=*/0.0, - /*bias=*/true, - /*add_bias_kv=*/false, - /*add_zero_attn=*/false, - }; +// SUBCASE("Estimate cost -- Attention Op") { +// int embed_dim = 32; +// int num_heads = 10; +// MultiHeadAttentionAttrs attrs = MultiHeadAttentionAttrs{ +// /*embed_dim=*/embed_dim, +// /*num_heads=*/num_heads, +// /*kdim=*/embed_dim, +// /*vdim=*/embed_dim, +// /*dropout=*/0.0, +// /*bias=*/true, +// /*add_bias_kv=*/false, +// /*add_zero_attn=*/false, +// }; - size_t batch_size = 40; - size_t seq_len = 48; - size_t feature_size = 36; +// size_t batch_size = 40; +// size_t seq_len = 48; +// size_t feature_size = 36; - DataType dtype = DataType::FLOAT; - ParallelTensorShape inputs_shape = lift_to_parallel(TensorShape{ - TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, - DataType::FLOAT, - }); +// DataType dtype = DataType::FLOAT; +// ParallelTensorShape inputs_shape = lift_to_parallel(TensorShape{ +// TensorDims{FFOrdered{batch_size, seq_len, feature_size}}, +// DataType::FLOAT, +// }); - ParallelTensorShape weights_shape = throw_if_unexpected( - get_weights_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); - ParallelTensorAttrs weight_attrs = - ParallelTensorAttrs{weights_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - CreateGrad::YES}; +// ParallelTensorShape weights_shape = throw_if_unexpected( +// get_weights_shape(attrs, inputs_shape, inputs_shape, +// inputs_shape)); +// ParallelTensorAttrs weight_attrs = +// ParallelTensorAttrs{weights_shape, +// /*sync_type=*/std::nullopt, +// /*initializer=*/std::nullopt, +// CreateGrad::YES}; - ParallelTensorShape output_shape = throw_if_unexpected( - get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); - ParallelTensorAttrs output_attrs = - ParallelTensorAttrs{output_shape, - /*sync_type=*/std::nullopt, - /*initializer=*/std::nullopt, - CreateGrad::YES}; +// ParallelTensorShape output_shape = throw_if_unexpected( +// get_output_shape(attrs, inputs_shape, inputs_shape, inputs_shape)); +// ParallelTensorAttrs output_attrs = +// ParallelTensorAttrs{output_shape, +// /*sync_type=*/std::nullopt, +// /*initializer=*/std::nullopt, +// CreateGrad::YES}; - CostDetails result = cost_estimator.estimate_cost( - PCGOperatorAttrs{attrs}, - std::vector{ - inputs_shape, inputs_shape, inputs_shape}, - std::vector{weight_attrs}, - std::vector{output_attrs}, - make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1})); +// CostDetails result = cost_estimator.estimate_cost( +// PCGOperatorAttrs{attrs}, +// std::vector{ +// inputs_shape, inputs_shape, inputs_shape}, +// std::vector{weight_attrs}, +// std::vector{output_attrs}, +// make_1d_machine_view(gpu_id_t{0}, gpu_id_t{1})); - CHECK(result.total_elapsed_time > 0); - CHECK(result.total_mem_usage > 0); - } - } -} +// CHECK(result.total_elapsed_time > 0); +// CHECK(result.total_mem_usage > 0); +// } +// } +// } diff --git a/lib/op-attrs/include/op-attrs/parallel_dim.h b/lib/op-attrs/include/op-attrs/parallel_dim.h index 5397ad7c68..a12951dec9 100644 --- a/lib/op-attrs/include/op-attrs/parallel_dim.h +++ b/lib/op-attrs/include/op-attrs/parallel_dim.h @@ -11,6 +11,7 @@ bool is_replica_dim(ParallelDim const &); ParallelDim with_size_set_to(ParallelDim const &, size_t); ParallelDim with_degree_set_to(ParallelDim const &, int); ParallelDim with_is_replica_set_to(ParallelDim const &, bool); +int get_degree(ParallelDim const &); } // namespace FlexFlow diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml new file mode 100644 index 0000000000..9396cbcbe8 --- /dev/null +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_dim_idx_t.variant.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "parallel_tensor_dim_idx_t" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "op-attrs/ff_dim.dtg.h", + "op-attrs/replica_type.dtg.h", +] + +[[values]] +type = "::FlexFlow::ff_dim_t" + +[[values]] +type = "::FlexFlow::ReplicaType" diff --git a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h index a03151160b..0759dc746e 100644 --- a/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h +++ b/lib/op-attrs/include/op-attrs/parallel_tensor_shape.h @@ -1,7 +1,9 @@ #ifndef _OP_META_PARALLEL_TENSOR_SHAPE_H #define _OP_META_PARALLEL_TENSOR_SHAPE_H +#include "op-attrs/parallel_dim.h" #include "op-attrs/parallel_tensor_dim_degrees.dtg.h" +#include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/replica_parallel_dim.dtg.h" #include "op-attrs/tensor_shape.h" @@ -50,6 +52,12 @@ std::vector TensorShape get_reduced_shape(ParallelTensorShape const &); +ParallelDim get_parallel_dim_at_idx(ParallelTensorShape const &shape, + parallel_tensor_dim_idx_t idx); + +std::unordered_set + get_parallel_tensor_dim_indices(ParallelTensorShape const &shape); + } // namespace FlexFlow #endif diff --git a/lib/op-attrs/src/op-attrs/parallel_dim.cc b/lib/op-attrs/src/op-attrs/parallel_dim.cc new file mode 100644 index 0000000000..26ba2b3fa1 --- /dev/null +++ b/lib/op-attrs/src/op-attrs/parallel_dim.cc @@ -0,0 +1,14 @@ +#include "op-attrs/parallel_dim.h" +#include "utils/overload.h" + +namespace FlexFlow { + +int get_degree(ParallelDim const &dim) { + return dim.visit(overload{ + [](ShardParallelDim const &shard_dim) { return shard_dim.degree; }, + [](ReplicaParallelDim const &replica_dim) { + return replica_dim.degree; + }}); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc index 0663795db5..dcc567e0ca 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_shape.cc @@ -1,9 +1,12 @@ #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/parallel_tensor_dims.h" #include "op-attrs/tensor_dims.h" +#include "utils/containers/extend.h" #include "utils/containers/product.h" +#include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/hash-utils.h" +#include "utils/overload.h" namespace FlexFlow { @@ -116,4 +119,30 @@ TensorShape get_reduced_shape(ParallelTensorShape const &s) { }; } +ParallelDim get_parallel_dim_at_idx(ParallelTensorShape const &shape, + parallel_tensor_dim_idx_t idx) { + return idx.visit( + overload{[&](ff_dim_t shard_dim) { + return ParallelDim{shape.dims.shard_dims.at(shard_dim)}; + }, + [&](ReplicaType replica_type) { + ReplicaParallelDimSet replicas = shape.dims.replica_dims; + int degree = (ReplicaType::SUM == replica_type + ? replicas.sum_degree.value + : replicas.discard_copy_degree.value); + return ParallelDim{ReplicaParallelDim{degree, replica_type}}; + }}); +} + +std::unordered_set + get_parallel_tensor_dim_indices(ParallelTensorShape const &shape) { + std::unordered_set indices; + extend(indices, transform(range(num_shard_dims(shape.dims)), [](int idx) { + return parallel_tensor_dim_idx_t(ff_dim_t(idx)); + })); + indices.insert(parallel_tensor_dim_idx_t(ReplicaType::SUM)); + indices.insert(parallel_tensor_dim_idx_t(ReplicaType::DISCARD_COPY)); + return indices; +} + } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index 1157a2932a..28cf30eaba 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -13,6 +13,7 @@ device_id_t operator+(device_id_t, size_t); DeviceType get_device_type(device_id_t const &device_id); gpu_id_t unwrap_gpu(device_id_t); cpu_id_t unwrap_cpu(device_id_t); +int get_raw_id(device_id_t); device_id_t device_id_from_index(int, DeviceType); diff --git a/lib/pcg/include/pcg/machine_space_coordinate.struct.toml b/lib/pcg/include/pcg/machine_space_coordinate.struct.toml new file mode 100644 index 0000000000..9b197a74c9 --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_coordinate.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MachineSpaceCoordinate" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", +] + +[[fields]] +name = "node_idx" +type = "int" + +[[fields]] +name = "device_idx" +type = "int" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_space_offset.h b/lib/pcg/include/pcg/machine_space_offset.h new file mode 100644 index 0000000000..2f702cc518 --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_offset.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_MACHINE_SPACE_OFFSET_H +#define _FLEXFLOW_PCG_INCLUDE_MACHINE_SPACE_OFFSET_H + +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/machine_space_offset.dtg.h" + +namespace FlexFlow { + +MachineSpaceOffset get_machine_space_offset_from_coordinate( + MachineSpaceCoordinate const &start, MachineSpaceCoordinate const &coord); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/machine_space_offset.struct.toml b/lib/pcg/include/pcg/machine_space_offset.struct.toml new file mode 100644 index 0000000000..3f6eab38fd --- /dev/null +++ b/lib/pcg/include/pcg/machine_space_offset.struct.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "MachineSpaceOffset" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", +] + +[[fields]] +name = "node_offset" +type = "int" + +[[fields]] +name = "device_offset" +type = "int" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index f66723b0ff..6ffa9900c2 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -1,6 +1,25 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_SPECIFICATION_H -namespace FlexFlow {} // namespace FlexFlow +#include "pcg/device_id_t.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +int get_num_gpus(MachineSpecification const &ms); +int get_num_cpus(MachineSpecification const &ms); +int get_num_devices(MachineSpecification const &ms, + DeviceType const &device_type); +int get_num_devices_per_node(MachineSpecification const &ms, + DeviceType const &device_type); + +bool is_valid_machine_space_coordinate(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord); + +device_id_t get_device_id(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord); +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_specification_dimension.enum.toml b/lib/pcg/include/pcg/machine_specification_dimension.enum.toml new file mode 100644 index 0000000000..837b4306da --- /dev/null +++ b/lib/pcg/include/pcg/machine_specification_dimension.enum.toml @@ -0,0 +1,14 @@ +namespace = "FlexFlow" +name = "MachineSpecificationDimension" +features = [ + "hash", + "json", + "fmt", + "rapidcheck", +] + +[[values]] +name = "INTER_NODE" + +[[values]] +name = "INTRA_NODE" diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 56abf5aa20..293227b7a1 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -1,50 +1,41 @@ #ifndef _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H -#include "pcg/cpu_id_t.dtg.h" -#include "pcg/device_id.h" +#include "machine_specification.dtg.h" +#include "machine_view.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "pcg/device_type.dtg.h" -#include "pcg/gpu_id_t.dtg.h" -#include "pcg/machine_view.dtg.h" -#include "pcg/num_points_t.dtg.h" -#include "pcg/side_size_t.dtg.h" +#include "pcg/operator_task_space.dtg.h" +#include "task_space_coordinate.dtg.h" #include -#include +#include +#include namespace FlexFlow { -std::vector device_ids(MachineView const &); -size_t num_dims(MachineView const &); -std::size_t num_devices(MachineView const &); -DeviceType get_device_type(MachineView const &); - -MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride = 1); -MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride = 1); -MachineView - make_1d_machine_view(device_id_t start, device_id_t stop, int stride = 1); - -MachineView make_1d_machine_view(gpu_id_t start, - num_points_t num_points, - int stride = 1); -MachineView make_1d_machine_view(cpu_id_t start, - num_points_t num_points, - int stride = 1); -MachineView make_1d_machine_view(device_id_t start, - num_points_t num_points, - int stride = 1); - -MachineView make_1d_machine_view(gpu_id_t start, - side_size_t interval_size, - int stride = 1); -MachineView make_1d_machine_view(cpu_id_t start, - side_size_t interval_size, - int stride = 1); -MachineView make_1d_machine_view(device_id_t start, - side_size_t interval_size, - int stride = 1); - -MachineView make_1d_machine_view(device_id_t start, size_t interval_size); +size_t num_dims(MachineView const &mv); + +DeviceType get_device_type(MachineView const &mv); + +std::vector get_strides(MachineView const &mv); + +std::vector + get_dimensions(MachineView const &mv); + +MachineView machine_view_from_strides_and_machine_spec_dimensions( + MachineSpaceCoordinate const &start, + std::vector const &strides, + std::vector const &dims); + +std::optional + get_machine_space_coordinate(OperatorTaskSpace const &task, + MachineView const &mv, + TaskSpaceCoordinate const &coordinates, + MachineSpecification const &ms); + +std::unordered_set + get_machine_space_coordinates(OperatorTaskSpace const &task, + MachineView const &mv, + MachineSpecification const &ms); } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/machine_view.struct.toml b/lib/pcg/include/pcg/machine_view.struct.toml index c97731991f..e4de69eafc 100644 --- a/lib/pcg/include/pcg/machine_view.struct.toml +++ b/lib/pcg/include/pcg/machine_view.struct.toml @@ -9,15 +9,21 @@ features = [ "fmt", ] -includes = [ - "pcg/device_id_t.dtg.h", - "pcg/strided_rectangle.dtg.h", +includes = [ + "pcg/machine_view_dimension.dtg.h", + "pcg/machine_space_coordinate.dtg.h" ] +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h" +] + + [[fields]] name = "start" -type = "::FlexFlow::device_id_t" +type = "::FlexFlow::MachineSpaceCoordinate" [[fields]] -name = "rect" -type = "::FlexFlow::StridedRectangle" +name = "dimensions" +type = "std::vector<::FlexFlow::MachineViewDimension>" diff --git a/lib/pcg/include/pcg/machine_view_dimension.struct.toml b/lib/pcg/include/pcg/machine_view_dimension.struct.toml new file mode 100644 index 0000000000..03b0ac51e4 --- /dev/null +++ b/lib/pcg/include/pcg/machine_view_dimension.struct.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "MachineViewDimension" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/machine_specification_dimension.dtg.h", + "pcg/stride_t.dtg.h", +] + + +[[fields]] +name = "stride" +type = "::FlexFlow::stride_t" + +[[fields]] +name = "projection" +type = "::FlexFlow::MachineSpecificationDimension" diff --git a/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml b/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml new file mode 100644 index 0000000000..9fa5a77f77 --- /dev/null +++ b/lib/pcg/include/pcg/multi_dimensional_stride.struct.toml @@ -0,0 +1,25 @@ +namespace = "FlexFlow" +name = "MultiDimensionalStride" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", + "pcg/stride_t.dtg.h", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h" + +] + +[[fields]] +name = "raw_strides" +type = "std::vector<::FlexFlow::stride_t>" diff --git a/lib/pcg/include/pcg/operator_task_space.h b/lib/pcg/include/pcg/operator_task_space.h new file mode 100644 index 0000000000..61cab4eff1 --- /dev/null +++ b/lib/pcg/include/pcg/operator_task_space.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H +#define _FLEXFLOW_PCG_INCLUDE_OPERATOR_TASK_SPACE_H + +#include "pcg/operator_task_space.dtg.h" +#include "pcg/task_space_coordinate.dtg.h" +#include +#include + +namespace FlexFlow { + +std::unordered_set + get_task_space_coordinates(OperatorTaskSpace const &task); + +TaskSpaceCoordinate + get_task_space_maximum_coordinate(OperatorTaskSpace const &task); + +size_t num_dims(OperatorTaskSpace const &task); +size_t num_tasks(OperatorTaskSpace const &task); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/operator_task_space.struct.toml b/lib/pcg/include/pcg/operator_task_space.struct.toml new file mode 100644 index 0000000000..3ab8b83173 --- /dev/null +++ b/lib/pcg/include/pcg/operator_task_space.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "OperatorTaskSpace" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h" +] + +[[fields]] +name = "degrees" +type = "std::vector" diff --git a/lib/pcg/include/pcg/start_invariant_machine_view.h b/lib/pcg/include/pcg/start_invariant_machine_view.h new file mode 100644 index 0000000000..f5091c69d1 --- /dev/null +++ b/lib/pcg/include/pcg/start_invariant_machine_view.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_START_INVARIANT_MACHINE_VIEW_H +#define _FLEXFLOW_PCG_INCLUDE_PCG_START_INVARIANT_MACHINE_VIEW_H + +#include "pcg/machine_space_offset.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/operator_task_space.dtg.h" +#include "pcg/start_invariant_machine_view.dtg.h" +#include "pcg/task_space_coordinate.dtg.h" +#include + +namespace FlexFlow { + +MachineView + machine_view_from_start_invariant(StartInvariantMachineView const &mv, + MachineSpaceCoordinate const &start); +StartInvariantMachineView + start_invariant_from_machine_view(MachineView const &mv); + +size_t num_dims(StartInvariantMachineView const &mv); + +DeviceType get_device_type(StartInvariantMachineView const &mv); + +std::vector get_strides(StartInvariantMachineView const &mv); + +std::vector + get_dimensions(StartInvariantMachineView const &mv); + +StartInvariantMachineView + start_invariant_machine_view_from_strides_and_machine_spec_dimensions( + std::vector const &strides, + std::vector const &dims); + +std::optional + get_machine_space_offset(OperatorTaskSpace const &task, + StartInvariantMachineView const &mv, + TaskSpaceCoordinate const &coordinates, + MachineSpecification const &ms); + +std::unordered_set + get_machine_space_offsets(OperatorTaskSpace const &task, + StartInvariantMachineView const &mv, + MachineSpecification const &ms); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml b/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml new file mode 100644 index 0000000000..a1b2b40524 --- /dev/null +++ b/lib/pcg/include/pcg/start_invariant_machine_view.struct.toml @@ -0,0 +1,29 @@ +namespace = "FlexFlow" +name = "StartInvariantMachineView" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "pcg/machine_view_dimension.dtg.h", + "pcg/device_type.dtg.h" +] + +src_includes = [ + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "dimensions" +type = "std::vector<::FlexFlow::MachineViewDimension>" + + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" diff --git a/lib/pcg/include/pcg/side_size_t.struct.toml b/lib/pcg/include/pcg/stride_t.struct.toml similarity index 87% rename from lib/pcg/include/pcg/side_size_t.struct.toml rename to lib/pcg/include/pcg/stride_t.struct.toml index dbaad4fedb..a764497b8b 100644 --- a/lib/pcg/include/pcg/side_size_t.struct.toml +++ b/lib/pcg/include/pcg/stride_t.struct.toml @@ -1,5 +1,5 @@ namespace = "FlexFlow" -name = "side_size_t" +name = "stride_t" features = [ "eq", "ord", diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h deleted file mode 100644 index 9c3b8eeda9..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H -#define _FLEXFLOW_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_H - -#include "op-attrs/ff_dim.dtg.h" -#include "pcg/side_size_t.dtg.h" -#include "pcg/strided_rectangle.dtg.h" - -namespace FlexFlow { - -size_t get_num_dims(StridedRectangle const &); -StridedRectangleSide get_side_at_idx(StridedRectangle const &rect, - ff_dim_t const &idx); -num_points_t get_num_points(StridedRectangle const &rect); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/strided_rectangle.struct.toml b/lib/pcg/include/pcg/strided_rectangle.struct.toml deleted file mode 100644 index 577825238d..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle.struct.toml +++ /dev/null @@ -1,19 +0,0 @@ -namespace = "FlexFlow" -name = "StridedRectangle" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "pcg/strided_rectangle_side.dtg.h", - "op-attrs/dim_ordered/dim_ordered.h", -] - -[[fields]] -name = "sides" -type = "::FlexFlow::FFOrdered<::FlexFlow::StridedRectangleSide>" diff --git a/lib/pcg/include/pcg/strided_rectangle_side.h b/lib/pcg/include/pcg/strided_rectangle_side.h deleted file mode 100644 index 1486b73143..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle_side.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H -#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_STRIDED_RECTANGLE_SIDE_H - -#include "pcg/side_size_t.dtg.h" -#include "pcg/strided_rectangle_side.dtg.h" - -namespace FlexFlow { - -StridedRectangleSide strided_side_from_size_and_stride(side_size_t, int stride); - -side_size_t get_side_size(StridedRectangleSide const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/pcg/include/pcg/strided_rectangle_side.struct.toml b/lib/pcg/include/pcg/strided_rectangle_side.struct.toml deleted file mode 100644 index f26adfafd5..0000000000 --- a/lib/pcg/include/pcg/strided_rectangle_side.struct.toml +++ /dev/null @@ -1,22 +0,0 @@ -namespace = "FlexFlow" -name = "StridedRectangleSide" -features = [ - "eq", - "ord", - "hash", - "json", - "rapidcheck", - "fmt", -] - -includes = [ - "pcg/num_points_t.dtg.h", -] - -[[fields]] -name = "num_points" -type = "::FlexFlow::num_points_t" - -[[fields]] -name = "stride" -type = "int" diff --git a/lib/pcg/include/pcg/task_space_coordinate.struct.toml b/lib/pcg/include/pcg/task_space_coordinate.struct.toml new file mode 100644 index 0000000000..65aea167cb --- /dev/null +++ b/lib/pcg/include/pcg/task_space_coordinate.struct.toml @@ -0,0 +1,23 @@ +namespace = "FlexFlow" +name = "TaskSpaceCoordinate" +features = [ + "eq", + "ord", + "hash", + "json", + # "rapidcheck", + "fmt", +] + +includes = [ + "", +] + +src_includes = [ + "utils/hash/vector.h", + "utils/fmt/vector.h", +] + +[[fields]] +name = "raw_coord" +type = "std::vector" diff --git a/lib/pcg/src/pcg/device_id.cc b/lib/pcg/src/pcg/device_id.cc index 35b0c9aeda..a8cfe1f82f 100644 --- a/lib/pcg/src/pcg/device_id.cc +++ b/lib/pcg/src/pcg/device_id.cc @@ -25,8 +25,27 @@ cpu_id_t unwrap_cpu(device_id_t device_id) { return device_id.get(); } -device_id_t device_id_from_index(int, DeviceType) { - NOT_IMPLEMENTED(); +int get_raw_id(device_id_t device_id) { + switch (get_device_type(device_id)) { + case DeviceType::GPU: + return unwrap_gpu(device_id).gpu_index; + case DeviceType::CPU: + return unwrap_cpu(device_id).cpu_index; + default: + throw mk_runtime_error(fmt::format("Unsupported device {}", device_id)); + } +} + +device_id_t device_id_from_index(int idx, DeviceType device_type) { + switch (device_type) { + case DeviceType::GPU: + return device_id_t{gpu_id_t{idx}}; + case DeviceType::CPU: + return device_id_t{cpu_id_t{idx}}; + default: + throw mk_runtime_error( + fmt::format("Unsupported DeviceType {}", device_type)); + } } } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_space_offset.cc b/lib/pcg/src/pcg/machine_space_offset.cc new file mode 100644 index 0000000000..9990023f8c --- /dev/null +++ b/lib/pcg/src/pcg/machine_space_offset.cc @@ -0,0 +1,25 @@ +#include "pcg/machine_space_offset.h" +#include "utils/exception.h" + +namespace FlexFlow { +MachineSpaceOffset get_machine_space_offset_from_coordinate( + MachineSpaceCoordinate const &start, MachineSpaceCoordinate const &coord) { + if ((coord.device_idx < start.device_idx) || + (coord.node_idx < start.node_idx)) { + throw mk_runtime_error(fmt::format( + "One of the coordinates of start {} is greater than one of the " + "coordinates of coord {}, are you sure you didn't swap them?", + start, + coord)); + } + if (start.device_type != coord.device_type) { + throw mk_runtime_error( + fmt::format("{} has different DeviceType from {}", start, coord)); + } + + return MachineSpaceOffset{coord.node_idx - start.node_idx, + coord.device_idx - start.device_idx, + coord.device_type}; +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_specification.cc b/lib/pcg/src/pcg/machine_specification.cc new file mode 100644 index 0000000000..ca5b8ba047 --- /dev/null +++ b/lib/pcg/src/pcg/machine_specification.cc @@ -0,0 +1,53 @@ +#include "pcg/machine_specification.h" +#include "pcg/device_id.h" +#include "utils/exception.h" +namespace FlexFlow { + +int get_num_gpus(MachineSpecification const &ms) { + return ms.num_nodes * ms.num_gpus_per_node; +} +int get_num_cpus(MachineSpecification const &ms) { + return ms.num_nodes * ms.num_cpus_per_node; +} +int get_num_devices(MachineSpecification const &ms, + DeviceType const &device_type) { + switch (device_type) { + case DeviceType::GPU: + return get_num_gpus(ms); + case DeviceType::CPU: + return get_num_cpus(ms); + default: + throw mk_runtime_error(fmt::format("Unknown DeviceType {}", device_type)); + } +} + +int get_num_devices_per_node(MachineSpecification const &ms, + DeviceType const &device_type) { + switch (device_type) { + case DeviceType::GPU: + return ms.num_gpus_per_node; + case DeviceType::CPU: + return ms.num_cpus_per_node; + default: + throw mk_runtime_error(fmt::format("Unknown DeviceType {}", device_type)); + } +} +bool is_valid_machine_space_coordinate(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord) { + return (coord.node_idx < ms.num_nodes) && + (coord.device_idx < get_num_devices_per_node(ms, coord.device_type)); +} + +device_id_t get_device_id(MachineSpecification const &ms, + MachineSpaceCoordinate const &coord) { + if (!is_valid_machine_space_coordinate(ms, coord)) { + throw mk_runtime_error(fmt::format( + "Invalid coordinate {} for machine specification {}", ms, coord)); + } + int raw_idx = + coord.node_idx * get_num_devices_per_node(ms, coord.device_type) + + coord.device_idx; + return device_id_from_index(raw_idx, coord.device_type); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index c09ab1a3c9..18f6cacb7e 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -1,121 +1,115 @@ #include "pcg/machine_view.h" -#include "pcg/device_id.h" -#include "pcg/strided_rectangle.dtg.h" -#include "pcg/strided_rectangle.h" -#include "pcg/strided_rectangle_side.h" +#include "pcg/machine_specification.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/contains.h" +#include "utils/containers/count.h" +#include "utils/containers/filter.h" +#include "utils/containers/scanl.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip.h" namespace FlexFlow { -std::vector device_ids(MachineView const &) { - NOT_IMPLEMENTED(); -} - -std::size_t num_dims(MachineView const &mv) { - return get_num_dims(mv.rect); -} - -size_t num_devices(MachineView const &mv) { - return get_num_points(mv.rect).unwrapped; +size_t num_dims(MachineView const &mv) { + return get_strides(mv).size(); } DeviceType get_device_type(MachineView const &mv) { - return get_device_type(mv.start); -} - -static StridedRectangle make_1d_rect(int start, int stop, int stride) { - assert(stop > start); - assert(stride > 0); - StridedRectangleSide side = - strided_side_from_size_and_stride(side_size_t{stop - start}, stride); - StridedRectangle rect = - StridedRectangle{std::vector{side}}; - return rect; -} - -MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.gpu_index, stop.gpu_index, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { - StridedRectangle rect = make_1d_rect(start.cpu_index, stop.cpu_index, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView - make_1d_machine_view(device_id_t start, device_id_t stop, int stride) { - assert(get_device_type(start) == get_device_type(stop)); - if (get_device_type(start) == DeviceType::CPU) { - return make_1d_machine_view(unwrap_cpu(start), unwrap_cpu(stop), stride); + return mv.start.device_type; +} + +std::vector get_strides(MachineView const &mv) { + return transform(mv.dimensions, + [](MachineViewDimension const &dim) { return dim.stride; }); +} + +std::vector + get_dimensions(MachineView const &mv) { + return transform(mv.dimensions, [](MachineViewDimension const &dim) { + return dim.projection; + }); +} + +MachineView machine_view_from_strides_and_machine_spec_dimensions( + MachineSpaceCoordinate const &start, + std::vector const &strides, + std::vector const &dims) { + std::vector dimensions = + transform(zip(strides, dims), [&](auto const &p) { + return MachineViewDimension{p.first, p.second}; + }); + return MachineView{start, dimensions}; +} + +std::optional get_machine_space_coordinate( + OperatorTaskSpace const &task, + MachineView const &machine_view, + TaskSpaceCoordinate const &coord, + MachineSpecification const &machine_specification) { + + auto get_dimension_indices_for_dimension = + [&](MachineSpecificationDimension dimension) { + std::vector mv_dimensions = + get_dimensions(machine_view); + return filter(count(mv_dimensions.size()), [&](size_t idx) { + return mv_dimensions.at(idx) == dimension; + }); + }; + + auto compute_index = [&](int start_idx, + std::vector const &dimension_indices) { + std::vector mv_strides = get_strides(machine_view); + + std::vector sizes = transform(dimension_indices, [&](size_t i) { + return task.degrees.at(i) * mv_strides.at(i).unwrapped; + }); + std::vector coord_points = transform( + dimension_indices, [&](size_t i) { return coord.raw_coord.at(i); }); + std::vector strides = transform(dimension_indices, [&](size_t i) { + return mv_strides.at(i).unwrapped; + }); + + std::vector coeffs = scanl(sizes, 1, std::multiplies()); + + int index = start_idx; + for (auto [coeff, coord_point, stride] : + zip(coeffs, coord_points, strides)) { + index += coeff * coord_point * stride; + } + return index; + }; + + std::vector inter_dimension_indices = + get_dimension_indices_for_dimension( + MachineSpecificationDimension::INTER_NODE); + std::vector intra_dimension_indices = + get_dimension_indices_for_dimension( + MachineSpecificationDimension::INTRA_NODE); + + int node_idx = + compute_index(machine_view.start.node_idx, inter_dimension_indices); + int device_idx = + compute_index(machine_view.start.device_idx, intra_dimension_indices); + MachineSpaceCoordinate ms_coord = MachineSpaceCoordinate{ + node_idx, device_idx, get_device_type(machine_view)}; + + if (!is_valid_machine_space_coordinate(machine_specification, ms_coord)) { + return std::nullopt; } - assert(get_device_type(start) == DeviceType::GPU); - return make_1d_machine_view(unwrap_gpu(start), unwrap_gpu(stop), stride); + return ms_coord; } -static StridedRectangle - make_1d_rect(int start, num_points_t num_points, int stride) { - return make_1d_rect(start, start + num_points.unwrapped * stride, stride); +std::unordered_set get_machine_space_coordinates( + OperatorTaskSpace const &task, + MachineView const &machine_view, + MachineSpecification const &machine_specification) { + return transform( + get_task_space_coordinates(task), [&](TaskSpaceCoordinate const &coord) { + return get_machine_space_coordinate( + task, machine_view, coord, machine_specification) + .value(); + }); } -MachineView - make_1d_machine_view(cpu_id_t start, num_points_t num_points, int stride) { - StridedRectangle rect = make_1d_rect(start.cpu_index, num_points, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView - make_1d_machine_view(gpu_id_t start, num_points_t num_points, int stride) { - StridedRectangle rect = make_1d_rect(start.gpu_index, num_points, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView make_1d_machine_view(device_id_t start, - num_points_t num_points, - int stride) { - if (get_device_type(start) == DeviceType::CPU) { - return make_1d_machine_view(unwrap_cpu(start), num_points, stride); - } else { - assert(get_device_type(start) == DeviceType::GPU); - return make_1d_machine_view(unwrap_gpu(start), num_points, stride); - } -} - -static StridedRectangle - make_1d_rect(int start, side_size_t interval_size, int stride) { - return make_1d_rect(start, start + interval_size.unwrapped, stride); -} - -MachineView make_1d_machine_view(cpu_id_t start, - side_size_t interval_size, - int stride) { - StridedRectangle rect = make_1d_rect(start.cpu_index, interval_size, stride); - return MachineView{device_id_t{start}, rect}; -} - -MachineView make_1d_machine_view(gpu_id_t start, - side_size_t interval_size, - int stride) { - StridedRectangle rect = make_1d_rect(start.gpu_index, interval_size, stride); - return MachineView{device_id_t{start}, rect}; -} -MachineView make_1d_machine_view(device_id_t start, - side_size_t interval_size, - int stride) { - - if (get_device_type(start) == DeviceType::CPU) { - return make_1d_machine_view(unwrap_cpu(start), interval_size, stride); - } else { - assert(get_device_type(start) == DeviceType::GPU); - return make_1d_machine_view(unwrap_gpu(start), interval_size, stride); - } -} -MachineView make_1d_machine_view(device_id_t start, size_t interval_size) { - NOT_IMPLEMENTED(); -} - -/* device_id_t MachineView::at(FFOrdered const &coord) const { */ -/* size_t offset = this->rect.at(coord); */ -/* return this->start + offset; */ -/* } */ - } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/pcg/src/pcg/operator_task_space.cc new file mode 100644 index 0000000000..02522ae411 --- /dev/null +++ b/lib/pcg/src/pcg/operator_task_space.cc @@ -0,0 +1,38 @@ +#include "pcg/operator_task_space.h" +#include "utils/containers/cartesian_product.h" +#include "utils/containers/maximum.h" +#include "utils/containers/product.h" +#include "utils/containers/range.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +std::unordered_set + get_task_space_coordinates(OperatorTaskSpace const &task) { + + std::vector> coordinate_ranges = transform( + task.degrees, [&](int const &num_points) { return range(num_points); }); + + std::unordered_set> raw_coordinates = + unordered_set_of(cartesian_product(coordinate_ranges)); + std::unordered_set task_space_coordinates = + transform(raw_coordinates, [](std::vector const &point) { + return TaskSpaceCoordinate{point}; + }); + return task_space_coordinates; +} + +TaskSpaceCoordinate + get_task_space_maximum_coordinate(OperatorTaskSpace const &task) { + return maximum(get_task_space_coordinates(task)).value(); +} + +size_t num_dims(OperatorTaskSpace const &task) { + return task.degrees.size(); +} +size_t num_tasks(OperatorTaskSpace const &task) { + return product(task.degrees); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/start_invariant_machine_view.cc b/lib/pcg/src/pcg/start_invariant_machine_view.cc new file mode 100644 index 0000000000..1fcc3ea12f --- /dev/null +++ b/lib/pcg/src/pcg/start_invariant_machine_view.cc @@ -0,0 +1,86 @@ +#include "pcg/start_invariant_machine_view.h" +#include "pcg/machine_space_offset.h" +#include "pcg/machine_view.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/count.h" +#include "utils/containers/filter.h" +#include "utils/containers/scanl.h" +#include "utils/containers/transform.h" +#include "utils/containers/zip.h" +namespace FlexFlow { + +MachineView machine_view_from_start_invariant( + StartInvariantMachineView const &start_inv_mv, + MachineSpaceCoordinate const &start) { + return MachineView{start, start_inv_mv.dimensions}; +} + +StartInvariantMachineView + start_invariant_from_machine_view(MachineView const &mv) { + return StartInvariantMachineView{mv.dimensions, get_device_type(mv)}; +} + +size_t num_dims(StartInvariantMachineView const &start_inv_mv) { + return start_inv_mv.dimensions.size(); +} + +DeviceType get_device_type(StartInvariantMachineView const &start_inv_mv) { + return start_inv_mv.device_type; +} + +std::vector + get_strides(StartInvariantMachineView const &start_inv_mv) { + return transform(start_inv_mv.dimensions, + [](MachineViewDimension const &dim) { return dim.stride; }); +} + +std::vector + get_dimensions(StartInvariantMachineView const &start_inv_mv) { + return transform( + start_inv_mv.dimensions, + [](MachineViewDimension const &dim) { return dim.projection; }); +} + +StartInvariantMachineView + start_invariant_machine_view_from_strides_and_machine_spec_dimensions( + std::vector const &strides, + std::vector const &dims, + DeviceType device_type) { + std::vector dimensions = + transform(zip(strides, dims), [&](auto const &p) { + return MachineViewDimension{p.first, p.second}; + }); + return StartInvariantMachineView{dimensions, device_type}; +} + +std::optional get_machine_space_offset( + OperatorTaskSpace const &task, + StartInvariantMachineView const &start_inv_machine_view, + TaskSpaceCoordinate const &coord, + MachineSpecification const &machine_specification) { + MachineSpaceCoordinate dummy_start = + MachineSpaceCoordinate{0, 0, get_device_type(start_inv_machine_view)}; + MachineView mv = + machine_view_from_start_invariant(start_inv_machine_view, dummy_start); + std::optional ms_coord = + get_machine_space_coordinate(task, mv, coord, machine_specification); + if (ms_coord == std::nullopt) { + return std::nullopt; + } + return get_machine_space_offset_from_coordinate(dummy_start, + ms_coord.value()); +} + +std::unordered_set get_machine_space_offsets( + OperatorTaskSpace const &task, + StartInvariantMachineView const &start_inv_machine_view, + MachineSpecification const &machine_specification) { + return transform( + get_task_space_coordinates(task), [&](TaskSpaceCoordinate const &coord) { + return get_machine_space_offset( + task, start_inv_machine_view, coord, machine_specification) + .value(); + }); +} + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle.cc b/lib/pcg/src/pcg/strided_rectangle.cc deleted file mode 100644 index dfb5d0af12..0000000000 --- a/lib/pcg/src/pcg/strided_rectangle.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "pcg/strided_rectangle.h" -#include "op-attrs/dim_ordered/transform.h" -#include "utils/containers/product.h" - -namespace FlexFlow { - -/* size_t StridedRectangle::at(FFOrdered const &coord) const { */ -/* assert(coord.size() == this->num_dims()); */ - -/* size_t _1d_stride = 1; */ -/* size_t idx = 0; */ -/* for (auto dim : inner_to_outer_idxs(this->sides)) { */ -/* idx += this->sides.at(dim).at(coord.at(dim)).value() * _1d_stride; */ -/* _1d_stride *= this->sides.at(dim).get_size().value(); */ -/* } */ -/* return idx; */ -/* } */ - -size_t get_num_dims(StridedRectangle const &rect) { - return rect.sides.size(); -} - -num_points_t get_num_points(StridedRectangle const &rect) { - return num_points_t{ - product(transform(rect.sides, [](StridedRectangleSide const &side) { - return side.num_points.unwrapped; - }))}; -} - -StridedRectangleSide get_side_at_idx(StridedRectangle const &rect, - ff_dim_t const &idx) { - return rect.sides.at(idx); -} - -} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc deleted file mode 100644 index e6caf4cb86..0000000000 --- a/lib/pcg/src/pcg/strided_rectangle_side.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "pcg/strided_rectangle_side.h" -#include "utils/exception.h" - -namespace FlexFlow { - -StridedRectangleSide strided_side_from_size_and_stride(side_size_t side_size, - int stride) { - assert((side_size.unwrapped % stride) == 0); - return StridedRectangleSide{num_points_t{side_size.unwrapped / stride}, - stride}; -} - -side_size_t get_side_size(StridedRectangleSide const &s) { - return side_size_t{s.num_points.unwrapped * s.stride}; -} - -} // namespace FlexFlow diff --git a/lib/pcg/test/src/test_computation_graph_builder.cc b/lib/pcg/test/src/pcg/computation_graph_builder.cc similarity index 100% rename from lib/pcg/test/src/test_computation_graph_builder.cc rename to lib/pcg/test/src/pcg/computation_graph_builder.cc index ff169d8312..e7fa853be9 100644 --- a/lib/pcg/test/src/test_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/computation_graph_builder.cc @@ -1,6 +1,6 @@ +#include "pcg/computation_graph_builder.h" #include "doctest/doctest.h" #include "pcg/computation_graph.h" -#include "pcg/computation_graph_builder.h" using namespace ::FlexFlow; diff --git a/lib/pcg/test/src/pcg/machine_specification.cc b/lib/pcg/test/src/pcg/machine_specification.cc new file mode 100644 index 0000000000..c183ae0d31 --- /dev/null +++ b/lib/pcg/test/src/pcg/machine_specification.cc @@ -0,0 +1,54 @@ +#include "pcg/machine_specification.h" +#include "pcg/device_id.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("MachineSpecification") { + + MachineSpecification ms = MachineSpecification{ + /*num_nodes=*/4, + /*num_cpus_per_node=*/16, + /*num_gpus_per_node=*/8, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0, + }; + + SUBCASE("get_num_gpus") { + CHECK(get_num_gpus(ms) == 4 * 8); + } + + SUBCASE("get_num_cpus") { + CHECK(get_num_cpus(ms) == 4 * 16); + } + + SUBCASE("get_num_devices") { + CHECK(get_num_devices(ms, DeviceType::GPU) == 4 * 8); + CHECK(get_num_devices(ms, DeviceType::CPU) == 16 * 4); + } + + SUBCASE("get_device_id") { + SUBCASE("valid MachineSpaceCoordinate") { + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/2, + /*device_idx=*/12, + DeviceType::CPU, + }; + device_id_t correct = + device_id_from_index(2 * 16 + 12, DeviceType::CPU); + device_id_t result = get_device_id(ms, coord); + CHECK(correct == result); + } + SUBCASE("MachineSpaceCoordinate out of bounds for given machine spec") { + MachineSpaceCoordinate coord = MachineSpaceCoordinate{ + /*node_idx=*/2, + /*device_idx=*/18, + DeviceType::CPU, + }; + CHECK_THROWS(get_device_id(ms, coord)); + } + } + } +} diff --git a/lib/pcg/test/src/pcg/machine_view.cc b/lib/pcg/test/src/pcg/machine_view.cc new file mode 100644 index 0000000000..dcf22d6c00 --- /dev/null +++ b/lib/pcg/test/src/pcg/machine_view.cc @@ -0,0 +1,301 @@ +#include "pcg/machine_view.h" +#include "test/utils/doctest/fmt/optional.h" +#include "utils/containers/transform.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MachineView - utility functions") { + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}}}; + + SUBCASE("num_dims") { + CHECK(num_dims(mv) == 2); + } + SUBCASE("get_device_type") { + CHECK(get_device_type(mv) == DeviceType::GPU); + } + } + + TEST_CASE("get_machine_space_coordinate") { + SUBCASE("1D case") { + + // This operator has shape (3,), and thus 3 tasks. + // The (only) dimension is projected on the INTER (device) dimension with + // a stride of 2. The start of the projection defined by MachineView + // starts at MachineSpaceCoordinate (0,1), and the machine space has 1 + // node and 6 devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+ + * | | (0,) | | (1,) | | (2,) | + * +-------+-------+-------+-------+-------+-------+ + * Where the (x,) are the `TaskSpaceCoordinate`s, and the underlying grid + * is the machine space. + */ + OperatorTaskSpace task = OperatorTaskSpace{{3}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}, + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/1, + /*num_cpus_per_node=*/6, + /*num_gpus_per_node=*/6, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/3, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (2,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/5, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("TaskSpaceCoordinate is out of bounds") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{4}}; + std::optional result = + get_machine_space_coordinate(task, mv, coord, ms); + std::optional correct = std::nullopt; + CHECK(result == correct); + } + + SUBCASE("2D case - projection on different dimensions") { + // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + // The first dimension is projected onto the INTER (node) dimension with + // stride 1, while the second dimension is projected onto the INTRA + // (device) dimension with stride 2. The start of the projection defined + // by MachineView is at MachineSpaceCoordinates (1, 2), and the machine + // space has 3 nodes and 5 devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+ + * | | | | | | + * +-------+-------+-------+-------+-------+ + * | | | (0,0) | | (0,1) | + * +-------+-------+-------+-------+-------+ + * | | | (1,0) | | (1,1) | + * +-------+-------+-------+-------+-------+ + * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/2, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/3, + /*num_cpus_per_node=*/5, + /*num_gpus_per_node=*/5, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/2, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/4, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/2, /*device_idx=*/2, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/2, /*device_idx=*/4, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("2D case - projection on same dimension") { + // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + // Both dimensions are projected on the INTRA (device) dimension, with + // strides 1 and 2 respectively. The start of the projection defined by + // MachineView is at MachineSpaceCoordinates (1, 0), and the machine + // space has 2 nodes and 6 devices per node. + + /** + * +-------+-------+-------+-------+-------+-------+ + * | (0,0) | (1,0) | | | (0,1) | (1,1) | + * +-------+-------+-------+-------+-------+-------+ + * Where the (x,y) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/0, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/2, + /*num_cpus_per_node=*/6, + /*num_gpus_per_node=*/6, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/0, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/4, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/1, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/5, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("3D case") { + // This operator has shape (2, 2, 2), and thus 2 * 2 * 2 = 8 tasks. + // - The first dimension is projected onto the INTER (node) dimension + // with stride 1, + // - The second dimension is projected onto the INTRA (device) dimension + // with stride 2, + // - The third dimension is projected onto the INTRA (device) dimension + // with stride 1. The start of the projection defined by MachineView is + // at MachineSpaceCoordinates (0, 1), and the machine space has 2 nodes + // and 8 devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * | |(0,0,0)| |(0,0,1)| |(0,1,0)| |(0,1,1)| + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * | |(1,0,0)| |(1,0,1)| |(1,1,0)| |(1,1,1)| + * +-------+-------+-------+-------+-------+-------+-------+-------+ + * Where the (x,y,z) are the `TaskSpaceCoordinate`s, and the underlying + * grid is the machine space. + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2, 2}}; + MachineView mv = MachineView{ + MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/1, DeviceType::GPU}, + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}, + MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTRA_NODE}}}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/2, + /*num_cpus_per_node=*/8, + /*num_gpus_per_node=*/8, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("Task with TaskSpaceCoordinate = (0,0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1, 0}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/0, /*device_idx=*/3, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/5, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1, 1}}; + MachineSpaceCoordinate correct = MachineSpaceCoordinate{ + /*node_idx=*/1, /*device_idx=*/7, DeviceType::GPU}; + MachineSpaceCoordinate result = + get_machine_space_coordinate(task, mv, coord, ms).value(); + CHECK(correct == result); + } + } + } + } +} diff --git a/lib/pcg/test/src/pcg/operator_task_space.cc b/lib/pcg/test/src/pcg/operator_task_space.cc new file mode 100644 index 0000000000..13198d9456 --- /dev/null +++ b/lib/pcg/test/src/pcg/operator_task_space.cc @@ -0,0 +1,66 @@ +#include "pcg/operator_task_space.h" +#include "utils/fmt/unordered_set.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_task_space_coordinates") { + + SUBCASE("OperatorTaskSpace has 0 dimensions") { + OperatorTaskSpace task = OperatorTaskSpace{{}}; + + std::unordered_set correct = { + TaskSpaceCoordinate{{}}}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + SUBCASE("OperatorTaskSpace has 2 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + + std::unordered_set correct = {{ + TaskSpaceCoordinate{{0, 0}}, + TaskSpaceCoordinate{{0, 1}}, + TaskSpaceCoordinate{{1, 0}}, + TaskSpaceCoordinate{{1, 1}}, + }}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + SUBCASE("OperatorTaskSpace has 3 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{1, 2, 2}}; + + std::unordered_set correct = {{ + TaskSpaceCoordinate{{0, 0, 0}}, + TaskSpaceCoordinate{{0, 0, 1}}, + TaskSpaceCoordinate{{0, 1, 0}}, + TaskSpaceCoordinate{{0, 1, 1}}, + }}; + std::unordered_set result = + get_task_space_coordinates(task); + CHECK(correct == result); + } + } + TEST_CASE("get_task_space_maximum_coordinate") { + SUBCASE("OperatorTaskSpace has 2 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{3, 2}}; + + TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2, 1}}; + TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); + CHECK(correct == result); + } + SUBCASE("OperatorTaskSpace has 3 dimensions") { + + OperatorTaskSpace task = OperatorTaskSpace{{3, 2, 4}}; + + TaskSpaceCoordinate correct = TaskSpaceCoordinate{{2, 1, 3}}; + TaskSpaceCoordinate result = get_task_space_maximum_coordinate(task); + CHECK(correct == result); + } + } +} diff --git a/lib/pcg/test/src/pcg/start_invariant_machine_view.cc b/lib/pcg/test/src/pcg/start_invariant_machine_view.cc new file mode 100644 index 0000000000..8383754aa2 --- /dev/null +++ b/lib/pcg/test/src/pcg/start_invariant_machine_view.cc @@ -0,0 +1,229 @@ +#include "pcg/start_invariant_machine_view.h" +#include "utils/fmt/unordered_set.h" +#include "utils/fmt/vector.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("StartInvariantMachineView - utility functions") { + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}}, + DeviceType::GPU}; + + SUBCASE("num_dims") { + int result = num_dims(simv); + int correct = 2; + CHECK(result == correct); + } + + SUBCASE("get_device_type") { + DeviceType result = get_device_type(simv); + DeviceType correct = DeviceType::GPU; + CHECK(result == correct); + } + + SUBCASE("get_strides") { + std::vector result = get_strides(simv); + std::vector correct = {stride_t{2}, stride_t{2}}; + CHECK(result == correct); + } + + SUBCASE("get_dimensions") { + std::vector result = get_dimensions(simv); + std::vector correct = { + MachineSpecificationDimension::INTER_NODE, + MachineSpecificationDimension::INTER_NODE}; + CHECK(result == correct); + } + } + + TEST_CASE("StartInvariantMachineView - conversions") { + MachineSpaceCoordinate start = + MachineSpaceCoordinate{1, 2, DeviceType::GPU}; + std::vector dimensions = { + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{3}, + MachineSpecificationDimension::INTRA_NODE}}; + + MachineView mv = MachineView{start, dimensions}; + StartInvariantMachineView simv = + StartInvariantMachineView{dimensions, DeviceType::GPU}; + + SUBCASE("start_invariant_from_machine_view") { + StartInvariantMachineView result = start_invariant_from_machine_view(mv); + StartInvariantMachineView correct = simv; + CHECK(result == correct); + } + + SUBCASE("machine_view_from_start_invariant") { + MachineView result = machine_view_from_start_invariant(simv, start); + MachineView correct = mv; + CHECK(result == correct); + } + + SUBCASE("conversion is invertible") { + SUBCASE("MachineView -> StartInvariant -> MachineView") { + MachineView result = machine_view_from_start_invariant( + start_invariant_from_machine_view(mv), start); + MachineView correct = mv; + CHECK(result == correct); + } + + SUBCASE("StartInvariant -> MachineView -> StartInvariant") { + StartInvariantMachineView result = start_invariant_from_machine_view( + machine_view_from_start_invariant(simv, start)); + StartInvariantMachineView correct = simv; + CHECK(result == correct); + } + } + } + + TEST_CASE("StartInvariantMachineView - get_machine_space_offset") { + SUBCASE("1D case") { + // This operator has shape (3,), and thus 3 tasks. + // The (only) dimension is projected on the INTRA (device) dimension with + // a stride of 2. The machine space has 1 node and 6 devices per node. + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+-------+-------+ + * | (0,) | | (1,) | | (2,) | | + * +-------+-------+-------+-------+-------+-------+ + */ + OperatorTaskSpace task = OperatorTaskSpace{{3}}; + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}, + DeviceType::GPU}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/1, + /*num_cpus_per_node=*/6, + /*num_gpus_per_node=*/6, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("get_machine_space_offset") { + SUBCASE("Task with TaskSpaceCoordinate = (0,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (2,)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{2}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 4, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("get_machine_space_offsets") { + std::unordered_set correct = { + MachineSpaceOffset{0, 0, DeviceType::GPU}, + MachineSpaceOffset{0, 2, DeviceType::GPU}, + MachineSpaceOffset{0, 4, DeviceType::GPU}}; + std::unordered_set result = + get_machine_space_offsets(task, simv, ms); + CHECK(correct == result); + } + } + + SUBCASE("2D case") { + // This operator has shape (2, 2), and thus 2 * 2 = 4 tasks. + // The first dimension is projected onto the INTER (node) dimension with + // stride 1, while the second dimension is projected onto the INTRA + // (device) dimension with stride 2. The machine space has 2 nodes and 4 + // devices per node. + + /** + * The tasks will thus be distributed like this: + * +-------+-------+-------+-------+ + * | (0,0) | | (0,1) | | + * +-------+-------+-------+-------+ + * | (1,0) | | (1,1) | | + * +-------+-------+-------+-------+ + */ + + OperatorTaskSpace task = OperatorTaskSpace{{2, 2}}; + StartInvariantMachineView simv = StartInvariantMachineView{ + {MachineViewDimension{stride_t{1}, + MachineSpecificationDimension::INTER_NODE}, + MachineViewDimension{stride_t{2}, + MachineSpecificationDimension::INTRA_NODE}}, + DeviceType::GPU}; + MachineSpecification ms = + MachineSpecification{/*num_nodes=*/2, + /*num_cpus_per_node=*/4, + /*num_gpus_per_node=*/4, + /*inter_node_bandwidth=*/0, + /*intra_node_bandwidth=*/0}; + + SUBCASE("get_machine_space_offset") { + SUBCASE("Task with TaskSpaceCoordinate = (0,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 0}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (0,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{0, 1}}; + MachineSpaceOffset correct = + MachineSpaceOffset{0, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,0)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 0}}; + MachineSpaceOffset correct = + MachineSpaceOffset{1, 0, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + + SUBCASE("Task with TaskSpaceCoordinate = (1,1)") { + TaskSpaceCoordinate coord = TaskSpaceCoordinate{{1, 1}}; + MachineSpaceOffset correct = + MachineSpaceOffset{1, 2, DeviceType::GPU}; + MachineSpaceOffset result = + get_machine_space_offset(task, simv, coord, ms).value(); + CHECK(correct == result); + } + } + + SUBCASE("get_machine_space_offsets") { + std::unordered_set correct = { + MachineSpaceOffset{0, 0, DeviceType::GPU}, + MachineSpaceOffset{0, 2, DeviceType::GPU}, + MachineSpaceOffset{1, 0, DeviceType::GPU}, + MachineSpaceOffset{1, 2, DeviceType::GPU}}; + std::unordered_set result = + get_machine_space_offsets(task, simv, ms); + CHECK(correct == result); + } + } + } +} diff --git a/lib/pcg/test/src/test_machine_view.cc b/lib/pcg/test/src/test_machine_view.cc deleted file mode 100644 index 25c6e21b87..0000000000 --- a/lib/pcg/test/src/test_machine_view.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "pcg/machine_view.h" -#include "pcg/strided_rectangle.h" -#include "pcg/strided_rectangle_side.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MachineView general util functions") { - StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}, - StridedRectangleSide{num_points_t{10}, 2}}}; - gpu_id_t start(1); - MachineView mv{device_id_t{start}, rect}; - SUBCASE("num_dims") { - CHECK(num_dims(mv) == 2); - } - SUBCASE("num_devices") { - CHECK(num_devices(mv) == 7 * 10); - } - SUBCASE("get_device_type") { - CHECK(get_device_type(mv) == DeviceType::GPU); - } - } - - TEST_CASE("MachineView make_1d_machine_view - GPU") { - StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}}}; - device_id_t start_gpu{gpu_id_t{1}}; - MachineView gpu_mv{start_gpu, rect}; - - SUBCASE("make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride)") { - MachineView result = - make_1d_machine_view(start_gpu, device_id_t{gpu_id_t(1 + 7 * 5)}, 5); - MachineView correct = gpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(gpu_id_t start, num_points_t num_points, int " - "stride)") { - MachineView result = make_1d_machine_view(start_gpu, num_points_t{7}, 5); - MachineView correct = gpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(gpu_id_t start, side_size_t interval_size, " - "int stride)") { - MachineView result = make_1d_machine_view( - start_gpu, get_side_size(rect.sides.at(ff_dim_t{0})), 5); - MachineView correct = gpu_mv; - CHECK(result == correct); - } - } - - TEST_CASE("MachineView make_1d_machine_view - CPU") { - StridedRectangle rect{{StridedRectangleSide{num_points_t{11}, 4}}}; - device_id_t start_cpu{cpu_id_t{2}}; - MachineView cpu_mv{start_cpu, rect}; - - SUBCASE("make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride)") { - MachineView result = - make_1d_machine_view(start_cpu, device_id_t{cpu_id_t(2 + 11 * 4)}, 4); - MachineView correct = cpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(cpu_id_t start, num_points_t num_points, int " - "stride)") { - MachineView result = make_1d_machine_view(start_cpu, num_points_t{11}, 4); - MachineView correct = cpu_mv; - CHECK(result == correct); - } - SUBCASE("make_1d_machine_view(cpu_id_t start, side_size_t interval_size, " - "int stride)") { - MachineView result = make_1d_machine_view( - start_cpu, get_side_size(rect.sides.at(ff_dim_t{0})), 4); - MachineView correct = cpu_mv; - CHECK(result == correct); - } - } -} diff --git a/lib/pcg/test/src/test_strided_rectangle.cc b/lib/pcg/test/src/test_strided_rectangle.cc deleted file mode 100644 index ac6af9fa19..0000000000 --- a/lib/pcg/test/src/test_strided_rectangle.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "pcg/strided_rectangle.h" -#include "pcg/strided_rectangle_side.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("get_side_size(StridedRectangleSide)") { - StridedRectangleSide side{num_points_t{7}, 5}; - - CHECK(get_side_size(side) == side_size_t{7 * 5}); - } - TEST_CASE("strided_side_from_size_and_stride") { - StridedRectangleSide correct{num_points_t{10}, 3}; - StridedRectangleSide result = - strided_side_from_size_and_stride(side_size_t{10 * 3}, 3); - CHECK(result == correct); - } - - TEST_CASE("StridedRectangle - helper functions") { - - StridedRectangleSide s0{num_points_t{7}, 5}; - StridedRectangleSide s1{num_points_t{10}, 2}; - StridedRectangleSide s2{num_points_t{8}, 1}; - StridedRectangle rect{{s0, s1, s2}}; - - SUBCASE("get_num_dims") { - CHECK(get_num_dims(rect) == 3); - } - SUBCASE("get_num_points") { - CHECK(get_num_points(rect) == num_points_t{7 * 8 * 10}); - } - SUBCASE("get_side_at_idx") { - CHECK(get_side_at_idx(rect, ff_dim_t{0}) == s0); - CHECK(get_side_at_idx(rect, ff_dim_t{1}) == s1); - CHECK(get_side_at_idx(rect, ff_dim_t{2}) == s2); - } - } -} diff --git a/lib/utils/include/utils/containers/cartesian_product.h b/lib/utils/include/utils/containers/cartesian_product.h index bcba52113e..28d0fb118c 100644 --- a/lib/utils/include/utils/containers/cartesian_product.h +++ b/lib/utils/include/utils/containers/cartesian_product.h @@ -1,7 +1,6 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_CARTESIAN_PRODUCT_H -#include "utils/containers/vector_of.h" #include "utils/hash/vector.h" #include #include @@ -9,10 +8,10 @@ namespace FlexFlow { -template -std::unordered_set> - cartesian_product(std::vector> const &containers) { - std::unordered_set> result; +template +std::unordered_multiset> + cartesian_product(std::vector const &containers) { + std::unordered_multiset> result; std::function &, size_t)> recurse = [&](std::vector ¤t, size_t depth) { diff --git a/lib/utils/include/utils/containers/filter.h b/lib/utils/include/utils/containers/filter.h index fb8c703d2a..07f25dc348 100644 --- a/lib/utils/include/utils/containers/filter.h +++ b/lib/utils/include/utils/containers/filter.h @@ -44,6 +44,14 @@ std::map filter(std::map const &m, F const &f) { return result; } +template +std::unordered_multiset filter(std::unordered_multiset const &m, + F const &f) { + std::unordered_multiset result; + std::copy_if(m.cbegin(), m.cend(), std::inserter(result, result.begin()), f); + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/foldl.h b/lib/utils/include/utils/containers/foldl.h new file mode 100644 index 0000000000..16851d7d9b --- /dev/null +++ b/lib/utils/include/utils/containers/foldl.h @@ -0,0 +1,72 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H + +#include "utils/exception.h" +#include "utils/fmt/vector.h" +#include +#include +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Iteratively applies `func` to the elements of `c` from left to right. + * `init` is used as the starting value. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * int result = foldl(nums, 0, [](int a, int b) { return a + b; }); + * result -> ((((0+1)+2)+3)+4) = 10 + * + * @note + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldl + */ +template +T foldl(C const &c, T init, F func) { + T result = init; + for (auto const &elem : c) { + result = func(result, elem); + } + return result; +} + +/** + * @brief + * Applies `func` to the elements of `c` from left to right, accumulating the + * result. The first element of `c` is used as the starting point for the + * accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * int result = foldl1(nums, [](int a, int b) { return a + b; }); + * result -> (((1+2)+3)+4) = 10 + * + * @note + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:foldl1 + * @throws std::runtime_error if the container is empty. + */ +template +E foldl1(C const &c, F func) { + if (c.empty()) { + throw mk_runtime_error( + fmt::format("foldl1 received empty container: {}", c)); + } + std::optional result = std::nullopt; + + for (E const &e : c) { + if (!result.has_value()) { + result = e; + } else { + result = func(result.value(), e); + } + } + return result.value(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h index b7b30cbae4..9981948f47 100644 --- a/lib/utils/include/utils/containers/get_all_assignments.h +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -5,6 +5,7 @@ #include "utils/containers/keys.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_map_from_pairs.h" +#include "utils/containers/unordered_set_of.h" #include "utils/containers/vector_of.h" #include "utils/containers/zip.h" #include "utils/hash/unordered_map.h" @@ -30,7 +31,7 @@ std::unordered_set> get_all_assignments( ordered_keys, [&](K const &k) { return options_per_key.at(k); }); std::unordered_set> result = transform( - cartesian_product(ordered_value_option_sets), + unordered_set_of(cartesian_product(ordered_value_option_sets)), [&](std::vector const &chosen_values) { return unordered_map_from_pairs(zip(ordered_keys, chosen_values)); }); diff --git a/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h new file mode 100644 index 0000000000..ccdde0131a --- /dev/null +++ b/lib/utils/include/utils/containers/get_all_permutations_with_repetition.h @@ -0,0 +1,50 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_WITH_REPETITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_PERMUTATIONS_WITH_REPETITION_H + +#include +#include + +namespace FlexFlow { + +/** + * @brief For a given container `c` and integer `n`, return all possible vectors + * of size `n` that only contain (possibly duplicated) elements of `c`. + * @details + * https://en.wikipedia.org/wiki/Permutation#Permutations_with_repetition + **/ +template +std::unordered_multiset> + get_all_permutations_with_repetition(C const &container, int n) { + std::unordered_multiset> result; + + if (container.empty() || n == 0) { + return result; + } + + std::vector elements(std::begin(container), std::end(container)); + std::vector indices(n, 0); + + while (true) { + std::vector perm(n); + for (int i = 0; i < n; ++i) { + perm[i] = elements[indices[i]]; + } + result.insert(perm); + + int i = n - 1; + while (i != -1 && ++indices[i] == elements.size()) { + indices[i] = 0; + --i; + } + + if (i == -1) { + break; + } + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_from_keys_and_values.h b/lib/utils/include/utils/containers/map_from_keys_and_values.h new file mode 100644 index 0000000000..499965dc5e --- /dev/null +++ b/lib/utils/include/utils/containers/map_from_keys_and_values.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_KEYS_AND_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_KEYS_AND_VALUES_H + +#include "utils/containers/zip.h" +#include "utils/exception.h" +#include + +namespace FlexFlow { + +template +std::unordered_map + map_from_keys_and_values(std::vector const &keys, + std::vector const &values) { + if (keys.size() != values.size()) { + throw mk_runtime_error(fmt::format( + "recieved keys (of size {}) not matching values (of size {})", + keys.size(), + values.size())); + } + std::unordered_map result; + for (auto const &[k, v] : zip(keys, values)) { + result.insert({k, v}); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/range.h b/lib/utils/include/utils/containers/range.h new file mode 100644 index 0000000000..ff6b9f44ee --- /dev/null +++ b/lib/utils/include/utils/containers/range.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RANGE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RANGE_H + +#include + +namespace FlexFlow { + +std::vector range(int start, int end, int step = 1); +std::vector range(int end); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/replicate.h b/lib/utils/include/utils/containers/replicate.h new file mode 100644 index 0000000000..aa3d0a7e35 --- /dev/null +++ b/lib/utils/include/utils/containers/replicate.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REPLICATE_H + +#include + +namespace FlexFlow { + +template +std::vector replicate(int n, T const &element) { + return std::vector(n, element); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/scanl.h b/lib/utils/include/utils/containers/scanl.h new file mode 100644 index 0000000000..a30a9e1576 --- /dev/null +++ b/lib/utils/include/utils/containers/scanl.h @@ -0,0 +1,77 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SCANL_H + +#include +#include + +namespace FlexFlow { + +/** + * @brief + * Applies `op` to the elements of `c` from left to right, accumulating + * the intermediate results in a vector. `init` is used as the starting point + * for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl(nums, 0, [](int a, int b) {return a+b;}); + * result -> {0,1,3,6,10} + * + * @note + * Essentially a foldl which stores the intermediate results + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl + */ +template +std::vector scanl(C const &c, T init, F const &op) { + std::vector result; + + result.push_back(init); + for (auto const &elem : c) { + init = op(init, elem); + result.push_back(init); + } + + return result; +} + +/** + * @brief + * Applies `op` to the elements of `c` from left to right, accumulating + * the intermediate results in a vector. The first item of `c` is used as the + * starting point for the accumulation. + * + * @example + * std::vector nums = {1, 2, 3, 4}; + * auto result = scanl1(nums, [](int a, int b) {return a+b;}); + * result -> {1,3,6,10} + * + * @note + * Essentially a foldl1 which stores the intermediate results. + * For more information, see + * https://hackage.haskell.org/package/base-4.20.0.1/docs/Prelude.html#v:scanl1 + */ +template +std::vector scanl1(C const &c, F op) { + + if (c.empty()) { + return std::vector(); + } + + std::optional init = std::nullopt; + std::vector result; + + for (T const &elem : c) { + if (!init.has_value()) { + init = elem; + } else { + init = op(init.value(), elem); + } + result.push_back(init.value()); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index ef6a26c79a..a8a6a749cd 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -25,7 +25,7 @@ auto transform(req const &c, F const &f) template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; - for (auto const &e : v) { + for (In const &e : v) { result.insert(f(e)); } return result; @@ -35,7 +35,7 @@ template > std::unordered_multiset transform(std::unordered_multiset const &v, F const &f) { std::unordered_multiset result; - for (auto const &e : v) { + for (In const &e : v) { result.insert(f(e)); } return result; @@ -44,7 +44,7 @@ std::unordered_multiset transform(std::unordered_multiset const &v, template > std::set transform(std::set const &v, F const &f) { std::set result; - for (auto const &e : v) { + for (In const &e : v) { result.insert(f(e)); } return result; diff --git a/lib/utils/include/utils/containers/without_order.h b/lib/utils/include/utils/containers/without_order.h deleted file mode 100644 index 7199b2bd4a..0000000000 --- a/lib/utils/include/utils/containers/without_order.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_WITHOUT_ORDER_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_WITHOUT_ORDER_H - -#include - -namespace FlexFlow { - -template -std::unordered_multiset without_order(C const &c) { - return {c.cbegin(), c.cend()}; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/zip.h b/lib/utils/include/utils/containers/zip.h index 94182577ee..0f6dbed1d3 100644 --- a/lib/utils/include/utils/containers/zip.h +++ b/lib/utils/include/utils/containers/zip.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_H +#include #include #include @@ -16,6 +17,17 @@ std::vector> zip(std::vector const &l, return result; } +template +std::vector> zip(std::vector const &a, + std::vector const &b, + std::vector const &c) { + std::vector> result; + for (int i = 0; i < std::min({a.size(), b.size(), c.size()}); i++) { + result.push_back(std::make_tuple(a.at(i), b.at(i), c.at(i))); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index deb03a04d4..09dd3c5eab 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -23,7 +23,6 @@ struct formatter< ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - // } return formatter::format("{" + result + "}", ctx); } }; diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index 257545af1b..be347ec5ea 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -24,7 +24,6 @@ struct formatter< ::FlexFlow::join_strings(m.cbegin(), m.cend(), ", ", [](T const &t) { return fmt::to_string(t); }); - // } return formatter::format("{" + result + "}", ctx); } }; diff --git a/lib/utils/src/utils/containers/foldl.cc b/lib/utils/src/utils/containers/foldl.cc new file mode 100644 index 0000000000..a4c32e83cc --- /dev/null +++ b/lib/utils/src/utils/containers/foldl.cc @@ -0,0 +1 @@ +#include "utils/containers/foldl.h" diff --git a/lib/utils/src/utils/containers/get_all_assignments.cc b/lib/utils/src/utils/containers/get_all_assignments.cc index 3a7cf6377a..f920ba1c1a 100644 --- a/lib/utils/src/utils/containers/get_all_assignments.cc +++ b/lib/utils/src/utils/containers/get_all_assignments.cc @@ -1 +1,12 @@ #include "utils/containers/get_all_assignments.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template std::unordered_set> + get_all_assignments(std::unordered_map> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/range.cc b/lib/utils/src/utils/containers/range.cc new file mode 100644 index 0000000000..d3ebd1063b --- /dev/null +++ b/lib/utils/src/utils/containers/range.cc @@ -0,0 +1,26 @@ +#include "utils/containers/range.h" +#include + +namespace FlexFlow { + +std::vector range(int start, int end, int step) { + assert(step != 0); + + std::vector result; + if (step > 0) { + for (int i = start; i < end; i += step) { + result.push_back(i); + } + } else { + for (int i = start; i > end; i += step) { + result.push_back(i); + } + } + return result; +} + +std::vector range(int end) { + return range(0, end); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/replicate.cc b/lib/utils/src/utils/containers/replicate.cc new file mode 100644 index 0000000000..2fb2f079f6 --- /dev/null +++ b/lib/utils/src/utils/containers/replicate.cc @@ -0,0 +1 @@ +#include "utils/containers/replicate.h" diff --git a/lib/utils/src/utils/containers/scanl.cc b/lib/utils/src/utils/containers/scanl.cc new file mode 100644 index 0000000000..4f7ff78b9f --- /dev/null +++ b/lib/utils/src/utils/containers/scanl.cc @@ -0,0 +1 @@ +#include "utils/containers/scanl.h" diff --git a/lib/utils/src/utils/containers/without_order.cc b/lib/utils/src/utils/containers/without_order.cc deleted file mode 100644 index 3ef44b8044..0000000000 --- a/lib/utils/src/utils/containers/without_order.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/without_order.h" diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index dca500ced5..76b7fd0d31 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -119,7 +119,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == expected); } - TEST_CASE("without_order") { + TEST_CASE("unordered_multiset_of") { std::vector v = {1, 4, 6, 4, 6}; std::unordered_set expected = {1, 4, 6}; CHECK(unordered_set_of(v) == expected); diff --git a/lib/utils/test/src/utils/containers/cartesian_product.cc b/lib/utils/test/src/utils/containers/cartesian_product.cc index 42b8a10439..773d94c8d0 100644 --- a/lib/utils/test/src/utils/containers/cartesian_product.cc +++ b/lib/utils/test/src/utils/containers/cartesian_product.cc @@ -1,5 +1,5 @@ #include "utils/containers/cartesian_product.h" -#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/vector.h" #include #include @@ -11,51 +11,60 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("cartesian_product") { SUBCASE("empty") { - std::vector> containers = {}; - std::unordered_set> result = + std::vector> containers = {}; + std::unordered_multiset> result = cartesian_product(containers); - std::unordered_set> correct = {{}}; + std::unordered_multiset> correct = {{}}; CHECK(result == correct); } SUBCASE("single container, one element") { - std::vector> containers = {{1}}; - std::unordered_set> result = + std::vector> containers = {{1}}; + std::unordered_multiset> result = cartesian_product(containers); - std::unordered_set> correct = {{1}}; + std::unordered_multiset> correct = {{1}}; CHECK(result == correct); } SUBCASE("single container, multiple elements") { - std::vector> containers = {{1, 2, 3}}; - std::unordered_set> result = + std::vector> containers = {{1, 2, 3}}; + std::unordered_multiset> result = cartesian_product(containers); - std::unordered_set> correct = {{1}, {2}, {3}}; + std::unordered_multiset> correct = {{1}, {2}, {3}}; CHECK(result == correct); } SUBCASE("multiple containers, one element each") { - std::vector> containers = {{1}, {2}, {3}}; - std::unordered_set> result = + std::vector> containers = {{1}, {2}, {3}}; + std::unordered_multiset> result = cartesian_product(containers); - std::unordered_set> correct = {{1, 2, 3}}; + std::unordered_multiset> correct = {{1, 2, 3}}; CHECK(result == correct); } SUBCASE("multiple containers, multiple elements") { - std::vector> containers = {{1, 2}, {3, 4}}; - std::unordered_set> result = + std::vector> containers = {{1, 2}, {3, 4}}; + std::unordered_multiset> result = cartesian_product(containers); - std::unordered_set> correct = { + std::unordered_multiset> correct = { {1, 3}, {1, 4}, {2, 3}, {2, 4}}; CHECK(result == correct); } + SUBCASE("multiple containers, duplicate elements") { + std::vector> containers = {{1, 1}, {2, 3}}; + std::unordered_multiset> result = + cartesian_product(containers); + std::unordered_multiset> correct = { + {1, 2}, {1, 3}, {1, 3}, {1, 2}}; + CHECK(result == correct); + } + SUBCASE("1 empty container, 1 non-empty container") { - std::vector> containers = {{}, {2, 3}}; - std::unordered_set> result = + std::vector> containers = {{}, {2, 3}}; + std::unordered_multiset> result = cartesian_product(containers); - std::unordered_set> correct = {}; + std::unordered_multiset> correct = {}; CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/containers/filter.cc b/lib/utils/test/src/utils/containers/filter.cc index 770ad40375..9462d30024 100644 --- a/lib/utils/test/src/utils/containers/filter.cc +++ b/lib/utils/test/src/utils/containers/filter.cc @@ -2,6 +2,7 @@ #include "test/utils/doctest/fmt/map.h" #include "test/utils/doctest/fmt/set.h" #include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" #include "test/utils/rapidcheck.h" @@ -95,4 +96,13 @@ TEST_SUITE(FF_TEST_SUITE) { }; CHECK(result == correct); } + + TEST_CASE("filter(std::unordered_multiset, F)") { + std::unordered_multiset input = {1, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, 8}; + auto predicate = [](int x) { return x % 2 == 0; }; + + std::unordered_multiset result = filter(input, predicate); + std::unordered_multiset correct = {2, 2, 2, 4, 6, 8, 8}; + CHECK(result == correct); + } } diff --git a/lib/utils/test/src/utils/containers/foldl.cc b/lib/utils/test/src/utils/containers/foldl.cc new file mode 100644 index 0000000000..9ed9768a92 --- /dev/null +++ b/lib/utils/test/src/utils/containers/foldl.cc @@ -0,0 +1,47 @@ +#include "utils/containers/foldl.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("foldl") { + SUBCASE("product") { + std::vector container = {1, 2, 3, 4, 5}; + int result = + foldl(container, 1, [](int acc, int elem) { return acc * elem; }); + int correct = 120; + CHECK(result == correct); + } + + SUBCASE("string concat") { + std::vector container = {1, 2, 3, 4, 5}; + std::string result = + foldl(container, std::string(""), [](std::string acc, int elem) { + return acc + std::to_string(elem); + }); + std::string correct = "12345"; + CHECK(result == correct); + } + } + + TEST_CASE("foldl1") { + SUBCASE("product") { + std::vector container = {1, 2, 3, 4, 5}; + int result = + foldl1(container, [](int acc, int elem) { return acc * elem; }); + int correct = 120; + CHECK(result == correct); + } + + SUBCASE("string concat") { + std::vector container = {"1", "2", "3", "4", "5"}; + std::string result = + foldl1(container, + [](std::string acc, std::string elem) { return acc + elem; }); + std::string correct = "12345"; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc b/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc new file mode 100644 index 0000000000..f25bcf65b1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_all_permutations_with_repetition.cc @@ -0,0 +1,75 @@ +#include "utils/containers/get_all_permutations_with_repetition.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "test/utils/doctest/fmt/vector.h" +#include "utils/hash/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("get_all_permutations_with_repetition") { + SUBCASE("output vector has only one element") { + std::vector input = {1, 2, 3}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 1); + std::unordered_multiset> correct = { + {1}, + {2}, + {3}, + }; + + CHECK(result == correct); + } + + SUBCASE("input vector has only one element") { + std::vector input = {1}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 2); + std::unordered_multiset> correct = { + {1, 1}, + }; + + CHECK(result == correct); + } + + SUBCASE("input, output vectors have more than 1 element") { + std::vector input = {1, 2}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 3); + std::unordered_multiset> correct = { + {1, 1, 1}, + {1, 1, 2}, + {1, 2, 1}, + {1, 2, 2}, + {2, 1, 1}, + {2, 1, 2}, + {2, 2, 1}, + {2, 2, 2}, + }; + + CHECK(result == correct); + } + + SUBCASE("duplicate elements") { + std::vector input = {1, 2, 2}; + + std::unordered_multiset> result = + get_all_permutations_with_repetition(input, 2); + std::unordered_multiset> correct = {{1, 1}, + {1, 2}, + {1, 2}, + {2, 1}, + {2, 1}, + {2, 2}, + {2, 2}, + {2, 2}, + {2, 2}}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/range.cc b/lib/utils/test/src/utils/containers/range.cc new file mode 100644 index 0000000000..f115855323 --- /dev/null +++ b/lib/utils/test/src/utils/containers/range.cc @@ -0,0 +1,54 @@ +#include "utils/containers/range.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("range") { + SUBCASE("step=1") { + std::vector result = range(0, 5); + std::vector correct = {0, 1, 2, 3, 4}; + CHECK(result == correct); + } + + SUBCASE("step = 2") { + std::vector result = range(-2, 10, 2); + std::vector correct = {-2, 0, 2, 4, 6, 8}; + CHECK(result == correct); + } + + SUBCASE("step = -1") { + std::vector result = range(5, 0, -1); + std::vector correct = {5, 4, 3, 2, 1}; + CHECK(result == correct); + } + + SUBCASE("single argument") { + std::vector result = range(5); + std::vector correct = {0, 1, 2, 3, 4}; + CHECK(result == correct); + } + + SUBCASE("start = end") { + std::vector result = range(5, 5); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("start > end") { + std::vector result = range(5, 4); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("start < end, step < 0") { + std::vector result = range(0, 10, -1); + std::vector correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/replicate.cc b/lib/utils/test/src/utils/containers/replicate.cc new file mode 100644 index 0000000000..1c7845642e --- /dev/null +++ b/lib/utils/test/src/utils/containers/replicate.cc @@ -0,0 +1,25 @@ +#include "utils/containers/replicate.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("replicate") { + SUBCASE("ints") { + int x = 42; + std::vector result = replicate(5, x); + std::vector correct = {42, 42, 42, 42, 42}; + CHECK(result == correct); + } + SUBCASE("unordered_set") { + std::unordered_set x = {1.0, 1.5}; + std::vector> result = replicate(3, x); + std::vector> correct = { + {1.0, 1.5}, {1.0, 1.5}, {1.0, 1.5}}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/scanl.cc b/lib/utils/test/src/utils/containers/scanl.cc new file mode 100644 index 0000000000..d6da0ac0a1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/scanl.cc @@ -0,0 +1,71 @@ +#include "utils/containers/scanl.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("scanl") { + SUBCASE("sum") { + std::vector input = {1, 2, 3, 4}; + std::vector result = + scanl(input, 0, [](int a, int b) { return a + b; }); + std::vector correct = {0, 1, 3, 6, 10}; + CHECK(result == correct); + } + + SUBCASE("custom function") { + std::vector input = {1, 3, 1, 2}; + auto op = [](int a, int b) { return (a + 1) * (b + 1); }; + std::vector result = scanl(input, 1, op); + std::vector correct = {1, 4, 20, 42, 129}; + CHECK(result == correct); + } + + SUBCASE("heterogeneous types") { + std::vector input = {1, 2, 3, 4}; + auto op = [](std::string const &a, int b) { + return a + std::to_string(b); + }; + std::vector result = scanl(input, std::string(""), op); + std::vector correct = {"", "1", "12", "123", "1234"}; + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + std::vector result = + scanl(input, 0, [](int a, int b) { return a + b; }); + std::vector correct = {0}; + CHECK(result == correct); + } + } + + TEST_CASE("scanl1") { + SUBCASE("sum") { + std::vector input = {1, 2, 3, 4}; + std::vector result = + scanl1(input, [](int a, int b) { return a + b; }); + std::vector correct = {1, 3, 6, 10}; + CHECK(result == correct); + } + + SUBCASE("custom function") { + std::vector input = {1, 2, 5, 2}; + auto op = [](int a, int b) { return a * b + 1; }; + std::vector result = scanl1(input, op); + std::vector correct = {1, 3, 16, 33}; + CHECK(result == correct); + } + + SUBCASE("empty input") { + std::vector input = {}; + std::vector result = + scanl1(input, [](int a, int b) { return a + b; }); + std::vector correct = {}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/without_order.cc b/lib/utils/test/src/utils/containers/without_order.cc deleted file mode 100644 index b4c8663b14..0000000000 --- a/lib/utils/test/src/utils/containers/without_order.cc +++ /dev/null @@ -1,15 +0,0 @@ -#include "utils/containers/without_order.h" -#include "test/utils/doctest/fmt/unordered_multiset.h" -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("without_order") { - std::vector input = {1, 2, 3, 3, 2, 3}; - std::unordered_multiset result = without_order(input); - std::unordered_multiset correct = {1, 2, 3, 3, 2, 3}; - CHECK(result == correct); - } -} From 1d5140d5e98c18e73ce576673aa34022ad6d804f Mon Sep 17 00:00:00 2001 From: Marsella8 <45826022+Marsella8@users.noreply.github.com> Date: Thu, 10 Oct 2024 07:53:36 -0700 Subject: [PATCH 12/12] Utils: Refactor and Test Updates (#1464) * removed GraphInternal * minor changes * Added views.cc testing * Fixed containers testing * views.cc fix * test reorganizing * fmt * test rearranging * rearranging * graph fixes * undirected graph test fix * Updated docs * Test updates + bug fixes * minor fix * PR fixes * moved graph-testing to other PR * minor fixes * Remove unnecessary includes * Post-merge fixes * Small bugfix and remove some unnecessary includes --------- Co-authored-by: Pietro Max Marsella Co-authored-by: Colin Unger Co-authored-by: Colin Unger --- .../machine_mapping/machine_mapping.cc | 1 - lib/kernels/include/kernels/accessor.h | 1 - .../include/kernels/initializer_kernels.h | 1 - lib/kernels/test/src/test_dropout.cc | 1 - lib/pcg/include/pcg/optimizer_attrs.h | 2 +- lib/pcg/src/pcg/operator_task_space.cc | 3 +- lib/runtime/src/model.cc | 1 - .../src/task_invocation_compilation.cc | 1 - .../task_spec/task_invocation_compilation.cc | 1 - .../src/task_spec/tensor_args_format.cc | 2 +- .../src/substitutions/pcg_pattern_match.cc | 4 +- .../output_expr_to_result_sub_pcg_mapping.cc | 4 +- .../unlabelled/multidigraph_pattern_match.cc | 56 --- ...rge_bidicts.h => merge_disjoint_bidicts.h} | 18 +- lib/utils/include/utils/containers.decl.h | 88 ---- lib/utils/include/utils/containers.h | 233 ----------- .../utils/containers/are_all_distinct.h | 16 + .../include/utils/containers/compare_by.h | 15 + .../include/utils/containers/get_first.h | 21 - .../include/utils/containers/get_one_of.h | 20 + lib/utils/include/utils/containers/index_of.h | 25 ++ .../include/utils/containers/is_submapeq_of.h | 18 + .../include/utils/containers/is_subseteq_of.h | 10 +- .../utils/containers/is_superseteq_of.h | 17 + lib/utils/include/utils/containers/map_keys.h | 14 + lib/utils/include/utils/containers/maximum.h | 13 +- lib/utils/include/utils/containers/product.h | 3 + .../include/utils/containers/product_where.h | 26 ++ .../utils/containers/reversed_container.h | 85 ++++ lib/utils/include/utils/containers/subvec.h | 15 +- lib/utils/include/utils/containers/sum.h | 13 +- .../include/utils/containers/sum_where.h | 24 ++ .../include/utils/containers/value_all.h | 34 ++ .../include/utils/containers/vector_split.h | 10 +- lib/utils/include/utils/fmt/expected.h | 2 +- lib/utils/include/utils/fmt/map.h | 2 +- lib/utils/include/utils/fmt/multiset.h | 2 +- lib/utils/include/utils/fmt/optional.h | 12 +- lib/utils/include/utils/fmt/pair.h | 2 +- lib/utils/include/utils/fmt/set.h | 2 +- lib/utils/include/utils/fmt/unordered_map.h | 2 +- .../include/utils/fmt/unordered_multiset.h | 2 +- lib/utils/include/utils/fmt/unordered_set.h | 2 +- lib/utils/include/utils/fmt/variant.h | 2 +- lib/utils/include/utils/fmt/vector.h | 2 +- lib/utils/include/utils/graph/cow_ptr_t.h | 2 - .../instances/hashmap_undirected_graph.h | 0 lib/utils/include/utils/join_strings.h | 9 +- lib/utils/include/utils/rapidcheck/variant.h | 19 + lib/utils/include/utils/stack_string.h | 7 + lib/utils/include/utils/unique.h | 13 - lib/utils/include/utils/variant.h | 11 - lib/utils/src/containers.cc | 1 - .../utils/bidict/algorithms/merge_bidicts.cc | 1 - .../algorithms/merge_disjoint_bidicts.cc | 1 + .../src/utils/cli/cli_get_help_message.cc | 3 +- .../src/utils/containers/are_all_distinct.cc | 1 + lib/utils/src/utils/containers/compare_by.cc | 1 + lib/utils/src/utils/containers/get_first.cc | 1 - lib/utils/src/utils/containers/get_one_of.cc | 1 + lib/utils/src/utils/containers/index_of.cc | 1 + .../src/utils/containers/is_submapeq_of.cc | 1 + .../src/utils/containers/is_superseteq_of.cc | 1 + .../src/utils/containers/product_where.cc | 1 + .../utils/containers/reversed_container.cc | 1 + lib/utils/src/utils/containers/sum_where.cc | 1 + lib/utils/src/utils/containers/value_all.cc | 1 + lib/utils/src/utils/fmt/optional.cc | 8 + .../algorithms/find_isomorphism.cc | 4 +- .../get_cbc_decomposition.cc | 4 - .../is_complete_bipartite_digraph.cc | 1 - .../algorithms/get_imm_post_dominator.cc | 4 +- .../algorithms/find_isomorphism.cc | 4 +- .../algorithms/find_isomorphisms.cc | 2 +- lib/utils/src/utils/rapidcheck/variant.cc | 10 + .../include/test/utils/doctest/fmt/optional.h | 5 + .../src/test/utils/doctest/fmt/optional.cc | 8 + lib/utils/test/src/test_algorithms.cc | 246 ----------- lib/utils/test/src/test_containers.cc | 393 ------------------ lib/utils/test/src/test_multidigraph.cc | 94 ----- lib/utils/test/src/test_random_utils.cc | 67 --- lib/utils/test/src/test_type_index.cc | 35 -- lib/utils/test/src/test_undirected_graph.cc | 62 --- .../algorithms/merge_disjoint_bidicts.cc | 42 ++ lib/utils/test/src/utils/bidict/bidict.cc | 36 +- lib/utils/test/src/utils/containers/all_of.cc | 13 + .../src/utils/containers/are_all_distinct.cc | 24 ++ .../test/src/utils/containers/are_disjoint.cc | 31 ++ .../test/src/utils/containers/compare_by.cc | 19 + .../test/src/utils/containers/contains.cc | 13 + lib/utils/test/src/utils/containers/count.cc | 13 + .../test/src/utils/containers/enumerate.cc | 25 +- .../test/src/utils/containers/filter_keys.cc | 18 + lib/utils/test/src/utils/containers/find.cc | 54 +++ .../test/src/utils/containers/flatmap.cc | 34 ++ .../test/src/utils/containers/get_one_of.cc | 19 + .../test/src/utils/containers/index_of.cc | 24 ++ .../src/utils/containers/is_submapeq_of.cc | 40 ++ .../src/utils/containers/is_subseteq_of.cc | 16 + .../src/utils/containers/is_superseteq_of.cc | 25 ++ lib/utils/test/src/utils/containers/keys.cc | 18 + .../test/src/utils/containers/map_keys.cc | 26 ++ .../test/src/utils/containers/map_values.cc | 17 + .../test/src/utils/containers/maximum.cc | 53 +-- .../test/src/utils/containers/merge_maps.cc | 30 ++ .../test/src/utils/containers/product.cc | 32 ++ .../src/utils/containers/product_where.cc | 34 ++ .../src/utils/containers/restrict_keys.cc | 17 + .../test/src/utils/containers/set_union.cc | 16 + .../test/src/utils/containers/sorted_by.cc | 35 ++ lib/utils/test/src/utils/containers/subvec.cc | 62 +++ .../test/src/utils/containers/sum_where.cc | 36 ++ .../test/src/utils/containers/value_all.cc | 24 ++ lib/utils/test/src/utils/containers/values.cc | 19 + .../test/src/utils/containers/vector_split.cc | 39 ++ .../deduplicated_priority_queue.cc} | 2 + .../disjoint_set.cc} | 19 +- .../{test_dot_file.cc => utils/dot_file.cc} | 0 .../utils/graph/multidigraph/multidigraph.cc | 171 ++++++++ lib/utils/test/src/utils/hash/map.cc | 20 + lib/utils/test/src/utils/hash/set.cc | 20 + lib/utils/test/src/utils/hash/tuple.cc | 21 + .../hash/unordered_map.cc} | 8 +- .../test/src/utils/hash/unordered_set.cc | 20 + lib/utils/test/src/utils/hash/vector.cc | 20 + lib/utils/test/src/utils/join_strings.cc | 47 +++ lib/utils/test/src/utils/random_utils.cc | 60 +++ .../test/src/utils/rapidcheck/variant.cc | 13 + .../record_formatter.cc} | 0 .../{test_sequence.cc => utils/sequence.cc} | 0 .../{test_stack_map.cc => utils/stack_map.cc} | 15 +- .../stack_string.cc} | 6 +- .../stack_vector.cc} | 31 +- .../src/{test_tuple.cc => utils/tuple.cc} | 27 +- lib/utils/test/src/utils/type_index.cc | 34 ++ .../src/{test_variant.cc => utils/variant.cc} | 11 +- .../src/{test_vector.cc => utils/vector.cc} | 4 + 137 files changed, 1791 insertions(+), 1515 deletions(-) delete mode 100644 lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc rename lib/utils/include/utils/bidict/algorithms/{merge_bidicts.h => merge_disjoint_bidicts.h} (52%) delete mode 100644 lib/utils/include/utils/containers.decl.h delete mode 100644 lib/utils/include/utils/containers.h create mode 100644 lib/utils/include/utils/containers/are_all_distinct.h create mode 100644 lib/utils/include/utils/containers/compare_by.h delete mode 100644 lib/utils/include/utils/containers/get_first.h create mode 100644 lib/utils/include/utils/containers/get_one_of.h create mode 100644 lib/utils/include/utils/containers/index_of.h create mode 100644 lib/utils/include/utils/containers/is_submapeq_of.h create mode 100644 lib/utils/include/utils/containers/is_superseteq_of.h create mode 100644 lib/utils/include/utils/containers/product_where.h create mode 100644 lib/utils/include/utils/containers/reversed_container.h create mode 100644 lib/utils/include/utils/containers/sum_where.h create mode 100644 lib/utils/include/utils/containers/value_all.h rename lib/utils/{src => include}/utils/graph/instances/hashmap_undirected_graph.h (100%) create mode 100644 lib/utils/include/utils/rapidcheck/variant.h delete mode 100644 lib/utils/include/utils/unique.h delete mode 100644 lib/utils/src/containers.cc delete mode 100644 lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc create mode 100644 lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc create mode 100644 lib/utils/src/utils/containers/are_all_distinct.cc create mode 100644 lib/utils/src/utils/containers/compare_by.cc delete mode 100644 lib/utils/src/utils/containers/get_first.cc create mode 100644 lib/utils/src/utils/containers/get_one_of.cc create mode 100644 lib/utils/src/utils/containers/index_of.cc create mode 100644 lib/utils/src/utils/containers/is_submapeq_of.cc create mode 100644 lib/utils/src/utils/containers/is_superseteq_of.cc create mode 100644 lib/utils/src/utils/containers/product_where.cc create mode 100644 lib/utils/src/utils/containers/reversed_container.cc create mode 100644 lib/utils/src/utils/containers/sum_where.cc create mode 100644 lib/utils/src/utils/containers/value_all.cc create mode 100644 lib/utils/src/utils/rapidcheck/variant.cc delete mode 100644 lib/utils/test/src/test_algorithms.cc delete mode 100644 lib/utils/test/src/test_containers.cc delete mode 100644 lib/utils/test/src/test_multidigraph.cc delete mode 100644 lib/utils/test/src/test_random_utils.cc delete mode 100644 lib/utils/test/src/test_type_index.cc delete mode 100644 lib/utils/test/src/test_undirected_graph.cc create mode 100644 lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc create mode 100644 lib/utils/test/src/utils/containers/all_of.cc create mode 100644 lib/utils/test/src/utils/containers/are_all_distinct.cc create mode 100644 lib/utils/test/src/utils/containers/are_disjoint.cc create mode 100644 lib/utils/test/src/utils/containers/compare_by.cc create mode 100644 lib/utils/test/src/utils/containers/contains.cc create mode 100644 lib/utils/test/src/utils/containers/count.cc create mode 100644 lib/utils/test/src/utils/containers/filter_keys.cc create mode 100644 lib/utils/test/src/utils/containers/find.cc create mode 100644 lib/utils/test/src/utils/containers/get_one_of.cc create mode 100644 lib/utils/test/src/utils/containers/index_of.cc create mode 100644 lib/utils/test/src/utils/containers/is_submapeq_of.cc create mode 100644 lib/utils/test/src/utils/containers/is_subseteq_of.cc create mode 100644 lib/utils/test/src/utils/containers/is_superseteq_of.cc create mode 100644 lib/utils/test/src/utils/containers/keys.cc create mode 100644 lib/utils/test/src/utils/containers/map_keys.cc create mode 100644 lib/utils/test/src/utils/containers/map_values.cc create mode 100644 lib/utils/test/src/utils/containers/merge_maps.cc create mode 100644 lib/utils/test/src/utils/containers/product.cc create mode 100644 lib/utils/test/src/utils/containers/product_where.cc create mode 100644 lib/utils/test/src/utils/containers/restrict_keys.cc create mode 100644 lib/utils/test/src/utils/containers/set_union.cc create mode 100644 lib/utils/test/src/utils/containers/sorted_by.cc create mode 100644 lib/utils/test/src/utils/containers/subvec.cc create mode 100644 lib/utils/test/src/utils/containers/sum_where.cc create mode 100644 lib/utils/test/src/utils/containers/value_all.cc create mode 100644 lib/utils/test/src/utils/containers/values.cc create mode 100644 lib/utils/test/src/utils/containers/vector_split.cc rename lib/utils/test/src/{test_deduplicated_priority_queue.cc => utils/deduplicated_priority_queue.cc} (96%) rename lib/utils/test/src/{test_disjoint_set.cc => utils/disjoint_set.cc} (73%) rename lib/utils/test/src/{test_dot_file.cc => utils/dot_file.cc} (100%) create mode 100644 lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc create mode 100644 lib/utils/test/src/utils/hash/map.cc create mode 100644 lib/utils/test/src/utils/hash/set.cc create mode 100644 lib/utils/test/src/utils/hash/tuple.cc rename lib/utils/test/src/{test_hash.cc => utils/hash/unordered_map.cc} (69%) create mode 100644 lib/utils/test/src/utils/hash/unordered_set.cc create mode 100644 lib/utils/test/src/utils/hash/vector.cc create mode 100644 lib/utils/test/src/utils/join_strings.cc create mode 100644 lib/utils/test/src/utils/random_utils.cc create mode 100644 lib/utils/test/src/utils/rapidcheck/variant.cc rename lib/utils/test/src/{test_format.cc => utils/record_formatter.cc} (100%) rename lib/utils/test/src/{test_sequence.cc => utils/sequence.cc} (100%) rename lib/utils/test/src/{test_stack_map.cc => utils/stack_map.cc} (73%) rename lib/utils/test/src/{test_stack_string.cc => utils/stack_string.cc} (95%) rename lib/utils/test/src/{test_stack_vector.cc => utils/stack_vector.cc} (68%) rename lib/utils/test/src/{test_tuple.cc => utils/tuple.cc} (73%) create mode 100644 lib/utils/test/src/utils/type_index.cc rename lib/utils/test/src/{test_variant.cc => utils/variant.cc} (91%) rename lib/utils/test/src/{test_vector.cc => utils/vector.cc} (91%) diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index 6f350d8773..57e82684e9 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -1,5 +1,4 @@ #include "compiler/machine_mapping/machine_mapping.h" -#include "utils/containers.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/keys.h" #include "utils/containers/merge_maps.h" diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 5fbcd91a06..39da65c3be 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -7,7 +7,6 @@ #include "op-attrs/datatype.h" #include "utils/exception.h" #include "utils/required.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/include/kernels/initializer_kernels.h b/lib/kernels/include/kernels/initializer_kernels.h index 52609a303f..9840e457e6 100644 --- a/lib/kernels/include/kernels/initializer_kernels.h +++ b/lib/kernels/include/kernels/initializer_kernels.h @@ -4,7 +4,6 @@ #include "accessor.h" #include "kernels/cpu.h" #include "op-attrs/datatype_value.dtg.h" -#include "utils/variant.h" namespace FlexFlow { diff --git a/lib/kernels/test/src/test_dropout.cc b/lib/kernels/test/src/test_dropout.cc index 981bc611d8..81f3c7183a 100644 --- a/lib/kernels/test/src/test_dropout.cc +++ b/lib/kernels/test/src/test_dropout.cc @@ -1,7 +1,6 @@ #include "doctest/doctest.h" #include "kernels/dropout_kernels.h" #include "test_utils.h" -#include "utils/containers.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { diff --git a/lib/pcg/include/pcg/optimizer_attrs.h b/lib/pcg/include/pcg/optimizer_attrs.h index 4bac74b999..3e787503d6 100644 --- a/lib/pcg/include/pcg/optimizer_attrs.h +++ b/lib/pcg/include/pcg/optimizer_attrs.h @@ -3,7 +3,7 @@ #include "pcg/optimizers/adam_optimizer_attrs.h" #include "pcg/optimizers/sgd_optimizer_attrs.h" -#include "utils/variant.h" +#include namespace FlexFlow { diff --git a/lib/pcg/src/pcg/operator_task_space.cc b/lib/pcg/src/pcg/operator_task_space.cc index 02522ae411..2538cb4ea0 100644 --- a/lib/pcg/src/pcg/operator_task_space.cc +++ b/lib/pcg/src/pcg/operator_task_space.cc @@ -5,6 +5,7 @@ #include "utils/containers/range.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" +#include "utils/fmt/unordered_set.h" namespace FlexFlow { @@ -25,7 +26,7 @@ std::unordered_set TaskSpaceCoordinate get_task_space_maximum_coordinate(OperatorTaskSpace const &task) { - return maximum(get_task_space_coordinates(task)).value(); + return maximum(get_task_space_coordinates(task)); } size_t num_dims(OperatorTaskSpace const &task) { diff --git a/lib/runtime/src/model.cc b/lib/runtime/src/model.cc index a655bcb050..22f0f2e98d 100644 --- a/lib/runtime/src/model.cc +++ b/lib/runtime/src/model.cc @@ -27,7 +27,6 @@ #include "parallel_tensor_mapping.h" #include "task_spec/task_argument_accessor.h" #include "test_utils.h" -#include "utils/containers.h" #include "utils/random_utils.h" #include #include diff --git a/lib/runtime/src/task_invocation_compilation.cc b/lib/runtime/src/task_invocation_compilation.cc index bfeb2be6d4..ff281370e7 100644 --- a/lib/runtime/src/task_invocation_compilation.cc +++ b/lib/runtime/src/task_invocation_compilation.cc @@ -1,5 +1,4 @@ #include "task_invocation_compilation.h" -#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/runtime/src/task_spec/task_invocation_compilation.cc b/lib/runtime/src/task_spec/task_invocation_compilation.cc index bfeb2be6d4..ff281370e7 100644 --- a/lib/runtime/src/task_spec/task_invocation_compilation.cc +++ b/lib/runtime/src/task_spec/task_invocation_compilation.cc @@ -1,5 +1,4 @@ #include "task_invocation_compilation.h" -#include "utils/containers.h" namespace FlexFlow { diff --git a/lib/runtime/src/task_spec/tensor_args_format.cc b/lib/runtime/src/task_spec/tensor_args_format.cc index 22bd021996..8fb7fc029b 100644 --- a/lib/runtime/src/task_spec/tensor_args_format.cc +++ b/lib/runtime/src/task_spec/tensor_args_format.cc @@ -1,5 +1,5 @@ #include "tensor_args_format.h" -#include "utils/containers.h" +#include "utils/containers/flatmap.h" namespace FlexFlow { diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index f1f4e31d57..b701be65cf 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -2,7 +2,7 @@ #include "substitutions/pcg_pattern.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" -#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/containers/map_values.h" #include "utils/containers/zip.h" @@ -27,7 +27,7 @@ bidict bidict_from_keys_and_values(pattern_node_outputs, matched_layer_output_tensors); - result = merge_bidicts(result, mapping); + result = merge_disjoint_bidicts(result, mapping); } return result; diff --git a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc index 083334f0db..22e6a9f333 100644 --- a/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/substitution_internal/output_expr_to_result_sub_pcg_mapping.cc @@ -2,7 +2,7 @@ #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" -#include "utils/bidict/algorithms/merge_bidicts.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" namespace FlexFlow { @@ -23,7 +23,7 @@ bidict mapping_for_layer = bidict_from_keys_and_values( layer_outputs, output_graph_expr_outputs); - result = merge_bidicts(result, mapping_for_layer); + result = merge_disjoint_bidicts(result, mapping_for_layer); } return result; diff --git a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc b/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc deleted file mode 100644 index 8ce60fab4f..0000000000 --- a/lib/substitutions/src/substitutions/unlabelled/multidigraph_pattern_match.cc +++ /dev/null @@ -1,56 +0,0 @@ -#include "substitutions/unlabelled/multidigraph_pattern_match.h" -// #include "substitutions/unlabelled/edge_splits.h" -// #include "substitutions/unlabelled/pattern_edge.h" -#include "utils/containers.h" - -namespace FlexFlow { - -// MultiDiGraphPatternMatch empty_multidigraph_pattern_match() { -// return MultiDiGraphPatternMatch{ -// bidict{}, -// bidict{}, -// }; -// } - -// std::optional -// unsplit_matches(MultiDiGraphPatternMatch const &prefix, -// MultiDiGraphPatternMatch const &postfix, -// UnlabelledPatternEdgeSplits const &edge_splits) { -// -// MultiDiGraphPatternMatch result = empty_multidigraph_pattern_match(); -// -// std::unordered_set handled; -// for (auto const &coi : as_closed_output_input_tuples(edge_splits)) { -// ClosedPatternEdge closed_edge = std::get(coi); -// OutputPatternEdge output_edge = std::get(coi); -// InputPatternEdge input_edge = std::get(coi); -// -// handled.insert(pattern_edge_from_output_edge(output_edge)); -// handled.insert(pattern_edge_from_input_edge(input_edge)); -// -// OpenMultiDiEdge output_graph_edge = -// prefix.edge_assignment.at_l(pattern_edge_from_output_edge(output_edge)); -// OpenMultiDiEdge input_graph_edge = -// postfix.edge_assignment.at_l(pattern_edge_from_input_edge(input_edge)); -// if (output_graph_edge == input_graph_edge) { -// result.edge_assignment.equate(pattern_edge_from_closed_edge(closed_edge), -// output_graph_edge); -// } else { -// return std::nullopt; -// } -// } -// -// for (auto const &kv : -// merge_maps(prefix.edge_assignment, postfix.edge_assignment)) { -// if (!contains(handled, kv.first)) { -// result.edge_assignment.equate(kv.first, kv.second); -// } -// } -// -// result.node_assignment = -// merge_maps(prefix.node_assignment, postfix.node_assignment); -// -// return result; -// } - -} // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h similarity index 52% rename from lib/utils/include/utils/bidict/algorithms/merge_bidicts.h rename to lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index d388e35d75..97e7334c26 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,17 +1,25 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_BIDICTS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" #include "utils/bidict/bidict.h" #include "utils/containers/are_disjoint.h" +#include "utils/exception.h" namespace FlexFlow { template -bidict merge_bidicts(bidict const &lhs, bidict const &rhs) { - assert(are_disjoint(left_entries(lhs), left_entries(rhs))); - assert(are_disjoint(right_entries(lhs), right_entries(rhs))); +bidict merge_disjoint_bidicts(bidict const &lhs, + bidict const &rhs) { + if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); + } + if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); + } bidict result; for (auto const &kv : lhs) { diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h deleted file mode 100644 index cb652a9e69..0000000000 --- a/lib/utils/include/utils/containers.decl.h +++ /dev/null @@ -1,88 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H - -#include "utils/bidict/bidict.h" -#include "utils/containers/get_element_type.h" -#include "utils/required_core.h" -#include "utils/type_traits_core.h" -#include -#include -#include - -namespace FlexFlow { - -template -Element sum_where(Container const &container, ConditionF const &condition); - -template -Element product_where(Container const &container, ConditionF const &condition); - -template -bool contains_l(bidict const &m, K const &k); - -template -bool contains_r(bidict const &m, V const &v); - -template -std::unordered_map filter_values(std::unordered_map const &m, - F const &f); - -template -std::optional index_of(Container const &c, Element const &e); - -template -std::unordered_map restrict_keys(std::unordered_map const &m, - std::unordered_set const &mask); - -template -std::optional at_idx(std::vector const &v, size_t idx); - -template -std::function lookup_in(std::unordered_map const &m); - -template -std::function lookup_in_l(bidict const &m); - -template -std::function lookup_in_r(bidict const &m); - -template -bool is_supserseteq_of(std::unordered_set const &l, - std::unordered_set const &r); - -template -std::unordered_set - map_over_unordered_set(std::function const &f, - std::unordered_set const &input); - -template -std::optional maybe_get_only(C const &c); - -template -std::optional optional_all_of(Container const &, Function const &); - -template -std::function compare_by(F const &f); - -template -T reversed(T const &t); - -template -std::vector value_all(std::vector> const &v); - -template -std::unordered_set value_all(std::unordered_set> const &v); - -template -struct reversed_container_t; - -template -reversed_container_t reversed_container(C const &c); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h deleted file mode 100644 index 0e3b1fc0bd..0000000000 --- a/lib/utils/include/utils/containers.h +++ /dev/null @@ -1,233 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_INL - -#include "containers.decl.h" -#include "required_core.h" -#include "type_traits_core.h" -#include "utils/bidict/bidict.h" -#include "utils/containers/contains.h" -#include "utils/containers/extend.h" -#include "utils/containers/extend_vector.h" -#include "utils/containers/filter.h" -#include "utils/containers/intersection.h" -#include "utils/containers/is_subseteq_of.h" -#include "utils/containers/sorted.h" -#include "utils/containers/transform.h" -#include "utils/containers/vector_transform.h" -#include "utils/exception.h" -#include "utils/hash/pair.h" -#include "utils/type_traits.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace FlexFlow { - -template -Element sum_where(Container const &container, ConditionF const &condition) { - Element result = 0; - for (Element const &element : container) { - if (condition(element)) { - result += element; - } - } - return result; -} - -template -Element product_where(Container const &container, ConditionF const &condition) { - Element result = 1; - for (Element const &element : container) { - if (condition(element)) { - result *= element; - } - } - return result; -} - -template -bool contains_l(bidict const &m, K const &k) { - return m.contains_l(k); -} - -template -bool contains_r(bidict const &m, V const &v) { - return m.contains_r(v); -} - -template -bool is_submap(std::unordered_map const &m, - std::unordered_map const &sub) { - return restrict_keys(m, keys(sub)) == sub; -} - -template -std::optional index_of(Container const &c, Element const &e) { - auto it = std::find(c.cbegin(), c.cend(), e); - if (it == c.cend()) { - return std::nullopt; - } else { - return std::distance(c.cbegin(), it); - } -} - -template -std::function lookup_in(std::unordered_map const &m) { - return [&m](K const &k) -> V { return m.at(k); }; -} - -template -std::function lookup_in_l(bidict const &m) { - return [&m](L const &l) -> R { return m.at_l(l); }; -} - -template -std::function lookup_in_r(bidict const &m) { - return [&m](R const &r) -> L { return m.at_r(r); }; -} - -template -bool is_supserseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { - return is_subseteq_of(r, l); -} - -template -std::unordered_set - map_over_unordered_set(std::function const &f, - std::unordered_set const &input) { - std::unordered_set result; - std::transform( - input.cbegin(), input.cend(), std::inserter(result, result.begin()), f); - return result; -} - -template -std::optional optional_all_of(Container const &container, - Function const &func) { - for (auto const &element : container) { - std::optional condition = func(element); - if (!condition.has_value()) { - return std::nullopt; - } - - if (!condition.value()) { - return false; - } - } - return true; -} - -template -std::function compare_by(F const &f) { - return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; -} - -template -std::vector value_all(std::vector> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); - }); - }); -} - -template -std::unordered_set value_all(std::unordered_set> const &v) { - return transform(v, [](std::optional const &element) { - return unwrap(element, [] { - throw mk_runtime_error( - "Encountered element without value in call to value_all"); - }); - }); -} - -template -struct reversed_container_t { - reversed_container_t() = delete; - reversed_container_t(C const &c) : container(c) {} - - reversed_container_t(reversed_container_t const &) = delete; - reversed_container_t(reversed_container_t &&) = delete; - reversed_container_t &operator=(reversed_container_t const &) = delete; - reversed_container_t &operator=(reversed_container_t &&) = delete; - - using iterator = typename C::reverse_iterator; - using const_iterator = typename C::const_reverse_iterator; - using reverse_iterator = typename C::iterator; - using const_reverse_iterator = typename C::const_iterator; - using value_type = typename C::value_type; - using pointer = typename C::pointer; - using const_pointer = typename C::const_pointer; - using reference = typename C::reference; - using const_reference = typename C::const_reference; - - iterator begin() { - return this->container.rend(); - } - - iterator end() { - return this->container.rbegin(); - } - - const_iterator cbegin() const { - return this->container.crend(); - } - - const_iterator cend() const { - return this->container.crbegin(); - } - - const_iterator begin() const { - return this->cbegin(); - } - - const_iterator end() const { - return this->cend(); - } - - reverse_iterator rbegin() { - return this->container.begin(); - } - - reverse_iterator rend() { - return this->container.end(); - } - - const_reverse_iterator crbegin() const { - return this->container.cbegin(); - } - - const_reverse_iterator crend() const { - return this->container.cend(); - } - - const_reverse_iterator rbegin() const { - return this->crbegin(); - } - - const_reverse_iterator rend() const { - return this->crend(); - } - -private: - C const &container; -}; - -template -reversed_container_t reversed_container(C const &c) { - return reversed_container_t(c); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/are_all_distinct.h b/lib/utils/include/utils/containers/are_all_distinct.h new file mode 100644 index 0000000000..d02845ba16 --- /dev/null +++ b/lib/utils/include/utils/containers/are_all_distinct.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_DISTINCT_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_ARE_ALL_DISTINCT_H + +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/unordered_set_of.h" + +namespace FlexFlow { + +template +bool are_all_distinct(C const &c) { + return unordered_set_of(c).size() == unordered_multiset_of(c).size(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/compare_by.h b/lib/utils/include/utils/containers/compare_by.h new file mode 100644 index 0000000000..d6cb7f48cd --- /dev/null +++ b/lib/utils/include/utils/containers/compare_by.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_COMPARE_BY_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_COMPARE_BY_H + +#include + +namespace FlexFlow { + +template +std::function compare_by(F const &f) { + return [=](T const &lhs, T const &rhs) { return f(lhs) < f(rhs); }; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_first.h b/lib/utils/include/utils/containers/get_first.h deleted file mode 100644 index a616c44c20..0000000000 --- a/lib/utils/include/utils/containers/get_first.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_FIRST_H - -#include -#include - -namespace FlexFlow { - -template -T get_first(std::unordered_set const &s) { - return *s.cbegin(); -} - -template -T get_first(std::set const &s) { - return *s.cbegin(); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/containers/get_one_of.h b/lib/utils/include/utils/containers/get_one_of.h new file mode 100644 index 0000000000..47c46fb1d6 --- /dev/null +++ b/lib/utils/include/utils/containers/get_one_of.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONE_OF_H + +#include "utils/exception.h" +#include "utils/fmt/unordered_set.h" +#include +namespace FlexFlow { + +template +T get_one_of(std::unordered_set const &s) { + if (s.empty()) { + throw mk_runtime_error(fmt::format( + "get_one_of expected non-empty container but receieved {}", s)); + } + return *s.cbegin(); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/index_of.h b/lib/utils/include/utils/containers/index_of.h new file mode 100644 index 0000000000..1792490d0c --- /dev/null +++ b/lib/utils/include/utils/containers/index_of.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INDEX_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INDEX_OF_H + +#include +#include + +namespace FlexFlow { + +/** + * @details If multiple `e` are present within the container, the function + * returns the index of the first appearance + **/ +template +std::optional index_of(Container const &c, Element const &e) { + auto it = std::find(c.cbegin(), c.cend(), e); + if (it == c.cend()) { + return std::nullopt; + } else { + return std::distance(c.cbegin(), it); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/is_submapeq_of.h b/lib/utils/include/utils/containers/is_submapeq_of.h new file mode 100644 index 0000000000..03cb5ccd78 --- /dev/null +++ b/lib/utils/include/utils/containers/is_submapeq_of.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H + +#include "utils/containers/keys.h" +#include "utils/containers/restrict_keys.h" +#include + +namespace FlexFlow { + +template +bool is_submapeq_of(std::unordered_map const &sub, + std::unordered_map const &m) { + return restrict_keys(m, keys(sub)) == sub; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/is_subseteq_of.h b/lib/utils/include/utils/containers/is_subseteq_of.h index 26543ca75b..705c092962 100644 --- a/lib/utils/include/utils/containers/is_subseteq_of.h +++ b/lib/utils/include/utils/containers/is_subseteq_of.h @@ -7,14 +7,14 @@ namespace FlexFlow { template -bool is_subseteq_of(std::unordered_set const &l, - std::unordered_set const &r) { - if (l.size() > r.size()) { +bool is_subseteq_of(std::unordered_set const &sub, + std::unordered_set const &super) { + if (sub.size() > super.size()) { return false; } - for (auto const &ll : l) { - if (!contains(r, ll)) { + for (auto const &s : sub) { + if (!contains(super, s)) { return false; } } diff --git a/lib/utils/include/utils/containers/is_superseteq_of.h b/lib/utils/include/utils/containers/is_superseteq_of.h new file mode 100644 index 0000000000..23b16d92f9 --- /dev/null +++ b/lib/utils/include/utils/containers/is_superseteq_of.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUPERSETEQ_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUPERSETEQ_OF_H + +#include "utils/containers/is_subseteq_of.h" +#include + +namespace FlexFlow { + +template +bool is_superseteq_of(std::unordered_set const &super, + std::unordered_set const &sub) { + return is_subseteq_of(sub, super); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index e252333e93..4e5352748d 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -1,21 +1,35 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H +#include "utils/containers/keys.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/exception.h" #include #include namespace FlexFlow { +/** + * @brief Applies the given function to all the keys within the given map and + * returns the updated map. + */ template > std::unordered_map map_keys(std::unordered_map const &m, F const &f) { + std::unordered_map result; for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); } + if (keys(m).size() != keys(result).size()) { + throw mk_runtime_error( + "keys passed to map_keys must be transformed into distinct keys"); + } + return result; } diff --git a/lib/utils/include/utils/containers/maximum.h b/lib/utils/include/utils/containers/maximum.h index 634bb61bc1..b3d6d0c6d7 100644 --- a/lib/utils/include/utils/containers/maximum.h +++ b/lib/utils/include/utils/containers/maximum.h @@ -1,18 +1,19 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAXIMUM_H +#include "utils/exception.h" #include -#include namespace FlexFlow { -template -std::optional maximum(C const &v) { - if (v.empty()) { - return std::nullopt; +template +typename C::value_type maximum(C const &c) { + if (c.empty()) { + throw mk_runtime_error( + fmt::format("maximum expected non-empty container but received {}", c)); } - return *std::max_element(std::cbegin(v), std::cend(v)); + return *std::max_element(c.begin(), c.end()); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/product.h b/lib/utils/include/utils/containers/product.h index 52ff36e790..af04edcb81 100644 --- a/lib/utils/include/utils/containers/product.h +++ b/lib/utils/include/utils/containers/product.h @@ -5,6 +5,9 @@ namespace FlexFlow { +/** + * @details An empty container vacuously has product 1 + **/ template Element product(Container const &container) { Element result = 1; diff --git a/lib/utils/include/utils/containers/product_where.h b/lib/utils/include/utils/containers/product_where.h new file mode 100644 index 0000000000..51af47c2fa --- /dev/null +++ b/lib/utils/include/utils/containers/product_where.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PRODUCT_WHERE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_PRODUCT_WHERE_H + +#include + +namespace FlexFlow { + +/** + * @details An empty container vacuously has product 1 + **/ +template +Element product_where(Container const &container, ConditionF const &condition) { + Element result = 1; + for (Element const &element : container) { + if (condition(element)) { + result *= element; + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/reversed_container.h b/lib/utils/include/utils/containers/reversed_container.h new file mode 100644 index 0000000000..cffef3c6e6 --- /dev/null +++ b/lib/utils/include/utils/containers/reversed_container.h @@ -0,0 +1,85 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_CONTAINER_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REVERSED_CONTAINER_H + +namespace FlexFlow { + +template +struct reversed_container_t { + reversed_container_t() = delete; + reversed_container_t(C const &c) : container(c) {} + + reversed_container_t(reversed_container_t const &) = delete; + reversed_container_t(reversed_container_t &&) = delete; + reversed_container_t &operator=(reversed_container_t const &) = delete; + reversed_container_t &operator=(reversed_container_t &&) = delete; + + using iterator = typename C::reverse_iterator; + using const_iterator = typename C::const_reverse_iterator; + using reverse_iterator = typename C::iterator; + using const_reverse_iterator = typename C::const_iterator; + using value_type = typename C::value_type; + using pointer = typename C::pointer; + using const_pointer = typename C::const_pointer; + using reference = typename C::reference; + using const_reference = typename C::const_reference; + + iterator begin() { + return this->container.rend(); + } + + iterator end() { + return this->container.rbegin(); + } + + const_iterator cbegin() const { + return this->container.crend(); + } + + const_iterator cend() const { + return this->container.crbegin(); + } + + const_iterator begin() const { + return this->cbegin(); + } + + const_iterator end() const { + return this->cend(); + } + + reverse_iterator rbegin() { + return this->container.begin(); + } + + reverse_iterator rend() { + return this->container.end(); + } + + const_reverse_iterator crbegin() const { + return this->container.cbegin(); + } + + const_reverse_iterator crend() const { + return this->container.cend(); + } + + const_reverse_iterator rbegin() const { + return this->crbegin(); + } + + const_reverse_iterator rend() const { + return this->crend(); + } + +private: + C const &container; +}; + +template +reversed_container_t reversed_container(C const &c) { + return reversed_container_t(c); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/subvec.h b/lib/utils/include/utils/containers/subvec.h index 5ae90ec5ba..c89e9227de 100644 --- a/lib/utils/include/utils/containers/subvec.h +++ b/lib/utils/include/utils/containers/subvec.h @@ -1,7 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUBVEC_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUBVEC_H +#include "utils/exception.h" #include +#include #include namespace FlexFlow { @@ -15,11 +17,15 @@ std::vector subvec(std::vector const &v, auto resolve_loc = [&](int idx) -> typename std::vector::iterator::difference_type { + int size = static_cast(v.size()); + int new_idx = idx; if (idx < 0) { - return v.size() + idx; - } else { - return idx; + new_idx = size + idx; } + if (new_idx < 0 || new_idx > size) { + throw mk_runtime_error("Index {} is out of bounds for array {}"); + } + return new_idx; }; if (maybe_start.has_value()) { @@ -29,6 +35,9 @@ std::vector subvec(std::vector const &v, if (maybe_end.has_value()) { end_iter = v.cbegin() + resolve_loc(maybe_end.value()); } + if (begin_iter >= end_iter) { + return {}; + } if (end_iter < begin_iter) { end_iter = begin_iter; diff --git a/lib/utils/include/utils/containers/sum.h b/lib/utils/include/utils/containers/sum.h index 5dbd620781..135e704045 100644 --- a/lib/utils/include/utils/containers/sum.h +++ b/lib/utils/include/utils/containers/sum.h @@ -3,11 +3,14 @@ namespace FlexFlow { -template -T sum(C const &c) { - T result = 0; - for (T const &t : c) { - result += t; +/** + * @details An empty container vacuously has sum 0 + **/ +template +Element sum(Container const &container) { + Element result = 0; + for (Element const &element : container) { + result += element; } return result; } diff --git a/lib/utils/include/utils/containers/sum_where.h b/lib/utils/include/utils/containers/sum_where.h new file mode 100644 index 0000000000..214f51c1c9 --- /dev/null +++ b/lib/utils/include/utils/containers/sum_where.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_WHERE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_SUM_WHERE_H + +namespace FlexFlow { + +/** + * @details An empty container vacuously has sum 0 + **/ +template +Element sum_where(Container const &container, ConditionF const &condition) { + Element result = 0; + for (Element const &element : container) { + if (condition(element)) { + result += element; + } + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/value_all.h b/lib/utils/include/utils/containers/value_all.h new file mode 100644 index 0000000000..5727bd8396 --- /dev/null +++ b/lib/utils/include/utils/containers/value_all.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUE_ALL_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VALUE_ALL_H + +#include "utils/containers/transform.h" +#include "utils/exception.h" +#include "utils/optional.h" + +namespace FlexFlow { + +template +std::vector value_all(std::vector> const &v) { + return transform(v, [&](std::optional const &element) { + return unwrap(element, [&] { + throw mk_runtime_error(fmt::format( + "value_all expected all elements to have values, but received {}", + v)); + }); + }); +} + +template +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [&](std::optional const &element) { + return unwrap(element, [&] { + throw mk_runtime_error(fmt::format( + "value_all expected all elements to have values, but received {}", + v)); + }); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/vector_split.h b/lib/utils/include/utils/containers/vector_split.h index a1ab12a070..872733e3ce 100644 --- a/lib/utils/include/utils/containers/vector_split.h +++ b/lib/utils/include/utils/containers/vector_split.h @@ -1,14 +1,20 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_VECTOR_SPLIT_H +#include "utils/fmt/vector.h" +#include +#include #include namespace FlexFlow { template std::pair, std::vector> vector_split(std::vector const &v, - std::size_t idx) { - assert(v.size() > idx); + int idx) { + if (idx < 0 || idx > static_cast(v.size())) { + throw std::out_of_range(fmt::format( + "Index out of range in vector_split: index = {}, vector = {}", idx, v)); + } std::vector prefix(v.begin(), v.begin() + idx); std::vector postfix(v.begin() + idx, v.end()); diff --git a/lib/utils/include/utils/fmt/expected.h b/lib/utils/include/utils/fmt/expected.h index 4170882ae6..7ef7f24eb7 100644 --- a/lib/utils/include/utils/fmt/expected.h +++ b/lib/utils/include/utils/fmt/expected.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::tl::expected const &m, FormatContext &ctx) + auto format(::tl::expected const &m, FormatContext &ctx) const -> decltype(ctx.out()) { std::string result; diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 46bf9ca8fa..9225040d4d 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -17,7 +17,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::map const &m, FormatContext &ctx) + auto format(::std::map const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(K); CHECK_FMTABLE(V); diff --git a/lib/utils/include/utils/fmt/multiset.h b/lib/utils/include/utils/fmt/multiset.h index 616b784aac..4234d90e94 100644 --- a/lib/utils/include/utils/fmt/multiset.h +++ b/lib/utils/include/utils/fmt/multiset.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::multiset const &m, FormatContext &ctx) + auto format(::std::multiset const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/optional.h b/lib/utils/include/utils/fmt/optional.h index 2364e49568..16ccf61878 100644 --- a/lib/utils/include/utils/fmt/optional.h +++ b/lib/utils/include/utils/fmt/optional.h @@ -14,7 +14,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::optional const &m, FormatContext &ctx) + auto format(::std::optional const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); @@ -29,6 +29,14 @@ struct formatter< } }; +template +struct formatter : formatter { + template + auto format(std::nullopt_t, FormatContext &ctx) const -> decltype(ctx.out()) { + return formatter::format("nullopt", ctx); + } +}; + } // namespace fmt namespace FlexFlow { @@ -40,6 +48,8 @@ std::ostream &operator<<(std::ostream &s, std::optional const &t) { return s << fmt::to_string(t); } +std::ostream &operator<<(std::ostream &, std::nullopt_t); + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/fmt/pair.h b/lib/utils/include/utils/fmt/pair.h index ab5ddd4e28..d261c344a1 100644 --- a/lib/utils/include/utils/fmt/pair.h +++ b/lib/utils/include/utils/fmt/pair.h @@ -14,7 +14,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::pair const &m, FormatContext &ctx) + auto format(::std::pair const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(L); CHECK_FMTABLE(R); diff --git a/lib/utils/include/utils/fmt/set.h b/lib/utils/include/utils/fmt/set.h index a183d37542..c46984cc5a 100644 --- a/lib/utils/include/utils/fmt/set.h +++ b/lib/utils/include/utils/fmt/set.h @@ -16,7 +16,7 @@ struct formatter<::std::set, std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::set const &m, FormatContext &ctx) + auto format(::std::set const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/unordered_map.h b/lib/utils/include/utils/fmt/unordered_map.h index 876a032fe6..12faa64e32 100644 --- a/lib/utils/include/utils/fmt/unordered_map.h +++ b/lib/utils/include/utils/fmt/unordered_map.h @@ -18,7 +18,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_map const &m, FormatContext &ctx) + auto format(::std::unordered_map const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(K); CHECK_FMTABLE(V); diff --git a/lib/utils/include/utils/fmt/unordered_multiset.h b/lib/utils/include/utils/fmt/unordered_multiset.h index 09dd3c5eab..a4c17f7f5e 100644 --- a/lib/utils/include/utils/fmt/unordered_multiset.h +++ b/lib/utils/include/utils/fmt/unordered_multiset.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_multiset const &m, FormatContext &ctx) + auto format(::std::unordered_multiset const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/unordered_set.h b/lib/utils/include/utils/fmt/unordered_set.h index be347ec5ea..20a08916fc 100644 --- a/lib/utils/include/utils/fmt/unordered_set.h +++ b/lib/utils/include/utils/fmt/unordered_set.h @@ -16,7 +16,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::unordered_set const &m, FormatContext &ctx) + auto format(::std::unordered_set const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/fmt/variant.h b/lib/utils/include/utils/fmt/variant.h index 06a56417c3..1690955286 100644 --- a/lib/utils/include/utils/fmt/variant.h +++ b/lib/utils/include/utils/fmt/variant.h @@ -11,7 +11,7 @@ struct formatter, Char> /* std::enable_if_t>::value>> */ : formatter<::std::string> { template - auto format(std::variant const &m, FormatContext &ctx) + auto format(std::variant const &m, FormatContext &ctx) const -> decltype(ctx.out()) { std::string result = diff --git a/lib/utils/include/utils/fmt/vector.h b/lib/utils/include/utils/fmt/vector.h index 5d9ca0aeae..1eec7a306b 100644 --- a/lib/utils/include/utils/fmt/vector.h +++ b/lib/utils/include/utils/fmt/vector.h @@ -15,7 +15,7 @@ struct formatter< std::enable_if_t>::value>> : formatter<::std::string> { template - auto format(::std::vector const &m, FormatContext &ctx) + auto format(::std::vector const &m, FormatContext &ctx) const -> decltype(ctx.out()) { CHECK_FMTABLE(T); diff --git a/lib/utils/include/utils/graph/cow_ptr_t.h b/lib/utils/include/utils/graph/cow_ptr_t.h index 9a655ae072..7aed437136 100644 --- a/lib/utils/include/utils/graph/cow_ptr_t.h +++ b/lib/utils/include/utils/graph/cow_ptr_t.h @@ -2,8 +2,6 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_COW_PTR_T_H #include "utils/type_traits.h" -#include "utils/unique.h" -#include "utils/variant.h" #include #include diff --git a/lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h b/lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h similarity index 100% rename from lib/utils/src/utils/graph/instances/hashmap_undirected_graph.h rename to lib/utils/include/utils/graph/instances/hashmap_undirected_graph.h diff --git a/lib/utils/include/utils/join_strings.h b/lib/utils/include/utils/join_strings.h index 9eb717b066..7b53384b39 100644 --- a/lib/utils/include/utils/join_strings.h +++ b/lib/utils/include/utils/join_strings.h @@ -13,15 +13,12 @@ std::string join_strings(InputIt first, F const &f) { std::ostringstream oss; bool first_iter = true; - /* int i = 0; */ for (; first != last; first++) { if (!first_iter) { oss << delimiter; } oss << f(*first); - /* break; */ first_iter = false; - /* i++; */ } return oss.str(); } @@ -38,6 +35,12 @@ std::string join_strings(Container const &c, std::string const &delimiter) { return join_strings(c.cbegin(), c.cend(), delimiter); } +template +std::string + join_strings(Container const &c, std::string const &delimiter, F const &f) { + return join_strings(c.cbegin(), c.cend(), delimiter, f); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/rapidcheck/variant.h b/lib/utils/include/utils/rapidcheck/variant.h new file mode 100644 index 0000000000..bc741ea340 --- /dev/null +++ b/lib/utils/include/utils/rapidcheck/variant.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_RAPIDCHECK_VARIANT_H + +#include +#include + +namespace rc { + +template +struct Arbitrary> { + static Gen> arbitrary() { + return gen::oneOf( + gen::construct>(gen::arbitrary())...); + } +}; + +} // namespace rc + +#endif diff --git a/lib/utils/include/utils/stack_string.h b/lib/utils/include/utils/stack_string.h index 7a936ebd7b..2a7b2d1849 100644 --- a/lib/utils/include/utils/stack_string.h +++ b/lib/utils/include/utils/stack_string.h @@ -82,6 +82,13 @@ void from_json(nlohmann::json const &j, stack_string &v) { v = stack_string{as_string}; } +template +std::basic_ostream & + operator<<(std::basic_ostream &s, + stack_basic_string const &v) { + return s << fmt::to_string(v); +} + } // namespace FlexFlow namespace std { diff --git a/lib/utils/include/utils/unique.h b/lib/utils/include/utils/unique.h deleted file mode 100644 index cf6eb39026..0000000000 --- a/lib/utils/include/utils/unique.h +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H -#define _FLEXFLOW_UTILS_INCLUDE_UNIQUE_H - -#include - -namespace FlexFlow { -template -std::unique_ptr make_unique(Args &&...args) { - return std::unique_ptr(new T(std::forward(args)...)); -} -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index bb2286a9cd..241d631200 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -213,15 +213,4 @@ std::optional cast(VariantIn const &v) { } // namespace FlexFlow -namespace rc { - -template -struct Arbitrary> { - static Gen> arbitrary() { - return gen::oneOf(gen::cast>(gen::arbitrary())...); - } -}; - -} // namespace rc - #endif diff --git a/lib/utils/src/containers.cc b/lib/utils/src/containers.cc deleted file mode 100644 index 2af7fd1892..0000000000 --- a/lib/utils/src/containers.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers.h" diff --git a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc deleted file mode 100644 index f70be2355f..0000000000 --- a/lib/utils/src/utils/bidict/algorithms/merge_bidicts.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/bidict/algorithms/merge_bidicts.h" diff --git a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..754b8d2e90 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -0,0 +1 @@ +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" diff --git a/lib/utils/src/utils/cli/cli_get_help_message.cc b/lib/utils/src/utils/cli/cli_get_help_message.cc index 03c53c9356..51947aa885 100644 --- a/lib/utils/src/utils/cli/cli_get_help_message.cc +++ b/lib/utils/src/utils/cli/cli_get_help_message.cc @@ -2,6 +2,7 @@ #include "utils/containers/concat_vectors.h" #include "utils/containers/maximum.h" #include "utils/containers/transform.h" +#include "utils/fmt/vector.h" #include "utils/integer_conversions.h" #include "utils/join_strings.h" #include @@ -53,7 +54,7 @@ std::string cli_get_help_message(std::string const &program_name, if (!all_arg_columns.empty()) { int max_column_width = - std::min(int_from_size_t(maximum(all_arg_column_widths).value()), 20); + std::min(int_from_size_t(maximum(all_arg_column_widths)), 20); auto render_column = [&](std::string const &key, std::optional const &description) { diff --git a/lib/utils/src/utils/containers/are_all_distinct.cc b/lib/utils/src/utils/containers/are_all_distinct.cc new file mode 100644 index 0000000000..52c665d191 --- /dev/null +++ b/lib/utils/src/utils/containers/are_all_distinct.cc @@ -0,0 +1 @@ +#include "utils/containers/are_all_distinct.h" diff --git a/lib/utils/src/utils/containers/compare_by.cc b/lib/utils/src/utils/containers/compare_by.cc new file mode 100644 index 0000000000..b7df348f37 --- /dev/null +++ b/lib/utils/src/utils/containers/compare_by.cc @@ -0,0 +1 @@ +#include "utils/containers/compare_by.h" diff --git a/lib/utils/src/utils/containers/get_first.cc b/lib/utils/src/utils/containers/get_first.cc deleted file mode 100644 index ce8eb9cbea..0000000000 --- a/lib/utils/src/utils/containers/get_first.cc +++ /dev/null @@ -1 +0,0 @@ -#include "utils/containers/get_first.h" diff --git a/lib/utils/src/utils/containers/get_one_of.cc b/lib/utils/src/utils/containers/get_one_of.cc new file mode 100644 index 0000000000..4ce017f6b9 --- /dev/null +++ b/lib/utils/src/utils/containers/get_one_of.cc @@ -0,0 +1 @@ +#include "utils/containers/get_one_of.h" diff --git a/lib/utils/src/utils/containers/index_of.cc b/lib/utils/src/utils/containers/index_of.cc new file mode 100644 index 0000000000..d0c4b4dfd3 --- /dev/null +++ b/lib/utils/src/utils/containers/index_of.cc @@ -0,0 +1 @@ +#include "utils/containers/index_of.h" diff --git a/lib/utils/src/utils/containers/is_submapeq_of.cc b/lib/utils/src/utils/containers/is_submapeq_of.cc new file mode 100644 index 0000000000..567d94fac5 --- /dev/null +++ b/lib/utils/src/utils/containers/is_submapeq_of.cc @@ -0,0 +1 @@ +#include "utils/containers/is_submapeq_of.h" diff --git a/lib/utils/src/utils/containers/is_superseteq_of.cc b/lib/utils/src/utils/containers/is_superseteq_of.cc new file mode 100644 index 0000000000..0728c96f17 --- /dev/null +++ b/lib/utils/src/utils/containers/is_superseteq_of.cc @@ -0,0 +1 @@ +#include "utils/containers/is_superseteq_of.h" diff --git a/lib/utils/src/utils/containers/product_where.cc b/lib/utils/src/utils/containers/product_where.cc new file mode 100644 index 0000000000..3a435e7d80 --- /dev/null +++ b/lib/utils/src/utils/containers/product_where.cc @@ -0,0 +1 @@ +#include "utils/containers/product_where.h" diff --git a/lib/utils/src/utils/containers/reversed_container.cc b/lib/utils/src/utils/containers/reversed_container.cc new file mode 100644 index 0000000000..1a2fe3cf63 --- /dev/null +++ b/lib/utils/src/utils/containers/reversed_container.cc @@ -0,0 +1 @@ +#include "utils/containers/reversed_container.h" diff --git a/lib/utils/src/utils/containers/sum_where.cc b/lib/utils/src/utils/containers/sum_where.cc new file mode 100644 index 0000000000..c09f9c573e --- /dev/null +++ b/lib/utils/src/utils/containers/sum_where.cc @@ -0,0 +1 @@ +#include "utils/containers/sum_where.h" diff --git a/lib/utils/src/utils/containers/value_all.cc b/lib/utils/src/utils/containers/value_all.cc new file mode 100644 index 0000000000..1f863a20e4 --- /dev/null +++ b/lib/utils/src/utils/containers/value_all.cc @@ -0,0 +1 @@ +#include "utils/containers/value_all.h" diff --git a/lib/utils/src/utils/fmt/optional.cc b/lib/utils/src/utils/fmt/optional.cc index e21b32eaa9..4642292920 100644 --- a/lib/utils/src/utils/fmt/optional.cc +++ b/lib/utils/src/utils/fmt/optional.cc @@ -1 +1,9 @@ #include "utils/fmt/optional.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, std::nullopt_t) { + return (s << std::string{"nullopt"}); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc index d06a64597e..0e4d0c6759 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/find_isomorphism.cc @@ -1,5 +1,5 @@ #include "utils/graph/dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/graph/dataflow_graph/algorithms/find_isomorphisms.h" namespace FlexFlow { @@ -13,7 +13,7 @@ std::optional if (all_isomorphisms.empty()) { return std::nullopt; } else { - return get_first(all_isomorphisms); + return get_one_of(all_isomorphisms); } } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc index 8afe7da926..92bd1e32ca 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.cc @@ -1,12 +1,9 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/get_cbc_decomposition.h" #include "utils/containers/are_disjoint.h" #include "utils/containers/extend.h" -#include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" -#include "utils/containers/set_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" -#include "utils/fmt/set.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" @@ -16,7 +13,6 @@ #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" #include "utils/graph/node/algorithms.h" -#include "utils/hash/unordered_set.h" #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc index 2eab8371b2..ccd2808603 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/complete_bipartite_composite/is_complete_bipartite_digraph.h" -#include "utils/containers/get_first.h" #include "utils/containers/set_minus.h" #include "utils/graph/digraph/algorithms.h" #include "utils/graph/node/algorithms.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc index 1d98c44a9f..39523f2ec1 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc @@ -1,6 +1,6 @@ #include "utils/graph/digraph/algorithms/get_imm_post_dominator.h" #include "utils/containers/generate_map.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/get_only.h" #include "utils/containers/intersection.h" #include "utils/containers/restrict_keys.h" @@ -31,7 +31,7 @@ std::optional return get_imm_post_dominator(g, get_only(nodes)); } - Node contracted_node = get_first(nodes); + Node contracted_node = get_one_of(nodes); std::unordered_map contraction = generate_map(nodes, [&](Node const &) { return contracted_node; }); return get_imm_post_dominator(apply_contraction(g, contraction), diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc index d622497629..d75a447127 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphism.cc @@ -1,5 +1,5 @@ #include "utils/graph/open_dataflow_graph/algorithms/find_isomorphism.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.h" namespace FlexFlow { @@ -13,7 +13,7 @@ std::optional if (all_isomorphisms.empty()) { return std::nullopt; } else { - return get_first(all_isomorphisms); + return get_one_of(all_isomorphisms); } } diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc index 1dd5353301..fa17678943 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/find_isomorphisms.cc @@ -3,7 +3,7 @@ #include "utils/bidict/algorithms/left_entries.h" #include "utils/bidict/algorithms/right_entries.h" #include "utils/containers/get_all_permutations.h" -#include "utils/containers/get_first.h" +#include "utils/containers/get_one_of.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/keys.h" #include "utils/containers/values.h" diff --git a/lib/utils/src/utils/rapidcheck/variant.cc b/lib/utils/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..f0537d454f --- /dev/null +++ b/lib/utils/src/utils/rapidcheck/variant.cc @@ -0,0 +1,10 @@ +#include "utils/rapidcheck/variant.h" + +namespace rc { + +using T0 = int; +using T1 = std::string; + +template struct Arbitrary>; + +} // namespace rc diff --git a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h index 519cde7d74..25552003e3 100644 --- a/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h +++ b/lib/utils/test/common/include/test/utils/doctest/fmt/optional.h @@ -13,6 +13,11 @@ struct StringMaker> { } }; +template <> +struct StringMaker { + static String convert(std::nullopt_t const &); +}; + } // namespace doctest #endif diff --git a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc index 8a3f7f158e..09b02ac059 100644 --- a/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc +++ b/lib/utils/test/common/src/test/utils/doctest/fmt/optional.cc @@ -1 +1,9 @@ #include "test/utils/doctest/fmt/optional.h" + +namespace doctest { + +String StringMaker::convert(std::nullopt_t const &m) { + return toString(fmt::to_string(m)); +} + +} // namespace doctest diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc deleted file mode 100644 index 44f602f3bc..0000000000 --- a/lib/utils/test/src/test_algorithms.cc +++ /dev/null @@ -1,246 +0,0 @@ -#include "utils/graph/algorithms.h" -#include "utils/graph/construction.h" -#include "utils/graph/hashmap_undirected_graph.h" -#include "utils/graph/instances/adjacency_digraph.h" -#include "utils/graph/undirected.h" -#include -#include -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("MultiDiGraph") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector p = add_node_ports(g, 4); - - MultiDiEdge e0{n[3], p[3], n[0], p[0]}; - MultiDiEdge e1{n[2], p[2], n[1], p[0]}; - MultiDiEdge e2{n[3], p[3], n[1], p[1]}; - MultiDiEdge e3{n[3], p[3], n[2], p[2]}; - - std::vector e = {e0, e1, e2, e3}; - - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[1], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{e[3]}); - std::unordered_map> expected_result = - std::unordered_map>{ - {n[1], {}}, - {n[2], {n[1]}}, - {n[3], {n[0], n[1], n[2]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - } - - TEST_CASE("DiGraph") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - {n[0], n[3]}, - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[2]}, - }; - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[2], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{}); - auto expected_result = std::unordered_map>{ - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - {n[3], {n[0]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - - SUBCASE("get_imm_dominators") { - std::unordered_map> result = get_imm_dominators(g); - - std::unordered_map> expected_result = { - {n[2], n[0]}, - {n[1], n[0]}, - {n[3], n[0]}, - {n[0], nullopt}, - }; - CHECK(result == expected_result); - } - - SUBCASE("get_dominators") { - std::unordered_map> expected = { - {n[0], {n[0]}}, - {n[1], {n[0], n[1]}}, - {n[2], {n[0], n[2]}}, - {n[3], {n[0], n[3]}}, - }; - CHECK(get_dominators(g) == expected); - } - - SUBCASE("get_sinks") { - auto expected = std::unordered_set{n[2], n[3]}; - CHECK(get_sinks(g) == expected); - } - - SUBCASE("get_bfs") { - std::unordered_set start_points = std::unordered_set{n[0]}; - auto expected = std::vector{n[0], n[2], n[1], n[3]}; - CHECK(get_bfs_ordering(g, start_points) == expected); - } - - SUBCASE("get_predecessors") { - std::unordered_map> expected_result = { - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); - } - } - - TEST_CASE("traversal") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 5); - std::vector edges = { - {n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; - add_edges(g, edges); - - CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); - CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == true); - CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); - CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); - - SUBCASE("with root") { - g.add_edge({n[3], n[2]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - - SUBCASE("without root") { - g.add_edge({n[3], n[0]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - SUBCASE("nonlinear") { - g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs - } - - SUBCASE("not connected") { - g.remove_edge({n[2], n[3]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); - } - } - - TEST_CASE("bfs") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 7); - - std::vector e = { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }; - - add_edges(g, e); - - std::vector ordering = get_bfs_ordering(g, {n[0]}); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < - index_of(ordering, n[r]).value()); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - - CHECK_BEFORE(1, 3); - CHECK_BEFORE(1, 6); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(2, 6); - - CHECK_BEFORE(3, 4); - CHECK_BEFORE(6, 4); - - CHECK_BEFORE(4, 5); - } - - TEST_CASE("get_topological_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - std::vector edges = {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}; - add_edges(g, edges); - std::vector ordering = get_topological_ordering(g); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - CHECK_BEFORE(1, 5); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(3, 4); - CHECK_BEFORE(4, 5); - } - - TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - - CHECK(get_connected_components(g) == expected_components); - } - - TEST_CASE("get_weakly_connected_components") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); - - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; - - CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); - - CHECK(get_weakly_connected_components(g) == expected_components); - } -} diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc deleted file mode 100644 index 76b7fd0d31..0000000000 --- a/lib/utils/test/src/test_containers.cc +++ /dev/null @@ -1,393 +0,0 @@ -#include "utils/containers.h" -#include -#include -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("join_strings") { - std::vector const v = {"Hello", "world", "!"}; - CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); - } - - TEST_CASE("join_strings with container") { - std::vector const v = {"Hello", "world"}; - CHECK(join_strings(v, " ") == "Hello world"); - } - - TEST_CASE("find") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(find(v, 3) != v.cend()); - CHECK(find(v, 6) == v.cend()); - } - - TEST_CASE("sum") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(sum(v) == 15); - } - - TEST_CASE("sum with condition") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Sum of even numbers only - CHECK(sum_where(v, condition) == 6); - } - - TEST_CASE("product") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(product(v) == 120); - } - - TEST_CASE("product_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Product of even numbers only - CHECK(product_where(v, condition) == 8); - } - - TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); - } - - TEST_CASE("contains_key") { - std::unordered_map m = { - {"one", 1}, {"two", 2}, {"three", 3}}; - CHECK(contains_key(m, "one")); - CHECK(!contains_key(m, "four")); - } - - TEST_CASE("map_keys") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; // Mapping function - auto result = map_keys(m, f); - CHECK(result.size() == 2); - CHECK(result[1] == "one"); - CHECK(result[4] == "two"); - } - - TEST_CASE("filter_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = [](int x) { return x % 2 == 1; }; // Filtering function - std::unordered_map result = filter_keys(m, f); - std::unordered_map expected = {{1, "one"}, {3, "three"}}; - CHECK(result == expected); - } - - TEST_CASE("map_values") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](std::string const &s) { return s.size(); }; // Mapping function - std::unordered_map result = map_values(m, f); - std::unordered_map expected = {{1, 3}, {2, 3}}; - CHECK(result == expected); - } - - TEST_CASE("keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {3, 2, 1}; - CHECK(result == expected); - } - - TEST_CASE("values") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::vector result = values(m); - std::vector expected = {"three", "two", "one"}; - CHECK(result == expected); - } - - // TEST_CASE("items") { - // std::unordered_map m = {{1, std::string("one")}, {2, - // std::string("two")}, {3,std::string("three")}}; - // std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; - std::unordered_set result = unique(v); - std::unordered_set expected = {1, 2, 3}; - CHECK(result == expected); - } - - TEST_CASE("unordered_multiset_of") { - std::vector v = {1, 4, 6, 4, 6}; - std::unordered_set expected = {1, 4, 6}; - CHECK(unordered_set_of(v) == expected); - } - - TEST_CASE("index_of") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3) == 2); - CHECK(!index_of(v, 6).has_value()); - } - - TEST_CASE("intersection") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {2, 3, 4}; - std::unordered_set result = intersection(l, r); - std::unordered_set expected = {2, 3}; - CHECK(result == expected); - } - - TEST_CASE("are_disjoint") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {4, 5, 6}; - CHECK(are_disjoint(l, r)); - r.insert(3); - CHECK_FALSE(are_disjoint(l, r)); - } - - TEST_CASE("restrict_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set mask = {2, 3, 4}; - std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); - } - - TEST_CASE("merge_maps(unordered_map)") { - std::unordered_map lhs = {{1, "one"}, {2, "two"}}; - std::unordered_map rhs = {{3, "three"}, {4, "four"}}; - std::unordered_map result = merge_maps(lhs, rhs); - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); - } - - TEST_CASE("merge_maps(bidict)") { - std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; - std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; - std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; - std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; - bidict lhs{fwd_map1, bwd_map1}; - bidict rhs{fwd_map2, bwd_map2}; - - std::unordered_map result = - merge_maps(lhs, rhs); // impicit conversion - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); - } - - TEST_CASE("lookup_in") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = lookup_in(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - CHECK(f(3) == "three"); - } - - TEST_CASE("lookup_in_l") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_l(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - } - - TEST_CASE("lookup_in_r") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_r(m); - CHECK(f("one") == 1); - CHECK(f("two") == 2); - } - - TEST_CASE("set_union") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {2, 3, 4}; - std::unordered_set result = set_union(s1, s2); - std::unordered_set expected = {1, 2, 3, 4}; - CHECK(result == expected); - } - - TEST_CASE("is_subseteq_of") { - std::unordered_set s1 = {1, 2}; - std::unordered_set s2 = {1, 2, 3}; - CHECK(is_subseteq_of(s1, s2) == true); - CHECK(is_subseteq_of(s2, s1) == false); - CHECK(is_subseteq_of(s1, s1) == true); - CHECK(is_subseteq_of(s2, s2) == true); - } - - TEST_CASE("is_superseteq_of") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {1, 2}; - CHECK(is_supserseteq_of(s1, s2) == true); - CHECK(is_supserseteq_of(s2, s1) == false); - } - - TEST_CASE("get_only") { - std::unordered_set s = {42}; - CHECK(get_only(s) == 42); - } - - TEST_CASE("get_first") { - std::unordered_set s = {1, 2, 3}; - CHECK(s.count(get_first(s)) == 1); - } - - TEST_CASE("extend") { - std::vector v = {1, 2, 3}; - std::unordered_set s = {4, 5, 6}; - extend(v, s); - CHECK(v.size() == 6); - std::vector expected = {1, 2, 3, 6, 5, 4}; - CHECK(v == expected); - } - - TEST_CASE("all_of") { - std::vector v = {2, 4, 6, 8}; - CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); - CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); - } - - TEST_CASE("count") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); - CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); - } - - TEST_CASE("are_all_same") { - std::vector v1 = {2, 2, 2, 2}; - std::vector v2 = {1, 2, 3, 4}; - CHECK(are_all_same(v1) == true); - CHECK(are_all_same(v2) == false); - } - - TEST_CASE("vector_transform") { - std::vector v = {1, 2, 3}; - auto result = vector_transform([](int x) { return x * 2; }, v); - CHECK(result == std::vector({2, 4, 6})); - } - - TEST_CASE("vector_of") { - std::unordered_set s = {1, 2, 3}; - std::vector result = vector_of(s); - CHECK(result == std::vector({3, 2, 1})); - } - - TEST_CASE("transform_vector") { - std::vector v = {1, 2, 3}; - auto result = transform(v, [](int x) { return x * 2; }); - CHECK(result == std::vector({2, 4, 6})); - } - - TEST_CASE("transform_unordered_set") { - std::unordered_set s = {1, 2, 3}; - auto result = transform(s, [](int x) { return x * 2; }); - CHECK(result == std::unordered_set({2, 4, 6})); - } - - TEST_CASE("transform_string") { - std::string s = "abc"; - auto result = transform(s, ::toupper); - CHECK(result == "ABC"); - } - - TEST_CASE("repeat") { - int ctr = 0; - std::vector result = repeat(5, [&] { return ctr++; }); - - CHECK(result == std::vector{0, 1, 2, 3, 4}); - } - - TEST_CASE("Testing the 'enumerate' function") { - std::unordered_set input_set = {1, 2, 3, 4, 5}; - std::unordered_map result = enumerate(input_set); - std::unordered_map expected = { - {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; - CHECK(result == expected); - } - - TEST_CASE("Testing the 'maximum' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - auto result = maximum(input_vec); - - // Checking the maximum is as expected - REQUIRE(result == 5); - } - - TEST_CASE("Testing the 'reversed' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - std::vector result = reversed(input_vec); - std::vector expected = {5, 4, 3, 2, 1}; - - // Checking the reversed sequence is as expected - CHECK(result == expected); - } - - TEST_CASE("Testing sorted_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); - CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); - - std::unordered_set s2 = {-5, -1, -3, -2, -4}; - auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); - CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); - } - - TEST_CASE("Testing compare_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - std::vector result = - sorted_by(s, compare_by([](int i) { return (-i); })); - CHECK(result == std::vector{5, 4, 3, 2, 1}); - } - - TEST_CASE("Testing vector_split function") { - std::vector v = {1, 2, 3, 4, 5}; - auto result = vector_split(v, 2); - std::vector prefix = result.first; - std::vector postfix = result.second; - CHECK(prefix == std::vector({1, 2})); - CHECK(postfix == std::vector({3, 4, 5})); - } - - TEST_CASE("Testing value_all function") { - std::vector> v = {1, 2, 3, 4, 5}; - auto value_all_v = value_all(v); - CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); - } - - TEST_CASE("Testing subvec function") { - std::vector v = {1, 2, 3, 4, 5}; - auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); - - CHECK(subvec_v == std::vector({2, 3, 4})); - - auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); - CHECK(subvec_v2 == std::vector({1, 2, 3})); - } - - auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); - } - } - return factors; - }; - - // Example for vector - TEST_CASE("Test for flatmap function on vectors") { - std::vector v = {2, 3, 4, 5}; - auto result = flatmap(v, get_factors); - CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); - } -} diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc deleted file mode 100644 index cc7ac1de32..0000000000 --- a/lib/utils/test/src/test_multidigraph.cc +++ /dev/null @@ -1,94 +0,0 @@ -#include "utils/graph/adjacency_multidigraph.h" -#include "utils/graph/multidiedge.h" -#include "utils/graph/multidigraph_interfaces.h" -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { - MultiDiGraph g = MultiDiGraph::create(); - - std::vector n = repeat(3, [&] { return g.add_node(); }); - std::vector p = repeat(3, [&] { return g.add_node_port(); }); - - std::vector e = {{n[1], p[1], n[0], p[0]}, - {n[2], p[2], n[0], p[0]}, - {n[0], p[0], n[2], p[2]}, - {n[1], p[1], n[2], p[2]}}; - for (MultiDiEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[0], n[1], n[2]}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[0], e[1], e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( - {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( - {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs( - query_set({p[1], p[2]}))) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs( - query_set({p[0], p[2]}))) == - std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all() - .with_src_nodes({n[1]}) - .with_dst_nodes({n[2]}) - .with_src_idxs({p[1]}) - .with_dst_idxs({p[2]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == - std::unordered_set{e[1]}); - - SUBCASE("remove node") { - g.remove_node_unsafe(n[0]); - - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[1], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == - std::unordered_set{}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == - std::unordered_set{e[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == - std::unordered_set{e[2]}); - } - - SUBCASE("remove_edge") { - g.remove_edge(e[0]); - - CHECK(g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( - {n[1]})) == std::unordered_set{}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == - std::unordered_set{e[1]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - } - } -} diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc deleted file mode 100644 index 2b816eea4f..0000000000 --- a/lib/utils/test/src/test_random_utils.cc +++ /dev/null @@ -1,67 +0,0 @@ -#include "utils/random_utils.h" -#include -#include - -void checkProbabilities(std::vector const &counts, - int numIterations, - std::vector const &weights, - float totalWeight) { - for (int i = 0; i < counts.size(); i++) { - float expectedProbability = weights[i] / totalWeight; - float observedProbability = static_cast(counts[i]) / numIterations; - CHECK(observedProbability == - doctest::Approx(expectedProbability).epsilon(0.01f)); - } -} - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("select_random") { - std::vector values = {1, 2, 3, 4, 5}; - - SUBCASE("Select random value") { - int result = select_random(values); - - CHECK(std::find(values.begin(), values.end(), result) != values.end()); - } - - SUBCASE("Invalid arguments") { - std::vector weights = {0.1f, 0.3f, 0.2f}; - CHECK(select_random(values, weights) == 2); - } - } - - TEST_CASE("select_random - Weighted Random Selection") { - SUBCASE("Test with equal weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; - } - - checkProbabilities(counts, numIterations, weights, values.size()); - } - - SUBCASE("Test with different weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; - - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; - } - - float totalWeight = 0.0f; - for (float weight : weights) { - totalWeight += weight; - } - - checkProbabilities(counts, numIterations, weights, totalWeight); - } - } -} diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc deleted file mode 100644 index e7ce12346a..0000000000 --- a/lib/utils/test/src/test_type_index.cc +++ /dev/null @@ -1,35 +0,0 @@ -#include "utils/type_index.h" -#include -#include - -using namespace FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("type_index function") { - SUBCASE("int type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(int); - CHECK(idx == expected_idx); - } - - SUBCASE("string type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(std::string); - CHECK(idx == expected_idx); - } - } - - TEST_CASE("matches function") { - std::type_index idx = typeid(float); - - SUBCASE("matching type") { - bool result = matches(idx); - CHECK(result == true); - } - - SUBCASE("non-matching type") { - bool result = matches(idx); - CHECK(result == false); - } - } -} diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc deleted file mode 100644 index ea519478d3..0000000000 --- a/lib/utils/test/src/test_undirected_graph.cc +++ /dev/null @@ -1,62 +0,0 @@ -#include "test/utils/rapidcheck.h" -#include "test/utils/rapidcheck/visitable.h" -#include "utils/graph/hashmap_undirected_graph.h" -#include "utils/graph/undirected.h" -#include - -/* namespace rc { */ - -/* template <> */ -/* struct Arbitrary { */ -/* static Gen arbitrary() { */ -/* int num_nodes = *gen::inRange( */ -/* } */ -/* }; */ - -/* } */ - -using namespace FlexFlow; - -using namespace rc; - -/* static_assert(supports_rc_arbitrary::value, ""); */ -/* static_assert(is_strong_typedef::value, ""); */ -/* static_assert(supports_rc_arbitrary>::value, ""); - */ -/* static_assert(supports_rc_arbitrary>::value, - * ""); */ -/* static_assert(supports_rc_arbitrary::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ -/* static_assert(is_streamable::value, ""); */ -/* static_assert(is_fmtable::value, ""); */ - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE( - "UndirectedGraph implementations", T, HashmapUndirectedGraph) { - - RC_SUBCASE("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct(gen::elementOf(n), - gen::elementOf(n))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == unordered_set_of(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(UndirectedEdgeQuery::all()) == unordered_set_of(e)); - }); - } -} diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..0a1babd9f9 --- /dev/null +++ b/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -0,0 +1,42 @@ +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("merge_disjoint_bidicts") { + + SUBCASE("disjoint keys and values") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{3, "three"}, {4, "four"}}; + + bidict result = merge_disjoint_bidicts(bd1, bd2); + bidict correct = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + + CHECK(result == correct); + } + + SUBCASE("overlapping key, different associated value") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{2, "three"}, {3, "four"}}; + + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + } + + SUBCASE("overlapping key, same associated value") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{2, "two"}, {3, "three"}}; + + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + } + + SUBCASE("overlapping values") { + bidict bd1 = {{1, "one"}, {2, "two"}}; + bidict bd2 = {{3, "two"}, {4, "four"}}; + + CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + } + } +} diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index fed655013f..d158af129f 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -12,7 +12,37 @@ TEST_SUITE(FF_TEST_SUITE) { dict.equate(1, "one"); dict.equate(2, "two"); - // Test the equate() function + SUBCASE("L type is the same as R type") { + bidict bd; + bd.equate(1, 3); + + SUBCASE("bidict::contains_l") { + CHECK(bd.contains_l(1)); + CHECK_FALSE(bd.contains_l(3)); + } + + SUBCASE("bidict::contains_r") { + CHECK(bd.contains_r(3)); + CHECK_FALSE(bd.contains_r(1)); + } + } + + SUBCASE("L type is not the same as R type") { + bidict dict; + dict.equate(1, "one"); + dict.equate(2, "two"); + + SUBCASE("bidict::contains_l") { + CHECK(dict.contains_l(1)); + CHECK_FALSE(dict.contains_l(3)); + } + + SUBCASE("bidict::contains_r") { + CHECK(dict.contains_r("one")); + CHECK_FALSE(dict.contains_r("three")); + } + } + SUBCASE("bidict::equate") { CHECK(dict.at_l(1) == "one"); CHECK(dict.at_r("one") == 1); @@ -20,7 +50,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_r("two") == 2); } - // Test the erase_l() function SUBCASE("bidict::erase_l") { dict.erase_l(1); CHECK(dict.size() == 1); @@ -28,7 +57,6 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_r("two") == 2); } - // Test the erase_r() function SUBCASE("bidict::erase_r") { dict.erase_r("one"); CHECK(dict.size() == 1); @@ -36,14 +64,12 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(dict.at_l(2) == "two"); } - // Test the reversed() function SUBCASE("bidict::reversed") { bidict reversed_dict = dict.reversed(); CHECK(reversed_dict.at_l("one") == 1); CHECK(reversed_dict.at_r(2) == "two"); } - // Test the size() function SUBCASE("bidict::size") { CHECK(dict.size() == 2); } diff --git a/lib/utils/test/src/utils/containers/all_of.cc b/lib/utils/test/src/utils/containers/all_of.cc new file mode 100644 index 0000000000..247dd62787 --- /dev/null +++ b/lib/utils/test/src/utils/containers/all_of.cc @@ -0,0 +1,13 @@ +#include "utils/containers/all_of.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("all_of") { + std::vector v = {2, 4, 6, 8}; + CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); + CHECK(all_of(v, [](int x) { return x % 4 == 0; }) == false); + } +} diff --git a/lib/utils/test/src/utils/containers/are_all_distinct.cc b/lib/utils/test/src/utils/containers/are_all_distinct.cc new file mode 100644 index 0000000000..6c2e9ea445 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_all_distinct.cc @@ -0,0 +1,24 @@ +#include "utils/containers/are_all_distinct.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_all_distinct") { + + SUBCASE("Empty Container") { + std::vector input = {}; + CHECK(are_all_distinct(input)); + } + SUBCASE("All elements are distinct") { + std::vector input = {1, 2, 3, 4}; + CHECK(are_all_distinct(input)); + } + + SUBCASE("Not all elements are distinct") { + std::vector input = {2, 2, 3, 4}; + CHECK_FALSE(are_all_distinct(input)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/are_disjoint.cc b/lib/utils/test/src/utils/containers/are_disjoint.cc new file mode 100644 index 0000000000..17516dbf13 --- /dev/null +++ b/lib/utils/test/src/utils/containers/are_disjoint.cc @@ -0,0 +1,31 @@ +#include "utils/containers/are_disjoint.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("are_disjoint") { + SUBCASE("disjoint") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {4, 5, 6}; + CHECK(are_disjoint(l, r)); + } + SUBCASE("not disjoint") { + std::unordered_set l = {1, 2, 3, 4}; + std::unordered_set r = {3, 4, 5, 6}; + CHECK_FALSE(are_disjoint(l, r)); + } + + SUBCASE("one empty set") { + std::unordered_set l = {1, 2}; + std::unordered_set r = {}; + CHECK(are_disjoint(l, r)); + } + SUBCASE("both empty sets") { + std::unordered_set l = {}; + std::unordered_set r = {}; + CHECK(are_disjoint(l, r)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/compare_by.cc b/lib/utils/test/src/utils/containers/compare_by.cc new file mode 100644 index 0000000000..8c7221ffb4 --- /dev/null +++ b/lib/utils/test/src/utils/containers/compare_by.cc @@ -0,0 +1,19 @@ +#include "utils/containers/compare_by.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("compare_by") { + std::vector input = {"abc", "a", "ab"}; + auto comp = compare_by( + [](std::string const &s) { return s.length(); }); + std::vector correct = {"a", "ab", "abc"}; + std::sort(input.begin(), input.end(), comp); + CHECK(correct == input); + } +} diff --git a/lib/utils/test/src/utils/containers/contains.cc b/lib/utils/test/src/utils/containers/contains.cc new file mode 100644 index 0000000000..6e0a84c7ab --- /dev/null +++ b/lib/utils/test/src/utils/containers/contains.cc @@ -0,0 +1,13 @@ +#include "utils/containers/contains.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("contains") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } +} diff --git a/lib/utils/test/src/utils/containers/count.cc b/lib/utils/test/src/utils/containers/count.cc new file mode 100644 index 0000000000..4d5e05fb9d --- /dev/null +++ b/lib/utils/test/src/utils/containers/count.cc @@ -0,0 +1,13 @@ +#include "utils/containers/count.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("count") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); + CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); + } +} diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index c6ce9942e9..2f9a5b3c02 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -1,10 +1,11 @@ #include "utils/containers/enumerate.h" #include "test/utils/doctest/fmt/map.h" #include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" #include "utils/containers/keys.h" -#include "utils/containers/unordered_set_of.h" +#include "utils/containers/unordered_multiset_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include @@ -42,25 +43,13 @@ TEST_SUITE(FF_TEST_SUITE) { } TEST_CASE("enumerate(std::unordered_set)") { - std::unordered_set input = {"zero", "one", "two", "three"}; - - std::map correct = { - {0, "zero"}, - {1, "one"}, - {2, "two"}, - {3, "three"}, - }; - - std::map result = enumerate(input); - - std::unordered_set result_keys = keys(correct); - std::unordered_set result_values = - unordered_set_of(values(correct)); + std::unordered_set input = {"A", "B", "C", "D"}; std::unordered_set correct_keys = {0, 1, 2, 3}; - std::unordered_set correct_values = input; + std::unordered_multiset correct_values = {"A", "B", "C", "D"}; + std::map result = enumerate(input); - CHECK(result_keys == correct_keys); - CHECK(result_values == correct_values); + CHECK(keys(result) == correct_keys); + CHECK(unordered_multiset_of(values(result)) == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/filter_keys.cc b/lib/utils/test/src/utils/containers/filter_keys.cc new file mode 100644 index 0000000000..00e327a6f1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/filter_keys.cc @@ -0,0 +1,18 @@ +#include "utils/containers/filter_keys.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("filter_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = [](int x) { return x % 2 == 1; }; + std::unordered_map result = filter_keys(m, f); + std::unordered_map correct = {{1, "one"}, {3, "three"}}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/find.cc b/lib/utils/test/src/utils/containers/find.cc new file mode 100644 index 0000000000..36d4b771d8 --- /dev/null +++ b/lib/utils/test/src/utils/containers/find.cc @@ -0,0 +1,54 @@ +#include "utils/containers/find.h" +#include "test/utils/doctest/check_without_stringify.h" +#include +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find") { + + SUBCASE("vector") { + std::vector v = {1, 2, 3, 3, 4, 5, 3}; + + SUBCASE("element found") { + CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); + } + + SUBCASE("element not found") { + CHECK_WITHOUT_STRINGIFY(find(v, 6) == std::find(v.begin(), v.end(), 6)); + } + + SUBCASE("multiple occurrences of element") { + CHECK_WITHOUT_STRINGIFY(find(v, 3) == std::find(v.begin(), v.end(), 3)); + } + } + + SUBCASE("unordered_set") { + std::unordered_set s = {1, 2, 3, 4, 5}; + + SUBCASE("element in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); + } + + SUBCASE("element not in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + } + } + + SUBCASE("set") { + std::set s = {1, 2, 3, 4, 5}; + + SUBCASE("element in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 3) == std::find(s.begin(), s.end(), 3)); + } + + SUBCASE("element not in container") { + CHECK_WITHOUT_STRINGIFY(find(s, 6) == std::find(s.begin(), s.end(), 6)); + } + } + } +} diff --git a/lib/utils/test/src/utils/containers/flatmap.cc b/lib/utils/test/src/utils/containers/flatmap.cc index c10cc5ae75..bd6d3ae5be 100644 --- a/lib/utils/test/src/utils/containers/flatmap.cc +++ b/lib/utils/test/src/utils/containers/flatmap.cc @@ -2,6 +2,7 @@ #include "test/utils/doctest/fmt/pair.h" #include "test/utils/doctest/fmt/unordered_map.h" #include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/vector.h" #include "utils/containers/map_keys.h" #include "utils/hash/pair.h" #include @@ -10,6 +11,39 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flatmap(std::vector, F)") { + SUBCASE("same data-type") { + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } + } + return factors; + }; + + std::vector input = {2, 3, 4, 5}; + std::vector result = flatmap(input, get_factors); + std::vector correct = {1, 2, 1, 3, 1, 2, 4, 1, 5}; + CHECK(result == correct); + } + + SUBCASE("different data-type") { + auto get_string_sequence = [](int x) -> std::vector { + return { + std::to_string(x - 1), std::to_string(x), std::to_string(2 * x)}; + }; + + std::vector input = {2, 4, 10}; + std::vector result = flatmap(input, get_string_sequence); + std::vector correct = { + "1", "2", "4", "3", "4", "8", "9", "10", "20"}; + CHECK(result == correct); + } + } + TEST_CASE("flatmap(std::unordered_set, F)") { auto get_chars = [](std::string const &s) { std::unordered_set result; diff --git a/lib/utils/test/src/utils/containers/get_one_of.cc b/lib/utils/test/src/utils/containers/get_one_of.cc new file mode 100644 index 0000000000..326a292560 --- /dev/null +++ b/lib/utils/test/src/utils/containers/get_one_of.cc @@ -0,0 +1,19 @@ +#include "utils/containers/get_one_of.h" +#include "utils/containers/contains.h" +#include +#include +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_one_of") { + SUBCASE("non-empty set") { + std::unordered_set s = {1, 2, 3}; + CHECK(contains(s, get_one_of(s))); + } + + SUBCASE("empty set") { + std::unordered_set s = {}; + CHECK_THROWS(get_one_of(s)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/index_of.cc b/lib/utils/test/src/utils/containers/index_of.cc new file mode 100644 index 0000000000..6ab49cfd42 --- /dev/null +++ b/lib/utils/test/src/utils/containers/index_of.cc @@ -0,0 +1,24 @@ +#include "utils/containers/index_of.h" +#include "test/utils/doctest/fmt/optional.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("index_of") { + + std::vector v = {1, 2, 3, 4, 3, 5}; + + SUBCASE("element occurs once in container") { + CHECK(index_of(v, 4).value() == 3); + } + SUBCASE("if element appears multiple times, return the first occurrence") { + CHECK(index_of(v, 3).value() == 2); + } + SUBCASE("element not in container") { + CHECK(index_of(v, 7) == std::nullopt); + } + } +} diff --git a/lib/utils/test/src/utils/containers/is_submapeq_of.cc b/lib/utils/test/src/utils/containers/is_submapeq_of.cc new file mode 100644 index 0000000000..df89444235 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_submapeq_of.cc @@ -0,0 +1,40 @@ +#include "utils/containers/is_submapeq_of.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_submapeq_of") { + std::unordered_map super = { + {1, "one"}, {2, "two"}, {3, "three"}}; + + SUBCASE("keys and values match") { + std::unordered_map sub = {{1, "one"}, {2, "two"}}; + CHECK(is_submapeq_of(sub, super)); + } + + SUBCASE("keys and values don't match") { + std::unordered_map sub = {{1, "one"}, {4, "four"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + + SUBCASE("keys match but values don't") { + std::unordered_map sub = {{1, "wrong_value"}, + {2, "two"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + + SUBCASE("values match but keys don't") { + std::unordered_map sub = {{5, "one"}, {6, "two"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + + SUBCASE("sub is a superset of super") { + std::unordered_map sub = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK_FALSE(is_submapeq_of(sub, super)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/is_subseteq_of.cc b/lib/utils/test/src/utils/containers/is_subseteq_of.cc new file mode 100644 index 0000000000..d762f171b6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_subseteq_of.cc @@ -0,0 +1,16 @@ +#include "utils/containers/is_subseteq_of.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_subseteq_of") { + std::unordered_set s1 = {1, 2}; + std::unordered_set s2 = {1, 2, 3}; + CHECK(is_subseteq_of(s1, s2) == true); + CHECK(is_subseteq_of(s2, s1) == false); + CHECK(is_subseteq_of(s1, s1) == true); + CHECK(is_subseteq_of(s2, s2) == true); + } +} diff --git a/lib/utils/test/src/utils/containers/is_superseteq_of.cc b/lib/utils/test/src/utils/containers/is_superseteq_of.cc new file mode 100644 index 0000000000..e3b429fa64 --- /dev/null +++ b/lib/utils/test/src/utils/containers/is_superseteq_of.cc @@ -0,0 +1,25 @@ +#include "utils/containers/is_superseteq_of.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_superseteq_of") { + std::unordered_set super = {1, 2, 3, 4}; + + SUBCASE("true containment") { + std::unordered_set sub = {1, 2, 3}; + CHECK(is_superseteq_of(super, sub)); + } + + SUBCASE("false containment") { + std::unordered_set sub = {1, 2, 5}; + CHECK_FALSE(is_superseteq_of(super, sub)); + } + + SUBCASE("reflexive") { + CHECK(is_superseteq_of(super, super)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/keys.cc b/lib/utils/test/src/utils/containers/keys.cc new file mode 100644 index 0000000000..5bdaef6d08 --- /dev/null +++ b/lib/utils/test/src/utils/containers/keys.cc @@ -0,0 +1,18 @@ +#include "utils/containers/keys.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set result = keys(m); + std::unordered_set expected = {1, 2, 3}; + CHECK(result == expected); + } +} diff --git a/lib/utils/test/src/utils/containers/map_keys.cc b/lib/utils/test/src/utils/containers/map_keys.cc new file mode 100644 index 0000000000..5c0a81d5e6 --- /dev/null +++ b/lib/utils/test/src/utils/containers/map_keys.cc @@ -0,0 +1,26 @@ +#include "utils/containers/map_keys.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("map_keys") { + SUBCASE("Distinct keys after transformation") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](int x) { return x * x; }; + std::unordered_map result = map_keys(m, f); + std::unordered_map correct = {{1, "one"}, {4, "two"}}; + CHECK(correct == result); + } + + SUBCASE("Non-distinct keys after transformation") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {-1, "minus one"}}; + auto f = [](int x) { return std::abs(x); }; + CHECK_THROWS(map_keys(m, f)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/map_values.cc b/lib/utils/test/src/utils/containers/map_values.cc new file mode 100644 index 0000000000..a21645d0d5 --- /dev/null +++ b/lib/utils/test/src/utils/containers/map_values.cc @@ -0,0 +1,17 @@ +#include "utils/containers/map_values.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("map_values") { + std::unordered_map m = {{1, "one"}, {3, "three"}}; + auto f = [](std::string const &s) { return s.size(); }; + std::unordered_map result = map_values(m, f); + std::unordered_map correct = {{1, 3}, {3, 5}}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/maximum.cc b/lib/utils/test/src/utils/containers/maximum.cc index 71e7395805..2309458069 100644 --- a/lib/utils/test/src/utils/containers/maximum.cc +++ b/lib/utils/test/src/utils/containers/maximum.cc @@ -1,60 +1,25 @@ #include "utils/containers/maximum.h" -#include "test/utils/doctest/fmt/multiset.h" -#include "test/utils/doctest/fmt/optional.h" -#include "test/utils/doctest/fmt/set.h" -#include "test/utils/doctest/fmt/unordered_multiset.h" -#include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" #include -#include -#include #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("maximum(T)", - T, - std::vector, - std::unordered_set, - std::unordered_multiset, - std::set, - std::multiset) { - SUBCASE("input is empty") { - T input = {}; + TEST_CASE("maximum") { - std::optional result = maximum(input); - std::optional correct = std::nullopt; - - CHECK(result == correct); - } - - SUBCASE("input does not have duplicates") { - T input = {1, 3, 2}; - - std::optional result = maximum(input); - std::optional correct = 3; - - CHECK(result == correct); + SUBCASE("non-empty container") { + std::vector input = {1, 5, 3, 4, 2}; + int correct = 5; + int result = maximum(input); + CHECK(correct == result); } - SUBCASE("input has duplicates") { - T input = {1, 2, 2, 0}; + SUBCASE("empty container") { + std::vector input = {}; - std::optional result = maximum(input); - std::optional correct = 2; - - CHECK(result == correct); + CHECK_THROWS(maximum(input)); } } - - TEST_CASE("maximum(std::vector)") { - std::vector input = {"hello", "world"}; - - std::optional result = maximum(input); - std::optional correct = "world"; - - CHECK(result == correct); - } } diff --git a/lib/utils/test/src/utils/containers/merge_maps.cc b/lib/utils/test/src/utils/containers/merge_maps.cc new file mode 100644 index 0000000000..a083e94de3 --- /dev/null +++ b/lib/utils/test/src/utils/containers/merge_maps.cc @@ -0,0 +1,30 @@ +#include "utils/containers/merge_maps.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("merge_maps") { + + SUBCASE("disjoint keys") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{3, "three"}, {4, "four"}}; + + std::unordered_map result = merge_maps(lhs, rhs); + std::unordered_map correct = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + + CHECK(result == correct); + } + + SUBCASE("overlapping keys") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{2, "three"}, {3, "four"}}; + + CHECK_THROWS(merge_maps(lhs, rhs)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/product.cc b/lib/utils/test/src/utils/containers/product.cc new file mode 100644 index 0000000000..3fa94c8e9e --- /dev/null +++ b/lib/utils/test/src/utils/containers/product.cc @@ -0,0 +1,32 @@ +#include "utils/containers/product.h" +#include +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE_TEMPLATE("product", + C, + std::vector, + std::vector, + std::set, + std::unordered_set) { + + SUBCASE("non-empty container") { + C input = {1, -2, 3, 5}; + auto correct = -30; + auto result = product(input); + CHECK(correct == result); + } + + SUBCASE("empty container") { + C input = {}; + auto correct = 1; + auto result = product(input); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/product_where.cc b/lib/utils/test/src/utils/containers/product_where.cc new file mode 100644 index 0000000000..098ae01252 --- /dev/null +++ b/lib/utils/test/src/utils/containers/product_where.cc @@ -0,0 +1,34 @@ +#include "utils/containers/product_where.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("product_where") { + + SUBCASE("empty starting container") { + std::vector input = {}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 1; + int result = product_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("non-empty filtered container") { + std::vector input = {1, -2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = -8; + int result = product_where(input, condition); + CHECK(correct == result); + } + SUBCASE("empty filtered container") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 1; + int result = product_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/restrict_keys.cc b/lib/utils/test/src/utils/containers/restrict_keys.cc new file mode 100644 index 0000000000..b4b376784e --- /dev/null +++ b/lib/utils/test/src/utils/containers/restrict_keys.cc @@ -0,0 +1,17 @@ +#include "utils/containers/restrict_keys.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("restrict_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set mask = {2, 3, 4}; + std::unordered_map result = restrict_keys(m, mask); + std::unordered_map correct = {{2, "two"}, {3, "three"}}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/set_union.cc b/lib/utils/test/src/utils/containers/set_union.cc new file mode 100644 index 0000000000..d842e4df96 --- /dev/null +++ b/lib/utils/test/src/utils/containers/set_union.cc @@ -0,0 +1,16 @@ +#include "utils/containers/set_union.h" +#include "test/utils/doctest/fmt/unordered_set.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("set_union") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {2, 3, 4}; + std::unordered_set result = set_union(s1, s2); + std::unordered_set correct = {1, 2, 3, 4}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/sorted_by.cc b/lib/utils/test/src/utils/containers/sorted_by.cc new file mode 100644 index 0000000000..0ae2e0da77 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sorted_by.cc @@ -0,0 +1,35 @@ +#include "utils/containers/sorted_by.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("sorted_by") { + SUBCASE("sort increasing") { + std::unordered_set s = {5, 2, 3, 4, 1}; + std::vector result = + sorted_by(s, [](int a, int b) { return a < b; }); + std::vector correct = {1, 2, 3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("sort decreasing") { + std::unordered_set input = {-5, -1, -3, -2, -4}; + std::vector result = + sorted_by(input, [](int a, int b) { return a > b; }); + std::vector correct = {-1, -2, -3, -4, -5}; + CHECK(result == correct); + } + + SUBCASE("container contains duplicate elements") { + std::vector input = {3, 1, 3, -4, 1}; + std::vector result = + sorted_by(input, [](int a, int b) { return a < b; }); + std::vector correct = {-4, 1, 1, 3, 3}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/subvec.cc b/lib/utils/test/src/utils/containers/subvec.cc new file mode 100644 index 0000000000..610fc55b5a --- /dev/null +++ b/lib/utils/test/src/utils/containers/subvec.cc @@ -0,0 +1,62 @@ +#include "utils/containers/subvec.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("subvec") { + std::vector v = {1, 2, 3, 4, 5}; + + SUBCASE("Basic subvector") { + auto result = subvec(v, 1, 4); + std::vector correct = {2, 3, 4}; + CHECK(result == correct); + } + + SUBCASE("From beginning to index") { + auto result = subvec(v, std::nullopt, 3); + std::vector correct = {1, 2, 3}; + CHECK(result == correct); + } + + SUBCASE("From index to end") { + auto result = subvec(v, 2, std::nullopt); + std::vector correct = {3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("All of the vector") { + auto result = subvec(v, std::nullopt, std::nullopt); + std::vector correct = {1, 2, 3, 4, 5}; + CHECK(result == correct); + } + + SUBCASE("Start greater than end") { + auto result = subvec(v, 3, 1); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("Start equal to end") { + auto result = subvec(v, 3, 3); + std::vector correct = {}; + CHECK(result == correct); + } + + SUBCASE("Negative indices") { + auto result = subvec(v, -3, -1); + std::vector correct = {3, 4}; + CHECK(result == correct); + } + + SUBCASE("Upper index is out of bounds by 1") { + CHECK_THROWS(subvec(v, 2, 6)); + } + + SUBCASE("Lower index is out of bounds by 1") { + CHECK_THROWS(subvec(v, -6, 2)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/sum_where.cc b/lib/utils/test/src/utils/containers/sum_where.cc new file mode 100644 index 0000000000..7a909aea39 --- /dev/null +++ b/lib/utils/test/src/utils/containers/sum_where.cc @@ -0,0 +1,36 @@ +#include "utils/containers/sum_where.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("sum_where") { + + SUBCASE("starting container is empty") { + std::vector input = {}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 0; + int result = sum_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("resulting container is non-empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x % 2 == 0; }; + int correct = 6; + int result = sum_where(input, condition); + CHECK(correct == result); + } + + SUBCASE("resulting container is empty") { + std::vector input = {1, 2, 3, 4, 5}; + auto condition = [](int x) { return x > 10; }; + int correct = 0; + int result = sum_where(input, condition); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/value_all.cc b/lib/utils/test/src/utils/containers/value_all.cc new file mode 100644 index 0000000000..1a5f2c508a --- /dev/null +++ b/lib/utils/test/src/utils/containers/value_all.cc @@ -0,0 +1,24 @@ +#include "utils/containers/value_all.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("value_all") { + SUBCASE("With nullopt") { + std::vector> input = {1, 2, std::nullopt, 4, 5}; + CHECK_THROWS(value_all(input)); + } + + SUBCASE("Without nullopt") { + std::vector> input = {1, 2, 3, 4, 5}; + std::vector correct = {1, 2, 3, 4, 5}; + std::vector result = value_all(input); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/containers/values.cc b/lib/utils/test/src/utils/containers/values.cc new file mode 100644 index 0000000000..5fe69ac5e9 --- /dev/null +++ b/lib/utils/test/src/utils/containers/values.cc @@ -0,0 +1,19 @@ +#include "utils/containers/values.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("values") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}, {33, "three"}}; + std::unordered_multiset result = values(m); + std::unordered_multiset correct = { + "one", "two", "three", "three"}; + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/vector_split.cc b/lib/utils/test/src/utils/containers/vector_split.cc new file mode 100644 index 0000000000..76bb21348e --- /dev/null +++ b/lib/utils/test/src/utils/containers/vector_split.cc @@ -0,0 +1,39 @@ +#include "utils/containers/vector_split.h" +#include "test/utils/doctest/fmt/vector.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Testing vector_split function") { + std::vector v = {1, 2, 3, 4, 5}; + + SUBCASE("Normal case: idx = 2") { + auto [prefix, postfix] = vector_split(v, 2); + CHECK(prefix == std::vector({1, 2})); + CHECK(postfix == std::vector({3, 4, 5})); + } + + SUBCASE("Boundary case: idx = 0") { + auto [prefix, postfix] = vector_split(v, 0); + CHECK(prefix.empty()); + CHECK(postfix == std::vector({1, 2, 3, 4, 5})); + } + + SUBCASE("Boundary case: idx == list_size") { + auto [prefix, postfix] = vector_split(v, 5); + CHECK(prefix == std::vector({1, 2, 3, 4, 5})); + CHECK(postfix.empty()); + } + + SUBCASE("Out of bounds case: idx = -1") { + CHECK_THROWS_AS(vector_split(v, -1), std::out_of_range); + } + + SUBCASE("Out of bounds case: idx == list_size + 1") { + CHECK_THROWS_AS(vector_split(v, 6), std::out_of_range); + } + } +} diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/utils/deduplicated_priority_queue.cc similarity index 96% rename from lib/utils/test/src/test_deduplicated_priority_queue.cc rename to lib/utils/test/src/utils/deduplicated_priority_queue.cc index 048e95acb7..d5b35cf654 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/utils/deduplicated_priority_queue.cc @@ -1,6 +1,8 @@ #include "utils/deduplicated_priority_queue.h" #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DeduplicatedPriorityQueue push and pop") { DeduplicatedPriorityQueue queue; diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/utils/disjoint_set.cc similarity index 73% rename from lib/utils/test/src/test_disjoint_set.cc rename to lib/utils/test/src/utils/disjoint_set.cc index 65037be3dd..be88a30cdd 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/utils/disjoint_set.cc @@ -1,4 +1,5 @@ #include "utils/disjoint_set.h" +#include "test/utils/doctest/fmt/optional.h" #include using namespace FlexFlow; @@ -18,10 +19,10 @@ std::string generate_element(int seed) { TEST_SUITE(FF_TEST_SUITE) { TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { - disjoint_set> ds; + disjoint_set> ds; SUBCASE("SingleElementSets") { - optional element = generate_element(1); + std::optional element = generate_element(1); CHECK(ds.find(element) == element); element = generate_element(2); @@ -29,10 +30,10 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("UnionAndFind") { - optional element1 = generate_element(1); - optional element2 = generate_element(2); - optional element3 = generate_element(3); - optional element4 = generate_element(4); + std::optional element1 = generate_element(1); + std::optional element2 = generate_element(2); + std::optional element3 = generate_element(3); + std::optional element4 = generate_element(4); ds.m_union(element1, element2); CHECK(ds.find(element1) == ds.find(element2)); @@ -55,11 +56,11 @@ TEST_SUITE(FF_TEST_SUITE) { ds.m_union(1, 4); ds.m_union(5, 6); - std::map, optional, OptionalComparator> + std::map, std::optional, OptionalComparator> expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; - std::map, optional, OptionalComparator> mapping = - ds.get_mapping(); + std::map, std::optional, OptionalComparator> + mapping = ds.get_mapping(); for (auto const &kv : mapping) { CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/utils/dot_file.cc similarity index 100% rename from lib/utils/test/src/test_dot_file.cc rename to lib/utils/test/src/utils/dot_file.cc diff --git a/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc new file mode 100644 index 0000000000..ce4d7a373b --- /dev/null +++ b/lib/utils/test/src/utils/graph/multidigraph/multidigraph.cc @@ -0,0 +1,171 @@ +#include "utils/graph/multidigraph/multidigraph.h" +#include "utils/containers/contains.h" +#include "utils/graph/instances/adjacency_multidigraph.h" +#include "utils/graph/multidigraph/multidiedge_query.h" +#include "utils/graph/query_set.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph") { + MultiDiGraph g = MultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + MultiDiEdge e0 = g.add_edge(n0, n2); + MultiDiEdge e1 = g.add_edge(n1, n0); + MultiDiEdge e2 = g.add_edge(n1, n0); + MultiDiEdge e3 = g.add_edge(n1, n2); + MultiDiEdge e4 = g.add_edge(n1, n2); + MultiDiEdge e5 = g.add_edge(n2, n0); + MultiDiEdge e6 = g.add_edge(n2, n2); + + SUBCASE("add_node") { + Node n3 = g.add_node(); + std::unordered_set result = g.query_nodes(NodeQuery{{n3}}); + std::unordered_set correct = {n3}; + CHECK(result == correct); + } + + SUBCASE("add_edge") { + SUBCASE("non-duplicate edge") { + MultiDiEdge e7 = g.add_edge(n2, n1); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + std::unordered_set correct = {e7}; + CHECK(result == correct); + } + + SUBCASE("duplicate edge") { + MultiDiEdge e7 = g.add_edge(n2, n1); + MultiDiEdge e8 = g.add_edge(n2, n1); + + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n2}, {n1})); + std::unordered_set correct = {e7, e8}; + CHECK(result == correct); + } + } + + SUBCASE("remove_node") { + g.remove_node(n0); + + std::unordered_set node_result = g.query_nodes(NodeQuery{{n0}}); + std::unordered_set node_correct = {}; + CHECK(node_result == node_correct); + + std::unordered_set edge_result = + g.query_edges(MultiDiEdgeQuery({n0}, {n1, n2})); + std::unordered_set edge_correct = {}; + CHECK(edge_result == edge_correct); + } + + SUBCASE("remove_edge") { + g.remove_edge(e3); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n2})); + std::unordered_set correct = {e4}; + CHECK(result == correct); + + SUBCASE("remove non-duplicate edge") { + g.remove_edge(e0); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0}, {n2})); + std::unordered_set correct = {}; + CHECK(result == correct); + } + + SUBCASE("remove duplicate edge") { + g.remove_edge(e1); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n0})); + std::unordered_set correct = {e2}; + CHECK(result == correct); + } + } + + SUBCASE("query_nodes") { + SUBCASE("all nodes") { + std::unordered_set result = + g.query_nodes(NodeQuery{{n0, n1, n2}}); + std::unordered_set correct = {n0, n1, n2}; + CHECK(result == correct); + } + + SUBCASE("specific nodes") { + std::unordered_set result = g.query_nodes(NodeQuery{{n0, n2}}); + std::unordered_set correct = {n0, n2}; + CHECK(result == correct); + } + + SUBCASE("matchall") { + std::unordered_set result = + g.query_nodes(NodeQuery{matchall()}); + std::unordered_set correct = {n0, n1, n2}; + CHECK(result == correct); + } + + SUBCASE("nodes not in graph") { + Node n3 = Node(3); + Node n4 = Node(4); + std::unordered_set result = g.query_nodes(NodeQuery{{n3, n4}}); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } + + SUBCASE("query_edges") { + SUBCASE("all edges") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n0, n1, n2})); + std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; + CHECK(result == correct); + } + + SUBCASE("edges from n1") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1}, {n0, n1, n2})); + std::unordered_set correct = {e1, e2, e3, e4}; + CHECK(result == correct); + } + + SUBCASE("edges to n2") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n0, n1, n2}, {n2})); + std::unordered_set correct = {e0, e3, e4, e6}; + CHECK(result == correct); + } + + SUBCASE("matchall") { + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery(matchall(), matchall())); + std::unordered_set correct = {e0, e1, e2, e3, e4, e5, e6}; + CHECK(result == correct); + } + + SUBCASE("nodes that don't exist") { + Node n3 = Node(3); + Node n4 = Node(4); + std::unordered_set result = + g.query_edges(MultiDiEdgeQuery({n1, n3}, {n4})); + std::unordered_set correct = {}; + CHECK(result == correct); + } + } + SUBCASE("get_multidiedge_src") { + Node result = g.get_multidiedge_src(e0); + Node correct = n0; + CHECK(result == correct); + } + + SUBCASE("get_multidiedge_dst") { + Node result = g.get_multidiedge_dst(e0); + Node correct = n2; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/hash/map.cc b/lib/utils/test/src/utils/hash/map.cc new file mode 100644 index 0000000000..b4da6ddb68 --- /dev/null +++ b/lib/utils/test/src/utils/hash/map.cc @@ -0,0 +1,20 @@ +#include "utils/hash/map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::map map1{{1, 2}}; + std::map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({3, 4}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/set.cc b/lib/utils/test/src/utils/hash/set.cc new file mode 100644 index 0000000000..f9ccd925c6 --- /dev/null +++ b/lib/utils/test/src/utils/hash/set.cc @@ -0,0 +1,20 @@ +#include "utils/hash/set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::set set1{1, 2, 3}; + std::set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/tuple.cc b/lib/utils/test/src/utils/hash/tuple.cc new file mode 100644 index 0000000000..61240fd7b1 --- /dev/null +++ b/lib/utils/test/src/utils/hash/tuple.cc @@ -0,0 +1,21 @@ +#include "utils/hash/tuple.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::tuple tuple1{1, "test", 3.14}; + std::tuple tuple2{2, "test", 3.14}; + + size_t hash1 = get_std_hash(tuple1); + size_t hash2 = get_std_hash(tuple2); + + CHECK(hash1 != hash2); + + std::get<0>(tuple1) = 2; + hash1 = get_std_hash(tuple1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/utils/hash/unordered_map.cc similarity index 69% rename from lib/utils/test/src/test_hash.cc rename to lib/utils/test/src/utils/hash/unordered_map.cc index decf405e7a..bc0b2879a4 100644 --- a/lib/utils/test/src/test_hash.cc +++ b/lib/utils/test/src/utils/hash/unordered_map.cc @@ -1,10 +1,10 @@ -#include "utils/hash-utils.h" +#include "utils/hash/unordered_map.h" #include -using namespace FlexFlow; +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("hash:unordered_map") { + TEST_CASE("std::hash>") { std::unordered_map map1{{1, 2}}; std::unordered_map map2{{1, 2}, {3, 4}}; @@ -13,7 +13,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(hash1 != hash2); - map1.insert({1, 2}); + map1.insert({3, 4}); hash1 = get_std_hash(map1); CHECK(hash1 == hash2); } diff --git a/lib/utils/test/src/utils/hash/unordered_set.cc b/lib/utils/test/src/utils/hash/unordered_set.cc new file mode 100644 index 0000000000..299b3e5d13 --- /dev/null +++ b/lib/utils/test/src/utils/hash/unordered_set.cc @@ -0,0 +1,20 @@ +#include "utils/hash/unordered_set.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::unordered_set set1{1, 2, 3}; + std::unordered_set set2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(set1); + size_t hash2 = get_std_hash(set2); + + CHECK(hash1 != hash2); + + set1.insert(4); + hash1 = get_std_hash(set1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/hash/vector.cc b/lib/utils/test/src/utils/hash/vector.cc new file mode 100644 index 0000000000..ee5bb456eb --- /dev/null +++ b/lib/utils/test/src/utils/hash/vector.cc @@ -0,0 +1,20 @@ +#include "utils/hash/vector.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("std::hash>") { + std::vector vec1{1, 2, 3}; + std::vector vec2{1, 2, 3, 4}; + + size_t hash1 = get_std_hash(vec1); + size_t hash2 = get_std_hash(vec2); + + CHECK(hash1 != hash2); + + vec1.push_back(4); + hash1 = get_std_hash(vec1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/utils/join_strings.cc b/lib/utils/test/src/utils/join_strings.cc new file mode 100644 index 0000000000..ca1887e84d --- /dev/null +++ b/lib/utils/test/src/utils/join_strings.cc @@ -0,0 +1,47 @@ +#include "utils/join_strings.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("join_strings") { + std::vector v = {"Hello", "world", "!"}; + + SUBCASE("iterator") { + std::string result = join_strings(v.begin(), v.end(), " "); + std::string correct = "Hello world !"; + CHECK(result == correct); + } + + SUBCASE("join_strings with container") { + std::string result = join_strings(v, " "); + std::string correct = "Hello world !"; + CHECK(result == correct); + } + + SUBCASE("join_strings with transforming function") { + auto add_exclamation = [](std::string const &str) { return str + "!"; }; + std::string result = join_strings(v, " ", add_exclamation); + std::string correct = "Hello! world! !!"; + CHECK(result == correct); + } + + SUBCASE("join_strings with transforming function, iterator") { + auto add_exclamation = [](std::string const &str) { return str + "!"; }; + std::string result = + join_strings(v.begin(), v.end(), " ", add_exclamation); + std::string correct = "Hello! world! !!"; + CHECK(result == correct); + } + + SUBCASE("empty sequence") { + v = {}; + std::string result = join_strings(v, "!"); + std::string correct = ""; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/random_utils.cc b/lib/utils/test/src/utils/random_utils.cc new file mode 100644 index 0000000000..8e7d22138f --- /dev/null +++ b/lib/utils/test/src/utils/random_utils.cc @@ -0,0 +1,60 @@ +#include "utils/random_utils.h" +#include "utils/containers/contains.h" +#include "utils/containers/filter.h" +#include "utils/containers/repeat.h" +#include "utils/containers/sum.h" +#include "utils/containers/zip.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("select_random(std::vector)") { + std::vector values = {1, 2, 3, 4, 5}; + + SUBCASE("selected value is in container") { + SUBCASE("equal weights") { + int result = select_random(values); + CHECK(contains(values, result)); + } + + SUBCASE("unequal weights") { + std::vector weights = {0.1f, 0.3f, 0.2f, 0.2f, 0.2f}; + int result = select_random(values, weights); + CHECK(contains(values, result)); + } + } + + SUBCASE("correct distribution") { + auto check_probabilities = [](std::vector const &values, + std::vector const &weights) { + int num_iterations = 10'000; + std::vector trials = repeat( + num_iterations, [&]() { return select_random(values, weights); }); + + for (std::pair const &p : zip(values, weights)) { + int v = p.first; + float w = p.second; + float expectedProbability = w / sum(weights); + int num_occurrences = + filter(trials, [&](int c) { return (c == v); }).size(); + float observedProbability = + static_cast(num_occurrences) / num_iterations; + CHECK(observedProbability == + doctest::Approx(expectedProbability).epsilon(0.01f)); + } + }; + + SUBCASE("equal weights") { + std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + check_probabilities(values, weights); + } + + SUBCASE("unequal weights") { + std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; + check_probabilities(values, weights); + } + } + } +} diff --git a/lib/utils/test/src/utils/rapidcheck/variant.cc b/lib/utils/test/src/utils/rapidcheck/variant.cc new file mode 100644 index 0000000000..ca830fb3ab --- /dev/null +++ b/lib/utils/test/src/utils/rapidcheck/variant.cc @@ -0,0 +1,13 @@ +#include "utils/rapidcheck/variant.h" +#include "test/utils/rapidcheck.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("Arbitrary") { + RC_SUBCASE("valid type", [](std::variant v) { + return std::holds_alternative(v) || std::holds_alternative(v); + }); + } +} diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/utils/record_formatter.cc similarity index 100% rename from lib/utils/test/src/test_format.cc rename to lib/utils/test/src/utils/record_formatter.cc diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/utils/sequence.cc similarity index 100% rename from lib/utils/test/src/test_sequence.cc rename to lib/utils/test/src/utils/sequence.cc diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/utils/stack_map.cc similarity index 73% rename from lib/utils/test/src/test_stack_map.cc rename to lib/utils/test/src/utils/stack_map.cc index f117820c5d..3ab4faaa80 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/utils/stack_map.cc @@ -1,4 +1,6 @@ #include "utils/stack_map.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/vector.h" #include using namespace FlexFlow; @@ -6,8 +8,8 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("stack_map") { stack_map map; - // Test the [] operator to insert and access elements - SUBCASE("BracketOperator") { + + SUBCASE("operator[]") { map[1] = 10; map[2] = 20; @@ -15,8 +17,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(map[2] == 20); } - // Test the insert() function - SUBCASE("Insert") { + SUBCASE("insert") { map.insert(1, 10); map.insert(2, 20); @@ -24,21 +25,19 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(map[2] == 20); } - // Test the at() function to access elements - SUBCASE("At") { + SUBCASE("at") { map[1] = 10; map[2] = 20; CHECK(map.at(1) == 10); CHECK(map.at(2) == 20); CHECK(map.at(1) != 20); - // Test const version of at() function + stack_map const &const_map = map; CHECK(const_map.at(1) == 10); CHECK(const_map.at(2) == 20); } - // Test the begin() and end() functions for iterator SUBCASE("Iterator") { map[1] = 10; map[2] = 20; diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/utils/stack_string.cc similarity index 95% rename from lib/utils/test/src/test_stack_string.cc rename to lib/utils/test/src/utils/stack_string.cc index b89e3277cd..8dbfe36c9d 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/utils/stack_string.cc @@ -1,5 +1,5 @@ -#include "test/utils/rapidcheck.h" #include "utils/stack_string.h" +#include "test/utils/rapidcheck.h" #include using namespace FlexFlow; @@ -84,6 +84,8 @@ TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("Arbitrary") { constexpr std::size_t MAXSIZE = 10; - RCSUBCASE([](stack_string const &s) {}); + RC_SUBCASE([&](stack_string const &s) { + RC_ASSERT(s.size() <= MAXSIZE); + }); } } diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/utils/stack_vector.cc similarity index 68% rename from lib/utils/test/src/test_stack_vector.cc rename to lib/utils/test/src/utils/stack_vector.cc index 577e61092c..6cdd91ece1 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/utils/stack_vector.cc @@ -1,28 +1,31 @@ -#include "test/utils/rapidcheck.h" #include "utils/stack_vector.h" +#include "test/utils/doctest/fmt/vector.h" +#include "test/utils/rapidcheck.h" #include #include using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::push_back", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; vector.push_back(10); - std::vector res = vector; - std::vector expected = {10}; - CHECK(res == expected); + std::vector result = vector; + std::vector correct = {10}; + CHECK(result == correct); vector.push_back(20); - expected = {10, 20}; - res = vector; - CHECK(res == expected); + correct = {10, 20}; + result = vector; + CHECK(result == correct); } - TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::operator[]", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -36,7 +39,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector[2] == 30); } - TEST_CASE_TEMPLATE("Size", T, int, double, char) { + TEST_CASE_TEMPLATE("stack_vector::size", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -50,7 +53,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector.size() == 2); } - TEST_CASE_TEMPLATE("==", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector::operator==", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector1, vector2; @@ -66,7 +70,7 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector1 == vector2); } - TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { + TEST_CASE_TEMPLATE("stack_vector::back", T, int, double, char) { constexpr std::size_t MAXSIZE = 5; using StackVector = stack_vector; StackVector vector; @@ -78,7 +82,8 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(vector.back() == 20); } - TEST_CASE_TEMPLATE("Arbitrary", T, int, double, char) { + TEST_CASE_TEMPLATE( + "stack_vector - check for size bound", T, int, double, char) { constexpr std::size_t MAXSIZE = 10; RC_SUBCASE("within bound", [&](stack_vector v) { RC_ASSERT(v.size() <= MAXSIZE); diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/utils/tuple.cc similarity index 73% rename from lib/utils/test/src/test_tuple.cc rename to lib/utils/test/src/utils/tuple.cc index 96171510a7..01a1ebca18 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/utils/tuple.cc @@ -1,6 +1,5 @@ #include "utils/tuple.h" #include - #include #include @@ -37,23 +36,23 @@ TEST_SUITE(FF_TEST_SUITE) { } } - TEST_CASE("tuple_prepend function") { - std::tuple t1(3.14f, 2.71828); + TEST_CASE("tuple_prepend") { + std::tuple t1 = {3.14f, 2.71828}; int value = 42; - auto result = tuple_prepend(value, t1); - std::tuple expected(42, 3.14f, 2.71828); - CHECK(result == expected); + std::tuple result = tuple_prepend(value, t1); + std::tuple correct = {42, 3.14f, 2.71828}; + CHECK(tuple_compare(result, correct)); } - TEST_CASE("Testing tuple_head_t") { + TEST_CASE("tuple_head_t") { CHECK(std::is_same>, std::tuple>::value); CHECK(std::is_same>, std::tuple<>>::value); } - TEST_CASE("Testing tuple_slice_t") { + TEST_CASE("tuple_slice_t") { CHECK(std::is_same>, std::tuple>::value); CHECK(std::is_same>, @@ -62,17 +61,17 @@ TEST_SUITE(FF_TEST_SUITE) { std::tuple>::value); } - TEST_CASE("Testing tuple_compare function") { - std::tuple tup1{1, 3.14, 'a'}; - std::tuple tup2{1, 3.14, 'a'}; - std::tuple tup3{2, 3.14, 'b'}; + TEST_CASE("tuple_compare") { + std::tuple tup1 = {1, 3.14, 'a'}; + std::tuple tup2 = {1, 3.14, 'a'}; + std::tuple tup3 = {2, 3.14, 'b'}; CHECK(tuple_compare(tup1, tup2)); CHECK(!tuple_compare(tup1, tup3)); } - TEST_CASE("Testing get function with valid index") { - std::tuple tup{1, 3.14, 'a'}; + TEST_CASE("get") { + std::tuple tup = {1, 3.14, 'a'}; CHECK(get(tup) == 1); CHECK(get(tup) == 3.14); diff --git a/lib/utils/test/src/utils/type_index.cc b/lib/utils/test/src/utils/type_index.cc new file mode 100644 index 0000000000..0d53868a92 --- /dev/null +++ b/lib/utils/test/src/utils/type_index.cc @@ -0,0 +1,34 @@ +#include "utils/type_index.h" +#include "test/utils/doctest/check_without_stringify.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_type_index_for_type") { + SUBCASE("int type") { + std::type_index idx = get_type_index_for_type(); + std::type_index expected_idx = typeid(int); + CHECK_WITHOUT_STRINGIFY(idx == expected_idx); + } + + SUBCASE("string type") { + std::type_index idx = get_type_index_for_type(); + std::type_index expected_idx = typeid(std::string); + CHECK_WITHOUT_STRINGIFY(idx == expected_idx); + } + } + + TEST_CASE("matches(std::type_index)") { + std::type_index idx = typeid(float); + + SUBCASE("matching type") { + CHECK(matches(idx)); + } + + SUBCASE("non-matching type") { + CHECK_FALSE(matches(idx)); + } + } +} diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/utils/variant.cc similarity index 91% rename from lib/utils/test/src/test_variant.cc rename to lib/utils/test/src/utils/variant.cc index 0bd01b8dfe..3f6feadda0 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/utils/variant.cc @@ -1,7 +1,10 @@ -#include "test/utils/rapidcheck.h" #include "utils/variant.h" +#include "test/utils/doctest/fmt/optional.h" +#include "test/utils/doctest/fmt/variant.h" #include +using namespace ::FlexFlow; + TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("widen and narrow functions") { SUBCASE("widen function") { @@ -70,10 +73,4 @@ TEST_SUITE(FF_TEST_SUITE) { // Check the result CHECK(get(wider_variant) == 42); } - - TEST_CASE("Arbitrary") { - RC_SUBCASE("valid type", [](std::variant v) { - return std::holds_alternative(v) || std::holds_alternative(v); - }); - } } diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/utils/vector.cc similarity index 91% rename from lib/utils/test/src/test_vector.cc rename to lib/utils/test/src/utils/vector.cc index c6eb0828b8..18eda49543 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/utils/vector.cc @@ -1,5 +1,9 @@ #include "utils/vector.h" +#include "test/utils/doctest/fmt/vector.h" #include +#include + +using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("concat function") {