forked from flexflow/flexflow-train
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
635 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#ifndef _FLEXFLOW_ALLREDUCE_H | ||
#define _FLEXFLOW_ALLREDUCE_H | ||
|
||
#include "flexflow/layer.h" | ||
#include "flexflow/node.h" | ||
#include "flexflow/op_meta.h" | ||
#include "flexflow/operator.h" | ||
#include "flexflow/parallel_ops/allreduce_params.h" | ||
#include "parallel_op.h" | ||
|
||
namespace FlexFlow { | ||
|
||
class AllReduce : public ParallelOp { | ||
public: | ||
using Params = AllReduceParams; | ||
using Input = ParallelTensor; | ||
|
||
AllReduce(FFModel &model, | ||
const ParallelTensor input, | ||
int allreduce_legion_dim, | ||
char const *name = NULL); | ||
AllReduce(FFModel &model, | ||
Params const ¶ms, | ||
Input const input, | ||
char const *name = nullptr); | ||
void create_input_partition(FFModel &model) override; | ||
void init(FFModel const &) override; | ||
void forward(FFModel const &) override; | ||
void backward(FFModel const &) override; | ||
bool get_int_parameter(PMParameter, int *) const override; | ||
bool append_parallel_op_info( | ||
std::vector<ParallelOpInfo> ¶llel_ops) const override; | ||
static OpMeta *init_task(Legion::Task const *task, | ||
std::vector<Legion::PhysicalRegion> const ®ions, | ||
Legion::Context ctx, | ||
Legion::Runtime *runtime); | ||
static void forward_task(Legion::Task const *task, | ||
std::vector<Legion::PhysicalRegion> const ®ions, | ||
Legion::Context ctx, | ||
Legion::Runtime *runtime); | ||
static void backward_task(Legion::Task const *task, | ||
std::vector<Legion::PhysicalRegion> const ®ions, | ||
Legion::Context ctx, | ||
Legion::Runtime *runtime); | ||
bool measure_operator_cost(Simulator *sim, | ||
MachineView const &pc, | ||
CostMetrics &cost_metrics) const override; | ||
|
||
Params get_params() const; | ||
|
||
public: | ||
int allreduce_dim; | ||
}; | ||
|
||
}; // namespace FlexFlow | ||
|
||
#endif // _FLEXFLOW_ALLREDUCE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#ifndef _FLEXFLOW_ALLREDUCE_PARAMS_H | ||
#define _FLEXFLOW_ALLREDUCE_PARAMS_H | ||
|
||
namespace FlexFlow { | ||
|
||
struct AllReduceParams { | ||
int allreduce_legion_dim; | ||
char name[MAX_OPNAME]; | ||
bool is_valid(ParallelTensorShape const &) const; | ||
}; | ||
bool operator==(AllReduceParams const &, AllReduceParams const &); | ||
|
||
} // namespace FlexFlow | ||
|
||
namespace std { | ||
template <> | ||
struct hash<FlexFlow::AllReduceParams> { | ||
size_t operator()(FlexFlow::AllReduceParams const &) const; | ||
}; | ||
} // namespace std | ||
|
||
#endif // _FLEXFLOW_ALLREDUCE_PARAMS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#ifndef _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H | ||
#define _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H | ||
|
||
#include "flexflow/device.h" | ||
#include "flexflow/fftype.h" | ||
#include "flexflow/op_meta.h" | ||
#include "flexflow/parallel_ops/allreduce.h" | ||
|
||
namespace FlexFlow { | ||
|
||
class AllReduceMeta : public OpMeta { | ||
public: | ||
AllReduceMeta(FFHandler handle, AllReduce const *reduct); | ||
}; | ||
|
||
namespace Kernels { | ||
namespace AllReduce { | ||
|
||
void forward_kernel_wrapper(AllReduceMeta const *m, | ||
GenericTensorAccessorR const &input, | ||
GenericTensorAccessorW const &output); | ||
|
||
void backward_kernel_wrapper(AllReduceMeta const *m, | ||
GenericTensorAccessorW const &input_grad, | ||
GenericTensorAccessorR const &output_grad); | ||
|
||
} // namespace AllReduce | ||
} // namespace Kernels | ||
} // namespace FlexFlow | ||
|
||
#endif // _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.