Skip to content

Commit

Permalink
Merge branch 'inference' into optimize_attn
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao authored Sep 23, 2023
2 parents 1f22002 + 322afa9 commit 6d1b5e9
Show file tree
Hide file tree
Showing 42 changed files with 2,741 additions and 541 deletions.
2 changes: 2 additions & 0 deletions include/flexflow/ffconst.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ enum OperatorType {
OP_MEAN, // https://pytorch.org/docs/stable/generated/torch.mean.html
OP_LAYERNORM,
OP_ADD_BIAS_RESIDUAL_LAYERNORM,
OP_SIGMOID_SILU_MULTI,
OP_EXPERTS,
OP_GATHER, // https://pytorch.org/docs/stable/generated/torch.gather.html
OP_RMS_NORM,
OP_RESIDUAL_RMS_NORM,
OP_BEAM_TOPK,
OP_ARGMAX,
OP_INC_MULTIHEAD_SELF_ATTENTION,
Expand Down
14 changes: 14 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ flexflow_tensor_t *flexflow_model_add_add_bias_residual_layer_norm(
bool use_bias,
char const *name);

flexflow_tensor_t
flexflow_model_add_sigmoid_silu_multi(flexflow_model_t handle,
const flexflow_tensor_t input1,
const flexflow_tensor_t input2,
char const *name);

flexflow_tensor_t
flexflow_model_add_batch_matmul(flexflow_model_t handle,
const flexflow_tensor_t a,
Expand Down Expand Up @@ -540,6 +546,14 @@ flexflow_tensor_t flexflow_model_add_rms_norm(flexflow_model_t handle_,
int dim,
char const *name);

flexflow_tensor_t *
flexflow_model_add_residual_rms_norm(flexflow_model_t handle_,
const flexflow_tensor_t input1_,
const flexflow_tensor_t input2_,
float eps,
int dim,
char const *name);

flexflow_tensor_t flexflow_model_add_arg_top_k(flexflow_model_t handle_,
const flexflow_tensor_t input_,
int k,
Expand Down
33 changes: 30 additions & 3 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ enum TaskIDs {
LAYERNORM_BWD_TASK_ID,
ADD_BIAS_RESIDUAL_LAYERNORM_INIT_TASK_ID,
ADD_BIAS_RESIDUAL_LAYERNORM_INF_TASK_ID,
SIGMOID_SILU_MULTI_INIT_TASK_ID,
SIGMOID_SILU_MULTI_INF_TASK_ID,
LINEAR_INIT_TASK_ID,
LINEAR_INIT_PARA_TASK_ID,
LINEAR_INF_TASK_ID,
Expand Down Expand Up @@ -153,9 +155,11 @@ enum TaskIDs {
ATTENTION_INIT_TASK_ID,
ATTENTION_FWD_TASK_ID,
ATTENTION_BWD_TASK_ID,
RMSNROM_INIT_TASK_ID,
RMSNROM_FWD_TASK_ID,
RMSNROM_INF_TASK_ID,
RMSNORM_INIT_TASK_ID,
RMSNORM_FWD_TASK_ID,
RMSNORM_INF_TASK_ID,
RESIDUAL_RMSNORM_INIT_TASK_ID,
RESIDUAL_RMSNORM_INF_TASK_ID,
BEAM_TOPK_INIT_TASK_ID,
BEAM_TOPK_INF_TASK_ID,
INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID,
Expand Down Expand Up @@ -312,6 +316,7 @@ class Gather;
class Group_by;
class LayerNorm;
class AddBiasResidualLayerNorm;
class SigmoidSiluMulti;
class Linear;
class MultiHeadAttention;
class IncMultiHeadSelfAttention;
Expand All @@ -325,6 +330,7 @@ class TopK;
class ArgTopK;
class Transpose;
class RMSNorm;
class ResidualRMSNorm;
class BeamTopK;
class SpecIncMultiHeadSelfAttention;
class Sampling;
Expand Down Expand Up @@ -545,6 +551,11 @@ class FFModel {
bool use_bias = true,
DataType data_type = DT_NONE,
char const *name = NULL);
// Add a sigmoid_silu_multi layer
Tensor sigmoid_silu_multi(const Tensor input1,
const Tensor input2,
DataType data_type = DT_NONE,
char const *name = NULL);
// Add a batch_norm layer
Tensor
batch_norm(const Tensor input, bool relu = true, char const *name = NULL);
Expand All @@ -560,6 +571,14 @@ class FFModel {
int dim,
DataType data_type = DT_NONE,
char const *name = NULL);
// Add a residual root mean square layer
void residual_rms_norm(const Tensor input1,
const Tensor input2,
Tensor *outputs,
float eps,
int dim,
DataType data_type = DT_NONE,
char const *name = NULL);
// Add a beam search top k layer
Tensor beam_top_k(const Tensor input,
int max_beam_size,
Expand Down Expand Up @@ -1132,6 +1151,10 @@ class FFModel {
std::pair<std::pair<ParallelTensorShape, ParallelTensorShape>,
AddBiasResidualLayerNormParams>,
AddBiasResidualLayerNorm *>,
std::unordered_map<
std::pair<std::pair<ParallelTensorShape, ParallelTensorShape>,
SigmoidSiluMultiParams>,
SigmoidSiluMulti *>,
std::unordered_map<std::pair<ParallelTensorShape, LinearParams>,
Linear *>,
std::unordered_map<std::pair<ParallelTensorShape, Pool2DParams>,
Expand Down Expand Up @@ -1170,6 +1193,10 @@ class FFModel {
Transpose *>,
std::unordered_map<std::pair<ParallelTensorShape, RMSNormParams>,
RMSNorm *>,
std::unordered_map<
std::pair<std::pair<ParallelTensorShape, ParallelTensorShape>,
ResidualRMSNormParams>,
ResidualRMSNorm *>,
std::unordered_map<std::pair<ParallelTensorShape, RepartitionParams>,
Repartition *>,
std::unordered_map<std::pair<ParallelTensorShape, ReplicateParams>,
Expand Down
4 changes: 4 additions & 0 deletions include/flexflow/operator_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
#include "flexflow/ops/pool_2d_params.h"
#include "flexflow/ops/reduce_params.h"
#include "flexflow/ops/reshape_params.h"
#include "flexflow/ops/residual_rms_norm_params.h"
#include "flexflow/ops/rms_norm_params.h"
#include "flexflow/ops/sampling_params.h"
#include "flexflow/ops/sigmoid_silu_multi_params.h"
#include "flexflow/ops/softmax_params.h"
#include "flexflow/ops/spec_inc_multihead_self_attention_params.h"
#include "flexflow/ops/split_params.h"
Expand Down Expand Up @@ -61,13 +63,15 @@ using OperatorParameters = mp::variant<AggregateParams,
Group_byParams,
LayerNormParams,
AddBiasResidualLayerNormParams,
SigmoidSiluMultiParams,
LinearParams,
MultiHeadAttentionParams,
IncMultiHeadSelfAttentionParams,
BeamTopKParams,
SpecIncMultiHeadSelfAttentionParams,
TreeIncMultiHeadSelfAttentionParams,
RMSNormParams,
ResidualRMSNormParams,
Pool2DParams,
ReduceParams,
ReshapeParams,
Expand Down
56 changes: 56 additions & 0 deletions include/flexflow/ops/kernels/residual_rms_norm_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef _FLEXFLOW_OPS_KERNELS_RESIDUAL_RMSNORM_KERNELS_H
#define _FLEXFLOW_OPS_KERNELS_RESIDUAL_RMSNORM_KERNELS_H

#include "flexflow/accessor.h"
#include "flexflow/device.h"
#include "flexflow/fftype.h"
#include "flexflow/op_meta.h"
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {
using Legion::coord_t;

class ResidualRMSNorm;

class ResidualRMSNormMeta : public OpMeta {
public:
ResidualRMSNormMeta(FFHandler handler,
ResidualRMSNorm const *rms,
MemoryAllocator &gpu_mem_allocator);
~ResidualRMSNormMeta(void);
#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA)
cudnnTensorDescriptor_t inputTensor, outputTensor;
cudnnReduceTensorDescriptor_t reduceDesc;
#else
miopenTensorDescriptor_t inputTensor, outputTensor;
miopenReduceTensorDescriptor_t reduceDesc;
#endif

public:
float eps;
void *rms_ptr;
void *norm_ptr;

float alpha;
float beta;

int in_dim;
int batch_size;
int num_elements;
char op_name[MAX_OPNAME];
Realm::RegionInstance reserveInst;
};

namespace Kernels {
namespace ResidualRMSNorm {
void forward_kernel_wrapper(ResidualRMSNormMeta const *m,
GenericTensorAccessorR const &input1,
GenericTensorAccessorR const &input2,
GenericTensorAccessorR const &weight,
GenericTensorAccessorW const &residual_output,
GenericTensorAccessorW const &output);
} // namespace ResidualRMSNorm
} // namespace Kernels
} // namespace FlexFlow

#endif // _FLEXFLOW_OPS_KERNELS_RMSNORM_KERNELS_H
88 changes: 88 additions & 0 deletions include/flexflow/ops/residual_rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#ifndef _FLEXFLOW_RESIDUAL_RMS_NORM_H
#define _FLEXFLOW_RESIDUAL_RMS_NORM_H

#include "flexflow/inference.h"
#include "flexflow/model.h"
#include "flexflow/ops/residual_rms_norm_params.h"
#include "flexflow/utils/memory_allocator.h"

namespace FlexFlow {

class ResidualRMSNormMeta;

class ResidualRMSNorm : public Op {
public:
using Params = ResidualRMSNormParams;
using Input = std::pair<ParallelTensor, ParallelTensor>;
ResidualRMSNorm(FFModel &model,
LayerID const &_layer_guid,
const ParallelTensor _input1,
const ParallelTensor _input2,
float _eps,
int dim,
bool allocate_weights,
char const *name);
ResidualRMSNorm(FFModel &model,
ResidualRMSNormParams const &params,
Input const &inputs,
bool allocate_weights,
char const *name = nullptr);

ResidualRMSNorm(FFModel &model,
ResidualRMSNorm const &other,
Input const &inputs,
bool allocate_weights);
void init(FFModel const &) override;
void forward(FFModel const &) override;
void backward(FFModel const &) override;
void init_inference(FFModel const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
Legion::FutureMap inference(FFModel const &,
BatchConfigFuture const &,
std::vector<ParallelTensor> const &,
std::vector<ParallelTensor> const &,
MachineView const *mv = nullptr) override;
void print_layer(FFModel const &model) override {
assert(0);
}

static Op *
create_operator_from_layer(FFModel &model,
Layer const *layer,
std::vector<ParallelTensor> const &inputs);
void serialize(Legion::Serializer &) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
ParallelTensor inputs[],
int num_inputs);
Op *materialize(FFModel &ff,
ParallelTensor inputs[],
int num_inputs) const override;
ResidualRMSNormParams get_params() const;

static OpMeta *init_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void forward_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void inference_task(Legion::Task const *task,
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;

public:
float eps;
char op_name[MAX_OPNAME];
int effective_batch_size;
int dim, data_dim;
};
} // namespace FlexFlow
#endif // _FLEXFLOW_RESIDUAL_RMS_NORM_H
29 changes: 29 additions & 0 deletions include/flexflow/ops/residual_rms_norm_params.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef _FLEXFLOW_RESIDUAL_RMSNORM_PARAMS_H
#define _FLEXFLOW_RESIDUAL_RMSNORM_PARAMS_H

#include "flexflow/ffconst.h"
#include "flexflow/fftype.h"
#include "flexflow/parallel_tensor.h"

namespace FlexFlow {

struct ResidualRMSNormParams {
LayerID layer_guid;
float eps;
int dim;
bool is_valid(
std::pair<ParallelTensorShape, ParallelTensorShape> const &input) const;
};

bool operator==(ResidualRMSNormParams const &, ResidualRMSNormParams const &);

} // namespace FlexFlow

namespace std {
template <>
struct hash<FlexFlow::ResidualRMSNormParams> {
size_t operator()(FlexFlow::ResidualRMSNormParams const &) const;
};
} // namespace std

#endif // _FLEXFLOW_RESIDUAL_RMSNORM_PARAMS_H
Loading

0 comments on commit 6d1b5e9

Please sign in to comment.