Skip to content

Commit

Permalink
return/print loss at each finetuning step
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Jul 21, 2024
1 parent 7dfb3d0 commit 4066467
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 12 deletions.
1 change: 1 addition & 0 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class TreeVerifyBatchConfig : public BatchConfig {
struct InferenceResult {
static int const MAX_NUM_TOKENS = BatchConfig::MAX_NUM_TOKENS;
BatchConfig::TokenId token_ids[MAX_NUM_TOKENS];
float finetuning_loss;
};

class BeamSearchBatchConfig : public BatchConfig {
Expand Down
11 changes: 8 additions & 3 deletions include/flexflow/ops/argmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ArgMaxMeta : public OpMeta {
size_t temp_storage_bytes = 0;
int *d_offsets;
void *d_out;
float *d_loss;
Realm::RegionInstance reserveInst;
ArgMaxMeta(FFHandler handler,
Op const *op,
Expand Down Expand Up @@ -89,18 +90,22 @@ class ArgMax : public Op {
CostMetrics &cost_metrics) const override;
template <typename DT>
static void forward_kernel(ArgMaxMeta const *m,
DT *input_ptr,
BatchConfig const *bc,
DT const *input_ptr,
int *indices_ptr,
float *prob_ptr,
int *parent_ptr,
int length,
int batch_size,
float *loss,
ffStream_t stream);
static void forward_kernel_wrapper(ArgMaxMeta const *m,
GenericTensorAccessorW const &input,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &indices,
GenericTensorAccessorW const &parent,
int batch_size);
int batch_size,
float *loss);
Params get_params() const;

public:
Expand Down
19 changes: 14 additions & 5 deletions src/ops/argmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ FutureMap ArgMax::inference(FFModel const &ff,
launcher.add_future(bc);
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_WRITE,
READ_ONLY,
EXCLUSIVE,
batch_inputs[0]->region));
launcher.add_field(0, FID_DATA);
Expand Down Expand Up @@ -354,7 +354,9 @@ BeamInferenceResult
int batch_size = bc->num_active_infr_tokens();
GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO(
DT_INT32, regions[2], task->regions[2], FID_DATA, ctx, runtime);
ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size);
float loss = 0.0f;
ArgMax::forward_kernel_wrapper(
m, bc, input, indices, parent, batch_size, &loss);
BeamInferenceResult ir;
copy_tensor_dev_to_host<BatchConfig::TokenId>(
indices.get_int32_ptr(), ir.token_ids, batch_size);
Expand Down Expand Up @@ -387,19 +389,26 @@ InferenceResult
return ir;
}

GenericTensorAccessorW input = helperGetGenericTensorAccessorRW(
GenericTensorAccessorR input = helperGetGenericTensorAccessorRO(
m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime);
GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO(
DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime);
GenericTensorAccessorW parent;
int batch_size = bc->num_active_infr_tokens();
ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size);
float loss = 0.0f;
ArgMax::forward_kernel_wrapper(
m, bc, input, indices, parent, batch_size, &loss);
if (bc->num_active_peft_tokens() > 0) {
printf("Epoch %i loss: %.4f\n", m->decoding_step, loss);
}
InferenceResult ir;
if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
ArgMax::save_inference_tensors_to_file(
m, shard_id, bc, {}, {}, {input, indices});
m, shard_id, bc, {input}, {}, {indices});
} else {
m->decoding_step++;
}
copy_tensor_dev_to_host<BatchConfig::TokenId>(
indices.get_int32_ptr(), ir.token_ids, batch_size);
Expand Down
82 changes: 78 additions & 4 deletions src/ops/argmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,33 @@ __global__ void copy_result(cub::KeyValuePair<int, DT> *d_out,
}
}

template <typename DT>
__global__ void compute_sparse_categorical_crossentropy_loss(
DT const *logits,
BatchConfig::TokenId const *labels,
float *loss,
int num_tokens,
int num_classes) {
float const LOG_MIN_VALUE = 0.00000001f;
CUDA_KERNEL_LOOP(b, num_tokens) {
float my_logit =
max((float)logits[b * num_classes + labels[b]], LOG_MIN_VALUE);
atomicAdd(loss, -log(my_logit));
}
}

