From 9141c46a64617de893d986b694f1202841e86b27 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 1 Feb 2024 23:23:14 -0500 Subject: [PATCH] tp --- examples/python/pytorch/mt5/mt5_ff.py | 1 + include/flexflow/ffconst.h | 2 + include/flexflow/model.h | 4 + include/flexflow/operator_params.h | 2 + include/flexflow/parallel_ops/allreduce.h | 57 ++++ .../flexflow/parallel_ops/allreduce_params.h | 22 ++ .../parallel_ops/kernels/allreduce_kernels.h | 31 ++ python/flexflow/core/flexflow_cffi.py | 4 +- src/ops/fused.cc | 2 +- src/ops/fused.cpp | 17 ++ src/ops/fused.cu | 17 ++ src/parallel_ops/allreduce.cc | 279 ++++++++++++++++++ .../kernels/allreduce_kernels.cpp | 58 ++++ src/parallel_ops/kernels/allreduce_kernels.cu | 56 ++++ src/runtime/ffconst_utils.cc | 2 + src/runtime/graph.cc | 19 ++ src/runtime/model.cc | 60 +++- src/runtime/operator_params.cc | 3 + src/runtime/optimizer.cc | 4 +- src/runtime/substitution.cc | 7 + 20 files changed, 635 insertions(+), 12 deletions(-) create mode 100644 include/flexflow/parallel_ops/allreduce.h create mode 100644 include/flexflow/parallel_ops/allreduce_params.h create mode 100644 include/flexflow/parallel_ops/kernels/allreduce_kernels.h create mode 100644 src/parallel_ops/allreduce.cc create mode 100644 src/parallel_ops/kernels/allreduce_kernels.cpp create mode 100644 src/parallel_ops/kernels/allreduce_kernels.cu diff --git a/examples/python/pytorch/mt5/mt5_ff.py b/examples/python/pytorch/mt5/mt5_ff.py index c2868e9d1e..5dff7415d3 100644 --- a/examples/python/pytorch/mt5/mt5_ff.py +++ b/examples/python/pytorch/mt5/mt5_ff.py @@ -122,6 +122,7 @@ def top_level_task(): input_names = ["input_ids", "attention_mask"] print("Tracing the model...") + print(batch_size) hf_model = PyTorchModel( model, is_hf_model=True, input_names=input_names, batch_size=batch_size, seq_length=seq_length, diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 5658e2923d..060983b020 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -157,6 +157,7 @@ enum OperatorType { OP_REPLICATE, OP_REDUCTION, OP_PIPELINE, + OP_ALLREDUCE, OP_FUSED_PARALLEL, OP_INVALID, }; @@ -189,6 +190,7 @@ enum PMParameter { PM_COMBINE_DEGREE, // Combine PM_REDUCTION_DIM, // Reduction PM_REDUCTION_DEGREE, // Reduction + PM_ALLREDUCE_DIM, // AllReduce PM_SOFTMAX_DIM, // Softmax PM_NUM_HEADS, // MultiHeadAttention PM_INVALID, diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 82a29c8182..3f45de3d79 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -191,6 +191,10 @@ enum TaskIDs { PIPELINE_INIT_TASK_ID, PIPELINE_FWD_TASK_ID, PIPELINE_BWD_TASK_ID, + ALLREDUCE_INIT_TASK_ID, + ALLREDUCE_INF_TASK_ID, + ALLREDUCE_FWD_TASK_ID, + ALLREDUCE_BWD_TASK_ID, FUSED_PARALLELOP_INIT_TASK_ID, FUSED_PARALLELOP_FWD_TASK_ID, FUSED_PARALLELOP_BWD_TASK_ID, diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 24c84a85ed..84653ac9ca 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -7,6 +7,7 @@ #include "flexflow/ops/batch_matmul_params.h" #include "flexflow/ops/cast_params.h" #include "flexflow/ops/concat_params.h" +#include "flexflow/parallel_ops/allreduce_params.h" #include "flexflow/ops/conv_2d_params.h" #include "flexflow/ops/dropout_params.h" #include "flexflow/ops/element_binary_params.h" @@ -62,6 +63,7 @@ using OperatorParameters = mp::variant; tl::optional get_op_parameters(Op const *op); diff --git a/include/flexflow/parallel_ops/allreduce.h b/include/flexflow/parallel_ops/allreduce.h new file mode 100644 index 0000000000..a28d4cef9e --- /dev/null +++ b/include/flexflow/parallel_ops/allreduce.h @@ -0,0 +1,57 @@ +#ifndef _FLEXFLOW_ALLREDUCE_H +#define _FLEXFLOW_ALLREDUCE_H + +#include "flexflow/layer.h" +#include "flexflow/node.h" +#include "flexflow/op_meta.h" +#include "flexflow/operator.h" +#include "flexflow/parallel_ops/allreduce_params.h" +#include "parallel_op.h" + +namespace FlexFlow { + +class AllReduce : public ParallelOp { +public: + using Params = AllReduceParams; + using Input = ParallelTensor; + + AllReduce(FFModel &model, + const ParallelTensor input, + int allreduce_legion_dim, + char const *name = NULL); + AllReduce(FFModel &model, + Params const ¶ms, + Input const input, + char const *name = nullptr); + void create_input_partition(FFModel &model) override; + void init(FFModel const &) override; + void forward(FFModel const &) override; + void backward(FFModel const &) override; + bool get_int_parameter(PMParameter, int *) const override; + bool append_parallel_op_info( + std::vector ¶llel_ops) const override; + static OpMeta *init_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void forward_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void backward_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + bool measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const override; + + Params get_params() const; + +public: + int allreduce_dim; +}; + +}; // namespace FlexFlow + +#endif // _FLEXFLOW_ALLREDUCE_H diff --git a/include/flexflow/parallel_ops/allreduce_params.h b/include/flexflow/parallel_ops/allreduce_params.h new file mode 100644 index 0000000000..a0daac8f9a --- /dev/null +++ b/include/flexflow/parallel_ops/allreduce_params.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_ALLREDUCE_PARAMS_H +#define _FLEXFLOW_ALLREDUCE_PARAMS_H + +namespace FlexFlow { + +struct AllReduceParams { + int allreduce_legion_dim; + char name[MAX_OPNAME]; + bool is_valid(ParallelTensorShape const &) const; +}; +bool operator==(AllReduceParams const &, AllReduceParams const &); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AllReduceParams const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_ALLREDUCE_PARAMS_H diff --git a/include/flexflow/parallel_ops/kernels/allreduce_kernels.h b/include/flexflow/parallel_ops/kernels/allreduce_kernels.h new file mode 100644 index 0000000000..02a5026fcf --- /dev/null +++ b/include/flexflow/parallel_ops/kernels/allreduce_kernels.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H +#define _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H + +#include "flexflow/device.h" +#include "flexflow/fftype.h" +#include "flexflow/op_meta.h" +#include "flexflow/parallel_ops/allreduce.h" + +namespace FlexFlow { + +class AllReduceMeta : public OpMeta { +public: + AllReduceMeta(FFHandler handle, AllReduce const *reduct); +}; + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output); + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad); + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow + +#endif // _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index 49ad5b3dc8..c39b58e4b8 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -2124,9 +2124,9 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1): for d in dataloaders: d.next_batch(self) self.forward() - self.zero_gradients() + # self.zero_gradients() self.backward() - self.update() + self.unified_update() self._ffconfig.end_trace(self._tracing_id) def eval(self, x=None, y=None, batch_size=None): diff --git a/src/ops/fused.cc b/src/ops/fused.cc index d0c895c53d..b241ff1587 100644 --- a/src/ops/fused.cc +++ b/src/ops/fused.cc @@ -129,7 +129,7 @@ bool FusedOp::add_operator(FFModel &model, Op *op) { // op->name, op_config)); // Cannot fuse parallel operators since they have different paralel_is // in forward and backward - assert(!op->is_parallel_op()); + assert(!op->is_parallel_op() || op->op_type == OP_ALLREDUCE); // Currently don't consider nested fusion if (op->op_type == OP_FUSED) { return false; diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index d1de787fb4..9da93f0c65 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -31,6 +31,7 @@ #include "flexflow/ops/kernels/reshape_kernels.h" #include "flexflow/ops/kernels/softmax_kernels.h" #include "flexflow/ops/kernels/transpose_kernels.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" #include "flexflow/ops/layer_norm.h" #include "flexflow/ops/linear.h" #include "flexflow/utils/hip_helper.h" @@ -420,6 +421,14 @@ __host__ void FusedOp::forward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0]); + break; + } case OP_RESHAPE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); @@ -957,6 +966,14 @@ __host__ void FusedOp::backward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::backward_kernel_wrapper( + m, my_input_grad_accessor[0], my_output_grad_accessor[0]); + break; + } case OP_TRANSPOSE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); diff --git a/src/ops/fused.cu b/src/ops/fused.cu index e43ce51f9c..b78447ba41 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -33,6 +33,7 @@ #include "flexflow/ops/kernels/reshape_kernels.h" #include "flexflow/ops/kernels/softmax_kernels.h" #include "flexflow/ops/kernels/transpose_kernels.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" #include "flexflow/ops/layer_norm.h" #include "flexflow/utils/cuda_helper.h" @@ -460,6 +461,14 @@ __host__ void FusedOp::forward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0]); + break; + } case OP_TRANSPOSE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); @@ -1002,6 +1011,14 @@ __host__ void FusedOp::backward_task(Task const *task, } break; } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::backward_kernel_wrapper( + m, my_input_grad_accessor[0], my_output_grad_accessor[0]); + break; + } case OP_TRANSPOSE: { assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_weights[op] == 0); diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc new file mode 100644 index 0000000000..1bbc43ca0d --- /dev/null +++ b/src/parallel_ops/allreduce.cc @@ -0,0 +1,279 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/allreduce.h" +#include "flexflow/ffconst_utils.h" +#include "flexflow/model.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/hash_utils.h" + +namespace FlexFlow { +// declare Legion names +using Legion::ArgumentMap; +using Legion::Context; +using Legion::coord_t; +using Legion::Domain; +using Legion::Future; +using Legion::FutureMap; +using Legion::IndexLauncher; +using Legion::LogicalPartition; +using Legion::LogicalRegion; +using Legion::Machine; +using Legion::Memory; +using Legion::PhysicalRegion; +using Legion::Predicate; +using Legion::Rect; +using Legion::RegionRequirement; +using Legion::Runtime; +using Legion::Task; +using Legion::TaskArgument; +using Legion::TaskLauncher; + +using namespace FlexFlow::Kernels::AllReduce; + +/* Params */ +bool operator==(AllReduceParams const &lhs, AllReduceParams const &rhs) { + return lhs.allreduce_legion_dim == rhs.allreduce_legion_dim; +} + +bool AllReduceParams::is_valid(ParallelTensorShape const &input) const { + return input.is_valid(); +} + +AllReduceParams AllReduce::get_params() const { + AllReduceParams params; + params.allreduce_legion_dim = this->allreduce_dim; + if (this->name != nullptr) { + strcpy(params.name, this->name); + } + return params; +} + +AllReduce::AllReduce(FFModel &model, + const ParallelTensor _input, + int _allreduce_legion_dim, + char const *name) + : ParallelOp(model, OP_ALLREDUCE, name, _input), + allreduce_dim(_allreduce_legion_dim) { + int numdim = _input->num_dims; + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdim; i++) { + dims[i] = _input->dims[i]; + } + assert(dims[allreduce_dim].degree > 1); + // ParallelTensorBase::update_parallel_ids(numdim, dims); + outputs[0] = model.create_parallel_tensor_legion_ordering( + numdim, dims, _input->data_type, this); +} + +AllReduce::AllReduce(FFModel &model, + AllReduceParams const ¶ms, + ParallelTensor const input, + char const *name) + : AllReduce(model, input, params.allreduce_legion_dim, params.name) {} + +void AllReduce::create_input_partition(FFModel &ff) { + // Do nothing + return; +} + + +OpMeta *AllReduce::init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + AllReduce *ar = (AllReduce *)task->args; + FFHandler handle = *((FFHandler const *)task->local_args); + AllReduceMeta *meta = new AllReduceMeta(handle, ar); + meta->input_type[0] = ar->inputs[0]->data_type; + meta->output_type[0] = ar->outputs[0]->data_type; + assert(meta->input_type[0] == meta->output_type[0]); + return meta; +} + +void AllReduce::init(FFModel const &ff) { + ArgumentMap argmap; + parallel_is = outputs[0]->parallel_is; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + set_argumentmap_for_init(ff, argmap); + IndexLauncher launcher(ALLREDUCE_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(AllReduce)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap(ff, fm); +} + +void AllReduce::forward(FFModel const &ff) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = outputs[0]->parallel_is; + assert(numOutputs == 1); + assert(numInputs == 1); + set_argumentmap_for_forward(ff, argmap); + IndexLauncher launcher(ALLREDUCE_FWD_TASK_ID, + outputs[0]->parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + runtime->execute_index_space(ctx, launcher); +} + +void AllReduce::backward(FFModel const &ff) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + IndexLauncher launcher(ALLREDUCE_BWD_TASK_ID, + inputs[0]->parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + inputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + inputs[0]->region_grad)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + outputs[0]->region_grad)); + launcher.add_field(1, FID_DATA); + runtime->execute_index_space(ctx, launcher); +} + +bool AllReduce::measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const { + cost_metrics = CostMetrics(); + cost_metrics.forward_time = 0.0f; + cost_metrics.backward_time = 0.0f; + + cost_metrics.sync_time = 0; + cost_metrics.inputs_memory = 0; + cost_metrics.outputs_memory = 0; + cost_metrics.weights_memory = 0; + return true; +} + +bool AllReduce::get_int_parameter(PMParameter para, int *value) const { + switch (para) { + case PM_ALLREDUCE_DIM: + *value = allreduce_dim; + return true; + default: + return Op::get_int_parameter(para, value); + } +} + +bool AllReduce::append_parallel_op_info( + std::vector ¶llel_ops) const { + ParallelOpInfo ret; + ret.op_type = op_type; + ret.parallel_dim = allreduce_dim; + ret.parallel_degree = -1; // AllReduce does not affect parallel degree + parallel_ops.push_back(ret); + return true; +} + +/*static*/ +void AllReduce::forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + + AllReduceMeta const *m = *((AllReduceMeta **)task->local_args); + + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + + assert(input.data_type == output.data_type); + forward_kernel_wrapper(m, input, output); +} + +void AllReduce::backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + AllReduceMeta const *m = *((AllReduceMeta **)task->local_args); + + GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + + assert(input_grad.data_type == output_grad.data_type); + backward_kernel_wrapper(m, input_grad, output_grad); +} + +}; // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AllReduceParams const ¶ms) const { + size_t key = 0; + hash_combine(key, params.allreduce_legion_dim); + return key; +} + +} // namespace std diff --git a/src/parallel_ops/kernels/allreduce_kernels.cpp b/src/parallel_ops/kernels/allreduce_kernels.cpp new file mode 100644 index 0000000000..0aea27107d --- /dev/null +++ b/src/parallel_ops/kernels/allreduce_kernels.cpp @@ -0,0 +1,58 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/hip_helper.h" +#include + +namespace FlexFlow { + +AllReduceMeta::AllReduceMeta(FFHandler handle, AllReduce const *reduct) + : OpMeta(handle) {} + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); + size_t hidden_dim_size = input.domain.hi()[0] - input.domain.lo()[0] + 1; +#ifdef FF_USE_NCCL + // ncclDataType_t nccl_data_type = ff_to_nccl_datatype(input.data_type); + checkNCCL(ncclAllReduce(input.ptr, + output.ptr, + input.domain.get_volume(), + ncclFloat, + ncclSum, + m->handle.ncclComm, + stream)); +#else + assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); +#endif +} + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + assert(false && "To be implemented"); +} + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow diff --git a/src/parallel_ops/kernels/allreduce_kernels.cu b/src/parallel_ops/kernels/allreduce_kernels.cu new file mode 100644 index 0000000000..bce970cbdc --- /dev/null +++ b/src/parallel_ops/kernels/allreduce_kernels.cu @@ -0,0 +1,56 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/cuda_helper.h" + +namespace FlexFlow { + +AllReduceMeta::AllReduceMeta(FFHandler handle, AllReduce const *reduct) + : OpMeta(handle) {} + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); +#ifdef FF_USE_NCCL + // ncclDataType_t nccl_data_type = ff_to_nccl_datatype(input.data_type); + checkNCCL(ncclAllReduce(input.ptr, + output.ptr, + input.domain.get_volume(), + ncclFloat, + ncclSum, + m->handle.ncclComm, + stream)); +#else + assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); +#endif +} + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + assert(false && "To be implemented"); +} + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index 7ab9201113..e2debfa2d5 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -168,6 +168,8 @@ std::string get_operator_type_name(OperatorType type) { return "Replicate"; case OP_REDUCTION: return "Reduction"; + case OP_ALLREDUCE: + return "AllReduce"; case OP_PIPELINE: return "Pipeline"; case OP_FUSED_PARALLEL: diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 7447125197..8be3877b48 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -39,6 +39,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/parallel_ops/combine.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" #include "flexflow/parallel_ops/reduction.h" @@ -2295,6 +2296,13 @@ GraphOptimalViewSerialized sez.serialize(reduction->reduction_dim); sez.serialize(reduction->reduction_degree); break; + } + case OP_ALLREDUCE: { + AllReduce *allreduce = (AllReduce *)op; + sez.serialize(allreduce->allreduce_dim); + sez.serialize(strlen(allreduce->name)); + sez.serialize(allreduce->name, strlen(allreduce->name)); + break; } case OP_COMBINE: { Combine *combine = (Combine *)op; @@ -2704,6 +2712,17 @@ void FFModel::deserialize_graph_optimal_view( {reduction_dim, reduction_degree}); break; } + case OP_ALLREDUCE: { + assert(num_inputs == 1); + int allreduce_dim; + dez.deserialize(allreduce_dim); + size_t name_len; + char name[MAX_OPNAME] = {0}; + dez.deserialize(name_len); + dez.deserialize(name, name_len); + node = get_or_create_node(inputs[0], {allreduce_dim}); + break; + } case OP_FUSED_PARALLEL: { assert(num_inputs == 1); std::vector parallel_ops; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index b16766291a..3807e85391 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -53,6 +53,7 @@ #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/reduction.h" #include "flexflow/parallel_ops/replicate.h" #include "flexflow/substitution.h" @@ -2508,9 +2509,10 @@ bool FFModel::apply_fusion(std::vector const &operators, operators[l]->op_type == OP_WEIGHT) { continue; } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[l]->is_parallel_op()) { + // don't fuse parallel op except allReduce since they have different + // parallel_is in forward/backward + if (operators[l]->is_parallel_op() && + operators[l]->op_type != OP_ALLREDUCE) { continue; } size_t start = 0; @@ -2553,9 +2555,10 @@ bool FFModel::apply_fusion(std::vector const &operators, operators[i]->op_type == OP_WEIGHT) { continue; } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[i]->is_parallel_op()) { + // don't fuse parallel op except allReduce since they have different + // parallel_is in forward/backward + if (operators[i]->is_parallel_op() && + operators[i]->op_type != OP_ALLREDUCE) { continue; } fused_op = new FusedOp(*this, operators[i]); @@ -3491,7 +3494,7 @@ struct DefaultConfig { const static bool profiling = false; constexpr static float learningRate = 0.01f; constexpr static float weightDecay = 0.0001f; - const static size_t workSpaceSize = (size_t)1 * 1024 * 1024 * 1024; // 2GB + const static size_t workSpaceSize = (size_t)2 * 1024 * 1024 * 1024; // 2GB const static int numNodes = 1; const static int workersPerNode = 0; const static int cpusPerNode = 0; @@ -5214,6 +5217,49 @@ void register_flexflow_internal_tasks(Runtime *runtime, runtime->register_task_variant(registrar); } } + // AllReduce + { + TaskVariantRegistrar registrar(ALLREDUCE_INIT_TASK_ID, "AllReduce Init"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "AllReduce init Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } + { + TaskVariantRegistrar registrar(ALLREDUCE_FWD_TASK_ID, "AllReduce Forward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "AllReduce Forward Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } + { + TaskVariantRegistrar registrar(ALLREDUCE_BWD_TASK_ID, "AllReduce Backward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "AllReduce Backward Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } // FusedParallelOp { TaskVariantRegistrar registrar(FUSED_PARALLELOP_FWD_TASK_ID, diff --git a/src/runtime/operator_params.cc b/src/runtime/operator_params.cc index 41dd37dec7..322d7840fb 100644 --- a/src/runtime/operator_params.cc +++ b/src/runtime/operator_params.cc @@ -28,6 +28,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/parallel_ops/combine.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" #include "flexflow/parallel_ops/reduction.h" @@ -94,6 +95,8 @@ tl::optional get_op_parameters(Op const *op) { return ((Reduction *)op)->get_params(); case OP_COMBINE: return ((Combine *)op)->get_params(); + case OP_ALLREDUCE: + return ((AllReduce *)op)->get_params(); case OP_FUSED_PARALLEL: return ((FusedParallelOp *)op)->get_params(); case OP_TRANSPOSE: diff --git a/src/runtime/optimizer.cc b/src/runtime/optimizer.cc index a683d234e3..2b90bc3f44 100644 --- a/src/runtime/optimizer.cc +++ b/src/runtime/optimizer.cc @@ -536,10 +536,10 @@ void AdamOptimizer::unified_update(std::vector const parameters) assert(v_values.find(p->region) != v_values.end()); assert(m_values.find(p->region) != m_values.end()); assert(p->owner_op != NULL); - if (reservedWorkSpaceSize + p->get_volume() * sizeof(float) >= - workSpaceSize) { + if (reservedWorkSpaceSize + p->get_volume() * sizeof(float) >= workSpaceSize) { break; } + printf("continueee\n"); reservedWorkSpaceSize += p->get_volume() * sizeof(float); parameters_num += 1; assert(p->sync_type == ParameterSyncType::NCCL); diff --git a/src/runtime/substitution.cc b/src/runtime/substitution.cc index e3adfec5b7..4f44a3a574 100644 --- a/src/runtime/substitution.cc +++ b/src/runtime/substitution.cc @@ -34,6 +34,7 @@ #include "flexflow/ops/split.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/partition.h" #include "flexflow/parallel_ops/reduction.h" #include "flexflow/parallel_ops/replicate.h" @@ -3743,6 +3744,12 @@ bool FFModel::convert_graph_to_operators( reduction->reduction_degree); break; } + case OP_ALLREDUCE: { + assert(inList.size() == 1); + AllReduce *allreduce = (AllReduce *)node.ptr; + new_op = new AllReduce(*this, inputs[0], allreduce->allreduce_dim); + break; + } case OP_FUSED_PARALLEL: { assert(inList.size() == 1); FusedParallelOp *fused = (FusedParallelOp *)node.ptr;