Skip to content

Commit

Permalink
fix tp
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaoc committed Feb 22, 2024
1 parent d958805 commit 38dfd87
Show file tree
Hide file tree
Showing 21 changed files with 484 additions and 108 deletions.
3 changes: 3 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class FFConfig {
size_t workSpaceSize;
Legion::Context lg_ctx;
Legion::Runtime *lg_hlr;
Legion::IndexSpaceT<1> all_gpu_task_is;
Legion::FieldSpace field_space;
bool syntheticInput, profiling, perform_fusion;
size_t simulator_work_space_size;
Expand All @@ -137,6 +138,8 @@ class FFConfig {
bool enable_parameter_parallel;
bool enable_attribute_parallel;
bool enable_inplace_optimizations;
int data_parallelism_degree;
int tensor_parallelism_degree;
// Control Tensor Op Math Conversion
bool allow_tensor_op_math_conversion;
std::string dataset_path;
Expand Down
6 changes: 6 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class Split;
class TopK;
class Transpose;
class Combine;
class AllReduce;
class Repartition;
class Reduction;
class Replicate;
Expand Down Expand Up @@ -834,6 +835,8 @@ class FFModel {
Legion::IndexSpace get_task_is(Legion::Domain const &domain) const;
Legion::IndexSpace get_task_is(ParallelConfig const &pc) const;
Legion::IndexSpace get_task_is(MachineView const &view) const;
bool is_transformer_block(int layer_idx) const;
bool is_mlp_block(int layer_idx) const;
void create_operators_from_layers();
Op *create_operator_from_layer(Layer *layer,
std::vector<ParallelTensor> const &inputs);
Expand All @@ -860,6 +863,7 @@ class FFModel {
int metrics_input;
ParallelTensor parallel_label_tensor;
Tensor label_tensor;
int num_inputs = 0;

std::vector<Layer *> layers;
std::vector<Op *> operators;
Expand Down Expand Up @@ -929,6 +933,8 @@ class FFModel {
Replicate *>,
std::unordered_map<std::pair<ParallelTensorShape, ReductionParams>,
Reduction *>,
std::unordered_map<std::pair<ParallelTensorShape, AllReduceParams>,
AllReduce *>,
std::unordered_map<std::pair<ParallelTensorShape, CombineParams>,
Combine *>,
std::unordered_map<std::pair<ParallelTensorShape, FusedParallelOpParams>,
Expand Down
5 changes: 5 additions & 0 deletions include/flexflow/ops/element_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class ElementBinary : public Op {
bool measure_operator_cost(Simulator *sim,
MachineView const &pc,
CostMetrics &cost_metrics) const override;
void serialize(Legion::Serializer &) const override;
static PCG::Node deserialize(FFModel &ff,
Legion::Deserializer &d,
ParallelTensor inputs[],
int num_inputs);
Params get_params() const;

public:
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/element_binary_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace FlexFlow {

struct ElementBinaryParams {
OperatorType type;
bool inplace_a;

bool is_valid(
std::pair<ParallelTensorShape, ParallelTensorShape> const &) const;
Expand Down
3 changes: 3 additions & 0 deletions include/flexflow/utils/cuda_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ __global__ void assign_kernel(DT *ptr, Legion::coord_t size, DT value);
template <typename DT>
__global__ void copy_kernel(DT *dst, const DT *src, Legion::coord_t size);

template <typename DT>
__global__ void copy_kernel_with_replicate(DT *dst, const DT *src, Legion::coord_t origin_size, Legion::coord_t size);

template <typename T>
__global__ void add_kernel(T *data_ptr, T const *grad_ptr, size_t size);

Expand Down
6 changes: 3 additions & 3 deletions python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,11 +2079,11 @@ def load_bert_pretrained(self, checkpoint=None):
layer = self._layers[i]
if (layer.name + "_weight") in weights_dict:
print('weight: ' + layer.name)
weight = layer.get_parameter_by_id(0);
weight = layer.get_parameter_by_id(0)
weight.set_tensor(self, weights_dict[layer.name + "_weight"])
if (layer.name + "_bias") in weights_dict:
print('bias: ' + layer.name)
bias = layer.get_parameter_by_id(1);
bias = layer.get_parameter_by_id(1)
bias.set_tensor(self, weights_dict[layer.name + "_bias"])
def fit(self, x=None, y=None, batch_size=None, epochs=1):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
Expand Down Expand Up @@ -2126,7 +2126,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1):
self.forward()
# self.zero_gradients()
self.backward()
self.unified_update()
self.update()
self._ffconfig.end_trace(self._tracing_id)

def eval(self, x=None, y=None, batch_size=None):
Expand Down
2 changes: 1 addition & 1 deletion src/dataloader/dataloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ SingleDataLoader::SingleDataLoader(FFModel &ff,
datatype = datatype_;
// Currently assume that the leading dim of input is a replica dim of degree 1
assert(input->dims[input->num_dims - 1].is_replica_dim);
assert(input->dims[input->num_dims - 1].size == 1);
// assert(input->dims[input->num_dims - 1].size == 1);

batch_input = input;
ParallelDim dims[MAX_TENSOR_DIM];
Expand Down
15 changes: 11 additions & 4 deletions src/dataloader/dataloader.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ void SingleDataLoader::load_input(Task const *task,
int num_dims = full_input_domain.get_dim();
assert(num_dims + 1 == batch_input_domain.get_dim());
// assert the leading replica dim has a degree of one
assert(batch_input_domain.hi()[num_dims] ==
batch_input_domain.lo()[num_dims]);
// assert(batch_input_domain.hi()[num_dims] ==
// batch_input_domain.lo()[num_dims]);
coord_t batch_size = batch_input_domain.hi()[num_dims - 1] -
batch_input_domain.lo()[num_dims - 1] + 1;

coord_t replicate_num =
batch_input_domain.hi()[num_dims] - batch_input_domain.lo()[num_dims] + 1;
coord_t num_elements_per_batch = batch_input_domain.get_volume() / batch_size;
// FIXME: currently assume continous indices
assert(batch_size == meta->num_samples);
Expand All @@ -60,11 +63,15 @@ void SingleDataLoader::load_input(Task const *task,
// printf("ptr(%p, %p), idx0 %d nb_elements_per_batch %d, batch_size %d,
// %d\n", acc_full_input.ptr, input_zc, start_idx, num_elements_per_batch,
// batch_size, start_idx * num_elements_per_batch);
copy_kernel<DT>
assert(batch_input_domain.get_volume() % replicate_num == 0);
copy_kernel_with_replicate<DT>
<<<GET_BLOCKS(batch_input_domain.get_volume()),
CUDA_NUM_THREADS,
0,
stream>>>(batch_input_ptr, input_zc, batch_input_domain.get_volume());
stream>>>(batch_input_ptr,
input_zc,
batch_input_domain.get_volume() / replicate_num,
batch_input_domain.get_volume());
checkCUDA(cudaDeviceSynchronize());
}

Expand Down
22 changes: 22 additions & 0 deletions src/ops/element_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -802,10 +802,32 @@ bool ElementBinary::measure_operator_cost(Simulator *sim,

return true;
}
void ElementBinary::serialize(Legion::Serializer &sez) const {
sez.serialize(this->op_type);
sez.serialize(this->inplace_a);
}

using PCG::Node;
/*static*/
Node ElementBinary::deserialize(FFModel &ff,
Legion::Deserializer &dez,
ParallelTensor inputs[],
int num_inputs) {
assert(num_inputs == 2);
OperatorType op_type;
bool inplace_a;
dez.deserialize(op_type);
dez.deserialize(inplace_a);
ElementBinaryParams params;
params.type = op_type;
params.inplace_a = inplace_a;
return ff.get_or_create_node<ElementBinary>({inputs[0], inputs[1]}, params);
}

ElementBinaryParams ElementBinary::get_params() const {
ElementBinaryParams params;
params.type = this->op_type;
params.inplace_a = this->inplace_a;
return params;
}

Expand Down
12 changes: 8 additions & 4 deletions src/ops/embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,27 @@ int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) {
int const OUT_CHANNELS = Output::OUT_CHANNELS;
if (aggr == AGGR_MODE_NONE) {
int num_dims = input->num_dims + 1;
for (int i = 1; i < num_dims; i++) {
for (int i = 1; i < num_dims - 1; i++) {
output_dims[i] = input->dims[i - 1];
}
assert(OUT_CHANNELS == 0);
output_dims[OUT_CHANNELS].size = this->out_channels;
output_dims[OUT_CHANNELS].degree = 1;
output_dims[OUT_CHANNELS].parallel_idx = -1;
// Copy replica dim
output_dims[num_dims - 1] = input->dims[input->num_dims - 1];
return num_dims;
} else {
int num_dims = input->num_dims;
for (int i = 1; i < num_dims; i++) {
for (int i = 1; i < num_dims - 1; i++) {
output_dims[i] = input->dims[i];
}
assert(OUT_CHANNELS == 0);
output_dims[OUT_CHANNELS].size = this->out_channels;
output_dims[OUT_CHANNELS].degree = 1;
output_dims[OUT_CHANNELS].parallel_idx = -1;
// Copy replica dim
output_dims[num_dims - 1] = input->dims[input->num_dims - 1];
return num_dims;
}
// const int REPLICA = this->output_vocab_size_replica_dim();
Expand All @@ -179,13 +183,13 @@ int Embedding::weight_size(ParallelDim weight_dims[MAX_TENSOR_DIM]) {
weight_dims[Weight::VOCAB_SIZE].size = this->num_entries;
weight_dims[Weight::VOCAB_SIZE].degree = 1;
weight_dims[Weight::VOCAB_SIZE].parallel_idx = -1;
for (int i = 2; i < input->num_dims; i++) {
for (int i = 2; i < input->num_dims + 1; i++) {
weight_dims[i].size = input->dims[i - 1].degree;
weight_dims[i].degree = weight_dims[i].size;
weight_dims[i].parallel_idx = input->dims[i - 1].parallel_idx;
weight_dims[i].is_replica_dim = true;
}
return input->num_dims;
return input->num_dims + 1;
}

void Embedding::register_output_mappings() {
Expand Down
4 changes: 2 additions & 2 deletions src/ops/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ LayerNorm::LayerNorm(FFModel &model,
dims[i] = inputs[0]->dims[i];
}
assert(numInputs == 1);
dims[num_dims].degree = inputs[0]->dims[inputs[0]->num_dims - 2].degree;
dims[num_dims].degree = inputs[0]->dims[inputs[0]->num_dims - 1].degree;
dims[num_dims].size = dims[num_dims].degree;
dims[num_dims].parallel_idx =
inputs[0]->dims[inputs[0]->num_dims - 2].parallel_idx;
inputs[0]->dims[inputs[0]->num_dims - 1].parallel_idx;
dims[num_dims].is_replica_dim = true;
num_dims += 1;

Expand Down
20 changes: 18 additions & 2 deletions src/ops/linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,23 @@ Linear::Linear(FFModel &model,
params.construct_mappings(*this->parallel_dims_mapping, input_shape);
params.solve_dims(input_shape, output_shape, kernel_shape, bias_shape);

kernel_shape.dims[0].size = this->in_channels;
bias_shape.dims[0].degree = _input->dims[_input->num_dims - 1].degree;
bias_shape.dims[0].parallel_idx =
_input->dims[_input->num_dims - 1].parallel_idx;
bias_shape.dims[1].size = bias_shape.dims[1].degree = 1;
bias_shape.dims[1].parallel_idx = -1;
bias_shape.dims[bias_shape.num_dims - 1].size =
bias_shape.dims[bias_shape.num_dims - 1].degree = 1;
for (int i = 0; i < input_shape.num_dims - 1; i++) {
if (_input->dims[i].degree > 1) {
bias_shape.dims[bias_shape.num_dims - 1].size *= _input->dims[i].degree;
bias_shape.dims[bias_shape.num_dims - 1].degree *= _input->dims[i].degree;
bias_shape.dims[bias_shape.num_dims - 1].parallel_idx =
_input->dims[i].parallel_idx;
}
}

if (allocate_weights) {
Initializer *kernel_initializer = new GlorotUniform(std::rand() /*seed*/);

Expand Down Expand Up @@ -220,7 +237,6 @@ Linear::Linear(FFModel &model,
outputs[0] = model.create_parallel_tensor_legion_ordering(
output_shape.num_dims, output_shape.dims, _data_type, this);

assert(check_output_input_weight_parallel_dims(allocate_weights));
}

void Linear::init(FFModel const &ff) {
Expand Down Expand Up @@ -433,7 +449,7 @@ void Linear::forward_task_with_dim(Task const *task,
int out_dim = acc_output.rect.hi[0] - acc_output.rect.lo[0] + 1;
int batch_size = acc_output.rect.volume() / out_dim;
assert(acc_output.rect.volume() == static_cast<size_t>(out_dim * batch_size));
assert(acc_input.rect.volume() == static_cast<size_t>(in_dim * batch_size));
// assert(acc_input.rect.volume() == static_cast<size_t>(in_dim * batch_size));
assert(acc_kernel.rect.volume() == static_cast<size_t>(in_dim * out_dim));
float const *acc_bias_ptr = NULL;
if (m->use_bias) {
Expand Down
82 changes: 75 additions & 7 deletions src/ops/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,14 @@ Op *Reshape::create_operator_from_layer(
return new Reshape(model, layer->layer_guid, inputs[0], shape, layer->name);
}

bool match_pattern(std::vector<int> const &_shape) {
return (_shape.size() == 4 && _shape[1] == 1 && _shape[2] == 1 &&
_shape[3] == 512);
}

Reshape::Reshape(FFModel &model,
LayerID const &_layer_guid,
const ParallelTensor input,
ParallelTensor const input,
std::vector<int> const &_shape,
char const *name)
: Op(model,
Expand All @@ -106,19 +111,64 @@ Reshape::Reshape(FFModel &model,
if (input->dims[i].is_replica_dim) {
num_replica_dims++;
}
// std::cout << "reshape input size: " << input->dims[i].size
// << ", parallelidx: " << input->dims[i].parallel_idx << ". degree: " << input->dims[i].degree
// << "is replicate dim: " << input->dims[i].is_replica_dim <<
// "\n";
}

// assert(false);
// assert that all replica dims are leading dims
for (int i = 0; i < num_replica_dims; i++) {
assert(input->dims[input->num_dims - 1 - i].is_replica_dim);
}
int numdim = (int)_shape.size();
ParallelDim dims[MAX_TENSOR_DIM];
for (int i = 0; i < numdim; i++) {
dims[i].size = _shape[numdim - 1 - i];
dims[i].degree = 1;
dims[i].parallel_idx = -1;
dims[i].is_replica_dim = false;
}


bool expanded = numdim >= input->num_dims;
bool aggregation = numdim < input->num_dims - 1;

for (int i = 0; i < numdim; i++) {
if (expanded && i < numdim - 1 &&
_shape[i] * _shape[i + 1] == input->dims[numdim - i - 2].size) {
dims[numdim - i - 1].size = _shape[i];
dims[numdim - i - 1].degree = input->dims[numdim - i - 2].degree;
dims[numdim - i - 1].parallel_idx =
input->dims[numdim - i - 2].parallel_idx;
dims[numdim - i - 1].is_replica_dim =
input->dims[numdim - i - 2].is_replica_dim;
std::cout << "expand dim i:" << i << ", " << dims[numdim - i - 1].degree
<< ", " << dims[numdim - i - 1].size << "\n";
} else if (aggregation &&
(_shape[i] == input->dims[input->num_dims - 2 - i].size *
input->dims[input->num_dims - 3 - i].size)) {
// inherit
dims[numdim - i - 1].size = _shape[i];
dims[numdim - i - 1].degree =
input->dims[input->num_dims - 2 - i].degree;
dims[numdim - i - 1].parallel_idx =
input->dims[input->num_dims - 2 - i].parallel_idx;
dims[numdim - i - 1].is_replica_dim =
input->dims[input->num_dims - 2 - i].is_replica_dim;
// std::cout << "agree i: " << i <<", " << _shape[i] << "\n";
} else {
dims[numdim - i - 1].size = _shape[i];
dims[numdim - i - 1].degree = 1;
dims[numdim - i - 1].parallel_idx = -1;
dims[numdim - i - 1].is_replica_dim = false;
}
}




// for (int i = 0; i < numdim; i++) {
// dims[i].size = _shape[numdim - 1 - i];
// dims[i].degree = 1;
// dims[i].parallel_idx = -1;
// dims[i].is_replica_dim = false;
// }
// copy all replica dims
for (int i = 0; i < num_replica_dims; i++) {
dims[i + numdim] = input->dims[input->num_dims - 1 - i];
Expand All @@ -131,6 +181,24 @@ Reshape::Reshape(FFModel &model,
}
dims[numdim - 1 - i] = input->dims[input->num_dims - 1 - i];
}

//TODO temporary fix for input to attention QK, fix it after fuse the attention block
if(match_pattern(_shape) && model.config.tensor_parallelism_degree > 1){
//number of heads

dims[2].size = 12;
dims[2].degree = model.config.tensor_parallelism_degree;
dims[2].parallel_idx = 0;
dims[2].is_replica_dim = true;

dims[4].size = 1;
dims[4].degree = 1;
dims[4].parallel_idx = -1;
dims[4].is_replica_dim = false;

}


outputs[0] = model.create_parallel_tensor_legion_ordering(
numdim, dims, input->data_type, this);
assert(outputs[0]->get_volume() == inputs[0]->get_volume());
Expand Down
1 change: 1 addition & 0 deletions src/parallel_ops/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ void AllReduce::backward(FFModel const &ff) {
Runtime *runtime = ff.config.lg_hlr;
assert(numOutputs == 1);
assert(numInputs == 1);
set_argumentmap_for_backward(ff, argmap);
IndexLauncher launcher(ALLREDUCE_BWD_TASK_ID,
inputs[0]->parallel_is,
TaskArgument(NULL, 0),
Expand Down
Loading

0 comments on commit 38dfd87

Please sign in to comment.