/*static*/
template <typename DT>
void ArgMax::forward_kernel(ArgMaxMeta const *m,
DT *input_ptr,
BatchConfig const *bc,
DT const *input_ptr,
int *indices_ptr,
float *prob_ptr,
int *parent,
int const length,
int const batch_size,
float *loss,
cudaStream_t stream) {

checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
DT alpha = 1.0f, beta = 0.0f;
if (m->beam_search) {
Expand Down Expand Up @@ -84,14 +100,64 @@ void ArgMax::forward_kernel(ArgMaxMeta const *m,
batch_size,
m->beam_search);
// print_tensor<int>(indices_ptr, 32, "argmax op");
// compute cross-entropy loss if there is a finetuning request
assert(loss != nullptr);
BatchConfig::TokenId token_ids[BatchConfig::MAX_NUM_TOKENS];
int num_finetuning_requests = 0, num_bwd_tokens = 0;
int tokens_previous_requests = 0;
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
}
// Skip non-PEFT requests
if (bc->requestsInfo[i].peft_bwd) {
assert(num_finetuning_requests == 0 && num_bwd_tokens == 0);
num_bwd_tokens = bc->requestsInfo[i].num_tokens_in_batch - 1;
// shift labels by 1 position to the left (ignore first token label)
for (int j = 0; j < num_bwd_tokens; j++) {
token_ids[j] =
bc->tokensInfo[j + tokens_previous_requests + 1].token_id;
}
num_finetuning_requests += 1;
} else {
tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch;
}
}
assert(num_finetuning_requests <= 1);
if (num_bwd_tokens > 0) {
checkCUDA(cudaMemcpyAsync(m->handle.workSpace,
token_ids,
sizeof(BatchConfig::TokenId) * num_bwd_tokens,
cudaMemcpyHostToDevice,
stream));
// copy loss to d_loss
checkCUDA(cudaMemsetAsync(m->d_loss, 0, sizeof(float), stream));
compute_sparse_categorical_crossentropy_loss<<<GET_BLOCKS(num_bwd_tokens),
min(CUDA_NUM_THREADS,
num_bwd_tokens),
0,
stream>>>(
input_ptr,
static_cast<BatchConfig::TokenId *>(m->handle.workSpace),
m->d_loss,
num_bwd_tokens,
length);
// copy value from d_loss to loss
checkCUDA(cudaMemcpyAsync(
loss, m->d_loss, sizeof(float), cudaMemcpyDeviceToHost, stream));
*loss = *loss / (float)num_bwd_tokens;
}
}
/*static*/
void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m,
GenericTensorAccessorW const &input,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &indices,
GenericTensorAccessorW const &parent,
int batch_size) {
int batch_size,
float *loss) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
cudaEvent_t t_start, t_end;
Expand All @@ -104,24 +170,28 @@ void ArgMax::forward_kernel_wrapper(ArgMaxMeta const *m,
if (input.data_type == DT_HALF) {
ArgMax::forward_kernel<half>(m,
bc,
input.get_half_ptr(),
indices.get_int32_ptr(),
m->probs,
m->beam_search ? parent.get_int32_ptr()
: nullptr,
length,
batch_size,
loss,
stream);
} else if (input.data_type == DT_FLOAT) {
ArgMax::forward_kernel<float>(m,
bc,
input.get_float_ptr(),
indices.get_int32_ptr(),
m->probs,
m->beam_search ? parent.get_int32_ptr()
: nullptr,
length,
batch_size,
loss,
stream);
} else {
assert(false && "Unsupported data type");
Expand Down Expand Up @@ -202,6 +272,10 @@ ArgMaxMeta::ArgMaxMeta(FFHandler handler,
gpu_mem_allocator.create_legion_instance(reserveInst, temp_storage_bytes);
d_temp_storage =
gpu_mem_allocator.allocate_instance_untyped(temp_storage_bytes);
// allocate space for loss on device
gpu_mem_allocator.create_legion_instance(reserveInst, sizeof(float));
d_loss = gpu_mem_allocator.allocate_instance<float>(1);
}
ArgMaxMeta::~ArgMaxMeta(void) {
Expand Down

0 comments on commit 4066467

Please sign in to comment.