Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Nov 7, 2024
1 parent f542fbb commit 139b643
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 502 deletions.
2 changes: 2 additions & 0 deletions include/flexflow/ops/kernels/lora_linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void peft_bwd_kernel_wrapper(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
int shard_id,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output_grad);

Expand All @@ -71,6 +72,7 @@ void peft_bwd_kernel(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
int shard_id,
DT *input_grad_ptr,
DT const *output_grad_ptr,
int in_dim,
Expand Down
68 changes: 0 additions & 68 deletions include/flexflow/utils/peft_weight_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,74 +23,6 @@

namespace FlexFlow {

#ifdef DEADCODE
class PEFTWeightAllocator {
public:
PEFTWeightAllocator(void *_base_ptr, size_t _total_size)
: base_ptr(_base_ptr), total_size(_total_size), sync_offset(0),
local_offset(_total_size) {}

inline void *allocate_sync_weights_untyped(PEFTModelID const &peft_model_id,
size_t datalen) {
const std::lock_guard<std::mutex> lock(peft_weight_allocator_mutex);
void *ptr = static_cast<char *>(base_ptr) + sync_offset;
off_t model_sync_weights_offset = sync_offset;
size_t model_sync_weights_size = datalen;
if (sync_weights.find(peft_model_id) != sync_weights.end()) {
// Assert that sync weights for each PEFT model is consecutive
std::pair<off_t, size_t> offset_and_size = sync_weights[peft_model_id];
assert(sync_offset == offset_and_size.first + offset_and_size.second);
model_sync_weights_offset = offset_and_size.first;
model_sync_weights_size = offset_and_size.second + datalen;
}
sync_offset += datalen;
assert(sync_offset < local_offset);
sync_weights[peft_model_id] =
std::make_pair(model_sync_weights_offset, model_sync_weights_size);
return ptr;
}

std::pair<void *, size_t>
get_sync_weights_ptr_and_size(PEFTModelID const &peft_model_id) {
const std::lock_guard<std::mutex> lock(peft_weight_allocator_mutex);
assert(sync_weights.find(peft_model_id) != sync_weights.end());
std::pair<off_t, size_t> offset_and_size = sync_weights[peft_model_id];
return std::make_pair(static_cast<char *>(base_ptr) + offset_and_size.first,
offset_and_size.second);
}

inline void *allocate_local_weights_untyped(PEFTModelID const &peft_model_id,
size_t datalen) {
const std::lock_guard<std::mutex> lock(peft_weight_allocator_mutex);
local_offset -= datalen;
assert(sync_offset < local_offset);
void *ptr = static_cast<char *>(base_ptr) + local_offset;
return ptr;
}

template <typename DT>
inline DT *allocate_sync_weights(PEFTModelID const &peft_model_id,
size_t count) {
return static_cast<DT *>(
allocate_sync_weights_untyped(peft_model_id, sizeof(DT) * count));
}

template <typename DT>
inline DT *allocate_local_weights(PEFTModelID const &peft_model_id,
size_t count) {
return static_cast<DT *>(
allocate_local_weights_untyped(peft_model_id, sizeof(DT) * count));
}

public:
void *base_ptr;
size_t total_size;
off_t sync_offset, local_offset;
std::unordered_map<PEFTModelID, std::pair<off_t, size_t>> sync_weights;
std::mutex peft_weight_allocator_mutex;
};
#endif

struct LoraLinearWeight {
// weights
void *w0_ptr, *w1_ptr;
Expand Down
2 changes: 2 additions & 0 deletions src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -889,11 +889,13 @@ __host__ void FusedOp::peft_bwd_task(Task const *task,
// Assert that the output and the second input are at the same place
// since we ``inplace'' the output for LoRA
assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr);
int shard_id = task->index_point.point_data[0];
Kernels::LoraLinear::peft_bwd_kernel_wrapper(
ctx,
runtime,
m,
bc,
shard_id,
my_input_grad_accessor[0],
my_output_grad_accessor[0]);
break;
Expand Down
192 changes: 35 additions & 157 deletions src/ops/kernels/lora_linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,34 @@ namespace FlexFlow {

LoraLinearMeta::LoraLinearMeta(FFHandler handler, LoraLinear const *li)
: OpMeta(handler, li) {
#ifdef DEADCODE
allocated_peft_buffer_size1 = 0;
allocated_peft_buffer_size2 = 0;
#endif
}

LoraLinearMeta::~LoraLinearMeta(void) {}

std::string get_peft_dbg_folder(LoraLinearMeta const *m,
int shard_id,
bool is_fwd) {
std::string op_name_without_uid = LoraLinear::get_op_name_without_uid(m);
fs::path dst_filepath;
if (is_fwd) {
dst_filepath = get_dst_folder("fwd", m->decoding_step, shard_id);
} else {
dst_filepath = get_dst_folder("bwd", m->bwd_step, shard_id);
}
if (m->layer_guid.model_id > 0) {
assert(false && "Model ID > 0 not supported yet");
}
std::string layername = "layers." +
std::to_string(m->layer_guid.transformer_layer_id) +
"." + op_name_without_uid;
dst_filepath /= layername;
return dst_filepath.string();
}

namespace Kernels {
namespace LoraLinear {

#ifdef DEADCODE
void init_kernel_wrapper(LoraLinearMeta *m, int seed) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));

if (m->input_type[0] == DT_FLOAT) {
Internal::init_kernel<float>(m, seed, stream);
} else if (m->input_type[0] == DT_HALF) {
Internal::init_kernel<half>(m, seed, stream);
} else {
assert(false && "Unsupported data type");
}
}
#endif

void inference_kernel_wrapper(LoraLinearMeta *m,
BatchConfig const *bc,
Expand Down Expand Up @@ -104,6 +107,7 @@ void peft_bwd_kernel_wrapper(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
int shard_id,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output_grad) {
cudaStream_t stream;
Expand All @@ -121,6 +125,7 @@ void peft_bwd_kernel_wrapper(Context ctx,
runtime,
m,
bc,
shard_id,
input_grad.get_float_ptr(),
output_grad.get_float_ptr(),
in_dim,
Expand All @@ -131,6 +136,7 @@ void peft_bwd_kernel_wrapper(Context ctx,
runtime,
m,
bc,
shard_id,
input_grad.get_half_ptr(),
output_grad.get_half_ptr(),
in_dim,
Expand Down Expand Up @@ -168,146 +174,6 @@ bool lora_applies_to_this_layer(LoraLinearMeta *m,

namespace Internal {

#ifdef DEADCODE
template <typename DT>
void inference_kernel(LoraLinearMeta *m,
BatchConfig const *bc,
DT const *input_ptr,
DT *output_ptr,
int in_dim,
int out_dim,
ffStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
DT alpha = 1.0f, beta = 0.0f;
cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]);
cudaDataType_t output_type = ff_to_cuda_datatype(m->input_type[1]);
cudaDataType_t lr_actv_type = output_type;
assert(input_type == output_type);
cudaDataType_t weight_type = output_type;
cudaDataType_t compute_type = output_type;
// #if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
// cudaDataType_t compute_type = output_type;
// #else
// // For best performance, set the default cublas compute type to
// // CUBLAS_COMPUTE_16F for half precision and to
// // CUBLAS_COMPUTE_32F_FAST_16F for full precision
// cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
// if (m->input_type[0] == DT_FLOAT) {
// compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
// }
// #endif
int num_peft_requests = 0;
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
}
if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) {
continue;
}
if (bc->requestsInfo[i].peft_bwd) {
num_peft_requests++;
}
}
// Assert that we have at most one request that requires peft_bwd
assert(num_peft_requests <= 1);
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
}
// Skip non-PEFT requests
if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) {
continue;
}
int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch;
int max_peft_tokens = bc->requestsInfo[i].max_length;
int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch;
assert(m->model_state.find(bc->requestsInfo[i].peft_model_id) !=
m->model_state.end());
LoraLinearWeight weight =
m->model_state[bc->requestsInfo[i].peft_model_id].weights;
int rank = weight.rank;
void *intermediate_result_ptr = nullptr;
if (bc->requestsInfo[i].peft_bwd) {
size_t activation_size_needed1 =
data_type_size(m->input_type[0]) * max_peft_tokens * in_dim;
size_t activation_size_needed2 =
data_type_size(m->input_type[1]) * max_peft_tokens * rank;
MemoryAllocator *allocator = m->handle.peft_activation_allocator;
if (activation_size_needed1 > m->allocated_peft_buffer_size1) {
m->input_activation =
allocator->allocate_instance_untyped(activation_size_needed1);
m->allocated_peft_buffer_size1 = activation_size_needed1;
}
if (activation_size_needed2 > m->allocated_peft_buffer_size2) {
m->low_rank_activation =
allocator->allocate_instance_untyped(activation_size_needed2);
m->allocated_peft_buffer_size2 = activation_size_needed2;
}
// copy input activation
checkCUDA(cudaMemcpyAsync(m->input_activation,
input_ptr + first_token_offset * in_dim,
data_type_size(m->input_type[0]) *
num_peft_tokens * in_dim,
cudaMemcpyDeviceToDevice,
stream));
intermediate_result_ptr = m->low_rank_activation;
} else {
// use workspace to save intermediate result
assert(m->handle.workSpaceSize >=
data_type_size(m->input_type[1]) * num_peft_tokens * rank);
intermediate_result_ptr = m->handle.workSpace;
}
// buffer = weight_first * input
// [rank, num_peft_tokens] = [in_dim, rank].T * [in_dim, num_peft_tokens]
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
rank,
num_peft_tokens,
in_dim,
&alpha,
weight.w0_ptr,
weight_type,
in_dim,
input_ptr + first_token_offset * in_dim,
input_type,
in_dim,
&beta,
intermediate_result_ptr,
lr_actv_type,
rank,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// output = weight_second * buffer
// [out_dim, num_peft_tokens] = [rank, out_dim].T * [rank, num_peft_tokens]
// Note that we use alpha in both places since we do
// an in-place update for LoraLinear
float lora_alpha =
m->model_state[bc->requestsInfo[i].peft_model_id].lora_alpha;
DT scaling_constant = (DT)(lora_alpha / rank);
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_dim,
num_peft_tokens,
rank,
&scaling_constant,
weight.w1_ptr,
weight_type,
rank,
intermediate_result_ptr,
lr_actv_type,
rank,
&alpha,
output_ptr + first_token_offset * out_dim,
output_type,
out_dim,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
#endif

template <typename DT>
void inference_kernel(LoraLinearMeta *m,
Expand Down Expand Up @@ -342,6 +208,8 @@ void inference_kernel(LoraLinearMeta *m,
if (!lora_applies_to_this_layer(m, lora_config)) {
continue;
}
std::cout << "Lora layer activated!" << std::endl;
std::cout << "Lora Config: " << peft_model_config_str << std::endl;
assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd &&
"Trainable flag mismatch");
int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch;
Expand Down Expand Up @@ -443,6 +311,7 @@ void peft_bwd_kernel(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
int shard_id,
DT *input_grad_ptr,
DT const *output_grad_ptr,
int in_dim,
Expand Down Expand Up @@ -471,6 +340,8 @@ void peft_bwd_kernel(Context ctx,
if (!lora_applies_to_this_layer(m, lora_config)) {
continue;
}
std::cout << "Lora layer activated!" << std::endl;
std::cout << "Lora Config: " << peft_model_config_str << std::endl;
assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd &&
"Trainable flag mismatch");
m->peft_memory_manager->check_ft_model_id(
Expand All @@ -488,6 +359,13 @@ void peft_bwd_kernel(Context ctx,
DT beta = (bc->requestsInfo[i].optimizer_tasks.reset_gradients_to_zero)
? 0.0f
: 1.0f;
std::cout << "Lora B gradient computation, beta = " << (float) beta << std::endl;
if (m->inference_debugging) {
// save result to file for checking
std::string filename = get_peft_dbg_folder(m, shard_id, false) + ".low_rank_activation";
std::cout << "Save low_rank_activation (" << lora_config.rank << ", " << num_peft_tokens << ") to " << filename << std::endl;
save_tensor(static_cast<const DT*>(weight.low_rank_activation), lora_config.rank*num_peft_tokens, filename.c_str());
}
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_T,
Expand Down
Loading

0 comments on commit 139b643

Please sign in to comment.