forked from flexflow/flexflow-train
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'inference' into optimize_attn
- Loading branch information
Showing
42 changed files
with
2,741 additions
and
541 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#ifndef _FLEXFLOW_OPS_KERNELS_RESIDUAL_RMSNORM_KERNELS_H | ||
#define _FLEXFLOW_OPS_KERNELS_RESIDUAL_RMSNORM_KERNELS_H | ||
|
||
#include "flexflow/accessor.h" | ||
#include "flexflow/device.h" | ||
#include "flexflow/fftype.h" | ||
#include "flexflow/op_meta.h" | ||
#include "flexflow/utils/memory_allocator.h" | ||
|
||
namespace FlexFlow { | ||
using Legion::coord_t; | ||
|
||
class ResidualRMSNorm; | ||
|
||
class ResidualRMSNormMeta : public OpMeta { | ||
public: | ||
ResidualRMSNormMeta(FFHandler handler, | ||
ResidualRMSNorm const *rms, | ||
MemoryAllocator &gpu_mem_allocator); | ||
~ResidualRMSNormMeta(void); | ||
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) | ||
cudnnTensorDescriptor_t inputTensor, outputTensor; | ||
cudnnReduceTensorDescriptor_t reduceDesc; | ||
#else | ||
miopenTensorDescriptor_t inputTensor, outputTensor; | ||
miopenReduceTensorDescriptor_t reduceDesc; | ||
#endif | ||
|
||
public: | ||
float eps; | ||
void *rms_ptr; | ||
void *norm_ptr; | ||
|
||
float alpha; | ||
float beta; | ||
|
||
int in_dim; | ||
int batch_size; | ||
int num_elements; | ||
char op_name[MAX_OPNAME]; | ||
Realm::RegionInstance reserveInst; | ||
}; | ||
|
||
namespace Kernels { | ||
namespace ResidualRMSNorm { | ||
void forward_kernel_wrapper(ResidualRMSNormMeta const *m, | ||
GenericTensorAccessorR const &input1, | ||
GenericTensorAccessorR const &input2, | ||
GenericTensorAccessorR const &weight, | ||
GenericTensorAccessorW const &residual_output, | ||
GenericTensorAccessorW const &output); | ||
} // namespace ResidualRMSNorm | ||
} // namespace Kernels | ||
} // namespace FlexFlow | ||
|
||
#endif // _FLEXFLOW_OPS_KERNELS_RMSNORM_KERNELS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#ifndef _FLEXFLOW_RESIDUAL_RMS_NORM_H | ||
#define _FLEXFLOW_RESIDUAL_RMS_NORM_H | ||
|
||
#include "flexflow/inference.h" | ||
#include "flexflow/model.h" | ||
#include "flexflow/ops/residual_rms_norm_params.h" | ||
#include "flexflow/utils/memory_allocator.h" | ||
|
||
namespace FlexFlow { | ||
|
||
class ResidualRMSNormMeta; | ||
|
||
class ResidualRMSNorm : public Op { | ||
public: | ||
using Params = ResidualRMSNormParams; | ||
using Input = std::pair<ParallelTensor, ParallelTensor>; | ||
ResidualRMSNorm(FFModel &model, | ||
LayerID const &_layer_guid, | ||
const ParallelTensor _input1, | ||
const ParallelTensor _input2, | ||
float _eps, | ||
int dim, | ||
bool allocate_weights, | ||
char const *name); | ||
ResidualRMSNorm(FFModel &model, | ||
ResidualRMSNormParams const ¶ms, | ||
Input const &inputs, | ||
bool allocate_weights, | ||
char const *name = nullptr); | ||
|
||
ResidualRMSNorm(FFModel &model, | ||
ResidualRMSNorm const &other, | ||
Input const &inputs, | ||
bool allocate_weights); | ||
void init(FFModel const &) override; | ||
void forward(FFModel const &) override; | ||
void backward(FFModel const &) override; | ||
void init_inference(FFModel const &, | ||
std::vector<ParallelTensor> const &, | ||
std::vector<ParallelTensor> const &, | ||
MachineView const *mv = nullptr) override; | ||
Legion::FutureMap inference(FFModel const &, | ||
BatchConfigFuture const &, | ||
std::vector<ParallelTensor> const &, | ||
std::vector<ParallelTensor> const &, | ||
MachineView const *mv = nullptr) override; | ||
void print_layer(FFModel const &model) override { | ||
assert(0); | ||
} | ||
|
||
static Op * | ||
create_operator_from_layer(FFModel &model, | ||
Layer const *layer, | ||
std::vector<ParallelTensor> const &inputs); | ||
void serialize(Legion::Serializer &) const override; | ||
static PCG::Node deserialize(FFModel &ff, | ||
Legion::Deserializer &d, | ||
ParallelTensor inputs[], | ||
int num_inputs); | ||
Op *materialize(FFModel &ff, | ||
ParallelTensor inputs[], | ||
int num_inputs) const override; | ||
ResidualRMSNormParams get_params() const; | ||
|
||
static OpMeta *init_task(Legion::Task const *task, | ||
std::vector<Legion::PhysicalRegion> const ®ions, | ||
Legion::Context ctx, | ||
Legion::Runtime *runtime); | ||
static void forward_task(Legion::Task const *task, | ||
std::vector<Legion::PhysicalRegion> const ®ions, | ||
Legion::Context ctx, | ||
Legion::Runtime *runtime); | ||
static void inference_task(Legion::Task const *task, | ||
std::vector<Legion::PhysicalRegion> const ®ions, | ||
Legion::Context ctx, | ||
Legion::Runtime *runtime); | ||
bool measure_operator_cost(Simulator *sim, | ||
MachineView const &pc, | ||
CostMetrics &cost_metrics) const override; | ||
|
||
public: | ||
float eps; | ||
char op_name[MAX_OPNAME]; | ||
int effective_batch_size; | ||
int dim, data_dim; | ||
}; | ||
} // namespace FlexFlow | ||
#endif // _FLEXFLOW_RESIDUAL_RMS_NORM_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#ifndef _FLEXFLOW_RESIDUAL_RMSNORM_PARAMS_H | ||
#define _FLEXFLOW_RESIDUAL_RMSNORM_PARAMS_H | ||
|
||
#include "flexflow/ffconst.h" | ||
#include "flexflow/fftype.h" | ||
#include "flexflow/parallel_tensor.h" | ||
|
||
namespace FlexFlow { | ||
|
||
struct ResidualRMSNormParams { | ||
LayerID layer_guid; | ||
float eps; | ||
int dim; | ||
bool is_valid( | ||
std::pair<ParallelTensorShape, ParallelTensorShape> const &input) const; | ||
}; | ||
|
||
bool operator==(ResidualRMSNormParams const &, ResidualRMSNormParams const &); | ||
|
||
} // namespace FlexFlow | ||
|
||
namespace std { | ||
template <> | ||
struct hash<FlexFlow::ResidualRMSNormParams> { | ||
size_t operator()(FlexFlow::ResidualRMSNormParams const &) const; | ||
}; | ||
} // namespace std | ||
|
||
#endif // _FLEXFLOW_RESIDUAL_RMSNORM_PARAMS_H |
Oops, something went wrong.