From 40664678c87443ee0e20d39312a3a684cb1e6000 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sun, 21 Jul 2024 06:22:59 +0000 Subject: [PATCH] return/print loss at each finetuning step --- include/flexflow/batch_config.h | 1 + include/flexflow/ops/argmax.h | 11 +++-- src/ops/argmax.cc | 19 ++++++-- src/ops/argmax.cu | 82 +++++++++++++++++++++++++++++++-- 4 files changed, 101 insertions(+), 12 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 15806fcf66..873fed0bdb 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -162,6 +162,7 @@ class TreeVerifyBatchConfig : public BatchConfig { struct InferenceResult { static int const MAX_NUM_TOKENS = BatchConfig::MAX_NUM_TOKENS; BatchConfig::TokenId token_ids[MAX_NUM_TOKENS]; + float finetuning_loss; }; class BeamSearchBatchConfig : public BatchConfig { diff --git a/include/flexflow/ops/argmax.h b/include/flexflow/ops/argmax.h index 298059e3ed..eca9943d20 100644 --- a/include/flexflow/ops/argmax.h +++ b/include/flexflow/ops/argmax.h @@ -17,6 +17,7 @@ class ArgMaxMeta : public OpMeta { size_t temp_storage_bytes = 0; int *d_offsets; void *d_out; + float *d_loss; Realm::RegionInstance reserveInst; ArgMaxMeta(FFHandler handler, Op const *op, @@ -89,18 +90,22 @@ class ArgMax : public Op { CostMetrics &cost_metrics) const override; template static void forward_kernel(ArgMaxMeta const *m, - DT *input_ptr, + BatchConfig const *bc, + DT const *input_ptr, int *indices_ptr, float *prob_ptr, int *parent_ptr, int length, int batch_size, + float *loss, ffStream_t stream); static void forward_kernel_wrapper(ArgMaxMeta const *m, - GenericTensorAccessorW const &input, + BatchConfig const *bc, + GenericTensorAccessorR const &input, GenericTensorAccessorW const &indices, GenericTensorAccessorW const &parent, - int batch_size); + int batch_size, + float *loss); Params get_params() const; public: diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index c2646b226c..6cc500654c 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -317,7 +317,7 @@ FutureMap ArgMax::inference(FFModel const &ff, launcher.add_future(bc); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_WRITE, + READ_ONLY, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); @@ -354,7 +354,9 @@ BeamInferenceResult int batch_size = bc->num_active_infr_tokens(); GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO( DT_INT32, regions[2], task->regions[2], FID_DATA, ctx, runtime); - ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size); + float loss = 0.0f; + ArgMax::forward_kernel_wrapper( + m, bc, input, indices, parent, batch_size, &loss); BeamInferenceResult ir; copy_tensor_dev_to_host( indices.get_int32_ptr(), ir.token_ids, batch_size); @@ -387,19 +389,26 @@ InferenceResult return ir; } - GenericTensorAccessorW input = helperGetGenericTensorAccessorRW( + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); GenericTensorAccessorW parent; int batch_size = bc->num_active_infr_tokens(); - ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size); + float loss = 0.0f; + ArgMax::forward_kernel_wrapper( + m, bc, input, indices, parent, batch_size, &loss); + if (bc->num_active_peft_tokens() > 0) { + printf("Epoch %i loss: %.4f\n", m->decoding_step, loss); + } InferenceResult ir; if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; ArgMax::save_inference_tensors_to_file( - m, shard_id, bc, {}, {}, {input, indices}); + m, shard_id, bc, {input}, {}, {indices}); + } else { + m->decoding_step++; } copy_tensor_dev_to_host( indices.get_int32_ptr(), ir.token_ids, batch_size); diff --git a/src/ops/argmax.cu b/src/ops/argmax.cu index 05c84719c1..78f432acee 100644 --- a/src/ops/argmax.cu +++ b/src/ops/argmax.cu @@ -44,17 +44,33 @@ __global__ void copy_result(cub::KeyValuePair *d_out, } } +template +__global__ void compute_sparse_categorical_crossentropy_loss( + DT const *logits, + BatchConfig::TokenId const *labels, + float *loss, + int num_tokens, + int num_classes) { + float const LOG_MIN_VALUE = 0.00000001f; + CUDA_KERNEL_LOOP(b, num_tokens) { + float my_logit = + max((float)logits[b * num_classes + labels[b]], LOG_MIN_VALUE); + atomicAdd(loss, -log(my_logit)); + } +} + /*static*/ template void ArgMax::forward_kernel(ArgMaxMeta const *m, - DT *input_ptr, + BatchConfig const *bc, + DT const *input_ptr, int *indices_ptr, float *prob_ptr, int *parent, int const length, int const batch_size, + float *loss, cudaStream_t stream) { - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); DT alpha = 1.0f, beta = 0.0f; if (m->beam_search) { @@ -84,14 +100,64 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m, batch_size, m->beam_search); // print_tensor(indices_ptr, 32, "argmax op"); + + // compute cross-entropy loss if there is a finetuning request + assert(loss != nullptr); + BatchConfig::TokenId token_ids[BatchConfig::MAX_NUM_TOKENS]; + int num_finetuning_requests = 0, num_bwd_tokens = 0; + int tokens_previous_requests = 0; + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i]) { + continue; + } + // Skip non-PEFT requests + if (bc->requestsInfo[i].peft_bwd) { + assert(num_finetuning_requests == 0 && num_bwd_tokens == 0); + num_bwd_tokens = bc->requestsInfo[i].num_tokens_in_batch - 1; + // shift labels by 1 position to the left (ignore first token label) + for (int j = 0; j < num_bwd_tokens; j++) { + token_ids[j] = + bc->tokensInfo[j + tokens_previous_requests + 1].token_id; + } + num_finetuning_requests += 1; + } else { + tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + } + } + assert(num_finetuning_requests <= 1); + if (num_bwd_tokens > 0) { + checkCUDA(cudaMemcpyAsync(m->handle.workSpace, + token_ids, + sizeof(BatchConfig::TokenId) * num_bwd_tokens, + cudaMemcpyHostToDevice, + stream)); + // copy loss to d_loss + checkCUDA(cudaMemsetAsync(m->d_loss, 0, sizeof(float), stream)); + compute_sparse_categorical_crossentropy_loss<<>>( + input_ptr, + static_cast(m->handle.workSpace), + m->d_loss, + num_bwd_tokens, + length); + // copy value from d_loss to loss + checkCUDA(cudaMemcpyAsync( + loss, m->d_loss, sizeof(float), cudaMemcpyDeviceToHost, stream)); + *loss = *loss / (float)num_bwd_tokens; + } } /*static*/ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, - GenericTensorAccessorW const &input, + BatchConfig const *bc, + GenericTensorAccessorR const &input, GenericTensorAccessorW const &indices, GenericTensorAccessorW const &parent, - int batch_size) { + int batch_size, + float *loss) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); cudaEvent_t t_start, t_end; @@ -104,6 +170,7 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, if (input.data_type == DT_HALF) { ArgMax::forward_kernel(m, + bc, input.get_half_ptr(), indices.get_int32_ptr(), m->probs, @@ -111,10 +178,12 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, : nullptr, length, batch_size, + loss, stream); } else if (input.data_type == DT_FLOAT) { ArgMax::forward_kernel(m, + bc, input.get_float_ptr(), indices.get_int32_ptr(), m->probs, @@ -122,6 +191,7 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m, : nullptr, length, batch_size, + loss, stream); } else { assert(false && "Unsupported data type"); @@ -202,6 +272,10 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler, gpu_mem_allocator.create_legion_instance(reserveInst, temp_storage_bytes); d_temp_storage = gpu_mem_allocator.allocate_instance_untyped(temp_storage_bytes); + + // allocate space for loss on device + gpu_mem_allocator.create_legion_instance(reserveInst, sizeof(float)); + d_loss = gpu_mem_allocator.allocate_instance(1); } ArgMaxMeta::~ArgMaxMeta(void) {