diff --git a/.flake/pkgs/ffdb/default.nix b/.flake/pkgs/ffdb/default.nix new file mode 100644 index 0000000000..8e3989372a --- /dev/null +++ b/.flake/pkgs/ffdb/default.nix @@ -0,0 +1,40 @@ +{ lib +, stdenv +, makeWrapper +, gdb +, python3 +, proj +}: + +stdenv.mkDerivation rec { + pname = "ffdb"; + version = "0.1"; + + pythonPath = with python3.pkgs; makePythonPath [ + proj + ]; + + dontBuild = true; + + nativeBuildInputs = [ makeWrapper ]; + + src = ./.; + + installPhase = '' + mkdir -p $out/share/ffdb + cp ffdb.py $out/share/ffdb + makeWrapper ${gdb}/bin/gdb $out/bin/gdb \ + --add-flags "-q -x $out/share/ffdb/ffdb.py" \ + --set NIX_PYTHONPATH ${pythonPath} \ + --prefix PATH : ${lib.makeBinPath [ + python3 + ]} + cp $out/bin/gdb $out/bin/ffdb + ''; + + nativeCheckInputs = [ + gdb + python3 + proj + ]; +} diff --git a/.flake/pkgs/ffdb/ffdb.py b/.flake/pkgs/ffdb/ffdb.py new file mode 100644 index 0000000000..84354ccd82 --- /dev/null +++ b/.flake/pkgs/ffdb/ffdb.py @@ -0,0 +1,7 @@ +from proj.config_file import get_config_root +from pathlib import Path +import gdb + +gdb.execute(f'directory {get_config_root(Path.cwd())}') +gdb.prompt_hook = lambda x: '(ffdb) ' +gdb.execute('set history save on') diff --git a/.github/runs-on.yml b/.github/runs-on.yml new file mode 100644 index 0000000000..14f75549dd --- /dev/null +++ b/.github/runs-on.yml @@ -0,0 +1,11 @@ +images: + dlami-x64: + platform: "linux" + arch: "x64" + owner: "898082745236" # AWS + name: "Deep Learning Base OSS Nvidia Driver GPU AMI (Ubuntu 22.04)*" + +runners: + gpu-nvidia: + family: ["g4dn.xlarge"] + image: dlami-x64 diff --git a/.gitignore b/.gitignore index 397ac0974d..4b40a016df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# gdb history +.gdb_history + # dtgen files *.dtg.cc *.dtg.h diff --git a/README.md b/README.md index 8a4d852245..216b32fd52 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,21 @@ -# FlexFlow -![build](https://github.com/flexflow/flexflow/workflows/build/badge.svg?branch=master) ![gpu tests](https://github.com/flexflow/flexflow/workflows/gpu-ci/badge.svg?branch=master) ![multinode gpu tests](https://github.com/flexflow/flexflow/workflows/multinode-test/badge.svg?branch=master) ![docker](https://github.com/flexflow/flexflow/workflows/docker-build/badge.svg?branch=master) ![pip](https://github.com/flexflow/flexflow/workflows/pip-install/badge.svg?branch=master) ![shell-check](https://github.com/flexflow/flexflow/workflows/Shell%20Check/badge.svg?branch=master) ![clang-format](https://github.com/flexflow/flexflow/workflows/clang-format%20Check/badge.svg?branch=master) [![Documentation Status](https://readthedocs.org/projects/flexflow/badge/?version=latest)](https://flexflow.readthedocs.io/en/latest/?badge=latest) +# flexflow-train +[![clang-format Check](https://github.com/flexflow/flexflow-train/actions/workflows/clang-format-check.yml/badge.svg?branch=master)](https://github.com/flexflow/flexflow-train/actions/workflows/clang-format-check.yml) +[![per-lib-checks](https://github.com/flexflow/flexflow-train/actions/workflows/per-lib-check.yml/badge.svg)](https://github.com/flexflow/flexflow-train/actions/workflows/per-lib-check.yml) +[![shell-check](https://github.com/flexflow/flexflow-train/actions/workflows/shell-check.yml/badge.svg)](https://github.com/flexflow/flexflow-train/actions/workflows/shell-check.yml) +[![Documentation Status](https://readthedocs.org/projects/flexflow/badge/?version=latest)](https://flexflow.readthedocs.io/en/latest/?badge=latest) -FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization strategies. FlexFlow provides a drop-in replacement for PyTorch and TensorFlow Keras. Running existing PyTorch and Keras programs in FlexFlow only requires [a few lines of changes to the program](https://flexflow.ai/keras). +> [!WARNING] +> The FlexFlow repository has been split into separate [flexflow-train](https://github.com/flexflow/flexflow-train) and [flexflow-serve](https://github.com/flexflow/flexflow-serve) repositories. +> You are currently viewing [flexflow-train](https://github.com/flexflow/flexflow-train). +> For anything inference/serving-related, go to [flexflow-serve](https://github.com/flexflow/flexflow-serve). +FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization strategies. + + + + ## Contributing -Please let us know if you encounter any bugs or have any suggestions by [submitting an issue](https://github.com/flexflow/flexflow/issues). +Please let us know if you encounter any bugs or have any suggestions by [submitting an issue](https://github.com/flexflow/flexflow-train/issues). We welcome all contributions to FlexFlow from bug fixes to new features and extensions. diff --git a/cmake/flexflow-utils.cmake b/cmake/flexflow-utils.cmake index 90e100bb1b..7ba39e92c9 100644 --- a/cmake/flexflow-utils.cmake +++ b/cmake/flexflow-utils.cmake @@ -39,7 +39,7 @@ function(ff_set_cxx_properties target) CXX_EXTENSIONS NO ) target_compile_options(${target} - PRIVATE $<$:> # add C++ compile flags here + PRIVATE $<$:> "-ffile-prefix-map=${CMAKE_SOURCE_DIR}=." # add C++ compile flags here ) endfunction() diff --git a/flake.lock b/flake.lock index 87fae7f446..1fb4f26189 100644 --- a/flake.lock +++ b/flake.lock @@ -43,11 +43,11 @@ ] }, "locked": { - "lastModified": 1728341842, - "narHash": "sha256-XMS52KBSS6z3k2VaiVcHyZQD6b2QUm1wIvTClel4xwg=", + "lastModified": 1731206929, + "narHash": "sha256-5O85Ydkk4AG8F3Y5pFj3aywCZwGqmvOj1DFnIXgfyxs=", "owner": "lockshaw", "repo": "proj", - "rev": "830fb5b1a0c7087752693990e90bbbf021168dfe", + "rev": "99d4df1a81b3b7a6595e9e7913b20f9e6a7f5e21", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index afbc2c1e37..38e59a81be 100644 --- a/flake.nix +++ b/flake.nix @@ -35,10 +35,13 @@ mkShell = pkgs.mkShell.override { stdenv = pkgs.cudaPackages.backendStdenv; }; + + proj = proj-repo.packages.${system}.proj; in { packages = { legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + ffdb = pkgs.callPackage ./.flake/pkgs/ffdb { inherit proj; }; hpp2plantuml = pkgs.python3Packages.callPackage ./.flake/pkgs/hpp2plantuml.nix { }; rapidcheckFull = pkgs.symlinkJoin { name = "rapidcheckFull"; @@ -102,9 +105,7 @@ doxygen lcov # for code coverage ]) - (with proj-repo.packages.${system}; [ - proj - ]) + [ proj ] (with self.packages.${system}; [ legion hpp2plantuml @@ -128,7 +129,6 @@ gh-markdown-preview shellcheck plantuml - gdb ruff compdb jq @@ -148,6 +148,9 @@ black toml ]) + (with self.packages.${system}; [ + ffdb + ]) ]; }; }; diff --git a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h index 65bae0c76a..ecaffa337b 100644 --- a/lib/compiler/include/compiler/cost_estimator/cost_estimator.h +++ b/lib/compiler/include/compiler/cost_estimator/cost_estimator.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_COST_ESTIMATOR_COST_ESTIMATOR_H #include "compiler/cost_estimator/op_cost_estimate_key.dtg.h" +#include "compiler/cost_estimator/op_cost_metrics.dtg.h" #include "compiler/cost_estimator/tensor_set_movement.dtg.h" #include "op-attrs/parallel_tensor_shape.dtg.h" #include "op-attrs/pcg_operator_attrs.dtg.h" @@ -11,7 +12,7 @@ namespace FlexFlow { struct ICostEstimator { - virtual float estimate_cost(OpCostEstimateKey const &) const = 0; + virtual OpCostMetrics estimate_cost(OpCostEstimateKey const &) const = 0; virtual float estimate_cost(TensorSetMovement const &) const = 0; ICostEstimator() = default; @@ -23,7 +24,7 @@ struct ICostEstimator { CHECK_RC_COPY_VIRTUAL_COMPLIANT(ICostEstimator); struct CostEstimator { - float estimate_cost(OpCostEstimateKey const &k) const; + OpCostMetrics estimate_cost(OpCostEstimateKey const &) const; float estimate_cost(TensorSetMovement const &m) const; template diff --git a/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml new file mode 100644 index 0000000000..f137935a4d --- /dev/null +++ b/lib/compiler/include/compiler/cost_estimator/op_cost_metrics.struct.toml @@ -0,0 +1,18 @@ +namespace = "FlexFlow" +name = "OpCostMetrics" +features = [ + "eq", + "fmt", + "hash", +] + +includes = [ +] + +[[fields]] +name = "runtime" +type = "float" + +[[fields]] +name = "memory" +type = "size_t" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h b/lib/compiler/include/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h new file mode 100644 index 0000000000..d176d298db --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h @@ -0,0 +1,48 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_GET_OPTIMAL_MACHINE_MAPPING_WITH_MEMORY_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_GET_OPTIMAL_MACHINE_MAPPING_WITH_MEMORY_H + +#include "compiler/machine_mapping/machine_mapping_cache.dtg.h" +#include "compiler/machine_mapping/machine_mapping_constraints.dtg.h" +#include "compiler/machine_mapping/machine_mapping_context.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_parallel_split.dtg.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/mm_problem_tree_series_split.dtg.h" +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" +#include "pcg/machine_specification.dtg.h" + +namespace FlexFlow { + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation); + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.struct.toml new file mode 100644 index 0000000000..b61dd134c0 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "MachineMappingForSingleLayer" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.dtg.h", + "compiler/cost_estimator/op_cost_metrics.dtg.h", +] + +[[fields]] +name = "cost" +type = "::FlexFlow::OpCostMetrics" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::ParallelLayerGuidObliviousMachineMapping" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h new file mode 100644 index 0000000000..b749235c89 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_MACHINE_MAPPING_CACHE_WITH_MEMORY_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_MACHINE_MAPPING_CACHE_WITH_MEMORY_H + +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.dtg.h" + +namespace FlexFlow { + +MachineMappingWithMemoryCache empty_machine_mapping_with_memory_cache(); +std::optional + machine_mapping_with_memory_cache_load( + MachineMappingWithMemoryCache const &, MachineMappingState const &); +void machine_mapping_with_memory_cache_save( + MachineMappingWithMemoryCache &, + MachineMappingState const &, + MachineMappingWithMemoryResult const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.struct.toml new file mode 100644 index 0000000000..c2fe393e99 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.struct.toml @@ -0,0 +1,22 @@ +namespace = "FlexFlow" +name = "MachineMappingWithMemoryCache" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "compiler/machine_mapping/machine_mapping_state.dtg.h", + "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_map.h", + "utils/hash/unordered_map.h", +] + +[[fields]] +name = "raw_map" +type = "std::unordered_map<::FlexFlow::MachineMappingState, ::FlexFlow::MachineMappingWithMemoryResult>" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h new file mode 100644 index 0000000000..0383376116 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_MACHINE_MAPPING_RESULT_WITH_MEMORY_H +#define _FLEXFLOW_COMPILER_MACHINE_MAPPING_MEMORY_OPTIMIZATION_MACHINE_MAPPING_RESULT_WITH_MEMORY_H + +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.dtg.h" +#include "compiler/machine_mapping/parallel_split_transformation.dtg.h" +#include + +namespace FlexFlow { + +[[nodiscard]] MachineMappingWithMemoryResult + empty_machine_mapping_with_memory_result(); +[[nodiscard]] bool is_empty(MachineMappingWithMemoryResult const &); + +[[nodiscard]] MachineMappingWithMemoryResult get_mapping_with_minimal_runtime( + std::unordered_set const &); + +[[nodiscard]] MachineMappingWithMemoryResult + remove_non_pareto_optimal_machine_mapping_result( + MachineMappingWithMemoryResult const &); + +[[nodiscard]] MachineMappingWithMemoryResult + series_combine(float comm_cost, + MachineMappingWithMemoryResult const &pre_result, + MachineMappingWithMemoryResult const &post_result, + std::optional const + ¶llel_split_transformation); +[[nodiscard]] MachineMappingWithMemoryResult + parallel_combine(MachineMappingWithMemoryResult const &lhs_result, + MachineMappingWithMemoryResult const &rhs_result); + +[[nodiscard]] MachineMappingWithMemoryResult + minimize_runtime(MachineMappingWithMemoryResult const &m1, + MachineMappingWithMemoryResult const &m2); + +[[nodiscard]] MachineMappingWithMemoryResult + make_singleton_machine_mapping_with_memory_result( + OpCostMetrics cost, MachineView const &machine_view); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.struct.toml new file mode 100644 index 0000000000..c1e1ee1cac --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.struct.toml @@ -0,0 +1,20 @@ +namespace = "FlexFlow" +name = "MachineMappingWithMemoryResult" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "compiler/machine_mapping/memory_optimization/machine_mapping_for_single_layer.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_set.h", + "utils/fmt/unordered_set.h", +] + +[[fields]] +name = "machine_mappings" +type = "std::unordered_set<::FlexFlow::MachineMappingForSingleLayer>" diff --git a/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.struct.toml b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.struct.toml new file mode 100644 index 0000000000..0d2572c783 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/memory_optimization/machine_memory_constraints.struct.toml @@ -0,0 +1,13 @@ +namespace = "FlexFlow" +name = "MachineMemoryConstraints" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [] + +[[fields]] +name = "memory_limit" +type = "size_t" diff --git a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc index 051ffcd190..6ac6e3a8d6 100644 --- a/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc +++ b/lib/compiler/src/compiler/cost_estimator/cost_estimator.cc @@ -5,7 +5,7 @@ namespace FlexFlow { CostEstimator::CostEstimator(std::shared_ptr implementation_ptr) : implementation_ptr(implementation_ptr) {} -float CostEstimator::estimate_cost(OpCostEstimateKey const &k) const { +OpCostMetrics CostEstimator::estimate_cost(OpCostEstimateKey const &k) const { return this->implementation_ptr->estimate_cost(k); } diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 10abd7ff90..5bdd8645a5 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -240,7 +240,7 @@ MachineMappingResult auto get_mapping_result = [&](MachineView const &machine_view) { OpCostEstimateKey mapped = map_unmapped_op_cost_estimate_key(leaf, machine_view); - float cost = context.cost_estimator.estimate_cost(mapped); + float cost = context.cost_estimator.estimate_cost(mapped).runtime; return make_singleton_machine_mapping_result(cost, machine_view); }; diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc new file mode 100644 index 0000000000..b67083e8cd --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -0,0 +1,264 @@ +#include "compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/get_machine_resource_splits.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h" +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h" +#include "compiler/machine_mapping/transitive_reduced_pcg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.dtg.h" +#include "compiler/series_parallel/pcg/pcg_binary_sp_decomposition.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_specification.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/containers/contains.h" +#include "utils/containers/flatmap.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_all_assignments.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/exception.h" +#include "utils/overload.h" + +namespace FlexFlow { + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &context, + MachineMappingProblemTree const &problem_tree, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingState state = MachineMappingState{ + problem_tree, + resources, + constraints, + }; + + { + std::optional cached_result = + machine_mapping_with_memory_cache_load(result_cache, state); + if (cached_result) { + return cached_result.value(); + } + } + + MachineMappingWithMemoryResult result = + problem_tree.visit(overload{ + [&](MMProblemTreeSeriesSplit const &series_split) { + return get_optimal_machine_mapping_with_memory( + result_cache, + context, + series_split, + resources, + constraints, + /*parallel_split_transformation=*/std::nullopt); + }, + [&](auto const &decomp_tree_node) { + return get_optimal_machine_mapping_with_memory(result_cache, + context, + decomp_tree_node, + resources, + constraints); + }, + }); + + machine_mapping_with_memory_cache_save(result_cache, state, result); + return result; +} + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeSeriesSplit const &series_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints, + std::optional const + ¶llel_split_transformation) { + + auto get_boundary_machine_view_assignments = + [&](std::unordered_set const &boundary_layers) + -> std::unordered_set { + std::unordered_map> + allowed = generate_map( + boundary_layers, + [&](BinaryTreePath const &l) -> std::unordered_set { + UnmappedOpCostEstimateKey leaf = + mm_problem_tree_get_subtree_at_path( + MachineMappingProblemTree{series_split}, l) + .value() + .get(); + return context.allowed_machine_views(leaf, resources); + }); + return transform( + get_all_assignments(allowed), + [](std::unordered_map const &m) { + return ParallelLayerGuidObliviousMachineMapping{m}; + }); + }; + + auto eval_pre_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views) { + MachineMappingConstraints pre_candidate = with_additional_constraints( + restrict_to_left_child(constraints), assigned_pre_machine_views); + + MachineMappingWithMemoryResult pre_result = + get_optimal_machine_mapping_with_memory( + result_cache, + context, + series_split.get_left_child(), + resources, + pre_candidate); + + return pre_result; + }; + + auto eval_post_boundary_mapping = + [&](ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views) { + MachineMappingConstraints post_candidate = with_additional_constraints( + restrict_to_right_child(constraints), assigned_post_machine_views); + + MachineMappingWithMemoryResult post_result = + get_optimal_machine_mapping_with_memory( + result_cache, + context, + series_split.get_right_child(), + resources, + post_candidate); + + return post_result; + }; + + MachineMappingWithMemoryResult result = + empty_machine_mapping_with_memory_result(); + AbstractedTensorSetMovement tensor_movement = + series_split.tensor_set_movement; + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_pre_machine_views : + get_boundary_machine_view_assignments(get_src_layers(tensor_movement))) { + + MachineMappingWithMemoryResult pre_result = + eval_pre_boundary_mapping(assigned_pre_machine_views); + + for (ParallelLayerGuidObliviousMachineMapping const + &assigned_post_machine_views : + get_boundary_machine_view_assignments( + get_dst_layers(tensor_movement))) { + + MachineMappingWithMemoryResult post_result = + eval_post_boundary_mapping(assigned_post_machine_views); + + TensorSetMovement comm_across_split = + concretize_abstracted_tensor_set_movement( + tensor_movement, + /*pre_mapping=*/assigned_pre_machine_views, + /*post_mapping=*/assigned_post_machine_views); + float cost_across_split = + context.cost_estimator.estimate_cost(comm_across_split); + + result = minimize_runtime(result, + series_combine(cost_across_split, + pre_result, + post_result, + parallel_split_transformation)); + } + } + + return result; +} + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &context, + MMProblemTreeParallelSplit const ¶llel_split, + MachineSpecification const &resources, + MachineMappingConstraints const &constraints) { + + MachineMappingProblemTree lhs = parallel_split.get_left_child(); + MachineMappingProblemTree rhs = parallel_split.get_right_child(); + + MachineMappingWithMemoryResult series_result = [&] { + MMProblemTreeSeriesSplit series_split = MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/empty_abstracted_tensor_set_movement(), + /*left_child=*/lhs, + /*right_child=*/rhs, + }; + + return get_optimal_machine_mapping_with_memory( + result_cache, + context, + series_split, + resources, + constraints, + ParallelSplitTransformation::LthenR); + }(); + + MachineMappingConstraints left_constraints = + restrict_to_left_child(constraints); + MachineMappingConstraints right_constraints = + restrict_to_right_child(constraints); + + auto evaluate_resource_split = + [&](std::pair const + &resource_split) { + MachineMappingWithMemoryResult left_result = + get_optimal_machine_mapping_with_memory(result_cache, + context, + lhs, + resource_split.first, + left_constraints); + MachineMappingWithMemoryResult right_result = + get_optimal_machine_mapping_with_memory(result_cache, + context, + rhs, + resource_split.second, + right_constraints); + + return parallel_combine(left_result, right_result); + }; + + std::unordered_set parallel_results = + transform(get_machine_resource_splits(resources), + evaluate_resource_split); + + return minimize_runtime(series_result, + get_mapping_with_minimal_runtime(parallel_results)); +} + +MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( + MachineMappingWithMemoryCache &result_cache, + MachineMappingContext const &context, + UnmappedOpCostEstimateKey const &leaf, + MachineSpecification const &resource, + MachineMappingConstraints const &constraints) { + + std::unordered_set candidates = [&] { + std::optional machine_view = require_only_root(constraints); + if (machine_view.has_value()) { + return std::unordered_set{machine_view.value()}; + } else { + return context.allowed_machine_views(leaf, resource); + } + }(); + + auto get_mapping_result = [&](MachineView const &machine_view) { + OpCostEstimateKey mapped = + map_unmapped_op_cost_estimate_key(leaf, machine_view); + OpCostMetrics cost = context.cost_estimator.estimate_cost(mapped); + + return make_singleton_machine_mapping_with_memory_result(cost, + machine_view); + }; + + std::unordered_set candidate_results = + transform(candidates, get_mapping_result); + + return get_mapping_with_minimal_runtime(candidate_results); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc new file mode 100644 index 0000000000..617ba682be --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.cc @@ -0,0 +1,32 @@ +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/try_at.h" + +namespace FlexFlow { + +MachineMappingWithMemoryCache empty_machine_mapping_with_memory_cache() { + return MachineMappingWithMemoryCache{{}}; +} + +std::optional + machine_mapping_with_memory_cache_load( + MachineMappingWithMemoryCache const &cache, + MachineMappingState const &k) { + return try_at(cache.raw_map, k); +} + +void machine_mapping_with_memory_cache_save( + MachineMappingWithMemoryCache &cache, + MachineMappingState const &k, + MachineMappingWithMemoryResult const &v) { + if (contains_key(cache.raw_map, k)) { + throw mk_runtime_error(fmt::format( + "machine_mapping_with_memory_cache_save expected key to not already " + "exist, but received existing key {}", + k)); + } + + cache.raw_map.emplace(k, v); +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc new file mode 100644 index 0000000000..a6c2d1ed04 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.cc @@ -0,0 +1,142 @@ +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h" +#include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" +#include "utils/containers/set_union.h" +#include "utils/full_binary_tree/binary_tree_path.h" + +namespace FlexFlow { + +MachineMappingWithMemoryResult empty_machine_mapping_with_memory_result() { + return MachineMappingWithMemoryResult{ + {}, + }; +} + +MachineMappingWithMemoryResult get_mapping_with_minimal_runtime( + std::unordered_set const &candidates) { + MachineMappingWithMemoryResult result = + empty_machine_mapping_with_memory_result(); + + for (MachineMappingWithMemoryResult const &candidate : candidates) { + result = minimize_runtime(result, candidate); + } + + return result; +} + +MachineMappingWithMemoryResult remove_non_pareto_optimal_machine_mapping_result( + MachineMappingWithMemoryResult const &result) { + std::unordered_set non_pareto_optimal_mappings; + for (MachineMappingForSingleLayer const &mapping : result.machine_mappings) { + bool is_pareto_optimal = true; + for (MachineMappingForSingleLayer const &other_mapping : + result.machine_mappings) { + if (mapping.cost.runtime >= other_mapping.cost.runtime && + mapping.cost.memory >= other_mapping.cost.memory && + mapping != other_mapping) { + is_pareto_optimal = false; + break; + } + } + if (is_pareto_optimal) { + non_pareto_optimal_mappings.insert(mapping); + } + } + return MachineMappingWithMemoryResult{std::move(non_pareto_optimal_mappings)}; +} + +MachineMappingWithMemoryResult + series_combine(float comm_cost, + MachineMappingWithMemoryResult const &pre_result, + MachineMappingWithMemoryResult const &post_result, + std::optional const + ¶llel_split_transformation) { + auto combine_machine_mapping = + [&](MachineMappingForSingleLayer const &pre_mm, + MachineMappingForSingleLayer const &post_mm) { + OpCostMetrics cost = OpCostMetrics{ + pre_mm.cost.runtime + comm_cost + post_mm.cost.runtime, + pre_mm.cost.memory + post_mm.cost.memory, + }; + + ParallelLayerGuidObliviousMachineMapping mapping = [&] { + if (parallel_split_transformation.has_value() && + parallel_split_transformation.value() == + ParallelSplitTransformation::RthenL) { + return binary_combine_mappings(/*lhs=*/post_mm.machine_mapping, + /*rhs=*/pre_mm.machine_mapping); + } else { + return binary_combine_mappings(/*lhs=*/pre_mm.machine_mapping, + /*rhs=*/post_mm.machine_mapping); + } + }(); + + return MachineMappingForSingleLayer{cost, mapping}; + }; + + MachineMappingWithMemoryResult result = + empty_machine_mapping_with_memory_result(); + for (MachineMappingForSingleLayer const &pre_mm : + pre_result.machine_mappings) { + for (MachineMappingForSingleLayer const &post_mm : + post_result.machine_mappings) { + result.machine_mappings.insert(combine_machine_mapping(pre_mm, post_mm)); + } + } + + return remove_non_pareto_optimal_machine_mapping_result(result); +} + +MachineMappingWithMemoryResult + parallel_combine(MachineMappingWithMemoryResult const &lhs_result, + MachineMappingWithMemoryResult const &rhs_result) { + auto combine_machine_mapping = + [&](MachineMappingForSingleLayer const &lhs_mm, + MachineMappingForSingleLayer const &rhs_mm) { + OpCostMetrics cost = OpCostMetrics{ + std::max(lhs_mm.cost.runtime, rhs_mm.cost.runtime), + std::max(lhs_mm.cost.memory, rhs_mm.cost.memory), + }; + + ParallelLayerGuidObliviousMachineMapping mapping = + binary_combine_mappings(lhs_mm.machine_mapping, + rhs_mm.machine_mapping); + + return MachineMappingForSingleLayer{cost, mapping}; + }; + + MachineMappingWithMemoryResult result = + empty_machine_mapping_with_memory_result(); + for (MachineMappingForSingleLayer const &lhs_mm : + lhs_result.machine_mappings) { + for (MachineMappingForSingleLayer const &rhs_mm : + rhs_result.machine_mappings) { + result.machine_mappings.insert(combine_machine_mapping(lhs_mm, rhs_mm)); + } + } + + return remove_non_pareto_optimal_machine_mapping_result(result); +} + +MachineMappingWithMemoryResult + minimize_runtime(MachineMappingWithMemoryResult const &m1, + MachineMappingWithMemoryResult const &m2) { + MachineMappingWithMemoryResult result = MachineMappingWithMemoryResult{ + set_union(m1.machine_mappings, m2.machine_mappings), + }; + return remove_non_pareto_optimal_machine_mapping_result(result); +} + +MachineMappingWithMemoryResult + make_singleton_machine_mapping_with_memory_result( + OpCostMetrics cost, MachineView const &machine_view) { + return MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + cost, + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), machine_view}, + }}, + }, + }}; +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc index 9ee596af3e..0431104878 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.cc @@ -5,13 +5,15 @@ namespace FlexFlow { TestCostEstimator::TestCostEstimator( - std::function const &get_operator_cost, + std::function const + &get_operator_cost, std::function const &get_communication_cost) : get_operator_cost(get_operator_cost), get_communication_cost(get_communication_cost) {} -float TestCostEstimator::estimate_cost(OpCostEstimateKey const &k) const { +OpCostMetrics + TestCostEstimator::estimate_cost(OpCostEstimateKey const &k) const { return this->get_operator_cost(k); } @@ -20,16 +22,16 @@ float TestCostEstimator::estimate_cost(TensorSetMovement const &m) const { } CostEstimator make_fake_cost_estimator( - std::function const &get_operator_cost, + std::function const + &get_operator_cost, std::function const &get_communication_cost) { - return CostEstimator::create(get_operator_cost, get_communication_cost); } CostEstimator make_fake_cost_estimator( - std::unordered_map const &op_cost_map, + std::unordered_map const &op_cost_map, std::unordered_map const &comm_cost_map) { return make_fake_cost_estimator( [op_cost_map](OpCostEstimateKey const &k) { return op_cost_map.at(k); }, diff --git a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h index 7c1d06207a..16ea3a85bc 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h +++ b/lib/compiler/test/src/compiler/machine_mapping/cost_estimator_for_test.h @@ -11,7 +11,7 @@ namespace FlexFlow { struct TestCostEstimator : public ICostEstimator { - std::function get_operator_cost; + std::function get_operator_cost; std::function get_communication_cost; TestCostEstimator() = delete; @@ -19,18 +19,19 @@ struct TestCostEstimator : public ICostEstimator { decltype(get_communication_cost) const &get_communication_cost); - float estimate_cost(OpCostEstimateKey const &) const override; + OpCostMetrics estimate_cost(OpCostEstimateKey const &) const override; float estimate_cost(TensorSetMovement const &) const override; }; CostEstimator make_fake_cost_estimator( - std::function const &get_operator_cost, + std::function const + &get_operator_cost, std::function const &get_communication_cost); CostEstimator make_fake_cost_estimator( - std::unordered_map const &op_cost_map, + std::unordered_map const &op_cost_map, std::unordered_map const &comm_cost_map); } // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index a0d06fe930..f5d5a5ee1b 100644 --- a/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/test/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -144,13 +144,19 @@ TEST_SUITE(FF_TEST_SUITE) { {binary_tree_root_path(), mv2}, }}; + auto map1 = std::unordered_map{{ + {map_unmapped_op_cost_estimate_key(k1, mv1), + OpCostMetrics{/*runtime=*/1.0, /*memory=*/0}}, + {map_unmapped_op_cost_estimate_key(k2, mv1), + OpCostMetrics{/*runtime=*/2.0, /*memory=*/0}}, + {map_unmapped_op_cost_estimate_key(k1, mv2), + OpCostMetrics{/*runtime=*/1.5, /*memory=*/0}}, + {map_unmapped_op_cost_estimate_key(k2, mv2), + OpCostMetrics{/*runtime=*/2.5, /*memory=*/0}}, + }}; + CostEstimator cost_estimator = make_fake_cost_estimator( - std::unordered_map{{ - {map_unmapped_op_cost_estimate_key(k1, mv1), 1.0}, - {map_unmapped_op_cost_estimate_key(k2, mv1), 2.0}, - {map_unmapped_op_cost_estimate_key(k1, mv2), 1.5}, - {map_unmapped_op_cost_estimate_key(k2, mv2), 2.5}, - }}, + map1, std::unordered_map{{ {TensorSetMovement{{}}, 0.0}, {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc new file mode 100644 index 0000000000..8761116be2 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -0,0 +1,290 @@ +#include "compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.h" +#include "../cost_estimator_for_test.h" +#include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.h" +#include "compiler/machine_mapping/machine_mapping_constraints.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/machine_mapping_problem_tree.h" +#include "compiler/machine_mapping/machine_mapping_problem_tree/unmapped_op_cost_estimate_key.h" +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_cache.h" +#include "pcg/machine_view.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "utils/containers/get_only.h" +#include "utils/full_binary_tree/binary_tree_path.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_optimal_machine_mapping_with_memory") { + auto make_leaf = [](UnmappedOpCostEstimateKey const &k) { + return MachineMappingProblemTree{k}; + }; + + auto make_series_split = + [](AbstractedTensorSetMovement const &tensor_set_movement, + MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeSeriesSplit{ + /*tensor_set_movement=*/tensor_set_movement, + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + auto make_parallel_split = [](MachineMappingProblemTree const &lhs, + MachineMappingProblemTree const &rhs) { + return MachineMappingProblemTree{ + MMProblemTreeParallelSplit{ + /*left_child=*/lhs, + /*right_child=*/rhs, + }, + }; + }; + + MachineView mv1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView mv2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/2, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + MachineSpecification split_machine_spec = MachineSpecification{ + /*num_nodes=*/1, + /*num_cpus_per_node=*/1, + /*num_gpus_per_node=*/1, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + auto allowed_machine_views1 = [&](UnmappedOpCostEstimateKey const &, + MachineSpecification const &resources) { + if (resources == full_machine_spec) { + return std::unordered_set{mv1, mv2}; + } else { + return std::unordered_set{mv2}; + } + }; + + UnmappedOpCostEstimateKey k1 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{}}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + UnmappedOpCostEstimateKey k2 = UnmappedOpCostEstimateKey{ + /*op_attrs=*/PCGOperatorAttrs{ElementBinaryAttrs{ + /*type=*/OperatorType::EW_ADD, + /*compute_type=*/DataType::FLOAT, + /*should_broadcast_lhs=*/false, + /*should_broadcast_rhs=*/false, + }}, + /*input_shapes=*/{}, + /*weight_shapes=*/{}, + /*output_shapes=*/{}, + }; + + ParallelTensorShape tensor_shape1 = ParallelTensorShape{ + ParallelTensorDims{ + FFOrdered{}, + ReplicaParallelDimSet{ + SumDegree{1}, + DiscardCopyDegree{1}, + }, + }, + DataType::FLOAT, + }; + + AbstractedTensorSetMovement movement1 = AbstractedTensorSetMovement{{ + AbstractedSingleTensorMovement{ + /*parallel_tensor_shape=*/tensor_shape1, + /*src_machine_views=*/{}, + /*dst_machine_views=*/{}, + }, + }}; + + ParallelLayerGuidObliviousMachineMapping mm1 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}; + ParallelLayerGuidObliviousMachineMapping mm2 = + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv2}, + }}; + + CostEstimator cost_estimator = make_fake_cost_estimator( + std::unordered_map{{ + {map_unmapped_op_cost_estimate_key(k1, mv1), OpCostMetrics{1.0, 2}}, + {map_unmapped_op_cost_estimate_key(k2, mv1), OpCostMetrics{2.0, 3}}, + {map_unmapped_op_cost_estimate_key(k1, mv2), OpCostMetrics{1.5, 1}}, + {map_unmapped_op_cost_estimate_key(k2, mv2), OpCostMetrics{2.5, 2}}, + }}, + std::unordered_map{{ + {TensorSetMovement{{}}, 0.0}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm1), + 0.1}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm2), + 0.2}, + {concretize_abstracted_tensor_set_movement(movement1, mm1, mm2), + 0.3}, + {concretize_abstracted_tensor_set_movement(movement1, mm2, mm1), + 0.4}, + }}); + + MachineMappingContext context = MachineMappingContext{ + cost_estimator, + allowed_machine_views1, + }; + + MachineMappingWithMemoryCache cache = + empty_machine_mapping_with_memory_cache(); + + SUBCASE("single layer") { + MachineMappingProblemTree problem_tree = make_leaf(k1); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingWithMemoryResult result = + get_optimal_machine_mapping_with_memory( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + OpCostMetrics{1.0, 2}, + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv1}, + }}, + }, + MachineMappingForSingleLayer{ + OpCostMetrics{1.5, 1}, + ParallelLayerGuidObliviousMachineMapping{{ + {binary_tree_root_path(), mv2}, + }}, + }, + }}; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in sequence") { + MachineMappingProblemTree problem_tree = + make_series_split(movement1, make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingWithMemoryResult result = + get_optimal_machine_mapping_with_memory( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + OpCostMetrics{ + /*runtime=*/1.0 + 2.0 + 0.1, + /*memory=*/2 + 3, + }, + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv1, + }, + }}, + }, + MachineMappingForSingleLayer{ + OpCostMetrics{1.5 + 2.5 + 0.1, 1 + 2}, + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv2, + }, + }}, + }, + }}; + + CHECK(result == correct); + } + + SUBCASE("pair of layers in parallel") { + MachineMappingProblemTree problem_tree = + make_parallel_split(make_leaf(k1), make_leaf(k2)); + + MachineMappingConstraints constraints = + get_unconstrained_solution_for_layers( + get_all_leaf_paths(problem_tree)); + + MachineMappingWithMemoryResult result = + get_optimal_machine_mapping_with_memory( + cache, context, problem_tree, full_machine_spec, constraints); + MachineMappingWithMemoryResult correct = + MachineMappingWithMemoryResult{{MachineMappingForSingleLayer{ + OpCostMetrics{2.5, 2}, + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + mv2, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + mv2, + }, + }}, + + }}}; + + CHECK(result == correct); + } + } +} diff --git a/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc new file mode 100644 index 0000000000..a47d8713e9 --- /dev/null +++ b/lib/compiler/test/src/compiler/machine_mapping/memory_optimization/machine_mapping_result_with_memory.cc @@ -0,0 +1,593 @@ +#include "compiler/machine_mapping/memory_optimization/machine_mapping_with_memory_result.h" +#include "pcg/machine_view.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("remove_non_pareto_optimal_machine_mapping_result") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{4}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + OpCostMetrics cost1 = OpCostMetrics{ + /*runtime=*/2.0, + /*memory=*/2, + }; + OpCostMetrics cost2 = OpCostMetrics{ + /*runtime=*/4.0, + /*memory=*/1, + }; + OpCostMetrics cost3 = OpCostMetrics{ + /*runtime=*/2.0, + /*memory=*/3, + }; + + MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{ + cost1, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_0, + }, + }, + }, + }; + + MachineMappingForSingleLayer mm2 = MachineMappingForSingleLayer{ + cost2, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_1, + }, + }, + }, + }; + + MachineMappingForSingleLayer mm3 = MachineMappingForSingleLayer{ + cost3, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_2, + }, + }, + }, + }; + + SUBCASE("empty") { + MachineMappingWithMemoryResult before_remove = + empty_machine_mapping_with_memory_result(); + MachineMappingWithMemoryResult result = + remove_non_pareto_optimal_machine_mapping_result(before_remove); + MachineMappingWithMemoryResult correct = + empty_machine_mapping_with_memory_result(); + + CHECK(result == correct); + } + + SUBCASE("all solutions are pareto-optimal") { + MachineMappingWithMemoryResult before_remove = + MachineMappingWithMemoryResult{ + { + mm1, + mm2, + }, + }; + MachineMappingWithMemoryResult result = + remove_non_pareto_optimal_machine_mapping_result(before_remove); + MachineMappingWithMemoryResult correct = before_remove; + + CHECK(result == correct); + } + + SUBCASE("there exists a non-pareto-optimal solution") { + MachineMappingWithMemoryResult before_remove = + MachineMappingWithMemoryResult{ + { + mm1, + mm2, + mm3, + }, + }; + MachineMappingWithMemoryResult result = + remove_non_pareto_optimal_machine_mapping_result(before_remove); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ + { + mm1, + mm2, + }, + }; + + CHECK(result == correct); + } + } + + TEST_CASE("series_combine(float, MachineMappingWithMemoryResult const &, " + "MachineMappingWithMemoryResult const &, " + "std::optional const&)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + OpCostMetrics pre_cost = OpCostMetrics{ + /*runtime=*/2.0, + /*memory=*/2, + }; + MachineMappingWithMemoryResult pre = MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + pre_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{ + {BinaryTreePathEntry::LEFT_CHILD}, + }, + machine_view_0, + }, + { + BinaryTreePath{ + {BinaryTreePathEntry::RIGHT_CHILD}, + }, + machine_view_1, + }, + }, + }, + }, + }}; + + OpCostMetrics post_cost = OpCostMetrics{ + /*runtime=*/4.0, + /*memory=*/1, + }; + + MachineMappingWithMemoryResult post = MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + post_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_1, + }, + }, + }, + }, + }}; + + MachineMappingWithMemoryResult empty = + empty_machine_mapping_with_memory_result(); + + float comm_cost = 3.0; + + SUBCASE("pre is empty") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, empty, post, ParallelSplitTransformation::LthenR); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("post is empty") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, pre, empty, ParallelSplitTransformation::LthenR); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("both are nonempty") { + MachineMappingWithMemoryResult no_parallel_split_transform = + MachineMappingWithMemoryResult{ + { + MachineMappingForSingleLayer{ + /*cost=*/OpCostMetrics{ + /*runtime=*/pre_cost.runtime + comm_cost + + post_cost.runtime, + /*memory=*/pre_cost.memory + post_cost.memory, + }, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }, + }; + + SUBCASE("parallel_split_transformation = std::nullopt") { + MachineMappingWithMemoryResult result = + series_combine(comm_cost, pre, post, std::nullopt); + MachineMappingWithMemoryResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = LthenR") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::LthenR); + MachineMappingWithMemoryResult correct = no_parallel_split_transform; + + CHECK(result == correct); + } + + SUBCASE("parallel_split_transformation = RthenL") { + MachineMappingWithMemoryResult result = series_combine( + comm_cost, pre, post, ParallelSplitTransformation::RthenL); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ + { + MachineMappingForSingleLayer{ + /*cost=*/OpCostMetrics{ + /*runtime=*/pre_cost.runtime + comm_cost + + post_cost.runtime, + /*memory=*/pre_cost.memory + post_cost.memory, + }, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{{ + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_0, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::RIGHT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD, + }}, + machine_view_1, + }, + { + BinaryTreePath{{ + BinaryTreePathEntry::LEFT_CHILD, + }}, + machine_view_1, + }, + }}, + }, + }, + }; + + CHECK(result == correct); + } + } + } + + TEST_CASE("parallel_combine(float, MachineMappingWithMemoryResult const &, " + "MachineMappingWithMemoryResult const &, " + "std::optional const&)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + OpCostMetrics lhs_cost = OpCostMetrics{ + /*runtime=*/2.0, + /*memory=*/2, + }; + MachineMappingWithMemoryResult lhs = MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + lhs_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{ + {BinaryTreePathEntry::LEFT_CHILD}, + }, + machine_view_0, + }, + { + BinaryTreePath{ + {BinaryTreePathEntry::RIGHT_CHILD}, + }, + machine_view_1, + }, + }, + }, + }, + }}; + + OpCostMetrics rhs_cost = OpCostMetrics{ + /*runtime=*/4.0, + /*memory=*/1, + }; + MachineMappingWithMemoryResult rhs = MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + rhs_cost, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_1, + }, + }, + }, + }, + }}; + + MachineMappingWithMemoryResult empty = + empty_machine_mapping_with_memory_result(); + + SUBCASE("lhs is empty") { + MachineMappingWithMemoryResult result = parallel_combine(empty, rhs); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("rhs is empty") { + MachineMappingWithMemoryResult result = parallel_combine(lhs, empty); + MachineMappingWithMemoryResult correct = empty; + + CHECK(result == correct); + } + + SUBCASE("both are nonempty") { + MachineMappingWithMemoryResult result = parallel_combine(lhs, rhs); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{{ + MachineMappingForSingleLayer{ + /*cost=*/OpCostMetrics{ + /*runtime=*/std::max(lhs_cost.runtime, rhs_cost.runtime), + /*memory=*/std::max(lhs_cost.memory, rhs_cost.memory), + }, + /*machine_mapping=*/ + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::LEFT_CHILD}}, + machine_view_0, + }, + { + BinaryTreePath{{BinaryTreePathEntry::LEFT_CHILD, + BinaryTreePathEntry::RIGHT_CHILD}}, + machine_view_1, + }, + { + BinaryTreePath{{BinaryTreePathEntry::RIGHT_CHILD}}, + machine_view_1, + }, + }, + }, + }, + }}; + + CHECK(result == correct); + } + } + + TEST_CASE("minimize_runtime(memory)") { + MachineView machine_view_0 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{1}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_1 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{2}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + MachineView machine_view_2 = MachineView{ + /*start=*/MachineSpaceCoordinate{ + /*node_idx=*/0, + /*device_idx=*/0, + /*device_type=*/DeviceType::GPU, + }, + /*dimensions=*/ + { + MachineViewDimension{ + stride_t{4}, + MachineSpecificationDimension::INTRA_NODE, + }, + }, + }; + + OpCostMetrics cost1 = OpCostMetrics{ + /*runtime=*/2.0, + /*memory=*/2, + }; + OpCostMetrics cost2 = OpCostMetrics{ + /*runtime=*/4.0, + /*memory=*/1, + }; + OpCostMetrics cost3 = OpCostMetrics{ + /*runtime=*/2.0, + /*memory=*/3, + }; + + MachineMappingForSingleLayer mm1 = MachineMappingForSingleLayer{ + cost1, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_0, + }, + }, + }, + }; + + MachineMappingForSingleLayer mm2 = MachineMappingForSingleLayer{ + cost2, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_1, + }, + }, + }, + }; + + MachineMappingForSingleLayer mm3 = MachineMappingForSingleLayer{ + cost3, + ParallelLayerGuidObliviousMachineMapping{ + { + { + BinaryTreePath{{}}, + machine_view_2, + }, + }, + }, + }; + + MachineMappingWithMemoryResult result1 = MachineMappingWithMemoryResult{ + { + mm1, + mm2, + }, + }; + + MachineMappingWithMemoryResult result2 = MachineMappingWithMemoryResult{ + { + mm2, + mm3, + }, + }; + + MachineMappingWithMemoryResult result = minimize_runtime(result1, result2); + MachineMappingWithMemoryResult correct = MachineMappingWithMemoryResult{ + { + mm1, + mm2, + }, + }; + + CHECK(result == correct); + } +} diff --git a/lib/utils/include/utils/containers/recurse_n.h b/lib/utils/include/utils/containers/recurse_n.h new file mode 100644 index 0000000000..8dc22cb8a8 --- /dev/null +++ b/lib/utils/include/utils/containers/recurse_n.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_RECURSE_N_H + +#include "utils/exception.h" + +namespace FlexFlow { + +/** + * @brief + * Applies function `f` to value `initial_value` n times recursively. + * + * @example + * auto add_three = [](int x) { return x + 3; }; + * int result = recurse_n(add_three, 3, 5); + * result -> f(f(f(5))) = ((5+3)+3)+3 = 14 + * + * @throws RuntimeError if n is negative + */ +template +T recurse_n(F const &f, int n, T const &initial_value) { + if (n < 0) { + throw mk_runtime_error( + fmt::format("Supplied n={} should be non-negative", n)); + } + T t = initial_value; + for (int i = 0; i < n; i++) { + t = f(t); + } + return t; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/recurse_n.cc b/lib/utils/src/utils/containers/recurse_n.cc new file mode 100644 index 0000000000..182db6fd73 --- /dev/null +++ b/lib/utils/src/utils/containers/recurse_n.cc @@ -0,0 +1,12 @@ +#include "utils/containers/recurse_n.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using T = value_type<0>; +using F = std::function; // F :: T -> T + +template T recurse_n(F const &f, int n, T const &initial_value); + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/containers/recurse_n.cc b/lib/utils/test/src/utils/containers/recurse_n.cc new file mode 100644 index 0000000000..1805ee891f --- /dev/null +++ b/lib/utils/test/src/utils/containers/recurse_n.cc @@ -0,0 +1,29 @@ +#include "utils/containers/recurse_n.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("recurse_n") { + auto append_bar = [](std::string const &x) { + return x + std::string("Bar"); + }; + + SUBCASE("n = 0") { + std::string result = recurse_n(append_bar, 0, std::string("Foo")); + std::string correct = "Foo"; + CHECK(result == correct); + } + + SUBCASE("n = 3") { + std::string result = recurse_n(append_bar, 3, std::string("Foo")); + std::string correct = "FooBarBarBar"; + CHECK(result == correct); + } + + SUBCASE("n < 0") { + CHECK_THROWS(recurse_n(append_bar, -1, std::string("Foo"))); + } + } +} diff --git a/scripts/format.sh b/scripts/format.sh deleted file mode 100755 index e4f1ec1611..0000000000 --- a/scripts/format.sh +++ /dev/null @@ -1,77 +0,0 @@ -#! /usr/bin/env bash - -set -euo pipefail - -GIT_ROOT="$(git rev-parse --show-toplevel)" -cd "$GIT_ROOT" - -TOOLS_PATH="$GIT_ROOT/.tools" -RELEASE="master-f4f85437" -CLANG_FORMAT_VERSION="16" -CLANG_FORMAT_PATH="$TOOLS_PATH/clang-format-$CLANG_FORMAT_VERSION-$RELEASE" - -mkdir -p "$TOOLS_PATH" - -error() { - >&2 echo "$@" - exit 1 -} - -get_os() { - UNAME_OUTPUT="$(uname -s)" - case "$UNAME_OUTPUT" in - Linux*) - OS=Linux - ;; - Darwin*) - OS=Mac - ;; - *) - error "Unknown OS $UNAME_OUTPUT. Exiting..." - esac - - echo "$OS" -} - -download_clang_tool() { - TOOL="$1" - VERSION="$2" - TARGET_PATH="$3" - - BASE_URL="https://github.com/muttleyxd/clang-tools-static-binaries/releases/download/$RELEASE/" - - OS="$(get_os)" - case "$OS" in - Linux) - URL_OS="linux" - ;; - Mac) - URL_OS="macosx" - ;; - *) - error "Unknown return value from get_os: $OS. Exiting..." - esac - URL="$BASE_URL/clang-${TOOL}-${VERSION}_${URL_OS}-amd64" - echo "Downloading from $URL..." - - if command -v wget &> /dev/null; then - wget "$URL" -O "$TARGET_PATH" - elif command -v curl &> /dev/null; then - curl -L "$URL" -o "$TARGET_PATH" - else - error "Could not find either wget or curl. Exiting..." - fi -} - -if [[ ! -e $CLANG_FORMAT_PATH ]]; then - download_clang_tool format "$CLANG_FORMAT_VERSION" "$CLANG_FORMAT_PATH" - chmod u+x "$CLANG_FORMAT_PATH" -fi - -CLANG_FORMAT_CONFIG="$GIT_ROOT/.clang-format-for-format-sh" -mapfile -t FILES < <(git ls-files ':!:triton/**' '*.h' '*.cc' '*.cpp' '*.cu' '*.c' '*.decl') -if [[ -f $CLANG_FORMAT_CONFIG ]]; then - "$CLANG_FORMAT_PATH" --style=file:"$CLANG_FORMAT_CONFIG" -i "${FILES[@]}" -else - echo "error" -fi diff --git a/scripts/gdb/pretty_print.py b/scripts/gdb/pretty_print.py deleted file mode 100644 index 4cccc9b76b..0000000000 --- a/scripts/gdb/pretty_print.py +++ /dev/null @@ -1,95 +0,0 @@ -import gdb.printing - -class NodePrinter: - def __init__(self, val): - self.val = val - - def to_string(self): - ptr = self.val["ptr"] - if ptr != 0: - op_type = ptr.referenced_value()['op_type'] - return f'Node' - else: - return f'Node' - -class EdgePrinter: - def __init__(self, val): - self.val = val - - def to_string(self): - return f'Edge' - -class MachineViewPrinter: - def __init__(self, val): - self.val = val - - def to_string(self): - toks = [] - if self.val['device_type'] == 0: - toks.append('type=GPU') - else: - toks.append('type=CPU') - start_device_id = self.val['start_device_id'] - for i in range(self.val['ndims']): - dim = self.val['dim'][i] - stride = self.val['stride'][i] - toks.append(f'{i}=[{start_device_id}:{start_device_id+dim}:{stride}]') - return f'MachineView<{" ".join(toks)}>' - -class DomainPrinter: - def __init__(self, val): - self.val = val - - def to_string(self): - toks = [] - ndim = self.val['dim'] - for i in range(ndim): - lo = self.val['rect_data'][i] - hi = self.val['rect_data'][i + ndim] - toks.append(f'{i}=[{lo}:{hi}]') - return f'Domain<{" ".join(toks)}>' - -class TensorShapePrinter: - def __init__(self, val): - self.val = val - - def to_string(self): - toks = [] - ndim = self.val['num_dims'] - for i in range(ndim): - dim = self.val['dims'][i] - size = dim['size'] - degree = dim['degree'] - parallel_idx = dim['parallel_idx'] - toks.append(f'{i}=[s={size} d={degree} pi={parallel_idx}]') - return f'TensorShape<{" ".join(toks)}>' - -class ParallelTensorBasePrinter: - def __init__(self, val): - self.val = val - - def to_string(self): - toks = [] - toks.append(f'guid={self.val["parallel_tensor_guid"]}') - ndim = self.val['num_dims'] - for i in range(ndim): - dim = self.val['dims'][i] - size = dim['size'] - degree = dim['degree'] - parallel_idx = dim['parallel_idx'] - toks.append(f'{i}=[s={size} d={degree} pi={parallel_idx}]') - return f'ParallelTensorBase<{" ".join(toks)}>' - -def build_pretty_printer(): - pp = gdb.printing.RegexpCollectionPrettyPrinter( - "flexflow") - pp.add_printer('Node', '^FlexFlow::PCG::Node$', NodePrinter) - pp.add_printer('Edge', '^FlexFlow::PCG::Edge$', EdgePrinter) - pp.add_printer('MachineView', '^FlexFlow::MachineView$', MachineViewPrinter) - pp.add_printer('Domain', '^Legion::Domain$', DomainPrinter) - pp.add_printer('ParallelTensorShape', '^FlexFlow::ParallelTensorShape$', TensorShapePrinter) - pp.add_printer('ParallelTensorBase', '^FlexFlow::ParallelTensorBase$', ParallelTensorBasePrinter) - return pp - -gdb.printing.register_pretty_printer( - gdb.current_objfile(), build_pretty_printer(), replace=True)