Skip to content

Commit

Permalink
tp
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Feb 2, 2024
1 parent a31f8e9 commit 9141c46
Show file tree
Hide file tree
Showing 20 changed files with 635 additions and 12 deletions.
1 change: 1 addition & 0 deletions examples/python/pytorch/mt5/mt5_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def top_level_task():
input_names = ["input_ids", "attention_mask"]

print("Tracing the model...")
print(batch_size)
hf_model = PyTorchModel(
model, is_hf_model=True, input_names=input_names,
batch_size=batch_size, seq_length=seq_length,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ enum OperatorType {
OP_REPLICATE,
OP_REDUCTION,
OP_PIPELINE,
OP_ALLREDUCE,
OP_FUSED_PARALLEL,
OP_INVALID,
};
Expand Down Expand Up @@ -189,6 +190,7 @@ enum PMParameter {
PM_COMBINE_DEGREE, // Combine
PM_REDUCTION_DIM, // Reduction
PM_REDUCTION_DEGREE, // Reduction
PM_ALLREDUCE_DIM, // AllReduce
PM_SOFTMAX_DIM, // Softmax
PM_NUM_HEADS, // MultiHeadAttention
PM_INVALID,
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ enum TaskIDs {
PIPELINE_INIT_TASK_ID,
PIPELINE_FWD_TASK_ID,
PIPELINE_BWD_TASK_ID,
ALLREDUCE_INIT_TASK_ID,
ALLREDUCE_INF_TASK_ID,
ALLREDUCE_FWD_TASK_ID,
ALLREDUCE_BWD_TASK_ID,
FUSED_PARALLELOP_INIT_TASK_ID,
FUSED_PARALLELOP_FWD_TASK_ID,
FUSED_PARALLELOP_BWD_TASK_ID,
Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/operator_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "flexflow/ops/batch_matmul_params.h"
#include "flexflow/ops/cast_params.h"
#include "flexflow/ops/concat_params.h"
#include "flexflow/parallel_ops/allreduce_params.h"
#include "flexflow/ops/conv_2d_params.h"
#include "flexflow/ops/dropout_params.h"
#include "flexflow/ops/element_binary_params.h"
Expand Down Expand Up @@ -62,6 +63,7 @@ using OperatorParameters = mp::variant<AggregateParams,
ReplicateParams,
ReductionParams,
CombineParams,
AllReduceParams,
FusedParallelOpParams>;

tl::optional<OperatorParameters> get_op_parameters(Op const *op);
Expand Down
57 changes: 57 additions & 0 deletions include/flexflow/parallel_ops/allreduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#ifndef _FLEXFLOW_ALLREDUCE_H
#define _FLEXFLOW_ALLREDUCE_H

#include "flexflow/layer.h"
#include "flexflow/node.h"
#include "flexflow/op_meta.h"
#include "flexflow/operator.h"
#include "flexflow/parallel_ops/allreduce_params.h"
#include "parallel_op.h"

namespace FlexFlow {

class AllReduce : public ParallelOp {
public:
using Params = AllReduceParams;
using Input = ParallelTensor;

AllReduce(FFModel &model,
const ParallelTensor input,
int allreduce_legion_dim,
char const *name = NULL);
AllReduce(FFModel &model,
Params const &params,
Input const input,
char const *name = nullptr);
void create_input_partition(FFModel &model) override;
void init(FFModel const &) override;
void forward(FFModel const &) override;
void backward(FFModel const &) override;
bool get_int_parameter(PMParameter, int *) const override;
bool append_parallel_op_info(
std::vector<ParallelOpInfo> &parallel_ops) const override;
static OpMeta *init_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void forward_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void backward_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;

Params get_params() const;

public:
int allreduce_dim;
};

}; // namespace FlexFlow

#endif // _FLEXFLOW_ALLREDUCE_H
22 changes: 22 additions & 0 deletions include/flexflow/parallel_ops/allreduce_params.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef _FLEXFLOW_ALLREDUCE_PARAMS_H
#define _FLEXFLOW_ALLREDUCE_PARAMS_H

namespace FlexFlow {

struct AllReduceParams {
int allreduce_legion_dim;
char name[MAX_OPNAME];
bool is_valid(ParallelTensorShape const &) const;
};
bool operator==(AllReduceParams const &, AllReduceParams const &);

} // namespace FlexFlow

namespace std {
template <>
struct hash<FlexFlow::AllReduceParams> {
size_t operator()(FlexFlow::AllReduceParams const &) const;
};
} // namespace std

#endif // _FLEXFLOW_ALLREDUCE_PARAMS_H
31 changes: 31 additions & 0 deletions include/flexflow/parallel_ops/kernels/allreduce_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H

#include "flexflow/device.h"
#include "flexflow/fftype.h"
#include "flexflow/op_meta.h"
#include "flexflow/parallel_ops/allreduce.h"

namespace FlexFlow {

class AllReduceMeta : public OpMeta {
public:
AllReduceMeta(FFHandler handle, AllReduce const *reduct);
};

namespace Kernels {
namespace AllReduce {

void forward_kernel_wrapper(AllReduceMeta const *m,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);

void backward_kernel_wrapper(AllReduceMeta const *m,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output_grad);

} // namespace AllReduce
} // namespace Kernels
} // namespace FlexFlow

#endif // _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H
4 changes: 2 additions & 2 deletions python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2124,9 +2124,9 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1):
for d in dataloaders:
d.next_batch(self)
self.forward()
self.zero_gradients()
# self.zero_gradients()
self.backward()
self.update()
self.unified_update()
self._ffconfig.end_trace(self._tracing_id)

def eval(self, x=None, y=None, batch_size=None):
Expand Down
2 changes: 1 addition & 1 deletion src/ops/fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ bool FusedOp::add_operator(FFModel &model, Op *op) {
// op->name, op_config));
// Cannot fuse parallel operators since they have different paralel_is
// in forward and backward
assert(!op->is_parallel_op());
assert(!op->is_parallel_op() || op->op_type == OP_ALLREDUCE);
// Currently don't consider nested fusion
if (op->op_type == OP_FUSED) {
return false;
Expand Down
17 changes: 17 additions & 0 deletions src/ops/fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "flexflow/ops/kernels/reshape_kernels.h"
#include "flexflow/ops/kernels/softmax_kernels.h"
#include "flexflow/ops/kernels/transpose_kernels.h"
#include "flexflow/parallel_ops/kernels/allreduce_kernels.h"
#include "flexflow/ops/layer_norm.h"
#include "flexflow/ops/linear.h"
#include "flexflow/utils/hip_helper.h"
Expand Down Expand Up @@ -420,6 +421,14 @@ __host__ void FusedOp::forward_task(Task const *task,
}
break;
}
case OP_ALLREDUCE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op];
Kernels::AllReduce::forward_kernel_wrapper(
m, my_input_accessor[0], my_output_accessor[0]);
break;
}
case OP_RESHAPE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
Expand Down Expand Up @@ -957,6 +966,14 @@ __host__ void FusedOp::backward_task(Task const *task,
}
break;
}
case OP_ALLREDUCE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op];
Kernels::AllReduce::backward_kernel_wrapper(
m, my_input_grad_accessor[0], my_output_grad_accessor[0]);
break;
}
case OP_TRANSPOSE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
Expand Down
17 changes: 17 additions & 0 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "flexflow/ops/kernels/reshape_kernels.h"
#include "flexflow/ops/kernels/softmax_kernels.h"
#include "flexflow/ops/kernels/transpose_kernels.h"
#include "flexflow/parallel_ops/kernels/allreduce_kernels.h"
#include "flexflow/ops/layer_norm.h"
#include "flexflow/utils/cuda_helper.h"

Expand Down Expand Up @@ -460,6 +461,14 @@ __host__ void FusedOp::forward_task(Task const *task,
}
break;
}
case OP_ALLREDUCE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op];
Kernels::AllReduce::forward_kernel_wrapper(
m, my_input_accessor[0], my_output_accessor[0]);
break;
}
case OP_TRANSPOSE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
Expand Down Expand Up @@ -1002,6 +1011,14 @@ __host__ void FusedOp::backward_task(Task const *task,
}
break;
}
case OP_ALLREDUCE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op];
Kernels::AllReduce::backward_kernel_wrapper(
m, my_input_grad_accessor[0], my_output_grad_accessor[0]);
break;
}
case OP_TRANSPOSE: {
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_weights[op] == 0);
Expand Down
Loading

0 comments on commit 9141c46

Please sign in to comment.