From 1b0c32eb715daca846e4fd7602d74d4ee08910fb Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Sat, 30 Mar 2024 14:06:05 -0700 Subject: [PATCH] Re-merge #1229 (#1346) * compiler build * unity dp works * format * fmt * fix * add substitutions, compiler, and their unit tests to CI * disable runtime unit test * minor fix * (not compilable) visitable issue for OptimalCostState * fix machine mapping hash & refactor dp algorithm * minor fix * fix variant issue * fmt * fix * fmt * fix * add more unit tests * fmt * Fix post-merge * Add shell hook for sapling development * changed from nullopt to std::nullopt * fix cast issue * Fix spdlog cmake issue * Re-remove submodules * minor fix & fmt * upd tests name to match ci * Add TEST_SUITE declaration to make tests findable by ctest * Remove unnecessary nix files, add utils test to ci * Fix utils tests name, format --------- Co-authored-by: wmdi Co-authored-by: Pietro Max Marsella --- .flake/patches/doctest-template-test.patch | 50 + .flake/pkgs/tokenizers-cpp.nix | 43 - .github/workflows/helpers/build_libs.sh | 9 + .../helpers/{build_cuda.sh => cmake_cuda.sh} | 17 +- .github/workflows/helpers/test_libs.sh | 14 + .github/workflows/per-lib-check.yml | 37 +- CMakeLists.txt | 2 +- cmake/doctest.cmake | 9 - cmake/doctestlib.cmake | 11 + cmake/flexflow-utils.cmake | 4 +- cmake/fmt.cmake | 3 +- cmake/nccl.cmake | 1 + cmake/rapidcheck.cmake | 6 +- cmake/spdlog.cmake | 7 +- flake.nix | 83 +- lib/compiler/CMakeLists.txt | 3 +- lib/compiler/include/compiler/compiler.h | 4 +- lib/compiler/include/compiler/cost_estimate.h | 5 +- .../include/compiler/machine_mapping.h | 27 +- .../include/compiler/unity_algorithm.h | 7 +- lib/compiler/src/graph_utils.cc | 17 +- lib/compiler/src/graph_utils.h | 3 +- lib/compiler/src/machine_mapping.cc | 294 +- lib/compiler/src/old/basic_graph.h | 158 - lib/compiler/src/old/dominators.h | 494 --- lib/compiler/src/old/graph.cc | 1255 ------ lib/compiler/src/old/graph.h | 248 -- lib/compiler/src/old/graph_structures.h | 269 -- lib/compiler/src/old/node.h | 47 - .../src/old/parallel_dim_mapping_record.h | 4 - lib/compiler/src/old/search_helper.cc | 525 --- lib/compiler/src/old/search_helper.h | 122 - lib/compiler/src/old/simplification.cc | 189 - lib/compiler/src/old/simplification.h | 34 - lib/compiler/src/old/split_types.cc | 36 - lib/compiler/src/old/split_types.h | 32 - lib/compiler/src/old/substitution.cc | 3733 ----------------- lib/compiler/src/old/substitution.h | 309 -- lib/compiler/src/unity_algorithm.cc | 33 +- ...ive_logger.cc => recursive_logger.cc.todo} | 0 ...rsive_logger.h => recursive_logger.h.todo} | 0 lib/compiler/test/CMakeLists.txt | 3 +- .../test/{ => src}/test_cost_estimator.h | 0 lib/compiler/test/src/test_generator.h | 174 + .../test/src/test_labelled_open_graph.cc | 130 + lib/compiler/test/src/test_machine_mapping.cc | 23 + lib/compiler/test/src/test_open_graph.cc | 76 + lib/compiler/test/src/test_optimal_cost.cc | 69 + lib/compiler/test/src/test_unity_algorithm.cc | 28 + lib/compiler/test/test_disjoint_set.cc | 19 - lib/compiler/test/test_dominators.cc | 322 -- lib/compiler/test/test_dot.cc | 23 - lib/compiler/test/test_dp.cc | 54 - lib/compiler/test/test_generator.h | 168 - lib/compiler/test/test_labelled_open_graph.cc | 76 - lib/compiler/test/test_machine_mapping.cc | 21 - lib/compiler/test/test_machine_view.cc | 33 - lib/compiler/test/test_open_graph.cc | 102 - lib/compiler/test/test_optimal_cost.cc | 24 - lib/compiler/test/test_parallel_config.cc | 25 - lib/compiler/test/test_random_utils.cc | 47 - lib/compiler/test/test_substitution_loader.cc | 144 - lib/compiler/test/test_unity_algorithm.cc | 23 - .../include/op-attrs/operator_attrs.h | 1 + lib/op-attrs/src/attention.cc | 7 + lib/op-attrs/src/embedding.cc | 8 +- lib/op-attrs/src/get_output_shapes.cc | 6 + .../src/parallel_dim_mapping_record_solver.cc | 8 + lib/pcg/include/pcg/device_id.h | 1 + lib/pcg/include/pcg/machine_specification.h | 19 +- lib/pcg/include/pcg/machine_view.h | 10 +- lib/pcg/include/pcg/operator.h | 19 +- lib/pcg/include/pcg/optimizer.h | 22 +- .../include/pcg/parallel_computation_graph.h | 11 + lib/pcg/include/pcg/parallel_tensor.h | 2 + lib/pcg/include/pcg/strided_rectangle.h | 22 +- lib/pcg/src/machine_view.cc | 3 - lib/pcg/src/operator.cc | 4 - lib/pcg/src/parallel_computation_graph.cc | 40 + lib/pcg/src/parallel_tensor.cc | 4 + lib/pcg/src/strided_rectangle.cc | 6 +- lib/runtime/CMakeLists.txt | 26 +- .../include/substitutions/attribute_expr.h | 2 +- .../include/substitutions/get_attribute.h | 104 +- .../include/substitutions/operator_pattern.h | 33 +- .../include/substitutions/output_graph.h | 2 +- .../substitutions/parallel_tensor_pattern.h | 4 +- .../include/substitutions/substitution.h | 8 + lib/substitutions/src/graph_pattern.cc | 102 +- lib/substitutions/src/graph_pattern_match.cc | 24 +- lib/substitutions/src/operator_attributes.cc | 162 +- lib/substitutions/src/substitution.cc | 296 +- lib/substitutions/test/CMakeLists.txt | 2 +- .../test/src/test_pattern_matches.cc | 70 +- .../test/src/test_substitution.cc | 249 +- lib/utils/include/utils/containers.decl.h | 7 +- lib/utils/include/utils/containers.h | 12 +- lib/utils/include/utils/dot_file.h | 7 +- lib/utils/include/utils/fmt.h | 8 + lib/utils/include/utils/graph/algorithms.h | 6 + .../utils/graph/labelled/labelled_open.decl.h | 124 - .../utils/graph/labelled/labelled_open.h | 173 - .../graph/labelled/labelled_open_interfaces.h | 62 - .../utils/graph/labelled/node_labelled.h | 56 +- .../graph/labelled/node_labelled_interfaces.h | 36 + .../utils/graph/labelled/node_labelled_open.h | 63 +- .../include/utils/graph/labelled/open_views.h | 58 +- .../utils/graph/labelled/output_labelled.h | 86 +- .../labelled/output_labelled_interfaces.h | 15 +- .../graph/labelled/output_labelled_open.h | 117 +- .../output_labelled_open_interfaces.h | 34 + .../utils/graph/labelled/standard_labelled.h | 84 +- .../labelled/unordered_labelled_graphs.h | 249 +- .../include/utils/graph/labelled/views.h | 42 +- .../include/utils/graph/labelled_graphs.h | 1 + lib/utils/include/utils/graph/open_graphs.h | 3 +- lib/utils/include/utils/graph/views.h | 16 +- lib/utils/include/utils/hash-utils.h | 4 +- lib/utils/include/utils/variant.h | 16 +- lib/utils/src/graph/algorithms.cc | 68 +- lib/utils/src/graph/digraph.cc | 7 +- lib/utils/src/graph/multidigraph.cc | 7 +- lib/utils/src/graph/node.cc | 4 +- lib/utils/src/graph/open_edge.cc | 6 +- lib/utils/src/graph/open_graphs.cc | 22 +- lib/utils/src/graph/serialparallel.cc | 21 +- lib/utils/src/graph/undirected.cc | 6 +- lib/utils/src/graph/views.cc | 35 +- lib/utils/test/CMakeLists.txt | 24 +- lib/utils/test/src/test_algorithms.cc | 410 +- lib/utils/test/src/test_bidict.cc | 100 +- lib/utils/test/src/test_containers.cc | 653 +-- lib/utils/test/src/test_cow_ptr.cc | 62 + .../src/test_deduplicated_priority_queue.cc | 48 +- lib/utils/test/src/test_disjoint_set.cc | 75 +- lib/utils/test/src/test_dot_file.cc | 76 +- lib/utils/test/src/test_format.cc | 46 +- lib/utils/test/src/test_hash.cc | 20 + lib/utils/test/src/test_multidigraph.cc | 140 +- lib/utils/test/src/test_random_utils.cc | 72 +- lib/utils/test/src/test_sequence.cc | 308 +- lib/utils/test/src/test_stack_map.cc | 88 +- lib/utils/test/src/test_stack_string.cc | 124 +- lib/utils/test/src/test_stack_vector.cc | 142 +- lib/utils/test/src/test_tuple.cc | 118 +- lib/utils/test/src/test_type_index.cc | 42 +- lib/utils/test/src/test_undirected_graph.cc | 54 +- lib/utils/test/src/test_variant.cc | 106 +- lib/utils/test/src/test_vector.cc | 46 +- 149 files changed, 3709 insertions(+), 11505 deletions(-) create mode 100644 .flake/patches/doctest-template-test.patch delete mode 100644 .flake/pkgs/tokenizers-cpp.nix create mode 100755 .github/workflows/helpers/build_libs.sh rename .github/workflows/helpers/{build_cuda.sh => cmake_cuda.sh} (67%) create mode 100755 .github/workflows/helpers/test_libs.sh delete mode 100644 cmake/doctest.cmake create mode 100644 cmake/doctestlib.cmake delete mode 100644 lib/compiler/src/old/basic_graph.h delete mode 100644 lib/compiler/src/old/dominators.h delete mode 100644 lib/compiler/src/old/graph.cc delete mode 100644 lib/compiler/src/old/graph.h delete mode 100644 lib/compiler/src/old/graph_structures.h delete mode 100644 lib/compiler/src/old/node.h delete mode 100644 lib/compiler/src/old/parallel_dim_mapping_record.h delete mode 100644 lib/compiler/src/old/search_helper.cc delete mode 100644 lib/compiler/src/old/search_helper.h delete mode 100644 lib/compiler/src/old/simplification.cc delete mode 100644 lib/compiler/src/old/simplification.h delete mode 100644 lib/compiler/src/old/split_types.cc delete mode 100644 lib/compiler/src/old/split_types.h delete mode 100644 lib/compiler/src/old/substitution.cc delete mode 100644 lib/compiler/src/old/substitution.h rename lib/compiler/src/utils/{recursive_logger.cc => recursive_logger.cc.todo} (100%) rename lib/compiler/src/utils/{recursive_logger.h => recursive_logger.h.todo} (100%) rename lib/compiler/test/{ => src}/test_cost_estimator.h (100%) create mode 100644 lib/compiler/test/src/test_generator.h create mode 100644 lib/compiler/test/src/test_labelled_open_graph.cc create mode 100644 lib/compiler/test/src/test_machine_mapping.cc create mode 100644 lib/compiler/test/src/test_open_graph.cc create mode 100644 lib/compiler/test/src/test_optimal_cost.cc create mode 100644 lib/compiler/test/src/test_unity_algorithm.cc delete mode 100644 lib/compiler/test/test_disjoint_set.cc delete mode 100644 lib/compiler/test/test_dominators.cc delete mode 100644 lib/compiler/test/test_dot.cc delete mode 100644 lib/compiler/test/test_dp.cc delete mode 100644 lib/compiler/test/test_generator.h delete mode 100644 lib/compiler/test/test_labelled_open_graph.cc delete mode 100644 lib/compiler/test/test_machine_mapping.cc delete mode 100644 lib/compiler/test/test_machine_view.cc delete mode 100644 lib/compiler/test/test_open_graph.cc delete mode 100644 lib/compiler/test/test_optimal_cost.cc delete mode 100644 lib/compiler/test/test_parallel_config.cc delete mode 100644 lib/compiler/test/test_random_utils.cc delete mode 100644 lib/compiler/test/test_substitution_loader.cc delete mode 100644 lib/compiler/test/test_unity_algorithm.cc create mode 100644 lib/pcg/src/parallel_computation_graph.cc delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open.decl.h delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open.h delete mode 100644 lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h create mode 100644 lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h create mode 100644 lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h create mode 100644 lib/utils/test/src/test_cow_ptr.cc create mode 100644 lib/utils/test/src/test_hash.cc diff --git a/.flake/patches/doctest-template-test.patch b/.flake/patches/doctest-template-test.patch new file mode 100644 index 0000000000..ca4d0d9a18 --- /dev/null +++ b/.flake/patches/doctest-template-test.patch @@ -0,0 +1,50 @@ +diff --git a/scripts/cmake/doctestAddTests.cmake b/scripts/cmake/doctestAddTests.cmake +index 3b25485..d3ba906 100644 +--- a/scripts/cmake/doctestAddTests.cmake ++++ b/scripts/cmake/doctestAddTests.cmake +@@ -56,12 +56,14 @@ foreach(line ${output}) + if("${line}" STREQUAL "===============================================================================" OR "${line}" MATCHES [==[^\[doctest\] ]==]) + continue() + endif() +- set(test ${line}) ++ set(unescaped_test ${line}) ++ # use escape commas to handle properly test cases with commas inside the name ++ string(REPLACE "," "\\," escaped_test ${unescaped_test}) + set(labels "") + if(${add_labels}) + # get test suite that test belongs to + execute_process( +- COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" --test-case=${test} --list-test-suites ++ COMMAND ${TEST_EXECUTOR} "${TEST_EXECUTABLE}" --test-case=${escaped_test} --list-test-suites + OUTPUT_VARIABLE labeloutput + RESULT_VARIABLE labelresult + WORKING_DIRECTORY "${TEST_WORKING_DIR}" +@@ -85,24 +87,22 @@ foreach(line ${output}) + + if(NOT "${junit_output_dir}" STREQUAL "") + # turn testname into a valid filename by replacing all special characters with "-" +- string(REGEX REPLACE "[/\\:\"|<>]" "-" test_filename "${test}") ++ string(REGEX REPLACE "[/\\:\"|<>]" "-" test_filename "${unescaped_test}") + set(TEST_JUNIT_OUTPUT_PARAM "--reporters=junit" "--out=${junit_output_dir}/${prefix}${test_filename}${suffix}.xml") + else() + unset(TEST_JUNIT_OUTPUT_PARAM) + endif() +- # use escape commas to handle properly test cases with commas inside the name +- string(REPLACE "," "\\," test_name ${test}) + # ...and add to script + add_command(add_test +- "${prefix}${test}${suffix}" ++ "${prefix}${unescaped_test}${suffix}" + ${TEST_EXECUTOR} + "${TEST_EXECUTABLE}" +- "--test-case=${test_name}" ++ "--test-case=${escaped_test}" + "${TEST_JUNIT_OUTPUT_PARAM}" + ${extra_args} + ) + add_command(set_tests_properties +- "${prefix}${test}${suffix}" ++ "${prefix}${unescaped_test}${suffix}" + PROPERTIES + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + ${properties} diff --git a/.flake/pkgs/tokenizers-cpp.nix b/.flake/pkgs/tokenizers-cpp.nix deleted file mode 100644 index a705667ae6..0000000000 --- a/.flake/pkgs/tokenizers-cpp.nix +++ /dev/null @@ -1,43 +0,0 @@ -{ lib -, stdenv -, fetchFromGitHub -, cmake -, rustc -, cargo -}: - -stdenv.mkDerivation rec { - pname = "tokenizers-cpp"; - version = "2024-03-13"; - - src = fetchFromGitHub { - owner = "mlc-ai"; - repo = "tokenizers-cpp"; - rev = "4f42c9fa74946d70af86671a3804b6f2433e5dac"; - sha256 = "sha256-p7OYx9RVnKUAuMexy3WjW2zyfMJ/Q9ss4xFLsbQK7wA="; - fetchSubmodules = true; - }; - - nativeBuildInputs = [ - cmake - rustc - ]; - - # cmakeFlags = [ - # "-DLegion_USE_Python=1" - # "-DLegion_BUILD_BINDINGS=1" - # "-DLegion_USE_CUDA=1" - # "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" - # ]; - - buildInputs = [ ]; - # python3 - # cudatoolkit - # ]; - - meta = with lib; { - description = "Universal cross-platform tokenizers binding to HF and sentencepiece"; - homepage = "https://github.com/mlc-ai/tokenizers-cpp"; - license = licenses.asl20; - }; -} diff --git a/.github/workflows/helpers/build_libs.sh b/.github/workflows/helpers/build_libs.sh new file mode 100755 index 0000000000..cc4e25cc0b --- /dev/null +++ b/.github/workflows/helpers/build_libs.sh @@ -0,0 +1,9 @@ +#! /usr/bin/env bash + +set -euo pipefail + +DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" +REPO="$(realpath -- "$DIR/../../../")" + +cd "$REPO/build-ci" +make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "$@" diff --git a/.github/workflows/helpers/build_cuda.sh b/.github/workflows/helpers/cmake_cuda.sh similarity index 67% rename from .github/workflows/helpers/build_cuda.sh rename to .github/workflows/helpers/cmake_cuda.sh index 3524f885a7..e549859a5a 100755 --- a/.github/workflows/helpers/build_cuda.sh +++ b/.github/workflows/helpers/cmake_cuda.sh @@ -8,22 +8,21 @@ REPO="$(realpath -- "$DIR/../../../")" export FF_GPU_BACKEND="cuda" export FF_CUDA_ARCH=70 -cd "$REPO" -mkdir build -cd build + +if [[ -d "$REPO/build-ci" ]]; then + rm -rf "$REPO/build-ci" +fi +mkdir "$REPO/build-ci" +cd "$REPO/build-ci" #if [[ "${FF_GPU_BACKEND}" == "cuda" ]]; then # export FF_BUILD_ALL_EXAMPLES=ON # export FF_BUILD_UNIT_TESTS=ON #fi +IFS=" " read -r -a FLAGS <<< "$CMAKE_FLAGS" ../config/config.linux \ - -DCMAKE_CXX_COMPILER="clang++" \ - -DCMAKE_C_COMPILER="clang" \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache \ - -DFF_USE_EXTERNAL_LEGION=ON \ - -DFF_USE_EXTERNAL_JSON=ON \ - -DFF_USE_EXTERNAL_FMT=ON \ - -DFF_USE_EXTERNAL_SPDLOG=ON + "${FLAGS[@]}" # vim: set tabstop=2 shiftwidth=2 expandtab: diff --git a/.github/workflows/helpers/test_libs.sh b/.github/workflows/helpers/test_libs.sh new file mode 100755 index 0000000000..7662a7e601 --- /dev/null +++ b/.github/workflows/helpers/test_libs.sh @@ -0,0 +1,14 @@ +#! /usr/bin/env bash + +set -euo pipefail +set -x + +DIR="$(realpath -- "$(dirname "${BASH_SOURCE[0]}")")" +REPO="$(realpath -- "$DIR/../../../")" + +TEST_LIBS=("${@/%/-tests}") +REGEX="^$(IFS='|'; echo "${TEST_LIBS[*]}")\$" + +cd "$REPO/build-ci" +make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) "${TEST_LIBS[@]}" +ctest --progress --output-on-failure -L "$REGEX" diff --git a/.github/workflows/per-lib-check.yml b/.github/workflows/per-lib-check.yml index fa8252bc20..874a298587 100644 --- a/.github/workflows/per-lib-check.yml +++ b/.github/workflows/per-lib-check.yml @@ -20,6 +20,9 @@ jobs: with: submodules: recursive + - name: Add helpers directory to path + run: echo "${PWD}/.github/workflows/helpers" >> $GITHUB_PATH + - name: Install nix uses: cachix/install-nix-action@v25 with: @@ -51,24 +54,40 @@ jobs: - name: Run cmake run: | - .github/workflows/helpers/build_${{ matrix.gpu_backend }}.sh + cmake_${{ matrix.gpu_backend }}.sh - name: Build utils run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) utils + build_libs.sh utils - name: Build op-attrs run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) op-attrs + build_libs.sh op-attrs - name: Build pcg run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) pcg + build_libs.sh pcg - name: Build kernels run: | - cd build - make -j $(( $(nproc) < 2 ? 1 : $(nproc)-1 )) kernels + build_libs.sh kernels + + - name: Build substitutions + run: | + build_libs.sh substitutions + + - name: Build compiler + run: | + build_libs.sh compiler + + - name: Test utils + run: | + test_libs.sh utils + + - name: Test substitutions + run: | + test_libs.sh substitutions + + - name: Test compiler + run: | + test_libs.sh compiler diff --git a/CMakeLists.txt b/CMakeLists.txt index e04aa622c2..032bf1ac55 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,7 +84,7 @@ include(nccl) include(json) include(expected) include(spdlog) -include(doctest) +include(doctestlib) # named doctestlib to avoid a name collision with doctest.cmake in rapidcheck include(visit_struct) include(CTest) include(fmt) diff --git a/cmake/doctest.cmake b/cmake/doctest.cmake deleted file mode 100644 index b2d5243574..0000000000 --- a/cmake/doctest.cmake +++ /dev/null @@ -1,9 +0,0 @@ -include(aliasing) - -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest) -include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake) - -add_library(doctest-ff INTERFACE) -target_compile_definitions(doctest-ff INTERFACE DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS) -target_link_libraries(doctest-ff INTERFACE doctest::doctest) -alias_library(doctest doctest-ff) diff --git a/cmake/doctestlib.cmake b/cmake/doctestlib.cmake new file mode 100644 index 0000000000..5f29d94fd0 --- /dev/null +++ b/cmake/doctestlib.cmake @@ -0,0 +1,11 @@ +include(aliasing) + +if (FF_USE_EXTERNAL_DOCTEST) + find_package(doctest REQUIRED) + include(doctest) # import doctest_discover_tests +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest) + include(${CMAKE_CURRENT_SOURCE_DIR}/deps/doctest/scripts/cmake/doctest.cmake) +endif() + +alias_library(doctest doctest::doctest) diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index d41573acab..4cf5450942 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -118,7 +118,9 @@ function(ff_add_test_executable) ${FF_TEST_EXEC_NAME} ${FF_TEST_EXEC_DEPS}) + target_compile_definitions(${FF_TEST_EXEC_NAME} PRIVATE FF_TEST_SUITE="${FF_TEST_EXEC_NAME}") + define_ff_vars(${FF_TEST_EXEC_NAME}) ff_set_cxx_properties(${FF_TEST_EXEC_NAME}) - doctest_discover_tests(${FF_TEST_EXEC_NAME}) + doctest_discover_tests(${FF_TEST_EXEC_NAME} ADD_LABELS 1) endfunction() diff --git a/cmake/fmt.cmake b/cmake/fmt.cmake index 283caad69d..470de6a847 100644 --- a/cmake/fmt.cmake +++ b/cmake/fmt.cmake @@ -4,6 +4,5 @@ if (FF_USE_EXTERNAL_FMT) find_package(fmt REQUIRED) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/fmt) - - alias_library(fmt fmt::fmt) endif() +alias_library(fmt fmt::fmt) diff --git a/cmake/nccl.cmake b/cmake/nccl.cmake index e89bee04c6..755fe00f1b 100644 --- a/cmake/nccl.cmake +++ b/cmake/nccl.cmake @@ -8,6 +8,7 @@ else() message(STATUS "Building NCCL from source") list(TRANSFORM CUDA_GENCODE PREPEND "NVCC_GENCODE=" OUTPUT_VARIABLE NCCL_BUILD_NVCC_GENCODE) + include(ExternalProject) ExternalProject_Add(nccl_source_build SOURCE_DIR ${PROJECT_SOURCE_DIR}/deps/${NCCL_NAME} PREFIX ${CMAKE_BINARY_DIR}/deps/${NCCL_NAME} diff --git a/cmake/rapidcheck.cmake b/cmake/rapidcheck.cmake index 1ff64bd974..bf8f058e63 100644 --- a/cmake/rapidcheck.cmake +++ b/cmake/rapidcheck.cmake @@ -1 +1,5 @@ -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/rapidcheck) +if (FF_USE_EXTERNAL_RAPIDCHECK) + find_package(rapidcheck REQUIRED) +else() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/rapidcheck) +endif() diff --git a/cmake/spdlog.cmake b/cmake/spdlog.cmake index cd18944460..5ba1d6cc15 100644 --- a/cmake/spdlog.cmake +++ b/cmake/spdlog.cmake @@ -4,6 +4,9 @@ if (FF_USE_EXTERNAL_SPDLOG) find_package(spdlog REQUIRED) else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/deps/spdlog) - - alias_library(spdlog spdlog::spdlog) endif() + +add_library(ff_spdlog INTERFACE) +target_link_libraries(ff_spdlog INTERFACE spdlog::spdlog) +target_compile_definitions(ff_spdlog INTERFACE SPDLOG_FMT_EXTERNAL) +alias_library(spdlog ff_spdlog) diff --git a/flake.nix b/flake.nix index 3d357ca86c..d402d3c271 100644 --- a/flake.nix +++ b/flake.nix @@ -13,7 +13,6 @@ ]; }; - # Nixpkgs / NixOS version to use. inputs = { nixpkgs.url = "nixpkgs/nixos-23.11"; flake-utils.url = "github:numtide/flake-utils"; @@ -25,51 +24,88 @@ inherit system; config.allowUnfree = true; }; + lib = pkgs.lib; mkShell = pkgs.mkShell.override { - stdenv = pkgs.llvmPackages.libcxxStdenv; + stdenv = pkgs.cudaPackages.backendStdenv; }; in - { - packages = { - legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + { + packages = { + legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + rapidcheckFull = pkgs.symlinkJoin { + name = "rapidcheckFull"; + paths = (with pkgs; [ rapidcheck.out rapidcheck.dev ]); }; + doctest = pkgs.doctest.overrideAttrs ( old: rec { + version = "2.4.9"; + src = pkgs.fetchFromGitHub { + owner = "doctest"; + repo = "doctest"; + rev = "v${version}"; + sha256 = "sha256-ugmkeX2PN4xzxAZpWgswl4zd2u125Q/ADSKzqTfnd94="; + }; + patches = [ + ./.flake/patches/doctest-template-test.patch + ]; + }); + }; - devShells = rec { - ci = mkShell { - buildInputs = (with pkgs; [ - llvmPackages_17.clang - cmakeCurses - gcc10Stdenv - gcc10 - ccache - cudatoolkit + devShells = rec { + ci = mkShell { + shellHook = '' + export PATH="$HOME/ff/.scripts/:$HOME/ff/.modules/proj/bin/:$PATH" + ''; + + CMAKE_FLAGS = lib.strings.concatStringsSep " " [ + "-DFF_USE_EXTERNAL_LEGION=ON" + "-DFF_USE_EXTERNAL_NCCL=ON" + "-DFF_USE_EXTERNAL_JSON=ON" + "-DFF_USE_EXTERNAL_FMT=ON" + "-DFF_USE_EXTERNAL_SPDLOG=ON" + "-DFF_USE_EXTERNAL_DOCTEST=ON" + "-DFF_USE_EXTERNAL_RAPIDCHECK=ON" + "-DFF_USE_EXTERNAL_RANGEV3=ON" + "-DFF_USE_EXTERNAL_BOOST_PREPROCESSOR=ON" + "-DFF_USE_EXTERNAL_TYPE_INDEX=ON" + ]; + + buildInputs = builtins.concatLists [ + (with pkgs; [ zlib - pkg-config - python3 - self.packages.${system}.legion + boost nlohmann_json spdlog range-v3 - rapidcheck - doctest fmt + cmakeCurses + ccache + pkg-config + python3 + cudatoolkit cudaPackages.cuda_nvcc cudaPackages.cudnn cudaPackages.nccl cudaPackages.libcublas cudaPackages.cuda_cudart - ]) ++ (with pkgs.python3Packages; [ - ]); + ]) + (with self.packages.${system}; [ + legion + rapidcheckFull + doctest + ]) + ]; }; default = mkShell { inputsFrom = [ ci ]; - + inherit (ci) CMAKE_FLAGS; + buildInputs = builtins.concatLists [ (with pkgs; [ - clang-tools_17 + clang-tools gh-markdown-preview + shellcheck plantuml gdb ruff @@ -96,4 +132,3 @@ } ); } -# vim: set tabstop=2 shiftwidth=2 expandtab: diff --git a/lib/compiler/CMakeLists.txt b/lib/compiler/CMakeLists.txt index daa96b08bc..a2933efa50 100644 --- a/lib/compiler/CMakeLists.txt +++ b/lib/compiler/CMakeLists.txt @@ -11,9 +11,10 @@ ff_add_library( op-attrs utils json - optional pcg spdlog + substitutions ) add_subdirectory(ffi) +add_subdirectory(test) diff --git a/lib/compiler/include/compiler/compiler.h b/lib/compiler/include/compiler/compiler.h index 3a75e3a9bf..a4f7b0ecd3 100644 --- a/lib/compiler/include/compiler/compiler.h +++ b/lib/compiler/include/compiler/compiler.h @@ -12,8 +12,8 @@ enum class SearchAlgorithm { DATA_PARALLEL, }; -using SearchAlgorithmConfig = variant<>; -using SearchSolution = variant<>; +using SearchAlgorithmConfig = std::variant<>; +using SearchSolution = std::variant<>; struct SearchResult { ParallelComputationGraph pcg; diff --git a/lib/compiler/include/compiler/cost_estimate.h b/lib/compiler/include/compiler/cost_estimate.h index 27f963db50..557f51a7ca 100644 --- a/lib/compiler/include/compiler/cost_estimate.h +++ b/lib/compiler/include/compiler/cost_estimate.h @@ -16,10 +16,11 @@ struct ICostEstimator { MachineView const &src, MachineView const &dst) const = 0; + ICostEstimator() = default; ICostEstimator(ICostEstimator const &) = delete; ICostEstimator &operator=(ICostEstimator const &) = delete; - virtual ~ICostEstimator(); + virtual ~ICostEstimator() = default; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); @@ -44,6 +45,8 @@ struct CostEstimator { } private: + CostEstimator(std::shared_ptr implementation_ptr) + : implementation_ptr(implementation_ptr) {} std::shared_ptr implementation_ptr; }; diff --git a/lib/compiler/include/compiler/machine_mapping.h b/lib/compiler/include/compiler/machine_mapping.h index 4089260735..8b21b9522f 100644 --- a/lib/compiler/include/compiler/machine_mapping.h +++ b/lib/compiler/include/compiler/machine_mapping.h @@ -5,10 +5,13 @@ #include "pcg/machine_specification.h" #include "pcg/machine_view.h" #include "pcg/parallel_computation_graph.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" namespace FlexFlow { +using SubParallelComputationGraphView = + OutputLabelledOpenMultiDiGraphView; + struct MachineMapping { static MachineMapping combine(MachineMapping const &, MachineMapping const &); static bool nodes_are_disjoint(MachineMapping const &m1, @@ -21,13 +24,14 @@ FF_VISITABLE_STRUCT(MachineMapping, machine_views); struct OptimalCostState { SerialParallelDecomposition subgraph; MachineSpecification resource; - req> source_machine_view, sink_machine_view; + std::unordered_map given_machine_views; + req> frontier_machine_views; }; FF_VISITABLE_STRUCT(OptimalCostState, subgraph, resource, - source_machine_view, - sink_machine_view); + given_machine_views, + frontier_machine_views); struct OptimalCostResult { static OptimalCostResult sequential_combine(OptimalCostResult const &s1, @@ -37,7 +41,7 @@ struct OptimalCostResult { static OptimalCostResult infinity(); float runtime; - MachineMapping machine_mapping; + req machine_mapping; }; FF_VISITABLE_STRUCT(OptimalCostResult, runtime, machine_mapping); @@ -49,7 +53,7 @@ class OptimalCostCache { public: OptimalCostCache() = default; - optional load(OptimalCostState const &) const; + std::optional load(OptimalCostState const &) const; void save(OptimalCostState const &, OptimalCostResult const &); private: @@ -67,4 +71,15 @@ OptimalCostResult } // namespace FlexFlow +namespace std { + +template <> +struct hash> { + size_t operator()( + std::unordered_map const &g) + const; +}; + +}; // namespace std + #endif diff --git a/lib/compiler/include/compiler/unity_algorithm.h b/lib/compiler/include/compiler/unity_algorithm.h index 57f1c8c063..7d7a7a74dc 100644 --- a/lib/compiler/include/compiler/unity_algorithm.h +++ b/lib/compiler/include/compiler/unity_algorithm.h @@ -4,17 +4,16 @@ #include "cost_estimate.h" #include "machine_mapping.h" #include "pcg/computation_graph.h" -#include "sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph.h" namespace FlexFlow { -struct Substitution {}; - struct Strategy { ParallelComputationGraph pcg; MachineMapping machine_mapping; req runtime; }; + FF_VISITABLE_STRUCT(Strategy, pcg, machine_mapping, runtime); struct StrategyRuntimeCmp { @@ -30,7 +29,7 @@ struct OptimizerConfig { Strategy graph_optimize(ComputationGraph &cg, - ICostEstimator const &cost_estimator, + CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( Operator const &, MachineSpecification const &)> const diff --git a/lib/compiler/src/graph_utils.cc b/lib/compiler/src/graph_utils.cc index 4f22490ffa..5b76beb8c0 100644 --- a/lib/compiler/src/graph_utils.cc +++ b/lib/compiler/src/graph_utils.cc @@ -4,7 +4,16 @@ namespace FlexFlow { SerialParallelDecomposition get_serial_parallel_decomposition(ParallelComputationGraph const &pcg) { - return get_serial_parallel_decomposition(as_digraph(pcg)); + return get_serial_parallel_decomposition(pcg.value()); +} + +ParallelComputationGraph cg_to_pcg(ComputationGraph const &g) { + NOT_IMPLEMENTED(); +} + +SubParallelComputationGraphView + pcg_to_subpcg(ParallelComputationGraph const &pcg) { + return view_output_labelled_as_output_labelled_open(pcg.value()); } std::vector @@ -45,7 +54,7 @@ std::unordered_map } } - assert(result.size() == get_edges(pcg).size()); + assert(result.size() == get_edges(pcg.value()).size()); return result; } @@ -116,14 +125,14 @@ std::unordered_set get_nodes(SerialParallelDecomposition const &sp) { std::unordered_set get_nodes(Serial const &serial) { return set_union( - transform(serial.children, [](variant const child) { + transform(serial.children, [](std::variant const child) { return visit(GetNodes{}, child); })); } std::unordered_set get_nodes(Parallel const ¶llel) { return set_union( - transform(parallel.children, [](variant const child) { + transform(parallel.children, [](std::variant const child) { return visit(GetNodes{}, child); })); } diff --git a/lib/compiler/src/graph_utils.h b/lib/compiler/src/graph_utils.h index 88515ef950..711a253b61 100644 --- a/lib/compiler/src/graph_utils.h +++ b/lib/compiler/src/graph_utils.h @@ -9,7 +9,8 @@ SerialParallelDecomposition get_serial_parallel_decomposition(ParallelComputationGraph const &pcg); ParallelComputationGraph cg_to_pcg(ComputationGraph const &g); -SubParallelComputationGraph pcg_to_subpcg(ParallelComputationGraph const &g); +SubParallelComputationGraphView + pcg_to_subpcg(ParallelComputationGraph const &g); // NOTE(@wmdi): I think we should have the following interfaces in the graph // library eventually. diff --git a/lib/compiler/src/machine_mapping.cc b/lib/compiler/src/machine_mapping.cc index 2f6af8a62b..2b08e9fe23 100644 --- a/lib/compiler/src/machine_mapping.cc +++ b/lib/compiler/src/machine_mapping.cc @@ -43,12 +43,13 @@ bool OptimalCostRuntimeCmp::operator()(OptimalCostResult const &lhs, return lhs.runtime < rhs.runtime; } -optional +std::optional OptimalCostCache::load(OptimalCostState const &state) const { if (contains_key(cache, state)) { - return make_optional(cache.at(state)); + OptimalCostResult result = cache.at(state); + return std::make_optional(result); } - return nullopt; + return std::nullopt; } void OptimalCostCache::save(OptimalCostState const &state, @@ -88,201 +89,218 @@ GraphSplit return {get_nodes(pre_decomposition), get_nodes(post_decomposition)}; } -std::pair - apply_split(SubParallelComputationGraph const &g, GraphSplit const &split) { - OpenMultiDiGraphView g1 = get_subgraph(g, split.first); - OpenMultiDiGraphView g2 = get_subgraph(g, split.second); - - if (get_edge_splits(g, split).size() > 0) { - // Sequential split - if (get_open_sinks(g1).size() <= get_open_sources(g2).size()) { - // get_open_sinks(*g1).size() should be 1 in perfect sp graphs - return {get_subgraph(g, split.first), - get_subgraph(g, split.second)}; - } else { - return {get_subgraph(g, split.first), - get_subgraph(g, split.first)}; - } - } else { - // Parallel split - return {get_subgraph(g, split.first), - get_subgraph(g, split.second)}; - } -} - -float estimate_cost(SubParallelComputationGraph const &g, +float estimate_cost(SubParallelComputationGraphView const &g, CostEstimator const &estimator, - MachineMapping const &device_mapping) { - NOT_IMPLEMENTED(); + MachineMapping const &device_mapping, + std::unordered_map const + &frontier_machine_views) { + // TODO: Consider parallelism + float cost = 0; + for (Node const &node : get_nodes(g)) { + std::unordered_set incoming_edges = + get_incoming_edges(g, node); + std::vector inputs = + transform(as_vector(incoming_edges), + [&](UpwardOpenMultiDiEdge const &input_edge) { + return g.at(input_edge).get_shape(); + }); + cost += estimator.estimate_cost( + g.at(node).attrs, inputs, device_mapping.machine_views.at(node)); + } + return cost; } void minimize_runtime(OptimalCostResult &m1, OptimalCostResult const &m2) { minimize(m1, m2, OptimalCostRuntimeCmp{}); } -struct OptimalCost { - OptimalCost( - SubParallelComputationGraph const &g, - CostEstimator const &cost_estimator, - MachineSpecification const &resource, - optional const &source_machine_view, // assume perfect SP - optional const &sink_machine_view, +struct MachineMappingSearcher { + MachineMappingSearcher( + CostEstimator cost_estimator, std::function( Operator const &, MachineSpecification const &)> const &allowed_machine_views, OptimalCostCache &cached_subgraph_costs) - : g(g), cost_estimator(cost_estimator), resource(resource), - source_machine_view(source_machine_view), - sink_machine_view(sink_machine_view), + : cost_estimator(cost_estimator), allowed_machine_views(allowed_machine_views), cached_subgraph_costs(cached_subgraph_costs) {} - SubParallelComputationGraph const &g; - CostEstimator const &cost_estimator; - MachineSpecification const &resource; - optional const &source_machine_view; - optional const &sink_machine_view; - std::function( - Operator const &, MachineSpecification const &)> const - &allowed_machine_views; + CostEstimator cost_estimator; + std::function(Operator const &, + MachineSpecification const &)> + allowed_machine_views; OptimalCostCache &cached_subgraph_costs; - template - OptimalCostResult operator()(T const &t) const { - OptimalCostState state{g, resource, source_machine_view, sink_machine_view}; - optional cached_result = - cached_subgraph_costs.load(state); + struct OptimalCostFunctor { + OptimalCostFunctor( + MachineMappingSearcher *searcher, + SubParallelComputationGraphView const &g, + MachineSpecification resource, + std::unordered_map given_machine_views, + std::unordered_map frontier_machine_views) + : searcher(searcher), g(g), resource(resource), + given_machine_views(given_machine_views), + frontier_machine_views(frontier_machine_views) {} + + MachineMappingSearcher *searcher; + SubParallelComputationGraphView const &g; + MachineSpecification resource; + std::unordered_map given_machine_views; + std::unordered_map frontier_machine_views; + + template + OptimalCostResult operator()(T const &t) { + OptimalCostState state{ + t, resource, given_machine_views, frontier_machine_views}; + std::optional cached_result = + searcher->cached_subgraph_costs.load(state); + + if (cached_result) { + return cached_result.value(); + } + OptimalCostResult result = searcher->optimal_cost( + t, g, resource, given_machine_views, frontier_machine_views); - if (cached_result) { - return cached_result.value(); + searcher->cached_subgraph_costs.save(state, result); + return result; } - - OptimalCostResult result = this->optimal_cost(t); - - cached_subgraph_costs.save(state, result); - return result; + }; + + OptimalCostResult + optimal_cost(SubParallelComputationGraphView const &g, + MachineSpecification resource, + SerialParallelDecomposition const &sp_decomposition) { + return visit(OptimalCostFunctor(this, g, resource, {}, {}), + sp_decomposition); } - OptimalCostResult optimal_cost(Serial const &serial) const { + OptimalCostResult optimal_cost( + Serial const &serial, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { + auto decomposed = decompose(serial); SerialParallelDecomposition pre_decompn = decomposed.first; SerialParallelDecomposition post_decompn = decomposed.second; - auto subgraphs = apply_split(g, get_graph_split(pre_decompn, post_decompn)); - SubParallelComputationGraph pre_graph = subgraphs.first, - post_graph = subgraphs.second; + GraphSplit graph_split = get_graph_split(pre_decompn, post_decompn); + SubParallelComputationGraphView pre_graph = + get_subgraph(g, graph_split.first); + SubParallelComputationGraphView post_graph = + get_subgraph(g, graph_split.second); - std::unordered_set pre_graph_sinks = get_closed_sinks(pre_graph); std::unordered_set post_graph_sources = get_closed_sources(post_graph); - assert(pre_graph_sinks.size() + post_graph_sources.size() == - 1); // assume perfect SP + assert(post_graph_sources.size() == 1); // assume perfect SP - Node const &split_point = - get_only(set_union(pre_graph_sinks, post_graph_sources)); + Node split_point = get_only(post_graph_sources); + OutputMultiDiEdge split_edge = get_only(get_open_outputs(pre_graph)); OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (MachineView const &mv : allowed_machine_views(g.at(split_point), resource)) { - optional pre_sink_mv = - contains(pre_graph_sinks, split_point) ? make_optional(mv) : nullopt; - optional post_source_mv = - contains(post_graph_sources, split_point) ? make_optional(mv) - : nullopt; + std::unordered_map new_given_machine_views = + given_machine_views; + new_given_machine_views.emplace(split_point, mv); + std::unordered_map + new_frontier_machine_views = frontier_machine_views; + new_frontier_machine_views.emplace(split_edge, mv); minimize_runtime(optimal_result, OptimalCostResult::sequential_combine( - visit(OptimalCost(pre_graph, - cost_estimator, - resource, - source_machine_view, - pre_sink_mv, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + pre_graph, + resource, + given_machine_views, + new_frontier_machine_views), pre_decompn), - visit(OptimalCost(post_graph, - cost_estimator, - resource, - post_source_mv, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + post_graph, + resource, + new_given_machine_views, + frontier_machine_views), post_decompn))); } return optimal_result; } - OptimalCostResult optimal_cost(Parallel const ¶llel) const { + OptimalCostResult optimal_cost( + Parallel const ¶llel, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { auto decomposed = decompose(parallel); SerialParallelDecomposition decompn1 = decomposed.first; SerialParallelDecomposition decompn2 = decomposed.second; - auto subgraphs = apply_split(g, get_graph_split(decompn1, decompn2)); - SubParallelComputationGraph g1 = subgraphs.first, g2 = subgraphs.second; + GraphSplit graph_split = get_graph_split(decompn1, decompn2); + SubParallelComputationGraphView g1 = get_subgraph( + g, graph_split.first), + g2 = get_subgraph( + g, graph_split.second); OptimalCostResult optimal_result = OptimalCostResult::sequential_combine( - visit(OptimalCost(g1, - cost_estimator, - resource, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g1, + resource, + given_machine_views, + frontier_machine_views), decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g2, + resource, + given_machine_views, + frontier_machine_views), decompn2)); for (auto const &resource_split : get_resource_split(resource)) { minimize_runtime(optimal_result, OptimalCostResult::parallel_combine( - visit(OptimalCost(g1, - cost_estimator, - resource_split.first, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g1, + resource_split.first, + given_machine_views, + frontier_machine_views), decompn1), - visit(OptimalCost(g2, - cost_estimator, - resource_split.second, - source_machine_view, - sink_machine_view, - allowed_machine_views, - cached_subgraph_costs), + visit(OptimalCostFunctor(this, + g2, + resource_split.second, + given_machine_views, + frontier_machine_views), decompn2))); } return optimal_result; } - OptimalCostResult optimal_cost(Node const &node) const { - if (source_machine_view) { - assert(get_closed_sources(g).empty()); - assert(contains(allowed_machine_views(g.at(node), resource), - source_machine_view.value())); - MachineMapping mv_map{{{node, source_machine_view.value()}}}; - return {estimate_cost(g, cost_estimator, mv_map), mv_map}; - } else if (sink_machine_view) { - assert(get_closed_sinks(g).empty()); + OptimalCostResult optimal_cost( + Node const &node, + SubParallelComputationGraphView const &g, + MachineSpecification const &resource, + std::unordered_map const &given_machine_views, + std::unordered_map const + &frontier_machine_views) { + if (contains_key(given_machine_views, node)) { assert(contains(allowed_machine_views(g.at(node), resource), - sink_machine_view.value())); - MachineMapping mv_map{{{node, sink_machine_view.value()}}}; - return {estimate_cost(g, cost_estimator, mv_map), mv_map}; + given_machine_views.at(node))); + MachineMapping mv_map{given_machine_views}; + return {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), + mv_map}; } else { OptimalCostResult optimal_result = OptimalCostResult::infinity(); for (auto mv : allowed_machine_views(g.at(node), resource)) { MachineMapping mv_map{{{node, mv}}}; - minimize_runtime(optimal_result, - {estimate_cost(g, cost_estimator, mv_map), mv_map}); + minimize_runtime( + optimal_result, + {estimate_cost(g, cost_estimator, mv_map, frontier_machine_views), + mv_map}); } return optimal_result; } @@ -297,14 +315,12 @@ OptimalCostResult CostEstimator const &cost_estimator, MachineSpecification const &resources, OptimalCostCache &cached_subgraph_costs) { - return visit(OptimalCost(pcg_to_subpcg(g), - cost_estimator, - resources, - nullopt, - nullopt, - allowed_machine_views, - cached_subgraph_costs), - get_serial_parallel_decomposition(g)); + SerialParallelDecomposition sp_decomposition = + get_serial_parallel_decomposition(g); + SubParallelComputationGraphView subpcg = pcg_to_subpcg(g); + MachineMappingSearcher searcher( + cost_estimator, allowed_machine_views, cached_subgraph_costs); + return searcher.optimal_cost(subpcg, resources, sp_decomposition); } } // namespace FlexFlow diff --git a/lib/compiler/src/old/basic_graph.h b/lib/compiler/src/old/basic_graph.h deleted file mode 100644 index fca575e42a..0000000000 --- a/lib/compiler/src/old/basic_graph.h +++ /dev/null @@ -1,158 +0,0 @@ -#ifndef _BASIC_GRAPH_H -#define _BASIC_GRAPH_H - -#include "utils/hash-utils.h" -#include -#include - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template -struct GraphStructure; -/* -{ - using graph_type = ...; - using node_type = - using tGraph = G; - using tNode = N; - using tEdge = E; - - std::unordered_set get_nodes(G const &) const; - std::unordered_set get_incoming_edges(G const &, N const &) const; - std::unordered_set get_outgoing_edges(G const &, N const &) const; - N get_src(G const &, E const &) const; - N get_dst(G const &, E const &) const; -}; -*/ - -template -struct BasicGraph { - using N = T; - using E = std::pair; - - std::unordered_set nodes; - std::unordered_map> in_edges, out_edges; - - BasicGraph() : BasicGraph({}, {}) {} - - BasicGraph(std::unordered_set const &nodes, std::unordered_set edges) - : nodes(), in_edges(), out_edges() { - this->add_nodes(nodes); - this->add_edges(edges); - } - - void add_edge(N const &src, N const &dst) { - nodes.insert(src); - nodes.insert(dst); - out_edges[src].insert({src, dst}); - in_edges[dst].insert({src, dst}); - } - - void add_edge(E const &e) { - nodes.insert(e.first); - nodes.insert(e.second); - out_edges[e.first].insert(e); - in_edges[e.second].insert(e); - } - - bool has_edge(N const &src, N const &dst) const { - auto iter = this->in_edges.find(dst); - if (iter == this->in_edges.end()) { - return false; - } - - std::unordered_set const &dst_in_edges = iter->second; - return dst_in_edges.find({src, dst}) != dst_in_edges.end(); - } - - bool has_edge(E const &e) const { - return this->has_edge(e.first, e.second); - } - - void remove_edge(N const &src, N const &dst) { - out_edges[src].erase({src, dst}); - in_edges[dst].erase({src, dst}); - } - - void remove_edge(E const &e) { - out_edges[e.first].erase(e); - in_edges[e.second].erase(e); - } - - void add_node(N const &n) { - nodes.insert(n); - } - - template > - void add_nodes(Container const &nodes) { - for (auto const &n : nodes) { - this->add_node(n); - } - } - - template > - void add_edges(Container const &edges) { - for (auto const &e : edges) { - this->add_edge(e); - } - } - - bool operator==(BasicGraph const &other) const { - return this->nodes == other.nodes && this->in_edges == other.in_edges && - this->out_edges == other.out_edges; - } -}; - -template -struct GraphStructure> { - using graph_type = BasicGraph; - using vertex_type = T; - using edge_type = std::pair; - - std::unordered_set get_nodes(graph_type const &g) const { - std::unordered_set nodes(g.nodes); - return nodes; - } - - std::unordered_set get_incoming_edges(graph_type const &g, - vertex_type const &n) const { - std::unordered_set edges; - if (g.in_edges.find(n) != g.in_edges.end()) { - edges.insert(g.in_edges.at(n).begin(), g.in_edges.at(n).end()); - } - return edges; - } - - std::unordered_set get_outgoing_edges(graph_type const &g, - vertex_type const &n) const { - std::unordered_set edges; - if (g.out_edges.find(n) != g.out_edges.end()) { - edges.insert(g.out_edges.at(n).begin(), g.out_edges.at(n).end()); - } - return edges; - } - - vertex_type get_src(graph_type const &g, edge_type const &e) const { - return e.first; - } - - vertex_type get_dst(graph_type const &g, edge_type const &e) const { - return e.second; - } - - void set_src(graph_type const &g, edge_type &e, vertex_type const &n) const { - e.first = n; - } - - void set_dst(graph_type const &g, edge_type &e, vertex_type const &n) const { - e.second = n; - } -}; - -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -#endif // _BASIC_GRAPH_H diff --git a/lib/compiler/src/old/dominators.h b/lib/compiler/src/old/dominators.h deleted file mode 100644 index 70449ee001..0000000000 --- a/lib/compiler/src/old/dominators.h +++ /dev/null @@ -1,494 +0,0 @@ -#ifndef _DOMINATORS_H -#define _DOMINATORS_H - -#include "basic_graph.h" -#include "graph_structures.h" -#include "tl/optional.hpp" -#include "utils/dot_file.h" -#include "utils/record_formatter.h" -#include -#include -#include -#include - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template > -std::unordered_set nodes(G const &g) { - Structure s; - - return s.get_nodes(g); -} - -template > -bool has_edge(G const &g, - typename Structure::vertex_type const &src, - typename Structure::vertex_type const &dst) { - Structure s; - - for (auto const &e : s.get_outgoing_edges(g, src)) { - if (s.get_dst(g, e) == dst) { - return true; - } - } - - return false; -} - -template > -std::unordered_set - outgoing_edges(G const &g, typename Structure::vertex_type const &n) { - Structure s; - return s.get_outgoing_edges(g, n); -} - -template > -std::pair - get_basic_edge(G const &g, typename Structure::edge_type const &e) { - Structure s; - - return {s.get_src(g, e), s.get_dst(g, e)}; -} - -template > -std::vector get_edges(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - - std::vector edges; - - for (N const &n : s.get_nodes(g)) { - for (E const &e : s.get_outgoing_edges(g, n)) { - edges.push_back(e); - } - } - - return edges; -} - -template > -void successors(G const &g, - typename Structure::vertex_type const &node, - std::unordered_set *succ) { - Structure s; - for (auto const &edge : s.get_outgoing_edges(g, node)) { - succ->insert(s.get_dst(g, edge)); - } -} - -template > -std::unordered_set - successors(G const &g, typename Structure::vertex_type const &node) { - // using N = typename Structure::vertex_type; - - std::unordered_set succ; - successors(g, node, &succ); - - return succ; -} - -template > -tl::optional - successor(G const &g, typename Structure::vertex_type const &node) { - auto succs = successors(g, node); - if (succs.size() == 1) { - return *succs.begin(); - } else { - return tl::nullopt; - } -} - -template > -void predecessors(G const &g, - typename Structure::vertex_type const &node, - std::unordered_set *pred) { - Structure s; - for (auto const &edge : s.get_incoming_edges(g, node)) { - pred->insert(s.get_src(g, edge)); - } -} - -template > -std::unordered_set - predecessors(G const &g, typename Structure::vertex_type const &node) { - // using N = typename Structure::vertex_type; - - std::unordered_set pred; - predecessors(g, node, &pred); - - return pred; -} - -template > -tl::optional - predecessor(G const &g, typename Structure::vertex_type const &node) { - auto preds = predecessors(g, node); - if (preds.size() == 1) { - return *preds.begin(); - } else { - return tl::nullopt; - } -} - -template > -std::unordered_set roots(G const &g) { - using N = typename Structure::vertex_type; - - Structure s; - - std::unordered_set nodes = s.get_nodes(g); - std::unordered_set roots; - for (auto const &node : nodes) { - if (s.get_incoming_edges(g, node).empty()) { - roots.insert(node); - } - } - - return roots; -} - -template > -std::unordered_set leaves(G const &g) { - return roots>(g); -} - -template > -void topo_sort(G const &g, - std::vector *ordering) { - using N = typename Structure::vertex_type; - - Structure s; - std::unordered_map> predecessors; - - std::queue q; - for (auto const &node : s.get_nodes(g)) { - predecessors[node]; - for (auto const &edge : s.get_incoming_edges(g, node)) { - predecessors.at(node).insert(s.get_src(g, edge)); - } - } - - for (auto it = predecessors.begin(); it != predecessors.end();) { - if (it->second.empty()) { - q.push(it->first); - it = predecessors.erase(it); - } else { - it++; - } - } - - std::unordered_set node_successors; - while (!q.empty()) { - N const ¤t = q.front(); - - ordering->push_back(current); - - node_successors.clear(); - successors(g, current, &node_successors); - for (auto const &succ : node_successors) { - if (predecessors.find(succ) != predecessors.end()) { - predecessors.at(succ).erase(current); - if (predecessors.at(succ).empty()) { - predecessors.erase(succ); - q.push(succ); - } - } - } - - q.pop(); - } -} - -template > -std::unordered_map> - dominators(G const &g) { - using N = typename Structure::vertex_type; - // using E = typename Structure::edge_type; - - // Structure s; - - std::vector nodes; - topo_sort(g, &nodes); - std::unordered_map> dom; - - std::unordered_set pred_part; - for (auto const &node : nodes) { - pred_part.clear(); - predecessors(g, node, &pred_part); - for (auto const &p : pred_part) { - if (dom.find(node) == dom.end()) { - dom[node] = dom.at(p); - } else { - auto &node_dom_set = dom.at(node); - auto const &p_dom_set = dom.at(p); - for (auto it = node_dom_set.begin(); it != node_dom_set.end();) { - if (p_dom_set.find(*it) == p_dom_set.end()) { - it = node_dom_set.erase(it); - } else { - it++; - } - } - } - } - dom[node].insert(node); - } - - return dom; -} - -template > -std::unordered_map> - post_dominators(G const &g) { - return dominators>(g); -} - -template > -std::unordered_map - imm_dominators(G const &g) { - using N = typename Structure::vertex_type; - // using E = typename Structure::edge_type; - - std::vector topo; - topo_sort(g, &topo); - std::unordered_map topo_rank; - for (int i = 0; i < (int)topo.size(); i++) { - topo_rank[topo[i]] = i; - } - std::unordered_map> dom = - dominators(g); - - std::unordered_map imm_dom; - for (auto const &kv : dom) { - N const &n = kv.first; - std::unordered_set const &n_doms = kv.second; - - // if a node is only dominated by itself, set the dominator to itself to - // signify that it has no immediate dominator - if (n_doms.size() == 1) { - imm_dom[n] = n; - continue; - } - - N const *n_imm_dom = nullptr; - int current_topo_rank = std::numeric_limits::min(); - for (auto const &d : n_doms) { - if (topo_rank.at(d) > current_topo_rank && d != n) { - n_imm_dom = &d; - current_topo_rank = topo_rank.at(d); - } - } - imm_dom[n] = *n_imm_dom; - } - - return imm_dom; -} - -template > -void dfs(G const &g, - typename Structure::vertex_type const &n, - std::function const - &visitor) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - - /* auto i_visitor = std::bind(visitor, g, s, n); */ - auto i_visitor = [&](N const &nn) { return visitor(g, s, n, nn); }; - - std::queue q; - std::unordered_set visited; - - auto visit = [&](N const &n) { - if (visited.find(n) == visited.end()) { - q.push(n); - visited.insert(n); - } - }; - - visit(n); - - while (!q.empty()) { - N current = q.front(); - q.pop(); - - i_visitor(current); - - for (E const &edge : s.get_outgoing_edges(g, current)) { - N const &dst = s.get_dst(g, edge); - visit(dst); - } - } - - return; -} - -template > -std::unordered_set - descendants(G const &g, typename Structure::vertex_type const &n) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - std::unordered_set descendants; - - auto dfs_visitor = [&](G const &gg, - Structure const &ss, - N const &dfs_src, - N const ¤t_node) { - descendants.insert(current_node); - }; - - dfs(g, n, dfs_visitor); - - return descendants; -} - -template > -std::vector> - weakly_connected_components(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - std::vector> result; - std::unordered_set seen; - - for (N const &n : nodes>(g)) { - if (seen.find(n) != seen.end()) { - continue; - } - - auto component = descendants>(g, n); - seen.insert(component.begin(), component.end()); - result.emplace_back(component); - } - - return result; -} - -template > -std::unordered_map - imm_post_dominators(G const &g) { - return imm_dominators>(g); -} - -template > -BasicGraph transitive_reduction(G const &g) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - Structure s; - BasicGraph reduction; - - std::unordered_set nodes = s.get_nodes(g); - - reduction.add_nodes(nodes); - - std::unordered_set> to_delete; - - auto dfs_visitor = [&](N const &src, - G const &gg, - Structure const &ss, - N const &dfs_src, - N const &nn) { - if (nn != dfs_src && to_delete.find({src, nn}) == to_delete.end() && - has_edge(gg, src, nn)) { - to_delete.insert({src, nn}); - } - }; - - for (N const &n : nodes) { - /* auto n_dfs_visitor = std::bind(dfs_visitor, n); */ - auto n_dfs_visitor = - [&](G const &gg, Structure const &ss, N const &dfs_src, N const &nn) { - return dfs_visitor(n, gg, ss, dfs_src, nn); - }; - - for (N const &child : successors(g, n)) { - dfs(g, child, n_dfs_visitor); - } - } - - for (E const &e : get_edges(g)) { - std::pair basic_edge = get_basic_edge(g, e); - - if (to_delete.find(basic_edge) == to_delete.end()) { - reduction.add_edge(basic_edge); - } - } - - return reduction; -} - -template -void inplace_transitive_reduction(BasicGraph &g) { - using Structure = GraphStructure>; - using G = BasicGraph; - using E = std::pair; - - std::unordered_set to_delete; - - auto dfs_visitor = [&](N const &src, - G const &gg, - Structure const &ss, - N const &dfs_src, - N const &nn) { - if (nn != dfs_src && to_delete.find({src, nn}) == to_delete.end() && - has_edge(gg, src, nn)) { - to_delete.insert({src, nn}); - } - }; - - for (N const &n : g.nodes) { - auto n_dfs_visitor = - [&](G const &gg, Structure const &ss, N const &dfs_src, N const &nn) { - return dfs_visitor(n, gg, ss, dfs_src, nn); - }; - - for (N const &child : successors(g, n)) { - dfs(g, child, n_dfs_visitor); - } - } - - for (E const &e : to_delete) { - g.remove_edge(e); - } -}; - -template > -void export_as_dot( - DotFile &dotfile, - G const &g, - std::function const - &pretty) { - using N = typename Structure::vertex_type; - using E = typename Structure::edge_type; - - GraphStructure s; - - for (N const &n : s.get_nodes(g)) { - dotfile.add_record_node(n, pretty(n)); - - for (E const &edge : s.get_incoming_edges(g, n)) { - dotfile.add_edge(s.get_src(g, edge), s.get_dst(g, edge)); - } - } - - dotfile.close(); -} - -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -#endif // _DOMINATORS_H diff --git a/lib/compiler/src/old/graph.cc b/lib/compiler/src/old/graph.cc deleted file mode 100644 index 191b1028b7..0000000000 --- a/lib/compiler/src/old/graph.cc +++ /dev/null @@ -1,1255 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "graph.h" -#include "dominators.h" -#include "op-attrs/op-attrs.h" -#include "utils/disjoint_set.h" -#include "utils/unique.h" -#include - -// using FlexFlow::utils::Node; -// using FlexFlow::opmeta::OperatorParameters; - -namespace FlexFlow { - -ParallelComputationGraph::Graph(std::string const &logger_name) - : Graph(spdlog::get(logger_name)) {} - -ParallelComputationGraph::Graph(std::shared_ptr const &logger) - : logger(logger) {} - -Graph::Graph(utils::AdjacencyMultiDiGraph const &g, - utils::bidict const &nodeMap, - std::shared_ptr const &logger) - : g(g), nodeMap(nodeMap), logger(logger) {} - -/* using namespace Legion; */ -/* using FlexFlow::MachineView; */ - -/* LegionRuntime::Logger::Category log_graph("graph"); */ -/* LegionRuntime::Logger::Category log_simplify("graph_simplify"); */ - -void Graph::add_edge(Node const &srcOp, - Node const &dstOp, - int srcIdx, - int dstIdx) { - this->g.add_edge({srcOp, dstOp, (std::size_t)srcIdx, (std::size_t)dstIdx}); -} - -Node Graph::add_node(PCGOperatorAttrs const ¶ms) { - Node n = this->g.add_node(); - this->nodeMap.equate(n, params); - return n; -} - -void Graph::add_edge(utils::MultiDiEdge const &e) { - this->g.add_edge(e); -} - -void Graph::remove_edge(utils::MultiDiEdge const &e, - bool remove_node_if_unused) { - this->g.remove_edge(e); - utils::remove_node_if_unused(this->g, e.src); - utils::remove_node_if_unused(this->g, e.dst); -} - -bool Graph::has_edge(utils::MultiDiEdge const &e) const { - return utils::contains_edge(this->g, e); -} - -void Graph::print_dot() const { - this->print_dot(std::cout); -} - -void Graph::print_dot(std::ostream &s) const { - auto directed = unsafe_view_as_digraph(this->g); - - DotFile dot(s); - - export_as_dot(dot, directed, [&](utils::Node const &node) -> RecordFormatter { - RecordFormatter rf; - rf << node.to_string(); - tl::optional sub_rf = as_dot(this->nodeMap.at_l(node)); - if (sub_rf.has_value()) { - rf << sub_rf.value(); - } - - return rf; - }); - s << std::endl; -} - -bool Graph::has_loop() { - return !utils::is_acyclic(this->g).value_or(true); -} - -/* Node Graph::find_bottleneck_node(Node const &sink_node, */ -/* Node const &source_node) const { */ -/* using FlexFlow::PCG::Utils::GraphStructure; */ -/* using FlexFlow::PCG::Utils::imm_post_dominators; */ -/* using FlexFlow::PCG::Utils::MultisourceGraphStructure; */ -/* using FlexFlow::PCG::Utils::roots; */ - -/* Node source(source_node); */ -/* std::unordered_map ipd; */ -/* std::unordered_set graph_roots = roots(*this); */ -/* if (source_node != Node::INVALID_NODE) { */ -/* ipd = imm_post_dominators(*this); */ -/* } else if (graph_roots.size() == 1) { */ -/* ipd = imm_post_dominators(*this); */ -/* source = *graph_roots.begin(); */ -/* } else { */ -/* ipd = imm_post_dominators>(*this); */ -/* } */ - -/* Node bn_node = ipd.at(source); */ -/* if (bn_node == source || bn_node == sink_node) { */ -/* return Node::INVALID_NODE; */ -/* } */ - -/* return bn_node; */ -/* } */ - -Graph Graph::subgraph(std::unordered_set const &nodes) const { - AdjacencyMultiDiGraph sub_g = subgraph(this->g, nodes); - - bidict sub_nodeMap; - for (auto const &kv : this->nodeMap) { - if (contains(nodes, kv.first)) { - sub_nodeMap.equate(kv.first, kv.second); - } - } - - return {sub_g, sub_nodeMap, this->logger}; -} - -void Graph::remove_node(Node const &node, bool purge_edges) { - assert(purge_edges == true); - utils::remove_node(this->g, node); - this->nodeMap.erase_l(node); -} - -/*static*/ -Graph Graph::singleton(PCGOperatorAttrs const ¶ms) { - Graph g; - g.add_node(params); - return g; -} - -bool Graph::empty() const { - return utils::empty(this->g); -} - -void Graph::replace_subgraph(std::unordered_set const ¤tNodes, - Graph const &replaceWith) { - assert(currentNodes.size() > 0); - if (replaceWith.empty()) { - Graph subgraph = this->subgraph(currentNodes); - assert(!subgraph.empty()); - Node source_node = subgraph.find_source_node(); - Node noop = - this->model->get_or_create_noop_node(source_node.ptr->inputs[0]); - this->replace_subgraph_with_nonempty(currentNodes, - Graph::singleton(this->model, noop)); - this->contract_out_node(noop); - } else { - this->replace_subgraph_with_nonempty(currentNodes, replaceWith); - } -} - -void Graph::replace_subgraph_with_nonempty( - std::unordered_set const ¤tNodes, Graph const &replaceWith) { - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::nodes; - - Node new_sink_node = replaceWith.find_sink_node(); - - Graph old_subgraph = this->subgraph(currentNodes); - Node old_sink_node = old_subgraph.find_sink_node(); - Node old_source_node = old_subgraph.find_source_node(); - - std::unordered_set all_nodes = nodes(*this); - - for (Edge const &old_inner_edge : get_edges(old_subgraph)) { - this->remove_edge(old_inner_edge, false); - } - for (Edge const &new_inner_edge : get_edges(replaceWith)) { - this->add_edge(new_inner_edge); - } - - std::unordered_set old_in_edges = this->inEdges.at(old_source_node); - if (!old_in_edges.empty()) { - Node new_source_node = replaceWith.find_source_node(); - for (Edge const &old_in_edge : old_in_edges) { - Edge new_in_edge(old_in_edge); - new_in_edge.dstOp = new_source_node; - this->remove_edge(old_in_edge, false); - this->add_edge(new_in_edge); - } - } - - std::unordered_set old_out_edges = this->outEdges.at(old_sink_node); - for (Edge const &old_out_edge : old_out_edges) { - Edge new_out_edge(old_out_edge); - new_out_edge.srcOp = new_sink_node; - this->remove_edge(old_out_edge, false); - this->add_edge(new_out_edge); - } - - for (Node const &node : currentNodes) { - this->remove_node(node); - } - - assert(this->check_correctness()); -} - -void Graph::contract_out_node(Node const &node) { - contract_node(this->g, node); - this->nodeMap.erase_l(node); -} - -/* std::pair, std::unique_ptr> */ -/* Graph::split_at_node(Node const &bottleneck) const { */ -/* using FlexFlow::PCGe:Utils::topo_sort; */ - -/* auto first_graph = std::unique_ptr(new Graph(this->model)); */ -/* auto second_graph = std::unique_ptr(new Graph(this->model)); */ - -/* std::unordered_set used_nodes; */ -/* { */ -/* std::vector topo_sorted; */ -/* topo_sort(*this, &topo_sorted); */ - -/* for (auto const &node : topo_sorted) { */ -/* if (node == bottleneck) { */ -/* break; */ -/* } */ - -/* used_nodes.insert(node); */ -/* } */ -/* used_nodes.insert(bottleneck); */ - -/* assert(used_nodes.size() < topo_sorted.size()); */ -/* } */ - -/* for (auto const &it : this->inEdges) { */ -/* auto const &inList = it.second; */ -/* if (used_nodes.find(it.first) != used_nodes.end()) { */ -/* // Add all in-edges of used_nodes in to the first_graph */ -/* for (auto const &it2 : inList) { */ -/* first_graph->add_edge(it2); */ -/* } */ -/* } else { */ -/* // Add all in-edges of not_used_nodes into the second_graph */ -/* for (auto const &it2 : inList) { */ -/* second_graph->add_edge(it2); */ -/* } */ -/* } */ -/* } */ - -/* return {std::move(first_graph), std::move(second_graph)}; */ -/* } */ - -void Graph::remove_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - for (auto const &n : nodes(*this)) { - if (n.ptr->op_type == OP_INPUT) { - this->remove_node(n, true /*purge_edges*/); - } - } -} - -Node Graph::clone_node(Node const &n) { - Node cloned = n; - cloned.original_guid = n.guid; - cloned.guid = this->model->node_global_guid++; - this->add_node(cloned); - return cloned; -} - -Node Graph::declone_node(Node const &n) { - assert(n.original_guid.has_value()); - Node decloned = n; - decloned.guid = n.original_guid.value(); - decloned.original_guid = tl::nullopt; - this->add_node(decloned); - return decloned; -} - -std::pair> - Graph::deduplicate_input_node(Node const &n) { - using FlexFlow::PCG::Utils::nodes; - using FlexFlow::PCG::Utils::outgoing_edges; - - assert(n.original_guid.has_value()); - std::unordered_set old_all_nodes = nodes(*this); - Node decloned = this->declone_node(n); - - std::unordered_set old_nodes; - std::unordered_set new_edges; - for (Node const &nn : old_all_nodes) { - if (nn.original_guid == n.original_guid) { - old_nodes.insert(nn); - for (Edge const &e : outgoing_edges(*this, nn)) { - Edge decloned_edge(e); - decloned_edge.replace_node(nn, decloned); - new_edges.insert(decloned_edge); - } - this->remove_node(nn, true /*purge_edges*/); - } - } - - for (Edge const &e : new_edges) { - this->add_edge(e); - } - - return {decloned, old_nodes}; -} - -std::unordered_map Graph::deduplicate_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - std::unordered_map deduplication_map; - - bool done; - while (true) { - done = true; - for (Node const &n : nodes(*this)) { - if (n.original_guid.has_value()) { - done = false; - auto kv = this->deduplicate_input_node(n); - for (auto const &r : kv.second) { - deduplication_map[r] = kv.first; - } - break; - } - } - if (done) { - break; - } - } - - return deduplication_map; -} - -void Graph::duplicate_input_node(Node const &n) { - using FlexFlow::PCG::Utils::outgoing_edges; - using FlexFlow::PCG::Utils::successors; - - assert(n.ptr->op_type == OP_INPUT); - - std::unordered_map clones; - - for (auto const &s : successors(*this, n)) { - clones[s] = this->clone_node(n); - } - - for (auto const &e : outgoing_edges(*this, n)) { - Edge cloned(e); - cloned.srcOp = clones.at(e.dstOp); - this->add_edge(cloned); - } - this->remove_node(n, true /*purge_edges*/); -} - -void Graph::duplicate_input_nodes() { - using FlexFlow::PCG::Utils::nodes; - - for (auto const &n : nodes(*this)) { - if (n.ptr->op_type == OP_INPUT) { - this->duplicate_input_node(n); - } - } -} - -std::pair, std::unique_ptr> - Graph::split_horizontal(Node const &source_node, - Node const &sink_node) const { - using FlexFlow::PCG::Utils::weakly_connected_components; - - Graph trimmed_graph(*this); - assert(sink_node != - Node::INVALID_NODE); // sink node should never be invalid node - if (source_node != Node::INVALID_NODE) { - trimmed_graph.remove_node(source_node, true /*purge_edges*/); - } - trimmed_graph.remove_node(sink_node, true /*purge_edges*/); - std::vector> wccs = - weakly_connected_components(trimmed_graph); - assert(wccs.size() >= 2); - std::unordered_set first_branch = wccs.back(); - wccs.pop_back(); - std::unordered_set rest; - for (auto const &wcc : wccs) { - rest.insert(wcc.begin(), wcc.end()); - } - if (source_node != Node::INVALID_NODE) { - first_branch.insert(source_node); - rest.insert(source_node); - } - first_branch.insert(sink_node); - rest.insert(sink_node); - - auto first_graph = - std::unique_ptr(new Graph(this->subgraph(first_branch))); - auto second_graph = std::unique_ptr(new Graph(this->subgraph(rest))); - - return {std::move(first_graph), std::move(second_graph)}; -} - -GraphCostResult GraphCostResult::invalid() { - return {std::numeric_limits::infinity(), {}}; -} - -bool GraphCostResult::operator<(GraphCostResult const &other) const { - return this->cost < other.cost; -} - -std::ostream &operator<<(std::ostream &s, GraphCostResult const &r) { - s << "GraphCostResult{cost=" << r.cost << "}"; - return s; -} - -std::ostream &operator<<(std::ostream &s, GraphOptimizeResult const &r) { - s << "GraphOptimizeResult{cost=" << r.cost << "}"; - return s; -} - -template <> -GraphCostResult sequence_cost(GraphCostResult const &first, - GraphCostResult const &second) { - GraphCostResult result(first); - result.cost += second.cost; - result.views.insert(second.views.cbegin(), second.views.cend()); - return result; -} - -template <> -float sequence_cost(float const &first, float const &second) { - return first + second; -} - -template <> -GraphOptimizeResult - sequence_cost(GraphOptimizeResult const &first, - GraphOptimizeResult const &second) { - GraphOptimizeResult result; - result.cost = first.cost + second.cost; - result.views.insert(first.views.cbegin(), first.views.cend()); - result.views.insert(second.views.cbegin(), second.views.cend()); - - result.graph = second.graph; - Node second_src = result.graph.value().find_source_node(); - result.graph.value().replace_subgraph({second_src}, first.graph.value()); - return result; -} - -template <> -GraphCostResult parallel_cost(GraphCostResult const &first, - GraphCostResult const &second) { - GraphCostResult result; - result.cost = std::max(first.cost, second.cost); - result.views.insert(first.views.cbegin(), first.views.cend()); - result.views.insert(second.views.cbegin(), second.views.cend()); - - return result; -} - -template <> -float parallel_cost(float const &first, float const &second) { - return std::max(first, second); -} - -float Graph::optimal_cost() const { - return this->generic_optimal_cost(); -} - -std::unordered_map Graph::optimal_views() const { - return this->generic_optimal_cost().views; -} - -Graph Graph::reduced() const { - using FlexFlow::PCG::Utils::BasicGraph; - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::transitive_reduction; - - BasicGraph transitive_skeleton = transitive_reduction(*this); - - Graph reduced_graph(this->model); - - for (Edge const &e : get_edges(*this)) { - if (transitive_skeleton.has_edge(e.srcOp, e.dstOp)) { - reduced_graph.add_edge(e); - } - } - - return reduced_graph; -} - -/** - * @brief A generic cost function for a graph capable of finding both the cost - * and the optimal views - * - * @note A templated function is used here because while the caching behaviors - * of the cost and the optimal views are different, much of the code between the - * two versions is almost identical. By using a few template specializations we - * can avoid duplicating all this code. - * - * @tparam T the result type (can be either float or GraphCostResult) - * @return T the cost of the graph (along with any additional data in the return - * type) - */ -template -T Graph::generic_optimal_cost() const { - using FlexFlow::PCG::Utils::GraphStructure; - - Graph reduced_graph = this->reduced(); - // GraphStructure s; - // if (source_node.ptr->op_type == OP_INPUT) { - // for (auto const &e : s.get_outgoing_edges(reduced_graph, source_node)) { - // reduced_graph.remove_edge(e, false/*remove_node_if_unused*/); - // } - // reduced_graph.remove_node(source_node); - // } - - Node sink_node = reduced_graph.find_sink_node(); - this->search->logger->info() << "Found sink node: " << sink_node.to_string(); - - MachineResource resource(model->config); - - std::vector valid_views = - search->get_valid_machine_views(sink_node, resource, true); - - T optimal = search->infinity(); - - this->search->logger->info() - << "Exploring " << valid_views.size() << " valid views"; - for (MachineView const &sink_view : valid_views) { - this->search->logger->info() << " Exploring valid view " << sink_view; - T new_cost = - search->graph_cost(&reduced_graph, - {Node::INVALID_NODE, MachineView::NO_VIEW}, - {sink_node, sink_view}, - resource, - true); - if (new_cost < optimal) { - optimal = new_cost; - } - } - - return optimal; -} - -size_t Graph::hash(void) const { - // Graph hash should be additive and independent to the ordering of the nodes - size_t total_hash = 0; - for (auto const &it : inEdges) { - auto const &inList = it.second; - size_t node_hash = std::hash()((size_t)it.first.ptr); - for (auto const &e : inList) { - size_t edge_hash = 17; - edge_hash = edge_hash * 31 + std::hash()((size_t)e.srcOp.ptr); - edge_hash = edge_hash * 31 + std::hash()(e.srcIdx); - edge_hash = edge_hash * 31 + std::hash()(e.dstIdx); - node_hash *= edge_hash; - } - total_hash += node_hash; - } - return total_hash; -} - -size_t dp_state_hash(Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resource) { - size_t key = graph->hash(); - hash_combine(key, sink_node.ptr); - hash_combine(key, sink_view.hash()); - hash_combine(key, source_node.ptr); - hash_combine(key, resource.hash()); - return key; -} - -GraphOptimalViewSerialized - Graph::graph_optimize_task(Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - FFModel *model = *((FFModel **)task->args); - if (model->config.search_num_nodes.has_value()) { - model->config.numNodes = model->config.search_num_nodes.value(); - } - if (model->config.search_num_workers.has_value()) { - model->config.workersPerNode = model->config.search_num_workers.value(); - } - model->all_valid_views.clear(); - model->register_all_machine_views(model->config.numNodes, - model->config.workersPerNode, - model->config.cpusPerNode, - model->all_valid_views); - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - MachineModel *machine; - if (model->config.machine_model_version == 0) { - machine = - (MachineModel *)new SimpleMachineModel(model->config.numNodes, - model->config.workersPerNode, - gpu_mem.capacity()); - } else if (model->config.machine_model_version == 1 and - !model->config.machine_model_file.empty()) { - machine = (MachineModel *)new EnhancedMachineModel( - model->config.machine_model_file, gpu_mem.capacity()); - } else { - assert(false && - "machine model creation error: currently only support " - "machine-model-version = 0 or 1. When machine-model-version = 1, " - "machine-model-file should not be empty."); - } - model->simulator = - make_unique(model, model->handlers[0], gpu_mem, machine); - std::unique_ptr best_graph; - std::unordered_map optimal_views; - if (model->config.only_data_parallel) { - Graph *graph = new Graph(model); - std::unordered_map op_to_node_map; - for (FlexFlow::Op const *dstOp : model->operators) { - Node dstNode; - dstNode.ptr = dstOp; - dstNode.guid = model->node_global_guid++; - op_to_node_map[dstOp] = dstNode; - for (int j = 0; j < dstOp->numInputs; j++) { - FlexFlow::Op const *srcOp = dstOp->inputs[j]->owner_op; - assert(op_to_node_map.find(srcOp) != op_to_node_map.end()); - Node srcNode = op_to_node_map[srcOp]; - graph->add_edge(srcNode, dstNode, dstOp->inputs[j]->owner_idx, j); - } - } - best_graph = std::unique_ptr(graph); - MachineView data_parallel_view; - data_parallel_view.device_type = MachineView::GPU; - data_parallel_view.ndims = 1; - data_parallel_view.dim[0] = - model->config.numNodes * model->config.workersPerNode; - data_parallel_view.stride[0] = 1; - data_parallel_view.start_device_id = 0; - for (auto const &node : best_graph->inEdges) { - optimal_views[node.first] = data_parallel_view; - } - } else { - model->graph_optimize(model->config.search_budget, - model->config.only_data_parallel, - best_graph, - optimal_views); - } - /* Serializer sez; */ - /* // First serialize graph */ - /* sez.serialize(best_graph->inEdges.size()); */ - /* std::unordered_map todos; */ - /* std::vector opList; */ - /* for (auto const &it : best_graph->inEdges) { */ - /* auto const &inList = it.second; */ - /* todos[it.first] = (int)inList.size(); */ - /* if (todos[it.first] == 0) { */ - /* opList.push_back(it.first); */ - /* } */ - /* } */ - /* size_t node_idx = 0; */ - /* while (node_idx < opList.size()) { */ - /* Node cur_node = opList[node_idx++]; */ - /* auto const &outList = best_graph->outEdges[cur_node]; */ - /* for (auto const &e : outList) { */ - /* todos[e.dstOp]--; */ - /* if (todos[e.dstOp] == 0) { */ - /* opList.push_back(e.dstOp); */ - /* } */ - /* } */ - /* auto const &inList = best_graph->inEdges[cur_node]; */ - /* sez.serialize(inList.size()); */ - /* for (auto const &e : inList) { */ - /* sez.serialize(e.srcOp.guid); */ - /* assert(e.dstOp.guid == cur_node.guid); */ - /* sez.serialize(e.srcIdx); */ - /* sez.serialize(e.dstIdx); */ - /* } */ - /* sez.serialize((size_t)10101010); // safe guard for the end of inedges */ - /* Op const *op = cur_node.ptr; */ - /* assert(op != NULL); */ - /* sez.serialize(cur_node.guid); */ - /* sez.serialize(op->op_type); */ - /* switch (op->op_type) { */ - /* case OP_INPUT: { */ - /* assert(op->numOutputs == 1); */ - /* NoOp *noop = (NoOp *)op; */ - /* sez.serialize(noop->op_type); */ - /* sez.serialize(noop->input_tensor_guid); */ - /* sez.serialize(noop->outputs[0]->data_type); */ - /* sez.serialize(noop->outputs[0]->num_dims); */ - /* for (int i = 0; i < noop->outputs[0]->num_dims; i++) { */ - /* sez.serialize(noop->outputs[0]->dims[i]); */ - /* } */ - /* break; */ - /* } */ - /* case OP_NOOP: { */ - /* break; */ - /* } */ - /* case OP_CONCAT: { */ - /* Concat *concat = (Concat *)op; */ - /* sez.serialize(concat->legion_axis); */ - /* break; */ - /* } */ - /* case OP_SPLIT: { */ - /* Split *split = (Split *)op; */ - /* sez.serialize(split->legion_axis); */ - /* sez.serialize(split->numOutputs); */ - /* for (int i = 0; i < split->numOutputs; i++) { */ - /* sez.serialize(split->outputs[i]->dims[split->legion_axis].size); */ - /* } */ - /* break; */ - /* } */ - /* case OP_EMBEDDING: { */ - /* Embedding *embed = (Embedding *)op; */ - /* sez.serialize(embed->layer_guid.id); */ - /* sez.serialize(embed->num_entries); */ - /* sez.serialize(embed->out_channels); */ - /* sez.serialize(embed->aggr); */ - /* sez.serialize(embed->data_type); */ - /* break; */ - /* } */ - /* case OP_EW_ADD: */ - /* case OP_EW_SUB: */ - /* case OP_EW_MUL: */ - /* case OP_EW_MAX: */ - /* case OP_EW_MIN: { */ - /* sez.serialize(op->op_type); */ - /* break; */ - /* } */ - /* case OP_MULTIHEAD_ATTENTION: { */ - /* MultiHeadAttention *attn = (MultiHeadAttention *)op; */ - /* sez.serialize(attn->layer_guid.id); */ - /* sez.serialize(attn->oProjSize); */ - /* sez.serialize(attn->num_heads); */ - /* sez.serialize(attn->qProjSize); */ - /* sez.serialize(attn->vProjSize); */ - /* sez.serialize(attn->dropout); */ - /* sez.serialize(attn->bias); */ - /* sez.serialize(attn->add_bias_kv); */ - /* sez.serialize(attn->add_zero_attn); */ - /* break; */ - /* } */ - /* case OP_SOFTMAX: { */ - /* Softmax *softmax = (Softmax *)op; */ - /* sez.serialize(softmax->dim); */ - /* break; */ - /* } */ - /* case OP_REPARTITION: { */ - /* Repartition *repart = (Repartition *)op; */ - /* sez.serialize(repart->repartition_dim); */ - /* sez.serialize(repart->repartition_degree); */ - /* break; */ - /* } */ - /* case OP_REPLICATE: { */ - /* Replicate *replicate = (Replicate *)op; */ - /* sez.serialize(replicate->replicate_dim); */ - /* sez.serialize(replicate->replicate_degree); */ - /* break; */ - /* } */ - /* case OP_REDUCTION: { */ - /* Reduction *reduction = (Reduction *)op; */ - /* sez.serialize(reduction->reduction_dim); */ - /* sez.serialize(reduction->reduction_degree); */ - /* break; */ - /* } */ - /* case OP_COMBINE: { */ - /* Combine *combine = (Combine *)op; */ - /* sez.serialize(combine->combine_dim); */ - /* sez.serialize(combine->combine_degree); */ - /* break; */ - /* } */ - /* case OP_FUSED_PARALLEL: { */ - /* FusedParallelOp *fused = (FusedParallelOp *)op; */ - /* sez.serialize(fused->num_parallel_ops); */ - /* for (int i = 0; i < fused->num_parallel_ops; i++) { */ - /* sez.serialize(fused->parallel_ops[i]); */ - /* } */ - /* break; */ - /* } */ - /* default: { */ - /* op->serialize(sez); */ - /* } */ - /* } */ - /* sez.serialize((size_t)12345678); // safe guard for the end of an op */ - /* } */ - /* assert(node_idx == best_graph->inEdges.size()); */ - /* // Second, serialize optimal machine view */ - /* printf("opotimal_views.size = %zu\n", optimal_views.size()); */ - /* sez.serialize(optimal_views.size()); */ - /* for (auto const &it : optimal_views) { */ - /* sez.serialize((size_t)98765432); // safe guard */ - /* sez.serialize(it.first.guid); */ - /* sez.serialize(it.second); */ - /* } */ - /* assert(sez.get_used_bytes() < GraphOptimalViewSerialized::buffer_size); */ - /* GraphOptimalViewSerialized ret; */ - /* ret.total_bytes = sez.get_used_bytes(); */ - /* memcpy(ret.data, sez.get_buffer(), ret.total_bytes); */ - /* // Deallocate best_graph */ - /* // delete best_graph; */ - /* return ret; */ -} - -}; // namespace FlexFlow - -namespace FlexFlow { - -using PCG::Edge; -using PCG::Graph; -using PCG::GraphCostResult; -using PCG::Node; - -void FFModel::register_all_machine_views( - int num_nodes, - int gpus_per_node, - int cpus_per_node, - std::vector &valid_views) { - // Single-parallelism-dimension views - for (int i = 1; i <= num_nodes * gpus_per_node; i++) { - if (num_nodes * gpus_per_node % i == 0) { - MachineView view; - view.device_type = MachineView::GPU; - view.ndims = 1; - view.dim[0] = i; - view.stride[0] = 1; - view.start_device_id = 0; - valid_views.push_back(view); - } - } - // Two-dimensional views - /* for (int i = 1; i <= num_nodes; i++) { */ - /* for (int j = 1; j <= gpus_per_node; j++) { */ - /* MachineView view; */ - /* view.device_type = MachineView::GPU; */ - /* view.ndims = 2; */ - /* view.dim[0] = i; */ - /* view.stride[0] = 1; */ - /* view.dim[1] = j; */ - /* view.stride[1] = 1; */ - /* view.start_device_id = 0; */ - /* valid_views.push_back(view); */ - /* } */ - /* } */ -} - -float FFModel::graph_cost(Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resources, - bool include_sink_compute_time, - bool constructing_optimal_view) { - assert(!graph->inEdges.empty()); - - return this->search->graph_cost(graph, - {source_node, source_view}, - {sink_node, sink_view}, - resources, - include_sink_compute_time); -} - -void FFModel::construct_optimal_view( - Graph const *graph, - Node const &sink_node, - MachineView const &sink_view, - Node const &source_node, - MachineView const &source_view, - MachineResource const &resources, - bool include_sink_compute_time, - float optimal_cost, - std::unordered_map &optimal_views) { - GraphCostResult result = - this->search->graph_cost(graph, - {source_node, source_view}, - {sink_node, sink_view}, - resources, - include_sink_compute_time); - - optimal_views.insert(result.views.begin(), result.views.end()); -} - -/* void FFModel::deserialize_graph_optimal_view( */ -/* Legion::Deserializer &dez, */ -/* Graph *graph, */ -/* std::unordered_map &optimal_views) { */ -/* // Deserializer dez(serialized.data, serialized.total_bytes); */ -/* std::unordered_map guid_to_nodes; */ -/* size_t num_nodes; */ -/* dez.deserialize(num_nodes); */ -/* // best_graph = new Graph(this); */ -/* for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { */ -/* Edge inedges[MAX_NUM_INPUTS]; */ -/* ParallelTensor inputs[MAX_NUM_INPUTS]; */ -/* size_t num_inputs; */ -/* dez.deserialize(num_inputs); */ -/* for (size_t j = 0; j < num_inputs; j++) { */ -/* size_t src_guid; */ -/* int src_idx, dst_idx; */ -/* dez.deserialize(src_guid); */ -/* assert(guid_to_nodes.find(src_guid) != guid_to_nodes.end()); */ -/* dez.deserialize(src_idx); */ -/* dez.deserialize(dst_idx); */ -/* assert(dst_idx < (int)num_inputs); */ -/* inedges[dst_idx].srcOp = guid_to_nodes[src_guid]; */ -/* inedges[dst_idx].srcIdx = src_idx; */ -/* inedges[dst_idx].dstIdx = dst_idx; */ -/* inputs[dst_idx] = inedges[dst_idx].srcOp.ptr->outputs[src_idx]; */ -/* } */ -/* { */ -/* size_t safecode; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 10101010); */ -/* } */ -/* Node node = Node::INVALID_NODE; */ -/* size_t guid; */ -/* OperatorType op_type; */ -/* dez.deserialize(guid); */ -/* dez.deserialize(op_type); */ -/* switch (op_type) { */ -/* case OP_INPUT: { */ -/* assert(num_inputs == 0); */ -/* int num_dims; */ -/* ParallelDim dims[MAX_TENSOR_DIM]; */ -/* OperatorType op_type; */ -/* dez.deserialize(op_type); */ -/* size_t input_tensor_guid; */ -/* dez.deserialize(input_tensor_guid); */ -/* DataType data_type; */ -/* dez.deserialize(data_type); */ -/* dez.deserialize(num_dims); */ -/* for (int i = 0; i < num_dims; i++) { */ -/* dez.deserialize(dims[i]); */ -/* } */ -/* ParallelTensor t = */ -/* create_parallel_tensor_legion_ordering(num_dims, */ -/* dims, */ -/* data_type, */ -/* nullptr, */ -/* 0, */ -/* true create_grad, */ -/* input_tensor_guid); */ -/* node.ptr = t->owner_op; */ -/* node.guid = node_global_guid++; */ -/* break; */ -/* } */ -/* case OP_NOOP: { */ -/* assert(num_inputs == 1); */ -/* node = get_or_create_noop_node(inputs[0]); */ -/* break; */ -/* } */ -/* case OP_BATCHMATMUL: { */ -/* node = BatchMatmul::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_CAST: { */ -/* node = Cast::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_CONCAT: { */ -/* int legion_axis; */ -/* dez.deserialize(legion_axis); */ -/* node = get_or_create_node( */ -/* {std::begin(inputs), std::begin(inputs) + num_inputs}, */ -/* {legion_axis}); */ -/* break; */ -/* } */ -/* case OP_SPLIT: { */ -/* int legion_axis; */ -/* dez.deserialize(legion_axis); */ -/* int num_outputs; */ -/* dez.deserialize(num_outputs); */ -/* std::vector splits; */ -/* for (int i = 0; i < num_outputs; i++) { */ -/* int dim_size; */ -/* dez.deserialize(dim_size); */ -/* splits.push_back(dim_size); */ -/* } */ -/* node = get_or_create_node(inputs[0], {splits, legion_axis}); - */ -/* break; */ -/* } */ -/* case OP_EMBEDDING: { */ -/* assert(num_inputs == 1); */ -/* AggrMode aggr; */ -/* int num_entries, out_channels; */ -/* size_t id; */ -/* DataType data_type; */ -/* dez.deserialize(id); */ -/* LayerID layer_guid(id); */ -/* dez.deserialize(num_entries); */ -/* dez.deserialize(out_channels); */ -/* dez.deserialize(aggr); */ -/* dez.deserialize(data_type); */ - -/* EmbeddingParams params; */ -/* params.aggr = aggr; */ -/* params.num_entries = num_entries; */ -/* params.out_channels = out_channels; */ -/* params.layer_guid = layer_guid; */ -/* params.data_type = data_type; */ -/* node = get_or_create_node(inputs[0], params); */ -/* break; */ -/* } */ -/* case OP_EW_ADD: */ -/* case OP_EW_SUB: */ -/* case OP_EW_MUL: */ -/* case OP_EW_MAX: */ -/* case OP_EW_MIN: { */ -/* assert(num_inputs == 2); */ -/* OperatorType op_type; */ -/* dez.deserialize(op_type); */ -/* node = get_or_create_node({inputs[0], inputs[1]}, */ -/* {op_type}); */ -/* break; */ -/* } */ -/* case OP_CONV2D: { */ -/* node = Conv2D::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_DROPOUT: { */ -/* node = Dropout::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_EXP: */ -/* case OP_SIN: */ -/* case OP_COS: */ -/* case OP_SCALAR_MULTIPLY: */ -/* case OP_SCALAR_FLOOR_DIV: */ -/* case OP_SCALAR_TRUE_DIV: */ -/* case OP_SCALAR_ADD: */ -/* case OP_SCALAR_SUB: */ -/* case OP_RELU: */ -/* case OP_SIGMOID: */ -/* case OP_TANH: */ -/* case OP_POW: */ -/* case OP_IDENTITY: */ -/* case OP_GELU: */ -/* case OP_ELU: { */ -/* node = ElementUnary::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_FLAT: { */ -/* node = Flat::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_GATHER: { */ -/* node = Gather::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_LAYERNORM: { */ -/* node = LayerNorm::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_LINEAR: { */ -/* node = Linear::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_MULTIHEAD_ATTENTION: { */ -/* assert(num_inputs == 3); */ -/* int embed_dim, num_heads, k_dim, v_dim; */ -/* float dropout; */ -/* bool bias, add_bias_kv, add_zero_attn; */ -/* size_t id; */ -/* dez.deserialize(id); */ -/* LayerID layer_guid(id); */ -/* dez.deserialize(embed_dim); */ -/* dez.deserialize(num_heads); */ -/* dez.deserialize(k_dim); */ -/* dez.deserialize(v_dim); */ -/* dez.deserialize(dropout); */ -/* dez.deserialize(bias); */ -/* dez.deserialize(add_bias_kv); */ -/* dez.deserialize(add_zero_attn); */ - -/* MultiHeadAttentionParams params; */ -/* params.embed_dim = embed_dim; */ -/* params.num_heads = num_heads; */ -/* params.kdim = k_dim; */ -/* params.vdim = v_dim; */ -/* params.dropout = dropout; */ -/* params.bias = bias; */ -/* params.add_bias_kv = add_bias_kv; */ -/* params.add_zero_attn = add_zero_attn; */ -/* params.layer_guid = layer_guid; */ -/* node = get_or_create_node( */ -/* {inputs[0], inputs[1], inputs[2]}, params); */ -/* break; */ -/* } */ -/* case OP_TOPK: { */ -/* node = TopK::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_GROUP_BY: { */ -/* node = Group_by::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_AGGREGATE: { */ -/* // node = Aggregate::deserialize(*this, dez, inputs, num_inputs); */ -/* int n; */ -/* float lambda_bal; */ -/* dez.deserialize(n); */ -/* dez.deserialize(lambda_bal); */ -/* assert(num_inputs == n + 4); */ -/* AggregateParams params; */ -/* params.n = n; */ -/* params.lambda_bal = lambda_bal; */ -/* node = get_or_create_node( */ -/* {std::begin(inputs), std::begin(inputs) + num_inputs}, params); - */ -/* break; */ -/* } */ -/* case OP_POOL2D: { */ -/* node = Pool2D::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_REDUCE_SUM: { */ -/* node = Reduce::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_RESHAPE: { */ -/* node = Reshape::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_SOFTMAX: { */ -/* assert(num_inputs == 1); */ -/* int softmax_dim; */ -/* dez.deserialize(softmax_dim); */ -/* node = get_or_create_node(inputs[0], {softmax_dim}); */ -/* break; */ -/* } */ -/* case OP_TRANSPOSE: { */ -/* node = Transpose::deserialize(*this, dez, inputs, num_inputs); */ -/* break; */ -/* } */ -/* case OP_COMBINE: { */ -/* assert(num_inputs == 1); */ -/* int combine_dim, combine_degree; */ -/* dez.deserialize(combine_dim); */ -/* dez.deserialize(combine_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {combine_dim, combine_degree}); */ -/* break; */ -/* } */ -/* case OP_REPARTITION: { */ -/* assert(num_inputs == 1); */ -/* int repartition_dim, repartition_degree; */ -/* dez.deserialize(repartition_dim); */ -/* dez.deserialize(repartition_degree); */ -/* node = get_or_create_node( */ -/* inputs[0], {repartition_dim, repartition_degree}); */ -/* break; */ -/* } */ -/* case OP_REPLICATE: { */ -/* assert(num_inputs == 1); */ -/* int replicate_dim, replicate_degree; */ -/* dez.deserialize(replicate_dim); */ -/* dez.deserialize(replicate_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {replicate_dim, - * replicate_degree}); */ -/* break; */ -/* } */ -/* case OP_REDUCTION: { */ -/* assert(num_inputs == 1); */ -/* int reduction_dim, reduction_degree; */ -/* dez.deserialize(reduction_dim); */ -/* dez.deserialize(reduction_degree); */ -/* node = get_or_create_node(inputs[0], */ -/* {reduction_dim, - * reduction_degree}); */ -/* break; */ -/* } */ -/* case OP_FUSED_PARALLEL: { */ -/* assert(num_inputs == 1); */ -/* std::vector parallel_ops; */ -/* int num_parallel_ops; */ -/* dez.deserialize(num_parallel_ops); */ -/* for (int i = 0; i < num_parallel_ops; i++) { */ -/* ParallelOpInfo info; */ -/* dez.deserialize(info); */ -/* parallel_ops.push_back(info); */ -/* } */ -/* node = get_or_create_node(inputs[0], - * {parallel_ops}); */ -/* break; */ -/* } */ -/* default: { */ -/* fprintf(stderr, */ -/* "The following operator type is currently not supported" */ -/* " for graph deserialization: %s\n" */ -/* "Report the issue to the FlexFlow developers\n", */ -/* get_operator_type_name(op_type).c_str()); */ -/* assert(false && "Unsupported operator type"); */ -/* } */ -/* } */ -/* { */ -/* size_t safecode; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 12345678); */ -/* } */ -/* assert(node.ptr != nullptr); */ -/* guid_to_nodes[guid] = node; */ -/* for (size_t i = 0; i < num_inputs; i++) { */ -/* inedges[i].dstOp = node; */ -/* graph->add_edge(inedges[i]); */ -/* } */ -/* } */ -/* // Second, deserialize optimal machine view */ -/* size_t num_views; */ -/* dez.deserialize(num_views); */ -/* printf("views.size() = %zu\n", num_views); */ -/* for (size_t i = 0; i < num_views; i++) { */ -/* size_t safecode, guid; */ -/* MachineView view; */ -/* dez.deserialize(safecode); */ -/* assert(safecode == 98765432); */ -/* dez.deserialize(guid); */ -/* assert(guid_to_nodes.find(guid) != guid_to_nodes.end()); */ -/* dez.deserialize(view); */ -/* optimal_views[guid_to_nodes[guid]] = view; */ -/* } */ -/* assert(dez.get_remaining_bytes() == 0); */ -/* printf("Deserialized Views...\n"); */ -/* for (auto const &it : optimal_views) { */ -/* printf("node[%zu]: type(%s) view(%d %d %d) ", */ -/* it.first.guid, */ -/* it.first.to_string().c_str(), */ -/* it.second.ndims, */ -/* it.second.dim[0], */ -/* it.second.start_device_id); */ -/* auto const &list = graph->inEdges.at(it.first); */ -/* for (auto const &it2 : list) { */ -/* Edge e = it2; */ -/* printf(" inEdge(node(%zu) idx(%d))", e.srcOp.guid, e.srcIdx); */ -/* } */ -/* printf("\n"); */ -/* } */ -/* } */ - -} // namespace FlexFlow diff --git a/lib/compiler/src/old/graph.h b/lib/compiler/src/old/graph.h deleted file mode 100644 index db313b080d..0000000000 --- a/lib/compiler/src/old/graph.h +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2021 CMU, Facebook, LANL, MIT, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_GRAPH_H_ -#define _FLEXFLOW_GRAPH_H_ -#include "basic_graph.h" -/* #include "node.h" */ -#include "graph_structures.h" -#include "op-attrs/op-attrs.h" -#include "pcg/machine_view.h" -#include "utils/bidict.h" -#include "utils/dot_file.h" -#include "utils/graph.h" -#include "utils/graph/serialparallel.h" -#include "utils/recursive_logger.h" -#include -#include - -// extern LegionRuntime::Logger::Category log_dp; - -/* namespace FlexFlow { */ -/* namespace ffc { */ - -/* class SearchHelper; */ - -/* struct GraphOptimalViewSerialized { */ -/* #ifdef LEGION_MAX_RETURN_SIZE */ -/* static const size_t buffer_size = LEGION_MAX_RETURN_SIZE - 8; */ -/* #else */ -/* static const size_t buffer_size = 1024 * 1024 - 8; */ -/* #endif */ -/* size_t total_bytes; */ -/* char data[buffer_size]; */ -/* }; */ - -/* class Graph { */ -/* public: */ -/* Graph() = default; */ -/* Graph(std::string const &logger_name); */ -/* Graph(std::shared_ptr const &logger); */ - -/* void add_edge(utils::Node const &srcOp, utils::Node const &dstOp, int - * srcIdx, int dstIdx); */ -/* utils::Node add_node(opmeta::OperatorParameters const &); */ -/* void add_edge(utils::MultiDiEdge const &e); */ -/* void remove_node(utils::Node const &, bool purge_edges = false); */ -/* void remove_edge(utils::MultiDiEdge const &e, bool remove_node_if_unused = - * true); */ -/* bool has_edge(utils::MultiDiEdge const &e) const; */ -/* void replace_subgraph(std::unordered_set const - * ¤tNodes, */ -/* Graph const &replaceWith); */ -/* Graph subgraph(std::unordered_set const &nodes) const; */ -/* void contract_out_node(opmeta::OperatorParameters const &); */ -/* float optimal_cost() const; */ -/* std::unordered_map optimal_views() - * const; */ -/* void remove_input_nodes(); */ -/* void duplicate_input_node(opmeta::OperatorParameters const &); */ -/* void duplicate_input_nodes(); */ -/* opmeta::OperatorParameters clone_node(opmeta::OperatorParameters const &); - */ -/* std::pair> */ -/* deduplicate_input_node(opmeta::OperatorParameters const &); */ -/* std::unordered_map - * deduplicate_input_nodes(); */ -/* opmeta::OperatorParameters declone_node(opmeta::OperatorParameters const - * &); */ - -/* size_t hash(void) const; */ -/* void print(void) const; */ -/* void print_dot() const; */ -/* void print_dot(std::ostream &) const; */ - -/* bool check_correctness(void); */ -/* bool has_loop(void); */ -/* //bool map_operators_to_layers(std::vector &layers) const; */ -/* //static GraphOptimalViewSerialized */ -/* // graph_optimize_task(Legion::Task const *task, */ -/* // std::vector const - * ®ions, */ -/* // Legion::Context ctx, */ -/* // Legion::Runtime *runtime); */ -/* /1* opmeta::OperatorParameters - * find_bottleneck_node(opmeta::OperatorParameters const &sink_node, *1/ */ -/* /1* opmeta::OperatorParameters const - * &source_node) const; *1/ */ -/* void print_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy) const; */ -/* void export_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy, */ -/* std::string const &out_filename) const; */ -/* void export_strategy_computation_graph( */ -/* std::unordered_map const - * &strategy, */ -/* DotFile &dot) const; */ - -/* /1* std::pair, std::unique_ptr> *1/ */ -/* /1* split_at_node(opmeta::OperatorParameters const &bottleneck) const; - * *1/ */ -/* /1* std::pair, std::unique_ptr> *1/ */ -/* /1* split_horizontal(opmeta::OperatorParameters const &source_node, - * opmeta::OperatorParameters const &sink_node) const; *1/ */ - -/* Graph reduced() const; */ - -/* opmeta::OperatorParameters find_sink_node() const; */ -/* opmeta::OperatorParameters find_source_node() const; */ -/* void reshape_output_tensor(opmeta::ParallelTensorShape const &shape); */ -/* std::unique_ptr */ -/* with_output_tensor_reshaped_to(opmeta::ParallelTensorShape const - * &shape) const; */ - -/* static Graph singleton(opmeta::OperatorParameters const &); */ -/* bool empty() const; */ - -/* template */ -/* T generic_optimal_cost() const; */ - -/* private: */ -/* void remove_inverse_parallel_ops(); */ -/* void replace_subgraph_with_nonempty( */ -/* std::unordered_set const ¤tNodes, - * Graph const &replaceWith); */ -/* private: */ -/* Graph(utils::AdjacencyMultiDiGraph const &, utils::bidict const &, std::shared_ptr const - * &); */ - -/* utils::AdjacencyMultiDiGraph g; */ -/* utils::bidict nodeMap; */ -/* std::shared_ptr logger; */ -/* }; */ - -/* struct GraphOptimizeResult { */ -/* tl::optional graph; */ -/* float cost; */ -/* std::unordered_map views; */ - -/* friend std::ostream &operator<<(std::ostream &, GraphOptimizeResult const - * &); */ -/* }; */ - -/* /1* namespace Utils { *1/ */ -/* /1* template <> *1/ */ -/* /1* struct GraphStructure { *1/ */ -/* /1* using G = FlexFlow::PCG::Graph; *1/ */ -/* /1* using graph_type = FlexFlow::PCG::Graph; *1/ */ -/* /1* using vertex_type = FlexFlow::PCG::Node; *1/ */ -/* /1* using edge_type = FlexFlow::PCG::Edge; *1/ */ - -/* /1* std::unordered_set get_nodes(G const &g) const { *1/ */ -/* /1* std::unordered_set nodes; *1/ */ -/* /1* for (auto const &kv : g.inEdges) { *1/ */ -/* /1* nodes.insert(kv.first); *1/ */ -/* /1* } *1/ */ -/* /1* for (auto const &kv : g.outEdges) { *1/ */ -/* /1* nodes.insert(kv.first); *1/ */ -/* /1* } *1/ */ - -/* /1* return nodes; *1/ */ -/* /1* } *1/ */ - -/* /1* std::unordered_set get_incoming_edges(G const &g, *1/ */ -/* /1* vertex_type const &n) - * const { *1/ */ -/* /1* if (g.inEdges.find(n) == g.inEdges.end()) { *1/ */ -/* /1* return {}; *1/ */ -/* /1* } else { *1/ */ -/* /1* return {g.inEdges.at(n).begin(), g.inEdges.at(n).end()}; *1/ */ -/* /1* } *1/ */ -/* /1* } *1/ */ - -/* /1* std::unordered_set get_outgoing_edges(G const &g, *1/ */ -/* /1* vertex_type const &n) - * const { *1/ */ -/* /1* if (g.outEdges.find(n) == g.outEdges.end()) { *1/ */ -/* /1* return {}; *1/ */ -/* /1* } else { *1/ */ -/* /1* return {g.outEdges.at(n).begin(), g.outEdges.at(n).end()}; *1/ */ -/* /1* } *1/ */ -/* /1* } *1/ */ - -/* /1* vertex_type get_src(G const &g, edge_type const &e) const { *1/ */ -/* /1* return e.srcOp; *1/ */ -/* /1* } *1/ */ - -/* /1* vertex_type get_dst(G const &g, edge_type const &e) const { *1/ */ -/* /1* return e.dstOp; *1/ */ -/* /1* } *1/ */ - -/* /1* void set_src(G const &g, edge_type &e, vertex_type const &n) const { - * *1/ */ -/* /1* e.srcOp = n; *1/ */ -/* /1* } *1/ */ - -/* /1* void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - * *1/ */ -/* /1* e.dstOp = n; *1/ */ -/* /1* } *1/ */ -/* /1* }; *1/ */ - -/* size_t dp_state_hash(Graph const *graph, */ -/* opmeta::OperatorParameters const &sink_node, */ -/* MachineView const &sink_view, */ -/* opmeta::OperatorParameters const &source_node, */ -/* MachineView const &source_view, */ -/* MachineResource const &resource); */ - -/* // template <> */ -/* // struct invalid_node> { */ -/* // using G = Graph; */ -/* // using Structure = GraphStructure; */ -/* // using vertex_type = typename Structure::vertex_type; */ -/* // */ -/* // vertex_type operator()() const { */ -/* // return vertex_type::INVALID_NODE; */ -/* // } */ -/* // }; */ -/* // */ -/* // template <> */ -/* // struct invalid_node, GraphStructure>> { - */ -/* // Node operator()() const { */ -/* // return Node::INVALID_NODE; */ -/* // } */ -/* // }; */ - -/* /1* } // namespace Utils *1/ */ -/* } // namespace ffc */ -/* } // namespace FlexFlow */ - -#endif diff --git a/lib/compiler/src/old/graph_structures.h b/lib/compiler/src/old/graph_structures.h deleted file mode 100644 index 8b921794e1..0000000000 --- a/lib/compiler/src/old/graph_structures.h +++ /dev/null @@ -1,269 +0,0 @@ -#ifndef _GRAPH_STRUCTURES_H -#define _GRAPH_STRUCTURES_H - -#include "basic_graph.h" - -namespace FlexFlow { -namespace PCG { -namespace Utils { - -template -struct ReverseStructure { - using graph_type = typename BaseStructure::graph_type; - using G = graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using edge_type = typename BaseStructure::edge_type; - - std::unordered_set get_nodes(G const &g) const { - return this->base.get_nodes(g); - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - return this->base.get_outgoing_edges(g, n); - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - return this->base.get_incoming_edges(g, n); - } - - vertex_type get_src(G const &g, edge_type const &e) const { - return this->base.get_dst(g, e); - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - return this->base.get_src(g, e); - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_dst(g, e, n); - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_src(g, e, n); - } - - BaseStructure base; -}; - -template -struct UndirectedEdge { - union Edge { - NotReversed not_reversed; - Reversed reversed; - - Edge() {} - }; - - bool is_reversed; - Edge edge; - - UndirectedEdge() {} - - bool operator==(UndirectedEdge const &other) const { - if (other.is_reversed != this->is_reversed) { - return false; - } - if (this->is_reversed) { - return this->edge.reversed == other.edge.reversed; - } else { - return this->edge.not_reversed == other.edge.not_reversed; - } - } -}; - -template > -struct UndirectedStructure { - using graph_type = typename BaseStructure::graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using not_reversed_edge_type = typename BaseStructure::edge_type; - using reversed_edge_type = - typename ReverseStructure::edge_type; - using edge_type = UndirectedEdge; - - std::unordered_set get_nodes(G const &g) const { - return this->base.get_nodes(g); - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - std::unordered_set incoming; - auto base_edges = this->base.get_incoming_edges(g, n); - auto reversed_edges = this->reversed.get_incoming_edges(g, n); - - for (auto const &e : base_edges) { - edge_type lifted; - lifted.is_reversed = false; - lifted.edge.not_reversed = e; - incoming.insert(lifted); - } - - for (auto const &e : reversed_edges) { - edge_type lifted; - lifted.is_reversed = true; - lifted.edge.reversed = e; - incoming.insert(lifted); - } - - return incoming; - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - std::unordered_set outgoing; - auto base_edges = this->base.get_outgoing_edges(g, n); - auto reversed_edges = this->reversed.get_outgoing_edges(g, n); - - for (auto const &e : base_edges) { - edge_type lifted; - lifted.is_reversed = false; - lifted.edge.not_reversed = e; - outgoing.insert(lifted); - } - - for (auto const &e : reversed_edges) { - edge_type lifted; - lifted.is_reversed = true; - lifted.edge.reversed = e; - outgoing.insert(lifted); - } - - return outgoing; - } - - vertex_type get_src(G const &g, edge_type const &e) const { - if (e.is_reversed) { - return this->reversed.get_src(g, e.edge.reversed); - } else { - return this->base.get_src(g, e.edge.not_reversed); - } - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - if (e.is_reversed) { - return this->reversed.get_dst(g, e.edge.reversed); - } else { - return this->base.get_dst(g, e.edge.not_reversed); - } - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - if (e.is_reversed) { - this->reversed.set_src(g, e.edge.reversed, n); - } else { - this->base.set_src(g, e.edge.not_reversed, n); - } - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - if (e.is_reversed) { - this->reversed.set_src(g, e.edge.reversed, n); - } else { - this->base.set_src(g, e.edge.not_reversed, n); - } - } - - BaseStructure base; - ReverseStructure reversed; -}; - -template > -struct invalid_node; - -template , - typename Invalid = invalid_node> -struct MultisourceGraphStructure { - using graph_type = typename BaseStructure::graph_type; - using vertex_type = typename BaseStructure::vertex_type; - using edge_type = typename BaseStructure::edge_type; - - std::unordered_set get_nodes(G const &g) const { - Invalid invalid; - - std::unordered_set nodes = this->base.get_nodes(g); - nodes.insert(invalid()); - return nodes; - } - - std::unordered_set get_incoming_edges(G const &g, - vertex_type const &n) const { - Invalid invalid; - - if (n == invalid()) { - return {}; - } - - std::unordered_set edges = this->base.get_incoming_edges(g, n); - if (edges.empty()) { - edge_type e; - this->base.set_src(g, e, invalid()); - this->base.set_dst(g, e, n); - return {e}; - } - - return edges; - } - - std::unordered_set get_outgoing_edges(G const &g, - vertex_type const &n) const { - Invalid invalid; - - if (n == invalid()) { - std::unordered_set edges; - for (auto const &node : this->base.get_nodes(g)) { - if (this->base.get_incoming_edges(g, node).empty()) { - edge_type e; - this->base.set_src(g, e, invalid()); - this->base.set_dst(g, e, node); - edges.insert(e); - } - } - return edges; - } - - return this->base.get_outgoing_edges(g, n); - } - - vertex_type get_src(G const &g, edge_type const &e) const { - return this->base.get_src(g, e); - } - - vertex_type get_dst(G const &g, edge_type const &e) const { - return this->base.get_dst(g, e); - } - - void set_src(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_src(g, e, n); - } - - void set_dst(G const &g, edge_type &e, vertex_type const &n) const { - this->base.set_dst(g, e, n); - } - - BaseStructure base; -}; -} // namespace Utils -} // namespace PCG -} // namespace FlexFlow - -namespace std { -using FlexFlow::PCG::Utils::UndirectedEdge; - -template -struct hash> { - size_t operator()(UndirectedEdge const &e) const { - size_t result; - result = std::hash()(e.is_reversed); - if (e.is_reversed) { - hash_combine(result, e.edge.reversed); - } else { - hash_combine(result, e.edge.not_reversed); - } - return result; - } -}; -} // namespace std - -#endif // _GRAPH_STRUCTURES_H diff --git a/lib/compiler/src/old/node.h b/lib/compiler/src/old/node.h deleted file mode 100644 index eb33a39ae7..0000000000 --- a/lib/compiler/src/old/node.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef _FLEXFLOW_FFC_NODE_H -#define _FLEXFLOW_FFC_NODE_H - -#include - -#include "op-attrs/op-attrs.h" -#include "tl/optional.hpp" - -namespace FlexFlow { -namespace ffc { - -struct Node { - Node() = delete; - Node(size_t guid, PCGOperatorAttrs const &op_params); - - std::string to_string(void) const; - - using AsTuple = - std::tuple &>; - using AsConstTuple = std::tuple const &>; - - AsTuple as_tuple(); - AsConstTuple as_tuple() const; - -public: - size_t guid; - PCGOperatorAttrs op_params; - tl::optional original_guid = tl::nullopt; -}; - -bool operator==(Node const &, Node const &); -bool operator!=(Node const &, Node const &); -bool operator<(Node const &, Node const &); - -} // namespace ffc -} // namespace FlexFlow - -namespace std { -template <> -struct hash<::FlexFlow::ffc::Node> { - size_t operator()(::FlexFlow::ffc::Node const &n) const; -}; -} // namespace std - -#endif diff --git a/lib/compiler/src/old/parallel_dim_mapping_record.h b/lib/compiler/src/old/parallel_dim_mapping_record.h deleted file mode 100644 index 8e2c265489..0000000000 --- a/lib/compiler/src/old/parallel_dim_mapping_record.h +++ /dev/null @@ -1,4 +0,0 @@ -#ifndef _FLEXFLOW_FFC_PARALLEL_DIM_MAPPING_RECORD_H -#define _FLEXFLOW_FFC_PARALLEL_DIM_MAPPING_RECORD_H - -#endif diff --git a/lib/compiler/src/old/search_helper.cc b/lib/compiler/src/old/search_helper.cc deleted file mode 100644 index 2e7eafa5fd..0000000000 --- a/lib/compiler/src/old/search_helper.cc +++ /dev/null @@ -1,525 +0,0 @@ -#include "search_helper.h" - -namespace FlexFlow { -namespace PCG { - -SearchHelper::SearchHelper() { - this->logger = std::unique_ptr(new RecursiveLogger("DP")); -} - -template -T SearchHelper::execute_sequence_split(std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - SequenceSplit const &bn) const { - return sequence_cost( - this->graph_cost(pre_graph.get(), source, bn, resources, true), - this->graph_cost(post_graph.get(), bn, sink, resources, false)); -} - -template -T SearchHelper::find_optimal_sequence_graph_time( - Graph const *g, - Node const &bn_node, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const { - std::unique_ptr pre_graph; - std::unique_ptr post_graph; - std::tie(pre_graph, post_graph) = g->split_at_node(bn_node); - - T optimal = this->infinity(); - - std::vector valid_views = - this->get_valid_machine_views(bn_node.op_params, resources); - // A Corner Case: - // If bn_node is a parallel_op and an input to sink_node, - // Add sink_node's view to the list, since sink_node's view - // may not be a valid view for resources, but UniFlow support - // this case since parallel_op does not trigger computation - if (is_parallel_op(bn_node.op_params)) { - bool found = false; - auto const &inList = g->inEdges.find(sink.node)->second; - for (auto const &e : inList) { - if (e.srcOp == bn_node) { - found = true; - break; - } - } - if (found) { - for (int j = 0; j < bn_node.ptr->numOutputs; j++) { - if (!bn_node.ptr->outputs[j]->is_valid_machine_view(sink.view)) { - found = false; - } - } - } - if (found) { - valid_views.push_back(sink.view); - } - } - - if (valid_views.empty()) { - return optimal; - } - - float optimal_cost = std::numeric_limits::infinity(); - MachineView best_view; - - for (MachineView const &bn_view : valid_views) { - float cost = this->execute_sequence_split( - pre_graph, post_graph, source, sink, resources, {bn_node, bn_view}); - - if (cost < optimal_cost) { - best_view = bn_view; - optimal_cost = cost; - } - } - - if (optimal_cost != std::numeric_limits::infinity()) { - optimal = this->execute_sequence_split( - pre_graph, post_graph, source, sink, resources, {bn_node, best_view}); - } - - check_matches_graph(g, optimal, sink.node); - - return optimal; -} - -template -T SearchHelper::execute_nonsequence_split( - std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - NonsequenceSplit const &split) const { - Graph const *first = first_graph.get(); - Graph const *second = second_graph.get(); - if (split.flip_graphs) { - std::swap(first, second); - } - switch (split.type) { - case SplitType::SEQUENTIAL: - this->logger->debug() << "Exploring sequential nonsequence split"; - return sequence_cost( - this->graph_cost(first, source, sink, resources, false), - this->graph_cost(second, source, sink, resources, false)); - case SplitType::VERTICAL: { - this->logger->debug() << "Exploring vertical nonsequence split (" - << split.param << ", " << split.flip_graphs << ")"; - MachineResource firstRes = resources, secondRes = resources; - firstRes.num_nodes = split.param; - secondRes.num_nodes = resources.num_nodes - split.param; - secondRes.start_gpu_id = - resources.start_gpu_id + resources.all_gpus_per_node * split.param; - - return parallel_cost( - this->graph_cost(first, source, sink, firstRes, false), - this->graph_cost(second, source, sink, secondRes, false)); - } - case SplitType::HORIZONTAL: { - this->logger->debug() << "Exploring horizontal nonsequence split (" - << split.param << ", " << split.flip_graphs << ")"; - MachineResource firstRes = resources, secondRes = resources; - firstRes.available_gpus_per_node = split.param; - secondRes.available_gpus_per_node = - resources.available_gpus_per_node - split.param; - secondRes.start_gpu_id = resources.start_gpu_id + split.param; - - return parallel_cost( - this->graph_cost(first, source, sink, firstRes, false), - this->graph_cost(second, source, sink, secondRes, false)); - } - default: - assert(false); - } -} - -template -T SearchHelper::find_optimal_nonsequence_graph_time( - Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const { - std::unique_ptr first_graph; - std::unique_ptr second_graph; - std::tie(first_graph, second_graph) = - g->split_horizontal(source.node, sink.node); - - std::vector potential_splits; - - for (int i = 1; i < resources.num_nodes; i++) { - potential_splits.push_back(NonsequenceSplit::vertical(i, false)); - potential_splits.push_back(NonsequenceSplit::vertical(i, true)); - } - for (int i = 1; i < resources.available_gpus_per_node; i++) { - potential_splits.push_back(NonsequenceSplit::horizontal(i, false)); - potential_splits.push_back(NonsequenceSplit::horizontal(i, true)); - } - - NonsequenceSplit best_split = NonsequenceSplit::sequential(); - float best_cost = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, best_split); - for (NonsequenceSplit const &split : potential_splits) { - float cost = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, split); - this->logger->debug() << "Found cost: " << cost; - - if (cost < best_cost) { - best_cost = cost; - best_split = split; - } - } - - switch (best_split.type) { - case SplitType::SEQUENTIAL: - this->logger->debug() << "Best split: SEQUENTIAL"; - break; - case SplitType::VERTICAL: - this->logger->debug() << "Best split: VERTICAL(" << best_split.param - << ", " << best_split.flip_graphs << ")"; - break; - case SplitType::HORIZONTAL: - this->logger->debug() << "Best split: HORIZONTAL(" << best_split.param - << ", " << best_split.flip_graphs << ")"; - break; - } - T optimal = this->execute_nonsequence_split( - first_graph, second_graph, source, sink, resources, best_split); - - check_matches_graph(g, optimal, sink.node); - - return optimal; -} - -std::vector SearchHelper::get_valid_machine_views( - Node const &node, MachineResource const &resource, bool log) const { - this->logger->info() << "Getting valid machine views for " - << node.to_string(); - return this->get_valid_machine_views(node.ptr, resource, log); -} - -std::vector SearchHelper::get_valid_machine_views( - Op const *op, MachineResource const &resource, bool log) const { - std::vector const *cached_op_views = NULL; - std::vector valid_views; - - auto const &iter = cached_operator_valid_views.find(op->op_guid); - if (iter != cached_operator_valid_views.end()) { - cached_op_views = iter->second.get(); - } else { - auto to_cache = std::unique_ptr>( - new std::vector()); - if (log) { - this->logger->info() << "Considering a total of " - << this->model->all_valid_views.size() - << " potential valid views"; - } - for (size_t i = 0; i < this->model->all_valid_views.size(); i++) { - bool valid = true; - for (int j = 0; j < op->numOutputs; j++) { - if (!op->outputs[j]->is_valid_machine_view( - this->model->all_valid_views[i])) { - valid = false; - { - MachineView const &view = this->model->all_valid_views[i]; - std::ostringstream oss; - oss << "[" << view.ndims << "]("; - for (int i = 0; i < view.ndims; i++) { - oss << view.dim[i] << "/" << view.stride[i]; - if (i != view.ndims - 1) { - oss << " "; - } - } - oss << ")"; - if (log) { - this->logger->info() << "Rejecting machine view: " << oss.str(); - } - } - break; - } - } - if (valid) { - { - MachineView const &view = this->model->all_valid_views[i]; - std::ostringstream oss; - oss << "[" << view.ndims << "]("; - for (int i = 0; i < view.ndims; i++) { - oss << view.dim[i] << "/" << view.stride[i]; - if (i != view.ndims - 1) { - oss << " "; - } - } - oss << ")"; - if (log) { - this->logger->info() << "Accepting machine view: " << oss.str(); - } - } - to_cache->push_back(this->model->all_valid_views[i]); - } - } - cached_operator_valid_views[op->op_guid] = std::move(to_cache); - cached_op_views = cached_operator_valid_views.at(op->op_guid).get(); - } - if (log) { - this->logger->info() << "Found " << cached_op_views->size() - << " cached op views"; - } - for (size_t i = 0; i < cached_op_views->size(); i++) { - MachineView view = (*cached_op_views)[i]; - if (view.device_type == MachineView::GPU) { - view.start_device_id = resource.start_gpu_id; - } else if (view.device_type == MachineView::CPU) { - view.start_device_id = resource.start_cpu_id; - } else { - assert(false); - } - if (resource.is_valid_machine_view(view)) { - valid_views.push_back(view); - } - } - return valid_views; -} - -template <> -bool SearchHelper::is_invalid(float const &cost) const { - return cost == std::numeric_limits::infinity(); -} - -template <> -bool SearchHelper::is_invalid( - GraphCostResult const &cost) const { - return cost.cost == std::numeric_limits::infinity(); -} - -/** - * @brief Asserts that the results of graph optimization are valid for the graph - * - * @param g the graph to check against - * @param r the results to check - * @param sink the sink node of the graph g - * @param include_sink whether or not to include the sink node - */ -template <> -void SearchHelper::check_matches_graph( - Graph const *g, GraphCostResult const &r, Node const &sink) const { - using FlexFlow::PCG::Utils::nodes; - - if (this->is_invalid(r)) { - return; - } - - std::unordered_set g_nodes = nodes(*g); - g_nodes.erase(sink); - - std::unordered_set r_nodes; - for (auto const &kv : r.views) { - r_nodes.insert(kv.first); - } - - assert(g_nodes == r_nodes); -} - -template <> -void SearchHelper::check_matches_graph(Graph const *g, - float const &r, - Node const &sink) const {} - -template <> -std::pair - SearchHelper::try_get_cost_from_cache(size_t hash) const { - if (this->cached_graph_costs.find(hash) == this->cached_graph_costs.end()) { - return {false, std::numeric_limits::infinity()}; - } else { - return {true, this->cached_graph_costs.at(hash)}; - } -} - -template <> -std::pair - SearchHelper::try_get_cost_from_cache(size_t hash) const { - return {false, GraphCostResult::invalid()}; -} - -template <> -void SearchHelper::try_cache_result(size_t hash, - float const &value) const { - this->logger->debug() << "cached_graph_costs[" << hash << "] = " << value; - this->cached_graph_costs[hash] = value; -} - -template <> -void SearchHelper::try_cache_result( - size_t hash, GraphCostResult const &value) const { - this->logger->debug() << "cached_graph_costs[" << hash << "=" << value.cost - << "]"; - this->cached_graph_costs[hash] = value.cost; -} - -template <> -float SearchHelper::infinity() const { - return std::numeric_limits::infinity(); -} - -template <> -GraphCostResult SearchHelper::infinity() const { - return {std::numeric_limits::infinity(), {}}; -} - -template <> -float SearchHelper::empty() const { - return 0.0f; -} - -template <> -GraphCostResult SearchHelper::empty() const { - return {0.0f, {}}; -} - -template -T SearchHelper::estimate_xfer_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink) const { - T result = this->empty(); - - if (source.node != Node::INVALID_NODE) { - auto const &inList = graph->inEdges.find(sink.node)->second; - float op_cost = 0.0f; - for (auto const &it2 : inList) { - assert(it2.srcOp == source.node); - assert(sink.node.ptr->inputs[it2.dstIdx]->is_valid_machine_view( - source.view)); - - float estimated_xfer_cost = this->model->simulator->estimate_xfer_cost( - sink.node.ptr, it2.dstIdx, source.view, sink.view); - // printf("Estimated xfer cost from %s to %s: %fms\n", - // source.node.ptr->name, sink.node.ptr->name, estimated_xfer_cost); - op_cost += estimated_xfer_cost; - } - this->add_operator_cost(source, op_cost, &result); - } else { - Node real_source = graph->find_source_node(); - assert(real_source.ptr->op_type == OP_INPUT); - this->add_operator_cost({real_source, MachineView::NO_VIEW}, 0.0f, &result); - } - - return result; -} - -template <> -void SearchHelper::add_operator_cost(NodeAssignment const &node, - float node_cost, - float *cost) const { - *cost += node_cost; -} - -template <> -void SearchHelper::add_operator_cost( - NodeAssignment const &node, float node_cost, GraphCostResult *cost) const { - cost->cost += node_cost; - cost->views[node.node] = node.view; -} - -template <> -float SearchHelper::get_cost(float const &f) const { - return f; -} - -template <> -float SearchHelper::get_cost( - GraphCostResult const &gcr) const { - return gcr.cost; -} - -template -T SearchHelper::graph_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - bool include_sink_compute_time) const { - TAG_ENTER(this->logger); - this->logger->debug() << "sink(" << sink.node.guid << ") " - << "sink.view(" << sink.view.ndims << " " - << sink.view.start_device_id << " " << sink.view.dim[0] - << ") " - << "source(" << source.node.guid << ") " - << "source.view(" << source.view.ndims << " " - << source.view.start_device_id << " " - << source.view.dim[0] << ") " - << "resources(" << resources.num_nodes << " " - << resources.start_gpu_id << " " - << resources.available_gpus_per_node << ")"; - if (this->model->config.profiling) { - graph->print_dot(); - } - - assert(graph->inEdges.find(sink.node) != graph->inEdges.end()); - if (source.node != Node::INVALID_NODE) { - assert(graph->outEdges.find(source.node) != graph->outEdges.end()); - } - - size_t hash = dp_state_hash( - graph, sink.node, sink.view, source.node, source.view, resources); - this->logger->spew() << "hash = " << hash; - - T result; - - std::pair from_cache = this->try_get_cost_from_cache(hash); - if (from_cache.first) { - // cached_graph_costs does not include sink_compute_time - result = from_cache.second; - } else { - if (graph->inEdges.size() <= 2) { - result = this->estimate_xfer_cost(graph, source, sink); - this->logger->debug() - << "Estimated xfer cost is " << this->get_cost(result); - } else { - Node bn_node = graph->find_bottleneck_node(sink.node, source.node); - if (bn_node != Node::INVALID_NODE) { - // We found a bottleneck node - this->logger->debug() << "Found bn_node = " << bn_node.guid; - - result = this->find_optimal_sequence_graph_time( - graph, - bn_node, - {source.node, source.view}, - {sink.node, sink.view}, - resources); - } else { - // sink node must have multiple branches - // otherwise we should not be here - assert(graph->inEdges.find(sink.node)->second.size() > 1); - - result = this->find_optimal_nonsequence_graph_time( - graph, - {source.node, source.view}, - {sink.node, sink.view}, - resources); - } - } - - this->try_cache_result(hash, result); - } - - check_matches_graph(graph, result, sink.node); - - if (include_sink_compute_time) { - CostMetrics metrics = - this->model->simulator->measure_operator_cost(sink.node.ptr, sink.view); - this->logger->debug() << "Sink node cost: " - << "forward(" << metrics.forward_time << ") " - << "backward(" << metrics.backward_time << ") " - << "sync(" << metrics.sync_time << ")"; - this->add_operator_cost(sink, - metrics.forward_time + metrics.backward_time + - metrics.sync_time, - &result); - } - - return result; -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/search_helper.h b/lib/compiler/src/old/search_helper.h deleted file mode 100644 index 95350ce6af..0000000000 --- a/lib/compiler/src/old/search_helper.h +++ /dev/null @@ -1,122 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SRC_SEARCH_HELPER_H -#define _FLEXFLOW_FFC_SRC_SEARCH_HELPER_H - -#include "graph.h" -#include "split_types.h" - -namespace FlexFlow { - -struct GraphCostResult { - float cost; - std::unordered_map views; - - static GraphCostResult invalid(); - - bool operator<(GraphCostResult const &other) const; - - friend std::ostream &operator<<(std::ostream &, GraphCostResult const &); -}; - -template -T sequence_cost(T const &first, T const &second); - -template -T parallel_cost(T const &first, T const &second); - -class SearchHelper { -public: - SearchHelper(); - - template - T graph_cost(Graph const *graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - bool include_sink_compute_time) const; - template - T find_optimal_sequence_graph_time(Graph const *g, - Node const &bottleneck_node, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const; - template - T find_optimal_nonsequence_graph_time(Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources) const; - /* void find_optimal_nonsequence_graph_views(Graph const *g, */ - /* NodeAssignment const &source, */ - /* NodeAssignment const &sink, */ - /* MachineResource const &resources, - */ - /* float optimal_cost, */ - /* std::unordered_map& optimal_views) const; */ - std::vector - get_valid_machine_views(Node const &node, - MachineResource const &resource, - bool log = false) const; - std::vector - get_valid_machine_views(PCGOperatorAttrs const &op, - MachineResource const &resource, - bool log = false) const; - - template - std::pair try_get_cost_from_cache(size_t hash) const; - - template - void try_cache_result(size_t hash, T const &value) const; - - template - T infinity() const; - - template - T empty() const; - - template - bool is_invalid(T const &) const; - - template - T estimate_xfer_cost(Graph const *g, - NodeAssignment const &source, - NodeAssignment const &sink) const; - - template - void add_operator_cost(NodeAssignment const &, float, T *) const; - - template - float get_cost(T const &) const; - - template - void check_matches_graph(Graph const *, T const &, Node const &) const; - -public: - mutable std::unique_ptr logger; - -private: - template - T execute_nonsequence_split(std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - NonsequenceSplit const &split) const; - - template - T execute_sequence_split(std::unique_ptr const &first_graph, - std::unique_ptr const &second_graph, - NodeAssignment const &source, - NodeAssignment const &sink, - MachineResource const &resources, - SequenceSplit const &split) const; - -private: - mutable std::unordered_map cached_graph_costs; - mutable std::unordered_map>> - cached_operator_valid_views; -}; - -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/simplification.cc b/lib/compiler/src/old/simplification.cc deleted file mode 100644 index 18fc2fb71a..0000000000 --- a/lib/compiler/src/old/simplification.cc +++ /dev/null @@ -1,189 +0,0 @@ -#include "simplification.h" -#include "spdlog/spdlog.h" -#include - -namespace FlexFlow { -namespace PCG { - -Simplifier::Simplifier(std::string const &logger_name) - : logger(spdlog::get(logger_name)) {} - -void Simplifier::simplify_parallel_ops() { - logger->debug("Trying to simplify parallel ops"); - - /* using FlexFlow::PCG::Utils::nodes; */ - /* using FlexFlow::PCG::Utils::predecessor; */ - /* using FlexFlow::PCG::Utils::predecessors; */ - /* using FlexFlow::PCG::Utils::successor; */ - - std::queue work_queue; - for (Node const &node : nodes(*this)) { - if (node.ptr->is_parallel_op()) { - work_queue.push(node); - } - } - - while (!work_queue.empty()) { - Node node = work_queue.front(); - log_simplify.debug() << "Trying to simplify starting from " - << node.to_string(); - work_queue.pop(); - - auto opt_succ = successor(*this, node); - if (!opt_succ.has_value()) { - log_simplify.debug() << "Skipping because does not have single successor"; - continue; - } - Node succ = opt_succ.value(); - if (!succ.ptr->is_parallel_op()) { - log_simplify.debug() << "Skipping because successor is not a parallel op"; - continue; - } - - std::vector node_parallel_op_info, - successor_parallel_op_info; - ((ParallelOp *)node.ptr)->append_parallel_op_info(node_parallel_op_info); - ((ParallelOp *)succ.ptr) - ->append_parallel_op_info(successor_parallel_op_info); - ParallelOpJoinResult result = try_join_parallel_ops( - node_parallel_op_info.front(), successor_parallel_op_info.front()); - - if (!result.join_did_succeed) { - log_simplify.debug() << "Skipping because join did not succeed"; - continue; - } - log_simplify.debug() << "Did join nodes"; - log_simplify.debug() << " " << node.to_string(); - log_simplify.debug() << " " << succ.to_string(); - - for (Node const &p : predecessors(*this, node)) { - if (p.ptr->is_parallel_op()) { - work_queue.push(p); - } - } - - Graph new_g(this->model); - if (result.op.has_value()) { - Node new_op = this->model->get_or_create_parallel_op_node( - node.ptr->inputs[0], result.op.value()); - work_queue.push(new_op); - new_g.add_node(new_op); - } - this->replace_subgraph({node, succ}, new_g); - } - log_simplify.debug() << "Finished simplifying parallel ops"; -} - -void Graph::simplify(SimplificationSettings const &settings) { - // Simplify the graph by eliminating reverse parallel ops - // and fusing multiple parallel ops - // old graph: e1->n1->e2->n2->en - // new graph: e1->new_node->en - // TODO: temporarily disabled graph simplification - if (settings.simplify_parallel_ops) { - this->simplify_parallel_ops(); - } - if (settings.fuse_parallel_ops) { - bool simplify = true; - while (simplify) { - simplify = false; - for (auto const &it : this->inEdges) { - if (it.first.ptr == NULL) { - continue; - } - if (it.first.ptr->is_parallel_op()) { - Node n2 = it.first; - assert(it.second.size() == 1); - Edge e2 = *it.second.begin(); - Node n1 = e2.srcOp; - // Check that n1 is a parallel op - // Check that n1 must have a single out edge - if (n1.ptr->is_parallel_op() && - this->outEdges.find(n1)->second.size() == 1) { - // merge n1 and n2 - std::vector parallel_ops; - ((ParallelOp *)n1.ptr)->append_parallel_op_info(parallel_ops); - ((ParallelOp *)n2.ptr)->append_parallel_op_info(parallel_ops); - Node new_node = model->get_or_create_fused_parallel_node( - n1.ptr->inputs[0], parallel_ops); - auto const &inList = this->inEdges.find(n1)->second; - assert(inList.size() == 1); - Edge e1 = *inList.begin(); - // Update graph by adding edges - this->add_edge(e1.srcOp, new_node, e1.srcIdx, 0); - this->remove_edge(e1); - this->remove_edge(e2); - // make a copy of outList - if (this->outEdges.find(n2) != this->outEdges.end()) { - auto const outList = this->outEdges.find(n2)->second; - for (auto const &e : outList) { - this->add_edge(new_node, e.dstOp, 0, e.dstIdx); - this->remove_edge(e); - } - } - simplify = true; - } - } - if (simplify) { - break; - } - } - } - } - - if (settings.remove_trailing_parallel_ops) { - // Remove final parallel ops - std::vector candidates; - for (auto const &it : this->outEdges) { - if (it.second.size() == 0 && it.first.ptr->op_type != OP_REDUCTION && - it.first.ptr->op_type != OP_FUSED_PARALLEL && - it.first.ptr->is_parallel_op()) { - candidates.push_back(it.first); - } - } - size_t index = 0; - while (index < candidates.size()) { - Node parallel_op = candidates[index++]; - auto const &inList = this->inEdges.find(parallel_op)->second; - assert(inList.size() == 1); - Edge e = *inList.begin(); - this->remove_edge(e); - if (this->outEdges.find(e.srcOp)->second.size() == 0 && - e.srcOp.ptr->is_parallel_op()) { - candidates.push_back(e.srcOp); - } - } - } - - if (settings.remove_noops) { - // Remove NoOps - std::vector noop_nodes; - for (auto const &it : this->inEdges) { - if (it.first.ptr == NULL) { - continue; - } - if (it.first.ptr->op_type == OP_NOOP) { - noop_nodes.push_back(it.first); - } - } - size_t index = 0; - while (index < noop_nodes.size()) { - Node noop = noop_nodes[index++]; - auto const &inList = this->inEdges.find(noop)->second; - assert(inList.size() == 1); - Edge in_edge = *inList.begin(); - // make a copy of outList - if (this->outEdges.find(noop) != this->outEdges.end()) { - auto const outList = this->outEdges.find(noop)->second; - for (auto const &e : outList) { - this->add_edge(in_edge.srcOp, e.dstOp, in_edge.srcIdx, e.dstIdx); - this->remove_edge(e); - } - } - this->remove_edge(in_edge); - } - } -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/simplification.h b/lib/compiler/src/old/simplification.h deleted file mode 100644 index d83c16eb91..0000000000 --- a/lib/compiler/src/old/simplification.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SIMPLIFICATION_H -#define _FLEXFLOW_FFC_SIMPLIFICATION_H - -#include "graph.h" -#include "spdlog/spdlog.h" -#include - -namespace FlexFlow { -namespace PCG { - -struct SimplificationSettings { - bool simplify_parallel_ops = false; - bool fuse_parallel_ops = false; - bool remove_trailing_parallel_ops = false; - bool remove_noops = false; -}; - -class Simplifier { -public: - Simplifier(std::string const &logger_name); - - Graph const &simplify(SimplificationSettings const &, Graph const &); - -private: - void simplify_parallel_ops(); - -private: - std::shared_ptr logger; -}; - -} // namespace PCG -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/split_types.cc b/lib/compiler/src/old/split_types.cc deleted file mode 100644 index e9648344d4..0000000000 --- a/lib/compiler/src/old/split_types.cc +++ /dev/null @@ -1,36 +0,0 @@ -#include "split_types.h" - -namespace FlexFlow { -namespace PCG { - -/*static*/ -NonsequenceSplit NonsequenceSplit::sequential() { - NonsequenceSplit s; - s.type = SplitType::SEQUENTIAL; - s.flip_graphs = false; - - return s; -} - -/*static*/ -NonsequenceSplit NonsequenceSplit::vertical(int param, bool flip_graphs) { - NonsequenceSplit s; - s.type = SplitType::VERTICAL; - s.param = param; - s.flip_graphs = flip_graphs; - - return s; -} - -/*static*/ -NonsequenceSplit NonsequenceSplit::horizontal(int param, bool flip_graphs) { - NonsequenceSplit s; - s.type = SplitType::HORIZONTAL; - s.param = param; - s.flip_graphs = flip_graphs; - - return s; -} - -} // namespace PCG -} // namespace FlexFlow diff --git a/lib/compiler/src/old/split_types.h b/lib/compiler/src/old/split_types.h deleted file mode 100644 index 3c49ad5b7a..0000000000 --- a/lib/compiler/src/old/split_types.h +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef _FLEXFLOW_FFC_SPLIT_TYPES_H -#define _FLEXFLOW_FFC_SPLIT_TYPES_H - -#include "node.h" -#include "pcg/machine_view.h" - -namespace FlexFlow { -namespace PCG { - -enum class SplitType { SEQUENTIAL, VERTICAL, HORIZONTAL }; - -struct NonsequenceSplit { - SplitType type; - int param; - bool flip_graphs; - - static NonsequenceSplit sequential(); - static NonsequenceSplit vertical(int param, bool flip_graphs); - static NonsequenceSplit horizontal(int param, bool flip_graphs); -}; - -struct NodeAssignment { - Node node; - MachineView view; -}; - -using SequenceSplit = NodeAssignment; - -} // namespace PCG -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/old/substitution.cc b/lib/compiler/src/old/substitution.cc deleted file mode 100644 index 9f8381093c..0000000000 --- a/lib/compiler/src/old/substitution.cc +++ /dev/null @@ -1,3733 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "substitution.h" -#include "graph.h" -#include "graph_structures.h" -#include "op-meta/op-meta.h" -#include "parallel_ops/combine.h" -#include "parallel_ops/fused_parallel_op.h" -#include "parallel_ops/partition.h" -#include "parallel_ops/reduction.h" -#include "parallel_ops/replicate.h" -#include "utils/dot/dot_file.h" -#include -#include - -using namespace ::FlexFlow::substitutions; - -namespace FlexFlow { -namespace ffc { - -const TensorX TensorX::NO_TX = TensorX(); - -bool TensorX::operator==(TensorX const &other) const { - return this->op == other.op && this->idx == other.idx; -} - -bool TensorX::operator!=(TensorX const &other) const { - return !this->operator==(other); -} - -Rule create_combine_inception(int num_convs, int num_dims, int num_parts); -Rule create_combine_concat(int num_inputs, int num_dims, int num_parts); -Rule create_replicate_linear_combine(int num_dims, - int num_parts, - ActiMode activation, - bool use_bias); -Rule create_partition_linear_combine(int num_dims, - int num_parts, - ActiMode activation, - bool use_bias); -Rule create_partition_conv2d_combine(int num_dims, int num_parts); -Rule create_partition_attention_combine(int num_heads, int num_parts); -Rule create_replicate_attention_reduce(int num_heads, int num_parts); -Rule create_partition_add_combine(int parallel_dim, int num_parts); -Rule create_partition_relu_combine(int parallel_dim, int num_parts); -Rule create_partition_concat_combine(int num_inputs, - int concat_dim, - int parallel_dim, - int num_parts); -Rule create_partition_softmax_combine(int softmax_dim, - int part_dim, - int num_parts); -Rule leading_relu_branch_combine(int parallel_dim, - int num_parts, - int num_combines); -Rule leading_relu_branch_partition(int parallel_dim, - int num_parts, - int num_partitions); -Rule create_linear_relu_merge(int num_dims, bool use_bias); - -PMConstraint::PMConstraint(Compare c, PMParameter p, int v) - : comp(c), para(p), value(v) {} - -TNConstraint::TNConstraint(Compare c, TNParameter p, DIMParameter d, int v) - : singlePara(true), comp(c), para1(p), dim1(d), value(v) {} - -TNConstraint::TNConstraint( - Compare c, TNParameter p1, DIMParameter d1, TNParameter p2, DIMParameter d2) - : singlePara(false), comp(c), para1(p1), para2(p2), dim1(d1), dim2(d2) {} - -tl::optional TensorX::to_tensor(GraphXfer const *xfer) const { - if (op != NULL) { - assert(op->mapOp.ptr != NULL); - return op->mapOp.ptr->outputs[idx]; - } else { - auto const &it = xfer->mappedInputs.find(idx); - if (it == xfer->mappedInputs.end()) { - return tl::nullopt; - } - assert(it != xfer->mappedInputs.end()); - Node op = it->second.first; - int outIdx = it->second.second; - return op.ptr->outputs[outIdx]; - } -} - -OpX::OpX(const OperatorType _type, - int num_inputs, - int num_outputs, - TensorX const &input0, - TensorX const &input1, - TensorX const &input2, - TensorX const &input3) - : type(_type), mapOp(Node::INVALID_NODE), matchOpX(NULL) { - TensorX all_inputs[MAX_NUM_INPUTS]; - all_inputs[0] = input0; - all_inputs[1] = input1; - all_inputs[2] = input2; - all_inputs[3] = input3; - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(all_inputs[i]); - } - for (int i = 0; i < num_outputs; i++) { - TensorX out(this, i); - outputs.push_back(out); - } -} - -OpX::OpX(const OperatorType _type, - int num_inputs, - int num_outputs, - TensorX const *input_array) - : type(_type), mapOp(Node::INVALID_NODE), matchOpX(NULL) { - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(input_array[i]); - } - for (int i = 0; i < num_outputs; i++) { - TensorX out(this, i); - outputs.push_back(out); - } -} - -bool OpX::add_pm_constraint(Compare comp, PMParameter para, int value) { - PMConstraint pmc(comp, para, value); - pmConstraints.push_back(pmc); - return true; -} - -bool OpX::add_input_constraint(Compare comp, - TNParameter para, - DIMParameter dim, - int value) { - TNConstraint tnc(comp, para, dim, value); - tnConstraints.push_back(tnc); - return true; -} - -bool OpX::add_input_constraint(Compare comp, - TNParameter para1, - DIMParameter dim1, - TNParameter para2, - DIMParameter dim2) { - TNConstraint tnc(comp, para1, dim1, para2, dim2); - tnConstraints.push_back(tnc); - return true; -} - -bool OpX::get_pm_constraint(PMParameter para, int &value) const { - for (size_t i = 0; i < pmConstraints.size(); i++) { - if ((pmConstraints[i].comp == COMPARE_EQ) && - (pmConstraints[i].para == para)) { - value = pmConstraints[i].value; - return true; - } - } - return false; -} - -GraphXfer::GraphXfer(FFModel *_model) : model(_model), tensorId(10) {} - -TensorX GraphXfer::new_tensor(void) { - TensorX t; - t.op = NULL; - t.idx = tensorId++; - return t; -} - -bool GraphXfer::map_output(TensorX const &src, TensorX const &dst) { - mappedOutputs[src] = dst; - return true; -} - -bool GraphXfer::can_match(OpX *srcOp, Node const &op, Graph const *graph) { - if (srcOp->type != op.ptr->op_type) { - return false; - } - // check num input tensors - if ((int)srcOp->inputs.size() != op.ptr->numInputs) { - return false; - } - // check pmConstraints - for (size_t i = 0; i < srcOp->pmConstraints.size(); i++) { - PMConstraint pmc = srcOp->pmConstraints[i]; - int actValue = 0; - assert(op.ptr->get_int_parameter(pmc.para, &actValue)); - // printf("pmc[%d] para(%d) comp(%d) value(%d) actValue(%d)\n", - // i, pmc.para, pmc.comp, pmc.value, actValue); - switch (pmc.comp) { - case COMPARE_EQ: { - if (actValue != pmc.value) { - return false; - } - break; - } - case COMPARE_NE: { - if (actValue == pmc.value) { - return false; - } - break; - } - case COMPARE_LT: { - if (actValue >= pmc.value) { - return false; - } - break; - } - case COMPARE_LE: { - if (actValue > pmc.value) { - return false; - } - break; - } - case COMPARE_GT: { - if (actValue <= pmc.value) { - return false; - } - break; - } - case COMPARE_GE: { - if (actValue < pmc.value) { - return false; - } - break; - } - default: - assert(false); - } - } - // check inputs - std::map> newMapInputs; - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // input tensor - std::multimap>::const_iterator it; - it = mappedInputs.find(in.idx); - if (it != mappedInputs.end()) { - Node mappedOp = it->second.first; - int mappedIdx = it->second.second; - if (!(graph->has_edge(mappedOp, op, mappedIdx, i))) { - return false; - } - } else { - std::map>::const_iterator newit; - newit = newMapInputs.find(in.idx); - if (newit != newMapInputs.end()) { - Node mappedOp = newit->second.first; - int mappedIdx = newit->second.second; - if (!(graph->has_edge(mappedOp, op, mappedIdx, i))) { - return false; - } - } else { - auto const &list = graph->inEdges.find(op)->second; - for (auto const &e : list) { - if (e.dstIdx == (int)i) { - newMapInputs.insert( - std::make_pair(in.idx, std::make_pair(e.srcOp, e.srcIdx))); - } - } - } - // Do nothing when we check the match - /* mapped in.idx to an op - std::set list = graph->inEdges.find(op)->second; - std::set::const_iterator it2; - for (it2 = list.begin(); it2 != list.end(); it2++) { - Edge e = *it2; - if (e.dstIdx == i) - mappedInputs[in.idx] = std::make_pair(e.srcOp, e.srcIdx); - }*/ - } - } else { - // intermediate tensor - assert(in.op->mapOp != Node::INVALID_NODE); - if (!(graph->has_edge(in.op->mapOp, op, in.idx, i))) { - return false; - } - } - } - // check tnConstraints - for (size_t i = 0; i < srcOp->tnConstraints.size(); i++) { - TNConstraint tnc = srcOp->tnConstraints[i]; - int actValue = 0, expValue = 0; - if (tnc.singlePara) { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - expValue = tnc.value; - } else { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - assert(op.ptr->get_tensor_parameter(tnc.para2, tnc.dim2, &expValue)); - } - switch (tnc.comp) { - case COMPARE_EQ: { - if (actValue != expValue) { - return false; - } - break; - } - case COMPARE_NE: { - if (actValue == expValue) { - return false; - } - break; - } - case COMPARE_LT: { - if (actValue >= expValue) { - return false; - } - break; - } - case COMPARE_LE: { - if (actValue > expValue) { - return false; - } - break; - } - case COMPARE_GT: { - if (actValue <= expValue) { - return false; - } - break; - } - case COMPARE_GE: { - if (actValue < expValue) { - return false; - } - break; - } - default: - assert(false); - } - } - return true; -} - -void GraphXfer::match(OpX *srcOp, Node const &op, Graph const *graph) { - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // Update mappedInputs - auto const &list = graph->inEdges.find(op)->second; - for (auto const &e : list) { - if (e.dstIdx == (int)i) { - mappedInputs.insert( - std::make_pair(in.idx, std::make_pair(e.srcOp, e.srcIdx))); - } - } - } - } - // Map srcOp to Op - srcOp->mapOp = op; - mappedOps[op] = srcOp; -} - -void GraphXfer::unmatch(OpX *srcOp, Node const &op, Graph const *graph) { - for (size_t i = 0; i < srcOp->inputs.size(); i++) { - log_xfer_matches.spew() << "umatch iteration " << i; - TensorX in = srcOp->inputs[i]; - if (in.op == NULL) { - // Update mappedInputsa - std::multimap>::iterator it; - log_xfer_matches.spew() << "Starting find"; - it = mappedInputs.find(in.idx); - log_xfer_matches.spew() << "Finished find"; - if (it != mappedInputs.end()) { - mappedInputs.erase(it); - } - } - } - log_xfer_matches.spew() << "Finished the unmatch loop"; - // Unmap op - mappedOps.erase(op); - srcOp->mapOp.guid = 0; - srcOp->mapOp.ptr = NULL; - log_xfer_matches.spew() << "Returning from unmatch"; -} - -GraphXferMatch::GraphXferMatch(GraphXfer const *xfer) : xfer(xfer) {} - -void GraphXferMatch::add_mapping(Node const &node, OpX *opx) { - this->nodeToOpX[node] = opx; - this->opXToNode[opx] = node; -} - -void GraphXferMatch::add_mapping(OpX *opx, Node const &node) { - this->add_mapping(node, opx); -} - -void GraphXferMatch::add_output_mapping(TensorX const &src, - TensorX const &dst) { - this->mappedOutputs[src] = dst; -} - -OpX *GraphXferMatch::at(Node const &n) const { - return this->nodeToOpX.at(n); -} - -Node GraphXferMatch::at(OpX *opx) const { - return this->opXToNode.at(opx); -} - -void GraphXferMatch::set_graph(Graph const *g) { - this->graph_hash = g->hash(); -} - -bool GraphXferMatch::containsNode(Graph const *g, Node const &n) const { - assert(g->hash() == this->graph_hash); - - return this->nodeToOpX.find(n) != this->nodeToOpX.end(); -} - -bool GraphXferMatch::containsEdge(Graph const *g, Edge const &e) const { - assert(g->hash() == this->graph_hash); - - bool contains_src = this->containsNode(g, e.srcOp); - bool contains_dst = this->containsNode(g, e.dstOp); - - return contains_src && contains_dst; -} - -GraphXfer const *GraphXferMatch::get_xfer() const { - return this->xfer; -} - -std::unordered_set GraphXferMatch::get_nodes() const { - std::unordered_set nodes; - for (auto const &kv : nodeToOpX) { - nodes.insert(kv.first); - } - - return nodes; -} - -GraphXferMatch GraphXfer::get_match_record(Graph const *g) const { - GraphXferMatch match(this); - - for (auto const &kv : this->mappedOps) { - match.add_mapping(kv.first, kv.second); - } - - for (auto const &kv : this->mappedOutputs) { - match.add_output_mapping(kv.first, kv.second); - } - - match.set_graph(g); - - return match; -} - -void GraphXfer::find_matches(Graph const *graph, - std::vector &matches) { - this->find_matches(0, graph, matches); -} - -void GraphXfer::find_matches(int depth, - Graph const *graph, - std::vector &matches) { - log_xfer_matches.spew() << "find_matches at depth: " << depth; - if (depth >= (int)srcOps.size()) { - log_xfer_matches.spew() << "Achieved adequate depth"; - // Create dst operators - bool pass = true; - for (OpX *dstOp : this->dstOps) { - pass &= create_new_operator(dstOp, dstOp->mapOp); - if (!pass) { - break; - } - } - log_xfer_matches.spew() << "Completed create dst operators"; - if (!pass) { - log_xfer_matches.spew() << "Did not pass. Returning."; - return; - } - log_xfer_matches.spew() << "Checking external edges"; - // Check that output tensors with external edges are mapped - for (auto const &opIt : mappedOps) { - auto const &list = graph->outEdges.at(opIt.first); - for (auto const &e : list) { - if (mappedOps.find(e.dstOp) == mappedOps.end()) { - // dstOp is external, (srcOp, srcIdx) must be in mappedOutputs - TensorX srcTen; - srcTen.op = opIt.second; - srcTen.idx = e.srcIdx; - if (mappedOutputs.find(srcTen) == mappedOutputs.end()) { - pass = false; - return; - } - } - } - } - log_xfer_matches.spew() << "Completed checking external edges"; - // Generate a new graph by applying xfer rule - log_xfer_matches.spew() << "Creating new graph"; - SimplificationSettings - settings; // leave everything disabeld since we don't care about cost - Graph *newGraph = this->create_new_graph(graph, settings); - log_xfer_matches.spew() << "Completed creating new graph"; - - // Check that the new graph should not have any loop - log_xfer_matches.spew() << "Checking for loop"; - if (newGraph->has_loop()) { - printf("Found a new graph with LOOP!!!!\n"); - newGraph->print(); - delete newGraph; - return; - } - log_xfer_matches.spew() << "Finished checking for loop"; - // TODO: remove me for better performance - log_xfer_matches.spew() << "Checking correctness"; - assert(newGraph->check_correctness()); - log_xfer_matches.spew() << "Finished checking correctness"; - log_xfer_matches.spew() << "Getting match record"; - GraphXferMatch match_record = this->get_match_record(graph); - log_xfer_matches.spew() << "Finished getting match record"; - matches.push_back(match_record); - } else { - OpX *srcOp = srcOps[depth]; - for (auto const &it : graph->inEdges) { - log_xfer_matches.spew() << "Exploring node " << it.first.to_string(); - // printf("can_match(%d)\n", can_match(srcOp, it->first, graph)); - if (can_match(srcOp, it.first, graph) && - (mappedOps.find(it.first) == mappedOps.end())) { - Node op = it.first; - // Check mapOutput - this->match(srcOp, op, graph); - this->find_matches(depth + 1, graph, matches); - log_xfer_matches.spew() << "Completed find matches. Unmatching"; - this->unmatch(srcOp, op, graph); - log_xfer_matches.spew() << "Finished unmatching"; - } - } - } -} - -template -void GraphXfer::run( - int depth, - Graph *graph, - std::priority_queue, GraphComparator> - &candidates, - std::unordered_set &hashmap, - float threshold, - int maxNumOps, - SimplificationSettings const &simplification_settings, - int &num_matches_found, - int &num_matches_rejected) { - // printf("run: depth(%d) srcOps.size(%zu) graph.size(%zu) candidates(%zu)\n", - // depth, srcOps.size(), graph->inEdges.size(), candidates.size()); - if (depth >= (int)srcOps.size()) { - // Create dst operators - bool pass = true; - for (OpX *dstOp : this->dstOps) { - if (pass) { - pass &= create_new_operator(dstOp, dstOp->mapOp); - } - } - if (!pass) { - return; - } - // Check that output tensors with external edges are mapped - for (auto const &opIt : mappedOps) { - auto const &list = graph->outEdges[opIt.first]; - for (auto const &e : list) { - if (mappedOps.find(e.dstOp) == mappedOps.end()) { - // dstOp is external, (srcOp, srcIdx) must be in mappedOutputs - TensorX srcTen; - srcTen.op = opIt.second; - srcTen.idx = e.srcIdx; - if (mappedOutputs.find(srcTen) == mappedOutputs.end()) { - pass = false; - return; - } - } - } - } - // Generate a new graph by applying xfer rule - log_xfers.spew() << "Found a match for xfer: " << this->get_name(); - num_matches_found++; - Graph *newGraph = this->create_new_graph(graph, simplification_settings); - // Check that the new graph should not have any loop - if (newGraph->has_loop()) { - printf("Found a new graph with LOOP!!!!\n"); - newGraph->print(); - delete newGraph; - return; - } - // TODO: remove me for better performance - assert(newGraph->check_correctness()); - if (newGraph->optimal_cost() < threshold && - (int)newGraph->inEdges.size() < maxNumOps) { - if (hashmap.find(newGraph->hash()) == hashmap.end()) { - hashmap.insert(newGraph->hash()); - log_xfers.spew() << "Found new candidate"; - // newGraph->print_dot(); - candidates.push(newGraph); - } - } else { - num_matches_rejected++; - delete newGraph; - } - } else { - OpX *srcOp = srcOps[depth]; - for (auto const &it : graph->inEdges) { - // printf("can_match(%d)\n", can_match(srcOp, it->first, graph)); - if (can_match(srcOp, it.first, graph) && - (mappedOps.find(it.first) == mappedOps.end())) { - Node op = it.first; - // Check mapOutput - match(srcOp, op, graph); - run(depth + 1, - graph, - candidates, - hashmap, - threshold, - maxNumOps, - simplification_settings, - num_matches_found, - num_matches_rejected); - unmatch(srcOp, op, graph); - } - } - } -} - -void Graph::reshape_output_tensor(ParallelTensorShape const &desired_shape) { - Node output_node = this->find_sink_node(); - - assert(output_node.ptr->numOutputs == 1); - ParallelTensor output_tensor = output_node.ptr->outputs[0]; - - assert(output_tensor->num_dims == desired_shape.num_dims); - - for (int i = 0; i < output_tensor->num_dims; i++) { - int current_size = output_tensor->dims[i].size; - int current_degree = output_tensor->dims[i].degree; - - int desired_size = desired_shape.dims[i].size; - int desired_degree = desired_shape.dims[i].degree; - - assert(current_size == desired_size); - - if (current_degree < desired_degree) { - // we need to partition - assert(desired_degree % current_degree == 0); - int partition_factor = desired_degree / current_degree; - - Node partition_node = model->get_or_create_node( - output_tensor, {i /*legion_dim*/, partition_factor}); - this->add_edge(output_node, partition_node, 0, 0); - - output_node = partition_node; - output_tensor = partition_node.ptr->outputs[0]; - current_degree *= partition_factor; - - } else if (current_degree > desired_degree) { - // we need to combine - assert(current_degree % desired_degree == 0); - int combine_factor = current_degree / desired_degree; - - Node combine_node = model->get_or_create_node( - output_tensor, {i /*legion_dim*/, combine_factor}); - this->add_edge(output_node, combine_node, 0, 0); - - output_node = combine_node; - output_tensor = combine_node.ptr->outputs[0]; - current_degree /= combine_factor; - } - - assert(current_degree == desired_degree); - } - - assert(output_tensor == output_node.ptr->outputs[0]); - assert(output_tensor->num_dims == desired_shape.num_dims); - for (int i = 0; i < desired_shape.num_dims; i++) { - assert(output_tensor->dims[i].size == desired_shape.dims[i].size); - assert(output_tensor->dims[i].degree == desired_shape.dims[i].degree); - } -} - -std::unique_ptr Graph::with_output_tensor_reshaped_to( - ParallelTensorShape const &shape) const { - auto g = std::unique_ptr(new Graph(*this)); - g->reshape_output_tensor(shape); - return g; -} - -/* Graph::Graph(Graph const &graph) */ -/* : Graph(&graph) */ -/* { } */ - -/* Graph::Graph(Graph const *graph) */ -/* : Graph(graph->model) */ -/* { */ -/* for (auto const &kv : graph->inEdges) { */ -/* Node const &node = kv.first; */ -/* std::unordered_set const &edge_set = kv.second; */ - -/* for (auto const &edge : edge_set) { */ -/* this->add_edge(edge.srcOp, edge.dstOp, edge.srcIdx) */ -/* } */ -/* } */ -/* } */ - -Graph *GraphXfer::create_new_graph( - Graph const *graph, SimplificationSettings const &simplification_settings) { - Graph *newGraph = new Graph(model); - // Step 1: map dst ops - std::vector::const_iterator dstIt; - // Step 2: add edges to the graph - for (auto const &opIt : graph->inEdges) { - if (mappedOps.find(opIt.first) == mappedOps.end()) { - // Unmapped ops - auto const &list = opIt.second; - for (auto const &it : list) { - if (mappedOps.find(it.srcOp) != mappedOps.end()) { - // mapped src -> unmapped dst - TensorX srcTen; - srcTen.op = mappedOps[it.srcOp]; - srcTen.idx = it.srcIdx; - assert(mappedOutputs.find(srcTen) != mappedOutputs.end()); - TensorX dstTen = mappedOutputs[srcTen]; - newGraph->add_edge(dstTen.op->mapOp, it.dstOp, dstTen.idx, it.dstIdx); - } else { - // unmapped src -> unmmaped dst - newGraph->add_edge(it.srcOp, it.dstOp, it.srcIdx, it.dstIdx); - } - } - } - } - // Step 3: add edges for mapped ops - for (dstIt = dstOps.begin(); dstIt != dstOps.end(); dstIt++) { - OpX *dstOp = *dstIt; - for (size_t i = 0; i < dstOp->inputs.size(); i++) { - if (dstOp->inputs[i].op == NULL) { - // unmapped src -> mapped dst - std::multimap>::const_iterator it = - mappedInputs.find(dstOp->inputs[i].idx); - assert(it != mappedInputs.end()); - std::pair const &srcEdge = it->second; - newGraph->add_edge(srcEdge.first, dstOp->mapOp, srcEdge.second, i); - } else { - // mapped src -> mapped dst - OpX *srcOp = dstOp->inputs[i].op; - int srcIdx = dstOp->inputs[i].idx; - newGraph->add_edge(srcOp->mapOp, dstOp->mapOp, srcIdx, i); - } - } - } - newGraph->simplify(simplification_settings); - - return newGraph; -} - -bool GraphXfer::create_new_operator(OpX const *opx, Node &op) { - ParallelTensor inputs[MAX_NUM_INPUTS]; - for (size_t i = 0; i < opx->inputs.size(); i++) { - tl::optional mapped = opx->inputs[i].to_tensor(this); - if (!mapped.has_value()) { - return false; - } - inputs[i] = mapped.value(); - } - // Check that the total degree of inputs[0] does not exceed available - // resources - if (opx->inputs.size() > 0) { - int degree = 1; - for (int i = 0; i < inputs[0]->num_dims; i++) { - degree *= inputs[0]->dims[i].degree; - } - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - return false; - } - } - int num_inputs; - if (opx->get_pm_constraint(PM_NUM_INPUTS, num_inputs) && - opx->inputs.size() != num_inputs) { - return false; - } - int num_outputs; - if (opx->get_pm_constraint(PM_NUM_OUTPUTS, num_outputs) && - opx->outputs.size() != num_outputs) { - return false; - } - switch (opx->type) { - case OP_NOOP: { - op = model->get_or_create_noop_node(inputs[0]); - break; - } - case OP_CONCAT: { - int axis; - assert(opx->get_pm_constraint(PM_AXIS, axis)); - op = model->get_or_create_node( - {std::begin(inputs), std::end(inputs)}, {axis}); - break; - } - case OP_SPLIT: { - int axis; - assert(opx->get_pm_constraint(PM_AXIS, axis)); - int num_outputs = opx->outputs.size(); - int input_size = inputs[0]->dims[axis].size; - - if (input_size % num_outputs != 0) { - op = Node::INVALID_NODE; - } else { - int split_size = input_size / num_outputs; - std::vector split_sizes(num_outputs, split_size); - assert(split_sizes.size() == num_outputs); - op = model->get_or_create_node(inputs[0], {split_sizes, axis}); - } - break; - } - case OP_EW_ADD: - case OP_EW_SUB: - case OP_EW_MUL: - case OP_EW_MAX: - case OP_EW_MIN: { - op = model->get_or_create_node({inputs[0], inputs[1]}, - {opx->type}); - break; - } - case OP_RELU: { - ElementUnaryParams params; - params.op_type = opx->type; - params.inplace = false; - params.scalar = 0.0f; - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_CONV2D: { - Conv2D *conv = (Conv2D *)opx->matchOpX->mapOp.ptr; - Conv2DParams params = conv->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_POOL2D: { - Pool2D *pool = (Pool2D *)opx->matchOpX->mapOp.ptr; - Pool2DParams params = pool->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_FLAT: { - Flat *flat = (Flat *)opx->matchOpX->mapOp.ptr; - op = model->get_or_create_node(inputs[0], {}); - break; - } - case OP_LINEAR: { - int activation; - assert(opx->matchOpX != NULL); - assert(opx->matchOpX->mapOp.ptr != NULL); - Linear *linear = (Linear *)opx->matchOpX->mapOp.ptr; - // assert(opx->get_pm_constraint(PM_OUTPUT_CHANNELS, output_channels)); - assert(opx->get_pm_constraint(PM_ACTI, activation)); - LinearParams params = linear->get_params(); - op = model->get_or_create_node(inputs[0], params); - break; - } - case OP_MULTIHEAD_ATTENTION: { - int num_heads; - assert(opx->matchOpX != NULL); - assert(opx->matchOpX->mapOp.ptr != NULL); - MultiHeadAttention *attn = (MultiHeadAttention *)opx->matchOpX->mapOp.ptr; - assert(opx->get_pm_constraint(PM_NUM_HEADS, num_heads)); - MultiHeadAttentionParams params = attn->get_params(); - op = model->get_or_create_node( - {inputs[0], inputs[1], inputs[2]}, params); - break; - } - case OP_SOFTMAX: { - int softmax_dim; - assert(opx->get_pm_constraint(PM_SOFTMAX_DIM, softmax_dim)); - op = model->get_or_create_node(inputs[0], {softmax_dim}); - break; - } - case OP_REPARTITION: { - int repartition_dim, repartition_degree; - assert(opx->get_pm_constraint(PM_REPARTITION_DIM, repartition_dim)); - assert(opx->get_pm_constraint(PM_REPARTITION_DEGREE, repartition_degree)); - - int degree = inputs[0]->get_total_num_parts() * repartition_degree; - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - op = Node::INVALID_NODE; - } else { - op = model->get_or_create_node( - inputs[0], {repartition_dim, repartition_degree}); - } - break; - } - case OP_REPLICATE: { - int replicate_dim, replicate_degree; - assert(opx->get_pm_constraint(PM_REPLICATE_DIM, replicate_dim)); - assert(opx->get_pm_constraint(PM_REPLICATE_DEGREE, replicate_degree)); - - if (inputs[0]->dims[replicate_dim].degree * replicate_degree > - model->config.workersPerNode) { - op = Node::INVALID_NODE; - } else { - int degree = inputs[0]->get_total_num_parts() * replicate_degree; - if (degree > model->config.workersPerNode * model->config.numNodes && - (degree > model->config.cpusPerNode * model->config.numNodes)) { - op = Node::INVALID_NODE; - } else { - op = model->get_or_create_node( - inputs[0], {replicate_dim, replicate_degree}); - } - } - break; - } - case OP_REDUCTION: { - int reduction_dim, reduction_degree; - assert(opx->get_pm_constraint(PM_REDUCTION_DIM, reduction_dim)); - assert(opx->get_pm_constraint(PM_REDUCTION_DEGREE, reduction_degree)); - op = model->get_or_create_node( - inputs[0], {reduction_dim, reduction_degree}); - break; - } - case OP_COMBINE: { - int combine_dim, combine_degree; - assert(opx->get_pm_constraint(PM_COMBINE_DIM, combine_dim)); - assert(opx->get_pm_constraint(PM_COMBINE_DEGREE, combine_degree)); - op = model->get_or_create_node(inputs[0], - {combine_dim, combine_degree}); - break; - } - default: { - std::cout << "opx->type = " << get_operator_type_name(opx->type) - << std::endl; - assert(false); - } - } - // Check operator validness - if (op == Node::INVALID_NODE) { - return false; - } - // Check tnConstraints - for (size_t i = 0; i < opx->tnConstraints.size(); i++) { - TNConstraint tnc = opx->tnConstraints[i]; - int actValue = 0, expValue = 0; - if (tnc.singlePara) { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - expValue = tnc.value; - } else { - assert(op.ptr->get_tensor_parameter(tnc.para1, tnc.dim1, &actValue)); - assert(op.ptr->get_tensor_parameter(tnc.para2, tnc.dim2, &expValue)); - } - switch (tnc.comp) { - case COMPARE_EQ: - if (actValue != expValue) { - return false; - } - break; - case COMPARE_NE: - if (actValue == expValue) { - return false; - } - break; - case COMPARE_LT: - if (actValue >= expValue) { - return false; - } - break; - case COMPARE_LE: - if (actValue > expValue) { - return false; - } - break; - case COMPARE_GT: - if (actValue <= expValue) { - return false; - } - break; - case COMPARE_GE: - if (actValue < expValue) { - return false; - } - break; - default: - assert(false); - } - } - return true; -} - -OpX *GraphXfer::create_noop(TensorX const &input) { - OpX *noop = new OpX(OP_NOOP, 1, 1, input); - return noop; -} - -OpX *GraphXfer::create_concat(TensorX const *inputs, - int num_inputs, - OpX const *_matchOpX, - int concat_dim) { - OpX *concat = new OpX(OP_CONCAT, num_inputs, 1 /*outputs*/, inputs); - concat->matchOpX = _matchOpX; - concat->add_pm_constraint(COMPARE_EQ, PM_AXIS, concat_dim); - return concat; -} - -OpX *GraphXfer::create_element_unary(TensorX const &input, - OperatorType op_type) { - OpX *eu = new OpX(op_type, 1 /*numInputs*/, 1, input); - return eu; -} - -OpX *GraphXfer::create_relu(TensorX const &input) { - return this->create_element_unary(input, OP_RELU); -} - -OpX *GraphXfer::create_element_binary(TensorX const &input1, - TensorX const &input2, - OperatorType op_type) { - OpX *eb = new OpX(op_type, 2 /*numInputs*/, 1, input1, input2); - return eb; -} - -OpX *GraphXfer::create_linear(TensorX const &input, - OpX const *_matchOpX, - int num_dims, - ActiMode acti_mode, - bool use_bias) { - // TODO FIXME @lockshaw @zhihao use_bias is completely unused - OpX *li = new OpX(OP_LINEAR, 1, 1, input); - li->matchOpX = _matchOpX; - // li->add_pm_constraint(COMPARE_EQ, PM_OUTPUT_CHANNELS, out_channels); - li->add_pm_constraint(COMPARE_EQ, PM_ACTI, acti_mode); - li->add_input_constraint(COMPARE_EQ, INPUT_0, DIM_ND, num_dims); - return li; -} - -OpX *GraphXfer::create_conv2d(TensorX const &input, OpX const *matchOpX) { - OpX *conv = new OpX(OP_CONV2D, 1, 1, input); - conv->matchOpX = matchOpX; - return conv; -} - -OpX *GraphXfer::create_pool2d(TensorX const &input, OpX const *matchOpX) { - OpX *pool = new OpX(OP_POOL2D, 1, 1, input); - pool->matchOpX = matchOpX; - return pool; -} - -OpX *GraphXfer::create_attention(TensorX const &query, - TensorX const &key, - TensorX const &value, - OpX const *_matchOpX, - int num_heads) { - OpX *attn = new OpX(OP_MULTIHEAD_ATTENTION, 3, 1, query, key, value); - attn->matchOpX = _matchOpX; - attn->add_pm_constraint(COMPARE_EQ, PM_NUM_HEADS, num_heads); - attn->add_input_constraint(COMPARE_EQ, INPUT_0, DIM_ND, 4); - attn->add_input_constraint(COMPARE_EQ, INPUT_1, DIM_ND, 4); - attn->add_input_constraint(COMPARE_EQ, INPUT_2, DIM_ND, 4); - return attn; -} - -OpX *GraphXfer::create_softmax(TensorX const &input, int softmax_dim) { - OpX *softmax = new OpX(OP_SOFTMAX, 1, 1, input); - softmax->add_pm_constraint(COMPARE_EQ, PM_SOFTMAX_DIM, softmax_dim); - return softmax; -} - -OpX *GraphXfer::create_repartition(TensorX const &input, - int repartition_dim, - int num_parts) { - OpX *part = new OpX(OP_REPARTITION, 1, 1, input); - part->add_pm_constraint(COMPARE_EQ, PM_REPARTITION_DIM, repartition_dim); - part->add_pm_constraint(COMPARE_EQ, PM_REPARTITION_DEGREE, num_parts); - return part; -} - -OpX *GraphXfer::create_replicate(TensorX const &input, - int replicate_dim, - int num_parts) { - OpX *replicate = new OpX(OP_REPLICATE, 1, 1, input); - replicate->add_pm_constraint(COMPARE_EQ, PM_REPLICATE_DIM, replicate_dim); - replicate->add_pm_constraint(COMPARE_EQ, PM_REPLICATE_DEGREE, num_parts); - return replicate; -} - -OpX *GraphXfer::create_reduction(TensorX const &input, - int reduction_dim, - int num_parts) { - OpX *reduction = new OpX(OP_REDUCTION, 1, 1, input); - reduction->add_pm_constraint(COMPARE_EQ, PM_REDUCTION_DIM, reduction_dim); - reduction->add_pm_constraint(COMPARE_EQ, PM_REDUCTION_DEGREE, num_parts); - return reduction; -} - -OpX *GraphXfer::create_combine(TensorX const &input, - int combine_dim, - int num_parts) { - OpX *part = new OpX(OP_COMBINE, 1, 1, input); - part->add_pm_constraint(COMPARE_EQ, PM_COMBINE_DIM, combine_dim); - part->add_pm_constraint(COMPARE_EQ, PM_COMBINE_DEGREE, num_parts); - return part; -} - -void Graph::print_strategy_computation_graph( - std::unordered_map const &strategy) const { - DotFile dot(std::cout); - this->export_strategy_computation_graph(strategy, dot); -} - -void Graph::export_strategy_computation_graph( - std::unordered_map const &strategy, - std::string const &out_filename) const { - DotFile dot(out_filename); - - this->export_strategy_computation_graph(strategy, dot); -} - -void Graph::export_strategy_computation_graph( - std::unordered_map const &strategy, - DotFile &dot) const { - using FlexFlow::PCG::Utils::GraphStructure; - - GraphStructure s; - - for (auto const &node : s.get_nodes(*this)) { - // Add node - if (strategy.find(node) == strategy.end()) { - // Check FusedParallel node here and print out the detailed information - if (node.ptr->op_type == OperatorType::OP_FUSED_PARALLEL) { - RecordFormatter rf; - std::vector rows{}; - - FusedParallelOp *fused_op = (FusedParallelOp *)node.ptr; - for (int i = 0; i < fused_op->num_parallel_ops; i++) { - RecordFormatter row{}; - ParallelOpInfo op_info = fused_op->parallel_ops[i]; - std::string op_type_str = get_operator_type_name(op_info.op_type); - row << op_type_str << "dim: " + std::to_string(op_info.parallel_dim) - << "degree: " + std::to_string(op_info.parallel_degree); - rows.emplace_back(row); - } - rf << node.to_string(); - for (auto &r : rows) { - rf << r; - } - dot.add_record_node(node, rf); - } else { - dot.add_node(node, {{"label", node.to_string()}}); - } - } else { - RecordFormatter rf, meta_row, machine_view_row, runtime_code, memory_code, - runtime_cost_row, memory_cost_row; - MachineView mv = strategy.at(node); - std::ostringstream oss; - CostMetrics op_cost = - this->model->simulator->measure_operator_cost(node.ptr, mv); - switch (node.ptr->op_type) { - case OP_REPARTITION: { - Repartition *rp = (Repartition *)node.ptr; - meta_row << std::to_string(rp->repartition_dim) - << std::to_string(rp->repartition_degree); - break; - } - case OP_COMBINE: { - Combine *c = (Combine *)node.ptr; - meta_row << std::to_string(c->combine_dim) - << std::to_string(c->combine_degree); - break; - } - case OP_REPLICATE: { - Replicate *r = (Replicate *)node.ptr; - meta_row << std::to_string(r->replicate_dim) - << std::to_string(r->replicate_degree); - break; - } - case OP_REDUCTION: { - Reduction *r = (Reduction *)node.ptr; - meta_row << std::to_string(r->reduction_dim) - << std::to_string(r->reduction_degree); - break; - } - default: { - if (mv.ndims == 0) { - meta_row << "N/A"; - } else { - for (int i = 0; i < mv.ndims; i++) { - meta_row << std::to_string(mv.dim[i]); - } - } - } - } - - // Fetch machine view information - for (int device_id : mv.device_ids()) { - machine_view_row << std::to_string(device_id); - } - rf << node.to_string() << std::to_string(node.guid) << meta_row - << machine_view_row; - - // get memory cost - if (this->model->config.include_costs_dot_graph) { - float input_mem = (float)op_cost.inputs_memory; - if (node.ptr->numInputs > 0) { - input_mem /= (*node.ptr->inputs)->get_total_num_parts(); - } - float output_mem = (float)op_cost.outputs_memory; - if (node.ptr->numOutputs > 0) { - output_mem /= (*node.ptr->outputs)->get_total_num_parts(); - } - float weight_mem = (float)op_cost.weights_memory; - if (node.ptr->numWeights > 0) { - weight_mem /= (*node.ptr->weights)->get_total_num_parts(); - } - - runtime_code << "fwd" - << "bwd" - << "sync" - << "secs"; - runtime_cost_row << op_cost.forward_time << op_cost.backward_time - << op_cost.sync_time; - memory_code << "in" - << "out" - << "weight" - << "bytes"; - memory_cost_row << input_mem << output_mem << weight_mem; - rf << runtime_code << runtime_cost_row << memory_code - << memory_cost_row; - } - - dot.add_record_node(node, rf); - } - - // Add edges - for (auto const &edge : s.get_incoming_edges(*this, node)) { - dot.add_edge(s.get_src(*this, edge), s.get_dst(*this, edge)); - } - } - - dot.close(); -} - -template -void create_mapping_xfers( - FFModel *model, - int degree, - std::vector &xfers, - tl::optional> dims = tl::nullopt) { - std::vector records; - T::construct_output_mappings(records); - std::unordered_map output_mappings; - - std::unordered_set all_dims; - for (ParallelDimMappingRecord const &record : records) { - assert(record.input_idx == 0); - assert(record.get_type() == MappingRecordType::INPUT_OUTPUT); - assert(record.output_idx == 0); - assert(record.operation.has_value()); - - all_dims.insert(record.input_dim); - output_mappings.insert({record.input_dim, record}); - } - - if (dims.has_value()) { - all_dims = dims.value(); - } - - for (int const input_dim : all_dims) { - int output_dim = output_mappings.at(input_dim).output_dim; - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - - OpX *original_op = subst->create_opx(input, NULL /*matchOpX*/); - subst->srcOps.push_back(original_op); - - OpX *pre; - std::string pre_name; - switch (output_mappings.at(input_dim).operation.value()) { - case MappingOperation::PARTITION: - pre = subst->create_repartition(input, input_dim, degree); - pre_name = "partition"; - break; - case MappingOperation::REPLICATE: - pre = subst->create_replicate(input, input_dim, degree); - pre_name = "replicate"; - break; - } - subst->dstOps.push_back(pre); - - OpX *new_op = - subst->create_opx(pre->outputs[0], original_op /*matchOpX*/); - subst->dstOps.push_back(new_op); - - OpX *post; - std::string post_name; - switch (output_mappings.at(input_dim).operation.value()) { - case MappingOperation::PARTITION: - post = subst->create_combine(new_op->outputs[0], output_dim, degree); - post_name = "combine"; - break; - case MappingOperation::REPLICATE: - post = subst->create_reduction(new_op->outputs[0], output_dim, degree); - post_name = "reduce"; - break; - } - subst->dstOps.push_back(post); - - subst->map_output(original_op->outputs[0], post->outputs[0]); - - std::ostringstream oss; - std::string op_type_name = get_operator_type_name(new_op->type); - std::transform(op_type_name.begin(), - op_type_name.end(), - op_type_name.begin(), - [](unsigned char c) { return std::tolower(c); }); - oss << "mapping::" << pre_name << "_" << op_type_name << "_" << post_name - << "[" - << "input_dim=" << input_dim << ",degree=" << degree << "]"; - subst->name = oss.str(); - - xfers.push_back(subst); - } -} - -std::string GraphXfer::get_name() const { - if (this->name.has_value()) { - return this->name.value(); - } else { - std::ostringstream oss; - oss << "unknown_xfer(" << this << ")"; - return oss.str(); - } -} - -/* int get_num_outputs(sl::Operator const &op) { */ -/* switch (op.op_type) { */ -/* case OP_SPLIT: */ -/* return op.at(PM_NUM_OUTPUTS).value(); */ -/* default: */ -/* return 1; */ -/* } */ -/* } */ - -/* int get_num_inputs(sl::Operator const &op) { */ -/* switch (op.op_type) { */ -/* case OP_EW_ADD: // binary ops */ -/* case OP_EW_SUB: */ -/* case OP_EW_MUL: */ -/* case OP_EW_DIV: */ -/* case OP_EW_EQUAL: */ -/* case OP_EW_GREATER: */ -/* case OP_EW_LESS: */ -/* case OP_EW_MAX: */ -/* case OP_EW_MIN: */ -/* return 2; */ -/* case OP_SPLIT: */ -/* return 1; */ -/* case OP_LINEAR: */ -/* return 1; */ -/* case OP_CONV2D: */ -/* return 1; */ -/* case OP_RELU: */ -/* case OP_IDENTITY: */ -/* case OP_SIGMOID: */ -/* case OP_TANH: */ -/* case OP_ELU: */ -/* return 1; */ -/* case OP_CONCAT: */ -/* return op.at(PM_NUM_INPUTS).value(); */ -/* case OP_INPUT: */ -/* return 0; */ -/* case OP_REPARTITION: */ -/* case OP_COMBINE: */ -/* case OP_REPLICATE: */ -/* case OP_REDUCTION: */ -/* case OP_PIPELINE: */ -/* return 1; */ -/* default: */ -/* throw std::runtime_error("Unknown num_inputs for operator " + */ -/* get_operator_type_name(op.op_type)); */ -/* } */ -/* } */ - -OpX *create_opx(sl::Operator const &op, - int parallel_degree, - TensorX const &input1, - TensorX const &input2, - TensorX const &input3, - TensorX const &input4) { - int num_inputs = get_num_inputs(op); - int num_outputs = get_num_outputs(op); - - OpX *opx = new OpX( - op.op_type, num_inputs, num_outputs, input1, input2, input3, input4); - for (sl::Parameter const &p : op.para) { - if (p.key == PM_PARALLEL_DEGREE) { - tl::optional degree_key = tl::nullopt; - switch (op.op_type) { - case OP_REPARTITION: - degree_key = PM_REPARTITION_DEGREE; - break; - case OP_COMBINE: - degree_key = PM_COMBINE_DEGREE; - break; - case OP_REDUCTION: - degree_key = PM_REDUCTION_DEGREE; - break; - case OP_REPLICATE: - degree_key = PM_REPLICATE_DEGREE; - break; - } - - if (degree_key.has_value()) { - // Assume the generator only consider a parallel degree of 2 - assert(p.value == 2); - opx->add_pm_constraint(COMPARE_EQ, degree_key.value(), parallel_degree); - } - } else if (p.key == PM_PARALLEL_DIM) { - tl::optional dim_key = tl::nullopt; - switch (op.op_type) { - case OP_REPARTITION: - dim_key = PM_REPARTITION_DIM; - break; - case OP_COMBINE: - dim_key = PM_COMBINE_DIM; - break; - case OP_REDUCTION: - dim_key = PM_REDUCTION_DIM; - break; - case OP_REPLICATE: - dim_key = PM_REPLICATE_DIM; - break; - } - - if (dim_key.has_value()) { - opx->add_pm_constraint(COMPARE_EQ, dim_key.value(), p.value); - } - } else if (p.key == PM_PAD) { - opx->add_pm_constraint(COMPARE_EQ, PM_PADDING_H, p.value); - opx->add_pm_constraint(COMPARE_EQ, PM_PADDING_W, p.value); - } else { - opx->add_pm_constraint(COMPARE_EQ, p.key, p.value); - } - } - - return opx; -} - -OpX *find_opx_with_type(std::vector const &src_ops, - OperatorType op_type) { - OpX *matchOpX = nullptr; - for (size_t k = 0; k < src_ops.size(); k++) { - if (src_ops[k]->type == op_type) { - assert(matchOpX == nullptr); - matchOpX = src_ops[k]; - } - } - assert(matchOpX != nullptr); - return matchOpX; -} - -std::vector - create_rule_graph(GraphXfer &xfer, - std::vector const &ops, - std::function const &get_input_tensor, - std::vector *const src_ops, - int parallel_degree) { - std::vector rule_graph; - - for (int i = 0; i < ops.size(); i++) { - sl::Operator const &op = ops[i]; - std::array inputs; - std::fill(inputs.begin(), inputs.end(), TensorX::NO_TX); - - for (int j = 0; j < op.input.size(); j++) { - int opId = op.input[j].opId; - int tsId = op.input[j].tsId; - if (opId < 0) { - inputs[j] = get_input_tensor(opId, tsId); - } else { - inputs[j] = rule_graph[opId]->outputs[tsId]; - } - } - - // We need the matched OpX for constructing conv2d/pool2d/linear - OpX *opx = nullptr; - switch (ops[i].op_type) { - case OP_CONV2D: { - OpX *matchOpX = src_ops == nullptr - ? nullptr - : find_opx_with_type(*src_ops, ops[i].op_type); - opx = xfer.create_conv2d(inputs[0], matchOpX); - break; - } - case OP_POOL2D: { - OpX *matchOpX = src_ops == nullptr - ? nullptr - : find_opx_with_type(*src_ops, ops[i].op_type); - opx = xfer.create_pool2d(inputs[0], matchOpX); - break; - } - default: - opx = create_opx(ops[i], - parallel_degree, - inputs[0], - inputs[1], - inputs[2], - inputs[3]); - } - rule_graph.push_back(opx); - } - - return rule_graph; -} - -void create_xfer(GraphXfer &xfer, sl::Rule const &r, int parallel_degree) { - std::unordered_map, TensorX> input_tensors; - std::function get_input_tensor = - [&xfer, &input_tensors](int opId, int tsId) -> TensorX { - if (input_tensors.find({opId, tsId}) == input_tensors.end()) { - input_tensors[{opId, tsId}] = xfer.new_tensor(); - } - return input_tensors.at({opId, tsId}); - }; - - xfer.srcOps = create_rule_graph( - xfer, r.srcOp, get_input_tensor, nullptr, parallel_degree); - xfer.dstOps = create_rule_graph( - xfer, r.dstOp, get_input_tensor, &xfer.srcOps, parallel_degree); - xfer.name = r.name; - if (xfer.srcOps.size() == 1) { - printf("Here!\n"); - } - - for (sl::MapOutput const &m : r.mappedOutput) { - TensorX srcTensorX = xfer.srcOps[m.srcOpId]->outputs[m.srcTsId]; - TensorX dstTensorX = xfer.dstOps[m.dstOpId]->outputs[m.dstTsId]; - xfer.map_output(srcTensorX, dstTensorX); - } -} - -bool check_opxes_have_same_type_and_constraints(OpX const &src_opx, - OpX const &dst_opx) { - if (src_opx.type != dst_opx.type) { - return false; - } - if (src_opx.pmConstraints.size() != dst_opx.pmConstraints.size()) { - return false; - } - if (src_opx.tnConstraints.size() != dst_opx.tnConstraints.size()) { - return false; - } - for (auto const &c1 : src_opx.pmConstraints) { - bool found_same = false; - for (auto const &c2 : dst_opx.pmConstraints) { - if (c1.comp == c2.comp && c1.para == c2.para && c1.value == c2.value) { - found_same = true; - } - } - if (!found_same) { - return false; - } - } - for (auto const &c1 : src_opx.tnConstraints) { - bool found_same = false; - for (auto const &c2 : dst_opx.tnConstraints) { - if (c1.singlePara && c2.singlePara) { - if (c1.comp == c2.comp && c1.para1 == c2.para1 && c1.dim1 == c2.dim1 && - c1.value == c2.value) { - found_same = true; - } - } else if ((!c1.singlePara) && (!c2.singlePara)) { - if (c1.comp == c2.comp && c1.para1 == c2.para1 && - c1.para2 == c2.para2 && c1.dim1 == c2.dim1 && c1.dim2 == c2.dim2) { - found_same = true; - } - } - } - if (!found_same) { - return false; - } - } - - return true; -} - -std::vector create_xfers(FFModel *model, - sl::RuleCollection const &rules, - int parallel_degree) { - std::vector xfers; - for (sl::Rule const &r : rules.rules) { - GraphXfer *xfer = new GraphXfer(model); - create_xfer(*xfer, r, parallel_degree); - if (xfer->srcOps.size() == 1 && xfer->dstOps.size() == 1) { - delete xfer; - continue; - } - // Pruning redundant xfer - bool found_same_xfer = false; - for (auto const &old_xfer : xfers) { - bool same = true; - if (old_xfer->srcOps.size() != xfer->srcOps.size()) { - same = false; - continue; - } - for (size_t i = 0; i < old_xfer->srcOps.size(); i++) { - if (!check_opxes_have_same_type_and_constraints(*old_xfer->srcOps[i], - *xfer->srcOps[i])) { - same = false; - } - } - if (!same) { - continue; - } - if (old_xfer->dstOps.size() != xfer->dstOps.size()) { - same = false; - continue; - } - for (size_t i = 0; i < old_xfer->dstOps.size(); i++) { - if (!check_opxes_have_same_type_and_constraints(*old_xfer->dstOps[i], - *xfer->dstOps[i])) { - same = false; - } - } - if (same) { - found_same_xfer = true; - break; - } - } - if (!found_same_xfer && xfer->srcOps.size() == 1) { - xfers.push_back(xfer); - } else { - delete (xfer); - } - } - return xfers; -} - -GraphSearchHelper::GraphSearchHelper(FFModel *model) - : model(model), config(model->config), mem_config(1.0) { - this->logger = std::unique_ptr(new RecursiveLogger("gs")); - generate_all_pcg_xfers(); -} - -void GraphSearchHelper::clear_cache() { - cached_optimized_graphs.clear(); -} - -void GraphSearchHelper::load_graph_substitutions( - std::vector &xfers) const { - xfers = all_pcg_xfers; -} - -void GraphSearchHelper::generate_all_pcg_xfers() { - std::vector all_parallel_degrees, single_node_parallel_degrees; - auto const &config = this->model->config; - int workersPerNode = - config.search_num_workers.value_or(config.workersPerNode); - int numNodes = config.search_num_nodes.value_or(config.numNodes); - log_xfers.debug() << "Generating parallel degrees for workersPerNode " - << workersPerNode << " and numNodes " << numNodes; - for (int i = 2; i <= workersPerNode; i++) { - if (workersPerNode % i == 0) { - single_node_parallel_degrees.push_back(i); - all_parallel_degrees.push_back(i); - } - } - for (int i = 2; i <= numNodes; i++) { - if (numNodes % i == 0) { - all_parallel_degrees.push_back(i * workersPerNode); - } - } - { - std::ostringstream oss; - oss << "Generating all_pcg_xfers for all parallel degrees: "; - for (int parallel_degree : all_parallel_degrees) { - oss << parallel_degree << " "; - } - - log_xfers.debug() << oss.str(); - } - - for (auto const &it : single_node_parallel_degrees) { - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_replicate_linear_combine( - this->model, 3, it, AC_MODE_NONE, false)); - if (16 % it == 0) { - all_pcg_xfers.push_back( - create_replicate_attention_reduce(this->model, 16 /*num_heads*/, it)); - } - } - for (auto const &it : all_parallel_degrees) { - all_pcg_xfers.push_back( - create_partition_attention_combine(this->model, 16 /*num_heads*/, it)); - } - - if (config.substitution_json_path.has_value()) { - // Currently only consider a subset of all_parallel_degrees - std::vector considered_parallel_degrees; - considered_parallel_degrees.push_back(workersPerNode); - if (numNodes > 1) { - considered_parallel_degrees.push_back(numNodes * workersPerNode); - } - sl::RuleCollection rule_collection = sl::load_rule_collection_from_path( - config.substitution_json_path.value()); - for (int degree : considered_parallel_degrees) { - std::vector xfers = - create_xfers(this->model, rule_collection, degree); - all_pcg_xfers.insert(all_pcg_xfers.end(), xfers.begin(), xfers.end()); - } - } else { - // Manual substitutions - for (int num_dims = 3; num_dims <= 4; num_dims++) { - all_pcg_xfers.push_back( - create_linear_relu_merge(this->model, num_dims, true)); - all_pcg_xfers.push_back( - create_linear_relu_merge(this->model, num_dims, false)); - } - for (int const degree : all_parallel_degrees) { - create_mapping_xfers(this->model, degree, all_pcg_xfers); - create_mapping_xfers(this->model, degree, all_pcg_xfers); - create_mapping_xfers(this->model, degree, all_pcg_xfers); - } - for (auto const &it : all_parallel_degrees) { - // rewrites for the inception model - for (int i = 3; i <= 6; i++) { - all_pcg_xfers.push_back(create_combine_inception( - this->model, i - 1 /*num_convs*/, 5 /*num_dims*/, it)); - all_pcg_xfers.push_back(create_combine_concat( - this->model, i /*num_inputs*/, 5 /*num_dims*/, it)); - } - // all_pcg_xfers.push_back(create_partition_conv2d_combine(this->model, - // 5/*num_dims*/, it)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 3 /*num_dims*/, it, AC_MODE_NONE, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_RELU, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_SIGMOID, false)); - all_pcg_xfers.push_back(create_partition_linear_combine( - this->model, 4 /*num_dims*/, it, AC_MODE_NONE, false)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 1 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 2 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 3 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_add_combine( - this->model, 4 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_relu_combine( - this->model, 3 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back(create_partition_relu_combine( - this->model, 4 /*parallel_dims*/, it /*num_parts*/)); - all_pcg_xfers.push_back( - create_partition_softmax_combine(this->model, - 0 /*softmax_dim*/, - 1 /*parallel_dims*/, - it /*num_parts*/)); - for (int num_combines = 1; num_combines < 5; num_combines++) { - all_pcg_xfers.push_back(leading_relu_branch_combine( - this->model, 3 /*parallel_dim*/, it /*num_parts*/, num_combines)); - all_pcg_xfers.push_back(leading_relu_branch_partition( - this->model, 3 /*parallel_dim*/, it /*num_parts*/, num_combines)); - } - { - std::unordered_set concat_num_inputs; - for (size_t i = 0; i < this->model->operators.size(); i++) { - if (this->model->operators[i]->op_type == OP_CONCAT) { - concat_num_inputs.insert(this->model->operators[i]->numInputs); - } - } - for (auto const &it2 : concat_num_inputs) { - all_pcg_xfers.push_back( - create_partition_concat_combine(this->model, - it2 /*num_inputs*/, - 0 /*concat_dim*/, - 1 /*parallel_dims*/, - it /*num_parts*/)); - all_pcg_xfers.push_back( - create_partition_concat_combine(this->model, - it2 /*num_inputs*/, - 2 /*concat_dim*/, - 3 /*parallel_dims*/, - it /*num_parts*/)); - } - } - } - } -} - -Graph *GraphSearchHelper::construct_graph() { - Graph *graph = new Graph(this->model); - std::unordered_map op_to_node_map; - for (FlexFlow::Op const *dstOp : this->model->operators) { - Node dstNode; - dstNode.ptr = dstOp; - dstNode.guid = this->model->node_global_guid++; - op_to_node_map[dstOp] = dstNode; - for (int j = 0; j < dstOp->numInputs; j++) { - FlexFlow::Op const *srcOp = dstOp->inputs[j]->owner_op; - assert(op_to_node_map.find(srcOp) != op_to_node_map.end()); - Node srcNode = op_to_node_map[srcOp]; - graph->add_edge(srcNode, dstNode, dstOp->inputs[j]->owner_idx, j); - } - } - - return graph; -} - -/** - * @brief Unity search algorithm main entrance. - * - * @param[in] budget Not used - * @param[in] only_data_parallel Not used - * @param[out] best_graph The best possible PCG after optimization - * @param[out] optimal_views The corresponding device placement views of the - * best graph - */ -void GraphSearchHelper::graph_optimize( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views) { - // Construct graph structure - this->logger->debug() << "Starting graph optimization"; - - Graph *graph = this->construct_graph(); - graph->duplicate_input_nodes(); - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - Node sink_node = graph->find_sink_node(); - GraphOptimizeResult optimal = - this->generic_sequence_optimize( - graph, - sink_node, - tl::nullopt /*output_shape*/, - tl::nullopt /*input_shape*/); - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal cost: " << optimal.cost << std::endl; - SimplificationSettings settings; - settings.fuse_parallel_ops = true; - settings.remove_noops = true; - settings.remove_trailing_parallel_ops = true; - settings.simplify_parallel_ops = true; - best_graph = std::unique_ptr(new Graph(optimal.graph.value())); - best_graph->simplify(settings); - std::unordered_map duplicated_optimal_views = - best_graph->optimal_views(); - std::unordered_map deduplication_map = - best_graph->deduplicate_input_nodes(); - std::unordered_map real_optimal_views; - for (auto const &kv : duplicated_optimal_views) { - if (deduplication_map.find(kv.first) != deduplication_map.end()) { - real_optimal_views[deduplication_map.at(kv.first)] = kv.second; - } else { - real_optimal_views[kv.first] = kv.second; - } - } - best_graph->print_strategy_computation_graph(optimal.views); - optimal_views = real_optimal_views; -} - -/** - * @brief Experimental DP algorithm to optimize PCG with the consideration of - * memory usage. This is to avoid polluting the current Unity search algorithm - * above. And this should be merged to GraphSearchHelper::graph_optimize - * eventually. - * - * @param[in] budget Not used - * @param[in] only_data_parallel Not used - * @param[out] best_graph The best possible PCG after optimization - * @param[out] optimal_views The corresponding device placement views of the - * best graph - * @param[out] search_result The performance result of the search - */ -void GraphSearchHelper::graph_optimize_with_memory( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views, - MemorySearchResult &search_result) { - this->logger->debug() - << "Starting graph optimization with memory consideration"; - - // Construct graph structure - Graph *graph = this->construct_graph(); - - // The input nodes may need to be duplicated because the PCG was constructed - // to have one input node for one input, but the actual execution graph should - // have the distributed version of inputs (i.e. multiple nodes). - graph->duplicate_input_nodes(); - - // Export an empty schedule if needed. - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - Node sink_node = graph->find_sink_node(); - - auto const start = std::chrono::system_clock::now(); - GraphOptimizeResultWithMemory optimal = - this->generic_sequence_optimize_with_memory< - GraphOptimizeResultWithMemory>( - graph, sink_node, tl::nullopt, tl::nullopt); - auto const end = std::chrono::system_clock::now(); - - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal run time cost: " << optimal.cost - << ", Memory usage: " << optimal.mem_cost - << " | run_time_cost_factor: " - << this->mem_config.run_time_cost_factor << std::endl; - - // Save the search performance results to the output argument - search_result.run_time_cost = optimal.cost; - search_result.memory_cost = optimal.mem_cost.num; - search_result.search_time = - std::chrono::duration_cast(end - start) - .count(); - - // Further simplify the "optimal" graph/schedule to have a more efficient - // graph and more accurate cost. - best_graph = std::unique_ptr(new Graph(optimal.graph.value())); - SimplificationSettings settings; - // Simplify to consider parallel op fusion - settings.fuse_parallel_ops = true; - settings.remove_noops = true; - settings.remove_trailing_parallel_ops = true; - settings.simplify_parallel_ops = true; - best_graph->simplify(settings); - - // Get the real optimal machine views. - std::unordered_map duplicated_optimal_views = - best_graph->optimal_views(); - std::unordered_map deduplication_map = - best_graph->deduplicate_input_nodes(); - std::unordered_map real_optimal_views; - for (auto const &kv : duplicated_optimal_views) { - if (deduplication_map.find(kv.first) != deduplication_map.end()) { - real_optimal_views[deduplication_map.at(kv.first)] = kv.second; - } else { - real_optimal_views[kv.first] = kv.second; - } - } - std::cout << "Dot graph of searched strategy:" << std::endl; - best_graph->print_strategy_computation_graph(optimal.views); - std::cout << std::endl; - - optimal_views = real_optimal_views; -} - -void GraphSearchHelper::graph_optimize_no_split( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views) { - // Construct graph structure - this->logger->debug() << "Starting graph optimization without split"; - - Graph *graph = this->construct_graph(); - std::unordered_map empty_strategy; - if (!this->config.export_strategy_computation_graph_file.empty()) { - graph->export_strategy_computation_graph( - empty_strategy, this->config.export_strategy_computation_graph_file); - } - - SimplificationSettings settings; - settings.simplify_parallel_ops = true; - best_graph = this->base_optimize(graph, settings); - optimal_views = best_graph->optimal_views(); - - this->logger->debug() << "Total cache size: " - << this->cached_optimized_graphs.size(); - std::cout << "Optimal cost: " << best_graph->optimal_cost() << std::endl; -} - -static void graph_log_representation(Graph const *graph, - RecursiveLogger &logger) { - using FlexFlow::PCG::Utils::topo_sort; - - std::vector topo_sorted; - topo_sort(*graph, &topo_sorted); - std::ostringstream oss; - for (Node const &n : topo_sorted) { - logger.spew() << n.to_string(); - } -} - -void GraphSearchHelper::update_mem_optim_config( - MemoryOptimConfig const &new_config) { - mem_config = new_config; -} - -void GraphSearchHelper::find_rewrite_matches( - Graph const *graph, std::vector &matches) const { - std::vector xfers; - this->load_graph_substitutions(xfers); - - for (GraphXfer *xfer : xfers) { - log_xfer_matches.debug() - << "Finding matches for xfer: " << xfer->get_name(); - xfer->find_matches(graph, matches); - } - log_xfer_matches.debug() << "Finished finding xfer matches"; -} - -tl::optional - GraphSearchHelper::find_split_node(Graph const *graph, - int base_optimize_threshold) const { - using FlexFlow::PCG::Utils::get_edges; - using FlexFlow::PCG::Utils::MultisourceGraphStructure; - using FlexFlow::PCG::Utils::nodes; - using FlexFlow::PCG::Utils::post_dominators; - using FlexFlow::PCG::Utils::roots; - - TAG_ENTER(this->logger); - - int graph_size = nodes(*graph).size(); - this->logger->debug() << "Finding split node for graph (size " << graph_size - << ") with threshold " << base_optimize_threshold; - - if (graph_size <= base_optimize_threshold) { - this->logger->debug() - << "Graph size underneath threshold. Returning nullopt"; - return tl::nullopt; - } - - std::vector edges = get_edges(*graph); - std::unordered_map edge_scores; - - for (Edge const &e : edges) { - edge_scores[e] = 0; - } - - std::vector matches; - this->find_rewrite_matches(graph, matches); - this->logger->debug() << "Found " << matches.size() << " rewrite matches"; - { - TAG_ENTER(this->logger); - for (GraphXferMatch const &match : matches) { - auto msg = this->logger->spew(); - msg << match.get_xfer()->get_name() << " : "; - std::unordered_set nodes = match.get_nodes(); - for (Node const &node : nodes) { - msg << node.to_string() << " "; - } - } - } - - for (GraphXferMatch const &match : matches) { - for (Edge const &e : edges) { - if (match.containsEdge(graph, e)) { - edge_scores[e]++; - } - } - } - - this->logger->debug() << "Edge weights: "; - - { - TAG_ENTER(this->logger); - for (Edge const &e : edges) { - this->logger->debug() << e.srcOp.to_string() << "/" << e.srcIdx << " -> " - << e.dstOp.to_string() << "/" << e.dstIdx << " : " - << edge_scores.at(e); - } - } - - std::unordered_map> post_dominator_map = - post_dominators>(*graph); - Node source_node; - { - std::unordered_set source_nodes = roots(*graph); - if (source_nodes.size() != 1) { - source_nodes = roots>(*graph); - } - assert(source_nodes.size() == 1); - source_node = *source_nodes.begin(); - } - std::unordered_set possible_bottlenecks = - post_dominator_map.at(source_node); - Node sink_node = graph->find_sink_node(); - - int best_weight = 0; - tl::optional best = tl::nullopt; - int best_size = graph_size; - { - TAG_ENTER(this->logger); - - for (Node const &possible_bottleneck : possible_bottlenecks) { - if (possible_bottleneck == sink_node || - possible_bottleneck == source_node) { - continue; - } - - int weight = 0; - for (Edge const &e : graph->outEdges.at(possible_bottleneck)) { - weight += edge_scores.at(e); - } - this->logger->debug() - << "Potential bottleneck node " << possible_bottleneck.to_string() - << " has weight " << weight; - if (weight < best_weight) { - best_weight = weight; - best = possible_bottleneck; - } else if (weight == best_weight) { - // break ties by trying to choosing the split that produces the - // pre_graph with size closest to the threshold, favoring everything - // with smaller size over everything with larger size - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(possible_bottleneck); - int current_size = nodes(*pre_graph).size(); - - bool best_is_under = best_size <= base_optimize_threshold; - bool current_is_under = current_size <= base_optimize_threshold; - - bool condition1 = current_is_under && !best_is_under; - bool condition2 = - current_is_under && best_is_under && current_size > best_size; - bool condition3 = - !current_is_under && !best_is_under && current_size < best_size; - - if (condition1 || condition2 || condition3) { - best_weight = weight; - best = possible_bottleneck; - best_size = current_size; - } - } - } - } - - return best; -} - -/** - * @brief Base case of Unity's DP search algorithm. - * - * @param r_graph Graph to be optimized - * @param simplification_settings Settings to simplify the PCG - * @return std::unique_ptr Optimized PCG - */ -std::unique_ptr GraphSearchHelper::base_optimize( - Graph const *r_graph, - SimplificationSettings const &simplification_settings) { - // Construct graph substitutions - TAG_ENTER(this->logger); - - this->logger->debug() << "Optimizing base graph: "; - { - TAG_ENTER(this->logger); - /* graph_log_representation(r_graph, *this->logger); */ - // r_graph->print_dot(); - } - this->logger->debug() << "Starting cost: " << r_graph->optimal_cost(); - - std::vector xfers; - this->load_graph_substitutions(xfers); - - Graph *graph = new Graph(*r_graph); - - std::priority_queue, GraphCompare> candidates; - std::unordered_set hashmap; - candidates.push(graph); - hashmap.insert(graph->hash()); - Graph *best_graph = new Graph(*graph); - float best_cost = best_graph->optimal_cost(); - int counter = 0; - float const alpha = this->model->config.search_alpha; - - int budget = model->config.search_budget; - if (budget == 0) { - log_xfers.warning() - << "Base search budget is set to 0. This is probably not what you want " - "(use the --budget flag to set the base search budget)"; - } - for (int iter = 0; iter < budget || budget == -1; iter++) { - log_xfers.spew() << "Considering " << candidates.size() << " candidates"; - if (candidates.empty()) { - break; - } - - Graph *cur_graph = candidates.top(); - candidates.pop(); - if (cur_graph->optimal_cost() < best_graph->optimal_cost()) { - delete best_graph; - best_graph = cur_graph; - best_cost = cur_graph->optimal_cost(); - } else if (cur_graph->optimal_cost() > best_cost * alpha) { - continue; - } - - log_xfers.info("[%d] cur_cost(%.4lf) best_cost(%.4lf) candidates.size(%zu)", - counter, - cur_graph->optimal_cost(), - best_cost, - candidates.size()); - - log_xfers.debug() << "Considering " << xfers.size() << " possible xfers"; - for (size_t i = 0; i < xfers.size(); i++) { - int num_matches_found = 0, num_matches_rejected = 0; - log_xfers.debug() << "Considering xfer: " << xfers[i]->get_name(); - xfers[i]->run(0, - cur_graph, - candidates, - hashmap, - best_cost * alpha, - 1000, - simplification_settings, - num_matches_found, - num_matches_rejected); - log_xfers.debug() << "Rejected [ " << num_matches_rejected << " / " - << num_matches_found << " ] matches"; - /* std::cout << "." << std::flush; */ - } - /* std::cout << std::endl; */ - if (best_graph != cur_graph) { - delete cur_graph; - } - } - - this->logger->debug() << "Optimized cost: " << best_graph->optimal_cost(); - // best_graph->print_dot(); - return std::unique_ptr(best_graph); -} - -/** - * @brief Experimental. Base case of Unity's DP search algorithm with - * memory consideration. - * - * @param r_graph Graph to be optimized - * @param simplification_settings Settings to simplify the resulting PCG - * @return std::unique_ptr Optimized PCG - */ -std::unique_ptr GraphSearchHelper::base_optimize_with_memory( - Graph const *r_graph, - SimplificationSettings const &simplification_settings) { - TAG_ENTER(this->logger); - this->logger->debug() << "Optimizing base graph with memory: "; - { - TAG_ENTER(this->logger); - /* graph_log_representation(r_graph, *this->logger); */ - // r_graph->print_dot(); - } - this->logger->debug() << "Starting cost: " - << r_graph->optimal_cost_with_memory( - mem_config.run_time_cost_factor); - - // Construct graph substitutions - std::vector xfers; - this->load_graph_substitutions(xfers); - - // Prepare for the search - std::priority_queue, GraphCompareWithMemory> - candidates(GraphCompareWithMemory{mem_config.run_time_cost_factor}); - std::unordered_set hashmap; - - Graph *graph = new Graph(*r_graph); - candidates.push(graph); - hashmap.insert(graph->hash()); - - Graph *best_graph = new Graph(*graph); - float best_cost = - best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - - int counter = 0; - float const alpha = this->model->config.search_alpha; - int budget = model->config.search_budget; - if (budget == 0) { - log_xfers.warning() - << "Base search budget is set to 0. This is probably not what you want " - "(use the --budget flag to set the base search budget)"; - } - - // Actual exploration - for (int iter = 0; iter < budget || budget == -1; iter++) { - log_xfers.spew() << "Considering " << candidates.size() - << " candidates in base_optimize_with_memory"; - if (candidates.empty()) { - break; - } - - Graph *cur_graph = candidates.top(); - candidates.pop(); - if (cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor) < - best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor)) { - delete best_graph; - best_graph = cur_graph; - best_cost = - cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - } else if (cur_graph->optimal_cost_with_memory( - mem_config.run_time_cost_factor) > best_cost * alpha) { - continue; - } - - log_xfers.info( - "[%d] cur_cost(%.4lf) best_cost(%.4lf) candidates.size(%zu)", - counter, - cur_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor), - best_cost, - candidates.size()); - - log_xfers.debug() << "Considering " << xfers.size() - << " possible xfers in base_optimize_with_memory"; - for (size_t i = 0; i < xfers.size(); i++) { - int num_matches_found = 0, num_matches_rejected = 0; - log_xfers.debug() << "Considering xfer: " << xfers[i]->get_name(); - xfers[i]->run(0, - cur_graph, - candidates, - hashmap, - best_cost * alpha, - 1000, - simplification_settings, - num_matches_found, - num_matches_rejected); - log_xfers.debug() << "Rejected [ " << num_matches_rejected << " / " - << num_matches_found << " ] matches"; - } - - if (best_graph != cur_graph) { - delete cur_graph; - } - } - - this->logger->debug() - << "Optimized cost at the end of base_optimize_with_memory: " - << best_graph->optimal_cost_with_memory(mem_config.run_time_cost_factor); - - return std::unique_ptr(best_graph); -} - -size_t gs_dp_state_hash(Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - size_t key = graph->hash(); - hash_combine(key, sink_node.ptr); - hash_combine(key, output_shape); - hash_combine(key, input_shape); - return key; -} - -float GraphSearchHelper::sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - return this->generic_sequence_optimize( - graph, sink_node, output_shape, input_shape); -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache(size_t hash) const { - if (this->cached_optimized_graphs.find(hash) == - this->cached_optimized_graphs.end()) { - return tl::nullopt; - } else { - return this->cached_optimized_graphs.at(hash); - } -} - -template <> -float GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - return optimized->generic_optimal_cost(); -} - -template <> -GraphCostResult GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - return optimized->generic_optimal_cost(); -} - -template <> -GraphOptimizeResult GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - GraphOptimizeResult result; - result.graph = *optimized; - GraphCostResult gcr = optimized->generic_optimal_cost(); - result.cost = gcr.cost; - result.views = gcr.views; - return result; -} - -template <> -GraphOptimizeResultWithMemory - GraphSearchHelper::get_optimal_cost( - std::unique_ptr optimized) const { - GraphOptimizeResultWithMemory result; - result.graph = *optimized; - GraphCostResultWithMemory gcr = - optimized->generic_optimal_cost(); - result.cost = gcr.cost; - result.views = gcr.views; - result.mem_cost = gcr.mem_cost; - return result; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -tl::optional - GraphSearchHelper::try_get_cost_from_cache( - size_t hash) const { - return tl::nullopt; -} - -template <> -void GraphSearchHelper::try_cache_result(size_t hash, - float const &value) { - this->cached_optimized_graphs[hash] = value; -} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphCostResult const &value) {} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphOptimizeResult const &value) {} - -template <> -void GraphSearchHelper::try_cache_result( - size_t hash, GraphOptimizeResultWithMemory const &value) {} - -/** - * @brief Get the cost/result of PCG if sequentially split it. - * - * @details This function is to combine the search results from DP sub-problems. - * The sub-problems are solved by generic_sequence_optimize(). - */ -template -T GraphSearchHelper::execute_sequence_split( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape) { - return sequence_cost( - this->generic_sequence_optimize( - pre_graph.get(), bottleneck, bottleneck_output_shape, input_shape), - this->generic_sequence_optimize( - post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); -} - -/** - * @brief Experimental. Consider memory usage when spliting the PCG during the - * DP search. This should be merged with execute_sequence_split(). - */ -template -T GraphSearchHelper::execute_sequence_split_with_memory( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape) { - return sequence_cost( - this->generic_sequence_optimize_with_memory( - pre_graph.get(), bottleneck, bottleneck_output_shape, input_shape), - this->generic_sequence_optimize_with_memory( - post_graph.get(), sink_node, output_shape, bottleneck_output_shape)); -} - -/** - * @brief Top level DP search procedure for Unity. - */ -template -T GraphSearchHelper::generic_sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - /* int starting_depth = this->logger->get_depth(); */ - - TAG_ENTER(this->logger); - - size_t hash = gs_dp_state_hash(graph, sink_node, output_shape, input_shape); - tl::optional cached = this->try_get_cost_from_cache(hash); - if (cached.has_value()) { - this->logger->spew() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->spew() << "Retrieved value from cache: " << cached.value(); - } - - /* this->logger->check_same_as(starting_depth); */ - return cached.value(); - } - - this->logger->debug() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - T return_value; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->debug() << "Graph hash: " << std::setw(32) - << std::setfill('0') << graph->hash(); - if (input_shape.has_value()) { - this->logger->debug() << "Input shape: " << input_shape.value(); - } else { - this->logger->debug() << "Input shape: "; - } - if (output_shape.has_value()) { - this->logger->debug() << "Output shape: " << output_shape.value(); - } else { - this->logger->debug() << "Output shape: "; - } - - tl::optional bottleneck = - this->find_split_node(graph, this->config.base_optimize_threshold); - - if (!bottleneck.has_value()) { - this->logger->debug() << "Applying base case"; - Graph to_optimize(*graph); - if (input_shape.has_value()) { - Node input_node = - this->model->get_or_create_input_node(input_shape.value()); - Node noop_node = - this->model->get_or_create_noop_node(input_node.ptr->outputs[0]); - Graph input_graph(this->model); - Edge e(input_node, noop_node, 0, 0); - input_graph.add_edge(e); - - Node old_source_node = graph->find_source_node(); - ParallelTensorShape old_source_output_shape = - old_source_node.ptr->outputs[0]->get_shape(); - input_graph.reshape_output_tensor(old_source_output_shape); - - Node new_sink_node = input_graph.find_sink_node(); - assert(new_sink_node.ptr->numOutputs == 1); - assert(new_sink_node.ptr->outputs[0]->get_shape() == - old_source_output_shape); - - to_optimize.replace_subgraph({old_source_node}, input_graph); - } - SimplificationSettings settings; - if (output_shape.has_value()) { - to_optimize.reshape_output_tensor(output_shape.value()); - Node sink_node = to_optimize.find_sink_node(); - Node noop_node = - this->model->get_or_create_noop_node(sink_node.ptr->outputs[0]); - to_optimize.add_edge(sink_node, noop_node, 0, 0); - } else { - settings.remove_trailing_parallel_ops = true; - } - settings.simplify_parallel_ops = true; - std::unique_ptr optimized = - this->base_optimize(&to_optimize, settings); - return_value = get_optimal_cost( - std::move(optimized)); // optimized->generic_optimal_cost(); - } else { - this->logger->debug() << "Applying recursive case on bottleneck " - << bottleneck.value().guid; - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(bottleneck.value()); - - MachineResource resources(this->model->config); - std::vector valid_machine_views = - this->model->search->get_valid_machine_views(bottleneck.value().ptr, - resources); - - float best_cost = std::numeric_limits::infinity(); - tl::optional best_shape = tl::nullopt; - { - TAG_ENTER(this->logger); - for (ParallelTensorShape const &bottleneck_output_shape : - this->possible_split_output_tensor_shapes(bottleneck.value())) { - this->logger->debug() - << "Considering boundary shape " << bottleneck_output_shape; - float current_cost; - { - TAG_ENTER(this->logger); - // TODO @lockshaw we really should create the merged graph here - // since it's possible though unlikely for there to be hidden - // transfer costs between modules due to device assignment changes - // across the boundaries - - // We wait to add the communication nodes between boundaries so we - // don't accidentally split on them and keep processing the pure - // computation graph The bottleneck node is kept in the postgraph - // purely as a placeholder and will be replaced with an Input/NoOp - // sequence before any rewrites are actually performed - // this->logger->debug() << "Finding cost of pre_graph (" << - // bottleneck_output_shape << ")"; float pre_cost = - // this->generic_sequence_optimize(pre_graph.get(), - // bottleneck.value(), bottleneck_output_shape, input_shape); - // this->logger->debug() << "Cost of pre_graph (" << - // bottleneck_output_shape << "): " << pre_cost; - // this->logger->debug() << "Finding cost of post_graph (" << - // bottleneck_output_shape << ")"; float post_cost = - // this->generic_sequence_optimize(post_graph.get(), - // sink_node, output_shape, bottleneck_output_shape); - // this->logger->debug() << "Cost of post_graph (" << - // bottleneck_output_shape << "): " << post_cost; float current_cost - // = pre_cost + post_cost; - current_cost = - this->execute_sequence_split(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - bottleneck_output_shape); - - if (current_cost < best_cost) { - best_cost = current_cost; - best_shape = bottleneck_output_shape; - } - } - this->logger->debug() << "Boundary shape " << bottleneck_output_shape - << " has cost: " << current_cost; - } - } - - if (best_shape.has_value()) { - this->logger->debug() - << "Best intermediate shape found: " << best_shape.value(); - } else { - this->logger->debug() << "No valid intermediate shapes found"; - } - - if (best_cost != std::numeric_limits::infinity()) { - return_value = this->execute_sequence_split(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - best_shape.value()); - } - } - - this->try_cache_result(hash, return_value); - } - return return_value; -} - -/** - * @brief Top level DP search procedure for Unity with the consideration of - * memory usage. - * - * @tparam T Returned type - * @param graph Pre-optimization PCG - * @param sink_node Sink node of the PCG - * @param output_shape ??? - * @param input_shape ??? - * @return T Optimal result - */ -template -T GraphSearchHelper::generic_sequence_optimize_with_memory( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape) { - TAG_ENTER(this->logger); - - // Try to find the result from cache first. But this will only get the cached - // result if the returned type is float. The float number means the best run - // time cost with only machine quantity (without distinguishing machine - // identities). - size_t hash = gs_dp_state_hash(graph, sink_node, output_shape, input_shape); - tl::optional cached = this->try_get_cost_from_cache(hash); - if (cached.has_value()) { - this->logger->spew() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - { - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->spew() << "Retrieved value from cache: " << cached.value(); - } - return cached.value(); - } - - // Couldn't find the result from cache. Try to optimize and get one. - this->logger->debug() << "Optimizing graph with " << graph->inEdges.size() - << " nodes"; - T return_value; - { - // Print out debug information - TAG_ENTER(this->logger); - this->logger->spew() << "Nodes: "; - { - TAG_ENTER(this->logger); - graph_log_representation(graph, *this->logger); - } - this->logger->debug() << "Graph hash: " << std::setw(32) - << std::setfill('0') << graph->hash(); - if (input_shape.has_value()) { - this->logger->debug() << "Input shape: " << input_shape.value(); - } else { - this->logger->debug() << "Input shape: "; - } - if (output_shape.has_value()) { - this->logger->debug() << "Output shape: " << output_shape.value(); - } else { - this->logger->debug() << "Output shape: "; - } - - // Find the node to sequentially split the PCG. - // Decide if the search reaches the base condition by this. - tl::optional bottleneck = - this->find_split_node(graph, this->config.base_optimize_threshold); - - if (!bottleneck.has_value()) { - this->logger->debug() << "Applying base case"; - - // Construct the PCG to optimize based on input_shape and output_shape - // information. - Graph to_optimize(*graph); - if (input_shape.has_value()) { - Node input_node = - this->model->get_or_create_input_node(input_shape.value()); - Node noop_node = - this->model->get_or_create_noop_node(input_node.ptr->outputs[0]); - Graph input_graph(this->model); - Edge e(input_node, noop_node, 0, 0); - input_graph.add_edge(e); - - Node old_source_node = graph->find_source_node(); - ParallelTensorShape old_source_output_shape = - old_source_node.ptr->outputs[0]->get_shape(); - input_graph.reshape_output_tensor(old_source_output_shape); - - Node new_sink_node = input_graph.find_sink_node(); - assert(new_sink_node.ptr->numOutputs == 1); - assert(new_sink_node.ptr->outputs[0]->get_shape() == - old_source_output_shape); - - to_optimize.replace_subgraph({old_source_node}, input_graph); - } - SimplificationSettings settings; - if (output_shape.has_value()) { - to_optimize.reshape_output_tensor(output_shape.value()); - Node sink_node = to_optimize.find_sink_node(); - Node noop_node = - this->model->get_or_create_noop_node(sink_node.ptr->outputs[0]); - to_optimize.add_edge(sink_node, noop_node, 0, 0); - } else { - settings.remove_trailing_parallel_ops = true; - } - settings.simplify_parallel_ops = true; - - // Call base optimization to perform graph substitution. - std::unique_ptr optimized = - this->base_optimize_with_memory(&to_optimize, settings); - return_value = get_optimal_cost(std::move(optimized)); - } else { - this->logger->debug() << "Applying recursive case on bottleneck " - << bottleneck.value().guid; - - std::unique_ptr pre_graph, post_graph; - std::tie(pre_graph, post_graph) = - graph->split_at_node(bottleneck.value()); - - MachineResource resources(this->model->config); - std::vector valid_machine_views = - this->model->search->get_valid_machine_views(bottleneck.value().ptr, - resources); - - // Try to find the best cost and corresponding best bottleneck shape. - // This search process is based on the float version of - // execute_sequence_split_with_memory(). - float best_cost = std::numeric_limits::infinity(); - tl::optional best_shape = tl::nullopt; - { - TAG_ENTER(this->logger); - for (auto const &bottleneck_output_shape : - this->possible_split_output_tensor_shapes(bottleneck.value())) { - this->logger->debug() - << "Considering boundary shape " << bottleneck_output_shape; - float current_cost; - { - TAG_ENTER(this->logger); - // Get the cost from execute_sequence_split_with_memory by - // only changing bottleneck_output_shape. - current_cost = this->execute_sequence_split_with_memory( - pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - bottleneck_output_shape); - - if (current_cost < best_cost) { - best_cost = current_cost; - best_shape = bottleneck_output_shape; - } - } - this->logger->debug() << "Boundary shape " << bottleneck_output_shape - << " has cost: " << current_cost; - } - } - - if (best_shape.has_value()) { - this->logger->debug() - << "Best intermediate shape found: " << best_shape.value(); - } else { - this->logger->debug() << "No valid intermediate shapes found"; - } - - // ? What if best_cost is infinity ? - if (best_cost != std::numeric_limits::infinity()) { - // Get the return value of correct type with previously found - // best_shape. - return_value = - this->execute_sequence_split_with_memory(pre_graph, - post_graph, - output_shape, - input_shape, - sink_node, - bottleneck.value(), - best_shape.value()); - } - } - // Try to cache the float result - this->try_cache_result(hash, return_value); - } - return return_value; -} - -std::vector - GraphSearchHelper::possible_split_output_tensor_shapes( - Node const &source_node) const { - TAG_ENTER(this->logger); - - this->logger->debug() << "Finding possible output tensor shapes for node " - << source_node.guid; - assert(source_node.ptr->numOutputs == 1); - ParallelTensor output_tensor = source_node.ptr->outputs[0]; - for (int i = 0; i < output_tensor->num_dims; i++) { - assert(output_tensor->dims[i].degree == 1); - } - - std::vector without_replicas; - - int num_devices = this->config.numNodes * this->config.workersPerNode; - int degrees[MAX_TENSOR_DIM]; - std::fill_n(degrees, MAX_TENSOR_DIM, 1); - - ParallelTensorShape base_shape; - base_shape.num_dims = output_tensor->num_dims; - for (int i = 0; i < output_tensor->num_dims; i++) { - base_shape.dims[i].degree = 1; - base_shape.dims[i].size = output_tensor->dims[i].size; - } - without_replicas.push_back(base_shape); - - { - TAG_ENTER(this->logger); - while (true) { - bool is_done = true; - for (int i = 0; i < output_tensor->num_dims; i++) { - degrees[i] *= 2; - if (degrees[i] > num_devices) { - degrees[i] = 1; - } else { - is_done = false; - break; - } - } - std::ostringstream oss; - for (int i = 0; i < output_tensor->num_dims; i++) { - oss << degrees[i] << " "; - } - this->logger->spew() << "Considering: " << oss.str(); - if (is_done) { - break; - } - - bool is_valid = true; - int total_degree = 1; - ParallelTensorShape shape; - shape.num_dims = output_tensor->num_dims; - for (int i = 0; i < output_tensor->num_dims; i++) { - total_degree *= degrees[i]; - shape.dims[i].degree = degrees[i]; - shape.dims[i].size = output_tensor->dims[i].size; - if (shape.dims[i].size % shape.dims[i].degree != 0) { - is_valid = false; - } - } - if (total_degree <= num_devices && is_valid) { - without_replicas.push_back(shape); - } - } - } - - this->logger->debug() << "Found " << without_replicas.size() - << " possible tensor output shapes without replicas"; - this->logger->debug() << "They are:"; - { - TAG_ENTER(this->logger); - for (auto const &shape : without_replicas) { - this->logger->debug() << shape; - } - } - return without_replicas; -} - -void GraphSearchHelper::subgraph_optimize(Graph *subgraph) {} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - return this->create_conv2d(input, matchOpX); -} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - OpX *pool = new OpX(OP_POOL2D, 1, 1, input); - pool->matchOpX = matchOpX; - return pool; -} - -template <> -OpX *GraphXfer::create_opx(TensorX const &input, OpX const *matchOpX) { - OpX *flat = new OpX(OP_FLAT, 1, 1, input); - flat->matchOpX = matchOpX; - return flat; -} - -GraphXfer *create_partition_linear_combine(FFModel *model, - int num_dims, - int num_parts, - ActiMode activation, - bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *linear1 = subst->create_linear( - input, NULL /*matchOpX*/, num_dims, activation, use_bias); - OpX *repartition = subst->create_repartition(input, num_dims - 2, num_parts); - OpX *linear2 = subst->create_linear(repartition->outputs[0], - linear1 /*matchOpX*/, - num_dims, - activation, - use_bias); - OpX *combine = - subst->create_combine(linear2->outputs[0], num_dims - 2, num_parts); - subst->map_output(linear1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(linear1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(linear2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_conv2d_combine(FFModel *model, - int num_dims, - int num_parts) { - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *conv1 = subst->create_conv2d(input, NULL /*matchOpX*/); - OpX *repartition = subst->create_repartition(input, num_dims - 2, num_parts); - OpX *conv2 = - subst->create_conv2d(repartition->outputs[0], conv1 /*matchOpX*/); - OpX *combine = - subst->create_combine(conv2->outputs[0], num_dims - 2, num_parts); - subst->map_output(conv1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(conv1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(conv2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_conv2d_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_inception(FFModel *model, - int num_convs, - int num_dims, - int num_parts) { - // 3 convs and 1 pool2d - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *src_combine = subst->create_combine(input, num_dims - 2, num_parts); - subst->srcOps.push_back(src_combine); - std::vector src_convs; - for (int i = 0; i < num_convs; i++) { - OpX *conv = - subst->create_conv2d(src_combine->outputs[0], NULL /*matchOpX*/); - src_convs.push_back(conv); - subst->srcOps.push_back(conv); - } - OpX *src_pool = - subst->create_pool2d(src_combine->outputs[0], NULL /*matchOpX*/); - subst->srcOps.push_back(src_pool); - // dst ops - std::vector dst_convs; - for (int i = 0; i < num_convs; i++) { - OpX *conv = subst->create_conv2d(input, src_convs[i] /*matchOpX*/); - OpX *comb = - subst->create_combine(conv->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(conv); - subst->dstOps.push_back(comb); - subst->map_output(src_convs[i]->outputs[0], comb->outputs[0]); - } - OpX *dst_pool = subst->create_pool2d(input, src_pool /*matchOpX*/); - OpX *dst_comb = - subst->create_combine(dst_pool->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(dst_pool); - subst->dstOps.push_back(dst_comb); - subst->map_output(src_pool->outputs[0], dst_comb->outputs[0]); - subst->name = "create_combine_inceptionA"; - return subst; -} - -GraphXfer *create_combine_concat(FFModel *model, - int num_inputs, - int num_dims, - int num_parts) { - // assert 5D - assert(num_dims == 5); - GraphXfer *subst = new GraphXfer(model); - std::vector inputs, concat_inputs; - std::vector combines; - for (int i = 0; i < num_inputs; i++) { - inputs.push_back(subst->new_tensor()); - combines.push_back( - subst->create_combine(inputs[i], num_dims - 2, num_parts)); - concat_inputs.push_back(combines[i]->outputs[0]); - subst->srcOps.push_back(combines[i]); - } - OpX *concat1 = subst->create_concat( - concat_inputs.data(), num_inputs, NULL /*matchOpX*/, 2); - subst->srcOps.push_back(concat1); - OpX *concat2 = - subst->create_concat(inputs.data(), num_inputs, concat1 /*matchOpX*/, 2); - OpX *combine = - subst->create_combine(concat2->outputs[0], num_dims - 2, num_parts); - subst->dstOps.push_back(concat2); - subst->dstOps.push_back(combine); - subst->map_output(concat1->outputs[0], combine->outputs[0]); - subst->name = "create_combine_concat"; - return subst; -} - -GraphXfer *create_partition_attention_combine(FFModel *model, - int num_heads, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *attn1 = subst->create_attention( - input, input, input, NULL /*matchOpX*/, num_heads); - OpX *repart = subst->create_repartition(input, 2, num_parts); - OpX *attn2 = subst->create_attention(repart->outputs[0], - repart->outputs[0], - repart->outputs[0], - attn1 /*matchOpX*/, - num_heads); - OpX *combine = subst->create_combine(attn2->outputs[0], 2, num_parts); - subst->map_output(attn1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(attn1); - subst->dstOps.push_back(repart); - subst->dstOps.push_back(attn2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_attention_combine[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_replicate_attention_reduce(FFModel *model, - int num_heads, - int num_parts) { - assert(num_heads % num_parts == 0); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *attn1 = subst->create_attention( - input, input, input, NULL /*matchOpX*/, num_heads); - OpX *repl = subst->create_replicate(input, 3, num_parts); - OpX *attn2 = subst->create_attention(repl->outputs[0], - repl->outputs[0], - repl->outputs[0], - attn1 /*matchOpX*/, - num_heads / num_parts); - OpX *reduce = subst->create_reduction(attn2->outputs[0], 3, num_parts); - subst->map_output(attn1->outputs[0], reduce->outputs[0]); - subst->srcOps.push_back(attn1); - subst->dstOps.push_back(repl); - subst->dstOps.push_back(attn2); - subst->dstOps.push_back(reduce); - - std::ostringstream oss; - oss << "replicate_attention_reduce[" - << "num_heads=" << num_heads << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_replicate_linear_combine(FFModel *model, - int num_dims, - int num_parts, - ActiMode activation, - bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *linear1 = subst->create_linear( - input, NULL /*matchOpX*/, num_dims, activation, use_bias); - OpX *replicate = subst->create_replicate(input, num_dims - 1, num_parts); - OpX *linear2 = subst->create_linear(replicate->outputs[0], - linear1 /*matchOpX*/, - num_dims, - activation, - use_bias); - OpX *combine = subst->create_combine(linear2->outputs[0], 0, num_parts); - subst->map_output(linear1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(linear1); - subst->dstOps.push_back(replicate); - subst->dstOps.push_back(linear2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "replicate_linear_combine[" - << "num_dims=" << num_dims << ",num_parts=" << num_parts - << ",activation=" << activation << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_add_combine(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input1 = subst->new_tensor(); - TensorX input2 = subst->new_tensor(); - OpX *add1 = subst->create_element_binary(input1, input2, OP_EW_ADD); - OpX *repartition1 = - subst->create_repartition(input1, parallel_dim, num_parts); - OpX *repartition2 = - subst->create_repartition(input2, parallel_dim, num_parts); - OpX *add2 = subst->create_element_binary( - repartition1->outputs[0], repartition2->outputs[0], OP_EW_ADD); - OpX *combine = - subst->create_combine(add2->outputs[0], parallel_dim, num_parts); - subst->map_output(add1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(add1); - subst->dstOps.push_back(repartition1); - subst->dstOps.push_back(repartition2); - subst->dstOps.push_back(add2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_add_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_add_partition(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input1 = subst->new_tensor(); - TensorX input2 = subst->new_tensor(); - OpX *add1 = subst->create_element_binary(input1, input2, OP_EW_ADD); - - OpX *combine1 = subst->create_combine(input1, parallel_dim, num_parts); - OpX *combine2 = subst->create_combine(input2, parallel_dim, num_parts); - OpX *add2 = subst->create_element_binary( - combine1->outputs[0], combine2->outputs[0], OP_EW_ADD); - OpX *repartition = - subst->create_repartition(add2->outputs[0], parallel_dim, num_parts); - subst->map_output(add1->outputs[0], repartition->outputs[0]); - subst->srcOps.push_back(add1); - subst->dstOps.push_back(combine1); - subst->dstOps.push_back(combine2); - subst->dstOps.push_back(add2); - subst->dstOps.push_back(repartition); - - std::ostringstream oss; - oss << "combine_add_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_relu_combine(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *relu1 = subst->create_element_unary(input, OP_RELU); - - OpX *partition = subst->create_repartition(input, parallel_dim, num_parts); - OpX *relu2 = subst->create_element_unary(partition->outputs[0], OP_RELU); - OpX *combine = - subst->create_combine(relu2->outputs[0], parallel_dim, num_parts); - - subst->map_output(relu1->outputs[0], combine->outputs[0]); - - subst->srcOps.push_back(relu1); - - subst->dstOps.push_back(partition); - subst->dstOps.push_back(relu2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_relu_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_relu_partition(FFModel *model, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *relu1 = subst->create_element_unary(input, OP_RELU); - - OpX *combine = subst->create_combine(input, parallel_dim, num_parts); - OpX *relu2 = subst->create_element_unary(combine->outputs[0], OP_RELU); - OpX *partition = - subst->create_repartition(relu2->outputs[0], parallel_dim, num_parts); - - subst->map_output(relu1->outputs[0], partition->outputs[0]); - - subst->srcOps.push_back(relu1); - - subst->dstOps.push_back(combine); - subst->dstOps.push_back(relu2); - subst->dstOps.push_back(partition); - - std::ostringstream oss; - oss << "combine_relu_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_concat_combine(FFModel *model, - int num_inputs, - int concat_dim, - int parallel_dim, - int num_parts) { - GraphXfer *subst = new GraphXfer(model); - assert(num_inputs <= MAX_NUM_INPUTS); - TensorX inputs[MAX_NUM_INPUTS]; - for (int i = 0; i < num_inputs; i++) { - inputs[i] = subst->new_tensor(); - } - OpX *concat = - subst->create_concat(inputs, num_inputs, NULL /*matchOpX*/, concat_dim); - subst->srcOps.push_back(concat); - TensorX new_inputs[MAX_NUM_INPUTS]; - for (int i = 0; i < num_inputs; i++) { - OpX *repartition = - subst->create_repartition(inputs[i], parallel_dim, num_parts); - new_inputs[i] = repartition->outputs[0]; - subst->dstOps.push_back(repartition); - } - OpX *concat2 = subst->create_concat( - new_inputs, num_inputs, concat /*matchOpX*/, concat_dim); - subst->dstOps.push_back(concat2); - OpX *combine = - subst->create_combine(concat2->outputs[0], parallel_dim, num_parts); - subst->dstOps.push_back(combine); - subst->map_output(concat->outputs[0], combine->outputs[0]); - - std::ostringstream oss; - oss << "partition_concat_combine[" - << "num_inputs=" << num_inputs << ",concat_dim=" << concat_dim - << ",parallel_dim=" << parallel_dim << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_partition_softmax_combine(FFModel *model, - int softmax_dim, - int parallel_dim, - int num_parts) { - assert(parallel_dim != softmax_dim); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *softmax1 = subst->create_softmax(input, softmax_dim); - OpX *repartition = subst->create_repartition(input, parallel_dim, num_parts); - OpX *softmax2 = subst->create_softmax(repartition->outputs[0], softmax_dim); - OpX *combine = - subst->create_combine(softmax2->outputs[0], parallel_dim, num_parts); - subst->map_output(softmax1->outputs[0], combine->outputs[0]); - subst->srcOps.push_back(softmax1); - subst->dstOps.push_back(repartition); - subst->dstOps.push_back(softmax2); - subst->dstOps.push_back(combine); - - std::ostringstream oss; - oss << "partition_softmax_combine[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *create_combine_softmax_partition(FFModel *model, - int softmax_dim, - int parallel_dim, - int num_parts) { - assert(parallel_dim != softmax_dim); - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *softmax1 = subst->create_softmax(input, softmax_dim); - OpX *combine = subst->create_combine(input, parallel_dim, num_parts); - OpX *softmax2 = subst->create_softmax(combine->outputs[0], softmax_dim); - OpX *repartition = - subst->create_repartition(softmax2->outputs[0], parallel_dim, num_parts); - subst->map_output(softmax1->outputs[0], repartition->outputs[0]); - subst->srcOps.push_back(softmax1); - subst->dstOps.push_back(combine); - subst->dstOps.push_back(softmax2); - subst->dstOps.push_back(repartition); - - std::ostringstream oss; - oss << "combine_softmax_partition[" - << "softmax_dim=" << softmax_dim << ",parallel_dim=" << parallel_dim - << ",num_parts=" << num_parts << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *leading_relu_branch_combine(FFModel *model, - int parallel_dim, - int num_parts, - int num_combines) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_partition = - subst->create_repartition(input, parallel_dim, num_parts); - std::vector old_combines; - for (int i = 0; i < num_combines; i++) { - old_combines.push_back( - subst->create_combine(input, parallel_dim, num_parts)); - } - - OpX *new_partition = - subst->create_repartition(input, parallel_dim, num_parts); - std::vector new_noops; - for (int i = 0; i < num_combines; i++) { - new_noops.push_back(subst->create_noop(input)); - } - - subst->map_output(old_partition->outputs[0], new_partition->outputs[0]); - for (int i = 0; i < num_combines; i++) { - subst->map_output(old_combines[i]->outputs[0], new_noops[i]->outputs[0]); - } - - subst->srcOps.push_back(old_partition); - subst->srcOps.insert( - subst->srcOps.end(), old_combines.begin(), old_combines.end()); - subst->dstOps.push_back(new_partition); - subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); - - std::ostringstream oss; - oss << "leading_relu_branch_combine[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_combines=" << num_combines << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer *leading_relu_branch_partition(FFModel *model, - int parallel_dim, - int num_parts, - int num_partitions) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_combine = subst->create_combine(input, parallel_dim, num_parts); - std::vector old_partitions; - for (int i = 0; i < num_partitions; i++) { - old_partitions.push_back( - subst->create_repartition(input, parallel_dim, num_parts)); - } - - OpX *new_combine = subst->create_combine(input, parallel_dim, num_parts); - std::vector new_noops; - for (int i = 0; i < num_partitions; i++) { - new_noops.push_back(subst->create_noop(input)); - } - - subst->map_output(old_combine->outputs[0], new_combine->outputs[0]); - for (int i = 0; i < num_partitions; i++) { - subst->map_output(old_partitions[i]->outputs[0], new_noops[i]->outputs[0]); - } - - subst->srcOps.push_back(old_combine); - subst->srcOps.insert( - subst->srcOps.end(), old_partitions.begin(), old_partitions.end()); - subst->dstOps.push_back(new_combine); - subst->dstOps.insert(subst->dstOps.end(), new_noops.begin(), new_noops.end()); - - std::ostringstream oss; - oss << "leading_relu_branch_partition[" - << "parallel_dim=" << parallel_dim << ",num_parts=" << num_parts - << ",num_partitions=" << num_partitions << "]"; - subst->name = oss.str(); - - return subst; -} - -GraphXfer * - create_linear_relu_merge(FFModel *model, int num_dims, bool use_bias) { - GraphXfer *subst = new GraphXfer(model); - TensorX input = subst->new_tensor(); - OpX *old_linear = - subst->create_linear(input, nullptr, num_dims, AC_MODE_NONE, use_bias); - OpX *old_relu = subst->create_relu(old_linear->outputs[0]); - - OpX *new_linear = - subst->create_linear(input, old_linear, num_dims, AC_MODE_RELU, use_bias); - - subst->map_output(old_relu->outputs[0], new_linear->outputs[0]); - subst->srcOps.push_back(old_linear); - subst->srcOps.push_back(old_relu); - subst->dstOps.push_back(new_linear); - - std::ostringstream oss; - oss << "linear_relu_merge[" - << "num_dims=" << num_dims << ",use_bias=" << use_bias << "]"; - subst->name = oss.str(); - - return subst; -} - -} // namespace ffc - -using PCG::Edge; -using PCG::Graph; -using PCG::Node; - -/** - * @brief Optimize the graph stored in FFModel. - * - * @param[in] budget The search budget - * @param[in] only_data_parallel True if only doing data parallel training - * @param[out] best_graph The searched best graph - * @param[out] optimal_views The corresponding machine view of the best_graph - * @param[in] perform_memory_search True if we want to consider memory during - * the search - * @param[in] new_config Memory optimization config to use if this is a memory - * search - * @param[out] search_result The performance result of this search - */ -void FFModel::graph_optimize( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views, - bool perform_memory_search, - MemoryOptimConfig new_config, - MemorySearchResult &search_result) { - if (perform_memory_search) { - this->graph_search->update_mem_optim_config(new_config); - this->graph_search->graph_optimize_with_memory( - budget, only_data_parallel, best_graph, optimal_views, search_result); - } else { - this->graph_search->graph_optimize( - budget, only_data_parallel, best_graph, optimal_views); - } -} - -bool FFModel::convert_graph_to_operators( - Graph const *graph, - std::unordered_map const &optimal_views) { - // Clear operators - operators.clear(); - std::unordered_map todos; - std::unordered_map node_to_op; - std::vector queue; - for (auto const &it : graph->inEdges) { - auto const &inList = it.second; - if (inList.size() == 0) { - queue.push_back(it.first); - } else { - todos[it.first] = (int)inList.size(); - } - } - size_t index = 0; - while (index < queue.size()) { - Node node = queue[index++]; - assert(node.ptr != NULL); - auto const &inList = graph->inEdges.find(node)->second; - ParallelTensor inputs[MAX_NUM_INPUTS]; - int num_inputs = 0; - for (auto const &e : inList) { - inputs[e.dstIdx] = node_to_op[e.srcOp]->outputs[e.srcIdx]; - assert(e.dstIdx < (int)inList.size()); - num_inputs++; - } - Op *new_op = NULL; - switch (node.ptr->op_type) { - case OP_INPUT: { - NoOp *noop = (NoOp *)node.ptr; - new_op = new NoOp( - *this, OP_INPUT, noop->input_tensor_guid, node.ptr->outputs[0]); - break; - } - case OP_CONCAT: { - Concat *concat = (Concat *)node.ptr; - new_op = new Concat( - *this, (int)inList.size(), inputs, concat->legion_axis, NULL); - break; - } - case OP_AGGREGATE: { - Aggregate *aggr = (Aggregate *)node.ptr; - new_op = new Aggregate(*this, inputs, aggr->n, aggr->lambda_bal, NULL); - break; - } - case OP_SPLIT: { - Split *split = (Split *)node.ptr; - std::vector splits; - for (int i = 0; i < split->numOutputs; i++) { - splits.push_back(split->outputs[i]->dims[split->legion_axis].size); - } - new_op = new Split(*this, inputs[0], splits, split->legion_axis, NULL); - break; - } - case OP_EMBEDDING: { - new_op = new Embedding(*this, *(Embedding *)node.ptr, inputs[0], true); - break; - } - case OP_EW_ADD: - case OP_EW_SUB: - case OP_EW_MUL: - case OP_EW_MAX: - case OP_EW_MIN: { - assert(inList.size() == 2); - ElementBinary *eb = (ElementBinary *)node.ptr; - new_op = new ElementBinary( - *this, eb->op_type, inputs[0], inputs[1], eb->inplace_a, NULL); - break; - } - case OP_POOL2D: { - new_op = new Pool2D(*this, *(Pool2D *)node.ptr, inputs[0]); - break; - } - case OP_CONV2D: { - new_op = new Conv2D(*this, *(Conv2D *)node.ptr, inputs[0], true); - break; - } - case OP_DROPOUT: { - new_op = new Dropout(*this, *(Dropout *)node.ptr, inputs[0]); - break; - } - case OP_LINEAR: { - new_op = new Linear(*this, *(Linear *)node.ptr, inputs[0], true); - break; - } - case OP_MULTIHEAD_ATTENTION: { - assert(inList.size() == 3); - MultiHeadAttention *attn = (MultiHeadAttention *)node.ptr; - new_op = new MultiHeadAttention( - *this, *attn, inputs[0], inputs[1], inputs[2], true); - break; - break; - } - case OP_SOFTMAX: { - assert(inList.size() == 1); - Softmax *softmax = (Softmax *)node.ptr; - new_op = new Softmax(*this, inputs[0], softmax->dim, NULL); - break; - } - case OP_COMBINE: { - assert(inList.size() == 1); - Combine *combine = (Combine *)node.ptr; - new_op = new Combine( - *this, inputs[0], combine->combine_dim, combine->combine_degree); - break; - } - case OP_REPARTITION: { - assert(inList.size() == 1); - Repartition *repart = (Repartition *)node.ptr; - new_op = new Repartition(*this, - inputs[0], - repart->repartition_dim, - repart->repartition_degree); - break; - } - case OP_REPLICATE: { - assert(inList.size() == 1); - Replicate *replicate = (Replicate *)node.ptr; - new_op = new Replicate(*this, - inputs[0], - replicate->replicate_dim, - replicate->replicate_degree); - break; - } - case OP_REDUCTION: { - assert(inList.size() == 1); - Reduction *reduction = (Reduction *)node.ptr; - new_op = new Reduction(*this, - inputs[0], - reduction->reduction_dim, - reduction->reduction_degree); - break; - } - case OP_FUSED_PARALLEL: { - assert(inList.size() == 1); - FusedParallelOp *fused = (FusedParallelOp *)node.ptr; - std::vector parallel_ops; - for (int i = 0; i < fused->num_parallel_ops; i++) { - parallel_ops.push_back(fused->parallel_ops[i]); - } - new_op = new FusedParallelOp(*this, inputs[0], parallel_ops); - break; - } - default: { - new_op = node.ptr->materialize(*this, inputs, num_inputs); - break; - } - } - // Set machine view for the output tensors of this operator - assert(optimal_views.find(node) != optimal_views.end()); - MachineView view = optimal_views.find(node)->second; - for (int i = 0; i < new_op->numOutputs; i++) { - new_op->outputs[i]->machine_view = view; - } - // Set machine view for the weight tensors of this operator - for (int i = 0; i < new_op->numWeights; i++) { - new_op->weights[i]->machine_view = view; - } - node_to_op[node] = new_op; - operators.push_back(new_op); - // Decrease the todos - auto const &outList = graph->outEdges.find(node)->second; - for (auto const &it : outList) { - todos[it.dstOp] -= 1; - if (todos[it.dstOp] == 0) { - queue.push_back(it.dstOp); - } - } - } - assert(queue.size() == graph->inEdges.size()); - // Remove the final parallel operators - while (operators[operators.size() - 1]->is_parallel_op()) { - Op *op = operators[operators.size() - 1]; - if (op->op_type == OP_REDUCTION) { - break; - } - if (op->op_type == OP_FUSED_PARALLEL) { - FusedParallelOp *fused_op = (FusedParallelOp *)op; - bool has_reduction = false; - for (int i = 0; i < fused_op->num_parallel_ops; i++) { - if (fused_op->parallel_ops[i].op_type == OP_REDUCTION) { - has_reduction = true; - } - } - if (has_reduction) { - break; - } - } - operators.pop_back(); - } - return true; -} - -} // namespace FlexFlow diff --git a/lib/compiler/src/old/substitution.h b/lib/compiler/src/old/substitution.h deleted file mode 100644 index 95a59e952c..0000000000 --- a/lib/compiler/src/old/substitution.h +++ /dev/null @@ -1,309 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef _FLEXFLOW_SUBSTITUTION_H_ -#define _FLEXFLOW_SUBSTITUTION_H_ -#include "graph.h" -#include "substitutions/substitutions.h" -#include "tl/optional.hpp" -#include "utils/recursive_logger.h" -#include -#include - -namespace FlexFlow { -namespace ffc { - -/* struct PMConstraint { */ -/* PMConstraint(Compare comp, PMParameter para, int value); */ -/* Compare comp; */ -/* PMParameter para; */ -/* int value; */ -/* }; */ - -struct TNConstraint { - TNConstraint(Compare comp, TNParameter para, DIMParameter dim, int value); - TNConstraint(Compare comp, - TNParameter para1, - DIMParameter dim1, - TNParameter para2, - DIMParameter dim2); - bool singlePara; - Compare comp; - TNParameter para1, para2; - DIMParameter dim1, dim2; - int value; -}; - -/* class Op; */ -/* class OpX; */ -/* class GraphXfer; */ - -struct TensorX { - static const TensorX NO_TX; - TensorX(void) : op(NULL), idx(0) {} - TensorX(OpX *_op, int _idx) : op(_op), idx(_idx) {} - tl::optional - to_tensor(GraphXfer const *xfer) const; - OpX *op; - int idx; - - bool operator==(TensorX const &other) const; - bool operator!=(TensorX const &other) const; -}; - -struct TensorXCompare { - bool operator()(TensorX const &a, TensorX const &b) const { - if (a.op != b.op) { - return a.op < b.op; - } - return a.idx < b.idx; - }; -}; - -/* class OpX { */ -/* public: */ -/* OpX(OperatorType type, */ -/* int numInputs, */ -/* int numOutputs, */ -/* TensorX const &input1 = TensorX::NO_TX, */ -/* TensorX const &input2 = TensorX::NO_TX, */ -/* TensorX const &input3 = TensorX::NO_TX, */ -/* TensorX const &input4 = TensorX::NO_TX); */ -/* OpX(OperatorType type, */ -/* int num_inputs, */ -/* int num_outputs, */ -/* TensorX const *inputs); */ -/* bool add_pm_constraint(Compare, PMParameter para, int value); */ -/* bool add_input_constraint(Compare, TNParameter, DIMParameter, int); */ -/* bool add_input_constraint( */ -/* Compare, TNParameter, DIMParameter, TNParameter, DIMParameter); */ -/* bool get_pm_constraint(PMParameter para, int &value) const; */ - -/* public: */ -/* OperatorType type; */ -/* Node mapOp; */ -/* OpX const *matchOpX; */ -/* std::vector inputs, weights, outputs; */ -/* std::vector pmConstraints; */ -/* std::vector tnConstraints; */ -/* }; */ - -OpX *create_opx(substitutions::Operator const &op, - int parallel_degree, - TensorX const &input1 = TensorX::NO_TX, - TensorX const &input2 = TensorX::NO_TX, - TensorX const &input3 = TensorX::NO_TX, - TensorX const &input4 = TensorX::NO_TX); -void create_xfer(GraphXfer &xfer, - substitutions::Rule const &r, - int parallel_degree); -std::vector - create_xfers(substitutions::RuleCollection const &rules, - int parallel_degree); - -class GraphCompare { -public: - bool operator()(Graph *lhs, Graph *rhs) { - return lhs->optimal_cost() > rhs->optimal_cost(); - } -}; - -class GraphXferMatch { -public: - GraphXferMatch(GraphXfer const *); - - void add_mapping(Node const &, OpX *); - void add_mapping(OpX *, Node const &); - void add_input_mapping(int, std::pair const &); - void add_output_mapping(TensorX const &, TensorX const &); - OpX *at(Node const &) const; - Node at(OpX *) const; - void set_graph(Graph const *); - - bool containsNode(Graph const *, Node const &) const; - bool containsEdge(Graph const *, Edge const &) const; - - GraphXfer const *get_xfer() const; - std::unordered_set get_nodes() const; - -private: - std::map nodeToOpX; - std::map opXToNode; - std::map mappedOutputs; - size_t graph_hash; - GraphXfer const *xfer; -}; - -/* class GraphXfer { */ -/* public: */ -/* GraphXfer(); */ -/* TensorX new_tensor(void); */ -/* bool can_match(OpX *srcOp, Node const &op, Graph const *graph); */ -/* void match(OpX *srcOp, Node const &op, Graph const *graph); */ -/* void unmatch(OpX *srcOp, Node const &op, Graph const *graph); */ -/* // Compute Ops */ -/* template */ -/* OpX *create_opx(TensorX const &input, OpX const *matchOpX); */ - -/* OpX *create_noop(TensorX const &input); */ -/* OpX *create_concat(TensorX const *inputs, */ -/* int num_inputs, */ -/* OpX const *match_opx, */ -/* int concat_dim); */ -/* OpX *create_element_binary(TensorX const &input1, */ -/* TensorX const &input2, */ -/* OperatorType op_type); */ -/* OpX *create_element_unary(TensorX const &input, OperatorType op_type); */ -/* OpX *create_relu(TensorX const &input); */ -/* OpX *create_linear(TensorX const &input, */ -/* OpX const *match_opx, */ -/* int num_dims, */ -/* ActiMode acti_mode, */ -/* bool use_bias); */ -/* OpX *create_conv2d(TensorX const &input, OpX const *match_opx); */ -/* OpX *create_pool2d(TensorX const &input, OpX const *match_opx); */ -/* OpX *create_attention(TensorX const &query, */ -/* TensorX const &key, */ -/* TensorX const &value, */ -/* OpX const *match_opx, */ -/* int num_heads); */ -/* OpX *create_softmax(TensorX const &input, int softmax_dim); */ -/* // Parallel Ops */ -/* OpX *create_repartition(TensorX const &input, */ -/* int repartition_dim, */ -/* int num_parts); */ -/* OpX *create_replicate(TensorX const &input, int replicate_dim, int - * num_parts); */ -/* OpX *create_reduction(TensorX const &input, int reduction_dim, int - * num_parts); */ -/* OpX *create_combine(TensorX const &input, int combine_dim, int num_parts); - */ -/* bool map_output(TensorX const &src, TensorX const &dst); */ - -/* Graph *create_new_graph(Graph const *graph, */ -/* SimplificationSettings const &settings); */ -/* bool create_new_operator(OpX const *opx, Node &op); */ - -/* std::string get_name() const; */ - -/* void run(int depth, */ -/* Graph *graph, */ -/* std::priority_queue, GraphCompare> - * &, */ -/* std::unordered_set &, */ -/* float threshold, */ -/* int maxNumOps, */ -/* SimplificationSettings const &simplification_settings, */ -/* int &num_matches_found, */ -/* int &num_matches_rejected); */ - -/* void find_matches(Graph const *, std::vector &matches); */ -/* GraphXferMatch get_match_record(Graph const *) const; */ - -/* private: */ -/* void find_matches(int depth, */ -/* Graph const *graph, */ -/* std::vector &matches); */ - -/* public: */ -/* tl::optional name = tl::nullopt; */ -/* int tensorId; */ -/* std::map mappedOps; */ -/* std::multimap> mappedInputs; */ -/* std::map mappedOutputs; */ -/* std::vector srcOps; */ -/* std::vector dstOps; */ -/* }; */ - -struct SubstitutionMatch { - std::unordered_map node_assignment; - std::unordered_map edge_assignment; -}; - -std::unordered_set - find_matches(SubstitutionPattern const &pattern, - ParallelComputationGraph const &pcg); - -class GraphSearchHelper { -public: - GraphSearchHelper(); - void graph_optimize(size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views); - void graph_optimize_no_split( - size_t budget, - bool only_data_parallel, - std::unique_ptr &best_graph, - std::unordered_map &optimal_views); - -private: - template - T generic_sequence_optimize( - Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape); - - float sequence_optimize(Graph const *graph, - Node const &sink_node, - tl::optional const &output_shape, - tl::optional const &input_shape); - - template - T execute_sequence_split( - std::unique_ptr const &pre_graph, - std::unique_ptr const &post_graph, - tl::optional const &output_shape, - tl::optional const &input_shape, - Node const &sink_node, - Node const &bottleneck, - ParallelTensorShape const &bottleneck_output_shape); - void generate_all_pcg_xfers(); - void load_graph_substitutions(std::vector &xfers) const; - Graph *construct_graph(); - void subgraph_optimize(Graph *subgraph); - - std::unique_ptr - base_optimize(Graph const *, - SimplificationSettings const &simplification_settings); - - std::vector - possible_split_output_tensor_shapes(Node const &) const; - - void find_rewrite_matches(Graph const *graph, - std::vector &matches) const; - tl::optional find_split_node(Graph const *graph, - int base_optimize_threshold) const; - - template - tl::optional try_get_cost_from_cache(size_t hash) const; - - template - void try_cache_result(size_t hash, T const &value); - - template - T get_optimal_cost(std::unique_ptr optimized) const; - -private: - std::unordered_map cached_optimized_graphs; - std::vector all_pcg_xfers; - std::unique_ptr logger; -}; - -} // namespace ffc -} // namespace FlexFlow - -#endif diff --git a/lib/compiler/src/unity_algorithm.cc b/lib/compiler/src/unity_algorithm.cc index 86fdd88d92..c9666851db 100644 --- a/lib/compiler/src/unity_algorithm.cc +++ b/lib/compiler/src/unity_algorithm.cc @@ -9,9 +9,17 @@ bool StrategyRuntimeCmp::operator()(Strategy const &lhs, Strategy const &rhs) { return lhs.runtime < rhs.runtime; } +/* + * Gets all substitutions applicable to a PCG + */ std::unordered_set - get_all_substitutions(ParallelComputationGraph const &pcg); + get_all_applicable_substitutions(ParallelComputationGraph const &pcg) { + NOT_IMPLEMENTED(); +} +/* + * Applies a substitution to all possible positions in PCG + */ std::unordered_set apply_substitution(ParallelComputationGraph const &pcg, Substitution const &) { @@ -20,7 +28,7 @@ std::unordered_set Strategy graph_optimize(ComputationGraph &cg, - ICostEstimator const &cost_estimator, + CostEstimator const &cost_estimator, MachineSpecification const &resources, std::function( Operator const &, MachineSpecification const &)> const @@ -29,18 +37,19 @@ Strategy ParallelComputationGraph pcg = cg_to_pcg(cg); - std::unordered_set subs = get_all_substitutions(pcg); + std::unordered_set subs = get_all_applicable_substitutions(pcg); OptimalCostCache cached_subgraph_costs; DeduplicatedPriorityQueue, StrategyRuntimeCmp> candidates; - Strategy initial_result(pcg, - optimal_cost(pcg, - allowed_machine_views, - cost_estimator, - resources, - cached_subgraph_costs)); + OptimalCostResult initial_pcg_result = optimal_cost(pcg, + allowed_machine_views, + cost_estimator, + resources, + cached_subgraph_costs); + Strategy initial_result{ + pcg, initial_pcg_result.machine_mapping, initial_pcg_result.runtime}; Strategy best_result = initial_result; candidates.push(initial_result); @@ -50,7 +59,7 @@ Strategy Strategy const ¤t_result = candidates.top(); candidates.pop(); - if (StrategyRuntimeCmp(current_result, best_result)) { + if (current_result.runtime < best_result.runtime) { best_result = current_result; } else if (current_result.runtime > best_result.runtime * opt_config.alpha) { @@ -64,9 +73,9 @@ Strategy cost_estimator, resources, cached_subgraph_costs); - Strategy new_result(new_pcg, c.machine_mapping, c.runtime); + Strategy new_result{new_pcg, c.machine_mapping, c.runtime}; if (new_result.runtime <= opt_config.threshold && - new_result.pcg.query_nodes({}).size() <= opt_config.max_num_ops) { + get_nodes(new_pcg.value()).size() <= opt_config.max_num_ops) { candidates.push(new_result); } } diff --git a/lib/compiler/src/utils/recursive_logger.cc b/lib/compiler/src/utils/recursive_logger.cc.todo similarity index 100% rename from lib/compiler/src/utils/recursive_logger.cc rename to lib/compiler/src/utils/recursive_logger.cc.todo diff --git a/lib/compiler/src/utils/recursive_logger.h b/lib/compiler/src/utils/recursive_logger.h.todo similarity index 100% rename from lib/compiler/src/utils/recursive_logger.h rename to lib/compiler/src/utils/recursive_logger.h.todo diff --git a/lib/compiler/test/CMakeLists.txt b/lib/compiler/test/CMakeLists.txt index dbbd0a63ec..13b1fd3b83 100644 --- a/lib/compiler/test/CMakeLists.txt +++ b/lib/compiler/test/CMakeLists.txt @@ -1,11 +1,12 @@ ff_add_test_executable( NAME - compiler-test + compiler-tests SRC_PATTERNS src/*.cc PRIVATE_INCLUDE src/ DEPS + utils compiler doctest utils-test-common diff --git a/lib/compiler/test/test_cost_estimator.h b/lib/compiler/test/src/test_cost_estimator.h similarity index 100% rename from lib/compiler/test/test_cost_estimator.h rename to lib/compiler/test/src/test_cost_estimator.h diff --git a/lib/compiler/test/src/test_generator.h b/lib/compiler/test/src/test_generator.h new file mode 100644 index 0000000000..d6b8222968 --- /dev/null +++ b/lib/compiler/test/src/test_generator.h @@ -0,0 +1,174 @@ +#ifndef _FLEXFLOW_TEST_GENERATOR_H +#define _FLEXFLOW_TEST_GENERATOR_H + +#include "compiler/machine_mapping.h" +#include "pcg/computation_graph.h" +#include "rapidcheck.h" +#include "substitutions/sub_parallel_computation_graph.h" + +using namespace FlexFlow; + +// Rapidcheck does not work for now +// /* +// Generates computation graphs with trivial layers and tensors, which are +// used for tests focusing on graph structures. +// */ +// ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { +// return materialize_output_labelled_multidigraph_view( +// ViewMultiDiGraphAsOutputLabelled( +// g, +// [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, +// [](Tensor(MultiDiOutput const &)) { +// return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; +// })); +// } + +// /* +// Generates parallel computation graphs with trivial layers and tensors, +// which are used for tests focusing on graph structures. +// */ +// ParallelComputationGraph +// test_parallel_computation_graph(MultiDiGraphView const &g) { +// return materialize_output_labelled_multidigraph_view( +// ViewMultiDiGraphAsOutputLabelled( +// g, +// [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, +// [](Operator(MultiDiOutput const &)) { +// return ParallelTensor(ParallelTensorDims(TensorDims({})), +// DataType::FLOAT); +// })); +// } + +// rc::Gen small_integer_generator() { +// return rc::gen::inRange(1, 4); +// } + +// namespace rc { + +// Gen serialParallelMultiDiGraph() { +// return gen::map(gen::arbitrary(), +// multidigraph_from_sp_decomposition); +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return +// gen::map(gen::cast(serialParallelMultiDiGraph()), +// test_computataion_graph); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return +// gen::map(gen::cast(serialParallelMultiDiGraph()), +// test_parallel_computation_graph); +// } +// }; + +// template <> +// struct Arbitrary> { +// static Gen> arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_node) { +// return is_node +// ? gen::cast>(gen::arbitrary()) +// : gen::cast>(gen::arbitrary()); +// }); +// } +// }; + +// template <> +// struct Arbitrary> { +// static Gen> arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_node) { +// return is_node +// ? gen::cast>(gen::arbitrary()) +// : gen::cast>( +// gen::arbitrary()); +// }); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&Serial::children, +// gen::container>>( +// gen::arbitrary>()))); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&Parallel::children, +// gen::container>>( +// gen::arbitrary>()))); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::mapcat(gen::arbitrary(), [](bool is_serial) { +// return is_serial ? gen::construct( +// gen::arbitrary()) +// : gen::construct( +// gen::arbitrary()); +// }); +// } +// }; + +// template +// struct Arbitrary { +// static Gen< +// std::enable_if, +// Tag>::value>::type> arbitrary() { +// return gen::construct(gen::arbitrary()); +// } +// }; + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::apply(make_1d_machine_view, +// gen::arbitrary, +// gen::arbitrary, +// small_integer_generator()); +// } +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&MachineMapping::machine_views, +// gen::container>( +// gen::arbitrary(), +// gen::arbitrary()))); +// } +// } + +// template <> +// struct Arbitrary { +// static Gen arbitrary() { +// return gen::build( +// gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), +// gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, +// 64)), gen::set(&MachineSpecification::num_gpus_per_node, +// gen::inRange(1, 16)), +// gen::set(&MachineSpecification::inter_node_bandwidth, +// gen::nonZero()), +// gen::set(&MachineSpecification::intra_node_bandwidth, +// gen::nonZero())); +// } +// } + +// } // namespace rc + +#endif diff --git a/lib/compiler/test/src/test_labelled_open_graph.cc b/lib/compiler/test/src/test_labelled_open_graph.cc new file mode 100644 index 0000000000..ccad7b19ff --- /dev/null +++ b/lib/compiler/test/src/test_labelled_open_graph.cc @@ -0,0 +1,130 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +// #include "rapidcheck.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_subgraph(OpenMultiDiGraphView)") { + auto g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); + Node n4 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + NodePort p2 = g.add_node_port(); + NodePort p3 = g.add_node_port(); + NodePort p4 = g.add_node_port(); + NodePort p5 = g.add_node_port(); + NodePort p6 = g.add_node_port(); + NodePort p7 = g.add_node_port(); + NodePort p8 = g.add_node_port(); + NodePort p9 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + MultiDiEdge e1{n2, p2, n0, p0}; + MultiDiEdge e2{n3, p5, n1, p3}; + MultiDiEdge e3{n3, p6, n2, p4}; + MultiDiEdge e4{n4, p8, n3, p7}; + OutputMultiDiEdge e5{n4, p9, std::make_pair(p9.value(), p9.value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + std::unordered_set node_set0{n3, n4}; + + auto subgraph0 = get_subgraph(g, node_set0); + auto subgraph1 = get_subgraph(g, node_set0); + auto subgraph2 = + get_subgraph(g, node_set0); + auto subgraph3 = get_subgraph(g, node_set0); + + CHECK(bool(get_nodes(subgraph0) == node_set0)); + CHECK(bool(get_nodes(subgraph1) == node_set0)); + CHECK(bool(get_nodes(subgraph2) == node_set0)); + CHECK(bool(get_nodes(subgraph3) == node_set0)); + + std::unordered_set input_set{split_edge(e2).second, + split_edge(e3).second}; + std::unordered_set output_set{e5}; + + CHECK(bool(get_open_inputs(subgraph0) == input_set)); + CHECK(bool(get_open_inputs(subgraph1) == input_set)); + CHECK(bool(get_open_inputs(subgraph2).empty())); + CHECK(bool(get_open_inputs(subgraph3).empty())); + + CHECK(bool(get_open_outputs(subgraph0) == output_set)); + CHECK(bool(get_open_outputs(subgraph1).empty())); + CHECK(bool(get_open_outputs(subgraph2) == output_set)); + CHECK(bool(get_open_outputs(subgraph3).empty())); + + CHECK(bool(get_edges(subgraph0) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4, e5})); + CHECK(bool(get_edges(subgraph1) == + std::unordered_set{ + split_edge(e2).second, split_edge(e3).second, e4})); + CHECK(bool(get_edges(subgraph2) == + std::unordered_set{e4, e5})); + CHECK( + bool(get_edges(subgraph3) == std::unordered_set{e4})); + + CHECK(bool(get_closed_sources(subgraph2) == std::unordered_set{n3})); + } + + TEST_CASE("view OutputLabelledMultiDiGraph as open") { + OutputLabelledMultiDiGraph g = + OutputLabelledMultiDiGraph::create< + UnorderedOutputLabelledMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + + g.add_edge(e0); + g.add_output(e0, 2); + + CHECK(bool(get_edges(g).size() == 1)); + + OutputLabelledOpenMultiDiGraphView open_graph = + view_output_labelled_as_output_labelled_open(g); + + CHECK(bool(open_graph.at(n0) == 0)); + CHECK(bool(open_graph.at(n1) == 1)); + CHECK(bool(open_graph.at(e0) == 2)); + + CHECK(get_edges(open_graph).size() == 1); + } + + TEST_CASE("OutputLabelledOpenMultiDiGraph") { + OutputLabelledOpenMultiDiGraph g = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n0 = g.add_node(0); + Node n1 = g.add_node(1); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + MultiDiEdge e0{n1, p1, n0, p0}; + + g.add_edge(e0); + g.add_label(e0, 2); + + CHECK(bool(g.query_edges(OpenMultiDiEdgeQuery::all()).size() == 1)); + CHECK(bool(get_edges(g).size() == 1)); + } +} diff --git a/lib/compiler/test/src/test_machine_mapping.cc b/lib/compiler/test/src/test_machine_mapping.cc new file mode 100644 index 0000000000..365ed3e1db --- /dev/null +++ b/lib/compiler/test/src/test_machine_mapping.cc @@ -0,0 +1,23 @@ +#include "doctest/doctest.h" +#include "test_generator.h" + +TEST_SUITE(FF_TEST_SUITE) { + // TEST_CASE("MachineMapping::combine") { + // rc::check([](MachineMapping const &m0, MachineMapping const &m1) { + // RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); + + // MachineMapping comb = MachineMapping::combine(m0, m1); + + // RC_ASSERT(comb.machine_views.size() == + // m0.machine_views.size() + m1.machine_views.size()); + // RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); + // RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); + // }); + // } + + // TEST_CASE("OptimalCostResult::infinity") { + // rc::check([](OptimalCostResult const &c) { + // RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); + // }); + // } +} diff --git a/lib/compiler/test/src/test_open_graph.cc b/lib/compiler/test/src/test_open_graph.cc new file mode 100644 index 0000000000..db3630d316 --- /dev/null +++ b/lib/compiler/test/src/test_open_graph.cc @@ -0,0 +1,76 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "utils/graph/algorithms.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_source_sink_open_graph") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + NodePort p0 = g.add_node_port(); + InputMultiDiEdge e0{ + n0, g.add_node_port(), std::make_pair(n0.value(), n0.value())}; + g.add_edge(e0); + + CHECK(bool(get_closed_sources(g) == std::unordered_set{})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{})); + } + + TEST_CASE("get_source_sink_open_graph:unconnected") { + OpenMultiDiGraph g = OpenMultiDiGraph::create(); + + Node n0 = g.add_node(); + Node n1 = g.add_node(); + + NodePort p0 = g.add_node_port(); + NodePort p1 = g.add_node_port(); + + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + OutputMultiDiEdge e1{n1, p1, std::make_pair(p1.value(), p1.value())}; + g.add_edge(e0); + g.add_edge(e1); + + /* + g: ->n0 + n1-> + */ + + CHECK(bool(get_closed_sources(g) == std::unordered_set{n1})); + CHECK(bool(get_closed_sinks(g) == std::unordered_set{n0})); + + CHECK(bool(get_open_sources(g) == std::unordered_set{n0})); + CHECK(bool(get_open_sinks(g) == std::unordered_set{n1})); + } + + TEST_CASE("get_cut") { + auto g = OpenMultiDiGraph::create(); + + std::vector ns = add_nodes(g, 5); + + MultiDiEdge e0{ns[1], g.add_node_port(), ns[0], g.add_node_port()}; + MultiDiEdge e1{ns[2], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e2{ns[3], g.add_node_port(), ns[1], g.add_node_port()}; + MultiDiEdge e3{ns[4], g.add_node_port(), ns[2], g.add_node_port()}; + MultiDiEdge e4{ns[4], g.add_node_port(), ns[3], g.add_node_port()}; + OutputMultiDiEdge e5{ + ns[4], g.add_node_port(), std::make_pair(ns[4].value(), ns[4].value())}; + + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + g.add_edge(e3); + g.add_edge(e4); + g.add_edge(e5); + + GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; + CHECK(bool(get_cut_set(g, gs0) == std::unordered_set{e1, e2})); + + GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; + CHECK(bool(get_cut_set(g, gs1) == std::unordered_set{e3, e4})); + } +} diff --git a/lib/compiler/test/src/test_optimal_cost.cc b/lib/compiler/test/src/test_optimal_cost.cc new file mode 100644 index 0000000000..91c7a11888 --- /dev/null +++ b/lib/compiler/test/src/test_optimal_cost.cc @@ -0,0 +1,69 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "test_cost_estimator.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + // Rapidcheck infrastructures for graphs does not work for now + /* + Tests whether optimal_cost can give a valid result given random PCG, trivial + allowed machine views, trivial cost estimator and random machine + specification. + */ + // TEST_CASE("optimal_cost") { + // auto test_allowed_machine_views = [](Operator const &, + // MachineSpecification const &) { + // return std::unordered_set{make_1d_machine_view(0, 1, 1)}; + // }; + // rc::check([](ParallelComputationGraph const &g, + // MachineSpecification const &machine_spec) { + // OptimalCostCache cached_subgraph_costs; + // OptimalCostResult result = optimal_cost(g, + // test_allowed_machine_views, + // TestCostEstimator{}, + // machine_spec, + // cached_subgraph_costs); + // RC_ASSERT(result.runtime > 0); + // RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); + // }); + // } + + TEST_CASE("optimal_cost_0") { + auto pcg = + OutputLabelledMultiDiGraph::template create< + UnorderedOutputLabelledMultiDiGraph>(); + + Node n0 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n1 = pcg.add_node(Operator{ + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, + "linear"}); + + MultiDiEdge e{n1, pcg.add_node_port(), n0, pcg.add_node_port()}; + pcg.add_edge(e); + pcg.add_output(e, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); + + auto test_allowed_machine_views = [](Operator const &, + MachineSpecification const &) { + return std::unordered_set{ + make_1d_machine_view(gpu_id_t(1), gpu_id_t(2))}; + }; + + CostEstimator estimator = CostEstimator::create(); + + MachineSpecification machine_spec{1, 1, 1, 1, 1}; + + OptimalCostCache cached_results; + + OptimalCostResult result = optimal_cost(ParallelComputationGraph(pcg), + test_allowed_machine_views, + estimator, + machine_spec, + cached_results); + + CHECK(bool(result.runtime > 0)); + } +} diff --git a/lib/compiler/test/src/test_unity_algorithm.cc b/lib/compiler/test/src/test_unity_algorithm.cc new file mode 100644 index 0000000000..614e9bb182 --- /dev/null +++ b/lib/compiler/test/src/test_unity_algorithm.cc @@ -0,0 +1,28 @@ +#include "compiler/unity_algorithm.h" +#include "doctest/doctest.h" +#include "test_cost_estimator.h" +#include "test_generator.h" + +TEST_SUITE(FF_TEST_SUITE) { + // Rapidcheck does not work for now + // TEST_CASE("graph_optimize") { + // rc::check([](ComputationGraph const &g, + // float alpha, + // int budget, + // float threshold, + // int max_num_ops) { + // Strategy s = graph_optimize( + // g, + // TestCostEstimator{}, + // MachineSpecification{1, 1, 4, 0.1, 0.2}, + // [](Operator const &, MachineSpecification const &) { + // return std::unordered_set{make_1d_machine_view(0, 1, + // 1)}; + // }, + // OptimizerConfig{alpha, budget, threshold, max_num_ops}); + // RC_ASSERT(get_nodes(s.pcg).size() > 0); + // RC_ASSERT(s.machine_mapping.runtime > 0); + // RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); + // }); + // } +} diff --git a/lib/compiler/test/test_disjoint_set.cc b/lib/compiler/test/test_disjoint_set.cc deleted file mode 100644 index 796605f53f..0000000000 --- a/lib/compiler/test/test_disjoint_set.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "flexflow/utils/disjoint_set.h" -#include "gtest/gtest.h" - -TEST(disjoint_set, basic) { - int ctr = 0; - int a = ctr++, b = ctr++, c = ctr++, d = ctr++, e = ctr++, f = ctr++; - - disjoint_set ds; - ds.m_union(a, b); - ds.m_union(b, c); - ds.m_union(e, f); - ds.m_union(d, a); - - assert(ds.find(a) == ds.find(b)); - assert(ds.find(a) == ds.find(c)); - assert(ds.find(a) == ds.find(d)); - assert(ds.find(e) == ds.find(f)); - assert(ds.find(e) != ds.find(a)); -} diff --git a/lib/compiler/test/test_dominators.cc b/lib/compiler/test/test_dominators.cc deleted file mode 100644 index 60ac33696f..0000000000 --- a/lib/compiler/test/test_dominators.cc +++ /dev/null @@ -1,322 +0,0 @@ -#include "flexflow/basic_graph.h" -#include "flexflow/dominators.h" -#include "flexflow/utils/hash-utils.h" -#include "gtest/gtest.h" - -using namespace FlexFlow::PCG::Utils; - -namespace FlexFlow::PCG::Utils { -template <> -struct invalid_node<::BasicGraph, GraphStructure<::BasicGraph>> { - int operator()() const { - return -1; - } -}; -} // namespace FlexFlow::PCG::Utils - -TEST(pred_succ_cessors, basic) { - BasicGraph g; - g.add_node(0); - g.add_node(1); - g.add_node(2); - g.add_node(3); - g.add_node(4); - - g.add_edge(0, 2); - g.add_edge(1, 2); - g.add_edge(2, 3); - g.add_edge(2, 4); - - using AnswerMap = std::unordered_map>; - - AnswerMap expected_predecessors; - - expected_predecessors = {{0, {}}, {1, {}}, {2, {0, 1}}, {3, {2}}, {4, {2}}}; - - AnswerMap expected_successors = { - {0, {2}}, {1, {2}}, {2, {3, 4}}, {3, {}}, {4, {}}}; - - std::unordered_set answer; - for (auto const &kv : expected_predecessors) { - answer.clear(); - predecessors>(g, kv.first, &answer); - EXPECT_EQ(kv.second, answer) - << "^^^ Predecessors for node " << kv.first << std::endl; - } - for (auto const &kv : expected_successors) { - answer.clear(); - successors>(g, kv.first, &answer); - EXPECT_EQ(kv.second, answer) - << "^^^ Successors for node " << kv.first << std::endl; - } -} - -TEST(topo_sort, basic) { - BasicGraph g; - g.add_nodes({0, 1, 2, 3}); - g.add_edges({{3, 1}, {3, 0}, {1, 0}, {0, 2}}); - - std::vector topo_answer = {3, 1, 0, 2}; - - std::vector topo_result; - topo_sort(g, &topo_result); - EXPECT_EQ(topo_result, topo_answer); -} - -BasicGraph get_dominator_test_graph() { - BasicGraph g; - g.add_nodes({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); - g.add_edges({{1, 2}, - {1, 7}, - {2, 3}, - {2, 4}, - {3, 6}, - {4, 5}, - {4, 6}, - {5, 6}, - {6, 8}, - {7, 8}, - {8, 9}, - {8, 10}, - {9, 11}, - {10, 11}}); - - return g; -} - -TEST(dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map> answer = {{1, {1}}, - {2, {1, 2}}, - {3, {1, 2, 3}}, - {4, {1, 2, 4}}, - {5, {1, 2, 4, 5}}, - {6, {1, 2, 6}}, - {7, {1, 7}}, - {8, {1, 8}}, - {9, {1, 8, 9}}, - {10, {1, 8, 10}}, - {11, {1, 8, 11}}}; - - EXPECT_EQ(dominators(g), answer); -} - -TEST(post_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map> answer = {{1, {1, 8, 11}}, - {2, {2, 6, 8, 11}}, - {3, {3, 6, 8, 11}}, - {4, {4, 6, 8, 11}}, - {5, {5, 6, 8, 11}}, - {6, {6, 8, 11}}, - {7, {7, 8, 11}}, - {8, {8, 11}}, - {9, {9, 11}}, - {10, {10, 11}}, - {11, {11}}}; - - EXPECT_EQ(post_dominators(g), answer); -} - -TEST(imm_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map answer = {{1, 1}, // no immediate dominator - {2, 1}, - {3, 2}, - {4, 2}, - {5, 4}, - {6, 2}, - {7, 1}, - {8, 1}, - {9, 8}, - {10, 8}, - {11, 8}}; - - EXPECT_EQ(imm_dominators(g), answer); -} - -TEST(imm_post_dominators, basic) { - BasicGraph g = get_dominator_test_graph(); - - std::unordered_map answer = { - {1, 8}, - {2, 6}, - {3, 6}, - {4, 6}, - {5, 6}, - {6, 8}, - {7, 8}, - {8, 11}, - {9, 11}, - {10, 11}, - {11, 11} // no immediate post - // dominator - }; - - EXPECT_EQ(imm_post_dominators(g), answer); -} - -TEST(imm_post_dominators, multisource) { - BasicGraph g; - - g.add_nodes({1, 2, 3, 4, 5}); - g.add_edges({{1, 3}, {2, 3}, {3, 4}, {3, 5}}); - - std::unordered_map answer = { - {-1, 3}, {1, 3}, {2, 3}, {3, 3}, {4, 4}, {5, 5}}; - - auto result = - imm_post_dominators>( - g); - EXPECT_EQ(result, answer); -} - -TEST(transitive_reduction, basic) { - BasicGraph g({1, 2, 3}, {{1, 2}, {2, 3}, {1, 3}}); - - BasicGraph answer({1, 2, 3}, {{1, 2}, {2, 3}}); - - auto result = transitive_reduction(g); - - EXPECT_EQ(result, answer); -} - -TEST(transitive_reduction, medium) { - BasicGraph g({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {1, 5}, - {2, 3}, - {2, 4}, - {2, 6}, - {3, 4}, - {4, 5}, - {4, 6}, - {5, 6}, - }); - - BasicGraph answer({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {2, 3}, - {3, 4}, - {4, 5}, - {5, 6}, - }); - - auto result = transitive_reduction(g); - - EXPECT_EQ(result, answer); -} - -TEST(inplace_transitive_reduction, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {1, 5}, - {2, 3}, - {2, 4}, - {2, 6}, - {3, 4}, - {4, 5}, - {4, 6}, - {5, 6}, - }); - - BasicGraph answer({1, 2, 3, 4, 5, 6, 7}, - { - {1, 4}, - {2, 3}, - {3, 4}, - {4, 5}, - {5, 6}, - }); - - inplace_transitive_reduction(g); - - EXPECT_EQ(g, answer); -} - -TEST(roots, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - { - {1, 3}, - {2, 3}, - {3, 4}, - {3, 5}, - {3, 6}, - }); - - std::unordered_set answer{1, 2}; - - auto result = roots(g); - - EXPECT_EQ(result, answer); -} - -TEST(leaves, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 3}, {2, 3}, {3, 4}, {3, 5}, {3, 6}}); - - std::unordered_set answer{4, 5, 6}; - - auto result = leaves(g); - - EXPECT_EQ(result, answer); -} - -TEST(descendants, directed) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 2}, {2, 3}, {2, 4}, {3, 5}, {4, 5}}); - - std::unordered_set answer{2, 3, 4, 5}; - - auto result = descendants(g, 2); - - EXPECT_EQ(result, answer); -} - -TEST(descendants, undirected) { - BasicGraph g({1, 2, 3, 4, 5, 6}, - {{1, 2}, {2, 3}, {2, 4}, {3, 5}, {4, 5}}); - - std::unordered_set answer{1, 2, 3, 4, 5}; - - auto result = - descendants>(g, 2); - - EXPECT_EQ(result, answer); -} - -TEST(weakly_connected_components, basic) { - BasicGraph g({1, 2, 3, 4, 5, 6}, {{1, 3}, {2, 3}, {4, 5}, {5, 4}}); - - std::unordered_set component1{1, 2, 3}; - std::unordered_set component2{4, 5}; - std::unordered_set component3{6}; - auto result = weakly_connected_components(g); - - EXPECT_EQ(result.size(), 3); - bool component1_found = false; - bool component2_found = false; - bool component3_found = false; - for (std::unordered_set &component : result) { - if (component.size() == component1.size()) { - component1_found = true; - EXPECT_EQ(component, component1); - } else if (component.size() == component2.size()) { - component2_found = true; - EXPECT_EQ(component, component2); - } else if (component.size() == component3.size()) { - component3_found = true; - EXPECT_EQ(component, component3); - } - } - - EXPECT_TRUE(component1_found); - EXPECT_TRUE(component2_found); - EXPECT_TRUE(component3_found); -} diff --git a/lib/compiler/test/test_dot.cc b/lib/compiler/test/test_dot.cc deleted file mode 100644 index 3212971255..0000000000 --- a/lib/compiler/test/test_dot.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "flexflow/utils/dot/record_formatter.h" -#include "gtest/gtest.h" - -TEST(record_formatters, basic) { - RecordFormatter rf, rf2, rf3; - std::ostringstream oss; - oss << "Wo" - << "rld"; - rf << "Hello" - << "World" - << (rf2 << "Inner" - << "World" - << (rf3 << "Even" - << "More" - << "Inner World")) - << "Goodbye" << oss; - - std::ostringstream oss_final; - oss_final << rf; - EXPECT_EQ(oss_final.str(), - "{ Hello | World | { Inner | World | { Even | More | Inner World } " - "} | Goodbye | World }"); -} diff --git a/lib/compiler/test/test_dp.cc b/lib/compiler/test/test_dp.cc deleted file mode 100644 index 1878ade0b6..0000000000 --- a/lib/compiler/test/test_dp.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -struct TestCostEstimator : public ICostEstimator { - float estimate_cost(PCGOperatorAttrs const &op, - std::vector const &inputs, - MachineView const &mv) const override { - return 0.1; - } - float estimate_cost(ParallelTensorShape const &tensor_shape, - MachineView const &src, - MachineView const &dst) const override { - return 1; - } -}; - -TEST_CASE("optimal_cost") { - auto g(NodeLabelledMultiDiGraph::create< - UnorderedNodeLabelledMultiDiGraph>()); - - Node n0 = g.add_node(InputAttrs()); - Node n1 = g.add_node(RepartitionAttrs(ff_dim_t(0), 2)); - Node n2 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 0)); - Node n3 = g.add_node(ElementUnaryAttrs(OP_SCALAR_ADD, 1)); - Node n4 = g.add_node(ConcatAttrs(ff_dim_t(1))); - Node n5 = g.add_node(CombineAttrs(ff_dim_t(0), 2)); - - MultiDiEdge e0(n0, n1, 0, 0); - MultiDiEdge e1(n1, n2, 0, 0); - MultiDiEdge e2(n1, n3, 1, 0); - MultiDiEdge e3(n2, n4, 0, 0); - MultiDiEdge e4(n3, n4, 0, 1); - MultiDiEdge e5(n4, n5, 0, 0); - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - g.add_edge(e3); - g.add_edge(e4); - - OptimizerPCG pcg = infer_tensor_shape(g); - auto allowed_machine_views = [](PCGOperatorAttrs const &, - MachineResource const &) { - // TODO - return std::unordered_set{}; - }; - MachineResource resource(1, 1, 2); - Strategy s = - optimal_cost(pcg, allowed_machine_views, TestCostEstimator{}, resource); - - // TODO: check result -} diff --git a/lib/compiler/test/test_generator.h b/lib/compiler/test/test_generator.h deleted file mode 100644 index 374bb89455..0000000000 --- a/lib/compiler/test/test_generator.h +++ /dev/null @@ -1,168 +0,0 @@ -#ifndef _FLEXFLOW_TEST_GENERATOR_H -#define _FLEXFLOW_TEST_GENERATOR_H - -#include "compiler/machine_mapping.h" -#include "compiler/sub_parallel_computation_graph.h" -#include "pcg/computation_graph.h" -#include "rapidcheck.h" - -using namespace FlexFlow; - -/* - Generates computation graphs with trivial layers and tensors, which are used - for tests focusing on graph structures. -*/ -ComputationGraph test_computataion_graph(MultiDiGraphView const &g) { - return materialize_output_labelled_multidigraph_view( - ViewMultiDiGraphAsOutputLabelled( - g, - [](Layer(Node const &)) { return Layer(NoopAttrs{}); }, - [](Tensor(MultiDiOutput const &)) { - return Tensor{0, DataType::FLOAT, nullopt, false, nullopt}; - })); -} - -/* - Generates parallel computation graphs with trivial layers and tensors, which - are used for tests focusing on graph structures. -*/ -ParallelComputationGraph - test_parallel_computation_graph(MultiDiGraphView const &g) { - return materialize_output_labelled_multidigraph_view( - ViewMultiDiGraphAsOutputLabelled( - g, - [](Operator(Node const &)) { return ParallelTensor(NoopAttrs{}); }, - [](Operator(MultiDiOutput const &)) { - return ParallelTensor(ParallelTensorDims(TensorDims({})), - DataType::FLOAT); - })); -} - -rc::Gen small_integer_generator() { - return rc::gen::inRange(1, 4); -} - -namespace rc { - -Gen serialParallelMultiDiGraph() { - return gen::map(gen::arbitrary(), - multidigraph_from_sp_decomposition); -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::map(gen::cast(serialParallelMultiDiGraph()), - test_computataion_graph); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::map(gen::cast(serialParallelMultiDiGraph()), - test_parallel_computation_graph); - } -}; - -template <> -struct Arbitrary> { - static Gen> arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node - ? gen::cast>(gen::arbitrary()) - : gen::cast>(gen::arbitrary()); - }); - } -}; - -template <> -struct Arbitrary> { - static Gen> arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_node) { - return is_node - ? gen::cast>(gen::arbitrary()) - : gen::cast>( - gen::arbitrary()); - }); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&Serial::children, - gen::container>>( - gen::arbitrary>()))); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&Parallel::children, - gen::container>>( - gen::arbitrary>()))); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::mapcat(gen::arbitrary(), [](bool is_serial) { - return is_serial ? gen::construct( - gen::arbitrary()) - : gen::construct( - gen::arbitrary()); - }); - } -}; - -template -struct Arbitrary { - static Gen< - std::enable_if, Tag>::value>::type> - arbitrary() { - return gen::construct(gen::arbitrary()); - } -}; - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::apply(make_1d_machine_view, - gen::arbitrary, - gen::arbitrary, - small_integer_generator()); - } -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&MachineMapping::machine_views, - gen::container>( - gen::arbitrary(), gen::arbitrary()))); - } -} - -template <> -struct Arbitrary { - static Gen arbitrary() { - return gen::build( - gen::set(&MachineSpecification::num_nodes, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_cpus_per_node, gen::inRange(1, 64)), - gen::set(&MachineSpecification::num_gpus_per_node, gen::inRange(1, 16)), - gen::set(&MachineSpecification::inter_node_bandwidth, - gen::nonZero()), - gen::set(&MachineSpecification::intra_node_bandwidth, - gen::nonZero())); - } -} - -} // namespace rc - -#endif diff --git a/lib/compiler/test/test_labelled_open_graph.cc b/lib/compiler/test/test_labelled_open_graph.cc deleted file mode 100644 index 7d85514816..0000000000 --- a/lib/compiler/test/test_labelled_open_graph.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -TEST_CASE("get_subgraph_labelled_open_graph") { - auto g = LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>(); - - int t0 = 100000; - - Node n0 = g.add_node(0); - Node n1 = g.add_node(1); - Node n2 = g.add_node(2); - Node n3 = g.add_node(3); - Node n4 = g.add_node(4); - - MultiDiEdge e0(n0, n1, 0, 0); - MultiDiEdge e1(n0, n2, 1, 0); - MultiDiEdge e2(n1, n3, 0, 0); - MultiDiEdge e3(n2, n3, 0, 1); - MultiDiEdge e4(n3, n4, 0, 0); - OutputMultiDiEdge e5({n4.value(), t0}, n4, 0); - - g.add_edge(e0, 0); - g.add_edge(e1, 1); - g.add_edge(e2, 2); - g.add_edge(e3, 3); - g.add_edge(e4, 4); - g.add_edge(e5, 5); - - auto subgraph0 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::INCLUDE_INPUTS, - OutputSettings::INCLUDE_OUTPUTS); - auto subgraph1 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::INCLUDE_INPUTS, - OutputSettings::EXCLUDE_OUTPUTS); - auto subgraph2 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::EXCLUDE_INPUTS, - OutputSettings::INCLUDE_OUTPUTS); - auto subgraph3 = get_subgraph(g, - std::unordered_set{n3, n4}, - InputSettings::EXCLUDE_INPUTS, - OutputSettings::EXCLUDE_OUTPUTS); - - CHECK(get_nodes(subgraph0) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph1) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph2) == std::unordered_set{n3, n4}); - CHECK(get_nodes(subgraph3) == std::unordered_set{n3, n4}); - - std::unordered_set input_set{split_edge(e2).second, - split_edge(e3).second}; - std::unordered_set output_set{e5}; - - CHECK(get_inputs(subgraph0) == input_set); - CHECK(get_inputs(subgraph1) == input_set); - CHECK(get_inputs(subgraph2).empty()); - CHECK(get_inputs(subgraph3).empty()); - - CHECK(get_outputs(subgraph0) == output_set); - CHECK(get_outputs(subgraph1).empty()); - CHECK(get_outputs(subgraph2) == output_set); - CHECK(get_outputs(subgraph3).empty()); - - CHECK(get_edges(subgraph0) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4, e5}); - CHECK(get_edges(subgraph1) == - std::unordered_set{ - split_edge(e2).second, split_edge(e3).second, e4}); - CHECK(get_edges(subgraph2) == std::unordered_set{e4, e5}); - CHECK(get_edges(subgraph3) == std::unordered_set{e4}); -} diff --git a/lib/compiler/test/test_machine_mapping.cc b/lib/compiler/test/test_machine_mapping.cc deleted file mode 100644 index 4436a992d3..0000000000 --- a/lib/compiler/test/test_machine_mapping.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "doctest.h" -#include "test_generator.h" - -TEST_CASE("MachineMapping::combine") { - rc::check([](MachineMapping const &m0, MachineMapping const &m1) { - RC_PRE(MachineMapping::nodes_are_disjoint(m0, m1)); - - MachineMapping comb = MachineMapping::combine(m0, m1); - - RC_ASSERT(comb.machine_views.size() == - m0.machine_views.size() + m1.machine_views.size()); - RC_ASSERT(is_submap(comb.machine_views, m0.machine_views)); - RC_ASSERT(is_submap(comb.machine_views, m1.machine_views)); - }); -} - -TEST_CASE("OptimalCostResult::infinity") { - rc::check([](OptimalCostResult const &c) { - RC_ASSERT(c.runtime <= OptimalCostResult::infinity().runtime); - }); -} diff --git a/lib/compiler/test/test_machine_view.cc b/lib/compiler/test/test_machine_view.cc deleted file mode 100644 index eea084db48..0000000000 --- a/lib/compiler/test/test_machine_view.cc +++ /dev/null @@ -1,33 +0,0 @@ -#include "flexflow/config.h" -#include "flexflow/machine_view.h" -#include "gtest/gtest.h" - -using namespace Legion; -using namespace FlexFlow; - -TEST(machine_view_get_domain, basic) { - MachineView mv; - mv.ndims = 1; - mv.start_device_id = 2; - mv.dim[0] = 2; - mv.stride[0] = 1; - - Domain d; - d.dim = 1; - d.rect_data[0] = 0; - d.rect_data[0 + d.dim] = - 1; // Domain is includes, MachineView is exclusive on hi - - EXPECT_EQ(mv.get_domain(), d); -} - -TEST(machine_view_get_device_id, basic) { - MachineView mv; - mv.ndims = 1; - mv.start_device_id = 2; - mv.dim[0] = 2; - mv.stride[0] = 1; - - EXPECT_EQ(mv.get_device_id({0}), 2); - EXPECT_EQ(mv.get_device_id({1}), 3); -} diff --git a/lib/compiler/test/test_open_graph.cc b/lib/compiler/test/test_open_graph.cc deleted file mode 100644 index d96cdec467..0000000000 --- a/lib/compiler/test/test_open_graph.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "doctest.h" - -using namespace FlexFlow; - -TEST_CASE("get_source_sink_open_graph:basic") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - - int s0 = 100000; - - Node n0 = g.add_node(); - - g.add_edge(InputMultiDiEdge({s0, n0.value()}, n0, 0)); - - CHECK(get_closed_sources(g) == std::unordered_set{}); - CHECK(get_closed_sinks(g) == std::unordered_set{n0}); - - CHECK(get_open_sources(g) == std::unordered_set{n0}); - CHECK(get_open_sinks(g) == std::unordered_set{}); -} - -TEST_CASE("get_source_sink_open_graph:unconnected") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - int s0 = 100000; - int t0 = s0 + 1; - - Node n0 = g.add_node(); - Node n1 = g.add_node(); - - g.add_edge(InputMultiDiEdge({s0, n0.value()}, n0, 0)); - g.add_edge(OutputMultiDiEdge({n1.value(), t0}, n1, 0)); - - /* - g: ->n0 - n1-> - */ - - CHECK(get_closed_sources(g) == std::unordered_set{n1}); - CHECK(get_closed_sinks(g) == std::unordered_set{n0}); - - CHECK(get_open_sources(g) == std::unordered_set{n0}); - CHECK(get_open_sinks(g) == std::unordered_set{n1}); -} - -TEST_CASE("get_source_sink_open_graph:complex") { - OpenMultiDiGraph g(LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>()); - int s0 = 100000; - int s1 = s0 + 1; - int t0 = s1 + 1; - int t1 = t0 + 1; - - std::vector ns; - for (int i = 0; i < 8; ++i) { - ns.push_back(g.add_node()); - } - - g.add_edge(InputMultiDiEdge({s0, ns[0].value()}, ns[0], 0)); - g.add_edge(MultiDiEdge(ns[0], ns[1], 0, 0)); - g.add_edge(OutputMultiDiEdge({ns[1].value(), t0}, ns[1], 0)); - g.add_edge(OutputMultiDiEdge({ns[1].value(), t1}, ns[1], 1)); - - g.add_edge(MultiDiEdge(ns[2], ns[3], 0, 0)); - g.add_edge(MultiDiEdge(ns[2], ns[4], 1, 0)); - g.add_edge(MultiDiEdge(ns[4], ns[3], 0, 1)); - g.add_edge(OutputMultiDiEdge({ns[3].value(), t1}, ns[3], 0)); - - g.add_edge(InputMultiDiEdge({s0, ns[5].value()}, ns[5], 0)); - g.add_edge(InputMultiDiEdge({s1, ns[5].value()}, ns[5], 1)); - g.add_edge(MultiDiEdge(ns[5], ns[6], 0, 0)); - g.add_edge(MultiDiEdge(ns[6], ns[7], 0, 0)); - - CHECK(get_closed_sources(g) == std::unordered_set{ns[2]}); - CHECK(get_closed_sinks(g) == std::unordered_set{ns[7]}); - - CHECK(get_open_sources(g) == std::unordered_set{ns[1], ns[5]}); - CHECK(get_open_sinks(g) == std::unordered_set{ns[1], ns[3]}); -} - -TEST_CASE("get_cut") { - auto g = LabelledOpenMultiDiGraph::create< - UnorderedLabelledOpenMultiDiGraph>; - - std::vector ns = add_nodes(g, 5); - - int t0 = 100000; - - MultiDiEdge e0(ns[0], ns[1], 0, 0); - MultiDiEdge e1(ns[1], ns[2], 0, 0); - MultiDiEdge e2(ns[1], ns[3], 1, 0); - MultiDiEdge e3(ns[2], ns[4], 0, 0); - MultiDiEdge e4(ns[3], ns[4], 0, 1); - OutputMultiDiEdge e5({ns[4].value(), t0}, ns[4], 0); - - GraphSplit gs0{{ns[0], ns[1]}, {ns[2], ns[3], ns[4]}}; - CHECK(get_cut(g, gs0) == std::unordered_set{e1, e2}); - - GraphSplit gs1{{ns[0], ns[1], ns[2], ns[3]}, {ns[4]}}; - CHECK(get_cut(g, gs1) == std::unordered_set{e3, e4}); -} diff --git a/lib/compiler/test/test_optimal_cost.cc b/lib/compiler/test/test_optimal_cost.cc deleted file mode 100644 index 2d9414ba27..0000000000 --- a/lib/compiler/test/test_optimal_cost.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "test_cost_estimator.h" -#include "test_generator.h" - -/* -Tests whether optimal_cost can give a valid result given random PCG, trivial -allowed machine views, trivial cost estimator and random machine specification. -*/ -TEST_CASE("optimal_cost") { - auto test_allowed_machine_views = [](Operator const &, - MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(0, 1, 1)}; - }; - rc::check([](ParallelComputationGraph const &g, - MachineSpecification const &machine_spec) { - OptimalCostCache cached_subgraph_costs; - OptimalCostResult result = optimal_cost(g, - test_allowed_machine_views, - TestCostEstimator{}, - machine_spec, - cached_subgraph_costs); - RC_ASSERT(result.runtime > 0); - RC_ASSERT(keys(result.machine_mapping.machine_views) == get_nodes(g)); - }); -} diff --git a/lib/compiler/test/test_parallel_config.cc b/lib/compiler/test/test_parallel_config.cc deleted file mode 100644 index 843879bb0d..0000000000 --- a/lib/compiler/test/test_parallel_config.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "flexflow/config.h" -#include "flexflow/model.h" -#include "gtest/gtest.h" - -using namespace FlexFlow; - -TEST(change_data_parallel_dimensionality, basic_reduce) { - ParallelConfig pc = get_basic_data_parallel_config(8, 4); - - ParallelConfig expected = get_basic_data_parallel_config(8, 2); - - ParallelConfig result = pc.change_data_parallel_dimensionality(2); - - EXPECT_EQ(result, expected); -} - -TEST(change_data_parallel_dimensionality, basic_expand) { - ParallelConfig pc = get_basic_data_parallel_config(8, 2); - - ParallelConfig expected = get_basic_data_parallel_config(8, 4); - - ParallelConfig result = pc.change_data_parallel_dimensionality(4); - - EXPECT_EQ(result, expected); -} diff --git a/lib/compiler/test/test_random_utils.cc b/lib/compiler/test/test_random_utils.cc deleted file mode 100644 index c7b4f9e5c2..0000000000 --- a/lib/compiler/test/test_random_utils.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "flexflow/utils/random_utils.h" -#include "gtest/gtest.h" - -TEST(select_random, basic) { - std::vector values{1, 2, 3, 4}; - std::vector weights{0.1, 0.2, 0.3, 0.4}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.05), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.25), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 3); - EXPECT_EQ(select_random_determistic(values, weights, 0.9), 4); -} - -TEST(select_random, bounds) { - std::vector values{1, 2, 3}; - std::vector weights{0.2, 0.3, 0.5}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.0), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.2), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 3); - EXPECT_EQ(select_random_determistic(values, weights, 1.0), 3); -} - -TEST(select_random, singleton) { - std::vector values{1}; - std::vector weights{1.0}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.0), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 1); - EXPECT_EQ(select_random_determistic(values, weights, 1.0), 1); -} - -TEST(select_random, empty) { - std::vector values{}; - std::vector weights{}; - EXPECT_THROW(select_random_determistic(values, weights, 0.5), - std::invalid_argument); -} - -TEST(select_random, unnormalized_weights) { - std::vector values{1, 2, 3}; - std::vector weights{1.0, 2.0, 2.0}; - - EXPECT_EQ(select_random_determistic(values, weights, 0.1), 1); - EXPECT_EQ(select_random_determistic(values, weights, 0.5), 2); - EXPECT_EQ(select_random_determistic(values, weights, 0.9), 3); -} diff --git a/lib/compiler/test/test_substitution_loader.cc b/lib/compiler/test/test_substitution_loader.cc deleted file mode 100644 index b0531b598a..0000000000 --- a/lib/compiler/test/test_substitution_loader.cc +++ /dev/null @@ -1,144 +0,0 @@ -#include "flexflow/substitution.h" -#include "flexflow/substitution_loader.h" -#include "gtest/gtest.h" - -namespace sl = FlexFlow::substitution_loader; -// using namespace FlexFlow::substitution_loader; -using json = nlohmann::json; -using FlexFlow::PCG::create_xfer; -using FlexFlow::PCG::create_xfers; -using FlexFlow::PCG::GraphXfer; - -TEST(substitution_loader, basic) { - // Yes, I know this substitution is not correct. It's just for testing. - - sl::Rule example_rule; - - example_rule.name = "test_rule"; - - sl::Tensor input_tensor1; - input_tensor1.opId = -1; - input_tensor1.tsId = 0; - - sl::Tensor input_tensor2; - input_tensor2.opId = -2; - input_tensor2.tsId = 0; - - sl::Operator srcOp1; - srcOp1.op_type = OP_EW_ADD; - srcOp1.input = {input_tensor1, input_tensor2}; - srcOp1.para = {}; - - sl::Tensor srcOp1Output; - srcOp1Output.opId = 0; - srcOp1Output.tsId = 0; - - sl::Parameter activation_constraint; - activation_constraint.key = PM_ACTI; - activation_constraint.value = AC_MODE_NONE; - - sl::Operator srcOp2; - srcOp2.op_type = OP_LINEAR; - srcOp2.input = {srcOp1Output}; - srcOp2.para = {activation_constraint}; - - sl::Operator dstOp1; - dstOp1.op_type = OP_LINEAR; - dstOp1.input = {input_tensor1}; - dstOp1.para = {activation_constraint}; - - sl::Tensor dstOp1Output; - dstOp1Output.opId = 0; - dstOp1Output.tsId = 0; - - sl::Operator dstOp2; - dstOp2.op_type = OP_LINEAR; - dstOp2.input = {input_tensor2}; - dstOp2.para = {activation_constraint}; - - sl::Tensor dstOp2Output; - dstOp2Output.opId = 1; - dstOp2Output.tsId = 0; - - sl::Operator dstOp3; - dstOp3.op_type = OP_EW_ADD; - dstOp3.input = {dstOp1Output, dstOp2Output}; - dstOp3.para = {}; - - sl::MapOutput map_output; - map_output.srcOpId = 1; - map_output.srcTsId = 0; - map_output.dstOpId = 2; - map_output.dstTsId = 0; - - example_rule.srcOp = {srcOp1, srcOp2}; - example_rule.dstOp = {dstOp1, dstOp2, dstOp3}; - example_rule.mappedOutput = {map_output}; - - GraphXfer *xfer = new GraphXfer(nullptr); - create_xfer(*xfer, example_rule, 2); - - EXPECT_EQ(xfer->name, "test_rule"); - - EXPECT_EQ(xfer->srcOps.size(), 2); - EXPECT_EQ(xfer->srcOps[0]->type, OP_EW_ADD); - EXPECT_EQ(xfer->srcOps[1]->type, OP_LINEAR); - EXPECT_EQ(xfer->srcOps[0]->inputs.size(), 2); - EXPECT_NE(xfer->srcOps[0]->inputs[0], xfer->srcOps[0]->inputs[1]); - EXPECT_EQ(xfer->srcOps[0]->outputs.size(), 1); - EXPECT_EQ(xfer->srcOps[1]->inputs.size(), 1); - EXPECT_EQ(xfer->srcOps[0]->outputs[0], xfer->srcOps[1]->inputs[0]); - EXPECT_EQ(xfer->srcOps[1]->outputs.size(), 1); - - EXPECT_EQ(xfer->dstOps.size(), 3); - EXPECT_EQ(xfer->dstOps[0]->type, OP_LINEAR); - EXPECT_EQ(xfer->dstOps[1]->type, OP_LINEAR); - EXPECT_EQ(xfer->dstOps[2]->type, OP_EW_ADD); - EXPECT_EQ(xfer->dstOps[0]->inputs.size(), 1); - EXPECT_EQ(xfer->dstOps[0]->outputs.size(), 1); - EXPECT_EQ(xfer->dstOps[0]->inputs[0], xfer->srcOps[0]->inputs[0]); - EXPECT_EQ(xfer->dstOps[1]->inputs.size(), 1); - EXPECT_EQ(xfer->dstOps[1]->outputs.size(), 1); - EXPECT_EQ(xfer->dstOps[1]->inputs[0], xfer->srcOps[0]->inputs[1]); - EXPECT_EQ(xfer->dstOps[2]->inputs.size(), 2); - EXPECT_EQ(xfer->dstOps[2]->inputs[0], xfer->dstOps[0]->outputs[0]); - EXPECT_EQ(xfer->dstOps[2]->inputs[1], xfer->dstOps[1]->outputs[0]); - EXPECT_NE(xfer->dstOps[2]->inputs[0], xfer->dstOps[2]->inputs[1]); - EXPECT_EQ(xfer->dstOps[2]->outputs.size(), 1); - - EXPECT_EQ(xfer->mappedOutputs.size(), 1); - EXPECT_NE(xfer->srcOps[1]->outputs[0], xfer->dstOps[2]->outputs[0]); - EXPECT_EQ(xfer->mappedOutputs.at(xfer->srcOps[1]->outputs[0]), - xfer->dstOps[2]->outputs[0]); -} - -TEST(substitution_loader, operator_deserialization) { - json j = { - {"_t", "Operator"}, - {"input", - std::vector{{{"_t", "Tensor"}, {"opId", -2}, {"tsId", 0}}, - {{"_t", "Tensor"}, {"opId", -3}, {"tsId", 0}}}}, - {"para", std::vector{}}, - {"type", "OP_EW_ADD"}, - }; - - sl::Operator o; - from_json(j, o); - - EXPECT_EQ(o.op_type, OP_EW_ADD); - EXPECT_EQ(o.input.size(), 2); - EXPECT_EQ(o.input[0].opId, -2); - EXPECT_EQ(o.input[0].tsId, 0); - EXPECT_EQ(o.input[1].opId, -3); - EXPECT_EQ(o.input[1].tsId, 0); - EXPECT_EQ(o.para.size(), 0); -} - -// TEST(substitution_loader, load_full_file) { -// sl::RuleCollection collection = -// sl::load_rule_collection_from_path("tests/unit/graph_subst_3_v2.json"); -// EXPECT_EQ(collection.rules.size(), 640); - -// std::vector xfers = create_xfers(nullptr, collection, 2); -// EXPECT_EQ(xfers.size(), 640); -// } diff --git a/lib/compiler/test/test_unity_algorithm.cc b/lib/compiler/test/test_unity_algorithm.cc deleted file mode 100644 index 6a0131dd77..0000000000 --- a/lib/compiler/test/test_unity_algorithm.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "compiler/unity_algorithm.h" -#include "test_cost_estimator.h" -#include "test_generator.h" - -TEST_CASE("graph_optimize") { - rc::check([](ComputationGraph const &g, - float alpha, - int budget, - float threshold, - int max_num_ops) { - Strategy s = graph_optimize( - g, - TestCostEstimator{}, - MachineSpecification{1, 1, 4, 0.1, 0.2}, - [](Operator const &, MachineSpecification const &) { - return std::unordered_set{make_1d_machine_view(0, 1, 1)}; - }, - OptimizerConfig{alpha, budget, threshold, max_num_ops}); - RC_ASSERT(get_nodes(s.pcg).size() > 0); - RC_ASSERT(s.machine_mapping.runtime > 0); - RC_ASSERT(keys(s.machine_mapping.machine_views) == get_nodes(s.pcg)); - }); -} diff --git a/lib/op-attrs/include/op-attrs/operator_attrs.h b/lib/op-attrs/include/op-attrs/operator_attrs.h index 9da787cbf8..b63563cd67 100644 --- a/lib/op-attrs/include/op-attrs/operator_attrs.h +++ b/lib/op-attrs/include/op-attrs/operator_attrs.h @@ -32,6 +32,7 @@ #include "ops/topk.h" #include "ops/transpose.h" #include "utils/variant.h" +#include namespace FlexFlow { diff --git a/lib/op-attrs/src/attention.cc b/lib/op-attrs/src/attention.cc index 4b6c53897c..2c1500a477 100644 --- a/lib/op-attrs/src/attention.cc +++ b/lib/op-attrs/src/attention.cc @@ -91,7 +91,14 @@ TensorShape get_output_shape(MultiHeadAttentionAttrs const &attrs, static_cast(value_shape)); return get_tensor_shape_unsafe(parallel_shape); } +TensorShape get_output_shape(MultiHeadAttentionAttrs const &, + MultiHeadAttentionInputs const &) { + NOT_IMPLEMENTED(); +} +int get_oSize(ParallelTensorShape const &) { + NOT_IMPLEMENTED(); +} } // namespace FlexFlow // Tensor FFModel::multihead_attention(const Tensor query, diff --git a/lib/op-attrs/src/embedding.cc b/lib/op-attrs/src/embedding.cc index 02cbfaa031..56014fcc67 100644 --- a/lib/op-attrs/src/embedding.cc +++ b/lib/op-attrs/src/embedding.cc @@ -1,3 +1,9 @@ #include "op-attrs/ops/embedding.h" -namespace FlexFlow {} // namespace FlexFlow +namespace FlexFlow { + +TensorShape get_weights_shape(EmbeddingAttrs const &, TensorShape const &) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/op-attrs/src/get_output_shapes.cc b/lib/op-attrs/src/get_output_shapes.cc index d649856152..c20d4be34c 100644 --- a/lib/op-attrs/src/get_output_shapes.cc +++ b/lib/op-attrs/src/get_output_shapes.cc @@ -5,6 +5,12 @@ namespace FlexFlow { ParallelTensorShape as_parallel(TensorShape const &); std::vector as_parallel(std::vector const &); +std::vector get_output_shapes( + PCGOperatorAttrs const &op_params, + std::vector const &input_tensor_shapes) { + NOT_IMPLEMENTED(); +} + // TensorShape get_output_shape(AggregateAttrs const &attrs, // TensorShape const &gate_preds, // TensorShape const &gate_assign, diff --git a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc index 11cfbc125c..c7e70bb906 100644 --- a/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc +++ b/lib/op-attrs/src/parallel_dim_mapping_record_solver.cc @@ -351,4 +351,12 @@ void construct_output_parallel_dims( /* return solution; */ /* } */ +ParallelDimMappingSolution solve_parallel_dim_mappings( + std::vector const &mappings, + std::vector const &input, + int numWeights, + int numOutputs) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index 50c2558e39..b118d69259 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -3,6 +3,7 @@ #include "device_type.h" #include "utils/strong_typedef.h" +#include namespace FlexFlow { diff --git a/lib/pcg/include/pcg/machine_specification.h b/lib/pcg/include/pcg/machine_specification.h index 55f80e3cc0..1b2a02b070 100644 --- a/lib/pcg/include/pcg/machine_specification.h +++ b/lib/pcg/include/pcg/machine_specification.h @@ -11,22 +11,21 @@ struct BandwidthNetworkModelConfig int bandwidth; }; -struct MachineSpecification : public use_visitable_cmp { +struct MachineSpecification { int num_nodes; int num_cpus_per_node; int num_gpus_per_node; float inter_node_bandwidth; - float intra_node_bandwidth; + req intra_node_bandwidth; }; -} // namespace FlexFlow +FF_VISITABLE_STRUCT(MachineSpecification, + num_nodes, + num_cpus_per_node, + num_gpus_per_node, + inter_node_bandwidth, + intra_node_bandwidth); -VISITABLE_STRUCT(::FlexFlow::MachineSpecification, - num_nodes, - num_cpus_per_node, - num_gpus_per_node, - inter_node_bandwidth, - intra_node_bandwidth); -MAKE_VISIT_HASHABLE(::FlexFlow::MachineSpecification); +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 1a5c2bc3f8..7521cd209a 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -12,10 +12,7 @@ namespace FlexFlow { -struct MachineView : public use_visitable_cmp { - MachineView() = delete; - MachineView(device_id_t const &, StridedRectangle const &); - +struct MachineView { std::vector device_ids() const; device_id_t at(FFOrdered const &coord) const; @@ -26,6 +23,8 @@ struct MachineView : public use_visitable_cmp { StridedRectangle rect; }; +FF_VISITABLE_STRUCT(MachineView, start, rect); + std::size_t num_dims(MachineView const &); std::size_t num_devices(MachineView const &); DeviceType get_device_type(MachineView const &); @@ -43,7 +42,4 @@ MachineView make_1d_machine_view(device_id_t start, size_t interval_size); } // namespace FlexFlow -VISITABLE_STRUCT(::FlexFlow::MachineView, start, rect); -MAKE_VISIT_HASHABLE(::FlexFlow::MachineView); - #endif diff --git a/lib/pcg/include/pcg/operator.h b/lib/pcg/include/pcg/operator.h index 5804e38f95..bb9a4cf5e4 100644 --- a/lib/pcg/include/pcg/operator.h +++ b/lib/pcg/include/pcg/operator.h @@ -2,31 +2,26 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_OPERATOR_H #include "op-attrs/operator_attrs.h" -#include "utils/optional.h" #include "utils/stack_string.h" #include "utils/visitable.h" +#include + namespace FlexFlow { -struct Operator : public use_visitable_cmp { +struct Operator { public: - Operator() = delete; - Operator(PCGOperatorAttrs const &attrs, - std::optional const &name); - operator PCGOperatorAttrs() const; public: PCGOperatorAttrs attrs; + req> name; }; -} // namespace FlexFlow +FF_VISITABLE_STRUCT(Operator, attrs, name); -VISITABLE_STRUCT(::FlexFlow::Operator, attrs); -MAKE_VISIT_HASHABLE(::FlexFlow::Operator); +static_assert(is_well_behaved_value_type::value); -namespace FlexFlow { -static_assert(is_well_behaved_value_type::value, ""); -} +} // namespace FlexFlow #endif diff --git a/lib/pcg/include/pcg/optimizer.h b/lib/pcg/include/pcg/optimizer.h index df5bddf729..0bb3fab974 100644 --- a/lib/pcg/include/pcg/optimizer.h +++ b/lib/pcg/include/pcg/optimizer.h @@ -7,21 +7,21 @@ namespace FlexFlow { struct SGDOptimizer { - req lr; - req momentum; - req nesterov; + double lr; + double momentum; + bool nesterov; req weight_decay; }; FF_VISITABLE_STRUCT(SGDOptimizer, lr, momentum, nesterov, weight_decay); struct AdamOptimizer { - req alpha; - req beta1; - req beta2; - req weight_decay; - req epsilon; - req alpha_t; - req beta_t; + double alpha; + double beta1; + double beta2; + double weight_decay; + double epsilon; + double alpha_t; + double beta_t; req beta2_t; }; FF_VISITABLE_STRUCT(AdamOptimizer, @@ -34,7 +34,7 @@ FF_VISITABLE_STRUCT(AdamOptimizer, beta_t, beta2_t); -using Optimizer = variant; +using Optimizer = std::variant; } // namespace FlexFlow diff --git a/lib/pcg/include/pcg/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph.h index 7e332933c7..39a69a80ab 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph.h @@ -15,6 +15,17 @@ struct ParallelComputationGraph }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_HASH(ParallelComputationGraph); +bool operator==(ParallelComputationGraph const &, + ParallelComputationGraph const &); + } // namespace FlexFlow +namespace std { + +template <> +struct hash { + size_t operator()(FlexFlow::ParallelComputationGraph const &g) const; +}; +} // namespace std + #endif diff --git a/lib/pcg/include/pcg/parallel_tensor.h b/lib/pcg/include/pcg/parallel_tensor.h index c3f7ebdfed..652b408c15 100644 --- a/lib/pcg/include/pcg/parallel_tensor.h +++ b/lib/pcg/include/pcg/parallel_tensor.h @@ -47,6 +47,8 @@ struct ParallelTensor : public use_visitable_cmp { std::optional sync_type = std::nullopt, std::optional initializer = std::nullopt); + ParallelTensorShape get_shape() const; + public: ParallelTensorDims dims; DataType data_type; diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 28331f441c..d123d7c6ac 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -17,7 +17,7 @@ struct side_size_t : public strong_typedef { using strong_typedef::strong_typedef; }; -struct StridedRectangleSide : public use_visitable_cmp { +struct StridedRectangleSide { public: StridedRectangleSide() = delete; StridedRectangleSide(num_points_t const &, int stride); @@ -32,14 +32,15 @@ struct StridedRectangleSide : public use_visitable_cmp { public: num_points_t num_points; - int stride; + req stride; }; -struct StridedRectangle : public use_visitable_cmp { -public: - StridedRectangle() = delete; - StridedRectangle(std::vector const &); +FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(StridedRectangleSide, + num_points, + stride); +struct StridedRectangle { +public: size_t at(FFOrdered const &) const; StridedRectangleSide at(ff_dim_t const &) const; size_t num_dims() const; @@ -47,6 +48,9 @@ struct StridedRectangle : public use_visitable_cmp { public: FFOrdered sides; }; + +FF_VISITABLE_STRUCT(StridedRectangle, sides); + } // namespace FlexFlow MAKE_TYPEDEF_HASHABLE(::FlexFlow::num_points_t); @@ -55,10 +59,4 @@ MAKE_TYPEDEF_PRINTABLE(::FlexFlow::num_points_t, "num_points"); MAKE_TYPEDEF_HASHABLE(::FlexFlow::side_size_t); MAKE_TYPEDEF_PRINTABLE(::FlexFlow::side_size_t, "side_size"); -VISITABLE_STRUCT(::FlexFlow::StridedRectangleSide, num_points, stride); -MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangleSide); - -VISITABLE_STRUCT(::FlexFlow::StridedRectangle, sides); -MAKE_VISIT_HASHABLE(::FlexFlow::StridedRectangle); - #endif diff --git a/lib/pcg/src/machine_view.cc b/lib/pcg/src/machine_view.cc index 9edfb09a8e..46f87833f0 100644 --- a/lib/pcg/src/machine_view.cc +++ b/lib/pcg/src/machine_view.cc @@ -3,9 +3,6 @@ namespace FlexFlow { -MachineView::MachineView(device_id_t const &start, StridedRectangle const &rect) - : start(start), rect(rect) {} - static StridedRectangle make_1d_rect(int start, int stop, int stride) { assert(stop > start); assert(stride > 0); diff --git a/lib/pcg/src/operator.cc b/lib/pcg/src/operator.cc index 92ece9a2bf..9d36ae1b25 100644 --- a/lib/pcg/src/operator.cc +++ b/lib/pcg/src/operator.cc @@ -2,10 +2,6 @@ namespace FlexFlow { -Operator::Operator(PCGOperatorAttrs const &attrs, - std::optional const &name) - : attrs(attrs) {} - Operator::operator PCGOperatorAttrs() const { return attrs; } diff --git a/lib/pcg/src/parallel_computation_graph.cc b/lib/pcg/src/parallel_computation_graph.cc new file mode 100644 index 0000000000..011c40eb4c --- /dev/null +++ b/lib/pcg/src/parallel_computation_graph.cc @@ -0,0 +1,40 @@ +#include "pcg/parallel_computation_graph.h" +#include "utils/graph/algorithms.h" + +namespace FlexFlow { + +bool operator==(ParallelComputationGraph const &lhs, + ParallelComputationGraph const &rhs) { + return std::hash{}(lhs) == + std::hash{}(rhs); +} + +} // namespace FlexFlow + +namespace std { + +size_t hash::operator()( + FlexFlow::ParallelComputationGraph const &g) const { + using namespace FlexFlow; + + size_t h = 0; + + std::vector ordered_nodes = get_topological_ordering(g.value()); + hash_combine(h, ordered_nodes.size()); + + std::unordered_map node_index; + for (int i = 0; i < ordered_nodes.size(); ++i) { + node_index[ordered_nodes[i]] = i; + hash_combine(h, g.value().at(ordered_nodes[i])); + } + + for (MultiDiEdge const &edge : get_edges(g.value())) { + hash_combine(h, node_index.at(edge.src)); + hash_combine(h, node_index.at(edge.dst)); + hash_combine(h, g.value().at(edge)); + } + + return h; +} + +} // namespace std diff --git a/lib/pcg/src/parallel_tensor.cc b/lib/pcg/src/parallel_tensor.cc index 19dc1e96d3..ff53e456ec 100644 --- a/lib/pcg/src/parallel_tensor.cc +++ b/lib/pcg/src/parallel_tensor.cc @@ -10,4 +10,8 @@ ParallelTensor::ParallelTensor(ParallelTensorDims const &dims, : dims(dims), data_type(data_type), sync_type(sync_type), initializer(initializer), create_gradients(create_gradients) {} +ParallelTensorShape ParallelTensor::get_shape() const { + return ParallelTensorShape(dims, data_type); +} + } // namespace FlexFlow diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 29dcae6151..27ef9a7f5b 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -30,8 +30,8 @@ side_size_t StridedRectangleSide::get_size() const { NOT_IMPLEMENTED(); } -StridedRectangle::StridedRectangle( - std::vector const &sides) - : sides(sides) {} +size_t StridedRectangle::num_dims() const { + NOT_IMPLEMENTED(); +} } // namespace FlexFlow diff --git a/lib/runtime/CMakeLists.txt b/lib/runtime/CMakeLists.txt index 49b052ec2b..fd5b4991ef 100644 --- a/lib/runtime/CMakeLists.txt +++ b/lib/runtime/CMakeLists.txt @@ -17,18 +17,18 @@ ff_add_library( pcg ) -ff_add_test_executable( - NAME - runtime-test - SRC_PATTERNS - test/src/*.cc - PUBLIC_INCLUDE - include/ - PRIVATE_INCLUDE - test/src/ src/ - DEPS - runtime - doctest -) +# ff_add_test_executable( +# NAME +# runtime-test +# SRC_PATTERNS +# test/src/*.cc +# PUBLIC_INCLUDE +# include/ +# PRIVATE_INCLUDE +# test/src/ src/ +# DEPS +# runtime +# doctest +# ) add_subdirectory(ffi) diff --git a/lib/substitutions/include/substitutions/attribute_expr.h b/lib/substitutions/include/substitutions/attribute_expr.h index d6902d1274..0afd48b431 100644 --- a/lib/substitutions/include/substitutions/attribute_expr.h +++ b/lib/substitutions/include/substitutions/attribute_expr.h @@ -19,7 +19,7 @@ struct ListSize { }; template -using AttributeExpr = variant, ListSize>; +using AttributeExpr = std::variant, ListSize>; template struct AttributeConstraint { diff --git a/lib/substitutions/include/substitutions/get_attribute.h b/lib/substitutions/include/substitutions/get_attribute.h index 50c4108a67..0e6dd4c69b 100644 --- a/lib/substitutions/include/substitutions/get_attribute.h +++ b/lib/substitutions/include/substitutions/get_attribute.h @@ -7,58 +7,58 @@ namespace FlexFlow { -optional get_attribute(PCGOperatorAttrs const &, - OperatorAttributeKey); -optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey); -optional get_attribute(CastAttrs const &p, - OperatorAttributeKey); -optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey); -optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey); -optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ElementScalarUnaryAttrs const &p, - OperatorAttributeKey); -optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey); -optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey); -optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey); -optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey); -optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey); -optional get_attribute(MultiHeadAttentionAttrs const &p, - OperatorAttributeKey); -optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey); -optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey); -optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey); -optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey); -optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey); -optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey); -optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey); +std::optional get_attribute(PCGOperatorAttrs const &, + OperatorAttributeKey); +std::optional get_attribute(BatchMatmulAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(CastAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(CombineAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(ConcatAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(Conv2DAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(ElementBinaryAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(ElementUnaryAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(DropoutAttrs const &p, + OperatorAttributeKey); +std::optional + get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(EmbeddingAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(FlatAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(GatherAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(LayerNormAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(LinearAttrs const &p, + OperatorAttributeKey); +std::optional + get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey); +std::optional get_attribute(Pool2DAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(ReduceAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(ReductionAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(RepartitionAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(ReplicateAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(ReshapeAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(SplitAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(SoftmaxAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(TopKAttrs const &p, + OperatorAttributeKey); +std::optional get_attribute(TransposeAttrs const &p, + OperatorAttributeKey); // optional get_attribute(FusedParallelOpAttrs const &p, // OperatorAttributeKey); diff --git a/lib/substitutions/include/substitutions/operator_pattern.h b/lib/substitutions/include/substitutions/operator_pattern.h index 9392a7876e..8fc4ebefc2 100644 --- a/lib/substitutions/include/substitutions/operator_pattern.h +++ b/lib/substitutions/include/substitutions/operator_pattern.h @@ -70,21 +70,22 @@ enum class OperatorAttributeKey { NUM_INPUTS }; -using OperatorAttributeValue = variant, - stack_vector, - OperatorType, - Activation, - ff_dim_t, - unsigned long long, - AggregateOp, - stack_vector, - optional, - PoolOp, - TensorShape, - DataType>; +using OperatorAttributeValue = + std::variant, + stack_vector, + OperatorType, + Activation, + ff_dim_t, + unsigned long long, + AggregateOp, + stack_vector, + std::optional, + PoolOp, + TensorShape, + DataType>; FF_VISITABLE_STRUCT(ListIndexAccess, attribute_key, @@ -97,7 +98,7 @@ using OperatorAttributeConstraint = using OperatorPattern = AttributePattern; -optional +std::optional evaluate_attribute_expr(Operator const &attrs, AttributeExpr const &expr); diff --git a/lib/substitutions/include/substitutions/output_graph.h b/lib/substitutions/include/substitutions/output_graph.h index b9cf1f53f3..4ed90aed06 100644 --- a/lib/substitutions/include/substitutions/output_graph.h +++ b/lib/substitutions/include/substitutions/output_graph.h @@ -15,7 +15,7 @@ struct AttrConstant { OperatorAttributeValue value; }; -using OperatorAttributeExpr = variant; +using OperatorAttributeExpr = std::variant; // NOTE(@wmdi): Not sure if it aligns with other design. Or alternatively we can // define the assignment for each operator type. diff --git a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h index d07a1da23b..741554142f 100644 --- a/lib/substitutions/include/substitutions/parallel_tensor_pattern.h +++ b/lib/substitutions/include/substitutions/parallel_tensor_pattern.h @@ -8,7 +8,7 @@ namespace FlexFlow { enum class TensorAttributeKey { DIM_SIZES, DIM_DEGREES }; -using TensorAttributeValue = variant>; +using TensorAttributeValue = std::variant>; using TensorAttributeConstraint = AttributeConstraint; @@ -16,7 +16,7 @@ using TensorAttributeConstraint = using ParallelTensorPattern = AttributePattern; -optional +std::optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, AttributeExpr const &expr); diff --git a/lib/substitutions/include/substitutions/substitution.h b/lib/substitutions/include/substitutions/substitution.h index a52906c612..8dbe4e66cf 100644 --- a/lib/substitutions/include/substitutions/substitution.h +++ b/lib/substitutions/include/substitutions/substitution.h @@ -28,4 +28,12 @@ SubParallelComputationGraph } // namespace FlexFlow +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::Substitution const &) const; +}; + +}; // namespace std + #endif diff --git a/lib/substitutions/src/graph_pattern.cc b/lib/substitutions/src/graph_pattern.cc index 1dba5c4af8..296a975626 100644 --- a/lib/substitutions/src/graph_pattern.cc +++ b/lib/substitutions/src/graph_pattern.cc @@ -9,51 +9,52 @@ namespace FlexFlow { -optional +std::optional evaluate_list_index_access(int index, - optional const &v) { + std::optional const &v) { if (!v.has_value() || - !holds_alternative>(v.value()) || - !holds_alternative>(v.value())) { - return nullopt; + !std::holds_alternative>(v.value()) || + !std::holds_alternative>( + v.value())) { + return std::nullopt; } if (index >= MAX_TENSOR_DIM) { - return nullopt; + return std::nullopt; } - if (holds_alternative>(v.value())) { + if (std::holds_alternative>(v.value())) { return get>(v.value()).at(index); } else { return get>(v.value()).at(index); } } -optional +std::optional evaluate_list_index_access(int const &index, - optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; + std::optional const &v) { + if (!v.has_value() || !std::holds_alternative>(v.value())) { + return std::nullopt; } auto vec = get>(v.value()); if (index >= vec.size()) { - return nullopt; + return std::nullopt; } return vec.at(index); } -optional - evaluate_list_size(optional const &v) { +std::optional + evaluate_list_size(std::optional const &v) { return MAX_TENSOR_DIM; } -optional - evaluate_list_size(optional const &v) { - if (!v.has_value() || !holds_alternative>(v.value())) { - return nullopt; +std::optional + evaluate_list_size(std::optional const &v) { + if (!v.has_value() || !std::holds_alternative>(v.value())) { + return std::nullopt; } return (int)get>(v.value()).size(); @@ -62,20 +63,21 @@ optional struct EvaluateOperatorAttributeExpr { EvaluateOperatorAttributeExpr(Operator const &attrs) : attrs(attrs) {} - optional operator()(OperatorAttributeKey const &key) { + std::optional + operator()(OperatorAttributeKey const &key) { return get_attribute(this->attrs.attrs, key); } - optional + std::optional operator()(ListIndexAccess const &index_access) { - optional v = + std::optional v = get_attribute(this->attrs.attrs, index_access.attribute_key); return evaluate_list_index_access(index_access.index, v); } - optional + std::optional operator()(ListSize const &list_size) { - optional v = + std::optional v = get_attribute(this->attrs.attrs, list_size.attribute_key); return evaluate_list_size(v); } @@ -84,7 +86,7 @@ struct EvaluateOperatorAttributeExpr { Operator attrs; }; -optional +std::optional evaluate_tensor_attribute_expr(ParallelTensor const &, AttributeExpr const &); @@ -93,11 +95,11 @@ struct EvaluateTensorAttributeExpr { : tensor_shape(tensor_shape) {} template - optional evaluate(T const &t) { + std::optional evaluate(T const &t) { return this->operator()(t); } - optional operator()(TensorAttributeKey key) { + std::optional operator()(TensorAttributeKey key) { switch (key) { case TensorAttributeKey::DIM_SIZES: { std::vector result; @@ -118,14 +120,14 @@ struct EvaluateTensorAttributeExpr { } } - optional + std::optional operator()(ListIndexAccess const &index_access) { - optional v = + std::optional v = this->evaluate(index_access.attribute_key); return evaluate_list_index_access(index_access.index, v); } - optional + std::optional operator()(ListSize const &list_size) { return evaluate_list_size(this->evaluate(list_size.attribute_key)); } @@ -134,29 +136,29 @@ struct EvaluateTensorAttributeExpr { ParallelTensor tensor_shape; }; -optional +std::optional evaluate_attribute_expr(ParallelTensor const &tensor_shape, AttributeExpr const &expr) { return visit(EvaluateTensorAttributeExpr(tensor_shape), expr); } -optional +std::optional evaluate_attribute_expr(Operator const &attrs, AttributeExpr const &expr) { return visit(EvaluateOperatorAttributeExpr(attrs), expr); } template -optional satisfies(ConstraintType constraint_type, - V const &constraint_value, - optional const &maybe_attribute_value) { +std::optional satisfies(ConstraintType constraint_type, + V const &constraint_value, + std::optional const &maybe_attribute_value) { if (!maybe_attribute_value.has_value()) { - return nullopt; + return std::nullopt; } V attr_val = maybe_attribute_value.value(); if (attr_val.index() != constraint_value.index()) { - return nullopt; + return std::nullopt; } if (constraint_type == ConstraintType::EQUAL) { @@ -166,15 +168,15 @@ optional satisfies(ConstraintType constraint_type, } } -optional satisfies(ParallelTensor const &tensor_shape, - TensorAttributeConstraint const &constraint) { +std::optional satisfies(ParallelTensor const &tensor_shape, + TensorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(tensor_shape, constraint.attribute_expr); return satisfies( constraint.constraint_type, constraint.attribute_value, value); } -optional satisfies(Operator const ¶ms, - OperatorAttributeConstraint const &constraint) { +std::optional satisfies(Operator const ¶ms, + OperatorAttributeConstraint const &constraint) { auto value = evaluate_attribute_expr(params, constraint.attribute_expr); OperatorAttributeValue v = value.value(); return satisfies( @@ -182,12 +184,12 @@ optional satisfies(Operator const ¶ms, } template -optional optional_all_of(Container const &container, - Function const &func) { +std::optional optional_all_of(Container const &container, + Function const &func) { for (auto const &element : container) { - optional condition = func(element); + std::optional condition = func(element); if (!condition.has_value()) { - return nullopt; + return std::nullopt; } if (!condition.value()) { @@ -197,16 +199,16 @@ optional optional_all_of(Container const &container, return true; } -optional satisfies(Operator const ¶ms, - OperatorPattern const &pattern) { +std::optional satisfies(Operator const ¶ms, + OperatorPattern const &pattern) { return optional_all_of(pattern.attribute_constraints, [&](OperatorAttributeConstraint const &c) { return satisfies(params, c); }); } -optional satisfies(ParallelTensor const ¶ms, - ParallelTensorPattern const &pattern) { +std::optional satisfies(ParallelTensor const ¶ms, + ParallelTensorPattern const &pattern) { return optional_all_of( pattern.attribute_constraints, [&](TensorAttributeConstraint const &c) { return satisfies(params, c); }); @@ -229,7 +231,7 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, for (auto const &kv : patternMatch.node_assignment) { Node patternNode = kv.first; Node pcgNode = kv.second; - optional constraintResult = + std::optional constraintResult = satisfies(pcg.at(pcgNode), pattern.value().at(patternNode)); result &= constraintResult.value_or(false); } @@ -237,7 +239,7 @@ bool assignment_satisfies(SubParallelComputationGraph const &pcg, for (auto const &kv : patternMatch.edge_assignment) { OpenMultiDiEdge patternEdge = kv.first; OpenMultiDiEdge pcgEdge = kv.second; - optional constraintResult = + std::optional constraintResult = satisfies(pcg.at(pcgEdge), pattern.value().at(patternEdge)); result &= constraintResult.value_or(false); } diff --git a/lib/substitutions/src/graph_pattern_match.cc b/lib/substitutions/src/graph_pattern_match.cc index 7114c2d8ce..f9c6b9a773 100644 --- a/lib/substitutions/src/graph_pattern_match.cc +++ b/lib/substitutions/src/graph_pattern_match.cc @@ -56,7 +56,7 @@ MatchSplit apply_split(OpenMultiDiGraphView const &pattern, } else { assert(is_standard_edge(pattern_edge)); assert(is_standard_edge(graph_edge)); - auto standard_edge = mpark::get(pattern_edge); + auto standard_edge = std::get(pattern_edge); auto divided = edge_splits.at_l(standard_edge); auto divided_graph_edge = split_edge(get(graph_edge)); handle_edge(divided.first, divided_graph_edge.first); @@ -98,7 +98,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, } UpwardOpenMultiDiEdge matched_edge = narrow(graph_matched_edge).value(); - InputMultiDiEdge input_edge = mpark::get(e); + InputMultiDiEdge input_edge = std::get(e); if (match.node_assignment.at_l(input_edge.dst) != get_dst_node(matched_edge)) { return false; @@ -109,7 +109,7 @@ bool pattern_matches(OpenMultiDiGraphView const &pattern, } DownwardOpenMultiDiEdge matched_edge = narrow(graph_matched_edge).value(); - OutputMultiDiEdge output_edge = mpark::get(e); + OutputMultiDiEdge output_edge = std::get(e); if (match.node_assignment.at_l(output_edge.src) != get_src_node(matched_edge)) { return false; @@ -148,7 +148,7 @@ bool src_compare(T const &lhs, T const &rhs) { return get_src_idx(lhs) < get_src_idx(rhs); } -optional +std::optional get_candidate_singleton_match(OpenMultiDiGraphView const &pattern, OpenMultiDiGraphView const &graph, Node const &graph_node) { @@ -170,11 +170,11 @@ optional get_outgoing_edges(pattern, pattern_node); if (!pattern_incoming.empty() && pattern_incoming.size() != incoming.size()) { - return nullopt; + return std::nullopt; } if (!pattern_outgoing.empty() && pattern_outgoing.size() != outgoing.size()) { - return nullopt; + return std::nullopt; } std::vector incoming_ordered = @@ -198,7 +198,7 @@ optional node_port_mapping.emplace(graph_port, pattern_port); } else { if (pattern_port != node_port_mapping.at(graph_port)) { - return nullopt; + return std::nullopt; } } match.edge_assignment.equate(widen(pattern_edge), @@ -217,7 +217,7 @@ optional node_port_mapping.insert({graph_port, pattern_port}); } else { if (pattern_port != node_port_mapping.at(graph_port)) { - return nullopt; + return std::nullopt; } } match.edge_assignment.equate(widen(pattern_edge), @@ -228,7 +228,7 @@ optional return match; } -optional unsplit_matches( +std::optional unsplit_matches( MultiDiGraphPatternMatch const &prefix, MultiDiGraphPatternMatch const &postfix, bidict> const @@ -248,7 +248,7 @@ optional unsplit_matches( if (output_graph_edge == input_graph_edge) { result.edge_assignment.equate(standard_edge, output_graph_edge); } else { - return nullopt; + return std::nullopt; } } @@ -272,7 +272,7 @@ std::vector std::vector matches; if (is_singleton_pattern(pattern)) { for (Node const &graph_node : get_nodes(graph)) { - optional candidate = + std::optional candidate = get_candidate_singleton_match(pattern, graph, graph_node); if (candidate.has_value() && pattern_matches( @@ -290,7 +290,7 @@ std::vector auto edge_splits = get_edge_splits(pattern, split); for (MultiDiGraphPatternMatch const &prefix_match : prefix_matches) { for (MultiDiGraphPatternMatch const &postfix_match : postfix_matches) { - optional unsplit = + std::optional unsplit = unsplit_matches(prefix_match, postfix_match, edge_splits); if (unsplit.has_value()) { matches.push_back(unsplit.value()); diff --git a/lib/substitutions/src/operator_attributes.cc b/lib/substitutions/src/operator_attributes.cc index 3922b091a7..8bd8688194 100644 --- a/lib/substitutions/src/operator_attributes.cc +++ b/lib/substitutions/src/operator_attributes.cc @@ -3,48 +3,48 @@ namespace FlexFlow { -optional get_attribute(BatchMatmulAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(BatchMatmulAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(CastAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(CastAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.dtype; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(CombineAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(CombineAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.combine_dim; case OperatorAttributeKey::PARALLEL_DIM: return p.combine_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ConcatAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(ConcatAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(Conv2DAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(Conv2DAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -65,44 +65,44 @@ optional get_attribute(Conv2DAttrs const &p, case OperatorAttributeKey::USE_BIAS: return p.use_bias; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementBinaryAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(ElementBinaryAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementUnaryAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(ElementUnaryAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ElementScalarUnaryAttrs const &p, - OperatorAttributeKey key) { +std::optional + get_attribute(ElementScalarUnaryAttrs const &p, OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(DropoutAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(DropoutAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(EmbeddingAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(EmbeddingAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::DATA_TYPE: return p.data_type; @@ -113,38 +113,38 @@ optional get_attribute(EmbeddingAttrs const &p, case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(FlatAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(FlatAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(GatherAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(GatherAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(LayerNormAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(LayerNormAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(LinearAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(LinearAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::OUT_CHANNELS: return p.out_channels; @@ -159,24 +159,24 @@ optional get_attribute(LinearAttrs const &p, case OperatorAttributeKey::REGULARIZER: return p.regularizer; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(MultiHeadAttentionAttrs const &p, - OperatorAttributeKey key) { +std::optional + get_attribute(MultiHeadAttentionAttrs const &p, OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::NUM_HEADS: return p.num_heads; case OperatorAttributeKey::USE_BIAS: return p.bias; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(Pool2DAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(Pool2DAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::KERNEL_H: return p.kernel_h; @@ -195,97 +195,97 @@ optional get_attribute(Pool2DAttrs const &p, case OperatorAttributeKey::ACTIVATION: return p.activation; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReduceAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(ReduceAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReductionAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(ReductionAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.reduction_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.reduction_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(RepartitionAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(RepartitionAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.repartition_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.repartition_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReplicateAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(ReplicateAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PARALLEL_OP_DIM: return p.replicate_dim; case OperatorAttributeKey::PARALLEL_OP_DEGREE: return p.replicate_degree; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(ReshapeAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(ReshapeAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(SplitAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(SplitAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.axis; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(SoftmaxAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(SoftmaxAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::AXIS: return p.dim; default: - return nullopt; + return std::nullopt; } } -optional get_attribute(TopKAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(TopKAttrs const &p, + OperatorAttributeKey key) { switch (key) { default: - return nullopt; + return std::nullopt; } } -optional get_attribute(TransposeAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(TransposeAttrs const &p, + OperatorAttributeKey key) { switch (key) { case OperatorAttributeKey::PERMUTATION: return p.perm; default: - return nullopt; + return std::nullopt; } } @@ -293,7 +293,7 @@ struct GetAttribute { GetAttribute(OperatorAttributeKey key) : key(key) {} template - optional operator()(T const &t) { + std::optional operator()(T const &t) { return get_attribute(t, this->key); } @@ -303,17 +303,17 @@ struct GetAttribute { struct GetOpType { template - optional operator()(T const &t) { + std::optional operator()(T const &t) { return get_op_type(t); } }; -optional get_attribute(PCGOperatorAttrs const &p, - OperatorAttributeKey key) { +std::optional get_attribute(PCGOperatorAttrs const &p, + OperatorAttributeKey key) { if (key == OperatorAttributeKey::OP_TYPE) { - return visit(GetOpType{}, p); + return std::visit(GetOpType{}, p); } - return visit(GetAttribute(key), p); + return std::visit(GetAttribute(key), p); } } // namespace FlexFlow diff --git a/lib/substitutions/src/substitution.cc b/lib/substitutions/src/substitution.cc index dd28a9aa5d..15816185ee 100644 --- a/lib/substitutions/src/substitution.cc +++ b/lib/substitutions/src/substitution.cc @@ -113,51 +113,53 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, assignments.emplace(key, value); } assert(contains_key(assignments, OperatorAttributeKey::OP_TYPE)); - assert(holds_alternative( + assert(std::holds_alternative( assignments.at(OperatorAttributeKey::OP_TYPE))); OperatorType op_type = - get(assignments.at(OperatorAttributeKey::OP_TYPE)); + std::get(assignments.at(OperatorAttributeKey::OP_TYPE)); switch (op_type) { case Op::BATCHMATMUL: - return Operator( - BatchMatmulAttrs{ - get(assignments.at(OperatorAttributeKey::A_SEQ_LENGTH_DIM)), - get(assignments.at(OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, - nullopt); + return Operator{ + BatchMatmulAttrs{std::get(assignments.at( + OperatorAttributeKey::A_SEQ_LENGTH_DIM)), + std::get(assignments.at( + OperatorAttributeKey::B_SEQ_LENGTH_DIM))}, + std::nullopt}; case Op::BATCHNORM: - return Operator( - BatchNormAttrs{get(assignments.at(OperatorAttributeKey::RELU))}, - nullopt); + return Operator{BatchNormAttrs{std::get( + assignments.at(OperatorAttributeKey::RELU))}, + std::nullopt}; case Op::CAST: - return Operator(CastAttrs{get( + return Operator{CastAttrs{std::get( assignments.at(OperatorAttributeKey::DATA_TYPE))}, - nullopt); + std::nullopt}; case Op::CONCAT: - return Operator( + return Operator{ ConcatAttrs{ - get(assignments.at(OperatorAttributeKey::AXIS)), - get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, - nullopt); + std::get(assignments.at(OperatorAttributeKey::AXIS)), + std::get(assignments.at(OperatorAttributeKey::NUM_INPUTS))}, + std::nullopt}; case Op::CONV2D: - return Operator( + return Operator{ Conv2DAttrs{ - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::KERNEL_H)), - get(assignments.at(OperatorAttributeKey::KERNEL_W)), - get(assignments.at(OperatorAttributeKey::STRIDE_H)), - get(assignments.at(OperatorAttributeKey::STRIDE_W)), - get(assignments.at(OperatorAttributeKey::PADDING_H)), - get(assignments.at(OperatorAttributeKey::PADDING_W)), - get(assignments.at(OperatorAttributeKey::GROUPS)), - get(assignments.at(OperatorAttributeKey::ACTIVATION)), - get(assignments.at(OperatorAttributeKey::USE_BIAS))}, - nullopt); + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + std::get(assignments.at(OperatorAttributeKey::GROUPS)), + std::get( + assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get(assignments.at(OperatorAttributeKey::USE_BIAS))}, + std::nullopt}; case Op::DROPOUT: - return Operator( - DropoutAttrs{get(assignments.at(OperatorAttributeKey::RATE)), - get( - assignments.at(OperatorAttributeKey::SEED))}, - nullopt); + return Operator{DropoutAttrs{std::get(assignments.at( + OperatorAttributeKey::RATE)), + std::get(assignments.at( + OperatorAttributeKey::SEED))}, + std::nullopt}; case Op::EW_ADD: case Op::EW_DIV: case Op::EW_EQUAL: @@ -167,25 +169,25 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::EW_MIN: case Op::EW_MUL: case Op::EW_SUB: - return Operator( - ElementBinaryAttrs{ - op_type, - get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - get( - assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_LHS)), - get( - assignments.at(OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, - nullopt); + return Operator{ + ElementBinaryAttrs{op_type, + std::get(assignments.at( + OperatorAttributeKey::DATA_TYPE)), + std::get(assignments.at( + OperatorAttributeKey::SHOULD_BROADCAST_LHS)), + std::get(assignments.at( + OperatorAttributeKey::SHOULD_BROADCAST_RHS))}, + std::nullopt}; case Op::SCALAR_ADD: case Op::SCALAR_FLOOR_DIV: case Op::SCALAR_MULTIPLY: case Op::SCALAR_SUB: case Op::SCALAR_TRUE_DIV: - return Operator( + return Operator{ ElementScalarUnaryAttrs{ op_type, - get(assignments.at(OperatorAttributeKey::SCALAR))}, - nullopt); + std::get(assignments.at(OperatorAttributeKey::SCALAR))}, + std::nullopt}; case Op::EXP: case Op::IDENTITY: case Op::GELU: @@ -193,69 +195,73 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::POW: case Op::SIN: case Op::COS: - return Operator(ElementUnaryAttrs{op_type}, nullopt); + return Operator{ElementUnaryAttrs{op_type}, std::nullopt}; case Op::EMBEDDING: - return Operator( + return Operator{ EmbeddingAttrs{ - get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::AGGR)), - get(assignments.at(OperatorAttributeKey::OP_TYPE))}, - nullopt); + std::get(assignments.at(OperatorAttributeKey::NUM_ENTRIES)), + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::AGGR)), + std::get( + assignments.at(OperatorAttributeKey::OP_TYPE))}, + std::nullopt}; case Op::FLAT: - return Operator(FlatAttrs{}, nullopt); + return Operator{FlatAttrs{}, std::nullopt}; case Op::GATHER: - return Operator( - GatherAttrs{get(assignments.at(OperatorAttributeKey::DIM))}, - nullopt); + return Operator{GatherAttrs{std::get( + assignments.at(OperatorAttributeKey::DIM))}, + std::nullopt}; case Op::INPUT: - return Operator(InputAttrs{}, nullopt); + return Operator{InputAttrs{}, std::nullopt}; case Op::LAYERNORM: - return Operator( + return Operator{ LayerNormAttrs{ - get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), - get( + std::get( assignments.at(OperatorAttributeKey::ELEMENTWISE_AFFINE)), - get(assignments.at(OperatorAttributeKey::EPSILON))}, - nullopt); + std::get(assignments.at(OperatorAttributeKey::EPSILON))}, + std::nullopt}; case Op::LINEAR: - return Operator( + return Operator{ LinearAttrs{ - get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), - get(assignments.at(OperatorAttributeKey::USE_BIAS)), - get(assignments.at(OperatorAttributeKey::DATA_TYPE)), - get(assignments.at(OperatorAttributeKey::ACTIVATION)), - get>( + std::get(assignments.at(OperatorAttributeKey::OUT_CHANNELS)), + std::get(assignments.at(OperatorAttributeKey::USE_BIAS)), + std::get( + assignments.at(OperatorAttributeKey::DATA_TYPE)), + std::get( + assignments.at(OperatorAttributeKey::ACTIVATION)), + std::get>( assignments.at(OperatorAttributeKey::REGULARIZER))}, - nullopt); + std::nullopt}; case Op::MULTIHEAD_ATTENTION: - return Operator( + return Operator{ MultiHeadAttentionAttrs{ - get(assignments.at(OperatorAttributeKey::EMBED_DIM)), - get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - get(assignments.at(OperatorAttributeKey::NUM_HEADS)), - get(assignments.at(OperatorAttributeKey::VDIM)), - get(assignments.at(OperatorAttributeKey::DROPOUT)), - get(assignments.at(OperatorAttributeKey::BIAS)), - get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), - get(assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, - nullopt); + std::get(assignments.at(OperatorAttributeKey::EMBED_DIM)), + std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + std::get(assignments.at(OperatorAttributeKey::NUM_HEADS)), + std::get(assignments.at(OperatorAttributeKey::VDIM)), + std::get(assignments.at(OperatorAttributeKey::DROPOUT)), + std::get(assignments.at(OperatorAttributeKey::BIAS)), + std::get(assignments.at(OperatorAttributeKey::ADD_BIAS_KV)), + std::get( + assignments.at(OperatorAttributeKey::ADD_ZERO_ATTN))}, + std::nullopt}; case Op::NOOP: - return Operator(NoopAttrs{}, nullopt); + return Operator{NoopAttrs{}, std::nullopt}; case Op::POOL2D: - return Operator( + return Operator{ Pool2DAttrs{ - get(assignments.at(OperatorAttributeKey::KERNEL_H)), - get(assignments.at(OperatorAttributeKey::KERNEL_W)), - get(assignments.at(OperatorAttributeKey::STRIDE_H)), - get(assignments.at(OperatorAttributeKey::STRIDE_W)), - get(assignments.at(OperatorAttributeKey::PADDING_H)), - get(assignments.at(OperatorAttributeKey::PADDING_W)), - get(assignments.at(OperatorAttributeKey::POOL_TYPE)), - get( + std::get(assignments.at(OperatorAttributeKey::KERNEL_H)), + std::get(assignments.at(OperatorAttributeKey::KERNEL_W)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_H)), + std::get(assignments.at(OperatorAttributeKey::STRIDE_W)), + std::get(assignments.at(OperatorAttributeKey::PADDING_H)), + std::get(assignments.at(OperatorAttributeKey::PADDING_W)), + std::get(assignments.at(OperatorAttributeKey::POOL_TYPE)), + std::get( assignments.at(OperatorAttributeKey::ACTIVATION))}, - nullopt); + std::nullopt}; case Op::REDUCE_ARGMAX: case Op::REDUCE_ARGMIN: case Op::REDUCE_MAX: @@ -263,67 +269,72 @@ Operator get_operator_attrs(SubParallelComputationGraph const &graph, case Op::REDUCE_MIN: case Op::REDUCE_PROD: case Op::REDUCE_SUM: - return Operator( + return Operator{ ReduceAttrs{ - get>( + std::get>( assignments.at(OperatorAttributeKey::AXES)), op_type, - get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, - nullopt); + std::get(assignments.at(OperatorAttributeKey::KEEP_DIMS))}, + std::nullopt}; case Op::REVERSE: - return Operator(ReverseAttrs{get( + return Operator{ReverseAttrs{std::get( assignments.at(OperatorAttributeKey::AXIS))}, - nullopt); + std::nullopt}; case Op::RESHAPE: - return Operator(ReshapeAttrs{get( + return Operator{ReshapeAttrs{std::get( assignments.at(OperatorAttributeKey::SHAPE))}, - nullopt); + std::nullopt}; case Op::SPLIT: - return Operator( - SplitAttrs{get>( - assignments.at(OperatorAttributeKey::SPLITS)), - get(assignments.at(OperatorAttributeKey::AXIS))}, - nullopt); + return Operator{ + SplitAttrs{ + std::get>( + assignments.at(OperatorAttributeKey::SPLITS)), + std::get(assignments.at(OperatorAttributeKey::AXIS))}, + std::nullopt}; case Op::SOFTMAX: - return Operator(SoftmaxAttrs{get( + return Operator{SoftmaxAttrs{std::get( assignments.at(OperatorAttributeKey::DIM))}, - nullopt); + std::nullopt}; case Op::TOPK: - return Operator( - TopKAttrs{get(assignments.at(OperatorAttributeKey::K)), - get(assignments.at(OperatorAttributeKey::SORTED))}, - nullopt); + return Operator{ + TopKAttrs{ + std::get(assignments.at(OperatorAttributeKey::K)), + std::get(assignments.at(OperatorAttributeKey::SORTED))}, + std::nullopt}; case Op::TRANSPOSE: - return Operator( - TransposeAttrs{get>( + return Operator{ + TransposeAttrs{std::get>( assignments.at(OperatorAttributeKey::PERMUTATION))}, - nullopt); + std::nullopt}; case Op::COMBINE: - return Operator( - CombineAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + return Operator{CombineAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, + std::nullopt}; case Op::REDUCTION: - return Operator( - ReductionAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + return Operator{ + ReductionAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, + std::nullopt}; case Op::REPARTITION: - return Operator( - RepartitionAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + return Operator{ + RepartitionAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, + std::nullopt}; case Op::REPLICATE: - return Operator( - ReplicateAttrs{ - get(assignments.at(OperatorAttributeKey::PARALLEL_DIM)), - get(assignments.at(OperatorAttributeKey::PARALLEL_DEGREE))}, - nullopt); + return Operator{ + ReplicateAttrs{std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DIM)), + std::get(assignments.at( + OperatorAttributeKey::PARALLEL_DEGREE))}, + std::nullopt}; default: - mk_runtime_error("Unknown Operator"); + throw mk_runtime_error("Unknown Operator"); } } @@ -413,11 +424,8 @@ SubParallelComputationGraph Substitution const &substitution, MultiDiGraphPatternMatch const &match) { SubParallelComputationGraph new_pcg = - OutputLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>(); + OutputLabelledOpenMultiDiGraph::template create< + UnorderedOutputLabelledOpenMultiDiGraph>(); bidict node_mapping; // Refactor it with global nodes for (Node const &node : get_nodes(pcg)) { if (!contains_r(match.node_assignment, node)) { @@ -438,23 +446,23 @@ SubParallelComputationGraph } for (OpenMultiDiEdge const &output_edge : get_edges(substitution.output_graph_expr.value())) { - if (holds_alternative(output_edge)) { - InputMultiDiEdge e = get(output_edge); + if (std::holds_alternative(output_edge)) { + InputMultiDiEdge e = std::get(output_edge); OpenMultiDiEdge original_edge = match.edge_assignment.at_l(substitution.input_mapping.at_r(e)); visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, original_edge, output_edge); - } else if (holds_alternative(output_edge)) { - OutputMultiDiEdge e = get(output_edge); + } else if (std::holds_alternative(output_edge)) { + OutputMultiDiEdge e = std::get(output_edge); OpenMultiDiEdge original_edge = match.edge_assignment.at_l(substitution.output_mapping.at_r(e)); visit(AddNewEdgeFunctor{pcg, new_pcg, match, node_mapping}, original_edge, output_edge); } else { - assert(holds_alternative(output_edge)); - MultiDiEdge e = get(output_edge); + assert(std::holds_alternative(output_edge)); + MultiDiEdge e = std::get(output_edge); new_pcg.add_edge(MultiDiEdge{node_mapping.at_l(e.dst), new_pcg.add_node_port(), node_mapping.at_l(e.src), diff --git a/lib/substitutions/test/CMakeLists.txt b/lib/substitutions/test/CMakeLists.txt index d7e35ef9af..cfd6383e95 100644 --- a/lib/substitutions/test/CMakeLists.txt +++ b/lib/substitutions/test/CMakeLists.txt @@ -1,6 +1,6 @@ ff_add_test_executable( NAME - substitutions-test + substitutions-tests SRC_PATTERNS src/*.cc PRIVATE_INCLUDE diff --git a/lib/substitutions/test/src/test_pattern_matches.cc b/lib/substitutions/test/src/test_pattern_matches.cc index cc8a5cd5bd..5d72bbff7e 100644 --- a/lib/substitutions/test/src/test_pattern_matches.cc +++ b/lib/substitutions/test/src/test_pattern_matches.cc @@ -62,46 +62,50 @@ struct Arbitrary { // }); // } -TEST_CASE("find_pattern_matches_small") { - MultiDiGraph g = MultiDiGraph::template create(); - - { - Node n0 = g.add_node(); - Node n1 = g.add_node(); - Node n2 = g.add_node(); - Node n3 = g.add_node(); - - MultiDiEdge e0{n1, g.add_node_port(), n0, g.add_node_port()}; - MultiDiEdge e1{n2, g.add_node_port(), n1, g.add_node_port()}; - MultiDiEdge e2{n3, g.add_node_port(), n2, g.add_node_port()}; - - g.add_edge(e0); - g.add_edge(e1); - g.add_edge(e2); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("find_pattern_matches_small") { + MultiDiGraph g = MultiDiGraph::template create(); - MultiDiGraph sg0 = MultiDiGraph::template create(); + { + Node n0 = g.add_node(); + Node n1 = g.add_node(); + Node n2 = g.add_node(); + Node n3 = g.add_node(); - { - Node n0 = sg0.add_node(); - Node n1 = sg0.add_node(); + MultiDiEdge e0{n1, g.add_node_port(), n0, g.add_node_port()}; + MultiDiEdge e1{n2, g.add_node_port(), n1, g.add_node_port()}; + MultiDiEdge e2{n3, g.add_node_port(), n2, g.add_node_port()}; - MultiDiEdge e0{n1, sg0.add_node_port(), n0, sg0.add_node_port()}; + g.add_edge(e0); + g.add_edge(e1); + g.add_edge(e2); + } - sg0.add_edge(e0); - } + MultiDiGraph sg0 = MultiDiGraph::template create(); + + { + Node n0 = sg0.add_node(); + Node n1 = sg0.add_node(); + + MultiDiEdge e0{n1, sg0.add_node_port(), n0, sg0.add_node_port()}; + + sg0.add_edge(e0); + } - MatchAdditionalCriterion always_true{ - [](Node const &, Node const &) { return true; }, - [](OpenMultiDiEdge const &, OpenMultiDiEdge const &) { return true; }}; + MatchAdditionalCriterion always_true{ + [](Node const &, Node const &) { return true; }, + [](OpenMultiDiEdge const &, OpenMultiDiEdge const &) { return true; }}; - std::vector matches = find_pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); + std::vector matches = find_pattern_matches( + as_openmultidigraph(sg0), as_openmultidigraph(g), always_true); - RC_ASSERT(matches.size() == 3); + RC_ASSERT(matches.size() == 3); - for (MultiDiGraphPatternMatch const &match : matches) { - RC_ASSERT(pattern_matches( - as_openmultidigraph(sg0), as_openmultidigraph(g), match, always_true)); + for (MultiDiGraphPatternMatch const &match : matches) { + RC_ASSERT(pattern_matches(as_openmultidigraph(sg0), + as_openmultidigraph(g), + match, + always_true)); + } } } diff --git a/lib/substitutions/test/src/test_substitution.cc b/lib/substitutions/test/src/test_substitution.cc index a33e9127cc..df22d8a620 100644 --- a/lib/substitutions/test/src/test_substitution.cc +++ b/lib/substitutions/test/src/test_substitution.cc @@ -5,129 +5,128 @@ using namespace FlexFlow; -TEST_CASE("apply_substitution") { - OperatorPattern operator_pattern_n0{ - std::vector{OperatorAttributeConstraint{ - ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; - - ParallelTensorPattern tensor_pattern_e0{ - std::vector{TensorAttributeConstraint{ - ConstraintType::EQUAL, - ListIndexAccess{TensorAttributeKey::DIM_SIZES, 0}, - 2}}}; - - ParallelTensorPattern tensor_pattern_empty{ - std::vector{}}; - - auto ig = - OutputLabelledOpenMultiDiGraph:: - create, - UnorderedLabelling, - UnorderedLabelling>(); - Node n0 = ig.add_node(operator_pattern_n0); - NodePort p0 = ig.add_node_port(); - InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; - ig.add_edge(e0); - ig.add_label(e0, tensor_pattern_e0); - - RC_ASSERT(get_nodes(ig).size() == 1); - RC_ASSERT(get_edges(ig).size() == 1); - - GraphPattern input_graph{ig}; - - OperatorAttrAssignment op_ass_n1{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - OperatorAttrAssignment op_ass_n2{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, - {OperatorAttributeKey::OUT_CHANNELS, - OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, - {OperatorAttributeKey::USE_BIAS, - OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, - {OperatorAttributeKey::DATA_TYPE, - OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, - {OperatorAttributeKey::ACTIVATION, - OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, - {OperatorAttributeKey::REGULARIZER, - OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; - - OperatorAttrAssignment op_ass_n3{ - {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, - {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, - {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; - - auto og = NodeLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling>(); - Node n1 = og.add_node(op_ass_n1); - Node n2 = og.add_node(op_ass_n2); - Node n3 = og.add_node(op_ass_n3); - NodePort p1 = og.add_node_port(); - NodePort p2 = og.add_node_port(); - NodePort p3 = og.add_node_port(); - InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; - MultiDiEdge e2{n2, p2, n1, p1}; - MultiDiEdge e3{n3, p3, n2, p2}; - og.add_edge(e1); - og.add_edge(e2); - og.add_edge(e3); - OutputGraphExpr output_graph_expr{og}; - - RC_ASSERT(get_nodes(og).size() == 3); - RC_ASSERT(get_edges(og).size() == 3); - - bidict input_mapping; - input_mapping.equate(e0, e1); - bidict output_mapping; - - Substitution substitution{ - input_graph, output_graph_expr, input_mapping, output_mapping}; - - SubParallelComputationGraph pcg = - OutputLabelledOpenMultiDiGraph::create< - AdjacencyOpenMultiDiGraph, - UnorderedLabelling, - UnorderedLabelling, - UnorderedLabelling>(); - - Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); - Node n5 = pcg.add_node(Operator{ - LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, nullopt}, - "linear"}); - NodePort p4 = pcg.add_node_port(); - NodePort p5 = pcg.add_node_port(); - - MultiDiEdge e4{n5, p5, n4, p4}; - pcg.add_edge(e4); - pcg.add_label(e4, - ParallelTensor(ParallelTensorDims({2, 1}), - DataType::FLOAT, - CreateGrad::YES)); - - MatchAdditionalCriterion criterion{ - [&](Node const &pattern_node, Node const &graph_node) { - return operator_satisfies(pcg.at(graph_node), - input_graph.value().at(pattern_node)); - }, - [&](OpenMultiDiEdge const &pattern_edge, - OpenMultiDiEdge const &graph_edge) { - return parallel_tensor_satisfies(pcg.at(graph_edge), - input_graph.value().at(pattern_edge)); - }}; - - RC_ASSERT(criterion.node_criterion(n0, n5)); - - std::vector matches = - find_pattern_matches(input_graph, pcg, criterion); - - RC_ASSERT(matches.size() == 1); - - SubParallelComputationGraph new_pcg = - apply_substitution(pcg, substitution, matches[0]); - - RC_ASSERT(get_nodes(new_pcg).size() == 4); - RC_ASSERT(get_edges(new_pcg).size() == 3); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("apply_substitution") { + OperatorPattern operator_pattern_n0{ + std::vector{OperatorAttributeConstraint{ + ConstraintType::EQUAL, OperatorAttributeKey::OP_TYPE, Op::LINEAR}}}; + + ParallelTensorPattern tensor_pattern_e0{ + std::vector{ + TensorAttributeConstraint{ConstraintType::EQUAL, + ListIndexAccess{ + TensorAttributeKey::DIM_SIZES, 0}, + 2}}}; + + ParallelTensorPattern tensor_pattern_empty{ + std::vector{}}; + + auto ig = + OutputLabelledOpenMultiDiGraph:: + create>(); + Node n0 = ig.add_node(operator_pattern_n0); + NodePort p0 = ig.add_node_port(); + InputMultiDiEdge e0{n0, p0, std::make_pair(p0.value(), p0.value())}; + ig.add_edge(e0); + ig.add_label(e0, tensor_pattern_e0); + + RC_ASSERT(get_nodes(ig).size() == 1); + RC_ASSERT(get_edges(ig).size() == 1); + + GraphPattern input_graph{ig}; + + OperatorAttrAssignment op_ass_n1{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REPARTITION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + + OperatorAttrAssignment op_ass_n2{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::LINEAR}}, + {OperatorAttributeKey::OUT_CHANNELS, + OperatorAttrAccess{n0, OperatorAttributeKey::OUT_CHANNELS}}, + {OperatorAttributeKey::USE_BIAS, + OperatorAttrAccess{n0, OperatorAttributeKey::USE_BIAS}}, + {OperatorAttributeKey::DATA_TYPE, + OperatorAttrAccess{n0, OperatorAttributeKey::DATA_TYPE}}, + {OperatorAttributeKey::ACTIVATION, + OperatorAttrAccess{n0, OperatorAttributeKey::ACTIVATION}}, + {OperatorAttributeKey::REGULARIZER, + OperatorAttrAccess{n0, OperatorAttributeKey::REGULARIZER}}}}; + + OperatorAttrAssignment op_ass_n3{ + {{OperatorAttributeKey::OP_TYPE, AttrConstant{Op::REDUCTION}}, + {OperatorAttributeKey::PARALLEL_DIM, AttrConstant{ff_dim_t{0}}}, + {OperatorAttributeKey::PARALLEL_DEGREE, AttrConstant{2}}}}; + + auto og = NodeLabelledOpenMultiDiGraph::create< + UnorderedNodeLabelledOpenMultiDiGraph>(); + Node n1 = og.add_node(op_ass_n1); + Node n2 = og.add_node(op_ass_n2); + Node n3 = og.add_node(op_ass_n3); + NodePort p1 = og.add_node_port(); + NodePort p2 = og.add_node_port(); + NodePort p3 = og.add_node_port(); + InputMultiDiEdge e1{n1, p1, {p1.value(), p1.value()}}; + MultiDiEdge e2{n2, p2, n1, p1}; + MultiDiEdge e3{n3, p3, n2, p2}; + og.add_edge(e1); + og.add_edge(e2); + og.add_edge(e3); + OutputGraphExpr output_graph_expr{og}; + + RC_ASSERT(get_nodes(og).size() == 3); + RC_ASSERT(get_edges(og).size() == 3); + + bidict input_mapping; + input_mapping.equate(e0, e1); + bidict output_mapping; + + Substitution substitution{ + input_graph, output_graph_expr, input_mapping, output_mapping}; + + SubParallelComputationGraph pcg = + OutputLabelledOpenMultiDiGraph::create< + UnorderedOutputLabelledOpenMultiDiGraph>(); + + Node n4 = pcg.add_node(Operator{InputAttrs{}, "input"}); + Node n5 = pcg.add_node(Operator{ + LinearAttrs{1, false, DataType::FLOAT, Activation::RELU, std::nullopt}, + "linear"}); + NodePort p4 = pcg.add_node_port(); + NodePort p5 = pcg.add_node_port(); + + MultiDiEdge e4{n5, p5, n4, p4}; + pcg.add_edge(e4); + pcg.add_label(e4, + ParallelTensor(ParallelTensorDims({2, 1}), + DataType::FLOAT, + CreateGrad::YES)); + + MatchAdditionalCriterion criterion{ + [&](Node const &pattern_node, Node const &graph_node) { + return operator_satisfies(pcg.at(graph_node), + input_graph.value().at(pattern_node)); + }, + [&](OpenMultiDiEdge const &pattern_edge, + OpenMultiDiEdge const &graph_edge) { + return parallel_tensor_satisfies( + pcg.at(graph_edge), input_graph.value().at(pattern_edge)); + }}; + + RC_ASSERT(criterion.node_criterion(n0, n5)); + + std::vector matches = + find_pattern_matches(input_graph, pcg, criterion); + + RC_ASSERT(matches.size() == 1); + + SubParallelComputationGraph new_pcg = + apply_substitution(pcg, substitution, matches[0]); + + RC_ASSERT(get_nodes(new_pcg).size() == 4); + RC_ASSERT(get_edges(new_pcg).size() == 3); + } } diff --git a/lib/utils/include/utils/containers.decl.h b/lib/utils/include/utils/containers.decl.h index 84fd4a5acc..0332a331b2 100644 --- a/lib/utils/include/utils/containers.decl.h +++ b/lib/utils/include/utils/containers.decl.h @@ -2,9 +2,9 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_DECL_H #include "utils/bidict.h" -#include "utils/optional.decl" #include "utils/required_core.h" #include "utils/type_traits_core.h" +#include #include #include @@ -108,7 +108,7 @@ template std::vector values(C const &c); template -std::unordered_set> +std::unordered_set> items(C const &c); template @@ -293,6 +293,9 @@ T reversed(T const &t); template std::vector value_all(std::vector> const &v); +template +std::unordered_set value_all(std::unordered_set> const &v); + template std::vector subvec(std::vector const &v, std::optional const &maybe_start, diff --git a/lib/utils/include/utils/containers.h b/lib/utils/include/utils/containers.h index cdf4591cdb..1606eb0605 100644 --- a/lib/utils/include/utils/containers.h +++ b/lib/utils/include/utils/containers.h @@ -226,7 +226,7 @@ std::vector values(C const &c) { } template -std::unordered_set> +std::unordered_set> items(C const &c) { return {c.begin(), c.end()}; } @@ -673,6 +673,16 @@ std::vector value_all(std::vector> const &v) { }); } +template +std::unordered_set value_all(std::unordered_set> const &v) { + return transform(v, [](std::optional const &element) { + return unwrap(element, [] { + throw mk_runtime_error( + "Encountered element without value in call to value_all"); + }); + }); +} + template std::vector subvec(std::vector const &v, std::optional const &maybe_start, diff --git a/lib/utils/include/utils/dot_file.h b/lib/utils/include/utils/dot_file.h index 6cdc78f6d4..1fd9813646 100644 --- a/lib/utils/include/utils/dot_file.h +++ b/lib/utils/include/utils/dot_file.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -28,16 +29,16 @@ class DotFile { return s.str(); } bool has_ostream() const { - return this->owned_fstream.has_value() || this->out.has_value(); + return this->owned_fstream.has_value() || this->out != nullptr; } std::ostream &get_ostream() { bool has_owned_stream = this->owned_fstream.has_value(); - bool has_stream_ref = this->out.has_value(); + bool has_stream_ref = (this->out != nullptr); assert(has_owned_stream != has_stream_ref); if (has_owned_stream) { return this->owned_fstream.value(); } else if (has_stream_ref) { - return this->out.value(); + return *this->out; } else { throw std::runtime_error("No ostream value set"); } diff --git a/lib/utils/include/utils/fmt.h b/lib/utils/include/utils/fmt.h index 58982d6f36..905b4622f1 100644 --- a/lib/utils/include/utils/fmt.h +++ b/lib/utils/include/utils/fmt.h @@ -6,6 +6,8 @@ #include "utils/test_types.h" #include "utils/type_traits_core.h" +#include + namespace FlexFlow { template @@ -26,6 +28,12 @@ struct already_has_ostream_operator : std::true_type {}; template <> struct already_has_ostream_operator : std::true_type {}; +template <> +struct already_has_ostream_operator> : std::true_type {}; + +template <> +struct already_has_ostream_operator : std::true_type {}; + // This will create an error /* template diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index bb70a9093c..87b42a90d2 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -23,6 +23,7 @@ std::vector add_nodes(Graph &, int); std::vector add_nodes(UndirectedGraph &, int); std::vector add_nodes(DiGraph &, int); std::vector add_nodes(MultiDiGraph &, int); +std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes); std::vector add_node_ports(MultiDiGraph &, int); @@ -106,6 +107,11 @@ std::unordered_set get_node_edges(UndirectedGraphView const &, std::unordered_set get_outputs(MultiDiGraphView const &); std::unordered_set get_inputs(MultiDiGraphView const &); +std::unordered_set + get_open_outputs(OpenMultiDiGraphView const &); +std::unordered_set + get_open_inputs(OpenMultiDiGraphView const &); + std::unordered_set get_incoming_edges(MultiDiGraphView const &, Node const &); std::unordered_set get_incoming_edges(DiGraphView const &, diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h b/lib/utils/include/utils/graph/labelled/labelled_open.decl.h deleted file mode 100644 index cdd22b7847..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open.decl.h +++ /dev/null @@ -1,124 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_DECL_H - -#include "labelled_open_interfaces.h" -#include "node_labelled.h" -#include "utils/graph/open_graphs.h" - -namespace FlexFlow { - -template -struct LabelledOpenMultiDiGraphView { -private: - using Interface = ILabelledOpenMultiDiGraphView; - -public: - LabelledOpenMultiDiGraphView() = delete; - - operator OpenMultiDiGraphView() const; - // operator MultiDiGraphView() const; - - NodeLabel const &at(Node const &n) const; - EdgeLabel const &at(MultiDiEdge const &e) const; - InputLabel const &at(InputMultiDiEdge const &e) const; - OutputLabel const &at(OutputMultiDiEdge const &e) const; - - template - static typename std::enable_if::value, - LabelledOpenMultiDiGraphView>::type - create(); - -private: - std::shared_ptr ptr; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ( - LabelledOpenMultiDiGraphView); - -template -struct LabelledOpenMultiDiGraph { -private: - using Interface = - ILabelledOpenMultiDiGraph; - -public: - LabelledOpenMultiDiGraph() = delete; - LabelledOpenMultiDiGraph(LabelledOpenMultiDiGraph const &other) = default; - LabelledOpenMultiDiGraph & - operator=(LabelledOpenMultiDiGraph const &other) = default; - - operator LabelledOpenMultiDiGraphView() const; - - operator OpenMultiDiGraphView() const; - - friend void swap(LabelledOpenMultiDiGraph &lhs, - LabelledOpenMultiDiGraph &rhs) { - using std::swap; - - swap(lhs.ptr, rhs.ptr); - } - - Node add_node(NodeLabel const &l); - NodeLabel &at(Node const &n); - - NodePort add_node_port(); - - NodeLabel const &at(Node const &n) const; - - void add_node_unsafe(Node const &n, NodeLabel const &l); - - std::unordered_set query_nodes(NodeQuery const &q) const; - std::unordered_set - query_edges(OpenMultiDiEdgeQuery const &q) const; - - void add_edge( - MultiDiEdge const &e); // We should allow adding edges without labels. For - // example, we may want to first construct a PCG - // and infer its tensor shapes later. - void add_edge(InputMultiDiEdge const &e); - void add_edge(OutputMultiDiEdge const &e); - - void add_label(MultiDiEdge const &e, EdgeLabel const &l); - void add_label(InputMultiDiEdge const &e, EdgeLabel const &l); - void add_label(OutputMultiDiEdge const &e, EdgeLabel const &l); - - void add_edge(MultiDiEdge const &e, EdgeLabel const &l); - EdgeLabel &at(MultiDiEdge const &e); - EdgeLabel const &at(MultiDiEdge const &e) const; - - void add_edge(InputMultiDiEdge const &e, InputLabel const &l); - InputLabel &at(InputMultiDiEdge const &e); - InputLabel const &at(InputMultiDiEdge const &e) const; - - void add_edge(OutputMultiDiEdge const &, OutputLabel const &); - OutputLabel &at(OutputMultiDiEdge const &); - OutputLabel const &at(OutputMultiDiEdge const &) const; - - template - static typename std::enable_if::value, - LabelledOpenMultiDiGraph>::type - create(); - -private: - LabelledOpenMultiDiGraph(cow_ptr_t ptr); - -private: - cow_ptr_t ptr; -}; -CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ( - LabelledOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/labelled_open.h b/lib/utils/include/utils/graph/labelled/labelled_open.h deleted file mode 100644 index 58fd5416f7..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open.h +++ /dev/null @@ -1,173 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_H - -#include "labelled_open.decl.h" -#include "labelled_open_interfaces.h" -#include "node_labelled.h" -#include "utils/graph/open_graph_interfaces.h" -#include "utils/graph/open_graphs.h" - -namespace FlexFlow { - -// LabelledOpenMultiDiGraphView -template -LabelledOpenMultiDiGraphView::operator OpenMultiDiGraphView() - const { - return GraphInternal::create_open_multidigraph_view(this->ptr); -} - -// template -// LabelledOpenMultiDiGraphView::operator MultiDiGraphView() const { -// return GraphInternal::create_multidigraphview(this->ptr); -// } - -template -NodeLabel const & - LabelledOpenMultiDiGraphView::at(Node const &n) const { - return this->ptr->at(n); -} - -template -EdgeLabel const &LabelledOpenMultiDiGraphView::at( - MultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -InputLabel const &LabelledOpenMultiDiGraphView::at( - InputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -OutputLabel const &LabelledOpenMultiDiGraphView::at( - OutputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -template -enable_if_t::Interface, - BaseImpl>::value, - LabelledOpenMultiDiGraphView> - LabelledOpenMultiDiGraphView::create() { - return LabelledOpenMultiDiGraphView(std::make_shared()); -} - -// LabelledOpenMultiDiGraph -template -LabelledOpenMultiDiGraph:: - operator LabelledOpenMultiDiGraphView() const { - return GraphInternal::create_labelled_open_multidigraph_view( - this->ptr); -} - -template -LabelledOpenMultiDiGraph::operator OpenMultiDiGraphView() const { - return GraphInternal::create_open_multidigraph_view(this->ptr.get()); -} - -template -Node LabelledOpenMultiDiGraph::add_node( - NodeLabel const &l) { - return this->ptr.get_mutable()->add_node(l); -} - -template -NodeLabel &LabelledOpenMultiDiGraph::at(Node const &n) { - return this->ptr->at(n); -} - -template -NodeLabel const & - LabelledOpenMultiDiGraph::at(Node const &n) const { - return this->ptr->ILabelledMultiDiGraph::at(n); -} - -template -void LabelledOpenMultiDiGraph::add_node_unsafe( - Node const &n, NodeLabel const &l) { - this->ptr->add_node_unsafe(n, l); -} - -template -std::unordered_set LabelledOpenMultiDiGraph::query_nodes( - NodeQuery const &q) const { - return this->ptr->query_nodes(q); -} - -template -std::unordered_set - LabelledOpenMultiDiGraph::query_edges( - OpenMultiDiEdgeQuery const &q) const { - return this->ptr->query_edges(q); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - MultiDiEdge const &e, EdgeLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -EdgeLabel & - LabelledOpenMultiDiGraph::at(MultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -EdgeLabel const &LabelledOpenMultiDiGraph::at( - MultiDiEdge const &e) const { - return this->ptr->ILabelledMultiDiGraph::at(e); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - InputMultiDiEdge const &e, InputLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -InputLabel &LabelledOpenMultiDiGraph::at( - InputMultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -InputLabel const &LabelledOpenMultiDiGraph::at( - InputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -void LabelledOpenMultiDiGraph::add_edge( - OutputMultiDiEdge const &e, OutputLabel const &l) { - return this->ptr->add_edge(e, l); -} - -template -OutputLabel &LabelledOpenMultiDiGraph::at( - OutputMultiDiEdge const &e) { - return this->ptr->at(e); -} - -template -OutputLabel const &LabelledOpenMultiDiGraph::at( - OutputMultiDiEdge const &e) const { - return this->ptr->at(e); -} - -template -template -enable_if_t< - std::is_base_of::Interface, - BaseImpl>::value, - LabelledOpenMultiDiGraph> - LabelledOpenMultiDiGraph::create() { - return LabelledOpenMultiDiGraph(make_cow_ptr()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h deleted file mode 100644 index 2db654c615..0000000000 --- a/lib/utils/include/utils/graph/labelled/labelled_open_interfaces.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_INTERFACES_H -#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_LABELLED_OPEN_INTERFACES_H - -#include "standard_labelled_interfaces.h" -#include "utils/containers.h" -#include "utils/graph/open_graph_interfaces.h" - -namespace FlexFlow { - -template -struct ILabelledOpenMultiDiGraphView - : public IOpenMultiDiGraphView, - public ILabelledMultiDiGraphView { -public: - std::unordered_set - query_edges(MultiDiEdgeQuery const &q) const final { - return map_over_unordered_set( - [](OpenMultiDiEdge const &e) { return get(e); }, - IOpenMultiDiGraphView::query_edges( - static_cast(q))); - } - - using ILabelledMultiDiGraphView::at; - virtual InputLabel const &at(InputMultiDiEdge const &e) const = 0; - virtual OutputLabel const &at(OutputMultiDiEdge const &e) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT( - ILabelledOpenMultiDiGraphView); - -template -struct ILabelledOpenMultiDiGraph - : public ILabelledMultiDiGraph, - public ILabelledOpenMultiDiGraphView { -public: - virtual ILabelledOpenMultiDiGraph *clone() const = 0; - - virtual void add_edge(InputMultiDiEdge const &e, InputLabel const &label) = 0; - virtual void add_edge(OutputMultiDiEdge const &e, - OutputLabel const &label) = 0; - - virtual InputLabel const &at(InputMultiDiEdge const &e) const = 0; - virtual InputLabel &at(InputMultiDiEdge const &e) = 0; - - virtual OutputLabel const &at(OutputMultiDiEdge const &e) const = 0; - virtual OutputLabel &at(OutputMultiDiEdge const &e) = 0; - - using ILabelledMultiDiGraph::add_node; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledOpenMultiDiGraph); - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/graph/labelled/node_labelled.h b/lib/utils/include/utils/graph/labelled/node_labelled.h index f8ac988e73..856dd4434e 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled.h @@ -1,23 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_H -#include "label_interfaces.h" +#include "node_labelled_interfaces.h" #include "utils/graph/multidigraph.h" namespace FlexFlow { -template -struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { - INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; - INodeLabelledMultiDiGraphView & - operator=(INodeLabelledMultiDiGraphView const &) = delete; - - virtual ~INodeLabelledMultiDiGraphView() {} - - virtual NodeLabel const &at(Node const &n) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); - template struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: @@ -53,8 +41,7 @@ struct NodeLabelledMultiDiGraphView : virtual public MultiDiGraphView { private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraphView); @@ -64,7 +51,6 @@ struct NodeLabelledMultiDiGraph : virtual NodeLabelledMultiDiGraphView { private: using Interface = IMultiDiGraph; - using NodeLabelIf = ILabelling; public: NodeLabelledMultiDiGraph(NodeLabelledMultiDiGraph const &) = default; @@ -72,60 +58,50 @@ struct NodeLabelledMultiDiGraph operator=(NodeLabelledMultiDiGraph const &) = default; NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->get_ptr().at(n); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(); + return this->get_ptr().query_nodes(); } std::unordered_set query_edges(MultiDiEdge const &q) const { - return get_ptr().query_edges(); + return this->get_ptr().query_edges(); } Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } void add_edge(MultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of>::value, - NodeLabelledMultiDiGraph>::type + template + static typename std::enable_if::value, + NodeLabelledMultiDiGraph>::type create() { - return NodeLabelledMultiDiGraph(make_cow_ptr(), - make_cow_ptr()); + return NodeLabelledMultiDiGraph(make_cow_ptr()); } protected: - NodeLabelledMultiDiGraph(cow_ptr_t ptr, cow_ptr_t nl) - : NodeLabelledMultiDiGraphView(ptr), nl(nl) {} + NodeLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } - - cow_ptr_t nl; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(NodeLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h new file mode 100644 index 0000000000..c371a9a3bd --- /dev/null +++ b/lib/utils/include/utils/graph/labelled/node_labelled_interfaces.h @@ -0,0 +1,36 @@ +#ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_INTERFACES_H +#define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_NODE_LABELLED_INTERFACES_H + +#include "utils/graph/multidigraph.h" + +namespace FlexFlow { + +template +struct INodeLabelledMultiDiGraphView : virtual public IMultiDiGraphView { + INodeLabelledMultiDiGraphView() = default; + INodeLabelledMultiDiGraphView(INodeLabelledMultiDiGraphView const &) = delete; + INodeLabelledMultiDiGraphView & + operator=(INodeLabelledMultiDiGraphView const &) = delete; + + virtual ~INodeLabelledMultiDiGraphView() {} + + virtual NodeLabel const &at(Node const &n) const = 0; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledMultiDiGraphView); + +template +struct INodeLabelledMultiDiGraph + : virtual INodeLabelledMultiDiGraphView { + virtual NodeLabel &at(Node const &) = 0; + virtual Node add_node(NodeLabel const &l) = 0; + virtual NodePort add_node_port() = 0; + virtual void add_edge(MultiDiEdge const &) = 0; + + virtual INodeLabelledMultiDiGraph *clone() const = 0; + + using INodeLabelledMultiDiGraphView::at; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled/node_labelled_open.h b/lib/utils/include/utils/graph/labelled/node_labelled_open.h index 1ab14f5b3e..c864c7dacf 100644 --- a/lib/utils/include/utils/graph/labelled/node_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/node_labelled_open.h @@ -9,6 +9,7 @@ template struct INodeLabelledOpenMultiDiGraphView : virtual INodeLabelledMultiDiGraphView, virtual IOpenMultiDiGraphView { + INodeLabelledOpenMultiDiGraphView() = default; INodeLabelledOpenMultiDiGraphView(INodeLabelledOpenMultiDiGraphView const &) = delete; INodeLabelledOpenMultiDiGraphView & @@ -54,81 +55,77 @@ struct NodeLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; +template +struct INodeLabelledOpenMultiDiGraph + : virtual INodeLabelledOpenMultiDiGraphView { + virtual Node add_node(NodeLabel const &) = 0; + virtual NodePort add_node_port() = 0; + virtual NodeLabel &at(Node const &) = 0; + virtual void add_edge(OpenMultiDiEdge const &e) = 0; + + using INodeLabelledOpenMultiDiGraphView::at; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(INodeLabelledOpenMultiDiGraphView); + template struct NodeLabelledOpenMultiDiGraph : virtual NodeLabelledOpenMultiDiGraphView { private: - using Interface = IOpenMultiDiGraph; - using INodeLabel = ILabelling; + using Interface = INodeLabelledOpenMultiDiGraph; public: - // NodeLabelledOpenMultiDiGraph() = delete; NodeLabelledOpenMultiDiGraph(NodeLabelledOpenMultiDiGraph const &) = default; NodeLabelledOpenMultiDiGraph & operator=(NodeLabelledOpenMultiDiGraph const &) = default; - NodeLabel const &at(Node const &n) const { - return nl->get_label(n); - } - NodeLabel &at(Node const &n) { - return nl->get_label(n); + return this->get_ptr().at(n); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdge const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } void add_edge(OpenMultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of>::value, - NodeLabelledOpenMultiDiGraph>::type + using NodeLabelledOpenMultiDiGraphView::at; + + template + static typename std::enable_if::value, + NodeLabelledOpenMultiDiGraph>::type create() { - return NodeLabelledOpenMultiDiGraph(make_cow_ptr(), - make_cow_ptr()); + return NodeLabelledOpenMultiDiGraph(make_cow_ptr()); } private: - NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl) - : GraphView(ptr), nl(nl) {} + NodeLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } - - cow_ptr_t nl; }; } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/open_views.h b/lib/utils/include/utils/graph/labelled/open_views.h index 4a4c81aef9..494d8d9f9d 100644 --- a/lib/utils/include/utils/graph/labelled/open_views.h +++ b/lib/utils/include/utils/graph/labelled/open_views.h @@ -26,6 +26,10 @@ struct OutputLabelledOpenMultiDiSubgraphView return g.at(n); } + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + return g.at(i); + } + EdgeLabel const &at(MultiDiOutput const &o) const override { return g.at(o); } @@ -39,11 +43,61 @@ struct OutputLabelledOpenMultiDiSubgraphView return SubgraphView(g, nodes).query_edges(q); } + OutputLabelledOpenMultiDiSubgraphView *clone() const override { + return new OutputLabelledOpenMultiDiSubgraphView(g, nodes); + } + private: - OutputLabelledOpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OutputLabelledOpenMultiDiGraphView g; + std::unordered_set nodes; }; +template +struct ViewOutputLabelledAsOutputLabelledOpen + : virtual IOutputLabelledOpenMultiDiGraphView { + ViewOutputLabelledAsOutputLabelledOpen( + OutputLabelledMultiDiGraphView const &g) + : g(g) {} + + NodeLabel const &at(Node const &n) const override { + return g.at(n); + } + + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + assert(false); + } + + EdgeLabel const &at(MultiDiOutput const &o) const override { + return g.at(o); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); + } + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return transform(g.query_edges(q.standard_edge_query), + [](MultiDiEdge const &e) { return OpenMultiDiEdge(e); }); + } + + ViewOutputLabelledAsOutputLabelledOpen *clone() const override { + return new ViewOutputLabelledAsOutputLabelledOpen(g); + } + +private: + OutputLabelledMultiDiGraphView g; +}; + +template +OutputLabelledOpenMultiDiGraphView + view_output_labelled_as_output_labelled_open( + OutputLabelledMultiDiGraphView const &g) { + return OutputLabelledOpenMultiDiGraphView:: + template create< + ViewOutputLabelledAsOutputLabelledOpen>(g); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled.h b/lib/utils/include/utils/graph/labelled/output_labelled.h index 4d959782dc..ac5648c2e1 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled.h @@ -1,23 +1,11 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_H -#include "standard_labelled.h" +#include "node_labelled.h" +#include "output_labelled_interfaces.h" namespace FlexFlow { -template -struct IOutputLabelledMultiDiGraphView - : public INodeLabelledMultiDiGraphView { - IOutputLabelledMultiDiGraphView() = default; - IOutputLabelledMultiDiGraphView(IOutputLabelledMultiDiGraphView const &) = - delete; - IOutputLabelledMultiDiGraphView & - operator=(IOutputLabelledMultiDiGraphView const &) = delete; - - virtual OutputLabel const &at(MultiDiOutput const &) = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); - template struct OutputLabelledMultiDiGraphView : virtual public NodeLabelledMultiDiGraphView { @@ -31,19 +19,19 @@ struct OutputLabelledMultiDiGraphView operator=(OutputLabelledMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr().at(n); + return this->get_ptr().at(n); } OutputLabel const &at(MultiDiOutput const &o) const { - return get_ptr().at(o); + return this->get_ptr().at(o); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -55,13 +43,11 @@ struct OutputLabelledMultiDiGraphView } protected: - OutputLabelledMultiDiGraphView(cow_ptr_t ptr) - : NodeLabelledMultiDiGraphView(ptr) {} + using NodeLabelledMultiDiGraphView::NodeLabelledMultiDiGraphView; private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -69,9 +55,7 @@ template struct OutputLabelledMultiDiGraph : virtual OutputLabelledMultiDiGraphView { private: - using Interface = IMultiDiGraph; - using INodeLabel = ILabelling; - using IOutputLabel = ILabelling; + using Interface = IOutputLabelledMultiDiGraph; public: OutputLabelledMultiDiGraph(OutputLabelledMultiDiGraph const &other) = default; @@ -79,81 +63,67 @@ struct OutputLabelledMultiDiGraph operator=(OutputLabelledMultiDiGraph const &other) = default; Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); + return this->get_ptr().at(n); } NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } void add_output(MultiDiOutput const &o, OutputLabel const &l) { - ol->add_label(o, l); + this->get_ptr().add_output(o, l); }; void add_edge(MultiDiOutput const &o, MultiDiInput const &i) { - return get_ptr().add_edge(o, i); + this->get_ptr().add_edge(o, i); }; void add_edge(MultiDiEdge const &e) { - return get_ptr().add_edge(e); + this->get_ptr().add_edge(e); } OutputLabel &at(MultiDiOutput const &o) { - return ol->get_label(o); + return this->get_ptr().at(o); } OutputLabel const &at(MultiDiOutput const &o) const { - return ol->get_label(o); + return this->get_ptr().at(o); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of>::value, - OutputLabelledMultiDiGraph>::type + template + static typename std::enable_if::value, + OutputLabelledMultiDiGraph>::type create() { - return OutputLabelledMultiDiGraph( - make_cow_ptr(), make_cow_ptr(), make_cow_ptr()); + return OutputLabelledMultiDiGraph(make_cow_ptr()); } private: - OutputLabelledMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t ol) - : OutputLabelledMultiDiGraphView(ptr), nl(nl), - ol(ol) {} + OutputLabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} private: Interface &get_ptr() { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } - - cow_ptr_t nl; - cow_ptr_t ol; }; template struct IOutputLabelledMultiDiGraphView : public INodeLabelledMultiDiGraphView { - virtual OutputLabel &at(MultiDiOutput const &) = 0; + virtual OutputLabel const &at(MultiDiOutput const &) const = 0; + + using INodeLabelledMultiDiGraphView::at; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraphView); template struct IOutputLabelledMultiDiGraph - : public IOutputLabelledMultiDiGraphView { + : public IOutputLabelledMultiDiGraphView, + public INodeLabelledMultiDiGraph { public: virtual IOutputLabelledMultiDiGraph *clone() const = 0; virtual void add_output(MultiDiOutput const &output, OutputLabel const &label) = 0; - virtual void add_edge(MultiDiOutput const &output, - MultiDiInput const &input) = 0; - virtual NodePort add_node_ports() = 0; + virtual NodePort add_node_port() = 0; virtual NodeLabel &at(Node const &) = 0; virtual NodeLabel const &at(Node const &) const = 0; + virtual OutputLabel &at(MultiDiOutput const &) = 0; virtual OutputLabel const &at(MultiDiOutput const &) const = 0; }; CHECK_RC_COPY_VIRTUAL_COMPLIANT(IOutputLabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open.h b/lib/utils/include/utils/graph/labelled/output_labelled_open.h index eb406d1804..bc4fe3d828 100644 --- a/lib/utils/include/utils/graph/labelled/output_labelled_open.h +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open.h @@ -1,23 +1,15 @@ #ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN #define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN -#include "node_labelled.h" -#include "utils/graph/adjacency_openmultidigraph.h" +#include "node_labelled_open.h" +#include "output_labelled_open_interfaces.h" namespace FlexFlow { -template -struct IOutputLabelledOpenMultiDiGraphView - : virtual INodeLabelledOpenMultiDiGraphView { - virtual EdgeLabel const &at(InputMultiDiEdge const &) const = 0; - virtual EdgeLabel const &at(MultiDiOutput const &) const = 0; - - using INodeLabelledOpenMultiDiGraphView::at; -}; - template struct OutputLabelledOpenMultiDiGraphView - : virtual NodeLabelledOpenMultiDiGraphView { + : virtual NodeLabelledOpenMultiDiGraphView, + virtual OutputLabelledMultiDiGraphView { private: using Interface = IOutputLabelledOpenMultiDiGraphView; @@ -28,24 +20,34 @@ struct OutputLabelledOpenMultiDiGraphView operator=(OutputLabelledOpenMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr().at(n); + return this->get_ptr().at(n); } EdgeLabel const &at(InputMultiDiEdge const &i) const { - return get_ptr().at(i); + return this->get_ptr().at(i); } EdgeLabel const &at(MultiDiOutput const &o) const { - return get_ptr().at(o); + return this->get_ptr().at(o); + } + + template + EdgeLabel const &at(std::variant const &e) const { + return visit([&](auto const &e) -> auto const & { return this->at(e); }, e); + } + + template + EdgeLabel &at(std::variant const &e) { + return visit([&](auto const &e) -> auto & { return this->at(e); }, e); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } template @@ -62,8 +64,7 @@ struct OutputLabelledOpenMultiDiGraphView private: Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; @@ -71,10 +72,7 @@ template struct OutputLabelledOpenMultiDiGraph : virtual OutputLabelledOpenMultiDiGraphView { private: - using Interface = IOpenMultiDiGraph; - using INodeLabel = ILabelling; - using IInputLabel = ILabelling; - using IOutputLabel = ILabelling; + using Interface = IOutputLabelledOpenMultiDiGraph; public: OutputLabelledOpenMultiDiGraph() = delete; @@ -84,48 +82,35 @@ struct OutputLabelledOpenMultiDiGraph operator=(OutputLabelledOpenMultiDiGraph const &) = default; Node add_node(NodeLabel const &l) { - Node n = get_ptr().add_node(); - nl.get_mutable()->add_label(n, l); - return n; + return this->get_ptr().add_node(l); } NodePort add_node_port() { - return get_ptr().add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl.get_mutable()->get_label(n); - } - - NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } void add_label(MultiDiOutput const &o, EdgeLabel const &l) { - ol.get_mutable()->add_label(o, l); + this->get_ptr().add_label(o, l); }; void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) { - il.get_mutable()->add_label(e, l); + this->get_ptr().add_label(e, l); } void add_edge(OpenMultiDiEdge const &e) { - return get_ptr().add_edge(e); + return this->get_ptr().add_edge(e); } EdgeLabel &at(MultiDiOutput const &o) { - return ol.get_mutable()->get_label(o); - } - EdgeLabel const &at(MultiDiOutput const &o) const { - return ol->get_label(o); + return this->get_ptr().at(o); } EdgeLabel &at(InputMultiDiEdge const &e) { - return il.get_mutable()->get_label(e); - } - - EdgeLabel const &at(InputMultiDiEdge const &e) const { - return il->get_label(e); + return this->get_ptr().at(e); } template @@ -139,49 +124,41 @@ struct OutputLabelledOpenMultiDiGraph } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr().query_nodes(q); + return this->get_ptr().query_nodes(q); } std::unordered_set query_edges(OpenMultiDiEdgeQuery const &q) const { - return get_ptr().query_edges(q); + return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of, - std::is_base_of>::value, - OutputLabelledOpenMultiDiGraph>::type + template + static typename std::enable_if::value, + OutputLabelledOpenMultiDiGraph>::type create() { - return OutputLabelledOpenMultiDiGraph(make_cow_ptr(), - make_cow_ptr(), - make_cow_ptr(), - make_cow_ptr()); + return OutputLabelledOpenMultiDiGraph(make_cow_ptr()); } + using OutputLabelledOpenMultiDiGraphView::at; + private: - OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t il, - cow_ptr_t ol) - : GraphView(ptr), nl(nl), il(il), ol(ol) {} + OutputLabelledOpenMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } - - cow_ptr_t nl; - cow_ptr_t il; - cow_ptr_t ol; }; +template +void add_label(OutputLabelledOpenMultiDiGraph &g, + OpenMultiDiEdge const &e, + EdgeLabel const &l) { + visit([&](auto const &e) { g.add_label(e, l); }, e); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h b/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h new file mode 100644 index 0000000000..501805fe2a --- /dev/null +++ b/lib/utils/include/utils/graph/labelled/output_labelled_open_interfaces.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN_INTERFACES +#define _FLEXFLOW_UTILS_GRAPH_LABELLED_OUTPUT_LABELLED_OPEN_INTERFACES + +#include "node_labelled_open.h" + +namespace FlexFlow { + +template +struct IOutputLabelledOpenMultiDiGraphView + : virtual INodeLabelledOpenMultiDiGraphView { + virtual EdgeLabel const &at(InputMultiDiEdge const &) const = 0; + virtual EdgeLabel const &at(MultiDiOutput const &) const = 0; + + using INodeLabelledOpenMultiDiGraphView::at; +}; + +template +struct IOutputLabelledOpenMultiDiGraph + : virtual public IOutputLabelledOpenMultiDiGraphView { + virtual EdgeLabel &at(InputMultiDiEdge const &) = 0; + virtual EdgeLabel &at(MultiDiOutput const &) = 0; + virtual Node add_node(NodeLabel const &) = 0; + virtual NodePort add_node_port() = 0; + virtual NodeLabel &at(Node const &) = 0; + virtual void add_label(MultiDiOutput const &o, EdgeLabel const &l) = 0; + virtual void add_label(InputMultiDiEdge const &e, EdgeLabel const &l) = 0; + virtual void add_edge(OpenMultiDiEdge const &e) = 0; + + using IOutputLabelledOpenMultiDiGraphView::at; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/labelled/standard_labelled.h b/lib/utils/include/utils/graph/labelled/standard_labelled.h index 941a0470c2..34dabb5391 100644 --- a/lib/utils/include/utils/graph/labelled/standard_labelled.h +++ b/lib/utils/include/utils/graph/labelled/standard_labelled.h @@ -2,23 +2,10 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_STANDARD_LABELLED_H #include "node_labelled.h" +#include "standard_labelled_interfaces.h" namespace FlexFlow { -template -struct ILabelledMultiDiGraphView - : public INodeLabelledMultiDiGraphView { - ILabelledMultiDiGraphView() = default; - ILabelledMultiDiGraphView(ILabelledMultiDiGraphView const &) = delete; - ILabelledMultiDiGraphView & - operator=(ILabelledMultiDiGraphView const &) = delete; - - virtual ~ILabelledMultiDiGraphView() = default; - - virtual EdgeLabel const &at(MultiDiEdge const &) const = 0; -}; -CHECK_RC_COPY_VIRTUAL_COMPLIANT(ILabelledMultiDiGraphView); - template struct LabelledMultiDiGraphView : virtual public NodeLabelledMultiDiGraphView { @@ -32,19 +19,19 @@ struct LabelledMultiDiGraphView operator=(LabelledMultiDiGraphView const &) = default; NodeLabel const &at(Node const &n) const { - return get_ptr()->at(n); + return get_ptr().at(n); } EdgeLabel const &at(MultiDiEdge const &e) const { - return get_ptr()->at(e); + return get_ptr().at(e); } std::unordered_set query_nodes(NodeQuery const &q) const { - return get_ptr()->query_nodes(q); + return get_ptr().query_nodes(q); } std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return get_ptr()->query_edges(q); + return get_ptr().query_edges(q); } template @@ -60,8 +47,7 @@ struct LabelledMultiDiGraphView : NodeLabelledMultiDiGraphView(ptr) {} Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraphView); @@ -70,79 +56,59 @@ template struct LabelledMultiDiGraph : virtual LabelledMultiDiGraphView { private: - using Interface = IMultiDiGraph; - using INodeLabel = ILabelling; - using IEdgeLabel = ILabelling; + using Interface = ILabelledMultiDiGraph; public: - // LabelledMultiDiGraph() = delete; LabelledMultiDiGraph(LabelledMultiDiGraph const &other) = default; LabelledMultiDiGraph &operator=(LabelledMultiDiGraph const &other) = default; Node add_node(NodeLabel const &l) { - Node n = MultiDiGraph::add_node(); - nl->add_label(n, l); - return n; + return this->get_ptr().add_node(); } NodePort add_node_port() { - return this->get_ptr()->add_node_port(); + return this->get_ptr().add_node_port(); } NodeLabel &at(Node const &n) { - return nl->get_label(n); - } - - NodeLabel const &at(Node const &n) const { - return nl->get_label(n); + return this->get_ptr().at(n); } void add_edge(MultiDiEdge const &e, EdgeLabel const &l) { - return this->get_ptr()->add_edge(e, l); + return this->get_ptr().add_edge(e, l); } + EdgeLabel &at(MultiDiEdge const &e) { - return el->get_label(e); - } - EdgeLabel const &at(MultiDiEdge const &e) const { - return el->get_label(e); + return this->get_ptr().at(e); } std::unordered_set query_nodes(NodeQuery const &q) const { - return this->get_ptr()->query_nodes(q); + return this->get_ptr().query_nodes(q); } + std::unordered_set query_edges(MultiDiEdgeQuery const &q) const { - return this->get_ptr()->query_edges(q); + return this->get_ptr().query_edges(q); } - template - static typename std::enable_if< - std::conjunction, - std::is_base_of, - std::is_base_of>::value, - LabelledMultiDiGraph>::type + using LabelledMultiDiGraphView::at; + + template + static typename std::enable_if::value, + LabelledMultiDiGraph>::type create() { - return LabelledMultiDiGraph( - make_cow_ptr(), make_cow_ptr(), make_cow_ptr()); + return LabelledMultiDiGraph(make_cow_ptr()); } private: - LabelledMultiDiGraph(cow_ptr_t ptr, - cow_ptr_t nl, - cow_ptr_t el) - : LabelledMultiDiGraphView(ptr), nl(nl), el(el) {} + LabelledMultiDiGraph(cow_ptr_t ptr) : GraphView(ptr) {} Interface &get_ptr() { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } Interface const &get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } - - cow_ptr_t nl; - cow_ptr_t el; }; CHECK_WELL_BEHAVED_VALUE_TYPE_NO_EQ(LabelledMultiDiGraph); diff --git a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h index f7af522b3c..fe396e5989 100644 --- a/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled/unordered_labelled_graphs.h @@ -1,138 +1,227 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_UNORDERED_LABELLED_GRAPHS_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_UNORDERED_LABELLED_GRAPHS_H -#include "labelled_open_interfaces.h" -#include "node_labelled_interfaces.h" -#include "output_labelled_interfaces.h" -#include "standard_labelled_interfaces.h" -#include "utils/graph/open_graphs.h" +#include "output_labelled_open_interfaces.h" +#include "unordered_label.h" +#include "utils/graph/adjacency_openmultidigraph.h" namespace FlexFlow { template -struct UnorderedNodeLabelledMultiDiGraph - : public INodeLabelledMultiDiGraph, - protected MultiDiGraph { -public: - UnorderedNodeLabelledMultiDiGraph() = delete; +struct UnorderedNodeLabelledOpenMultiDiGraph + : public INodeLabelledOpenMultiDiGraph { - Node add_node(NodeLabel const &label) override { - Node n = MultiDiGraph::add_node(); - node_map.insert({n, label}); - return n; + UnorderedNodeLabelledOpenMultiDiGraph() + : g(OpenMultiDiGraph::create()) {} + + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; } - NodeLabel &at(Node const &n) override { - return this->node_map.at(n); + NodePort add_node_port() override { + return this->g.add_node_port(); } NodeLabel const &at(Node const &n) const override { - return this->node_map.at(n); + return this->node_labelling.get_label(n); } - using MultiDiGraph::query_edges; - using MultiDiGraph::query_nodes; + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } -private: - std::unordered_map node_map; -}; + void add_edge(OpenMultiDiEdge const &e) override { + this->g.add_edge(e); + } -template -struct UnorderedLabelledMultiDiGraph - : public ILabelledMultiDiGraph, - public UnorderedNodeLabelledMultiDiGraph { - void add_edge(MultiDiEdge const &e, EdgeLabel const &label) override { - MultiDiGraph::add_edge(e); - edge_map.insert({e, label}); + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); } - EdgeLabel &at(MultiDiEdge const &n) override { - return this->edge_map.at(n); + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return g.query_edges(q); } - EdgeLabel const &at(MultiDiEdge const &n) const override { - return this->edge_map.at(n); + using INodeLabelledOpenMultiDiGraph::query_edges; + + UnorderedNodeLabelledOpenMultiDiGraph *clone() const override { + return new UnorderedNodeLabelledOpenMultiDiGraph(g, + node_labelling); } private: - std::unordered_map edge_map; -}; + UnorderedNodeLabelledOpenMultiDiGraph( + OpenMultiDiGraph const &g, + UnorderedLabelling const &node_labelling) + : g(g), node_labelling(node_labelling) {} -MultiDiOutput get_output(MultiDiEdge const &e); + OpenMultiDiGraph g; + UnorderedLabelling node_labelling; +}; +CHECK_NOT_ABSTRACT(UnorderedNodeLabelledOpenMultiDiGraph); template struct UnorderedOutputLabelledMultiDiGraph - : public IOutputLabelledMultiDiGraph, - public UnorderedNodeLabelledMultiDiGraph { -public: + : public IOutputLabelledMultiDiGraph { + + UnorderedOutputLabelledMultiDiGraph() + : g(MultiDiGraph::create()) {} + + OutputLabel const &at(MultiDiOutput const &i) const override { + return this->output_labelling.get_label(i); + } + + OutputLabel &at(MultiDiOutput const &i) override { + return this->output_labelling.get_label(i); + } + + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; + } + + NodePort add_node_port() override { + return this->g.add_node_port(); + } + + NodeLabel const &at(Node const &n) const override { + return this->node_labelling.get_label(n); + } + + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } + + void add_edge(MultiDiEdge const &e) override { + this->g.add_edge(e); + } + void add_output(MultiDiOutput const &output, OutputLabel const &label) override { - this->output_map.insert({output, label}); + this->output_labelling.add_label(output, label); } - void add_edge(MultiDiEdge const &e) override { - MultiDiOutput output = get_output(e); - if (!contains_key(this->output_map, output)) { - throw mk_runtime_error("Could not find output {}", output); - } - this->add_edge(e); + std::unordered_set query_nodes(NodeQuery const &q) const override { + return g.query_nodes(q); } - void add_edge(MultiDiOutput const &output, - MultiDiInput const &input) override { - this->add_edge(MultiDiEdge{output.node, input.node, output.idx, input.idx}); + std::unordered_set + query_edges(MultiDiEdgeQuery const &q) const override { + return g.query_edges(q); + } + + using IOutputLabelledMultiDiGraph::query_edges; + + UnorderedOutputLabelledMultiDiGraph *clone() const override { + return new UnorderedOutputLabelledMultiDiGraph( + g, node_labelling, output_labelling); } private: - std::unordered_map output_map; + UnorderedOutputLabelledMultiDiGraph( + MultiDiGraph const &g, + UnorderedLabelling const &node_labelling, + UnorderedLabelling const &output_labelling) + : g(g), node_labelling(node_labelling), + output_labelling(output_labelling) {} + + MultiDiGraph g; + UnorderedLabelling node_labelling; + UnorderedLabelling output_labelling; }; +CHECK_NOT_ABSTRACT(UnorderedOutputLabelledMultiDiGraph); -template -struct UnorderedLabelledOpenMultiDiGraph - : public ILabelledOpenMultiDiGraph, - public UnorderedLabelledMultiDiGraph { -public: - void add_edge(InputMultiDiEdge const &e, InputLabel const &label) { - this->add_edge(e); - this->input_map.insert({e, label}); +template +struct UnorderedOutputLabelledOpenMultiDiGraph + : public IOutputLabelledOpenMultiDiGraph { + + UnorderedOutputLabelledOpenMultiDiGraph() + : g(OpenMultiDiGraph::create()) {} + + EdgeLabel const &at(InputMultiDiEdge const &i) const override { + return this->input_labelling.get_label(i); } - void add_edge(OutputMultiDiEdge const &e, OutputLabel const &label) { - this->add_edge(e); - this->output_map.insert({e, label}); + EdgeLabel &at(InputMultiDiEdge const &i) override { + return this->input_labelling.get_label(i); } - InputLabel const &at(InputMultiDiEdge const &e) const { - return this->input_map.at(e); + EdgeLabel const &at(MultiDiOutput const &i) const override { + return this->output_labelling.get_label(i); } - InputLabel &at(InputMultiDiEdge const &e) { - return this->input_map.at(e); + EdgeLabel &at(MultiDiOutput const &i) override { + return this->output_labelling.get_label(i); } - OutputLabel const &at(OutputMultiDiEdge const &e) const { - return this->output_map.at(e); + Node add_node(NodeLabel const &l) override { + Node node = g.add_node(); + this->node_labelling.add_label(node, l); + return node; } - OutputLabel &at(DownwardOpenMultiDiEdge const &e) { - return this->output_map.at(e); + NodePort add_node_port() override { + return this->g.add_node_port(); } - UnorderedLabelledOpenMultiDiGraph() { - NOT_IMPLEMENTED(); + NodeLabel const &at(Node const &n) const override { + return this->node_labelling.get_label(n); + } + + NodeLabel &at(Node const &n) override { + return this->node_labelling.get_label(n); + } + + void add_label(MultiDiOutput const &o, EdgeLabel const &l) override { + this->output_labelling.add_label(o, l); + } + + void add_label(InputMultiDiEdge const &i, EdgeLabel const &l) override { + this->input_labelling.add_label(i, l); + } + + void add_edge(OpenMultiDiEdge const &e) override { + this->g.add_edge(e); + } + + std::unordered_set query_nodes(NodeQuery const &q) const override { + return this->g.query_nodes(q); + } + + std::unordered_set + query_edges(OpenMultiDiEdgeQuery const &q) const override { + return this->g.query_edges(q); + } + + using IOutputLabelledOpenMultiDiGraph::query_edges; + + UnorderedOutputLabelledOpenMultiDiGraph *clone() const override { + return new UnorderedOutputLabelledOpenMultiDiGraph( + g, node_labelling, input_labelling, output_labelling); } private: - OpenMultiDiGraph base_graph; - std::unordered_map input_map; - std::unordered_map output_map; + UnorderedOutputLabelledOpenMultiDiGraph( + OpenMultiDiGraph const &g, + UnorderedLabelling const &node_labelling, + UnorderedLabelling const &input_labelling, + UnorderedLabelling const &output_labelling) + : g(g), node_labelling(node_labelling), input_labelling(input_labelling), + output_labelling(output_labelling) {} + + OpenMultiDiGraph g; + UnorderedLabelling node_labelling; + UnorderedLabelling input_labelling; + UnorderedLabelling output_labelling; }; +CHECK_NOT_ABSTRACT( + UnorderedOutputLabelledOpenMultiDiGraph); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled/views.h b/lib/utils/include/utils/graph/labelled/views.h index 8455cd2bcb..e31afad916 100644 --- a/lib/utils/include/utils/graph/labelled/views.h +++ b/lib/utils/include/utils/graph/labelled/views.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_VIEWS_H #include "node_labelled.h" +#include "output_labelled_open.h" #include "standard_labelled.h" namespace FlexFlow { @@ -54,6 +55,10 @@ struct ViewMultiDiGraphAsOutputLabelled return output_label(o); } + OutputLabel const &at(MultiDiOutput const &o) const override { + return output_label(o); + } + ViewMultiDiGraphAsOutputLabelled *clone() const { return new ViewMultiDiGraphAsOutputLabelled(g, node_label, output_label); } @@ -69,7 +74,7 @@ CHECK_NOT_ABSTRACT(ViewMultiDiGraphAsOutputLabelled Impl materialize_output_labelled_multidigraph_view( - IOutputLabelledMultiDiGraphView const &g) { + OutputLabelledMultiDiGraphView const &g) { Impl result; for (Node const &n : get_nodes(g)) { result.add_node_unsafe(n); @@ -84,6 +89,41 @@ Impl materialize_output_labelled_multidigraph_view( return result; } +template +OutputLabelledOpenMultiDiGraph + materialize_output_labelled_multidigraph_view( + OutputLabelledOpenMultiDiGraphView const &g) { + OutputLabelledOpenMultiDiGraph result = + OutputLabelledOpenMultiDiGraph::template create< + Impl, + NodeLabelImpl, + InputLabelImpl, + OutputLabelImpl>(); + for (Node const &n : get_nodes(g)) { + result.add_node_unsafe(n, g.at(n)); + } + for (OpenMultiDiEdge const &e : get_edges(g)) { + result.add_edge(e); + if (is_input_edge(e)) { + InputMultiDiEdge input_edge = get(e); + result.add_label(input_edge, g.at(input_edge)); + } else { + MultiDiOutput output = + is_standard_edge(e) + ? static_cast(get(e)) + : static_cast(get(e)); + auto tensor = g.at(output); + result.add_label(output, tensor); + } + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/graph/labelled_graphs.h b/lib/utils/include/utils/graph/labelled_graphs.h index 5c4b29038a..9cf5f0d97e 100644 --- a/lib/utils/include/utils/graph/labelled_graphs.h +++ b/lib/utils/include/utils/graph/labelled_graphs.h @@ -10,6 +10,7 @@ #include "labelled/output_labelled_open.h" #include "labelled/standard_labelled.h" #include "labelled/unordered_label.h" +#include "labelled/unordered_labelled_graphs.h" #include "labelled/views.h" #endif diff --git a/lib/utils/include/utils/graph/open_graphs.h b/lib/utils/include/utils/graph/open_graphs.h index ce4ec8b1cc..0b0db44f93 100644 --- a/lib/utils/include/utils/graph/open_graphs.h +++ b/lib/utils/include/utils/graph/open_graphs.h @@ -50,6 +50,7 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { Node add_node(); void add_node_unsafe(Node const &); void remove_node_unsafe(Node const &); + NodePort add_node_port(); void add_edge(Edge const &); void remove_edge(Edge const &); @@ -60,7 +61,7 @@ struct OpenMultiDiGraph : virtual OpenMultiDiGraphView { static typename std::enable_if::value, OpenMultiDiGraph>::type create() { - return make_cow_ptr(); + return OpenMultiDiGraph(make_cow_ptr()); } private: diff --git a/lib/utils/include/utils/graph/views.h b/lib/utils/include/utils/graph/views.h index e891a948f0..a0ef837796 100644 --- a/lib/utils/include/utils/graph/views.h +++ b/lib/utils/include/utils/graph/views.h @@ -255,8 +255,8 @@ struct OpenMultiDiSubgraphView : public IOpenMultiDiGraphView { OpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set inputs; std::unordered_set outputs; }; @@ -273,8 +273,8 @@ struct UpwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { UpwardOpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set inputs; }; @@ -290,8 +290,8 @@ struct DownwardOpenMultiDiSubgraphView : public IOpenMultiDiGraphView { DownwardOpenMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; std::unordered_set outputs; }; @@ -307,8 +307,8 @@ struct ClosedMultiDiSubgraphView : public IOpenMultiDiGraphView { ClosedMultiDiSubgraphView *clone() const override; private: - OpenMultiDiGraphView const &g; - std::unordered_set const &nodes; + OpenMultiDiGraphView g; + std::unordered_set nodes; }; UndirectedEdge to_undirected_edge(DirectedEdge const &); diff --git a/lib/utils/include/utils/hash-utils.h b/lib/utils/include/utils/hash-utils.h index 923c8df840..d56ff34644 100644 --- a/lib/utils/include/utils/hash-utils.h +++ b/lib/utils/include/utils/hash-utils.h @@ -4,6 +4,8 @@ #include "containers.h" #include "hash-utils-core.h" +using namespace FlexFlow; + namespace std { template struct hash> { @@ -18,7 +20,7 @@ struct hash> { template struct hash> { size_t operator()(std::unordered_map const &m) const { - return get_std_hash(items(m)); + return get_std_hash(::FlexFlow::items(m)); } }; diff --git a/lib/utils/include/utils/variant.h b/lib/utils/include/utils/variant.h index b3ae3de115..272caaffde 100644 --- a/lib/utils/include/utils/variant.h +++ b/lib/utils/include/utils/variant.h @@ -42,7 +42,7 @@ struct elements_satisfy> : elements_satisfy_impl {}; template -struct is_in_variant; +struct is_in_variant : std::false_type {}; template struct is_in_variant> : std::true_type {}; template @@ -169,7 +169,7 @@ auto widen(Container const &c) -> decltype(transform( template < typename VariantOut, typename VariantIn, - typename = std::enable_if::value>> + typename = std::enable_if_t::value>> std::optional narrow(VariantIn const &v) { return visit(VariantNarrowFunctor{}, v); } @@ -178,7 +178,7 @@ template < typename VariantOut, typename Container, typename VariantIn = typename Container::value_type, - typename = std::enable_if::value>> + typename = std::enable_if_t::value>> auto narrow(Container const &c) -> decltype(transform( c, std::declval< @@ -186,12 +186,20 @@ auto narrow(Container const &c) -> decltype(transform( return transform(c, [](VariantIn const &i) { return narrow(i); }); } +template ::value>> +auto narrow(Container const &c) { + return transform(c, [](VariantIn const &e) { return get(e); }); +} + template < typename T1, typename T2, typename... Trest, typename VariantIn, - typename = std::enable_if< + typename = std::enable_if_t< !is_subeq_variant, VariantIn>::value>> std::optional> narrow(VariantIn const &v) { return visit(VariantNarrowFunctor>{}, v); diff --git a/lib/utils/src/graph/algorithms.cc b/lib/utils/src/graph/algorithms.cc index 449a9a8203..2223b120a7 100644 --- a/lib/utils/src/graph/algorithms.cc +++ b/lib/utils/src/graph/algorithms.cc @@ -9,6 +9,7 @@ #include "utils/graph/traversal.h" #include "utils/graph/undirected.h" #include "utils/graph/views.h" +#include "utils/variant.h" #include #include #include @@ -42,6 +43,10 @@ std::vector add_nodes(MultiDiGraph &g, int num_nodes) { return add_nodes_impl(g, num_nodes); } +std::vector add_nodes(OpenMultiDiGraph &g, int num_nodes) { + return add_nodes_impl(g, num_nodes); +} + std::vector add_node_ports(MultiDiGraph &g, int num_node_ports) { std::vector node_ports; for (int i = 0; i < num_node_ports; i++) { @@ -164,7 +169,9 @@ DiGraphView apply_contraction(DiGraphView const &g, for (auto const &kv : nodes) { Node from = kv.first; Node into = kv.second; - contractedView = contract_node(contractedView, from, into); + if (from != into) { + contractedView = contract_node(contractedView, from, into); + } } return contractedView; } @@ -250,7 +257,7 @@ std::unordered_set get_node_edges(UndirectedGraphView const &g, std::unordered_set get_outputs(MultiDiGraphView const &g) { return transform(get_edges(g), [&](MultiDiEdge const &e) -> MultiDiOutput { - return MultiDiOutput(e); + return static_cast(e); }); } @@ -327,24 +334,29 @@ std::unordered_map> std::unordered_set get_outgoing_edges(OpenMultiDiGraphView const &g, Node const &n) { - return transform(g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), - MultiDiEdgeQuery::all().with_src_nodes({n}), - OutputMultiDiEdgeQuery::all().with_src_nodes({n}))), - [](OpenMultiDiEdge const &e) { - return narrow(e).value(); - }); + return value_all( + narrow(g.query_edges(OpenMultiDiEdgeQuery( + InputMultiDiEdgeQuery::none(), + MultiDiEdgeQuery::all().with_src_nodes({n}), + OutputMultiDiEdgeQuery::all().with_src_nodes({n}))))); } std::unordered_set get_incoming_edges(OpenMultiDiGraphView const &g, Node const &n) { - return transform(g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::all().with_dst_nodes({n}), - MultiDiEdgeQuery::all().with_dst_nodes({n}), - OutputMultiDiEdgeQuery::none())), - [](OpenMultiDiEdge const &e) { - return narrow(e).value(); - }); + return value_all(narrow(g.query_edges( + OpenMultiDiEdgeQuery(InputMultiDiEdgeQuery::all().with_dst_nodes({n}), + MultiDiEdgeQuery::all().with_dst_nodes({n}), + OutputMultiDiEdgeQuery::none())))); +} + +std::unordered_set + get_open_outputs(OpenMultiDiGraphView const &g) { + return narrow( + g.query_edges(OutputMultiDiEdgeQuery::all())); +} +std::unordered_set + get_open_inputs(OpenMultiDiGraphView const &g) { + return narrow(g.query_edges(InputMultiDiEdgeQuery::all())); } std::unordered_map> @@ -758,4 +770,28 @@ std::unordered_set> return components; } +std::unordered_set get_closed_sources(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return get_incoming_edges(g, n).size() == 0; + }); +} + +std::unordered_set get_closed_sinks(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return get_outgoing_edges(g, n).size() == 0; + }); +} + +std::unordered_set get_open_sources(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return !get_incoming_edges(g, n).empty(); + }); +} + +std::unordered_set get_open_sinks(OpenMultiDiGraphView const &g) { + return filter(get_nodes(g), [&](Node const &n) { + return !get_outgoing_edges(g, n).empty(); + }); +} + } // namespace FlexFlow diff --git a/lib/utils/src/graph/digraph.cc b/lib/utils/src/graph/digraph.cc index dda9eef5e0..bdfe5ff599 100644 --- a/lib/utils/src/graph/digraph.cc +++ b/lib/utils/src/graph/digraph.cc @@ -14,8 +14,7 @@ std::unordered_set } IDiGraphView const &DiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } Node DiGraph::add_node() { @@ -48,11 +47,11 @@ std::unordered_set } IDiGraph &DiGraph::get_ptr() { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } IDiGraph const &DiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/multidigraph.cc b/lib/utils/src/graph/multidigraph.cc index 99a7ea86fa..771e01e573 100644 --- a/lib/utils/src/graph/multidigraph.cc +++ b/lib/utils/src/graph/multidigraph.cc @@ -24,7 +24,7 @@ std::unordered_set } IMultiDiGraphView const &MultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -66,12 +66,11 @@ std::unordered_set MultiDiGraph::query_nodes(NodeQuery const &q) const { } IMultiDiGraph const &MultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( - GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IMultiDiGraph &MultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } diff --git a/lib/utils/src/graph/node.cc b/lib/utils/src/graph/node.cc index 9854afffbf..72caa3136e 100644 --- a/lib/utils/src/graph/node.cc +++ b/lib/utils/src/graph/node.cc @@ -53,11 +53,11 @@ std::unordered_set Graph::query_nodes(NodeQuery const &q) const { } IGraph const &Graph::get_ptr() const { - return *std::reinterpret_pointer_cast(GraphView::ptr.get()); + return *std::dynamic_pointer_cast(GraphView::ptr.get()); } IGraph &Graph::get_ptr() { - return *std::reinterpret_pointer_cast(GraphView::ptr.get_mutable()); + return *std::dynamic_pointer_cast(GraphView::ptr.get_mutable()); } } // namespace FlexFlow diff --git a/lib/utils/src/graph/open_edge.cc b/lib/utils/src/graph/open_edge.cc index b12f87dd1c..1b571d5c6c 100644 --- a/lib/utils/src/graph/open_edge.cc +++ b/lib/utils/src/graph/open_edge.cc @@ -3,15 +3,15 @@ namespace FlexFlow { bool is_input_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } bool is_output_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } bool is_standard_edge(OpenMultiDiEdge const &e) { - return holds_alternative(e); + return std::holds_alternative(e); } OpenMultiDiEdgeQuery::OpenMultiDiEdgeQuery( diff --git a/lib/utils/src/graph/open_graphs.cc b/lib/utils/src/graph/open_graphs.cc index 8b74729d77..387dd7e75b 100644 --- a/lib/utils/src/graph/open_graphs.cc +++ b/lib/utils/src/graph/open_graphs.cc @@ -22,7 +22,7 @@ std::unordered_set } IOpenMultiDiGraphView const &OpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -51,13 +51,17 @@ std::unordered_set return this->get_ptr().query_edges(q); } +NodePort OpenMultiDiGraph::add_node_port() { + return this->get_ptr().add_node_port(); +} + IOpenMultiDiGraph &OpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } IOpenMultiDiGraph const &OpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -73,7 +77,7 @@ std::unordered_set } IUpwardOpenMultiDiGraphView const &UpwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -103,12 +107,12 @@ std::unordered_set UpwardOpenMultiDiGraph::query_edges( } IUpwardOpenMultiDiGraph const &UpwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IUpwardOpenMultiDiGraph &UpwardOpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -125,7 +129,7 @@ std::unordered_set IDownwardOpenMultiDiGraphView const & DownwardOpenMultiDiGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } @@ -161,12 +165,12 @@ std::unordered_set } IDownwardOpenMultiDiGraph &DownwardOpenMultiDiGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } IDownwardOpenMultiDiGraph const &DownwardOpenMultiDiGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/serialparallel.cc b/lib/utils/src/graph/serialparallel.cc index 33fc99b079..f1c9e41005 100644 --- a/lib/utils/src/graph/serialparallel.cc +++ b/lib/utils/src/graph/serialparallel.cc @@ -19,7 +19,7 @@ Node find_sink_node(DiGraphView const &g) { std::optional find_bottleneck_node(DiGraphView const &g) { std::unordered_set sources = get_sources(g); - std::unordered_set sinks = get_sources(g); + std::unordered_set sinks = get_sinks(g); std::optional maybe_bottleneck = get_imm_post_dominator(g, sources); if (maybe_bottleneck.has_value()) { @@ -72,7 +72,7 @@ std::unordered_set if (include_src == SourceSettings::INCLUDE_SOURCE_NODES) { result = set_union(result, srcs); } - if (include_sink == SinkSettings::EXCLUDE_SINK_NODES) { + if (include_sink == SinkSettings::INCLUDE_SINK_NODES) { result = set_union(result, sinks); } return result; @@ -103,12 +103,12 @@ SplitAST sp_decomposition(DiGraphView const &g) { sources, {bottleneck.value()}, SourceSettings::INCLUDE_SOURCE_NODES, - SinkSettings::INCLUDE_SINK_NODES)), + SinkSettings::EXCLUDE_SINK_NODES)), sp_decomposition(source_to_sink_subgraph( g, {bottleneck.value()}, sinks, - SourceSettings::EXCLUDE_SOURCE_NODES, + SourceSettings::INCLUDE_SOURCE_NODES, SinkSettings::INCLUDE_SINK_NODES))); } else { return parallel_decomposition(g); @@ -142,7 +142,7 @@ SplitASTNode::SplitASTNode(SplitType type, struct FlattenAST { void add_flattened_child_to_parent(SplitASTNode &parent, SplitAST const &child) { - if (holds_alternative(child)) { + if (std::holds_alternative(child)) { parent.children.push_back(child); return; } @@ -178,11 +178,11 @@ struct ToFinalAST { std::variant operator()(SplitASTNode const &node) { if (node.type == SplitType::SERIAL) { return Serial{transform(node.children, [](SplitAST const &s) { - return narrow(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } else { return Parallel{transform(node.children, [](SplitAST const &s) { - return narrow(to_final_ast(s)).value(); + return narrow>(to_final_ast(s)).value(); })}; } } @@ -195,6 +195,13 @@ struct ToFinalAST { std::variant to_final_ast(SplitAST const &ast) { return visit(ToFinalAST{}, ast); } + +SerialParallelDecomposition + get_serial_parallel_decomposition(DiGraphView const &g) { + SplitAST ast = sp_decomposition(g); + return to_final_ast(ast); +} + struct GetNodes { template std::unordered_set operator()(T const &t) { diff --git a/lib/utils/src/graph/undirected.cc b/lib/utils/src/graph/undirected.cc index ce42cfe22c..b1e8be7f14 100644 --- a/lib/utils/src/graph/undirected.cc +++ b/lib/utils/src/graph/undirected.cc @@ -26,12 +26,12 @@ void UndirectedGraph::remove_edge(UndirectedEdge const &e) { } IUndirectedGraph const &UndirectedGraph::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } IUndirectedGraph &UndirectedGraph::get_ptr() { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get_mutable()); } @@ -56,7 +56,7 @@ std::unordered_set } IUndirectedGraphView const &UndirectedGraphView::get_ptr() const { - return *std::reinterpret_pointer_cast( + return *std::dynamic_pointer_cast( GraphView::ptr.get()); } diff --git a/lib/utils/src/graph/views.cc b/lib/utils/src/graph/views.cc index 062dca6858..af15b0d6aa 100644 --- a/lib/utils/src/graph/views.cc +++ b/lib/utils/src/graph/views.cc @@ -445,9 +445,10 @@ std::unordered_set OpenMultiDiSubgraphView::OpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), - inputs(transform(get_cut_set(g, nodes), to_inputmultidiedge)), - outputs(transform(get_cut_set(g, nodes), to_outputmultidiedge)) {} + : g(g), nodes(nodes) { + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); +} std::unordered_set OpenMultiDiSubgraphView::query_edges(OpenMultiDiEdgeQuery const &q) const { @@ -469,7 +470,9 @@ std::unordered_set UpwardOpenMultiDiSubgraphView::UpwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes), inputs(inputs) {} + : g(g), nodes(nodes) { + this->inputs = transform(get_cut_set(g, nodes), to_inputmultidiedge); +} UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { return new UpwardOpenMultiDiSubgraphView(g, nodes); @@ -477,11 +480,11 @@ UpwardOpenMultiDiSubgraphView *UpwardOpenMultiDiSubgraphView::clone() const { std::unordered_set UpwardOpenMultiDiSubgraphView::query_edges( OpenMultiDiEdgeQuery const &q) const { - std::unordered_set result = - g.query_edges(OpenMultiDiEdgeQuery( - q.input_edge_query.with_dst_nodes(nodes), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - OutputMultiDiEdgeQuery::none())); + OpenMultiDiEdgeQuery subgraph_query( + q.input_edge_query.with_dst_nodes(nodes), + q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), + OutputMultiDiEdgeQuery::none()); + std::unordered_set result = g.query_edges(subgraph_query); extend(result, query_edge(inputs, q.input_edge_query.with_dst_nodes(nodes))); return result; } @@ -493,16 +496,18 @@ std::unordered_set DownwardOpenMultiDiSubgraphView::DownwardOpenMultiDiSubgraphView( OpenMultiDiGraphView const &g, std::unordered_set const &nodes) - : g(g), nodes(nodes) {} + : g(g), nodes(nodes) { + this->outputs = transform(get_cut_set(g, nodes), to_outputmultidiedge); +} std::unordered_set DownwardOpenMultiDiSubgraphView::query_edges( OpenMultiDiEdgeQuery const &q) const { - std::unordered_set result = - g.query_edges(OpenMultiDiEdgeQuery( - InputMultiDiEdgeQuery::none(), - q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), - q.output_edge_query.with_src_nodes(nodes))); + OpenMultiDiEdgeQuery subgraph_query( + InputMultiDiEdgeQuery::none(), + q.standard_edge_query.with_src_nodes(nodes).with_dst_nodes(nodes), + q.output_edge_query.with_src_nodes(nodes)); + std::unordered_set result = g.query_edges(subgraph_query); extend(result, query_edge(outputs, q.output_edge_query.with_src_nodes(nodes))); return result; diff --git a/lib/utils/test/CMakeLists.txt b/lib/utils/test/CMakeLists.txt index be4b33129b..40ff07285e 100644 --- a/lib/utils/test/CMakeLists.txt +++ b/lib/utils/test/CMakeLists.txt @@ -1,14 +1,14 @@ -# ff_add_test_executable( -# NAME -# utils-test -# SRC_PATTERNS -# src/*.cc -# PRIVATE_INCLUDE -# src/ -# DEPS -# utils -# doctest -# utils-test-common -# ) +ff_add_test_executable( + NAME + utils-tests + SRC_PATTERNS + src/test_cow_ptr.cc + PRIVATE_INCLUDE + src/ + DEPS + utils + doctest + utils-test-common +) add_subdirectory(common) diff --git a/lib/utils/test/src/test_algorithms.cc b/lib/utils/test/src/test_algorithms.cc index 2e97496b6b..0fb258bf15 100644 --- a/lib/utils/test/src/test_algorithms.cc +++ b/lib/utils/test/src/test_algorithms.cc @@ -12,232 +12,236 @@ using namespace FlexFlow; -TEST_CASE("MultiDiGraph") { - MultiDiGraph g = MultiDiGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector p = add_node_ports(g, 4); - - MultiDiEdge e0{n[3], p[3], n[0], p[0]}; - MultiDiEdge e1{n[2], p[2], n[1], p[0]}; - MultiDiEdge e2{n[3], p[3], n[1], p[1]}; - MultiDiEdge e3{n[3], p[3], n[2], p[2]}; - - std::vector e = {e0, e1, e2, e3}; - - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[1], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{e[3]}); - std::unordered_map> expected_result = - std::unordered_map>{ - {n[1], {}}, - {n[2], {n[1]}}, - {n[3], {n[0], n[1], n[2]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); -} - -TEST_CASE("DiGraph") { - DiGraph g = DiGraph::create(); - - std::vector n = add_nodes(g, 4); - std::vector e = { - {n[0], n[3]}, - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[2]}, - }; - add_edges(g, e); - - CHECK(get_incoming_edges(g, {n[2], n[3]}) == - std::unordered_set{e[0], e[2], e[3]}); - CHECK(get_outgoing_edges(g, {n[2], n[3]}) == - std::unordered_set{}); - auto expected_result = std::unordered_map>{ - {n[1], {n[0]}}, - {n[2], {n[0], n[1]}}, - {n[3], {n[0]}}, - }; - CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); - - SUBCASE("get_imm_dominators") { - std::unordered_map> result = get_imm_dominators(g); - - std::unordered_map> expected_result = { - {n[2], n[0]}, - {n[1], n[0]}, - {n[3], n[0]}, - {n[0], nullopt}, - }; - CHECK(result == expected_result); - } - - SUBCASE("get_dominators") { - std::unordered_map> expected = { - {n[0], {n[0]}}, - {n[1], {n[0], n[1]}}, - {n[2], {n[0], n[2]}}, - {n[3], {n[0], n[3]}}, - }; - CHECK(get_dominators(g) == expected); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MultiDiGraph") { + MultiDiGraph g = MultiDiGraph::create(); + std::vector n = add_nodes(g, 4); + std::vector p = add_node_ports(g, 4); + + MultiDiEdge e0{n[3], p[3], n[0], p[0]}; + MultiDiEdge e1{n[2], p[2], n[1], p[0]}; + MultiDiEdge e2{n[3], p[3], n[1], p[1]}; + MultiDiEdge e3{n[3], p[3], n[2], p[2]}; + + std::vector e = {e0, e1, e2, e3}; + + add_edges(g, e); + + CHECK(get_incoming_edges(g, {n[1], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_incoming_edges(g, {n[1]}) == std::unordered_set{}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{e[3]}); + std::unordered_map> expected_result = + std::unordered_map>{ + {n[1], {}}, + {n[2], {n[1]}}, + {n[3], {n[0], n[1], n[2]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); } - SUBCASE("get_sinks") { - auto expected = std::unordered_set{n[2], n[3]}; - CHECK(get_sinks(g) == expected); - } + TEST_CASE("DiGraph") { + DiGraph g = DiGraph::create(); - SUBCASE("get_bfs") { - std::unordered_set start_points = std::unordered_set{n[0]}; - auto expected = std::vector{n[0], n[2], n[1], n[3]}; - CHECK(get_bfs_ordering(g, start_points) == expected); - } + std::vector n = add_nodes(g, 4); + std::vector e = { + {n[0], n[3]}, + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[2]}, + }; + add_edges(g, e); - SUBCASE("get_predecessors") { - std::unordered_map> expected_result = { + CHECK(get_incoming_edges(g, {n[2], n[3]}) == + std::unordered_set{e[0], e[2], e[3]}); + CHECK(get_outgoing_edges(g, {n[2], n[3]}) == + std::unordered_set{}); + auto expected_result = std::unordered_map>{ {n[1], {n[0]}}, {n[2], {n[0], n[1]}}, + {n[3], {n[0]}}, }; - CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); - } -} + CHECK(get_predecessors(g, {n[1], n[2], n[3]}) == expected_result); -TEST_CASE("traversal") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 5); - std::vector edges = {{n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; - add_edges(g, edges); - - CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); - CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(get_bfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == true); - CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); - CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); - - SUBCASE("with root") { - g.add_edge({n[3], n[2]}); - - CHECK(get_dfs_ordering(g, {n[0]}) == - std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); + SUBCASE("get_imm_dominators") { + std::unordered_map> result = get_imm_dominators(g); + + std::unordered_map> expected_result = { + {n[2], n[0]}, + {n[1], n[0]}, + {n[3], n[0]}, + {n[0], nullopt}, + }; + CHECK(result == expected_result); + } + + SUBCASE("get_dominators") { + std::unordered_map> expected = { + {n[0], {n[0]}}, + {n[1], {n[0], n[1]}}, + {n[2], {n[0], n[2]}}, + {n[3], {n[0], n[3]}}, + }; + CHECK(get_dominators(g) == expected); + } + + SUBCASE("get_sinks") { + auto expected = std::unordered_set{n[2], n[3]}; + CHECK(get_sinks(g) == expected); + } + + SUBCASE("get_bfs") { + std::unordered_set start_points = std::unordered_set{n[0]}; + auto expected = std::vector{n[0], n[2], n[1], n[3]}; + CHECK(get_bfs_ordering(g, start_points) == expected); + } + + SUBCASE("get_predecessors") { + std::unordered_map> expected_result = { + {n[1], {n[0]}}, + {n[2], {n[0], n[1]}}, + }; + CHECK(get_predecessors(g, {n[1], n[2]}) == expected_result); + } } - SUBCASE("without root") { - g.add_edge({n[3], n[0]}); + TEST_CASE("traversal") { + DiGraph g = DiGraph::create(); + std::vector const n = add_nodes(g, 5); + std::vector edges = { + {n[0], n[1]}, {n[1], n[2]}, {n[2], n[3]}}; + add_edges(g, edges); - CHECK(get_dfs_ordering(g, {n[0]}) == + CHECK(get_sources(g) == std::unordered_set{n[0], n[4]}); + CHECK(get_unchecked_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2], n[3]}); - CHECK(is_acyclic(g) == false); - } - SUBCASE("nonlinear") { - g.add_edge({n[1], n[3]}); - CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + CHECK(get_bfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == true); + CHECK(get_bfs_ordering(g, {n[4]}) == std::vector{n[4]}); + CHECK(get_dfs_ordering(g, {n[4]}) == std::vector{n[4]}); + + SUBCASE("with root") { + g.add_edge({n[3], n[2]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == false); + } + + SUBCASE("without root") { + g.add_edge({n[3], n[0]}); + + CHECK(get_dfs_ordering(g, {n[0]}) == + std::vector{n[0], n[1], n[2], n[3]}); + CHECK(is_acyclic(g) == false); + } + SUBCASE("nonlinear") { + g.add_edge({n[1], n[3]}); + CHECK(is_acyclic(g) == true); // TODO, maybe a bug about the unchecked_dfs + } + + SUBCASE("not connected") { + g.remove_edge({n[2], n[3]}); + CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + } } - SUBCASE("not connected") { - g.remove_edge({n[2], n[3]}); - CHECK(get_dfs_ordering(g, {n[0]}) == std::vector{n[0], n[1], n[2]}); + TEST_CASE("bfs") { + DiGraph g = DiGraph::create(); + std::vector const n = add_nodes(g, 7); + + std::vector e = { + {n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[6]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}, + {n[5], n[6]}, + {n[6], n[0]}, + }; + + add_edges(g, e); + + std::vector ordering = get_bfs_ordering(g, {n[0]}); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).has_value()); + CHECK(index_of(ordering, n[r]).has_value()); + CHECK(index_of(ordering, n[l]).value() < + index_of(ordering, n[r]).value()); + }; + + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + + CHECK_BEFORE(1, 3); + CHECK_BEFORE(1, 6); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(2, 6); + + CHECK_BEFORE(3, 4); + CHECK_BEFORE(6, 4); + + CHECK_BEFORE(4, 5); } -} -TEST_CASE("bfs") { - DiGraph g = DiGraph::create(); - std::vector const n = add_nodes(g, 7); - - std::vector e = { - {n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[6]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}, - {n[5], n[6]}, - {n[6], n[0]}, - }; - - add_edges(g, e); - - std::vector ordering = get_bfs_ordering(g, {n[0]}); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]).value() < index_of(ordering, n[r]).value()); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - - CHECK_BEFORE(1, 3); - CHECK_BEFORE(1, 6); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(2, 6); - - CHECK_BEFORE(3, 4); - CHECK_BEFORE(6, 4); - - CHECK_BEFORE(4, 5); -} + TEST_CASE("get_topological_ordering") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + std::vector edges = {{n[0], n[1]}, + {n[0], n[2]}, + {n[1], n[5]}, + {n[2], n[3]}, + {n[3], n[4]}, + {n[4], n[5]}}; + add_edges(g, edges); + std::vector ordering = get_topological_ordering(g); + auto CHECK_BEFORE = [&](int l, int r) { + CHECK(index_of(ordering, n[l]).has_value()); + CHECK(index_of(ordering, n[r]).has_value()); + CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); + }; -TEST_CASE("get_topological_ordering") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 6); - std::vector edges = {{n[0], n[1]}, - {n[0], n[2]}, - {n[1], n[5]}, - {n[2], n[3]}, - {n[3], n[4]}, - {n[4], n[5]}}; - add_edges(g, edges); - std::vector ordering = get_topological_ordering(g); - auto CHECK_BEFORE = [&](int l, int r) { - CHECK(index_of(ordering, n[l]).has_value()); - CHECK(index_of(ordering, n[r]).has_value()); - CHECK(index_of(ordering, n[l]) < index_of(ordering, n[r])); - }; - - CHECK(ordering.size() == n.size()); - CHECK_BEFORE(0, 1); - CHECK_BEFORE(0, 2); - CHECK_BEFORE(1, 5); - CHECK_BEFORE(2, 3); - CHECK_BEFORE(3, 4); - CHECK_BEFORE(4, 5); -} + CHECK(ordering.size() == n.size()); + CHECK_BEFORE(0, 1); + CHECK_BEFORE(0, 2); + CHECK_BEFORE(1, 5); + CHECK_BEFORE(2, 3); + CHECK_BEFORE(3, 4); + CHECK_BEFORE(4, 5); + } -TEST_CASE("get_connected_components") { - UndirectedGraph g = UndirectedGraph::create(); - std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + TEST_CASE("get_connected_components") { + UndirectedGraph g = UndirectedGraph::create(); + std::vector n = add_nodes(g, 4); + std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; - CHECK(get_connected_components(g) == expected_components); -} + CHECK(get_connected_components(g) == expected_components); + } -TEST_CASE("get_weakly_connected_components") { - DiGraph g = DiGraph::create(); - std::vector n = add_nodes(g, 4); + TEST_CASE("get_weakly_connected_components") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); - std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; + std::vector edges = {{n[0], n[1]}, {n[2], n[1]}}; - add_edges(g, edges); - std::unordered_set> expected_components = { - {n[0], n[1], n[2]}, - {n[3]}, - }; + add_edges(g, edges); + std::unordered_set> expected_components = { + {n[0], n[1], n[2]}, + {n[3]}, + }; - CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); + CHECK(get_outgoing_edges(as_digraph(as_undirected(g)), n[0]).size() == 1); - CHECK(get_weakly_connected_components(g) == expected_components); + CHECK(get_weakly_connected_components(g) == expected_components); + } } diff --git a/lib/utils/test/src/test_bidict.cc b/lib/utils/test/src/test_bidict.cc index 6c288089b6..afc32b3658 100644 --- a/lib/utils/test/src/test_bidict.cc +++ b/lib/utils/test/src/test_bidict.cc @@ -3,61 +3,63 @@ using namespace FlexFlow; -TEST_CASE("bidict") { - bidict dict; - dict.equate(1, "one"); - dict.equate(2, "two"); - - // Test the equate() function - SUBCASE("Equate") { - CHECK(dict.at_l(1) == "one"); - CHECK(dict.at_r("one") == 1); - CHECK(dict.at_l(2) == "two"); - CHECK(dict.at_r("two") == 2); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("bidict") { + bidict dict; + dict.equate(1, "one"); + dict.equate(2, "two"); - // Test the erase_l() function - SUBCASE("EraseL") { - dict.erase_l(1); - CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); - CHECK(dict.at_r("two") == 2); - } + // Test the equate() function + SUBCASE("Equate") { + CHECK(dict.at_l(1) == "one"); + CHECK(dict.at_r("one") == 1); + CHECK(dict.at_l(2) == "two"); + CHECK(dict.at_r("two") == 2); + } - // Test the erase_r() function - SUBCASE("EraseR") { - dict.erase_r("one"); - CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); - CHECK(dict.at_l(2) == "two"); - } + // Test the erase_l() function + SUBCASE("EraseL") { + dict.erase_l(1); + CHECK(dict.size() == 1); + CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); + CHECK(dict.at_r("two") == 2); + } - // Test the reversed() function - SUBCASE("Reversed") { - bidict reversed_dict = dict.reversed(); - CHECK(reversed_dict.at_l("one") == 1); - CHECK(reversed_dict.at_r(2) == "two"); - } + // Test the erase_r() function + SUBCASE("EraseR") { + dict.erase_r("one"); + CHECK(dict.size() == 1); + CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); + CHECK(dict.at_l(2) == "two"); + } - // Test the size() function - SUBCASE("Size") { - CHECK(dict.size() == 2); - } + // Test the reversed() function + SUBCASE("Reversed") { + bidict reversed_dict = dict.reversed(); + CHECK(reversed_dict.at_l("one") == 1); + CHECK(reversed_dict.at_r(2) == "two"); + } - SUBCASE("implicitly convert to std::unordered_map") { - std::unordered_map res = dict; - std::unordered_map expected = {{1, "one"}, {2, "two"}}; - CHECK(res == expected); - } + // Test the size() function + SUBCASE("Size") { + CHECK(dict.size() == 2); + } - SUBCASE("begin") { - auto it = dict.begin(); - CHECK(it->first == 2); - CHECK(it->second == "two"); - } + SUBCASE("implicitly convert to std::unordered_map") { + std::unordered_map res = dict; + std::unordered_map expected = {{1, "one"}, {2, "two"}}; + CHECK(res == expected); + } + + SUBCASE("begin") { + auto it = dict.begin(); + CHECK(it->first == 2); + CHECK(it->second == "two"); + } - SUBCASE("end") { - auto it = dict.end(); - CHECK(it == dict.end()); + SUBCASE("end") { + auto it = dict.end(); + CHECK(it == dict.end()); + } } } diff --git a/lib/utils/test/src/test_containers.cc b/lib/utils/test/src/test_containers.cc index 8c37abf877..a6776d492e 100644 --- a/lib/utils/test/src/test_containers.cc +++ b/lib/utils/test/src/test_containers.cc @@ -5,384 +5,389 @@ #include using namespace FlexFlow; -TEST_CASE("join_strings") { - std::vector const v = {"Hello", "world", "!"}; - CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); -} -TEST_CASE("join_strings with container") { - std::vector const v = {"Hello", "world"}; - CHECK(join_strings(v, " ") == "Hello world"); -} +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("join_strings") { + std::vector const v = {"Hello", "world", "!"}; + CHECK(join_strings(v.begin(), v.end(), " ") == "Hello world !"); + } -TEST_CASE("find") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(find(v, 3) != v.cend()); - CHECK(find(v, 6) == v.cend()); -} + TEST_CASE("join_strings with container") { + std::vector const v = {"Hello", "world"}; + CHECK(join_strings(v, " ") == "Hello world"); + } -TEST_CASE("sum") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(sum(v) == 15); -} + TEST_CASE("find") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(find(v, 3) != v.cend()); + CHECK(find(v, 6) == v.cend()); + } -TEST_CASE("sum with condition") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { return x % 2 == 0; }; // Sum of even numbers only - CHECK(sum_where(v, condition) == 6); -} + TEST_CASE("sum") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(sum(v) == 15); + } -TEST_CASE("product") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(product(v) == 120); -} + TEST_CASE("sum with condition") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { + return x % 2 == 0; + }; // Sum of even numbers only + CHECK(sum_where(v, condition) == 6); + } -TEST_CASE("product_where") { - std::vector v = {1, 2, 3, 4, 5}; - auto condition = [](int x) { - return x % 2 == 0; - }; // Product of even numbers only - CHECK(product_where(v, condition) == 8); -} + TEST_CASE("product") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(product(v) == 120); + } -TEST_CASE("contains") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(contains(v, 3)); - CHECK(!contains(v, 6)); -} + TEST_CASE("product_where") { + std::vector v = {1, 2, 3, 4, 5}; + auto condition = [](int x) { + return x % 2 == 0; + }; // Product of even numbers only + CHECK(product_where(v, condition) == 8); + } -TEST_CASE("contains_key") { - std::unordered_map m = { - {"one", 1}, {"two", 2}, {"three", 3}}; - CHECK(contains_key(m, "one")); - CHECK(!contains_key(m, "four")); -} + TEST_CASE("contains") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(contains(v, 3)); + CHECK(!contains(v, 6)); + } -TEST_CASE("map_keys") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](int x) { return x * x; }; // Mapping function - auto result = map_keys(m, f); - CHECK(result.size() == 2); - CHECK(result[1] == "one"); - CHECK(result[4] == "two"); -} + TEST_CASE("contains_key") { + std::unordered_map m = { + {"one", 1}, {"two", 2}, {"three", 3}}; + CHECK(contains_key(m, "one")); + CHECK(!contains_key(m, "four")); + } -TEST_CASE("filter_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = [](int x) { return x % 2 == 1; }; // Filtering function - std::unordered_map result = filter_keys(m, f); - std::unordered_map expected = {{1, "one"}, {3, "three"}}; - CHECK(result == expected); -} + TEST_CASE("map_keys") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](int x) { return x * x; }; // Mapping function + auto result = map_keys(m, f); + CHECK(result.size() == 2); + CHECK(result[1] == "one"); + CHECK(result[4] == "two"); + } -TEST_CASE("map_values") { - std::unordered_map m = {{1, "one"}, {2, "two"}}; - auto f = [](std::string const &s) { return s.size(); }; // Mapping function - std::unordered_map result = map_values(m, f); - std::unordered_map expected = {{1, 3}, {2, 3}}; - CHECK(result == expected); -} + TEST_CASE("filter_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = [](int x) { return x % 2 == 1; }; // Filtering function + std::unordered_map result = filter_keys(m, f); + std::unordered_map expected = {{1, "one"}, {3, "three"}}; + CHECK(result == expected); + } -TEST_CASE("keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {3, 2, 1}; - CHECK(result == expected); -} + TEST_CASE("map_values") { + std::unordered_map m = {{1, "one"}, {2, "two"}}; + auto f = [](std::string const &s) { return s.size(); }; // Mapping function + std::unordered_map result = map_values(m, f); + std::unordered_map expected = {{1, 3}, {2, 3}}; + CHECK(result == expected); + } -TEST_CASE("values") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::vector result = values(m); - std::vector expected = {"three", "two", "one"}; - CHECK(result == expected); -} + TEST_CASE("keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set result = keys(m); + std::unordered_set expected = {3, 2, 1}; + CHECK(result == expected); + } -// TEST_CASE("items") { -// std::unordered_map m = {{1, std::string("one")}, {2, -// std::string("two")}, {3,std::string("three")}}; -// std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; - std::unordered_set result = unique(v); - std::unordered_set expected = {1, 2, 3}; - CHECK(result == expected); -} + TEST_CASE("values") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::vector result = values(m); + std::vector expected = {"three", "two", "one"}; + CHECK(result == expected); + } -TEST_CASE("without_order") { - std::vector v = {1, 4, 6, 4, 6}; - std::unordered_set expected = {1, 4, 6}; - CHECK(without_order(v) == expected); -} + // TEST_CASE("items") { + // std::unordered_map m = {{1, std::string("one")}, {2, + // std::string("two")}, {3,std::string("three")}}; + // std::cout<<"result type:"< v = {1, 2, 3, 2, 1}; + std::unordered_set result = unique(v); + std::unordered_set expected = {1, 2, 3}; + CHECK(result == expected); + } -TEST_CASE("index_of") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(index_of(v, 3) == 2); - CHECK(!index_of(v, 6).has_value()); -} + TEST_CASE("without_order") { + std::vector v = {1, 4, 6, 4, 6}; + std::unordered_set expected = {1, 4, 6}; + CHECK(without_order(v) == expected); + } -TEST_CASE("intersection") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {2, 3, 4}; - std::unordered_set result = intersection(l, r); - std::unordered_set expected = {2, 3}; - CHECK(result == expected); -} + TEST_CASE("index_of") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(index_of(v, 3) == 2); + CHECK(!index_of(v, 6).has_value()); + } -TEST_CASE("are_disjoint") { - std::unordered_set l = {1, 2, 3}; - std::unordered_set r = {4, 5, 6}; - CHECK(are_disjoint(l, r)); - r.insert(3); - CHECK_FALSE(are_disjoint(l, r)); -} + TEST_CASE("intersection") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {2, 3, 4}; + std::unordered_set result = intersection(l, r); + std::unordered_set expected = {2, 3}; + CHECK(result == expected); + } -TEST_CASE("restrict_keys") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set mask = {2, 3, 4}; - std::unordered_map result = restrict_keys(m, mask); - std::unordered_map expected = {{2, "two"}, {3, "three"}}; - CHECK(result == expected); -} + TEST_CASE("are_disjoint") { + std::unordered_set l = {1, 2, 3}; + std::unordered_set r = {4, 5, 6}; + CHECK(are_disjoint(l, r)); + r.insert(3); + CHECK_FALSE(are_disjoint(l, r)); + } -TEST_CASE("merge_maps(unordered_map)") { - std::unordered_map lhs = {{1, "one"}, {2, "two"}}; - std::unordered_map rhs = {{3, "three"}, {4, "four"}}; - std::unordered_map result = merge_maps(lhs, rhs); - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); -} + TEST_CASE("restrict_keys") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + std::unordered_set mask = {2, 3, 4}; + std::unordered_map result = restrict_keys(m, mask); + std::unordered_map expected = {{2, "two"}, {3, "three"}}; + CHECK(result == expected); + } -TEST_CASE("merge_maps(bidict)") { - std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; - std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; - std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; - std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; - bidict lhs{fwd_map1, bwd_map1}; - bidict rhs{fwd_map2, bwd_map2}; - - std::unordered_map result = - merge_maps(lhs, rhs); // impicit conversion - std::unordered_map expected = { - {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; - CHECK(result == expected); -} + TEST_CASE("merge_maps(unordered_map)") { + std::unordered_map lhs = {{1, "one"}, {2, "two"}}; + std::unordered_map rhs = {{3, "three"}, {4, "four"}}; + std::unordered_map result = merge_maps(lhs, rhs); + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK(result == expected); + } -TEST_CASE("lookup_in") { - std::unordered_map m = { - {1, "one"}, {2, "two"}, {3, "three"}}; - auto f = lookup_in(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); - CHECK(f(3) == "three"); -} + TEST_CASE("merge_maps(bidict)") { + std::unordered_map fwd_map1 = {{1, "one"}, {2, "two"}}; + std::unordered_map bwd_map1 = {{"one", 1}, {"two", 2}}; + std::unordered_map fwd_map2 = {{3, "three"}, {4, "four"}}; + std::unordered_map bwd_map2 = {{"three", 3}, {"four", 4}}; + bidict lhs{fwd_map1, bwd_map1}; + bidict rhs{fwd_map2, bwd_map2}; + + std::unordered_map result = + merge_maps(lhs, rhs); // impicit conversion + std::unordered_map expected = { + {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; + CHECK(result == expected); + } -TEST_CASE("lookup_in_l") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_l(m); - CHECK(f(1) == "one"); - CHECK(f(2) == "two"); -} + TEST_CASE("lookup_in") { + std::unordered_map m = { + {1, "one"}, {2, "two"}, {3, "three"}}; + auto f = lookup_in(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + CHECK(f(3) == "three"); + } -TEST_CASE("lookup_in_r") { - bidict m; - m.equate(1, "one"); - m.equate(2, "two"); - auto f = lookup_in_r(m); - CHECK(f("one") == 1); - CHECK(f("two") == 2); -} + TEST_CASE("lookup_in_l") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_l(m); + CHECK(f(1) == "one"); + CHECK(f(2) == "two"); + } -TEST_CASE("set_union") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {2, 3, 4}; - std::unordered_set result = set_union(s1, s2); - std::unordered_set expected = {1, 2, 3, 4}; - CHECK(result == expected); -} + TEST_CASE("lookup_in_r") { + bidict m; + m.equate(1, "one"); + m.equate(2, "two"); + auto f = lookup_in_r(m); + CHECK(f("one") == 1); + CHECK(f("two") == 2); + } -TEST_CASE("is_subseteq_of") { - std::unordered_set s1 = {1, 2}; - std::unordered_set s2 = {1, 2, 3}; - CHECK(is_subseteq_of(s1, s2) == true); - CHECK(is_subseteq_of(s2, s1) == false); - CHECK(is_subseteq_of(s1, s1) == true); - CHECK(is_subseteq_of(s2, s2) == true); -} + TEST_CASE("set_union") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {2, 3, 4}; + std::unordered_set result = set_union(s1, s2); + std::unordered_set expected = {1, 2, 3, 4}; + CHECK(result == expected); + } -TEST_CASE("is_superseteq_of") { - std::unordered_set s1 = {1, 2, 3}; - std::unordered_set s2 = {1, 2}; - CHECK(is_supserseteq_of(s1, s2) == true); - CHECK(is_supserseteq_of(s2, s1) == false); -} + TEST_CASE("is_subseteq_of") { + std::unordered_set s1 = {1, 2}; + std::unordered_set s2 = {1, 2, 3}; + CHECK(is_subseteq_of(s1, s2) == true); + CHECK(is_subseteq_of(s2, s1) == false); + CHECK(is_subseteq_of(s1, s1) == true); + CHECK(is_subseteq_of(s2, s2) == true); + } -TEST_CASE("get_only") { - std::unordered_set s = {42}; - CHECK(get_only(s) == 42); -} + TEST_CASE("is_superseteq_of") { + std::unordered_set s1 = {1, 2, 3}; + std::unordered_set s2 = {1, 2}; + CHECK(is_supserseteq_of(s1, s2) == true); + CHECK(is_supserseteq_of(s2, s1) == false); + } -TEST_CASE("get_first") { - std::unordered_set s = {1, 2, 3}; - CHECK(s.count(get_first(s)) == 1); -} + TEST_CASE("get_only") { + std::unordered_set s = {42}; + CHECK(get_only(s) == 42); + } -TEST_CASE("extend") { - std::vector v = {1, 2, 3}; - std::unordered_set s = {4, 5, 6}; - extend(v, s); - CHECK(v.size() == 6); - std::vector expected = {1, 2, 3, 6, 5, 4}; - CHECK(v == expected); -} + TEST_CASE("get_first") { + std::unordered_set s = {1, 2, 3}; + CHECK(s.count(get_first(s)) == 1); + } -TEST_CASE("all_of") { - std::vector v = {2, 4, 6, 8}; - CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); - CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); -} + TEST_CASE("extend") { + std::vector v = {1, 2, 3}; + std::unordered_set s = {4, 5, 6}; + extend(v, s); + CHECK(v.size() == 6); + std::vector expected = {1, 2, 3, 6, 5, 4}; + CHECK(v == expected); + } -TEST_CASE("count") { - std::vector v = {1, 2, 3, 4, 5}; - CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); - CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); -} + TEST_CASE("all_of") { + std::vector v = {2, 4, 6, 8}; + CHECK(all_of(v, [](int x) { return x % 2 == 0; }) == true); + CHECK(all_of(v, [](int x) { return x % 2 == 1; }) == false); + } -TEST_CASE("are_all_same") { - std::vector v1 = {2, 2, 2, 2}; - std::vector v2 = {1, 2, 3, 4}; - CHECK(are_all_same(v1) == true); - CHECK(are_all_same(v2) == false); -} + TEST_CASE("count") { + std::vector v = {1, 2, 3, 4, 5}; + CHECK(count(v, [](int x) { return x % 2 == 0; }) == 2); + CHECK(count(v, [](int x) { return x % 2 == 1; }) == 3); + } -TEST_CASE("vector_transform") { - std::vector v = {1, 2, 3}; - auto result = vector_transform([](int x) { return x * 2; }, v); - CHECK(result == std::vector({2, 4, 6})); -} + TEST_CASE("are_all_same") { + std::vector v1 = {2, 2, 2, 2}; + std::vector v2 = {1, 2, 3, 4}; + CHECK(are_all_same(v1) == true); + CHECK(are_all_same(v2) == false); + } -TEST_CASE("as_vector") { - std::unordered_set s = {1, 2, 3}; - std::vector result = as_vector(s); - CHECK(result == std::vector({3, 2, 1})); -} + TEST_CASE("vector_transform") { + std::vector v = {1, 2, 3}; + auto result = vector_transform([](int x) { return x * 2; }, v); + CHECK(result == std::vector({2, 4, 6})); + } -TEST_CASE("transform_vector") { - std::vector v = {1, 2, 3}; - auto result = transform(v, [](int x) { return x * 2; }); - CHECK(result == std::vector({2, 4, 6})); -} + TEST_CASE("as_vector") { + std::unordered_set s = {1, 2, 3}; + std::vector result = as_vector(s); + CHECK(result == std::vector({3, 2, 1})); + } -TEST_CASE("transform_unordered_set") { - std::unordered_set s = {1, 2, 3}; - auto result = transform(s, [](int x) { return x * 2; }); - CHECK(result == std::unordered_set({2, 4, 6})); -} + TEST_CASE("transform_vector") { + std::vector v = {1, 2, 3}; + auto result = transform(v, [](int x) { return x * 2; }); + CHECK(result == std::vector({2, 4, 6})); + } -TEST_CASE("transform_string") { - std::string s = "abc"; - auto result = transform(s, ::toupper); - CHECK(result == "ABC"); -} + TEST_CASE("transform_unordered_set") { + std::unordered_set s = {1, 2, 3}; + auto result = transform(s, [](int x) { return x * 2; }); + CHECK(result == std::unordered_set({2, 4, 6})); + } -TEST_CASE("repeat") { - int ctr = 0; - std::vector result = repeat(5, [&] { return ctr++; }); + TEST_CASE("transform_string") { + std::string s = "abc"; + auto result = transform(s, ::toupper); + CHECK(result == "ABC"); + } - CHECK(result == std::vector{0, 1, 2, 3, 4}); -} + TEST_CASE("repeat") { + int ctr = 0; + std::vector result = repeat(5, [&] { return ctr++; }); -TEST_CASE("Testing the 'enumerate' function") { - std::unordered_set input_set = {1, 2, 3, 4, 5}; - std::unordered_map result = enumerate(input_set); - std::unordered_map expected = { - {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; - CHECK(result == expected); -} + CHECK(result == std::vector{0, 1, 2, 3, 4}); + } -TEST_CASE("Testing the 'maximum' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - auto result = maximum(input_vec); + TEST_CASE("Testing the 'enumerate' function") { + std::unordered_set input_set = {1, 2, 3, 4, 5}; + std::unordered_map result = enumerate(input_set); + std::unordered_map expected = { + {1, 4}, {2, 3}, {3, 2}, {4, 1}, {0, 5}}; + CHECK(result == expected); + } - // Checking the maximum is as expected - REQUIRE(result == 5); -} + TEST_CASE("Testing the 'maximum' function") { + std::vector input_vec = {1, 2, 3, 4, 5}; + auto result = maximum(input_vec); -TEST_CASE("Testing the 'reversed' function") { - std::vector input_vec = {1, 2, 3, 4, 5}; - std::vector result = reversed(input_vec); - std::vector expected = {5, 4, 3, 2, 1}; + // Checking the maximum is as expected + REQUIRE(result == 5); + } - // Checking the reversed sequence is as expected - CHECK(result == expected); -} + TEST_CASE("Testing the 'reversed' function") { + std::vector input_vec = {1, 2, 3, 4, 5}; + std::vector result = reversed(input_vec); + std::vector expected = {5, 4, 3, 2, 1}; + + // Checking the reversed sequence is as expected + CHECK(result == expected); + } -TEST_CASE("Testing sorted_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); - CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); + TEST_CASE("Testing sorted_by function") { + std::unordered_set s = {5, 2, 3, 4, 1}; + auto sorted_s = sorted_by(s, [](int a, int b) { return a < b; }); + CHECK(sorted_s == std::vector({1, 2, 3, 4, 5})); - std::unordered_set s2 = {-5, -1, -3, -2, -4}; - auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); - CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); -} + std::unordered_set s2 = {-5, -1, -3, -2, -4}; + auto sorted_s2 = sorted_by(s2, [](int a, int b) { return a > b; }); + CHECK(sorted_s2 == std::vector({-1, -2, -3, -4, -5})); + } -TEST_CASE("Testing compare_by function") { - std::unordered_set s = {5, 2, 3, 4, 1}; - std::vector result = - sorted_by(s, compare_by([](int i) { return (-i); })); - CHECK(result == std::vector{5, 4, 3, 2, 1}); -} + TEST_CASE("Testing compare_by function") { + std::unordered_set s = {5, 2, 3, 4, 1}; + std::vector result = + sorted_by(s, compare_by([](int i) { return (-i); })); + CHECK(result == std::vector{5, 4, 3, 2, 1}); + } -TEST_CASE("Testing vector_split function") { - std::vector v = {1, 2, 3, 4, 5}; - auto result = vector_split(v, 2); - std::vector prefix = result.first; - std::vector postfix = result.second; - CHECK(prefix == std::vector({1, 2})); - CHECK(postfix == std::vector({3, 4, 5})); -} + TEST_CASE("Testing vector_split function") { + std::vector v = {1, 2, 3, 4, 5}; + auto result = vector_split(v, 2); + std::vector prefix = result.first; + std::vector postfix = result.second; + CHECK(prefix == std::vector({1, 2})); + CHECK(postfix == std::vector({3, 4, 5})); + } -TEST_CASE("Testing value_all function") { - std::vector> v = {1, 2, 3, 4, 5}; - auto value_all_v = value_all(v); - CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); -} + TEST_CASE("Testing value_all function") { + std::vector> v = {1, 2, 3, 4, 5}; + auto value_all_v = value_all(v); + CHECK(value_all_v == std::vector({1, 2, 3, 4, 5})); + } -TEST_CASE("Testing subvec function") { - std::vector v = {1, 2, 3, 4, 5}; - auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); + TEST_CASE("Testing subvec function") { + std::vector v = {1, 2, 3, 4, 5}; + auto subvec_v = subvec(v, tl::optional(1), tl::optional(4)); - CHECK(subvec_v == std::vector({2, 3, 4})); + CHECK(subvec_v == std::vector({2, 3, 4})); - auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); - CHECK(subvec_v2 == std::vector({1, 2, 3})); -} + auto subvec_v2 = subvec(v, tl::nullopt, tl::optional(3)); + CHECK(subvec_v2 == std::vector({1, 2, 3})); + } -auto get_factors = [](int x) -> std::vector { - // Returns a vector of factors of x - std::vector factors; - for (int i = 1; i <= x; i++) { - if (x % i == 0) { - factors.push_back(i); + auto get_factors = [](int x) -> std::vector { + // Returns a vector of factors of x + std::vector factors; + for (int i = 1; i <= x; i++) { + if (x % i == 0) { + factors.push_back(i); + } } + return factors; + }; + + // Example for vector + TEST_CASE("Test for flatmap function on vectors") { + std::vector v = {2, 3, 4, 5}; + auto result = flatmap(v, get_factors); + CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); } - return factors; -}; - -// Example for vector -TEST_CASE("Test for flatmap function on vectors") { - std::vector v = {2, 3, 4, 5}; - auto result = flatmap(v, get_factors); - CHECK(result == std::vector({1, 2, 1, 3, 1, 2, 4, 1, 5})); } diff --git a/lib/utils/test/src/test_cow_ptr.cc b/lib/utils/test/src/test_cow_ptr.cc new file mode 100644 index 0000000000..de573d0c9b --- /dev/null +++ b/lib/utils/test/src/test_cow_ptr.cc @@ -0,0 +1,62 @@ +#include "test/utils/doctest.h" +#include "utils/graph/cow_ptr_t.h" +#include +#include +#include + +using namespace FlexFlow; + +struct TestObject { + TestObject(int x) : x(x) {} + int x; + virtual TestObject *clone() const { + return new TestObject(x); + } +}; + +struct TestObjectDerived : public TestObject { + TestObjectDerived(int x, int y) : TestObject(x), y(y) {} + int y; + TestObjectDerived *clone() const override { + return new TestObjectDerived(x, y); + } +}; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("cow_ptr_t constructor") { + std::shared_ptr sp = std::make_shared(1); + cow_ptr_t p1(sp); + cow_ptr_t p2(std::make_shared(3)); + cow_ptr_t p3(TestObject(2)); + cow_ptr_t p4(p3); + cow_ptr_t p5 = p1; + CHECK(p1->x == 1); + CHECK(p2->x == 3); + CHECK(p3->x == 2); + CHECK(p4->x == p3->x); + CHECK(p5->x == p1->x); + } + + TEST_CASE("cow_ptr_t copy") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(std::make_shared(2)); + p1 = p2; + CHECK(p1->x == p2->x); + } + + TEST_CASE("cow_ptr_t cast") { + cow_ptr_t p1(std::make_shared(1, 2)); + cow_ptr_t p2(p1); + CHECK(p2->x == 1); + } + + TEST_CASE("cow_ptr_t get_mutable") { + cow_ptr_t p1(std::make_shared(1)); + cow_ptr_t p2(p1); + p1.get_mutable()->x = 3; + CHECK(p1->x == 3); + CHECK(p2->x == 1); + p2.get_mutable()->x = 2; + CHECK(p1->x == 3); + } +} diff --git a/lib/utils/test/src/test_deduplicated_priority_queue.cc b/lib/utils/test/src/test_deduplicated_priority_queue.cc index a5c97fa0f8..66cfd395bc 100644 --- a/lib/utils/test/src/test_deduplicated_priority_queue.cc +++ b/lib/utils/test/src/test_deduplicated_priority_queue.cc @@ -1,34 +1,36 @@ #include "test/utils/doctest.h" #include "utils/deduplicated_priority_queue.h" -TEST_CASE("DeduplicatedPriorityQueue push and pop") { - DeduplicatedPriorityQueue queue; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DeduplicatedPriorityQueue push and pop") { + DeduplicatedPriorityQueue queue; - SUBCASE("Push elements") { - queue.push(5); - queue.push(2); - queue.push(7); - queue.push(2); + SUBCASE("Push elements") { + queue.push(5); + queue.push(2); + queue.push(7); + queue.push(2); - CHECK(queue.size() == 3); - CHECK(queue.top() == 7); - CHECK_FALSE(queue.empty()); - } + CHECK(queue.size() == 3); + CHECK(queue.top() == 7); + CHECK_FALSE(queue.empty()); + } - SUBCASE("Pop elements") { - queue.push(5); - queue.push(2); - queue.push(7); + SUBCASE("Pop elements") { + queue.push(5); + queue.push(2); + queue.push(7); - queue.pop(); - CHECK(queue.size() == 2); - CHECK(queue.top() == 5); + queue.pop(); + CHECK(queue.size() == 2); + CHECK(queue.top() == 5); - queue.pop(); - CHECK(queue.size() == 1); - CHECK(queue.top() == 2); + queue.pop(); + CHECK(queue.size() == 1); + CHECK(queue.top() == 2); - queue.pop(); - CHECK(queue.empty()); + queue.pop(); + CHECK(queue.empty()); + } } } diff --git a/lib/utils/test/src/test_disjoint_set.cc b/lib/utils/test/src/test_disjoint_set.cc index fe2c4bae33..80fcf87d6b 100644 --- a/lib/utils/test/src/test_disjoint_set.cc +++ b/lib/utils/test/src/test_disjoint_set.cc @@ -16,53 +16,54 @@ std::string generate_element(int seed) { return "Element" + std::to_string(seed); } -TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { - disjoint_set> ds; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("DisjointSetUnionAndFind", T, int, std::string) { + disjoint_set> ds; - SUBCASE("SingleElementSets") { - optional element = generate_element(1); - CHECK(ds.find(element) == element); + SUBCASE("SingleElementSets") { + optional element = generate_element(1); + CHECK(ds.find(element) == element); - element = generate_element(2); - CHECK(ds.find(element) == element); - } + element = generate_element(2); + CHECK(ds.find(element) == element); + } - SUBCASE("UnionAndFind") { - optional element1 = generate_element(1); - optional element2 = generate_element(2); - optional element3 = generate_element(3); - optional element4 = generate_element(4); + SUBCASE("UnionAndFind") { + optional element1 = generate_element(1); + optional element2 = generate_element(2); + optional element3 = generate_element(3); + optional element4 = generate_element(4); - ds.m_union(element1, element2); - CHECK(ds.find(element1) == ds.find(element2)); + ds.m_union(element1, element2); + CHECK(ds.find(element1) == ds.find(element2)); - ds.m_union(element3, element4); - CHECK(ds.find(element3) == ds.find(element4)); + ds.m_union(element3, element4); + CHECK(ds.find(element3) == ds.find(element4)); - ds.m_union(element1, element3); - CHECK(ds.find(element1) == ds.find(element3)); - CHECK(ds.find(element2) == ds.find(element4)); - CHECK(ds.find(element1) == ds.find(element2)); - CHECK(ds.find(element1) == ds.find(element4)); + ds.m_union(element1, element3); + CHECK(ds.find(element1) == ds.find(element3)); + CHECK(ds.find(element2) == ds.find(element4)); + CHECK(ds.find(element1) == ds.find(element2)); + CHECK(ds.find(element1) == ds.find(element4)); + } } -} -TEST_CASE_TEMPLATE("DisjointSetMapping", T, int, std::string) { - disjoint_set ds; - ds.m_union(1, 2); - ds.m_union(3, 4); - ds.m_union(1, 4); - ds.m_union(5, 6); + TEST_CASE_TEMPLATE("DisjointSetMapping", T, int, std::string) { + disjoint_set ds; + ds.m_union(1, 2); + ds.m_union(3, 4); + ds.m_union(1, 4); + ds.m_union(5, 6); - std::map, optional, OptionalComparator> - expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; + std::map, optional, OptionalComparator> + expectedMapping = {{1, 4}, {2, 4}, {3, 4}, {4, 4}, {5, 6}, {6, 6}}; - std::map, optional, OptionalComparator> mapping = - ds.get_mapping(); + std::map, optional, OptionalComparator> mapping = + ds.get_mapping(); - for (auto const &kv : mapping) { - CHECK( - *kv.second == - *expectedMapping[kv.first]); // Compare the values inside the optionals + for (auto const &kv : mapping) { + CHECK(*kv.second == *expectedMapping[kv.first]); // Compare the values + // inside the optionals + } } } diff --git a/lib/utils/test/src/test_dot_file.cc b/lib/utils/test/src/test_dot_file.cc index a65265afbd..ed4c32bb1c 100644 --- a/lib/utils/test/src/test_dot_file.cc +++ b/lib/utils/test/src/test_dot_file.cc @@ -2,67 +2,68 @@ #include "utils/dot_file.h" #include -TEST_CASE("DotFile") { - std::ostringstream oss; - DotFile dotFile(oss); - SUBCASE("add_node") { - dotFile.add_node("A", {{"shape", "circle"}, {"label", "Node A"}}); - dotFile.add_node("B", {{"shape", "rectangle"}, {"label", "Node B"}}); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DotFile") { + std::ostringstream oss; + DotFile dotFile(oss); + SUBCASE("add_node") { + dotFile.add_node("A", {{"shape", "circle"}, {"label", "Node A"}}); + dotFile.add_node("B", {{"shape", "rectangle"}, {"label", "Node B"}}); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { node0 [label=Node A,shape=circle]; node1 [label=Node B,shape=rectangle]; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_edge") { - dotFile.add_edge("A", "B"); - dotFile.add_edge("B", "C"); + SUBCASE("add_edge") { + dotFile.add_edge("A", "B"); + dotFile.add_edge("B", "C"); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { node0 -> node1; node1 -> node2; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_record_node") { - RecordFormatter rf; + SUBCASE("add_record_node") { + RecordFormatter rf; - rf << "Field1"; - rf << 42; - rf << "Field2"; - rf << float(3.14); + rf << "Field1"; + rf << 42; + rf << "Field2"; + rf << float(3.14); - dotFile.add_record_node("A", rf); + dotFile.add_record_node("A", rf); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = - R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = + R"EXPECTED_OUTPUT(digraph taskgraph { node0 [label="{ Field1 | 42 | Field2 | 3.140000e+00 }",shape=record]; })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); - } + CHECK(oss.str() == expectedOutput); + } - SUBCASE("add_node_to_subgraph") { - size_t subgraph1 = dotFile.add_subgraph(); - size_t subgraph2 = dotFile.add_subgraph(subgraph1); + SUBCASE("add_node_to_subgraph") { + size_t subgraph1 = dotFile.add_subgraph(); + size_t subgraph2 = dotFile.add_subgraph(subgraph1); - dotFile.add_node_to_subgraph("A", subgraph1); - dotFile.add_node_to_subgraph("B", subgraph2); + dotFile.add_node_to_subgraph("A", subgraph1); + dotFile.add_node_to_subgraph("B", subgraph2); - dotFile.close(); + dotFile.close(); - std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { + std::string expectedOutput = R"EXPECTED_OUTPUT(digraph taskgraph { subgraph cluster_0 { node1; node0; @@ -72,6 +73,7 @@ node1; } })EXPECTED_OUTPUT"; - CHECK(oss.str() == expectedOutput); + CHECK(oss.str() == expectedOutput); + } } } diff --git a/lib/utils/test/src/test_format.cc b/lib/utils/test/src/test_format.cc index 2f653c85af..eeed2eae81 100644 --- a/lib/utils/test/src/test_format.cc +++ b/lib/utils/test/src/test_format.cc @@ -7,32 +7,34 @@ std::string formatRecord(RecordFormatter const &formatter) { return oss.str(); } -TEST_CASE("RecordFormatter") { - RecordFormatter formatter; - SUBCASE("Appending string") { - formatter << "Hello"; - formatter << "World"; - CHECK(formatRecord(formatter) == "{ Hello | World }"); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RecordFormatter") { + RecordFormatter formatter; + SUBCASE("Appending string") { + formatter << "Hello"; + formatter << "World"; + CHECK(formatRecord(formatter) == "{ Hello | World }"); + } - SUBCASE("Appending integer and float") { - formatter << 42; - formatter << 3.14f; - CHECK(formatRecord(formatter) == "{ 42 | 3.140000e+00 }"); - } + SUBCASE("Appending integer and float") { + formatter << 42; + formatter << 3.14f; + CHECK(formatRecord(formatter) == "{ 42 | 3.140000e+00 }"); + } - SUBCASE("Appending another RecordFormatter") { - RecordFormatter subFormatter; - subFormatter << "Sub"; - subFormatter << "Formatter"; + SUBCASE("Appending another RecordFormatter") { + RecordFormatter subFormatter; + subFormatter << "Sub"; + subFormatter << "Formatter"; - RecordFormatter formatter; - formatter << "Hello"; - formatter << subFormatter; + RecordFormatter formatter; + formatter << "Hello"; + formatter << subFormatter; - std::ostringstream oss; - oss << formatter; + std::ostringstream oss; + oss << formatter; - CHECK(formatRecord(formatter) == "{ Hello | { Sub | Formatter } }"); + CHECK(formatRecord(formatter) == "{ Hello | { Sub | Formatter } }"); + } } } diff --git a/lib/utils/test/src/test_hash.cc b/lib/utils/test/src/test_hash.cc new file mode 100644 index 0000000000..b38c43fe30 --- /dev/null +++ b/lib/utils/test/src/test_hash.cc @@ -0,0 +1,20 @@ +#include "test/utils/doctest.h" +#include "utils/hash-utils.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("hash:unordered_map") { + std::unordered_map map1{{1, 2}}; + std::unordered_map map2{{1, 2}, {3, 4}}; + + size_t hash1 = get_std_hash(map1); + size_t hash2 = get_std_hash(map2); + + CHECK(hash1 != hash2); + + map1.insert({1, 2}); + hash1 = get_std_hash(map1); + CHECK(hash1 == hash2); + } +} diff --git a/lib/utils/test/src/test_multidigraph.cc b/lib/utils/test/src/test_multidigraph.cc index 944ff0b7ca..90e1bb2187 100644 --- a/lib/utils/test/src/test_multidigraph.cc +++ b/lib/utils/test/src/test_multidigraph.cc @@ -5,86 +5,90 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { - MultiDiGraph g = MultiDiGraph::create(); - - std::vector n = repeat(3, [&] { return g.add_node(); }); - std::vector p = repeat(3, [&] { return g.add_node_port(); }); - - std::vector e = {{n[1], p[1], n[0], p[0]}, - {n[2], p[2], n[0], p[0]}, - {n[0], p[0], n[2], p[2]}, - {n[1], p[1], n[2], p[2]}}; - for (MultiDiEdge const &edge : e) { - g.add_edge(edge); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("MultiDiGraph implementations", T, AdjacencyMultiDiGraph) { + MultiDiGraph g = MultiDiGraph::create(); + + std::vector n = repeat(3, [&] { return g.add_node(); }); + std::vector p = repeat(3, [&] { return g.add_node_port(); }); - CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[0], n[1], n[2]}); - - CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == - std::unordered_set{n[0], n[2]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[0], e[1], e[2], e[3]}); - - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == - std::unordered_set{e[0], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( - {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( - {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs(query_set( - {p[1], p[2]}))) == std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs(query_set( - {p[0], p[2]}))) == std::unordered_set{e[1], e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all() - .with_src_nodes({n[1]}) - .with_dst_nodes({n[2]}) - .with_src_idxs({p[1]}) - .with_dst_idxs({p[2]})) == - std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == - std::unordered_set{e[1]}); - - SUBCASE("remove node") { - g.remove_node_unsafe(n[0]); + std::vector e = {{n[1], p[1], n[0], p[0]}, + {n[2], p[2], n[0], p[0]}, + {n[0], p[0], n[2], p[2]}, + {n[1], p[1], n[2], p[2]}}; + for (MultiDiEdge const &edge : e) { + g.add_edge(edge); + } CHECK(g.query_nodes(NodeQuery::all()) == - std::unordered_set{n[1], n[2]}); + std::unordered_set{n[0], n[1], n[2]}); + + CHECK(g.query_nodes(NodeQuery{query_set{{n[0], n[2]}}}) == + std::unordered_set{n[0], n[2]}); CHECK(g.query_edges(MultiDiEdgeQuery::all()) == - std::unordered_set{e[2], e[3]}); + std::unordered_set{e[0], e[1], e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[1]})) == std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[1]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[1]})) == + std::unordered_set{e[0], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes(query_set( + {n[1], n[2]}))) == std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes(query_set( + {n[0], n[2]}))) == std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs( + query_set({p[1], p[2]}))) == + std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs( + query_set({p[0], p[2]}))) == + std::unordered_set{e[1], e[2]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all() + .with_src_nodes({n[1]}) + .with_dst_nodes({n[2]}) + .with_src_idxs({p[1]}) + .with_dst_idxs({p[2]})) == + std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[2]})) == + std::unordered_set{e[1]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == - std::unordered_set{e[2]}); + SUBCASE("remove node") { + g.remove_node_unsafe(n[0]); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == - std::unordered_set{e[2]}); - } + CHECK(g.query_nodes(NodeQuery::all()) == + std::unordered_set{n[1], n[2]}); - SUBCASE("remove_edge") { - g.remove_edge(e[0]); + CHECK(g.query_edges(MultiDiEdgeQuery::all()) == + std::unordered_set{e[2], e[3]}); - CHECK(g.query_edges( - MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( - {n[1]})) == std::unordered_set{}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_nodes({n[0]})) == + std::unordered_set{}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == - std::unordered_set{e[1]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[0]})) == + std::unordered_set{e[2]}); - CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == - std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_idxs({p[0]})) == + std::unordered_set{e[2]}); + } + + SUBCASE("remove_edge") { + g.remove_edge(e[0]); + + CHECK(g.query_edges( + MultiDiEdgeQuery::all().with_src_nodes({n[0]}).with_dst_nodes( + {n[1]})) == std::unordered_set{}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_dst_nodes({n[2]})) == + std::unordered_set{e[1]}); + + CHECK(g.query_edges(MultiDiEdgeQuery::all().with_src_idxs({p[2]})) == + std::unordered_set{e[2], e[3]}); + } } } diff --git a/lib/utils/test/src/test_random_utils.cc b/lib/utils/test/src/test_random_utils.cc index dd7c320d85..88a566a198 100644 --- a/lib/utils/test/src/test_random_utils.cc +++ b/lib/utils/test/src/test_random_utils.cc @@ -14,52 +14,54 @@ void checkProbabilities(std::vector const &counts, } } -TEST_CASE("select_random") { - std::vector values = {1, 2, 3, 4, 5}; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("select_random") { + std::vector values = {1, 2, 3, 4, 5}; - SUBCASE("Select random value") { - int result = select_random(values); + SUBCASE("Select random value") { + int result = select_random(values); - CHECK(std::find(values.begin(), values.end(), result) != values.end()); - } + CHECK(std::find(values.begin(), values.end(), result) != values.end()); + } - SUBCASE("Invalid arguments") { - std::vector weights = {0.1f, 0.3f, 0.2f}; - CHECK(select_random(values, weights) == 2); + SUBCASE("Invalid arguments") { + std::vector weights = {0.1f, 0.3f, 0.2f}; + CHECK(select_random(values, weights) == 2); + } } -} -TEST_CASE("select_random - Weighted Random Selection") { - SUBCASE("Test with equal weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + TEST_CASE("select_random - Weighted Random Selection") { + SUBCASE("Test with equal weights") { + std::vector values = {1, 2, 3, 4, 5}; + std::vector weights = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; + std::vector counts(values.size(), 0); + int const numIterations = 10000; + for (int i = 0; i < numIterations; i++) { + int selected = select_random(values, weights); + counts[selected - 1]++; + } + + checkProbabilities(counts, numIterations, weights, values.size()); } - checkProbabilities(counts, numIterations, weights, values.size()); - } + SUBCASE("Test with different weights") { + std::vector values = {1, 2, 3, 4, 5}; + std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; - SUBCASE("Test with different weights") { - std::vector values = {1, 2, 3, 4, 5}; - std::vector weights = {0.1f, 0.2f, 0.3f, 0.2f, 0.2f}; + std::vector counts(values.size(), 0); + int const numIterations = 10000; + for (int i = 0; i < numIterations; i++) { + int selected = select_random(values, weights); + counts[selected - 1]++; + } - std::vector counts(values.size(), 0); - int const numIterations = 10000; - for (int i = 0; i < numIterations; i++) { - int selected = select_random(values, weights); - counts[selected - 1]++; - } + float totalWeight = 0.0f; + for (float weight : weights) { + totalWeight += weight; + } - float totalWeight = 0.0f; - for (float weight : weights) { - totalWeight += weight; + checkProbabilities(counts, numIterations, weights, totalWeight); } - - checkProbabilities(counts, numIterations, weights, totalWeight); } } diff --git a/lib/utils/test/src/test_sequence.cc b/lib/utils/test/src/test_sequence.cc index 576271a858..ee72febe05 100644 --- a/lib/utils/test/src/test_sequence.cc +++ b/lib/utils/test/src/test_sequence.cc @@ -3,169 +3,171 @@ using namespace FlexFlow; -TEST_CASE("seq_head") { - SUBCASE("seq_head with non-empty sequence") { - using Seq = seq<1, 2, 3, 4>; - constexpr int result = seq_head::value; - CHECK(result == 1); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("seq_head") { + SUBCASE("seq_head with non-empty sequence") { + using Seq = seq<1, 2, 3, 4>; + constexpr int result = seq_head::value; + CHECK(result == 1); + } + + SUBCASE("seq_head with empty sequence") { + using Seq = seq<>; + constexpr int result = seq_head::value; + CHECK(result == -1); + } } - SUBCASE("seq_head with empty sequence") { - using Seq = seq<>; - constexpr int result = seq_head::value; - CHECK(result == -1); + TEST_CASE("seq_tail") { + SUBCASE("seq_tail with non-empty sequence") { + using Seq = seq<1, 2, 3, 4>; + using ResultType = typename seq_tail::type; + using ExpectedType = seq<2, 3, 4>; + CHECK(std::is_same::value); + } + + SUBCASE("seq_tail with empty sequence") { + using Seq = seq<>; + using ResultType = typename seq_tail::type; + using ExpectedType = seq<>; + CHECK(std::is_same::value); + } } -} -TEST_CASE("seq_tail") { - SUBCASE("seq_tail with non-empty sequence") { - using Seq = seq<1, 2, 3, 4>; - using ResultType = typename seq_tail::type; - using ExpectedType = seq<2, 3, 4>; + TEST_CASE("seq_prepend") { + using ResultType = typename FlexFlow::seq_prepend<1, 2, 3>::type; + using ExpectedType = FlexFlow::seq<1, 2, 3>; CHECK(std::is_same::value); } - SUBCASE("seq_tail with empty sequence") { - using Seq = seq<>; - using ResultType = typename seq_tail::type; - using ExpectedType = seq<>; + TEST_CASE("seq_append") { + using Seq = seq<1, 2, 3>; + using ResultType = typename seq_append::type; + using ExpectedType = seq<1, 2, 3, 4>; CHECK(std::is_same::value); } -} -TEST_CASE("seq_prepend") { - using ResultType = typename FlexFlow::seq_prepend<1, 2, 3>::type; - using ExpectedType = FlexFlow::seq<1, 2, 3>; - CHECK(std::is_same::value); -} - -TEST_CASE("seq_append") { - using Seq = seq<1, 2, 3>; - using ResultType = typename seq_append::type; - using ExpectedType = seq<1, 2, 3, 4>; - CHECK(std::is_same::value); -} + TEST_CASE("seq_count") { + using ResultType = seq_count_t<5>; + using ExpectedType = seq<1, 2, 3, 4, 5>; + CHECK(!std::is_same::value); + } -TEST_CASE("seq_count") { - using ResultType = seq_count_t<5>; - using ExpectedType = seq<1, 2, 3, 4, 5>; - CHECK(!std::is_same::value); -} + TEST_CASE("seq_enumerate_args") { + using Args = std::tuple; + using ResultType = seq_enumerate_args_t; + using ExpectedType = seq<0, 1, 2>; + CHECK(std::is_same::value); + } -TEST_CASE("seq_enumerate_args") { - using Args = std::tuple; - using ResultType = seq_enumerate_args_t; - using ExpectedType = seq<0, 1, 2>; - CHECK(std::is_same::value); + // template + // int square(std::integral_constant) { + // return X * X; + // } + + // TEST_CASE("seq_select") { + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_select(square, 1, seq<1, 2, 3>); + // CHECK(result == 4); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_select(square, 3, Seq{}), std::runtime_error); + // } + // } + + // TEST_CASE("seq_get") { + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_get(square, 2, Seq{}); + // CHECK(result == 9); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_get(square, 3, Seq{}), std::runtime_error); + // } + // } + + // TEST_CASE("seq_get") { + // struct F { + // template + // int operator()(std::integral_constant) const { + // return X * X; + // } + // }; + + // SUBCASE("Valid index") { + // using Seq = seq<1, 2, 3>; + // int result = seq_get(F{}, 2, Seq{}); + // CHECK(result == 9); + // } + + // SUBCASE("Invalid index") { + // using Seq = seq<1, 2, 3>; + // CHECK_THROWS_AS(seq_get(F{}, 3, Seq{}), std::runtime_error); + // } + // } + + // struct F { + // template + // struct type { + // using result = std::integral_constant; + // }; + // }; + + // TEST_CASE("seq_transform_type") { + // using Seq = seq<1, 2, 3>; + // using ResultType = seq_transform_type_t; + // using ExpectedType = std::tuple, + // std::integral_constant, + // std::integral_constant>; + // CHECK(std::is_same::value); + // } + + // TEST_CASE("seq_transform") { + // struct F { + // template + // int operator()(std::integral_constant) { + // return X * X; + // } + // }; + + // using Seq = seq<1, 2, 3>; + // auto result = seq_transform(F{}, Seq{}); + // std::tuple expected{1, 4, 9}; + // CHECK(result == expected); + // } + + // TEST_CASE("seq_select") { + // struct F { + // template + // tl::optional operator()(std::integral_constant) { + // if (X % 2 == 0) { + // return X; + // } else { + // return tl::nullopt; + // } + // } + // }; + + // using Seq = seq<1, 2, 3, 4, 5>; + // int result = seq_select(F{}, Seq{}); + // CHECK(result == 2); + // } + + // TEST_CASE("seq_get") { + // struct F { + // template + // int operator()(std::integral_constant) { + // return X * X; + // } + // }; + + // using Seq = seq<1, 2, 3, 4, 5>; + // int result = seq_get(F{}, 3, Seq{}); + // CHECK(result == 16); + // } } - -// template -// int square(std::integral_constant) { -// return X * X; -// } - -// TEST_CASE("seq_select") { -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_select(square, 1, seq<1, 2, 3>); -// CHECK(result == 4); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_select(square, 3, Seq{}), std::runtime_error); -// } -// } - -// TEST_CASE("seq_get") { -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_get(square, 2, Seq{}); -// CHECK(result == 9); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_get(square, 3, Seq{}), std::runtime_error); -// } -// } - -// TEST_CASE("seq_get") { -// struct F { -// template -// int operator()(std::integral_constant) const { -// return X * X; -// } -// }; - -// SUBCASE("Valid index") { -// using Seq = seq<1, 2, 3>; -// int result = seq_get(F{}, 2, Seq{}); -// CHECK(result == 9); -// } - -// SUBCASE("Invalid index") { -// using Seq = seq<1, 2, 3>; -// CHECK_THROWS_AS(seq_get(F{}, 3, Seq{}), std::runtime_error); -// } -// } - -// struct F { -// template -// struct type { -// using result = std::integral_constant; -// }; -// }; - -// TEST_CASE("seq_transform_type") { -// using Seq = seq<1, 2, 3>; -// using ResultType = seq_transform_type_t; -// using ExpectedType = std::tuple, -// std::integral_constant, -// std::integral_constant>; -// CHECK(std::is_same::value); -// } - -// TEST_CASE("seq_transform") { -// struct F { -// template -// int operator()(std::integral_constant) { -// return X * X; -// } -// }; - -// using Seq = seq<1, 2, 3>; -// auto result = seq_transform(F{}, Seq{}); -// std::tuple expected{1, 4, 9}; -// CHECK(result == expected); -// } - -// TEST_CASE("seq_select") { -// struct F { -// template -// tl::optional operator()(std::integral_constant) { -// if (X % 2 == 0) { -// return X; -// } else { -// return tl::nullopt; -// } -// } -// }; - -// using Seq = seq<1, 2, 3, 4, 5>; -// int result = seq_select(F{}, Seq{}); -// CHECK(result == 2); -// } - -// TEST_CASE("seq_get") { -// struct F { -// template -// int operator()(std::integral_constant) { -// return X * X; -// } -// }; - -// using Seq = seq<1, 2, 3, 4, 5>; -// int result = seq_get(F{}, 3, Seq{}); -// CHECK(result == 16); -// } diff --git a/lib/utils/test/src/test_stack_map.cc b/lib/utils/test/src/test_stack_map.cc index 11d332afa4..21c1b07d1b 100644 --- a/lib/utils/test/src/test_stack_map.cc +++ b/lib/utils/test/src/test_stack_map.cc @@ -3,48 +3,50 @@ using namespace FlexFlow; -TEST_CASE("stack_map") { - stack_map map; - // Test the [] operator to insert and access elements - SUBCASE("BracketOperator") { - map[1] = 10; - map[2] = 20; - - CHECK(map[1] == 10); - CHECK(map[2] == 20); - } - - // Test the insert() function - SUBCASE("Insert") { - map.insert(1, 10); - map.insert(2, 20); - - CHECK(map[1] == 10); - CHECK(map[2] == 20); - } - - // Test the at() function to access elements - SUBCASE("At") { - map[1] = 10; - map[2] = 20; - - CHECK(map.at(1) == 10); - CHECK(map.at(2) == 20); - CHECK(map.at(1) != 20); - // Test const version of at() function - stack_map const &const_map = map; - CHECK(const_map.at(1) == 10); - CHECK(const_map.at(2) == 20); - } - - // Test the begin() and end() functions for iterator - SUBCASE("Iterator") { - map[1] = 10; - map[2] = 20; - map[3] = 30; - - std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; - std::vector> actual = map; - CHECK(actual == expected); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("stack_map") { + stack_map map; + // Test the [] operator to insert and access elements + SUBCASE("BracketOperator") { + map[1] = 10; + map[2] = 20; + + CHECK(map[1] == 10); + CHECK(map[2] == 20); + } + + // Test the insert() function + SUBCASE("Insert") { + map.insert(1, 10); + map.insert(2, 20); + + CHECK(map[1] == 10); + CHECK(map[2] == 20); + } + + // Test the at() function to access elements + SUBCASE("At") { + map[1] = 10; + map[2] = 20; + + CHECK(map.at(1) == 10); + CHECK(map.at(2) == 20); + CHECK(map.at(1) != 20); + // Test const version of at() function + stack_map const &const_map = map; + CHECK(const_map.at(1) == 10); + CHECK(const_map.at(2) == 20); + } + + // Test the begin() and end() functions for iterator + SUBCASE("Iterator") { + map[1] = 10; + map[2] = 20; + map[3] = 30; + + std::vector> expected = {{1, 10}, {2, 20}, {3, 30}}; + std::vector> actual = map; + CHECK(actual == expected); + } } } diff --git a/lib/utils/test/src/test_stack_string.cc b/lib/utils/test/src/test_stack_string.cc index 700b7d6a0f..1836e0824a 100644 --- a/lib/utils/test/src/test_stack_string.cc +++ b/lib/utils/test/src/test_stack_string.cc @@ -3,79 +3,81 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("StackStringConstruction", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("StackStringConstruction", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - SUBCASE("DefaultConstruction") { - StackString str; - CHECK(str.size() == 0); - CHECK(str.length() == 0); - CHECK(static_cast(str) == ""); - } + SUBCASE("DefaultConstruction") { + StackString str; + CHECK(str.size() == 0); + CHECK(str.length() == 0); + CHECK(static_cast(str) == ""); + } - SUBCASE("CStringConstruction") { - char const *cstr = "Hello"; - StackString str(cstr); - CHECK(str.size() == 5); - CHECK(str.length() == 5); - CHECK(static_cast(str) == "Hello"); - } + SUBCASE("CStringConstruction") { + char const *cstr = "Hello"; + StackString str(cstr); + CHECK(str.size() == 5); + CHECK(str.length() == 5); + CHECK(static_cast(str) == "Hello"); + } - SUBCASE("ShortCStringConstruction") { - char const *cstr = "CMU"; - StackString str(cstr); - CHECK(str.size() == 3); - CHECK(str.length() == 3); - CHECK(static_cast(str) == "CMU"); - } + SUBCASE("ShortCStringConstruction") { + char const *cstr = "CMU"; + StackString str(cstr); + CHECK(str.size() == 3); + CHECK(str.length() == 3); + CHECK(static_cast(str) == "CMU"); + } - SUBCASE("StdStringConstruction") { - std::basic_string stdStr = "World"; - StackString str(stdStr); - CHECK(str.size() == 5); - CHECK(str.length() == 5); - CHECK(static_cast(str) == "World"); + SUBCASE("StdStringConstruction") { + std::basic_string stdStr = "World"; + StackString str(stdStr); + CHECK(str.size() == 5); + CHECK(str.length() == 5); + CHECK(static_cast(str) == "World"); + } } -} -TEST_CASE_TEMPLATE("StackStringComparison", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringComparison", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - StackString str1{"abc"}; - StackString str2{"def"}; - StackString str3{"abc"}; + StackString str1{"abc"}; + StackString str2{"def"}; + StackString str3{"abc"}; - CHECK(str1 == str1); - CHECK(str1 == str3); - CHECK(str1 != str2); - CHECK(str2 != str3); - CHECK(str1 < str2); -} + CHECK(str1 == str1); + CHECK(str1 == str3); + CHECK(str1 != str2); + CHECK(str2 != str3); + CHECK(str1 < str2); + } -TEST_CASE_TEMPLATE("StackStringSize", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringSize", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - SUBCASE("EmptyString") { - StackString str; - CHECK(str.size() == 0); - CHECK(str.length() == 0); - } + SUBCASE("EmptyString") { + StackString str; + CHECK(str.size() == 0); + CHECK(str.length() == 0); + } - SUBCASE("NonEmptyString") { - StackString str{"Hello"}; - CHECK(str.size() == 5); - CHECK(str.length() == 5); + SUBCASE("NonEmptyString") { + StackString str{"Hello"}; + CHECK(str.size() == 5); + CHECK(str.length() == 5); + } } -} -TEST_CASE_TEMPLATE("StackStringConversion", T, char) { - constexpr std::size_t MAXSIZE = 5; - using StackString = stack_string; + TEST_CASE_TEMPLATE("StackStringConversion", T, char) { + constexpr std::size_t MAXSIZE = 5; + using StackString = stack_string; - StackString str{"Hello"}; - std::string stdStr = static_cast(str); - CHECK(stdStr == "Hello"); + StackString str{"Hello"}; + std::string stdStr = static_cast(str); + CHECK(stdStr == "Hello"); + } } diff --git a/lib/utils/test/src/test_stack_vector.cc b/lib/utils/test/src/test_stack_vector.cc index 08101527f9..6c0ecf36f3 100644 --- a/lib/utils/test/src/test_stack_vector.cc +++ b/lib/utils/test/src/test_stack_vector.cc @@ -4,74 +4,76 @@ using namespace FlexFlow; -TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - std::vector res = vector; - std::vector expected = {10}; - CHECK(res == expected); - - vector.push_back(20); - expected = {10, 20}; - res = vector; - CHECK(res == expected); -} - -TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - vector.push_back(20); - vector.push_back(30); - - CHECK(vector[0] == 10); - CHECK(vector[1] == 20); - CHECK(vector[2] == 30); -} - -TEST_CASE_TEMPLATE("Size", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - CHECK(vector.size() == 0); - - vector.push_back(10); - CHECK(vector.size() == 1); - - vector.push_back(20); - CHECK(vector.size() == 2); -} - -TEST_CASE_TEMPLATE("==", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector1, vector2; - - vector1.push_back(10); - vector1.push_back(15); - vector1.push_back(20); - - vector2.push_back(10); - vector2.push_back(15); - vector2.push_back(20); - - CHECK(vector1 == vector2); -} - -TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { - constexpr std::size_t MAXSIZE = 5; - using StackVector = stack_vector; - StackVector vector; - - vector.push_back(10); - CHECK(vector.back() == 10); - - vector.push_back(20); - CHECK(vector.back() == 20); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE("PushBack", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + std::vector res = vector; + std::vector expected = {10}; + CHECK(res == expected); + + vector.push_back(20); + expected = {10, 20}; + res = vector; + CHECK(res == expected); + } + + TEST_CASE_TEMPLATE("OperatorIndex", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + vector.push_back(20); + vector.push_back(30); + + CHECK(vector[0] == 10); + CHECK(vector[1] == 20); + CHECK(vector[2] == 30); + } + + TEST_CASE_TEMPLATE("Size", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + CHECK(vector.size() == 0); + + vector.push_back(10); + CHECK(vector.size() == 1); + + vector.push_back(20); + CHECK(vector.size() == 2); + } + + TEST_CASE_TEMPLATE("==", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector1, vector2; + + vector1.push_back(10); + vector1.push_back(15); + vector1.push_back(20); + + vector2.push_back(10); + vector2.push_back(15); + vector2.push_back(20); + + CHECK(vector1 == vector2); + } + + TEST_CASE_TEMPLATE("EmplaceBack", T, int, double, char) { + constexpr std::size_t MAXSIZE = 5; + using StackVector = stack_vector; + StackVector vector; + + vector.push_back(10); + CHECK(vector.back() == 10); + + vector.push_back(20); + CHECK(vector.back() == 20); + } } diff --git a/lib/utils/test/src/test_tuple.cc b/lib/utils/test/src/test_tuple.cc index 344a2cd0fb..31308dec2c 100644 --- a/lib/utils/test/src/test_tuple.cc +++ b/lib/utils/test/src/test_tuple.cc @@ -6,74 +6,76 @@ using namespace FlexFlow; -TEST_CASE("get function") { - std::tuple t(42, 3.14f, 2.71828); - - SUBCASE("get mutable reference") { - int &result = get(t); - CHECK(result == 42); - - result = 100; - CHECK(std::get<0>(t) == 100); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get function") { + std::tuple t(42, 3.14f, 2.71828); + + SUBCASE("get mutable reference") { + int &result = get(t); + CHECK(result == 42); + + result = 100; + CHECK(std::get<0>(t) == 100); + } + + SUBCASE("get rvalue reference") { + int &&result = get(std::move(t)); + CHECK(result == 42); + + // t is in a valid but unspecified state after move + CHECK(std::get<0>(t) == 42); // Uncomment this line to check the behavior + } + + SUBCASE("get const reference") { + int const &result = get(t); + CHECK(result == 42); + } + + SUBCASE("get const rvalue reference") { + int const &&result = get(std::move(t)); + CHECK(result == 42); + } } - SUBCASE("get rvalue reference") { - int &&result = get(std::move(t)); - CHECK(result == 42); + TEST_CASE("tuple_prepend function") { + std::tuple t1(3.14f, 2.71828); + int value = 42; - // t is in a valid but unspecified state after move - CHECK(std::get<0>(t) == 42); // Uncomment this line to check the behavior + auto result = tuple_prepend(value, t1); + std::tuple expected(42, 3.14f, 2.71828); + CHECK(result == expected); } - SUBCASE("get const reference") { - int const &result = get(t); - CHECK(result == 42); + TEST_CASE("Testing tuple_head_t") { + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple<>>::value); } - SUBCASE("get const rvalue reference") { - int const &&result = get(std::move(t)); - CHECK(result == 42); + TEST_CASE("Testing tuple_slice_t") { + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple>::value); + CHECK(std::is_same>, + std::tuple>::value); } -} - -TEST_CASE("tuple_prepend function") { - std::tuple t1(3.14f, 2.71828); - int value = 42; - auto result = tuple_prepend(value, t1); - std::tuple expected(42, 3.14f, 2.71828); - CHECK(result == expected); -} - -TEST_CASE("Testing tuple_head_t") { - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple<>>::value); -} + TEST_CASE("Testing tuple_compare function") { + std::tuple tup1{1, 3.14, 'a'}; + std::tuple tup2{1, 3.14, 'a'}; + std::tuple tup3{2, 3.14, 'b'}; -TEST_CASE("Testing tuple_slice_t") { - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple>::value); - CHECK(std::is_same>, - std::tuple>::value); -} - -TEST_CASE("Testing tuple_compare function") { - std::tuple tup1{1, 3.14, 'a'}; - std::tuple tup2{1, 3.14, 'a'}; - std::tuple tup3{2, 3.14, 'b'}; - - CHECK(tuple_compare(tup1, tup2)); - CHECK(!tuple_compare(tup1, tup3)); -} + CHECK(tuple_compare(tup1, tup2)); + CHECK(!tuple_compare(tup1, tup3)); + } -TEST_CASE("Testing get function with valid index") { - std::tuple tup{1, 3.14, 'a'}; + TEST_CASE("Testing get function with valid index") { + std::tuple tup{1, 3.14, 'a'}; - CHECK(get(tup) == 1); - CHECK(get(tup) == 3.14); - CHECK(get(tup) == 'a'); + CHECK(get(tup) == 1); + CHECK(get(tup) == 3.14); + CHECK(get(tup) == 'a'); + } } diff --git a/lib/utils/test/src/test_type_index.cc b/lib/utils/test/src/test_type_index.cc index 1b9a811846..b2d8aea848 100644 --- a/lib/utils/test/src/test_type_index.cc +++ b/lib/utils/test/src/test_type_index.cc @@ -4,30 +4,32 @@ using namespace FlexFlow; -TEST_CASE("type_index function") { - SUBCASE("int type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(int); - CHECK(idx == expected_idx); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("type_index function") { + SUBCASE("int type") { + std::type_index idx = type_index(); + std::type_index expected_idx = typeid(int); + CHECK(idx == expected_idx); + } - SUBCASE("string type") { - std::type_index idx = type_index(); - std::type_index expected_idx = typeid(std::string); - CHECK(idx == expected_idx); + SUBCASE("string type") { + std::type_index idx = type_index(); + std::type_index expected_idx = typeid(std::string); + CHECK(idx == expected_idx); + } } -} -TEST_CASE("matches function") { - std::type_index idx = typeid(float); + TEST_CASE("matches function") { + std::type_index idx = typeid(float); - SUBCASE("matching type") { - bool result = matches(idx); - CHECK(result == true); - } + SUBCASE("matching type") { + bool result = matches(idx); + CHECK(result == true); + } - SUBCASE("non-matching type") { - bool result = matches(idx); - CHECK(result == false); + SUBCASE("non-matching type") { + bool result = matches(idx); + CHECK(result == false); + } } } diff --git a/lib/utils/test/src/test_undirected_graph.cc b/lib/utils/test/src/test_undirected_graph.cc index c6f2003ee4..3616ee59aa 100644 --- a/lib/utils/test/src/test_undirected_graph.cc +++ b/lib/utils/test/src/test_undirected_graph.cc @@ -31,30 +31,32 @@ using namespace rc; /* static_assert(is_streamable::value, ""); */ /* static_assert(is_fmtable::value, ""); */ -TEST_CASE_TEMPLATE("UndirectedGraph implementations", - T, - HashmapUndirectedGraph) { - - rc::dc_check("Full", [&]() { - UndirectedGraph g = UndirectedGraph::create(); - int num_nodes = *gen::inRange(1, 10); - std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); - int num_edges = *gen::inRange(0, num_nodes); - std::vector e; - if (num_nodes > 0) { - e = *gen::unique>( - num_edges, - gen::construct(gen::elementOf(n), gen::elementOf(n))); - } - for (UndirectedEdge const &edge : e) { - g.add_edge(edge); - } - - CHECK(g.query_nodes(NodeQuery::all()) == without_order(n)); - - auto subset = *rc::subset_of(n); - CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); - - CHECK(g.query_edges(UndirectedEdgeQuery::all()) == without_order(e)); - }); +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE_TEMPLATE( + "UndirectedGraph implementations", T, HashmapUndirectedGraph) { + + rc::dc_check("Full", [&]() { + UndirectedGraph g = UndirectedGraph::create(); + int num_nodes = *gen::inRange(1, 10); + std::vector n = repeat(num_nodes, [&] { return g.add_node(); }); + int num_edges = *gen::inRange(0, num_nodes); + std::vector e; + if (num_nodes > 0) { + e = *gen::unique>( + num_edges, + gen::construct(gen::elementOf(n), + gen::elementOf(n))); + } + for (UndirectedEdge const &edge : e) { + g.add_edge(edge); + } + + CHECK(g.query_nodes(NodeQuery::all()) == without_order(n)); + + auto subset = *rc::subset_of(n); + CHECK(g.query_nodes(NodeQuery{query_set{subset}}) == subset); + + CHECK(g.query_edges(UndirectedEdgeQuery::all()) == without_order(e)); + }); + } } diff --git a/lib/utils/test/src/test_variant.cc b/lib/utils/test/src/test_variant.cc index 031defd417..0fef782c0e 100644 --- a/lib/utils/test/src/test_variant.cc +++ b/lib/utils/test/src/test_variant.cc @@ -1,64 +1,72 @@ #include "test/utils/doctest.h" #include "utils/variant.h" -TEST_CASE("widen and narrow functions") { - SUBCASE("widen function") { - variant v1 = 42; - variant result = widen>(v1); - variant expected = 42; - CHECK(result == expected); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("widen and narrow functions") { + SUBCASE("widen function") { + std::variant v1 = 42; + std::variant result = + widen>(v1); + std::variant expected = 42; + CHECK(result == expected); + } - SUBCASE("narrow function fail") { - variant v2 = - 3.14; // this is a doule, because 3.14 default to double - optional> result = narrow>(v2); - optional> expected = float(3.14); - CHECK(!result.has_value()); // result should be empty due to narrowing - } + SUBCASE("narrow function fail") { + std::variant v2 = + 3.14; // this is a doule, because 3.14 default to double + std::optional> result = + narrow>(v2); + std::optional> expected = float(3.14); + CHECK(!result.has_value()); // result should be empty due to narrowing + } - SUBCASE("narrow function success") { - variant v2 = - 3.14; // this is a doule, because 3.14 default to double - optional> result = narrow>(v2); - optional> expected = 3.14; - CHECK(result == expected); // - } + SUBCASE("narrow function success") { + std::variant v2 = + 3.14; // this is a doule, because 3.14 default to double + std::optional> result = + narrow>(v2); + std::optional> expected = 3.14; + CHECK(result == expected); // + } - SUBCASE("cast function") { - variant v3 = 42; - optional> result = cast>(v3); - optional> expected = 42; - CHECK(result == expected); + SUBCASE("cast function") { + std::variant v3 = 42; + std::optional> result = + cast>(v3); + std::optional> expected = 42; + CHECK(result == expected); + } } -} -TEST_CASE("Narrow and cast variants") { - variant original_variant = 42; + TEST_CASE("Narrow and cast variants") { + std::variant original_variant = 42; - // narrow - optional> narrow_result = - narrow>(original_variant); - CHECK(narrow_result.has_value()); // assert narrow has value + // narrow + std::optional> narrow_result = + narrow>(original_variant); + CHECK(narrow_result.has_value()); // assert narrow has value - // cast - optional> cast_result = - cast>(narrow_result.value()); - CHECK(cast_result.has_value()); // assert cast has value - CHECK(get(cast_result.value()) == 42); -} + // cast + std::optional> cast_result = + cast>(narrow_result.value()); + CHECK(cast_result.has_value()); // assert cast has value + CHECK(get(cast_result.value()) == 42); + } -TEST_CASE("casting and widening a variant") { - variant smaller_variant = 42; - variant wider_variant; + TEST_CASE("casting and widening a variant") { + std::variant smaller_variant = 42; + std::variant wider_variant; - // Perform the cast operation - optional> cast_result = cast>(smaller_variant); - REQUIRE(cast_result); // Ensure the cast was successful + // Perform the cast operation + std::optional> cast_result = + cast>(smaller_variant); + REQUIRE(cast_result); // Ensure the cast was successful - // Perform the widening operation - wider_variant = widen>(cast_result.value()); + // Perform the widening operation + wider_variant = + widen>(cast_result.value()); - // Check the result - CHECK(get(wider_variant) == 42); + // Check the result + CHECK(get(wider_variant) == 42); + } } diff --git a/lib/utils/test/src/test_vector.cc b/lib/utils/test/src/test_vector.cc index 5eba16c312..4bdc724dd8 100644 --- a/lib/utils/test/src/test_vector.cc +++ b/lib/utils/test/src/test_vector.cc @@ -1,29 +1,31 @@ #include "test/utils/doctest.h" #include "utils/vector.h" -TEST_CASE("concat function") { - SUBCASE("concatenates two vectors") { - std::vector v1 = {1, 2, 3}; - std::vector v2 = {4, 5, 6}; - std::vector result = concat(v1, v2); - std::vector expected = {1, 2, 3, 4, 5, 6}; - CHECK(result == expected); - } +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("concat function") { + SUBCASE("concatenates two vectors") { + std::vector v1 = {1, 2, 3}; + std::vector v2 = {4, 5, 6}; + std::vector result = concat(v1, v2); + std::vector expected = {1, 2, 3, 4, 5, 6}; + CHECK(result == expected); + } - SUBCASE("concatenates two string vectors") { - std::vector v1 = {"1", "2", "3"}; - std::vector v2 = {"4", "5", "6"}; - std::vector result = concat(v1, v2); - std::vector expected = {"1", "2", "3", "4", "5", "6"}; - CHECK(result == expected); - } + SUBCASE("concatenates two string vectors") { + std::vector v1 = {"1", "2", "3"}; + std::vector v2 = {"4", "5", "6"}; + std::vector result = concat(v1, v2); + std::vector expected = {"1", "2", "3", "4", "5", "6"}; + CHECK(result == expected); + } - SUBCASE("concatenates multiple vectors") { - std::vector v1 = {1, 2, 3}; - std::vector v2 = {4, 5, 6}; - std::vector v3 = {7, 8, 9}; - std::vector result = concat(v1, v2, v3); - std::vector expected = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - CHECK(result == expected); + SUBCASE("concatenates multiple vectors") { + std::vector v1 = {1, 2, 3}; + std::vector v2 = {4, 5, 6}; + std::vector v3 = {7, 8, 9}; + std::vector result = concat(v1, v2, v3); + std::vector expected = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + CHECK(result == expected); + } } }