From bc61e9ddbe33ced6574fbf91fafc26212b8a6f56 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Fri, 26 Jan 2024 20:39:06 -0500 Subject: [PATCH 01/18] Peft alignment & debugging tools (#1288) * Revert "several hacks for performance measurement; some of the changes should be reverted" This reverts commit b9c392631b596db788ead74fe76d08d80a487b7c. * backup * backup * updates * update * backup * backup * backup * fix * cleanup * fix * fix * fix * update * simplify tensor names * fix * fixes and updates * fixes * fix * cleanup * . * restore softmax * cleanup * update alignment scripts * newline --- .gitignore | 3 +- include/flexflow/operator.h | 26 +- src/ops/add_bias_residual_layer_norm.cu | 84 +- src/ops/argmax.cc | 5 - src/ops/inc_multihead_self_attention.cu | 102 + src/ops/kernels/linear_kernels.cu | 21 +- src/ops/kernels/residual_rms_norm_kernels.cu | 45 +- src/ops/kernels/rms_norm_kernels.cu | 76 +- src/ops/layer_norm.cu | 69 +- src/ops/linear.cc | 20 +- src/ops/lora_linear.cc | 34 +- src/ops/residual_layer_norm.cu | 188 +- tests/peft/alignment/align_test_utils.py | 240 ++ .../alignment/llama_alignment_tests.ipynb | 2039 +++++++++++++++++ .../peft/alignment/opt_alignment_tests.ipynb | 450 ++++ tests/peft/alignment_tests.ipynb | 1427 ------------ tests/peft/hf_finetune.py | 70 +- tests/peft/hf_serve.py | 8 - tests/peft/qk_prods_alignment.ipynb | 24 - 19 files changed, 3161 insertions(+), 1770 deletions(-) create mode 100644 tests/peft/alignment/align_test_utils.py create mode 100644 tests/peft/alignment/llama_alignment_tests.ipynb create mode 100644 tests/peft/alignment/opt_alignment_tests.ipynb delete mode 100644 tests/peft/alignment_tests.ipynb delete mode 100644 tests/peft/qk_prods_alignment.ipynb diff --git a/.gitignore b/.gitignore index 0579eb5a74..23da3c5899 100644 --- a/.gitignore +++ b/.gitignore @@ -189,4 +189,5 @@ python/flexflow/version.txt inference_tensors hf_peft_tensors -Untitled-1.ipynb \ No newline at end of file +Untitled-1.ipynb +Untitled-2.ipynb diff --git a/include/flexflow/operator.h b/include/flexflow/operator.h index e3f28756ec..2dfba77b77 100644 --- a/include/flexflow/operator.h +++ b/include/flexflow/operator.h @@ -267,7 +267,7 @@ class Op { bool fwd_pass = true, bool before_kernel = false) { // Check if output directory exists, and create it if it does not - char const *folder_path = "./inference_tensors"; + char const *folder_path = "./inference_tensors/"; struct stat st = {0}; if (stat(folder_path, &st) == -1) { // Directory does not exist, create it @@ -275,20 +275,26 @@ class Op { } // output base filepath, shared by all tensors from the same operator std::string op_name_without_uid = get_op_name_without_uid(m); - std::string base_filepath = - "./inference_tensors/model_" + std::to_string(m->layer_guid.model_id) + - (fwd_pass ? "_decoding-step_" : "_bwd-step_") + - (fwd_pass ? std::to_string(m->decoding_step) - : std::to_string(m->bwd_step)) + - "_layer-num_" + std::to_string(m->layer_guid.transformer_layer_id) + - "_layer-name_" + op_name_without_uid + "_shard-id_" + - std::to_string(shard_id); + std::cout << (fwd_pass ? "INF " : "BWD ") << op_name_without_uid + << std::endl; + std::string base_filepath = std::string(folder_path); + if (m->layer_guid.model_id > 0) { + base_filepath += "model_" + std::to_string(m->layer_guid.model_id) + "_"; + } + if (fwd_pass) { + base_filepath += "fwd_step_" + std::to_string(m->decoding_step); + } else { + base_filepath += "bwd_step_" + std::to_string(m->bwd_step); + } + base_filepath += "_layers_" + + std::to_string(m->layer_guid.transformer_layer_id) + "_" + + op_name_without_uid + "_shard_" + std::to_string(shard_id); if (before_kernel) { base_filepath += "_pre"; } // save batch config, if passed if (bc != nullptr) { - bc->save_to_file(base_filepath + "_batch-config"); + bc->save_to_file(base_filepath + "_batch_config"); } // save all inputs for (int i = 0; i < input_tensors.size(); i++) { diff --git a/src/ops/add_bias_residual_layer_norm.cu b/src/ops/add_bias_residual_layer_norm.cu index ab017ed46c..505806a2b9 100644 --- a/src/ops/add_bias_residual_layer_norm.cu +++ b/src/ops/add_bias_residual_layer_norm.cu @@ -91,25 +91,6 @@ __inline__ __device__ T BlockReduceSum(T val, T *shared) { return val; } -template -__inline__ __device__ T BlockReduceSum(T val, T *shared, int max_num_threads) { - int const lid = threadIdx.x % C10_WARP_SIZE; - int const wid = threadIdx.x / C10_WARP_SIZE; - val = WarpReduceSum(val); - __syncthreads(); - if (lid == 0) { - shared[wid] = val; - } - __syncthreads(); - val = (threadIdx.x < (min(blockDim.x, max_num_threads) / C10_WARP_SIZE)) - ? shared[lid] - : T(0); - if (wid == 0) { - val = WarpReduceSum(val); - } - return val; -} - template __global__ void LayerNormFusedForwardKernel(int64_t N, int64_t attn_bias_dim, @@ -128,20 +109,17 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, const int64_t i = blockIdx.x; float sum1 = 0.0f; float sum2 = 0.0f; - for (int64_t j = threadIdx.x; j < N; - j += min(blockDim.x, kCUDABlockReduceNumThreads)) { + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { const int64_t index = i * N + j; const int64_t bias_idx = index % attn_bias_dim; X[index] = input_ptr[index] + attn_bias_ptr[bias_idx] + residual_ptr[index]; sum1 += static_cast(X[index]); sum2 += static_cast(X[index]) * static_cast(X[index]); } - if (threadIdx.x < kCUDABlockReduceNumThreads) { - sum1 = BlockReduceSum( - sum1, m_shared, min(blockDim.x, kCUDABlockReduceNumThreads)); - sum2 = BlockReduceSum( - sum2, v_shared, min(blockDim.x, kCUDABlockReduceNumThreads)); - } + + sum1 = BlockReduceSum(sum1, m_shared); + sum2 = BlockReduceSum(sum2, v_shared); + if (threadIdx.x == 0) { float const scale = float(1) / static_cast(N); sum1 *= scale; @@ -153,7 +131,7 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, __syncthreads(); using T_ACC = T; - for (int64_t j = threadIdx.x; j < N; j += min(blockDim.x, kCUDANumThreads)) { + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); @@ -179,30 +157,22 @@ void AddBiasResidualLayerNorm::inference_kernel( T const *gamma_ptr, T const *beta_ptr, cudaStream_t stream) { - - std::pair kernel1_parallelism = - std::make_pair(m->effective_batch_size, kCUDABlockReduceNumThreads); - std::pair kernel2_parallelism = - std::make_pair(m->effective_batch_size, kCUDANumThreads); - - int num_blocks = - std::max(kernel1_parallelism.first, kernel2_parallelism.first); - int num_threads = - std::max(kernel1_parallelism.second, kernel2_parallelism.second); - LayerNormFusedForwardKernel - <<>>(m->effective_num_elements, - attn_bias_dim, - m->eps, - input_ptr, - attn_bias_ptr, - residual_ptr, - added_output_ptr, - static_cast(m->mean_ptr), - static_cast(m->rstd_ptr), - gamma_ptr, - beta_ptr, - output_ptr); + <<effective_batch_size, + std::min(CUDA_NUM_THREADS, (int)m->effective_num_elements), + 0, + stream>>>(m->effective_num_elements, + attn_bias_dim, + m->eps, + input_ptr, + attn_bias_ptr, + residual_ptr, + added_output_ptr, + static_cast(m->mean_ptr), + static_cast(m->rstd_ptr), + gamma_ptr, + beta_ptr, + output_ptr); } /*static*/ @@ -242,20 +212,17 @@ void AddBiasResidualLayerNorm::inference_kernel_wrapper( } assert(num_peft_requests <= 1); - int tokens_previous_requests = 0; for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } // Skip non-PEFT requests if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { - // FIXME: use the new approach to computing token offset - tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; continue; } int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int in_dim = - added_output.domain.hi()[0] - added_output.domain.lo()[0] + 1; + int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; + int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; if (bc->requestsInfo[i].peft_bwd) { MemoryAllocator *allocator = m->handle.peft_activation_allocator; m->input_activation = allocator->allocate_instance_untyped( @@ -264,14 +231,14 @@ void AddBiasResidualLayerNorm::inference_kernel_wrapper( if (m->input_type[0] == DT_FLOAT) { checkCUDA(cudaMemcpyAsync( m->input_activation, - added_output.get_float_ptr() + tokens_previous_requests * in_dim, + added_output.get_float_ptr() + first_token_offset * in_dim, data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); } else if (m->input_type[0] == DT_HALF) { checkCUDA(cudaMemcpyAsync( m->input_activation, - added_output.get_half_ptr() + tokens_previous_requests * in_dim, + added_output.get_half_ptr() + first_token_offset * in_dim, data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); @@ -281,6 +248,7 @@ void AddBiasResidualLayerNorm::inference_kernel_wrapper( } } } + // inference kernel int attn_bias_dim = attn_bias.domain.hi()[0] - attn_bias.domain.lo()[0] + 1; int residual_volume = residual.domain.get_volume(); diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index dd0e2bb822..cabb8b204f 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -392,11 +392,6 @@ InferenceResult GenericTensorAccessorW parent; int batch_size = bc->num_active_infr_tokens(); ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size); - // Note that we free activation allocator here since argmax is the - // last operator in forward - if (m->handle.peft_activation_allocator != nullptr) { - m->handle.peft_activation_allocator->free_all(); - } InferenceResult ir; if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 452a8c09f6..4c3b0ee4b6 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -907,6 +907,22 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta *m, m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); } +std::string get_peft_dbg_folder(IncMultiHeadSelfAttentionMeta const *m, + int shard_id) { + std::string op_name_without_uid = + IncMultiHeadSelfAttention::get_op_name_without_uid(m); + char const *folder_path = "./inference_tensors/"; + std::string base_filepath = std::string(folder_path); + if (m->layer_guid.model_id > 0) { + base_filepath += "model_" + std::to_string(m->layer_guid.model_id) + "_"; + } + base_filepath += "bwd_step_" + std::to_string(m->bwd_step); + base_filepath += "_layers_" + + std::to_string(m->layer_guid.transformer_layer_id) + "_" + + op_name_without_uid + "_shard_" + std::to_string(shard_id); + return base_filepath; +} + template void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -934,6 +950,7 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, // compute_type = CUBLAS_COMPUTE_32F_FAST_16F; // } // #endif + for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; @@ -995,6 +1012,12 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + // save result to file for checking + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_o_proj_in_grad"; + save_tensor(C, m_ * n_, filename.c_str()); + } } // Step 2: compute gradients w.r.t. value { @@ -1046,6 +1069,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // save result to file for checking + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_v_proj_in_grad"; + save_tensor(C, m_ * n_ * m->num_q_heads, filename.c_str()); + std::string filename2 = + get_peft_dbg_folder(m, shard_id) + "_qk_prods_softmax"; + save_tensor(A, m_ * k_ * m->num_q_heads, filename2.c_str()); + } } // Step 3: compute gradients w.r.t. the qk_prods_softmax tensor { @@ -1094,6 +1126,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_qk_prods_softmax_grad"; + save_tensor( + C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + std::string filename2 = get_peft_dbg_folder(m, shard_id) + "_vcache"; + save_tensor( + B, m->vProjSize * m->num_q_heads * num_tokens, filename2.c_str()); + } } // Step 4: softmax backpropagation { @@ -1120,6 +1161,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, &beta, m->qk_tensor, m->qk_prods)); + + if (m->inference_debugging) { + DT *C = static_cast
(m->qk_prods); + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_qk_prods_softmax_grad_in"; + save_tensor( + C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + } + // TODO: fill all elements above diagonal to force causal attention size_t entries_above_diagonal = num_tokens * (num_tokens - 1) / 2; if (entries_above_diagonal > 0) { @@ -1135,6 +1185,13 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, entries_above_diagonal, DT(0.0f)); } + if (m->inference_debugging) { + DT *C = static_cast
(m->qk_prods); + std::string filename = get_peft_dbg_folder(m, shard_id) + + "_qk_prods_softmax_grad_in_masked"; + save_tensor( + C, num_tokens * num_tokens * m->num_q_heads, filename.c_str()); + } } // Step 5: compute gradients w.r.t. key { @@ -1189,6 +1246,16 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_query_activation"; + save_tensor( + B, m->qProjSize * m->num_q_heads * num_tokens, filename.c_str()); + std::string filename2 = + get_peft_dbg_folder(m, shard_id) + "_devkproj_pre"; + save_tensor( + C, num_tokens * (m->qProjSize * m->num_q_heads), filename2.c_str()); + } } // Step 6: compute gradients w.r.t query { @@ -1239,7 +1306,15 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_devQKVPRojArray_pre"; + save_tensor(C, + num_tokens * m->qProjSize * m->num_q_heads * 3, + filename.c_str()); + } } + // Step 7: perform rotary position embeddings (RoPE) bwd { if (*m->apply_rotary_embedding) { @@ -1257,8 +1332,30 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, m->qProjSize, num_tokens, m->hidden_size); + DT *C = static_cast
(m->devQKVProjArray); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_devQKVPRojArray"; + save_tensor(C, + num_tokens * m->qProjSize * m->num_q_heads * 3, + filename.c_str()); + } + } + + // matrix C: gradients for key (saved as part of m->devQKVProjArray) + // matrix C's layout: [num_tokens, qProjsize * num_heads, 3] + DT *C = + static_cast
(m->devQKVProjArray) + + num_tokens * + (m->qProjSize * + m->num_q_heads); // skip over regions reserved for Q gradients + if (m->inference_debugging) { + std::string filename = get_peft_dbg_folder(m, shard_id) + "_devkproj"; + save_tensor( + C, num_tokens * (m->qProjSize * m->num_q_heads), filename.c_str()); } } + // Step 8: compute gradients w.r.t. input { float alpha = 1.0f, beta = 0.0f; @@ -1300,6 +1397,11 @@ void peft_bwd_kernel(IncMultiHeadSelfAttentionMeta const *m, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + if (m->inference_debugging) { + std::string filename = + get_peft_dbg_folder(m, shard_id) + "_attn_final_grad_in"; + save_tensor(C, num_tokens * m->qSize, filename.c_str()); + } } } } diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index 5306be9bdf..a3f5c797de 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -200,17 +200,16 @@ void inference_kernel_wrapper(LinearMeta *m, stream); } else if (m->input_type[0] == DT_HALF) { Internal::forward_kernel(m, - input_ptr, - output_ptr, - weight_ptr, - bias_ptr, - in_dim, - out_dim, - batch_size, - stream); + input_ptr, + output_ptr, + weight_ptr, + bias_ptr, + in_dim, + out_dim, + batch_size, + stream); } - if (m->activation == AC_MODE_RELU || m->activation == AC_MODE_SIGMOID) { // save input activation if needed for PEFT if (bc->num_active_peft_tokens() > 0) { @@ -247,14 +246,14 @@ void inference_kernel_wrapper(LinearMeta *m, if (m->output_type[0] == DT_FLOAT) { checkCUDA(cudaMemcpyAsync( m->output_activation_buffer, - static_cast(output_ptr) + first_token_offset * out_dim, + static_cast(output_ptr) + first_token_offset * out_dim, data_type_size(m->output_type[0]) * num_peft_tokens * out_dim, cudaMemcpyDeviceToDevice, stream)); } else if (m->output_type[0] == DT_HALF) { checkCUDA(cudaMemcpyAsync( m->output_activation_buffer, - static_cast(output_ptr) + first_token_offset * out_dim, + static_cast(output_ptr) + first_token_offset * out_dim, data_type_size(m->output_type[0]) * num_peft_tokens * out_dim, cudaMemcpyDeviceToDevice, stream)); diff --git a/src/ops/kernels/residual_rms_norm_kernels.cu b/src/ops/kernels/residual_rms_norm_kernels.cu index b12d105c1b..664c1ed13b 100644 --- a/src/ops/kernels/residual_rms_norm_kernels.cu +++ b/src/ops/kernels/residual_rms_norm_kernels.cu @@ -221,7 +221,28 @@ void inference_kernel_wrapper(ResidualRMSNormMeta *m, assert(weight.data_type == output.data_type); assert(residual_output.data_type == output.data_type); - // save input activation if needed for PEFT + if (output.data_type == DT_HALF) { + forward_kernel(m, + input1.get_half_ptr(), + input2.get_half_ptr(), + weight.get_half_ptr(), + residual_output.get_half_ptr(), + output.get_half_ptr(), + stream); + } else if (output.data_type == DT_FLOAT) { + forward_kernel(m, + input1.get_float_ptr(), + input2.get_float_ptr(), + weight.get_float_ptr(), + residual_output.get_float_ptr(), + output.get_float_ptr(), + stream); + } else { + assert(false && "Unsupported data type"); + } + + // save input activation if needed for PEFT. This must be done after the + // forward kernel since that's where we add the residual if (bc->num_active_peft_tokens() > 0) { // Check that we have at most one request that requires peft_bwd int num_peft_requests = 0; @@ -247,7 +268,7 @@ void inference_kernel_wrapper(ResidualRMSNormMeta *m, continue; } int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int first_token_offset = bc->requestsInfo[i].num_tokens_in_batch; + int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; int in_dim = input1.domain.hi()[0] - input1.domain.lo()[0] + 1; if (bc->requestsInfo[i].peft_bwd) { MemoryAllocator *allocator = m->handle.peft_activation_allocator; @@ -275,26 +296,6 @@ void inference_kernel_wrapper(ResidualRMSNormMeta *m, } } - if (output.data_type == DT_HALF) { - forward_kernel(m, - input1.get_half_ptr(), - input2.get_half_ptr(), - weight.get_half_ptr(), - residual_output.get_half_ptr(), - output.get_half_ptr(), - stream); - } else if (output.data_type == DT_FLOAT) { - forward_kernel(m, - input1.get_float_ptr(), - input2.get_float_ptr(), - weight.get_float_ptr(), - residual_output.get_float_ptr(), - output.get_float_ptr(), - stream); - } else { - assert(false && "Unsupported data type"); - } - if (m->profiling) { cudaEventRecord(t_end, stream); checkCUDA(cudaEventSynchronize(t_end)); diff --git a/src/ops/kernels/rms_norm_kernels.cu b/src/ops/kernels/rms_norm_kernels.cu index d0702d651e..b11e954622 100644 --- a/src/ops/kernels/rms_norm_kernels.cu +++ b/src/ops/kernels/rms_norm_kernels.cu @@ -201,53 +201,53 @@ void inference_kernel_wrapper(RMSNormMeta *m, // save input activation if needed for PEFT if (bc->num_active_peft_tokens() > 0) { - // check that at most one dimension after the first is > 1. TODO(goliaro): - // support case where this condition does not hold - int non_unit_dims_encountered = 0; - for (int i = 1; i < input.domain.get_dim(); i++) { - int dim_i = input.domain.hi()[i] - input.domain.lo()[i] + 1; - if (dim_i > 1) { - non_unit_dims_encountered++; + // Check that we have at most one request that requires peft_bwd + int num_peft_requests = 0; + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i]) { + continue; + } + if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { + continue; + } + if (bc->requestsInfo[i].peft_bwd) { + num_peft_requests++; } } - assert(non_unit_dims_encountered <= 1); - - // allocate space for all peft tokens - MemoryAllocator *allocator = m->handle.peft_activation_allocator; - int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; - m->input_activation = allocator->allocate_instance_untyped( - data_type_size(input.data_type) * bc->num_active_peft_tokens() * - in_dim); - - int tokens_previous_requests = 0; + assert(num_peft_requests <= 1); for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } - // Skip non-PEFT requests and PEFT forward-only requests - if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID || - !bc->requestsInfo[i].peft_bwd) { - tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + // Skip non-PEFT requests + if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { continue; } int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; - - if (input.data_type == DT_FLOAT) { - checkCUDA(cudaMemcpyAsync( - m->input_activation, - input.get_float_ptr() + tokens_previous_requests * in_dim, - data_type_size(input.data_type) * num_peft_tokens * in_dim, - cudaMemcpyDeviceToDevice, - stream)); - } else if (input.data_type == DT_HALF) { - checkCUDA(cudaMemcpyAsync( - m->input_activation, - input.get_half_ptr() + tokens_previous_requests * in_dim, - data_type_size(input.data_type) * num_peft_tokens * in_dim, - cudaMemcpyDeviceToDevice, - stream)); - } else { - assert(false && "unsupport datatype in layernorm"); + int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; + int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; + if (bc->requestsInfo[i].peft_bwd) { + MemoryAllocator *allocator = m->handle.peft_activation_allocator; + m->input_activation = allocator->allocate_instance_untyped( + data_type_size(m->input_type[0]) * num_peft_tokens * in_dim); + + if (input.data_type == DT_FLOAT) { + checkCUDA(cudaMemcpyAsync( + m->input_activation, + input.get_float_ptr() + first_token_offset * in_dim, + data_type_size(input.data_type) * num_peft_tokens * in_dim, + cudaMemcpyDeviceToDevice, + stream)); + } else if (input.data_type == DT_HALF) { + checkCUDA(cudaMemcpyAsync( + m->input_activation, + input.get_half_ptr() + first_token_offset * in_dim, + data_type_size(input.data_type) * num_peft_tokens * in_dim, + cudaMemcpyDeviceToDevice, + stream)); + } else { + assert(false && "unsupport datatype in layernorm"); + } } } } diff --git a/src/ops/layer_norm.cu b/src/ops/layer_norm.cu index 1d4e94d7d5..bfbb2faae9 100644 --- a/src/ops/layer_norm.cu +++ b/src/ops/layer_norm.cu @@ -96,25 +96,6 @@ __inline__ __device__ T BlockReduceSum(T val, T *shared) { return val; } -template -__inline__ __device__ T BlockReduceSum(T val, T *shared, int max_num_threads) { - int const lid = threadIdx.x % C10_WARP_SIZE; - int const wid = threadIdx.x / C10_WARP_SIZE; - val = WarpReduceSum(val); - __syncthreads(); - if (lid == 0) { - shared[wid] = val; - } - __syncthreads(); - val = (threadIdx.x < (min(blockDim.x, max_num_threads) / C10_WARP_SIZE)) - ? shared[lid] - : T(0); - if (wid == 0) { - val = WarpReduceSum(val); - } - return val; -} - template __global__ void LayerNormFusedForwardKernel(int64_t N, float eps, @@ -129,18 +110,13 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, const int64_t i = blockIdx.x; float sum1 = 0.0f; float sum2 = 0.0f; - for (int64_t j = threadIdx.x; j < N; - j += min(blockDim.x, kCUDABlockReduceNumThreads)) { + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { const int64_t index = i * N + j; sum1 += static_cast(X[index]); sum2 += static_cast(X[index]) * static_cast(X[index]); } - if (threadIdx.x < kCUDABlockReduceNumThreads) { - sum1 = BlockReduceSum( - sum1, m_shared, min(blockDim.x, kCUDABlockReduceNumThreads)); - sum2 = BlockReduceSum( - sum2, v_shared, min(blockDim.x, kCUDABlockReduceNumThreads)); - } + sum1 = BlockReduceSum(sum1, m_shared); + sum2 = BlockReduceSum(sum2, v_shared); if (threadIdx.x == 0) { float const scale = float(1) / static_cast(N); sum1 *= scale; @@ -152,7 +128,7 @@ __global__ void LayerNormFusedForwardKernel(int64_t N, __syncthreads(); using T_ACC = T; - for (int64_t j = threadIdx.x; j < N; j += min(blockDim.x, kCUDANumThreads)) { + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); @@ -173,25 +149,18 @@ void LayerNorm::forward_kernel(LayerNormMeta const *m, T const *beta_ptr, cudaStream_t stream) { - std::pair kernel1_parallelism = - std::make_pair(m->effective_batch_size, kCUDABlockReduceNumThreads); - std::pair kernel2_parallelism = - std::make_pair(m->effective_batch_size, kCUDANumThreads); - - int num_blocks = - std::max(kernel1_parallelism.first, kernel2_parallelism.first); - int num_threads = - std::max(kernel1_parallelism.second, kernel2_parallelism.second); - LayerNormFusedForwardKernel - <<>>(m->effective_num_elements, - m->eps, - in_ptr, - static_cast(m->mean_ptr), - static_cast(m->rstd_ptr), - gamma_ptr, - beta_ptr, - out_ptr); + <<effective_batch_size, + std::min(CUDA_NUM_THREADS, (int)m->effective_num_elements), + 0, + stream>>>(m->effective_num_elements, + m->eps, + in_ptr, + static_cast(m->mean_ptr), + static_cast(m->rstd_ptr), + gamma_ptr, + beta_ptr, + out_ptr); } /*static*/ @@ -276,18 +245,16 @@ void LayerNorm::inference_kernel_wrapper(LayerNormMeta *m, } assert(num_peft_requests <= 1); - int tokens_previous_requests = 0; for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } // Skip non-PEFT requests if (bc->requestsInfo[i].peft_model_id == PEFTModelID::NO_ID) { - // FIXME: use the new approach to computing token offset - tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; continue; } int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; if (bc->requestsInfo[i].peft_bwd) { MemoryAllocator *allocator = m->handle.peft_activation_allocator; @@ -297,14 +264,14 @@ void LayerNorm::inference_kernel_wrapper(LayerNormMeta *m, if (m->input_type[0] == DT_FLOAT) { checkCUDA(cudaMemcpyAsync( m->input_activation, - input.get_float_ptr() + tokens_previous_requests * in_dim, + input.get_float_ptr() + first_token_offset * in_dim, data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); } else if (m->input_type[0] == DT_HALF) { checkCUDA(cudaMemcpyAsync( m->input_activation, - input.get_half_ptr() + tokens_previous_requests * in_dim, + input.get_half_ptr() + first_token_offset * in_dim, data_type_size(m->input_type[0]) * num_peft_tokens * in_dim, cudaMemcpyDeviceToDevice, stream)); diff --git a/src/ops/linear.cc b/src/ops/linear.cc index e23a6f48ca..209f514f65 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -621,6 +621,8 @@ void Linear::inference_task(Task const *task, ctx, task->regions[0].region.get_index_space()); LinearMeta *m = *((LinearMeta **)task->local_args); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); + std::string op_name_without_uid = Linear::get_op_name_without_uid(m); + printf("INF %s\n", op_name_without_uid.c_str()); if (bc->num_tokens == 0) { return; } @@ -653,14 +655,14 @@ void Linear::inference_task(Task const *task, assert(bias.domain.get_volume() == static_cast(out_dim)); } inference_kernel_wrapper(m, - bc, - input.ptr, - output.ptr, - weight.ptr, - bias.ptr, - in_dim, - out_dim, - batch_size); + bc, + input.ptr, + output.ptr, + weight.ptr, + bias.ptr, + in_dim, + out_dim, + batch_size); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; @@ -735,7 +737,7 @@ void Linear::peft_bwd_task(Task const *task, return; } assert(regions.size() == 3); - assert(task->regions.size() == 3 ); + assert(task->regions.size() == 3); if (m->quantization_type == DT_NONE) { assert(m->input_type[0] == m->weight_type[0]); } diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index e39b444af4..c02bddc5a6 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -473,7 +473,7 @@ void LoraLinear::inference_task(Task const *task, int shard_id = task->index_point.point_data[0]; // Check if output directory exists, and create it if it does not - char const *folder_path = "./inference_tensors"; + char const *folder_path = "./inference_tensors/"; struct stat st = {0}; if (stat(folder_path, &st) == -1) { // Directory does not exist, create it @@ -493,15 +493,18 @@ void LoraLinear::inference_task(Task const *task, lora_layername.substr(0, found + searchString.length()); // output base filepath, shared by all tensors from the same operator - std::string base_filepath = - "./inference_tensors/model_" + std::to_string(m->layer_guid.model_id) + - "_decoding-step_" + std::to_string(m->decoding_step) + "_layer-num_" + - std::to_string(m->layer_guid.transformer_layer_id) + "_layer-name_" + - lora_layername_substr + "_shard-id_" + std::to_string(shard_id); + std::string base_filepath = std::string(folder_path); + if (m->layer_guid.model_id > 0) { + base_filepath += "model_" + std::to_string(m->layer_guid.model_id) + "_"; + } + base_filepath += "fwd_step_" + std::to_string(m->decoding_step); + base_filepath += + "_layers_" + std::to_string(m->layer_guid.transformer_layer_id) + "_" + + lora_layername_substr + "_shard_" + std::to_string(shard_id); // save batch config, if passed if (bc != nullptr) { - bc->save_to_file(base_filepath + "_batch-config"); + bc->save_to_file(base_filepath + "_batch_config"); } std::string filename = base_filepath + "_input_" + std::to_string(0); @@ -634,7 +637,7 @@ void LoraLinear::peft_bwd_task(Task const *task, int shard_id = task->index_point.point_data[0]; // Check if output directory exists, and create it if it does not - char const *folder_path = "./inference_tensors"; + char const *folder_path = "./inference_tensors/"; struct stat st = {0}; if (stat(folder_path, &st) == -1) { // Directory does not exist, create it @@ -654,15 +657,18 @@ void LoraLinear::peft_bwd_task(Task const *task, lora_layername.substr(0, found + searchString.length()); // output base filepath, shared by all tensors from the same operator - std::string base_filepath = - "./inference_tensors/model_" + std::to_string(m->layer_guid.model_id) + - "_bwd-step_" + std::to_string(m->bwd_step) + "_layer-num_" + - std::to_string(m->layer_guid.transformer_layer_id) + "_layer-name_" + - lora_layername_substr + "_shard-id_" + std::to_string(shard_id); + std::string base_filepath = std::string(folder_path); + if (m->layer_guid.model_id > 0) { + base_filepath += "model_" + std::to_string(m->layer_guid.model_id) + "_"; + } + base_filepath += "bwd_step_" + std::to_string(m->bwd_step); + base_filepath += + "_layers_" + std::to_string(m->layer_guid.transformer_layer_id) + "_" + + lora_layername_substr + "_shard_" + std::to_string(shard_id); // save batch config, if passed if (bc != nullptr) { - bc->save_to_file(base_filepath + "_batch-config"); + bc->save_to_file(base_filepath + "_batch_config"); } std::string filename = base_filepath + "_input_" + std::to_string(0); diff --git a/src/ops/residual_layer_norm.cu b/src/ops/residual_layer_norm.cu index 1f87949234..0ba462cde5 100644 --- a/src/ops/residual_layer_norm.cu +++ b/src/ops/residual_layer_norm.cu @@ -91,25 +91,6 @@ __inline__ __device__ T BlockReduceSum(T val, T *shared) { return val; } -template -__inline__ __device__ T BlockReduceSum(T val, T *shared, int max_num_threads) { - int const lid = threadIdx.x % C10_WARP_SIZE; - int const wid = threadIdx.x / C10_WARP_SIZE; - val = WarpReduceSum(val); - __syncthreads(); - if (lid == 0) { - shared[wid] = val; - } - __syncthreads(); - val = (threadIdx.x < min(blockDim.x, max_num_threads) / C10_WARP_SIZE) - ? shared[lid] - : 0; - if (wid == 0) { - val = WarpReduceSum(val); - } - return val; -} - template __global__ void ResidualLayerNormKernel(int64_t N, float eps, @@ -127,8 +108,7 @@ __global__ void ResidualLayerNormKernel(int64_t N, const int64_t i = blockIdx.x; float sum1 = 0.0f; float sum2 = 0.0f; - for (int64_t j = threadIdx.x; j < N; - j += min(blockDim.x, kCUDABlockReduceNumThreads)) { + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { const int64_t index = i * N + j; const T residual2_val = (residual2_ptr == nullptr) ? T(0) @@ -137,12 +117,10 @@ __global__ void ResidualLayerNormKernel(int64_t N, sum1 += static_cast(X[index]); sum2 += static_cast(X[index]) * static_cast(X[index]); } - if (threadIdx.x < kCUDABlockReduceNumThreads) { - sum1 = BlockReduceSum( - sum1, m_shared, min(blockDim.x, kCUDABlockReduceNumThreads)); - sum2 = BlockReduceSum( - sum2, v_shared, min(blockDim.x, kCUDABlockReduceNumThreads)); - } + + sum1 = BlockReduceSum(sum1, m_shared); + sum2 = BlockReduceSum(sum2, v_shared); + if (threadIdx.x == 0) { float const scale = float(1) / static_cast(N); sum1 *= scale; @@ -154,7 +132,7 @@ __global__ void ResidualLayerNormKernel(int64_t N, __syncthreads(); using T_ACC = T; - for (int64_t j = threadIdx.x; j < N; j += min(blockDim.x, kCUDANumThreads)) { + for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { const int64_t index = i * N + j; const T_ACC gamma_v = gamma == nullptr ? T_ACC(1) : static_cast(gamma[j]); @@ -178,28 +156,51 @@ void ResidualLayerNorm::inference_kernel(ResidualLayerNormMeta const *m, T const *beta_ptr, cudaStream_t stream) { - std::pair kernel1_parallelism = - std::make_pair(m->effective_batch_size, kCUDABlockReduceNumThreads); - std::pair kernel2_parallelism = - std::make_pair(m->effective_batch_size, kCUDANumThreads); - - int num_blocks = - std::max(kernel1_parallelism.first, kernel2_parallelism.first); - int num_threads = - std::max(kernel1_parallelism.second, kernel2_parallelism.second); - ResidualLayerNormKernel - <<>>(m->effective_num_elements, - m->eps, - input_ptr, - residual1_ptr, - residual2_ptr, - added_output_ptr, - static_cast(m->mean_ptr), - static_cast(m->rstd_ptr), - gamma_ptr, - beta_ptr, - output_ptr); + <<effective_batch_size, + std::min(CUDA_NUM_THREADS, (int)m->effective_num_elements), + 0, + stream>>>(m->effective_num_elements, + m->eps, + input_ptr, + residual1_ptr, + residual2_ptr, + added_output_ptr, + static_cast(m->mean_ptr), + static_cast(m->rstd_ptr), + gamma_ptr, + beta_ptr, + output_ptr); +} +template +void save_inference_tensors(ResidualLayerNormMeta const *m) { + if (m->inference_debugging) { + // save stuff here + std::string op_name_without_uid = + ResidualLayerNorm::get_op_name_without_uid(m); + char const *folder_path = "./inference_tensors/"; + std::string base_filepath = std::string(folder_path); + if (m->layer_guid.model_id > 0) { + base_filepath += "model_" + std::to_string(m->layer_guid.model_id) + "_"; + } + base_filepath += "fwd_step_" + std::to_string(m->decoding_step); + base_filepath += "_layers_" + + std::to_string(m->layer_guid.transformer_layer_id) + "_" + + op_name_without_uid + "_shard_" + std::to_string(0); + + std::string filename1 = base_filepath + "_mean"; + save_tensor(static_cast(m->mean_ptr), + m->effective_batch_size, + filename1.c_str()); + std::string filename2 = base_filepath + "_rstd"; + save_tensor(static_cast(m->rstd_ptr), + m->effective_batch_size, + filename2.c_str()); + std::string filename3 = base_filepath + "_input_activation"; + save_tensor(static_cast(m->input_activation), + m->effective_batch_size * m->effective_num_elements, + filename3.c_str()); + } } /*static*/ @@ -222,6 +223,33 @@ void ResidualLayerNorm::inference_kernel_wrapper( cudaEventCreate(&t_end); cudaEventRecord(t_start, stream); } + + if (m->input_type[0] == DT_FLOAT) { + ResidualLayerNorm::inference_kernel( + m, + input.get_float_ptr(), + residual1.get_float_ptr(), + m->use_two_residuals ? residual2.get_float_ptr() : nullptr, + added_output.get_float_ptr(), + output.get_float_ptr(), + m->elementwise_affine ? gamma.get_float_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_float_ptr() : nullptr, + stream); + } else if (m->input_type[0] == DT_HALF) { + ResidualLayerNorm::inference_kernel( + m, + input.get_half_ptr(), + residual1.get_half_ptr(), + m->use_two_residuals ? residual2.get_half_ptr() : nullptr, + added_output.get_half_ptr(), + output.get_half_ptr(), + m->elementwise_affine ? gamma.get_half_ptr() : nullptr, + (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, + stream); + } else { + assert(false && "unsupport datatype in layernorm"); + } + // save input activation if needed for PEFT if (bc->num_active_peft_tokens() > 0) { // Check that we have at most one request that requires peft_bwd @@ -248,7 +276,7 @@ void ResidualLayerNorm::inference_kernel_wrapper( continue; } int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int first_token_offset = bc->requestsInfo[i].num_tokens_in_batch; + int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch; int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; if (bc->requestsInfo[i].peft_bwd) { MemoryAllocator *allocator = m->handle.peft_activation_allocator; @@ -276,30 +304,14 @@ void ResidualLayerNorm::inference_kernel_wrapper( } } - if (m->input_type[0] == DT_FLOAT) { - ResidualLayerNorm::inference_kernel( - m, - input.get_float_ptr(), - residual1.get_float_ptr(), - m->use_two_residuals ? residual2.get_float_ptr() : nullptr, - added_output.get_float_ptr(), - output.get_float_ptr(), - m->elementwise_affine ? gamma.get_float_ptr() : nullptr, - (m->elementwise_affine && m->use_bias) ? beta.get_float_ptr() : nullptr, - stream); - } else if (m->input_type[0] == DT_HALF) { - ResidualLayerNorm::inference_kernel( - m, - input.get_half_ptr(), - residual1.get_half_ptr(), - m->use_two_residuals ? residual2.get_half_ptr() : nullptr, - added_output.get_half_ptr(), - output.get_half_ptr(), - m->elementwise_affine ? gamma.get_half_ptr() : nullptr, - (m->elementwise_affine && m->use_bias) ? beta.get_half_ptr() : nullptr, - stream); - } else { - assert(false && "unsupport datatype in layernorm"); + if (m->inference_debugging) { + if (m->input_type[0] == DT_FLOAT) { + save_inference_tensors(m); + } else if (m->input_type[0] == DT_HALF) { + save_inference_tensors(m); + } else { + assert(false && "unsupport datatype in layernorm"); + } } if (m->profiling) { @@ -740,6 +752,34 @@ void peft_bwd_kernel(ResidualLayerNormMeta const *m, const int64_t M = m->effective_batch_size; const int64_t N = m->effective_num_elements; + if (m->inference_debugging) { + // save stuff here + std::string op_name_without_uid = + ResidualLayerNorm::get_op_name_without_uid(m); + char const *folder_path = "./inference_tensors/"; + std::string base_filepath = std::string(folder_path); + if (m->layer_guid.model_id > 0) { + base_filepath += "model_" + std::to_string(m->layer_guid.model_id) + "_"; + } + base_filepath += "bwd_step_" + std::to_string(m->bwd_step); + base_filepath += "_layers_" + + std::to_string(m->layer_guid.transformer_layer_id) + "_" + + op_name_without_uid + "_shard_" + std::to_string(0); + + std::string filename1 = base_filepath + "_mean"; + save_tensor(static_cast(m->mean_ptr), + m->effective_batch_size, + filename1.c_str()); + std::string filename2 = base_filepath + "_rstd"; + save_tensor(static_cast(m->rstd_ptr), + m->effective_batch_size, + filename2.c_str()); + std::string filename3 = base_filepath + "_input_activation"; + save_tensor(static_cast(m->input_activation), + m->effective_batch_size * m->effective_num_elements, + filename3.c_str()); + } + int const warp_size = C10_WARP_SIZE; int const num_threads = 128; const dim3 blocks(M); diff --git a/tests/peft/alignment/align_test_utils.py b/tests/peft/alignment/align_test_utils.py new file mode 100644 index 0000000000..b0cb5fe428 --- /dev/null +++ b/tests/peft/alignment/align_test_utils.py @@ -0,0 +1,240 @@ +import os, re, torch +import numpy as np +abs_dirname = os.path.dirname(os.path.abspath(__file__)) +hf_path = os.path.join(abs_dirname, "hf_peft_tensors") +ff_path = os.path.join(os.path.dirname(os.path.dirname(abs_dirname)), "build", "inference_tensors") +def print_unique_files_list(dirname): + files_list = os.listdir(dirname) + for f in sorted(files_list): + match = re.search(r'layers.\d+', f) + if match: + if "layers." in match[0]: + layer_num = int(match[0].split(".")[1]) + if layer_num > 0: + files_list.remove(f) + elif "layers_" in match[0]: + layer_num = int(match[0].split("_")[1]) + if layer_num > 0 and layer_num != 100: + files_list.remove(f) + return sorted(files_list) +def compare_tensors(hf_tensor_filepath, ff_tensor_filepath, tolerance=1e-2): + if not (os.path.exists(hf_tensor_filepath) and os.path.exists(ff_tensor_filepath)): + print(hf_tensor_filepath, os.path.exists(hf_tensor_filepath)) + print(ff_tensor_filepath, os.path.exists(ff_tensor_filepath)) + assert False + hf_tensor = torch.load(hf_tensor_filepath) + if type(hf_tensor) == tuple or type(hf_tensor) == list: + assert(len(hf_tensor) == 1) + hf_tensor = hf_tensor[0] + hf_tensor = torch.nan_to_num(hf_tensor) + hf_tensor = hf_tensor.flatten().detach().cpu().numpy() + ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',') + + len_hf_tensor = hf_tensor.shape[0] + ff_tensor = ff_tensor[:len_hf_tensor] + + mismatches = [] + if not np.allclose(ff_tensor, hf_tensor, atol=tolerance): + print(f"mismatch between {hf_tensor_filepath} and {ff_tensor_filepath}") + print(f"HF: {hf_tensor}\nFF:{ff_tensor}") + print(np.isclose(ff_tensor, hf_tensor, atol=tolerance)) + mismatches = np.where(~np.isclose(ff_tensor, hf_tensor, atol=tolerance))[0] + print(mismatches) + #print(np.nonzero(hf_tensor)[0]) + # print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0]) + # print(ff_tensor[36], hf_tensor[36]) + #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance)) + assert(len(mismatches) <= .05*len_hf_tensor) + print("Ok!") +def compare_tensors_difference(hf_tensor_filepath, ff_tensor1_filepath, ff_tensor2_filepath, tolerance=1e-2): + assert(os.path.exists(hf_tensor_filepath)) + assert(os.path.exists(ff_tensor1_filepath)) + assert(os.path.exists(ff_tensor2_filepath)) + hf_tensor = torch.load(hf_tensor_filepath) + if type(hf_tensor) == tuple or type(hf_tensor) == list: + assert(len(hf_tensor) == 1) + hf_tensor = hf_tensor[0] + hf_tensor = torch.nan_to_num(hf_tensor) + hf_tensor = hf_tensor.flatten().detach().cpu().numpy() + ff_tensor1 = np.loadtxt(ff_tensor1_filepath, delimiter=',') + ff_tensor2 = np.loadtxt(ff_tensor2_filepath, delimiter=',') + + len_hf_tensor = hf_tensor.shape[0] + ff_tensor1 = ff_tensor1[:len_hf_tensor] + ff_tensor2 = ff_tensor2[:len_hf_tensor] + ff_tensor = ff_tensor1 - ff_tensor2 + + mismatches = [] + if not np.allclose(ff_tensor, hf_tensor, atol=tolerance): + print(f"mismatch between {hf_tensor_filepath} and {ff_tensor1_filepath} - {ff_tensor2_filepath}") + print(f"HF: {hf_tensor}\nFF:{ff_tensor}") + print(np.isclose(ff_tensor, hf_tensor, atol=tolerance)) + mismatches = np.where(~np.isclose(ff_tensor, hf_tensor, atol=tolerance))[0] + print(mismatches) + #print(np.nonzero(hf_tensor)[0]) + # print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0]) + # print(ff_tensor[36], hf_tensor[36]) + #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance)) + assert(len(mismatches) <= .05*len_hf_tensor) + print("Ok!") +def compare_hf_tensors(tensor1_fp, tensor2_fp): + assert(os.path.exists(tensor1_fp) and os.path.exists(tensor2_fp)) + hf_tensor1 = torch.load(tensor1_fp) + hf_tensor2 = torch.load(tensor2_fp) + if type(hf_tensor1) == tuple or type(hf_tensor1) == list: + assert(len(hf_tensor1) == 1) + hf_tensor1 = hf_tensor1[0] + if type(hf_tensor2) == tuple or type(hf_tensor2) == list: + assert(len(hf_tensor2) == 1) + hf_tensor2 = hf_tensor2[0] + assert(torch.squeeze(hf_tensor1).shape == torch.squeeze(hf_tensor2).shape) + hf_tensor1 = torch.nan_to_num(hf_tensor1) + hf_tensor2 = torch.nan_to_num(hf_tensor2) + if not (np.allclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy())): + print(f"mismatch between {tensor1_fp} and {tensor2_fp}") + print(hf_tensor1) + print(hf_tensor2) + print(np.isclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy())) + mismatches = np.where(~np.isclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy()))[0] + print(mismatches) + assert(False) + print("Ok!") + +def check_hf_sum_tensors(tensor_sum_fp, tensor1_fp, tensor2_fp): + assert(os.path.exists(tensor_sum_fp) and os.path.exists(tensor1_fp) and os.path.exists(tensor2_fp)) + hf_tensor_sum = torch.load(tensor_sum_fp) + hf_tensor1 = torch.load(tensor1_fp) + hf_tensor2 = torch.load(tensor2_fp) + if type(hf_tensor_sum) == tuple or type(hf_tensor_sum) == list: + assert(len(hf_tensor_sum) == 1) + hf_tensor_sum = hf_tensor_sum[0] + if type(hf_tensor1) == tuple or type(hf_tensor1) == list: + assert(len(hf_tensor1) == 1) + hf_tensor1 = hf_tensor1[0] + if type(hf_tensor2) == tuple or type(hf_tensor2) == list: + assert(len(hf_tensor2) == 1) + hf_tensor2 = hf_tensor2[0] + assert(torch.squeeze(hf_tensor_sum).shape == torch.squeeze(hf_tensor1).shape) + assert(torch.squeeze(hf_tensor1).shape == torch.squeeze(hf_tensor2).shape) + hf_tensor1 = torch.nan_to_num(hf_tensor1) + hf_tensor2 = torch.nan_to_num(hf_tensor2) + hf_tensor_sum = torch.nan_to_num(hf_tensor_sum) + sum_check_tensor = hf_tensor1 + hf_tensor2 + if not (np.allclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy())): + print(f"mismatch between {sum_check_tensor} and {tensor1_fp} + {tensor2_fp}") + print(tensor_sum_fp) + print(sum_check_tensor) + print(hf_tensor1) + print(hf_tensor2) + print(np.isclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy())) + mismatches = np.where(~np.isclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy()))[0] + print(mismatches) + assert(False) + print("Ok!") +def check_hf_zero_tensor(hf_tensor_fp): + assert(os.path.exists(hf_tensor_fp)) + hf_tensor1 = torch.load(hf_tensor_fp) + if type(hf_tensor1) == tuple or type(hf_tensor1) == list: + assert(len(hf_tensor1) == 1) + hf_tensor1 = hf_tensor1[0] + assert(torch.count_nonzero(torch.nan_to_num(hf_tensor1)).sum() == 0) +def print_tensors(hf_tensor_filepath, ff_tensor_filepath, txt=""): + assert(os.path.exists(hf_tensor_filepath) and os.path.exists(ff_tensor_filepath)) + hf_tensor = torch.load(hf_tensor_filepath) + if type(hf_tensor) == tuple or type(hf_tensor) == list: + assert(len(hf_tensor) == 1) + hf_tensor = hf_tensor[0] + hf_tensor = torch.nan_to_num(hf_tensor) + hf_tensor = hf_tensor.flatten().detach().cpu().numpy() + ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',') + + len_hf_tensor = hf_tensor.shape[0] + ff_tensor = ff_tensor[:len_hf_tensor] + + print(f"{txt} - HF tensor:") + print(hf_tensor) + print(f"{txt} - FF tensor: ") + print(ff_tensor) +def compare_flexflow_tensors(ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5, max_len=-1): + assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp)) + ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',') + ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',') + + if (ff_tensor1.shape != ff_tensor2.shape): + print(ff_tensor1.shape, ff_tensor2.shape) + assert(ff_tensor1.shape == ff_tensor2.shape) + + if max_len > -1: + ff_tensor1 = ff_tensor1[:max_len] + ff_tensor2 = ff_tensor2[:max_len] + + mismatches = [] + if not np.allclose(ff_tensor1, ff_tensor2, atol=tolerance): + print(f"mismatch between {ff_tensor1_fp} and {ff_tensor2_fp}") + print(f"Tensor1: {ff_tensor1}\nTensor2:{ff_tensor2}") + print(np.isclose(ff_tensor1, ff_tensor2, atol=tolerance)) + mismatches = np.where(~np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))[0] + print(mismatches) + #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance)) + assert(len(mismatches) <= .05*len(ff_tensor1)) + print("Ok!") +def compare_flexflow_tensors_shortest(ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5): + assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp)) + ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',') + ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',') + minlen = min(ff_tensor1.shape[0], ff_tensor2.shape[0]) + ff_tensor1 = ff_tensor1[:minlen] + ff_tensor2 = ff_tensor2[:minlen] + mismatches = [] + if not np.allclose(ff_tensor1, ff_tensor2, atol=tolerance): + print(f"mismatch between {ff_tensor1_fp} and {ff_tensor2_fp}") + print(f"Tensor1: {ff_tensor1}\nTensor2:{ff_tensor2}") + print(np.isclose(ff_tensor1, ff_tensor2, atol=tolerance)) + mismatches = np.where(~np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))[0] + print(mismatches) + #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance)) + assert(len(mismatches) <= .05*len(ff_tensor1)) + print("Ok!") +def check_flexflow_tensors_sum(ff_tensor_sum_fp, ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5): + assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp)) + ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',') + ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',') + ff_tensor_sum = np.loadtxt(ff_tensor_sum_fp, delimiter=',') + + ff_sum = ff_tensor1 + ff_tensor2 + assert(ff_tensor1.shape == ff_tensor2.shape) + + mismatches = [] + if not np.allclose(ff_tensor_sum, ff_sum, atol=tolerance): + print(f"mismatch between {ff_tensor_sum_fp} and sum of {ff_tensor1_fp} + {ff_tensor2_fp}") + print(f"Tensor1: {ff_tensor1}\nTensor2:{ff_tensor2}") + print(f"Sum Tensor: {ff_tensor_sum}\nActual sum:{ff_sum}") + print(np.isclose(ff_tensor_sum, ff_sum, atol=tolerance)) + mismatches = np.where(~np.isclose(ff_tensor_sum, ff_sum, atol=tolerance))[0] + print(mismatches) + #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance)) + assert(len(mismatches) <= .05*len(ff_tensor1)) + print("Ok!") +def load_ff_tensor(filename, shape): + if ff_path not in filename: + filename = os.path.join(ff_path, filename) + ff_tensor = np.loadtxt(filename, delimiter=',').reshape(shape, order = 'F') + return ff_tensor +def load_hf_tensor(filename): + if hf_path not in filename: + filename = os.path.join(hf_path, filename) + hf_tensor = torch.load(filename) + hf_tensor = hf_tensor.detach().cpu().numpy() + return hf_tensor +def compare_loaded_tensors(hf_tensor, ff_tensor, tolerance=1e-2): + assert(hf_tensor.shape == ff_tensor.shape) + mismatches = [] + if not np.allclose(hf_tensor, ff_tensor, atol=tolerance): + print(f"mismatch between hf_tensor and ff_tensor") + print(f"HF: {hf_tensor}\nFF:{ff_tensor}") + print(np.isclose(hf_tensor, ff_tensor, atol=tolerance)) + mismatches = np.where(~np.isclose(hf_tensor, ff_tensor, atol=tolerance))[0] + print(mismatches) + len_hf_tensor = hf_tensor.flatten().shape[0] + assert(len(mismatches) <= .05*len_hf_tensor) + print("Ok!") \ No newline at end of file diff --git a/tests/peft/alignment/llama_alignment_tests.ipynb b/tests/peft/alignment/llama_alignment_tests.ipynb new file mode 100644 index 0000000000..414280cff5 --- /dev/null +++ b/tests/peft/alignment/llama_alignment_tests.ipynb @@ -0,0 +1,2039 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import os, torch\n", + "from align_test_utils import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n" + ] + } + ], + "source": [ + "tot_num_layers = 12\n", + "for i in range(tot_num_layers):\n", + " hf_input_ln_out = f\"{hf_path}/fwd_step_0_layers.{i}.input_layernorm.output_0\"\n", + " ff_input_ln_out = f\"{ff_path}/fwd_step_0_layers_{i}_RMSNorm_shard_0_output_0\"\n", + " if i > 0:\n", + " ff_input_ln_out = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_attention_norm_shard_0_output_1\"\n", + " compare_tensors(hf_input_ln_out, ff_input_ln_out)\n", + " hf_attn_out = f\"{hf_path}/fwd_step_0_layers.{i}.self_attn.o_proj.output_0\"\n", + " ff_attn_out = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_attention_shard_0_output_0\"\n", + " compare_tensors(hf_attn_out, ff_attn_out)\n", + " hf_ffn_norm_out = f\"{hf_path}/fwd_step_0_layers.{i}.post_attention_layernorm.output_0\"\n", + " ff_ffn_norm_out = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_ffn_norm_shard_0_output_1\"\n", + " compare_tensors(hf_ffn_norm_out, ff_ffn_norm_out)\n", + " # w1\n", + " hf_gate_proj_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.gate_proj.output_0\"\n", + " ff_gate_proj_out = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w1_shard_0_output_0\"\n", + " compare_tensors(hf_gate_proj_out, ff_gate_proj_out)\n", + " # w3\n", + " hf_up_proj_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.up_proj.output_0\" \n", + " ff_up_proj_out = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w3_shard_0_output_0\"\n", + " compare_tensors(hf_up_proj_out, ff_up_proj_out)\n", + " # w2\n", + " hf_down_proj_in = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.down_proj.input_0\"\n", + " hf_down_proj_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.down_proj.output_0\"\n", + " ff_down_proj_in = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_shard_0_input_0\"\n", + " ff_down_proj_out = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_shard_0_output_0\"\n", + " compare_tensors(hf_down_proj_in, ff_down_proj_in)\n", + " # compare_tensors(hf_down_proj_out, ff_down_proj_out)\n", + " # LORA input\n", + " hf_lora_A_in = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.down_proj.lora_A.default.input_0\"\n", + " ff_lora_A_in = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_input_0\"\n", + " compare_hf_tensors(hf_down_proj_in, hf_lora_A_in)\n", + " compare_tensors(hf_lora_A_in, ff_lora_A_in)\n", + " # LORA weights\n", + " hf_lora_A_weight_fp = f\"{hf_path}/layers.{i}.mlp.down_proj.lora_A.default.weight\"\n", + " ff_lora_A_weight_fp = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_weight_A\"\n", + " compare_tensors(hf_lora_A_weight_fp, ff_lora_A_weight_fp)\n", + " hf_lora_B_weight_fp = f\"{hf_path}/layers.{i}.mlp.down_proj.lora_B.default.weight\"\n", + " ff_lora_B_weight_fp = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_weight_B\"\n", + " compare_tensors(hf_lora_B_weight_fp, ff_lora_B_weight_fp)\n", + " # LORA intermediate hf\n", + " hf_lora_A_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.down_proj.lora_A.default.output_0\"\n", + " hf_lora_B_in = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.down_proj.lora_B.default.input_0\"\n", + " compare_hf_tensors(hf_lora_A_out, hf_lora_B_in)\n", + " # LORA output\n", + " hf_lora_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.down_proj.lora_B.default.output_0\"\n", + " ff_lora_out = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_output_0\"\n", + " # compare_tensors(hf_lora_out, ff_lora_out)\n", + " # compare_flexflow_tensors(ff_down_proj_out, ff_lora_out)\n", + " # compare_tensors(hf_down_proj_out, ff_lora_out)\n", + " compare_tensors_difference(hf_lora_out, ff_lora_out, ff_down_proj_out)\n", + " \n", + "\n", + "# After last layer only\n", + "hf_norm_out = f\"{hf_path}/fwd_step_0_norm.output_0\"\n", + "ff_norm_out = f\"{ff_path}/fwd_step_0_layers_{tot_num_layers-1}_norm_shard_0_output_1\"\n", + "compare_tensors(hf_norm_out, ff_norm_out)\n", + "hf_lm_head_out = f\"{hf_path}/fwd_step_0_base_model.model.lm_head.output_0\"\n", + "ff_lm_head_out = f\"{ff_path}/fwd_step_0_layers_{tot_num_layers-1}_output_shard_0_output_0\"\n", + "compare_tensors(hf_lm_head_out, ff_lm_head_out)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-- LM head --\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "-- Final Norm --\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n" + ] + } + ], + "source": [ + "tot_num_layers = 12\n", + "\n", + "# ff_BWD_softmax_in = f\"{ff_path}/model_0_bwd-step_0_layer-num_100_layer-name_Softmax_shard-id_0_input_0\"\n", + "print(\"-- LM head --\")\n", + "hf_BWD_lm_head_out = f\"{hf_path}/bwd_step_0_base_model.model.lm_head.go_0\"\n", + "ff_BWD_lm_head_out = f\"{ff_path}/bwd_step_0_layers_{tot_num_layers-1}_output_shard_0_output_0\"\n", + "compare_tensors(hf_BWD_lm_head_out, ff_BWD_lm_head_out, tolerance=1e-5)\n", + "# compare weights\n", + "hf_lm_head_weight = f\"{hf_path}/base_model.model.lm_head.weight\"\n", + "ff_lm_head_weight = f\"{ff_path}/fwd_step_0_layers_{tot_num_layers-1}_output_shard_0_weight_0\"\n", + "compare_tensors(hf_lm_head_weight, ff_lm_head_weight, tolerance=1e-5)\n", + "hf_BWD_lm_head_in = f\"{hf_path}/bwd_step_0_base_model.model.lm_head.gi_0\"\n", + "ff_BWD_lm_head_in = f\"{ff_path}/bwd_step_0_layers_{tot_num_layers-1}_output_shard_0_input_0\"\n", + "compare_tensors(hf_BWD_lm_head_in, ff_BWD_lm_head_in, tolerance=1e-5)\n", + "# # Manually check the matmul\n", + "# ff_tensor_out = np.loadtxt(ff_BWD_lm_head_out, delimiter=',')\n", + "# ff_weight = np.loadtxt(ff_lm_head_weight, delimiter=',').reshape((4096,32000), order='F')\n", + "# ff_tensor_out = ff_tensor_out[:32000*24].reshape((32000,24), order='F')\n", + "# print(ff_tensor_out.shape)\n", + "# print(ff_weight.shape)\n", + "# print(np.matmul(ff_weight, ff_tensor_out))\n", + "# compare_tensors(hf_BWD_lm_head_in, ff_BWD_lm_head_in)\n", + "# ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',')\n", + "print(\"-- Final Norm --\")\n", + "hf_BWD_norm_out = f\"{hf_path}/bwd_step_0_norm.go_0\"\n", + "ff_BWD_norm_out = f\"{ff_path}/bwd_step_0_layers_{tot_num_layers-1}_norm_shard_0_output_0\"\n", + "compare_hf_tensors(hf_BWD_lm_head_in, hf_BWD_norm_out)\n", + "compare_tensors(hf_BWD_norm_out, ff_BWD_norm_out)\n", + "ff_BWD_norm_weight = f\"{ff_path}/fwd_step_0_layers_{tot_num_layers-1}_norm_shard_0_weight_0\"\n", + "hf_FWD_norm_weight = f\"{hf_path}/norm.weight\"\n", + "compare_tensors(hf_FWD_norm_weight, ff_BWD_norm_weight, tolerance=1e-5)\n", + "hf_BWD_norm_in = f\"{hf_path}/bwd_step_0_norm.gi_0\"\n", + "ff_BWD_norm_in = f\"{ff_path}/bwd_step_0_layers_{tot_num_layers-1}_norm_shard_0_input_1\"\n", + "compare_tensors(hf_BWD_norm_in, ff_BWD_norm_in, tolerance=1e-5)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn\n", + "class LlamaRotaryEmbedding(nn.Module):\n", + " def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n", + " super().__init__()\n", + "\n", + " self.dim = dim\n", + " self.max_position_embeddings = max_position_embeddings\n", + " self.base = base\n", + " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n", + " self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n", + "\n", + " # Build here to make `torch.jit.trace` work.\n", + " self._set_cos_sin_cache(\n", + " seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n", + " )\n", + "\n", + " def _set_cos_sin_cache(self, seq_len, device, dtype):\n", + " self.max_seq_len_cached = seq_len\n", + " t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n", + "\n", + " freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n", + " # Different from paper, but it uses a different permutation in order to obtain the same calculation\n", + " emb = torch.cat((freqs, freqs), dim=-1)\n", + " self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n", + " self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n", + "\n", + " def forward(self, x, seq_len=None):\n", + " # x: [bs, num_attention_heads, seq_len, head_size]\n", + " if seq_len > self.max_seq_len_cached:\n", + " self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n", + "\n", + " return (\n", + " self.cos_cached[:seq_len].to(dtype=x.dtype),\n", + " self.sin_cached[:seq_len].to(dtype=x.dtype),\n", + " )\n", + "def rotate_half(x):\n", + " \"\"\"Rotates half the hidden dims of the input.\"\"\"\n", + " x1 = x[..., : x.shape[-1] // 2] # first half\n", + " x2 = x[..., x.shape[-1] // 2 :] # second half\n", + " return torch.cat((x2, -x1), dim=-1)\n", + "def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n", + " \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n", + "\n", + " Args:\n", + " q (`torch.Tensor`): The query tensor.\n", + " k (`torch.Tensor`): The key tensor.\n", + " cos (`torch.Tensor`): The cosine part of the rotary embedding.\n", + " sin (`torch.Tensor`): The sine part of the rotary embedding.\n", + " position_ids (`torch.Tensor`):\n", + " The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n", + " used to pass offsetted position ids when working with a KV-cache.\n", + " unsqueeze_dim (`int`, *optional*, defaults to 1):\n", + " The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n", + " sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n", + " that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n", + " k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n", + " cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n", + " the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n", + " Returns:\n", + " `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n", + " \"\"\"\n", + " cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n", + " sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n", + " q_embed = (q * cos) + (rotate_half(q) * sin)\n", + " k_embed = (k * cos) + (rotate_half(k) * sin)\n", + " return q_embed, k_embed\n", + "head_dim = 64\n", + "max_position_embeddings = 2048\n", + "rope_theta=10_000\n", + "kv_seq_len = 24\n", + "rotary_emb = LlamaRotaryEmbedding(\n", + " head_dim,\n", + " max_position_embeddings=max_position_embeddings,\n", + " base=rope_theta,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Huggingface checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "FlexFlow checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "Huggingface-FlexFlow checks:\n", + "-- W2 --\n", + "Ok!\n", + "Ok!\n", + "-- Lora --\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "-- W2/W1/W3 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_11_SigmoidSiluMulti_shard_0_output_0\n", + "HF: [ 6.4350547e+03 -6.4898600e+05 1.1761116e+05 ... 2.1410337e+01\n", + " 1.2096541e+01 3.6424692e+00]\n", + "FF:[ 6.43525000e+03 -6.48986062e+05 1.17611250e+05 ... 2.14103413e+01\n", + " 1.20965385e+01 3.64246368e+00]\n", + "[False True True ... True True True]\n", + "[ 0 162 185 308 339 745 747 820 830 909 933 968 1008 1156\n", + " 1160 1190 1212 1296 1304 1311 1323 1353 1395 1421 1523 1578 1689 1717\n", + " 1736 1748 1836 2074 2124 2192 2221 2313 2394 2515 2518 2693 2758 2825\n", + " 2888 2894 2937 3024]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_11_layers_11_feed_forward_w2_shard_0_input_0\n", + "HF: [ 6.4350547e+03 -6.4898600e+05 1.1761116e+05 ... 2.1410337e+01\n", + " 1.2096541e+01 3.6424692e+00]\n", + "FF:[ 6.43525000e+03 -6.48986062e+05 1.17611250e+05 ... 2.14103413e+01\n", + " 1.20965385e+01 3.64246368e+00]\n", + "[False True True ... True True True]\n", + "[ 0 162 185 308 339 745 747 820 830 909 933 968 1008 1156\n", + " 1160 1190 1212 1296 1304 1311 1323 1353 1395 1421 1523 1578 1689 1717\n", + " 1736 1748 1836 2074 2124 2192 2221 2313 2394 2515 2518 2693 2758 2825\n", + " 2888 2894 2937 3024]\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "-- Attention --\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.self_attn.o_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_11_layers_11_attention_shard_0_o_proj_in_grad\n", + "HF: [ 1.2223595e+06 -2.6348565e+06 -5.0760525e+05 ... 6.8275871e+01\n", + " -5.8116108e+01 9.5347488e+01]\n", + "FF:[ 1.22235925e+06 -2.63485625e+06 -5.07605000e+05 ... 6.82758865e+01\n", + " -5.81161423e+01 9.53475494e+01]\n", + "[ True True True ... True True True]\n", + "[ 51 77 95 168 175 232 725]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[ 1.22235950e+06 9.93645859e+01 -2.82157593e+01 ... -3.94578514e+01\n", + " -1.98409653e+01 -1.33438044e+01]\n", + " [-2.63485650e+06 -1.13461929e+02 1.14223976e+02 ... 7.52578735e+01\n", + " 1.33362747e+02 6.78501587e+01]\n", + " [-5.07605250e+05 4.34111862e+01 8.10619354e+01 ... 4.70537224e+01\n", + " 4.02149696e+01 6.98045502e+01]\n", + " ...\n", + " [ 3.02792250e+06 3.31295319e+02 9.98417091e+00 ... 4.90895653e+01\n", + " 9.71413574e+01 6.82758713e+01]\n", + " [-3.64456375e+06 -2.43692596e+02 -6.85474396e+00 ... -3.71503868e+01\n", + " -1.34136658e+01 -5.81161079e+01]\n", + " [ 3.31921500e+06 2.24193970e+02 -6.64005566e+00 ... 2.11662292e+00\n", + " 3.37400856e+01 9.53474884e+01]]\n", + "FF:[[ 1.22235925e+06 9.93645630e+01 -2.82157211e+01 ... -3.94577713e+01\n", + " -1.98408775e+01 -1.33438234e+01]\n", + " [-2.63485625e+06 -1.13461960e+02 1.14224037e+02 ... 7.52577744e+01\n", + " 1.33362701e+02 6.78501205e+01]\n", + " [-5.07605000e+05 4.34111404e+01 8.10619278e+01 ... 4.70536804e+01\n", + " 4.02149124e+01 6.98045578e+01]\n", + " ...\n", + " [ 3.02792250e+06 3.31295227e+02 9.98412323e+00 ... 4.90895386e+01\n", + " 9.71413727e+01 6.82758865e+01]\n", + " [-3.64456400e+06 -2.43692627e+02 -6.85472488e+00 ... -3.71504822e+01\n", + " -1.34137001e+01 -5.81161423e+01]\n", + " [ 3.31921500e+06 2.24193970e+02 -6.64004517e+00 ... 2.11670875e+00\n", + " 3.37400322e+01 9.53475494e+01]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[ 51 77 95 168 175 232 725]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[ 1.2223588e+06 -2.6348530e+06 -5.0760291e+05 ... 3.0279325e+06\n", + " -3.6445672e+06 3.3192180e+06]\n", + " [-4.2496326e+02 1.1576636e+03 9.8397858e+02 ... 1.6480791e+03\n", + " -5.9697235e+02 6.2627173e+02]\n", + " [-2.2012039e+01 6.6097900e+01 3.9933994e+01 ... 5.7103355e+01\n", + " -1.5968766e+01 3.6536639e+00]\n", + " ...\n", + " [-1.2302110e+00 5.3052688e+00 2.1982718e+00 ... 1.3990868e+00\n", + " -5.5132383e-01 4.8985812e-01]\n", + " [-1.0771493e+00 6.9571300e+00 2.7373023e+00 ... 4.9663010e+00\n", + " -9.9705428e-01 2.1829298e+00]\n", + " [-5.9534687e-01 3.0272012e+00 3.1143982e+00 ... 2.4072502e+00\n", + " -2.0490403e+00 3.3617332e+00]]\n", + "FF:[[ 1.22235850e+06 -2.63485275e+06 -5.07602656e+05 ... 3.02793250e+06\n", + " -3.64456750e+06 3.31921800e+06]\n", + " [-4.24962585e+02 1.15766296e+03 9.83978577e+02 ... 1.64807898e+03\n", + " -5.96972351e+02 6.26271790e+02]\n", + " [-2.20120354e+01 6.60979462e+01 3.99340210e+01 ... 5.71033745e+01\n", + " -1.59687757e+01 3.65366316e+00]\n", + " ...\n", + " [-1.23020661e+00 5.30526114e+00 2.19826817e+00 ... 1.39908671e+00\n", + " -5.51325083e-01 4.89858717e-01]\n", + " [-1.07714510e+00 6.95712519e+00 2.73729825e+00 ... 4.96630049e+00\n", + " -9.97055829e-01 2.18292713e+00]\n", + " [-5.95347941e-01 3.02720070e+00 3.11439991e+00 ... 2.40725493e+00\n", + " -2.04904509e+00 3.36174107e+00]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[0 0 0 0 0 0 0]\n", + "Ok!\n", + "7.4363425925925934% mismatch in QK prods softmax out grad\n", + "Ok!\n", + "hf_attn_in: (768, 24)\n", + "[[-7.52523500e+06 -1.27625415e+03 -4.39338150e+01 ... -3.34414902e+01\n", + " 2.38160934e+01 3.15938339e+01]\n", + " [-9.55138900e+06 6.71377197e+02 2.06871887e+02 ... -3.86393509e+01\n", + " 2.14816055e+01 -6.58599396e+01]\n", + " [ 1.14522670e+07 2.19898975e+03 -6.89673233e+00 ... 9.51593590e+00\n", + " -1.68612709e+01 6.02474251e+01]\n", + " ...\n", + " [ 2.10891925e+06 3.78648706e+03 1.02701221e+03 ... 3.59794388e+01\n", + " 5.03902206e+01 4.19777756e+01]\n", + " [ 2.11695300e+06 -2.36283508e+02 -1.08002625e+02 ... 9.36443710e+00\n", + " 3.84094887e+01 -7.51948738e+00]\n", + " [ 7.39155050e+06 1.11731885e+03 3.38369843e+02 ... 3.70399475e+01\n", + " 1.77629051e+01 9.76780853e+01]]\n", + "ff_attn_in: (768, 24)\n", + "[[-7.52523600e+06 -1.27625293e+03 -4.39336700e+01 ... -3.34414597e+01\n", + " 2.38162422e+01 3.15938187e+01]\n", + " [-9.55138900e+06 6.71377319e+02 2.06871674e+02 ... -3.86393127e+01\n", + " 2.14817867e+01 -6.58600464e+01]\n", + " [ 1.14522660e+07 2.19898950e+03 -6.89660644e+00 ... 9.51594448e+00\n", + " -1.68611774e+01 6.02474518e+01]\n", + " ...\n", + " [ 2.10891850e+06 3.78648633e+03 1.02701196e+03 ... 3.59794846e+01\n", + " 5.03901253e+01 4.19777679e+01]\n", + " [ 2.11695400e+06 -2.36282440e+02 -1.08002762e+02 ... 9.36448860e+00\n", + " 3.84096107e+01 -7.51954842e+00]\n", + " [ 7.39155000e+06 1.11731921e+03 3.38370087e+02 ... 3.70398293e+01\n", + " 1.77627277e+01 9.76782227e+01]]\n", + "6.011284722222222% mismatch in attention input grads\n", + "\n", + "Huggingface checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "FlexFlow checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "Huggingface-FlexFlow checks:\n", + "-- W2 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.10.mlp.down_proj.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_10_layers_10_feed_forward_w2_shard_0_output_0\n", + "HF: [-9.4779546e+09 -1.2174155e+10 1.4899113e+10 ... 4.9057606e+01\n", + " 4.7770348e+01 5.8564331e+01]\n", + "FF:[-9.47795558e+09 -1.21741548e+10 1.48991119e+10 ... 4.90575981e+01\n", + " 4.77703362e+01 5.85643845e+01]\n", + "[ True True True ... True True True]\n", + "[ 88 138 187 203 232 242 493 657 750 900 1198 1249\n", + " 1287 1305 1414 1428 1490 1588 1600 1612 1625 1657 1676 1677\n", + " 1692 1694 1724 1730 1772 1822 1825 1838 1853 1910 2035 2043\n", + " 2053 2059 2073 2078 2123 2145 2214 2238 2241 2285 2292 2389\n", + " 2542 2582 2589 2599 2674 2688 2711 2840 2856 2961 2963 2980\n", + " 3064 3176 3192 3255 3262 3278 3338 3341 3412 3419 3492 3590\n", + " 3624 3646 3657 3807 3840 3842 3846 3883 3887 4005 4049 4071\n", + " 4076 4077 4079 4137 4142 4192 4193 4202 4218 4224 4273 4355\n", + " 4358 4381 4401 4435 4469 4499 4514 4546 4598 4619 4747 4846\n", + " 4872 4916 4952 4966 5016 5067 5107 5112 5116 5194 5225 5350\n", + " 5364 5403 5515 5537 5550 5578 5650 5653 5654 5736 5751 5837\n", + " 5870 5881 5972 5998 6006 6051 6061 6107 6129 6204 6236 6292\n", + " 6296 6327 6382 6393 6403 6420 6424 6436 6468 6542 6599 6675\n", + " 6681 6711 6723 6767 6823 6914 6983 7047 7064 7133 7167 7197\n", + " 7198 7209 7528 7537 7538 7686 7850 7855 7889 7910 7919 7927\n", + " 7937 7939 8089 8101 8157 8169 8175 8223 8292 8304 8306 8342\n", + " 8351 8414 8475 8500 8543 8558 8609 8656 8687 8704 8724 8726\n", + " 8777 8816 8826 8871 8904 8934 8983 9012 9033 9043 9068 9093\n", + " 9125 9133 9144 9151 9154 9217 9222 9320 9335 9367 9398 9421\n", + " 9434 9521 9547 9633 9702 9726 9763 9949 10018 10053 10062 10079\n", + " 10137 10149 10203 10261 10269 10292 10312 10332 10471 10478 10514 10596\n", + " 10645 10676 10678 10781 10795 10810 10833 10891 10904 10935 10957 10977\n", + " 10982 11028 11095 11172 11223 11251 11283 11303 11319 11374 11392 11437\n", + " 11486 11627 11678 11750 11759 11979 11996 12019 12126 12237 12262 12288\n", + " 12303 12309 12315 12387 12543 12569 12613 12648 12786 12852 12866 12879\n", + " 12947 12963 13037 13058 13261 13284 13312 13394 13399 13427 13526 13527\n", + " 13592 13695 13741 13752 13775 13803 13812 13866 13902 14049 14170 14241\n", + " 14354 14382 14426 14451 14455 14486 14502 14582 14820 14934 14961 14976\n", + " 15000 15003 15014 15077 15096 15108 15135 15148 15165 15219 15232 15290\n", + " 15339 15345 15819 15945 15994 16077 16135 16218 16231 16233 16239 16243\n", + " 16295 16311 16339 16356 16366 16417 16456 16498 16502 16503 16506 16547\n", + " 16585 16603 16611 16633 16661 16683 16704 16710 16723 16724 16745 16754\n", + " 16773 16787 16789 16818 16829 16833 16913 16933 17025 17033 17037 17055\n", + " 17084 17098 17109 17176 17225 17240 17292 17294 17339 17390 17427 17437\n", + " 17579 17626 17630 17654 17719 17902 17912 18023 18025 18124 18203 18339\n", + " 18344]\n", + "Ok!\n", + "Ok!\n", + "-- Lora --\n", + "Ok!\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.10.mlp.down_proj.lora_B.default.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_10_layers_10_feed_forward_w2_lora_shard_0_output_0\n", + "HF: [-9.4779546e+09 -1.2174155e+10 1.4899113e+10 ... 4.9057606e+01\n", + " 4.7770348e+01 5.8564331e+01]\n", + "FF:[-9.47795558e+09 -1.21741548e+10 1.48991119e+10 ... 4.90575981e+01\n", + " 4.77703362e+01 5.85643845e+01]\n", + "[ True True True ... True True True]\n", + "[ 88 138 187 203 232 242 493 657 750]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.10.mlp.down_proj.lora_A.default.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_10_layers_10_feed_forward_w2_lora_shard_0_input_0\n", + "HF: [ 4.7819588e+07 3.8833264e+07 4.7789860e+07 ... 1.0804405e+00\n", + " 2.7186510e-01 -2.9918199e+00]\n", + "FF:[ 4.78195960e+07 3.88332640e+07 4.77898600e+07 ... 1.08044124e+00\n", + " 2.71864563e-01 -2.99182224e+00]\n", + "[ True True True ... True True True]\n", + "[ 109 211 312 422 590 832 835 1016 1053 1076 1268 1353 1374 1693\n", + " 1701 1710 1722 1832 1954 1965 1997 2076 2124 2146 2378 2520 2605 2624\n", + " 2967 3007 3015]\n", + "Ok!\n", + "-- W2/W1/W3 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.10.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_10_SigmoidSiluMulti_shard_0_output_0\n", + "HF: [ 3.3558659e+09 1.3409817e+10 -1.4671958e+10 ... 7.2100967e+01\n", + " 6.5979071e+00 -2.1230124e+01]\n", + "FF:[ 3.35586406e+09 1.34098166e+10 -1.46719611e+10 ... 7.21009750e+01\n", + " 6.59790993e+00 -2.12301121e+01]\n", + "[ True True True ... True True True]\n", + "[ 4 95 111 163 179 191 279 305 363 406 447 487 489 494\n", + " 517 617 703 713 735 796 805 819 826 858 882 959 964 967\n", + " 986 1020 1035 1054 1067 1070 1077 1081 1095 1097 1123 1139 1181 1238\n", + " 1296 1342 1369 1489 1550 1557 1623 1669 1752 1757 1783 1819 1876 1949\n", + " 1963 1993 2034 2047 2091 2115 2153 2170 2306 2381 2419 2431 2456 2501\n", + " 2503 2591 2653 2768 2778 2791 2970 2980 3053 3067]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.10.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_10_layers_10_feed_forward_w2_shard_0_input_0\n", + "HF: [ 3.3558659e+09 1.3409817e+10 -1.4671958e+10 ... 7.2100967e+01\n", + " 6.5979071e+00 -2.1230124e+01]\n", + "FF:[ 3.35586406e+09 1.34098166e+10 -1.46719611e+10 ... 7.21009750e+01\n", + " 6.59790993e+00 -2.12301121e+01]\n", + "[ True True True ... True True True]\n", + "[ 4 95 111 163 179 191 279 305 363 406 447 487 489 494\n", + " 517 617 703 713 735 796 805 819 826 858 882 959 964 967\n", + " 986 1020 1035 1054 1067 1070 1077 1081 1095 1097 1123 1139 1181 1238\n", + " 1296 1342 1369 1489 1550 1557 1623 1669 1752 1757 1783 1819 1876 1949\n", + " 1963 1993 2034 2047 2091 2115 2153 2170 2306 2381 2419 2431 2456 2501\n", + " 2503 2591 2653 2768 2778 2791 2970 2980 3053 3067]\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "-- Attention --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.10.self_attn.o_proj.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_10_layers_10_attention_shard_0_output_0\n", + "HF: [-9.4779546e+09 -1.2174155e+10 1.4899113e+10 ... 9.3464905e+01\n", + " 7.5613129e+01 7.6598846e+01]\n", + "FF:[-9.47795558e+09 -1.21741548e+10 1.48991119e+10 ... 9.34649200e+01\n", + " 7.56131058e+01 7.65989227e+01]\n", + "[ True True True ... True True True]\n", + "[ 88 138 187 203 232 242 493 657 750]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.10.self_attn.o_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_10_layers_10_attention_shard_0_o_proj_in_grad\n", + "HF: [-9.4470595e+09 -7.3870331e+09 1.2659395e+10 ... -2.8149616e+01\n", + " 1.7019112e+02 -7.7236428e+00]\n", + "FF:[-9.44706150e+09 -7.38703309e+09 1.26593966e+10 ... -2.81496239e+01\n", + " 1.70191177e+02 -7.72364044e+00]\n", + "[ True True True ... True True True]\n", + "[ 11 98 109 134 262 266 274 309 310 327 328 364 398 409 429 605 645]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[-9.44705946e+09 2.28078384e+01 3.18554016e+02 ... 1.17267204e+02\n", + " 2.06791725e+01 1.13138672e+02]\n", + " [-7.38703309e+09 -7.36898804e+00 7.93705673e+01 ... 2.04039650e+01\n", + " 3.18331490e+01 5.44241562e+01]\n", + " [ 1.26593946e+10 1.77534424e+02 -2.97175941e+01 ... 1.16716766e+01\n", + " 7.70214081e+01 2.81902496e+02]\n", + " ...\n", + " [ 4.51210445e+10 3.63867615e+02 -8.04915466e+01 ... -1.34332123e+02\n", + " -1.22151840e+02 -2.81496162e+01]\n", + " [-1.39591885e+10 1.59216873e+02 6.11343079e+01 ... 1.56675262e+02\n", + " 9.68551483e+01 1.70191116e+02]\n", + " [-1.29442345e+10 -2.39441833e+02 2.73647644e+02 ... -4.41197014e+01\n", + " -9.48526230e+01 -7.72364283e+00]]\n", + "FF:[[-9.44706150e+09 2.28079376e+01 3.18553864e+02 ... 1.17267227e+02\n", + " 2.06791859e+01 1.13138741e+02]\n", + " [-7.38703309e+09 -7.36921692e+00 7.93703690e+01 ... 2.04038925e+01\n", + " 3.18332825e+01 5.44241333e+01]\n", + " [ 1.26593966e+10 1.77534454e+02 -2.97174206e+01 ... 1.16717224e+01\n", + " 7.70213699e+01 2.81902618e+02]\n", + " ...\n", + " [ 4.51210527e+10 3.63867554e+02 -8.04915695e+01 ... -1.34332092e+02\n", + " -1.22151901e+02 -2.81496239e+01]\n", + " [-1.39591834e+10 1.59216995e+02 6.11343040e+01 ... 1.56675293e+02\n", + " 9.68551559e+01 1.70191177e+02]\n", + " [-1.29442304e+10 -2.39441772e+02 2.73647644e+02 ... -4.41196594e+01\n", + " -9.48526916e+01 -7.72364044e+00]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[ 11 98 109 134 262 266 274 309 310 327 328 364 398 409 429 605 645]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[-9.44705946e+09 -7.38703309e+09 1.26593946e+10 ... 4.51210445e+10\n", + " -1.39591885e+10 -1.29442345e+10]\n", + " [ 1.14852783e+03 4.39543152e+02 1.07877356e+03 ... -2.42416113e+03\n", + " 2.64504834e+03 4.68633453e+02]\n", + " [ 5.72417107e+01 4.12602806e+01 -2.27319489e+01 ... -3.40788422e+01\n", + " 4.86237946e+01 1.25752163e+01]\n", + " ...\n", + " [ 6.76848269e+00 8.23165894e+00 2.10253639e+01 ... -3.19590777e-01\n", + " 3.68098617e-01 -1.95310101e-01]\n", + " [ 4.08574820e+00 5.33035660e+00 1.41003275e+01 ... -1.35607815e+00\n", + " 4.06074905e+00 -7.67630756e-01]\n", + " [ 2.03186665e+01 9.77407932e+00 5.06271019e+01 ... -6.80029154e-01\n", + " 4.11142111e+00 -1.86585218e-01]]\n", + "FF:[[-9.44706150e+09 -7.38703309e+09 1.26593966e+10 ... 4.51210527e+10\n", + " -1.39591834e+10 -1.29442304e+10]\n", + " [ 1.14852808e+03 4.39542755e+02 1.07877344e+03 ... -2.42416138e+03\n", + " 2.64504932e+03 4.68633698e+02]\n", + " [ 5.72415771e+01 4.12602005e+01 -2.27318707e+01 ... -3.40787392e+01\n", + " 4.86236725e+01 1.25752039e+01]\n", + " ...\n", + " [ 6.76847696e+00 8.23167515e+00 2.10253181e+01 ... -3.19590837e-01\n", + " 3.68098557e-01 -1.95310280e-01]\n", + " [ 4.08574867e+00 5.33037567e+00 1.41003180e+01 ... -1.35607564e+00\n", + " 4.06074095e+00 -7.67629445e-01]\n", + " [ 2.03186874e+01 9.77407932e+00 5.06271439e+01 ... -6.80029511e-01\n", + " 4.11142349e+00 -1.86585203e-01]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", + "Ok!\n", + "6.640625% mismatch in QK prods softmax out grad\n", + "Ok!\n", + "hf_attn_in: (768, 24)\n", + "[[-5.1505955e+10 -4.7166772e+03 -1.3288132e+02 ... -3.0123844e+00\n", + " -5.5234032e+01 6.0299168e+00]\n", + " [-3.5960029e+10 -5.3263096e+03 -1.9434322e+02 ... -5.6601189e+01\n", + " -1.0787462e+02 -6.0718418e+01]\n", + " [ 4.8131662e+10 1.1578307e+04 1.7744476e+02 ... -5.6970375e+01\n", + " -1.7497168e+01 -7.2297249e+00]\n", + " ...\n", + " [-9.0346426e+08 6.4752144e+03 3.2408417e+02 ... 6.1075470e+01\n", + " 8.5356834e+01 8.3221588e+01]\n", + " [-5.0754217e+09 -2.2929268e+03 -1.4913528e+02 ... 8.6639397e+01\n", + " 1.1156468e+02 1.0695674e+02]\n", + " [ 5.5844772e+09 3.0225920e+03 -6.3137859e+01 ... -6.5270996e+01\n", + " 8.2730171e+01 -1.0107367e+02]]\n", + "ff_attn_in: (768, 24)\n", + "[[-5.15059548e+10 -4.71667773e+03 -1.32881012e+02 ... -3.01225996e+00\n", + " -5.52339973e+01 6.02991867e+00]\n", + " [-3.59600292e+10 -5.32630957e+03 -1.94343079e+02 ... -5.66010437e+01\n", + " -1.07874649e+02 -6.07182846e+01]\n", + " [ 4.81316659e+10 1.15783076e+04 1.77444519e+02 ... -5.69703102e+01\n", + " -1.74972763e+01 -7.22990799e+00]\n", + " ...\n", + " [-9.03455232e+08 6.47521484e+03 3.24083832e+02 ... 6.10753632e+01\n", + " 8.53567886e+01 8.32217255e+01]\n", + " [-5.07543654e+09 -2.29292749e+03 -1.49135025e+02 ... 8.66392517e+01\n", + " 1.11564789e+02 1.06956917e+02]\n", + " [ 5.58446592e+09 3.02259229e+03 -6.31376152e+01 ... -6.52709351e+01\n", + " 8.27302551e+01 -1.01073837e+02]]\n", + "7.025824652777778% mismatch in attention input grads\n", + "\n", + "Huggingface checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "FlexFlow checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "Huggingface-FlexFlow checks:\n", + "-- W2 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.9.mlp.down_proj.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_9_layers_9_feed_forward_w2_shard_0_output_0\n", + "HF: [-6.33203254e+13 -4.43651289e+13 6.35509366e+13 ... 1.08435585e+02\n", + " 9.42303467e+01 5.89958420e+01]\n", + "FF:[-6.33203296e+13 -4.43651289e+13 6.35509408e+13 ... 1.08435623e+02\n", + " 9.42303467e+01 5.89958954e+01]\n", + "[ True True True ... True True True]\n", + "[ 26 51 66 85 259 262 272 296 298 329 392 415\n", + " 428 482 492 514 526 531 671 731 763 777 893 927\n", + " 984 1105 1184 1206 1418 1541 1548 1572 1577 1613 1619 1643\n", + " 1658 1661 1691 1701 1706 1726 1757 1784 1815 1833 1849 1856\n", + " 1880 1891 1921 1956 1969 2012 2021 2028 2030 2059 2065 2144\n", + " 2149 2183 2210 2238 2292 2342 2357 2384 2414 2495 2531 2565\n", + " 2597 2662 2713 2781 2821 2829 2877 2904 2921 2927 2962 2973\n", + " 3044 3066 3094 3100 3106 3159 3193 3251 3377 3389 3397 3427\n", + " 3436 3570 3594 3703 3729 3770 3772 3780 3811 3840 3842 3860\n", + " 3907 3920 3929 3946 3955 3969 4005 4009 4034 4048 4077 4089\n", + " 4104 4129 4134 4178 4202 4212 4219 4239 4245 4256 4273 4373\n", + " 4407 4463 4464 4465 4481 4511 4537 4541 4543 4549 4597 4599\n", + " 4633 4759 4760 4789 4846 4884 4901 4930 4954 4971 4993 5024\n", + " 5030 5041 5050 5116 5130 5163 5207 5224 5282 5313 5322 5349\n", + " 5363 5403 5410 5412 5454 5543 5581 5590 5654 5673 5784 5821\n", + " 5849 5880 5911 5917 5982 6000 6062 6165 6178 6193 6200 6272\n", + " 6322 6351 6366 6376 6380 6382 6393 6412 6420 6430 6433 6446\n", + " 6476 6482 6488 6490 6519 6527 6540 6556 6563 6567 6577 6600\n", + " 6619 6680 6709 6735 6768 6777 6780 6823 6825 6826 6830 6863\n", + " 6880 6912 6988 7006 7030 7071 7077 7102 7123 7244 7264 7367\n", + " 7389 7390 7434 7451 7452 7455 7505 7532 7539 7589 7598 7620\n", + " 7651 7653 7659 7709 7714 7740 7751 7759 7803 7808 7820 7917\n", + " 7923 7926 7949 7962 7966 7978 8002 8004 8040 8050 8052 8068\n", + " 8180 8223 8250 8253 8265 8341 8344 8375 8376 8386 8449 8468\n", + " 8501 8509 8522 8535 8585 8590 8593 8642 8657 8674 8687 8707\n", + " 8714 8726 8729 8737 8756 8769 8801 8846 8850 8865 8907 8998\n", + " 9018 9043 9059 9066 9083 9093 9098 9130 9131 9165 9189 9216\n", + " 9285 9337 9368 9526 9539 9563 9620 9659 9723 9793 9804 9817\n", + " 9820 9827 9908 9995 10053 10128 10135 10143 10205 10253 10274 10292\n", + " 10300 10311 10327 10356 10406 10441 10491 10494 10551 10562 10563 10634\n", + " 10649 10674 10710 10734 10821 10831 10833 10838 10845 10911 10966 10981\n", + " 10988 10990 10998 11008 11044 11049 11100 11127 11141 11197 11250 11269\n", + " 11285 11308 11361 11383 11437 11460 11494 11502 11511 11522 11546 11557\n", + " 11564 11588 11649 11658 11671 11674 11703 11729 11749 11759 11832 11892\n", + " 11979 11988 12000 12038 12063 12078 12107 12119 12165 12259 12269 12270\n", + " 12347 12369 12386 12415 12475 12518 12566 12569 12574 12652 12693 12792\n", + " 12833 12834 12852 12872 12900 12946 13117 13121 13124 13321 13345 13357\n", + " 13427 13431 13446 13473 13526 13635 13638 13662 13706 13733 13803 13807\n", + " 13852 13882 13912 13924 13962 13969 13986 14023 14036 14046 14085 14110\n", + " 14130 14141 14175 14183 14191 14220 14222 14223 14285 14310 14331 14336\n", + " 14354 14375 14425 14427 14451 14482 14493 14516 14560 14563 14581 14623\n", + " 14671 14677 14679 14680 14685 14688 14742 14799 14860 14868 14870 14872\n", + " 14900 14909 14916 14940 14964 14991 15003 15023 15027 15033 15038 15051\n", + " 15086 15100 15184 15214 15232 15290 15352 15363 15365 15407 15433 15451\n", + " 15522 15577 15707 15720 15725 15739 15830 15837 15875 15937 15965 15985\n", + " 16017 16054 16113 16136 16142 16169 16191 16232 16238 16250 16268 16282\n", + " 16285 16290 16295 16304 16327 16334 16353 16356 16363 16382 16403 16407\n", + " 16408 16409 16458 16459 16495 16497 16499 16500 16516 16532 16595 16603\n", + " 16611 16657 16678 16680 16695 16701 16704 16754 16768 16807 16818 16856\n", + " 16870 16951 16971 16986 16989 16992 17048 17134 17181 17208 17217 17236\n", + " 17243 17319 17363 17398 17448 17471 17497 17557 17646 17654 17659 17692\n", + " 17754 17947 17957 17969 17975 18029 18128 18146 18196 18206 18207 18250\n", + " 18265 18313 18406]\n", + "Ok!\n", + "Ok!\n", + "-- Lora --\n", + "Ok!\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.9.mlp.down_proj.lora_B.default.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_9_layers_9_feed_forward_w2_lora_shard_0_output_0\n", + "HF: [-6.33203254e+13 -4.43651289e+13 6.35509366e+13 ... 1.08435585e+02\n", + " 9.42303467e+01 5.89958420e+01]\n", + "FF:[-6.33203296e+13 -4.43651289e+13 6.35509408e+13 ... 1.08435623e+02\n", + " 9.42303467e+01 5.89958954e+01]\n", + "[ True True True ... True True True]\n", + "[ 26 51 66 85 259 262 272 296 298 329 392 415 428 482 492 514 526 531\n", + " 671 731 763]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.9.mlp.down_proj.lora_A.default.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_9_layers_9_feed_forward_w2_lora_shard_0_input_0\n", + "HF: [ 5.0590863e+10 3.7823513e+11 -5.0394451e+11 ... -5.5814421e-01\n", + " 2.2970559e-01 -1.2293311e+00]\n", + "FF:[ 5.05906831e+10 3.78235290e+11 -5.03944544e+11 ... -5.58144033e-01\n", + " 2.29705781e-01 -1.22933090e+00]\n", + "[ True True True ... True True True]\n", + "[ 189 254 317 418 515 546 577 634 636 675 712 808 1011 1030\n", + " 1080 1091 1132 1168 1254 1265 1285 1287 1354 1381 1427 1459 1506 1620\n", + " 1654 1752 1887 1897 1900 1937 1981 1985 1986 2003 2029 2152 2181 2295\n", + " 2395 2426 2445 2673 2687 2859 2947 2977 3037]\n", + "Ok!\n", + "-- W2/W1/W3 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.9.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_9_SigmoidSiluMulti_shard_0_output_0\n", + "HF: [ 2.5211001e+13 -5.6630301e+13 -2.3639437e+13 ... -4.6000423e+01\n", + " 1.2655228e+01 7.1020460e+00]\n", + "FF:[ 2.52109673e+13 -5.66302930e+13 -2.36394182e+13 ... -4.60003510e+01\n", + " 1.26551876e+01 7.10206795e+00]\n", + "[ True True True ... True True True]\n", + "[ 9 49 113 174 243 267 271 288 323 335 397 399 438 439\n", + " 457 475 506 568 569 652 680 689 715 735 739 758 766 777\n", + " 785 837 842 852 865 884 893 919 930 932 936 939 957 1018\n", + " 1095 1105 1112 1114 1129 1168 1217 1220 1229 1230 1233 1237 1283 1304\n", + " 1354 1453 1532 1542 1547 1550 1592 1597 1603 1615 1647 1679 1698 1699\n", + " 1712 1770 1819 1835 1875 1977 2007 2016 2039 2066 2078 2102 2153 2245\n", + " 2403 2447 2621 2698 2704 2728 2736 2743 2774 2792 2836 2858 2870 2881\n", + " 2932 2948 3018 3034 3066]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.9.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_9_layers_9_feed_forward_w2_shard_0_input_0\n", + "HF: [ 2.5211001e+13 -5.6630301e+13 -2.3639437e+13 ... -4.6000423e+01\n", + " 1.2655228e+01 7.1020460e+00]\n", + "FF:[ 2.52109673e+13 -5.66302930e+13 -2.36394182e+13 ... -4.60003510e+01\n", + " 1.26551876e+01 7.10206795e+00]\n", + "[ True True True ... True True True]\n", + "[ 9 49 113 174 243 267 271 288 323 335 397 399 438 439\n", + " 457 475 506 568 569 652 680 689 715 735 739 758 766 777\n", + " 785 837 842 852 865 884 893 919 930 932 936 939 957 1018\n", + " 1095 1105 1112 1114 1129 1168 1217 1220 1229 1230 1233 1237 1283 1304\n", + " 1354 1453 1532 1542 1547 1550 1592 1597 1603 1615 1647 1679 1698 1699\n", + " 1712 1770 1819 1835 1875 1977 2007 2016 2039 2066 2078 2102 2153 2245\n", + " 2403 2447 2621 2698 2704 2728 2736 2743 2774 2792 2836 2858 2870 2881\n", + " 2932 2948 3018 3034 3066]\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "-- Attention --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.9.self_attn.o_proj.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_9_layers_9_attention_shard_0_output_0\n", + "HF: [-6.3320325e+13 -4.4365129e+13 6.3550937e+13 ... 7.2449814e+01\n", + " 8.6617142e+01 8.3981407e+01]\n", + "FF:[-6.33203296e+13 -4.43651289e+13 6.35509408e+13 ... 7.24498901e+01\n", + " 8.66170959e+01 8.39814606e+01]\n", + "[ True True True ... True True True]\n", + "[ 26 51 66 85 259 262 272 296 298 329 392 415 428 482 492 514 526 531\n", + " 671 731 763]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.9.self_attn.o_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_9_layers_9_attention_shard_0_o_proj_in_grad\n", + "HF: [ 7.2885461e+13 -6.0835821e+13 -7.9732612e+13 ... 2.5297220e+02\n", + " -8.1722275e+01 -7.0014725e+01]\n", + "FF:[ 7.28854608e+13 -6.08357832e+13 -7.97326201e+13 ... 2.52972260e+02\n", + " -8.17222137e+01 -7.00146637e+01]\n", + "[ True True True ... True True True]\n", + "[ 6 36 43 55 60 82 101 110 117 217 221 229 236 256 289 392 421 429\n", + " 433 454 486 518 523 565 568 629 639 648 707 725 744]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[ 7.28854608e+13 6.37500977e+02 2.96775421e+02 ... 8.35403061e+01\n", + " 1.72460327e+02 2.90482426e+01]\n", + " [-6.08358210e+13 -5.23222847e+01 -2.34542664e+02 ... -1.87500763e+01\n", + " -8.99429398e+01 8.64021378e+01]\n", + " [-7.97326117e+13 -4.24736328e+02 -1.82208099e+02 ... 3.21808720e+00\n", + " -5.87415466e+01 -2.08511108e+02]\n", + " ...\n", + " [-1.13411917e+14 -3.48418640e+02 1.52205795e+02 ... 1.51519928e+02\n", + " 2.45651031e+02 2.52972198e+02]\n", + " [-3.75985275e+12 2.39696625e+02 1.51989685e+02 ... -2.85605354e+01\n", + " -1.79121232e+00 -8.17222748e+01]\n", + " [ 1.11016038e+14 -1.96372967e+01 -1.27668396e+02 ... 3.35008011e+01\n", + " -7.46116943e+01 -7.00147247e+01]]\n", + "FF:[[ 7.28854608e+13 6.37500977e+02 2.96775513e+02 ... 8.35403976e+01\n", + " 1.72460068e+02 2.90483646e+01]\n", + " [-6.08357832e+13 -5.23225098e+01 -2.34542755e+02 ... -1.87501526e+01\n", + " -8.99431992e+01 8.64022217e+01]\n", + " [-7.97326201e+13 -4.24736572e+02 -1.82207733e+02 ... 3.21793270e+00\n", + " -5.87416573e+01 -2.08511139e+02]\n", + " ...\n", + " [-1.13411925e+14 -3.48418640e+02 1.52205902e+02 ... 1.51519714e+02\n", + " 2.45650864e+02 2.52972260e+02]\n", + " [-3.75988630e+12 2.39696686e+02 1.51989319e+02 ... -2.85606136e+01\n", + " -1.79138493e+00 -8.17222137e+01]\n", + " [ 1.11016046e+14 -1.96372318e+01 -1.27668480e+02 ... 3.35009079e+01\n", + " -7.46116791e+01 -7.00146637e+01]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[ 6 36 43 55 60 82 101 110 117 217 221 229 236 256 289 392 421 429\n", + " 433 454 486 518 523 565 568 629 639 648 707 725 744]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[ 7.2885461e+13 -6.0835821e+13 -7.9732612e+13 ... -1.1341192e+14\n", + " -3.7598527e+12 1.1101604e+14]\n", + " [ 3.3241980e+03 -6.3044128e+02 -3.0447307e+03 ... 3.0137921e+02\n", + " 3.8262988e+02 -4.2889914e+02]\n", + " [ 3.5639046e+01 -1.6155790e+01 -2.4461178e+01 ... 2.7450909e+02\n", + " 1.6181946e+02 -2.5407137e+02]\n", + " ...\n", + " [ 4.6487908e+00 -9.6633381e-01 -2.7078497e-01 ... 3.6374569e+01\n", + " -1.7563061e+00 -7.1206141e+00]\n", + " [ 1.8901447e+00 8.9006472e-01 -4.3125896e+00 ... 2.6014965e+01\n", + " -3.7720141e-01 -7.8855257e+00]\n", + " [ 1.9513500e+00 5.8041654e+00 -1.4006979e+01 ... 7.2743622e+01\n", + " -2.3499712e+01 -2.0133139e+01]]\n", + "FF:[[ 7.28854608e+13 -6.08357832e+13 -7.97326201e+13 ... -1.13411925e+14\n", + " -3.75988630e+12 1.11016046e+14]\n", + " [ 3.32419922e+03 -6.30442505e+02 -3.04472998e+03 ... 3.01379364e+02\n", + " 3.82629669e+02 -4.28898712e+02]\n", + " [ 3.56390572e+01 -1.61558037e+01 -2.44611683e+01 ... 2.74509308e+02\n", + " 1.61819229e+02 -2.54071594e+02]\n", + " ...\n", + " [ 4.64879847e+00 -9.66338813e-01 -2.70792574e-01 ... 3.63745117e+01\n", + " -1.75632846e+00 -7.12060070e+00]\n", + " [ 1.89013767e+00 8.90062451e-01 -4.31257772e+00 ... 2.60149212e+01\n", + " -3.77217919e-01 -7.88551569e+00]\n", + " [ 1.95135939e+00 5.80417490e+00 -1.40069904e+01 ... 7.27435226e+01\n", + " -2.34996586e+01 -2.01330910e+01]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", + "Ok!\n", + "7.609953703703703% mismatch in QK prods softmax out grad\n", + "Ok!\n", + "hf_attn_in: (768, 24)\n", + "[[-1.17282076e+14 -2.12461621e+03 8.80099030e+01 ... 4.34470520e+01\n", + " 7.55885468e+01 -2.88791332e+01]\n", + " [-2.07757936e+14 -3.81796265e+02 -2.33774780e+02 ... 8.11984329e+01\n", + " -4.41825638e+01 7.35064125e+00]\n", + " [ 4.11484165e+13 2.50572113e+02 1.91601822e+02 ... 1.00269365e+01\n", + " -3.41638985e+01 1.20433075e+02]\n", + " ...\n", + " [ 7.95562329e+13 1.55007373e+03 1.70351212e+02 ... -1.80320053e+01\n", + " 8.77533417e+01 2.14678173e+01]\n", + " [-1.86546485e+14 -5.18847070e+03 -3.34331085e+02 ... 2.51586838e+01\n", + " -4.06135368e+01 -6.27860641e+00]\n", + " [ 1.89751705e+14 -3.09853809e+03 -1.18278351e+01 ... -1.24640663e+02\n", + " 1.59719009e+01 -6.47173615e+01]]\n", + "ff_attn_in: (768, 24)\n", + "[[-1.17282034e+14 -2.12461694e+03 8.80101547e+01 ... 4.34468918e+01\n", + " 7.55886002e+01 -2.88791542e+01]\n", + " [-2.07757920e+14 -3.81795776e+02 -2.33774765e+02 ... 8.11985397e+01\n", + " -4.41825829e+01 7.35066986e+00]\n", + " [ 4.11484543e+13 2.50570099e+02 1.91601196e+02 ... 1.00270777e+01\n", + " -3.41638451e+01 1.20433121e+02]\n", + " ...\n", + " [ 7.95562413e+13 1.55007288e+03 1.70350784e+02 ... -1.80321960e+01\n", + " 8.77533112e+01 2.14678249e+01]\n", + " [-1.86546469e+14 -5.18847070e+03 -3.34331268e+02 ... 2.51588135e+01\n", + " -4.06132622e+01 -6.27861023e+00]\n", + " [ 1.89751521e+14 -3.09853711e+03 -1.18275299e+01 ... -1.24640862e+02\n", + " 1.59719791e+01 -6.47173767e+01]]\n", + "7.530381944444445% mismatch in attention input grads\n", + "\n", + "Huggingface checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "FlexFlow checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "Huggingface-FlexFlow checks:\n", + "-- W2 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.8.mlp.down_proj.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_8_layers_8_feed_forward_w2_shard_0_output_0\n", + "HF: [-1.3223293e+17 -2.3794983e+17 4.7027590e+16 ... 7.7873253e+01\n", + " 8.6085976e+01 6.8200005e+01]\n", + "FF:[-1.32232886e+17 -2.37949812e+17 4.70276284e+16 ... 7.78733292e+01\n", + " 8.60859299e+01 6.82000580e+01]\n", + "[ True True True ... True True True]\n", + "[ 3 24 66 71 94 95 124 134 141 150 163 181\n", + " 226 261 284 318 320 378 382 385 391 395 403 422\n", + " 434 495 515 523 524 549 579 610 644 710 764 772\n", + " 870 984 987 1045 1249 1330 1362 1489 1517 1550 1556 1588\n", + " 1595 1659 1672 1684 1689 1768 1792 1799 1808 1818 1842 1871\n", + " 1889 1899 1910 1915 1925 1936 1993 1997 2033 2041 2059 2062\n", + " 2066 2098 2111 2124 2129 2130 2146 2153 2159 2166 2197 2206\n", + " 2210 2212 2222 2234 2237 2320 2321 2357 2359 2362 2385 2428\n", + " 2518 2539 2553 2568 2598 2683 2689 2694 2711 2714 2733 2787\n", + " 2788 2795 2811 2815 2853 2881 2890 2917 2981 2997 3021 3037\n", + " 3089 3149 3163 3191 3196 3217 3225 3248 3277 3287 3292 3305\n", + " 3327 3361 3385 3402 3417 3425 3456 3479 3516 3521 3528 3555\n", + " 3587 3599 3608 3684 3702 3733 3770 3779 3819 3822 3823 3898\n", + " 3921 3942 3950 4012 4053 4077 4086 4091 4139 4185 4198 4225\n", + " 4241 4296 4347 4349 4368 4403 4407 4418 4453 4471 4472 4473\n", + " 4494 4537 4549 4555 4558 4598 4623 4648 4666 4698 4729 4782\n", + " 4848 4866 4886 4943 4959 5008 5010 5012 5057 5079 5177 5178\n", + " 5186 5211 5271 5281 5296 5313 5328 5356 5364 5409 5429 5440\n", + " 5453 5455 5457 5476 5529 5563 5591 5621 5625 5631 5654 5661\n", + " 5692 5705 5720 5740 5751 5758 5787 5799 5813 5835 5836 5867\n", + " 5872 5893 5953 5974 5980 5982 6000 6055 6082 6086 6102 6107\n", + " 6123 6159 6172 6193 6220 6230 6231 6263 6286 6297 6362 6396\n", + " 6401 6430 6436 6485 6497 6499 6502 6510 6537 6554 6555 6563\n", + " 6564 6579 6586 6598 6615 6625 6626 6649 6651 6661 6754 6764\n", + " 6776 6852 6863 6874 6883 6892 6913 6945 6969 7036 7057 7066\n", + " 7082 7138 7147 7150 7157 7197 7202 7231 7234 7235 7240 7270\n", + " 7278 7287 7322 7327 7345 7348 7361 7390 7402 7490 7539 7573\n", + " 7610 7714 7721 7758 7794 7812 7827 7829 7837 7839 7882 7894\n", + " 7943 7948 7952 7969 7975 7996 8024 8027 8037 8043 8055 8078\n", + " 8079 8088 8090 8095 8154 8258 8264 8283 8297 8313 8329 8336\n", + " 8359 8361 8376 8383 8416 8421 8428 8454 8475 8502 8521 8613\n", + " 8642 8653 8696 8756 8764 8777 8791 8837 8849 8859 8878 8955\n", + " 8991 8997 9006 9012 9040 9066 9093 9097 9098 9131 9158 9162\n", + " 9165 9214 9216 9280 9297 9301 9316 9355 9371 9412 9421 9475\n", + " 9510 9580 9620 9645 9696 9713 9732 9768 9802 9817 9819 9826\n", + " 9839 9846 9947 10004 10062 10065 10072 10103 10107 10108 10138 10167\n", + " 10173 10228 10262 10292 10326 10356 10360 10372 10421 10446 10466 10468\n", + " 10499 10505 10513 10517 10589 10606 10612 10645 10664 10669 10726 10777\n", + " 10835 10838 10839 10848 10855 10877 10897 10941 10963 10971 10977 10997\n", + " 11030 11060 11065 11076 11088 11140 11167 11174 11231 11252 11257 11259\n", + " 11275 11297 11302 11319 11331 11333 11357 11358 11380 11382 11402 11423\n", + " 11446 11447 11500 11501 11522 11585 11623 11670 11728 11736 11759 11761\n", + " 11772 11785 11839 11894 11916 11924 11936 11962 11968 11969 11977 11984\n", + " 12008 12030 12054 12074 12123 12175 12182 12194 12237 12262 12282 12285\n", + " 12341 12348 12351 12370 12376 12386 12399 12449 12507 12513 12518 12522\n", + " 12549 12572 12643 12648 12663 12689 12696 12710 12769 12780 12788 12792\n", + " 12793 12852 12864 12879 12884 12985 13018 13041 13057 13176 13264 13272\n", + " 13274 13275 13292 13303 13333 13379 13427 13428 13442 13451 13454 13500\n", + " 13510 13533 13564 13588 13607 13640 13655 13686 13687 13688 13732 13747\n", + " 13786 13801 13803 13826 13841 13846 13850 13892 13909 13946 14036 14040\n", + " 14046 14060 14080 14152 14161 14183 14195 14210 14240 14278 14331 14354\n", + " 14370 14372 14386 14395 14409 14432 14434 14497 14506 14531 14559 14589\n", + " 14648 14663 14686 14698 14715 14743 14757 14799 14808 14810 14849 14893\n", + " 14902 14929 14937 14947 14953 14958 15005 15012 15018 15036 15066 15069\n", + " 15083 15152 15154 15196 15197 15212 15292 15309 15323 15340 15343 15375\n", + " 15389 15396 15408 15410 15454 15499 15532 15557 15605 15647 15677 15736\n", + " 15745 15756 15769 15809 15824 15876 15882 15900 15906 15941 16027 16030\n", + " 16040 16116 16190 16192 16205 16207 16239 16279 16285 16295 16348 16358\n", + " 16367 16384 16386 16394 16399 16455 16457 16458 16471 16495 16500 16502\n", + " 16520 16541 16542 16598 16623 16643 16651 16665 16673 16679 16713 16725\n", + " 16734 16736 16739 16751 16756 16768 16861 16870 16939 16976 17007 17028\n", + " 17040 17069 17087 17108 17125 17139 17151 17158 17174 17175 17178 17182\n", + " 17189 17221 17258 17341 17360 17370 17381 17395 17396 17415 17432 17450\n", + " 17463 17470 17472 17473 17496 17507 17536 17608 17626 17627 17649 17653\n", + " 17664 17771 17815 17822 17831 17864 17883 17931 17994 17999 18035 18174\n", + " 18209 18250 18274 18307 18327 18403 18423]\n", + "Ok!\n", + "Ok!\n", + "-- Lora --\n", + "Ok!\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.8.mlp.down_proj.lora_B.default.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_8_layers_8_feed_forward_w2_lora_shard_0_output_0\n", + "HF: [-1.3223293e+17 -2.3794983e+17 4.7027590e+16 ... 7.7873253e+01\n", + " 8.6085976e+01 6.8200005e+01]\n", + "FF:[-1.32232886e+17 -2.37949812e+17 4.70276284e+16 ... 7.78733292e+01\n", + " 8.60859299e+01 6.82000580e+01]\n", + "[ True True True ... True True True]\n", + "[ 3 24 66 71 94 95 124 134 141 150 163 181 226 261 284 318 320 378\n", + " 382 385 391 395 403 422 434 495 515 523 524 549 579 610 644 710 764]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.8.mlp.down_proj.lora_A.default.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_8_layers_8_feed_forward_w2_lora_shard_0_input_0\n", + "HF: [ 6.5550952e+14 4.9376585e+14 3.8510841e+14 ... 1.6802770e+00\n", + " -1.1248941e+00 -1.1701980e+00]\n", + "FF:[ 6.55509317e+14 4.93765882e+14 3.85108377e+14 ... 1.68027747e+00\n", + " -1.12489426e+00 -1.17019880e+00]\n", + "[ True True True ... True True True]\n", + "[ 6 79 111 149 155 168 187 195 220 223 252 261 329 343\n", + " 347 369 386 392 403 438 439 450 461 524 535 643 656 659\n", + " 661 668 722 727 732 742 754 801 816 820 835 837 849 850\n", + " 978 993 997 1012 1019 1034 1044 1071 1088 1094 1114 1135 1151 1170\n", + " 1190 1212 1273 1275 1277 1289 1290 1308 1311 1337 1364 1379 1394 1430\n", + " 1454 1460 1469 1474 1703 1725 1728 1732 1733 1741 1754 1757 1804 1806\n", + " 1856 1862 1932 1945 1996 2030 2044 2045 2065 2071 2075 2094 2149 2152\n", + " 2163 2180 2182 2215 2254 2357 2362 2370 2392 2398 2428 2484 2519 2521\n", + " 2524 2582 2618 2641 2645 2664 2674 2681 2691 2735 2747 2779 2872 2899\n", + " 2909 2935 2957 3000 3033]\n", + "Ok!\n", + "-- W2/W1/W3 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.8.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_8_SigmoidSiluMulti_shard_0_output_0\n", + "HF: [-1.3871785e+17 -8.3164397e+16 4.9509505e+16 ... 4.3806694e+01\n", + " 9.4386072e+00 -2.4460859e+01]\n", + "FF:[-1.38717840e+17 -8.31644654e+16 4.95094495e+16 ... 4.38065948e+01\n", + " 9.43864822e+00 -2.44608364e+01]\n", + "[ True True True ... True True True]\n", + "[ 80 83 172 173 176 184 215 285 329 338 341 395 403 465\n", + " 468 565 572 601 614 636 639 651 660 749 750 806 828 844\n", + " 873 952 971 988 992 1014 1082 1083 1085 1123 1152 1195 1200 1227\n", + " 1391 1397 1462 1546 1548 1563 1584 1629 1704 1706 1759 1764 1820 1833\n", + " 1851 1857 1864 1899 1929 1943 1958 1967 1980 1985 2002 2030 2069 2076\n", + " 2120 2127 2130 2157 2180 2187 2195 2212 2243 2249 2256 2299 2393 2505\n", + " 2516 2525 2546 2562 2604 2702 2712 2731 2745 2764 2789 2821 2873 2915\n", + " 2936 2945 2951 3013 3016]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.8.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_8_layers_8_feed_forward_w2_shard_0_input_0\n", + "HF: [-1.3871785e+17 -8.3164397e+16 4.9509505e+16 ... 4.3806694e+01\n", + " 9.4386072e+00 -2.4460859e+01]\n", + "FF:[-1.38717840e+17 -8.31644654e+16 4.95094495e+16 ... 4.38065948e+01\n", + " 9.43864822e+00 -2.44608364e+01]\n", + "[ True True True ... True True True]\n", + "[ 80 83 172 173 176 184 215 285 329 338 341 395 403 465\n", + " 468 565 572 601 614 636 639 651 660 749 750 806 828 844\n", + " 873 952 971 988 992 1014 1082 1083 1085 1123 1152 1195 1200 1227\n", + " 1391 1397 1462 1546 1548 1563 1584 1629 1704 1706 1759 1764 1820 1833\n", + " 1851 1857 1864 1899 1929 1943 1958 1967 1980 1985 2002 2030 2069 2076\n", + " 2120 2127 2130 2157 2180 2187 2195 2212 2243 2249 2256 2299 2393 2505\n", + " 2516 2525 2546 2562 2604 2702 2712 2731 2745 2764 2789 2821 2873 2915\n", + " 2936 2945 2951 3013 3016]\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "-- Attention --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.8.self_attn.o_proj.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_8_layers_8_attention_shard_0_output_0\n", + "HF: [-1.3223293e+17 -2.3794983e+17 4.7027590e+16 ... 3.5121140e+01\n", + " -3.5587997e+00 9.5641022e+01]\n", + "FF:[-1.32232886e+17 -2.37949812e+17 4.70276284e+16 ... 3.51211472e+01\n", + " -3.55898285e+00 9.56410980e+01]\n", + "[ True True True ... True True True]\n", + "[ 3 24 66 71 94 95 124 134 141 150 163 181 226 261 284 318 320 378\n", + " 382 385 391 395 403 422 434 495 515 523 524 549 579 610 644 710 764]\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.8.self_attn.o_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_8_layers_8_attention_shard_0_o_proj_in_grad\n", + "HF: [-1.6186993e+17 -3.5698813e+17 3.4442975e+16 ... -2.5844165e+02\n", + " 2.0677340e+01 -2.4573349e+01]\n", + "FF:[-1.61869621e+17 -3.56988336e+17 3.44430865e+16 ... -2.58441467e+02\n", + " 2.06775093e+01 -2.45735531e+01]\n", + "[ True True True ... True True True]\n", + "[ 93 99 114 137 141 142 160 193 235 259 269 299 307 316 350 364 400 523\n", + " 608 702 720 731 759]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[-1.6186993e+17 -2.1968115e+02 8.5754425e+01 ... -6.9909119e+01\n", + " -2.6478451e+01 -7.4195160e+01]\n", + " [-3.5698813e+17 3.9582391e+02 5.5431940e+02 ... 1.9529277e+02\n", + " 1.2558211e+02 6.7965935e+01]\n", + " [ 3.4442975e+16 2.8310864e+02 -8.1522171e+01 ... -2.3606525e+01\n", + " -2.0410315e+01 -1.5228156e+02]\n", + " ...\n", + " [ 4.0923264e+16 -2.4507169e+02 -8.2614380e+02 ... -2.6583340e+02\n", + " -1.9878247e+02 -2.5844165e+02]\n", + " [ 6.9156258e+17 1.3969666e+02 -7.5639044e+02 ... -1.5231053e+02\n", + " -3.3650037e+02 2.0677340e+01]\n", + " [ 9.9511712e+16 -3.2348724e+01 3.0624988e+02 ... 1.0391423e+02\n", + " 6.0626881e+01 -2.4573349e+01]]\n", + "FF:[[-1.61869621e+17 -2.19681122e+02 8.57541504e+01 ... -6.99092026e+01\n", + " -2.64783611e+01 -7.41952515e+01]\n", + " [-3.56988336e+17 3.95823853e+02 5.54319275e+02 ... 1.95292725e+02\n", + " 1.25582062e+02 6.79659348e+01]\n", + " [ 3.44430865e+16 2.83108551e+02 -8.15224686e+01 ... -2.36064014e+01\n", + " -2.04101429e+01 -1.52281570e+02]\n", + " ...\n", + " [ 4.09233933e+16 -2.45071564e+02 -8.26143555e+02 ... -2.65833405e+02\n", + " -1.98782272e+02 -2.58441467e+02]\n", + " [ 6.91562577e+17 1.39696579e+02 -7.56390808e+02 ... -1.52310455e+02\n", + " -3.36500092e+02 2.06775093e+01]\n", + " [ 9.95114373e+16 -3.23486938e+01 3.06250122e+02 ... 1.03914482e+02\n", + " 6.06264191e+01 -2.45735531e+01]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[ 93 99 114 137 141 142 160 193 235 259 269 299 307 316 350 364 400 523\n", + " 608 702 720 731 759]\n", + "Ok!\n", + "mismatch between hf_tensor and ff_tensor\n", + "HF: [[-1.6186993e+17 -3.5698813e+17 3.4442975e+16 ... 4.0923264e+16\n", + " 6.9156258e+17 9.9511712e+16]\n", + " [-5.3483575e+02 2.6249797e+03 -6.7268573e+02 ... -6.1204077e+03\n", + " -4.3047915e+03 -9.5139771e+01]\n", + " [-1.2200641e+01 1.0347147e+02 -2.6777636e+01 ... -1.4766699e+02\n", + " -9.8514114e+01 1.2616925e+01]\n", + " ...\n", + " [-3.2097631e+00 9.1431990e+00 -1.6333975e+00 ... -6.9996667e+00\n", + " -6.4008064e+00 1.9126304e+00]\n", + " [-3.0982289e+00 1.2355285e+01 -3.1715555e+00 ... -4.6754313e+00\n", + " -6.2553053e+00 1.0515085e+00]\n", + " [-2.9516125e+00 2.7038031e+00 -6.0580249e+00 ... -1.6555168e+01\n", + " 1.3245420e+00 -1.5741113e+00]]\n", + "FF:[[-1.61869621e+17 -3.56988336e+17 3.44430865e+16 ... 4.09233933e+16\n", + " 6.91562577e+17 9.95114373e+16]\n", + " [-5.34834961e+02 2.62497900e+03 -6.72686401e+02 ... -6.12040576e+03\n", + " -4.30479297e+03 -9.51402283e+01]\n", + " [-1.22006664e+01 1.03471611e+02 -2.67777309e+01 ... -1.47666946e+02\n", + " -9.85141525e+01 1.26169167e+01]\n", + " ...\n", + " [-3.20977211e+00 9.14321709e+00 -1.63339353e+00 ... -6.99966621e+00\n", + " -6.40081263e+00 1.91262615e+00]\n", + " [-3.09821057e+00 1.23552399e+01 -3.17152786e+00 ... -4.67541933e+00\n", + " -6.25528765e+00 1.05149710e+00]\n", + " [-2.95161533e+00 2.70380235e+00 -6.05802393e+00 ... -1.65551491e+01\n", + " 1.32455230e+00 -1.57412362e+00]]\n", + "[[ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " ...\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]\n", + " [ True True True ... True True True]]\n", + "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n", + "Ok!\n", + "8.101851851851851% mismatch in QK prods softmax out grad\n", + "Ok!\n", + "hf_attn_in: (768, 24)\n", + "[[-7.3778828e+16 1.0956941e+03 1.1773144e+02 ... -4.0466427e+01\n", + " -3.1198654e+01 -1.7603550e+01]\n", + " [-1.2087128e+18 6.9384756e+03 6.1327003e+01 ... 1.5329468e+01\n", + " 7.6757736e+00 -4.5589094e+00]\n", + " [-6.7892266e+17 5.4895034e+03 7.6927376e+01 ... 9.1396770e+00\n", + " 2.3195824e+01 -6.1995559e+00]\n", + " ...\n", + " [ 2.6452032e+17 9.9761787e+03 2.2349066e+02 ... 5.7504387e+01\n", + " -8.6791611e-01 4.6890911e+01]\n", + " [-6.7528534e+16 3.3856902e+03 2.5189743e+02 ... 2.2824722e+01\n", + " 8.7917282e+01 -2.1569672e+01]\n", + " [-2.1779064e+17 5.2511855e+03 6.6282043e+01 ... 9.9689598e+00\n", + " -5.5022659e+00 -3.2573143e+01]]\n", + "ff_attn_in: (768, 24)\n", + "[[-7.37791458e+16 1.09569678e+03 1.17731285e+02 ... -4.04664154e+01\n", + " -3.11988506e+01 -1.76035423e+01]\n", + " [-1.20871251e+18 6.93847900e+03 6.13275528e+01 ... 1.53295393e+01\n", + " 7.67594433e+00 -4.55900288e+00]\n", + " [-6.78922523e+17 5.48950342e+03 7.69272308e+01 ... 9.13961220e+00\n", + " 2.31957569e+01 -6.19959354e+00]\n", + " ...\n", + " [ 2.64520284e+17 9.97617871e+03 2.23490509e+02 ... 5.75044785e+01\n", + " -8.67943764e-01 4.68908234e+01]\n", + " [-6.75287400e+16 3.38569165e+03 2.51897339e+02 ... 2.28247147e+01\n", + " 8.79171448e+01 -2.15696106e+01]\n", + " [-2.17790679e+17 5.25118652e+03 6.62821960e+01 ... 9.96885872e+00\n", + " -5.50213098e+00 -3.25731125e+01]]\n", + "9.809027777777777% mismatch in attention input grads\n", + "\n", + "Huggingface checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "FlexFlow checks:\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "\n", + "Huggingface-FlexFlow checks:\n", + "-- W2 --\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.7.mlp.down_proj.go_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/bwd_step_0_layers_7_layers_7_feed_forward_w2_shard_0_output_0\n", + "HF: [-7.5522525e+19 -1.3283726e+21 -7.2549753e+20 ... 4.9017162e+01\n", + " -9.7436657e+00 8.5870697e+01]\n", + "FF:[-7.55228501e+19 -1.32837218e+21 -7.25497390e+20 ... 4.90171394e+01\n", + " -9.74382782e+00 8.58707886e+01]\n", + "[ True True True ... True False True]\n", + "[ 19 64 75 ... 18418 18428 18430]\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[23], line 95\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mHuggingface-FlexFlow checks:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 94\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m-- W2 --\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 95\u001b[0m \u001b[43mcompare_tensors\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhf_BWD_w2_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mff_BWD_w2_out\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtolerance\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-5\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 96\u001b[0m compare_tensors(hf_w2_weight, ff_w2_weight, tolerance\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m)\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m-- Lora --\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Desktop/FlexFlow/tests/peft/align_test_utils.py:47\u001b[0m, in \u001b[0;36mcompare_tensors\u001b[0;34m(hf_tensor_filepath, ff_tensor_filepath, tolerance)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28mprint\u001b[39m(mismatches)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;66;03m#print(np.nonzero(hf_tensor)[0])\u001b[39;00m\n\u001b[1;32m 44\u001b[0m \u001b[38;5;66;03m# print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0])\u001b[39;00m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;66;03m# print(ff_tensor[36], hf_tensor[36])\u001b[39;00m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m#assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\u001b[39;00m\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m(\u001b[38;5;28mlen\u001b[39m(mismatches) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m.05\u001b[39m\u001b[38;5;241m*\u001b[39mlen_hf_tensor)\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOk!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "tot_num_layers = 12\n", + "attention_tests=True\n", + "for i in range(tot_num_layers-1, -1, -1):\n", + " # HuggingFace filepaths\n", + " hf_BWD_norm_in = f\"{hf_path}/bwd_step_0_norm.gi_0\"\n", + " hf_BWD_loraB_out = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.lora_B.default.go_0\"\n", + " hf_BWD_loraB_in = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.lora_B.default.gi_0\"\n", + " hf_BWD_loraA_out = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.lora_A.default.go_0\"\n", + " hf_BWD_loraA_in = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.lora_A.default.gi_0\"\n", + " hf_loraA_weight = f\"{hf_path}/layers.{i}.mlp.down_proj.lora_A.default.weight\"\n", + " hf_loraB_weight = f\"{hf_path}/layers.{i}.mlp.down_proj.lora_B.default.weight\"\n", + " hf_BWD_lora_dropout_out = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.lora_dropout.default.go_0\"\n", + " hf_BWD_lora_dropout_in = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.lora_dropout.default.gi_0\"\n", + " hf_BWD_w2_out = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.go_0\"\n", + " hf_BWD_w2_in = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.down_proj.gi_0\"\n", + " hf_w2_weight = f\"{hf_path}/layers.{i}.mlp.down_proj.weight\"\n", + " hf_BWD_w3_out = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.up_proj.go_0\"\n", + " hf_BWD_w3_in = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.up_proj.gi_0\"\n", + " hf_BWD_w1_out = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.gate_proj.go_0\"\n", + " hf_BWD_w1_in = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.gate_proj.gi_0\"\n", + " hf_BWD_act_fn_in = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.act_fn.gi_0\"\n", + " hf_BWD_act_fn_out = f\"{hf_path}/bwd_step_0_layers.{i}.mlp.act_fn.go_0\"\n", + " hf_BWD_ffn_norm_out = f\"{hf_path}/bwd_step_0_layers.{i}.post_attention_layernorm.go_0\"\n", + " hf_BWD_ffn_norm_in = f\"{hf_path}/bwd_step_0_layers.{i}.post_attention_layernorm.gi_0\"\n", + " hf_BWD_attn_out_out = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.o_proj.go_0\"\n", + " hf_BWD_attn_q_in = f\"{hf_path}/bwd_step_0_layers.11.self_attn.q_proj.gi_0\"\n", + " hf_FWD_w1_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.gate_proj.output_0\"\n", + " hf_FWD_w3_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.up_proj.output_0\"\n", + " hf_FWD_act_fn_out = f\"{hf_path}/fwd_step_0_layers.{i}.mlp.act_fn.output_0\"\n", + " hf_BWD_attn_oproj_in = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.o_proj.gi_0\"\n", + " hf_attn_qproj_weight = f\"{hf_path}/layers.{i}.self_attn.q_proj.weight\"\n", + " hf_attn_kproj_weight = f\"{hf_path}/layers.{i}.self_attn.k_proj.weight\"\n", + " hf_attn_vproj_weight = f\"{hf_path}/layers.{i}.self_attn.v_proj.weight\"\n", + " hf_attn_oproj_weight = f\"{hf_path}/layers.{i}.self_attn.o_proj.weight\"\n", + " \n", + " # FlexFlow filepaths\n", + " ff_BWD_w2_out = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_shard_0_output_0\"\n", + " ff_BWD_w2_in = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_shard_0_input_0\"\n", + " ff_BWD_w2_in_pre = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_shard_0_pre_input_0\"\n", + " ff_w2_weight = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_shard_0_weight_0\"\n", + " ff_BWD_ssm_out = f\"{ff_path}/bwd_step_0_layers_{i}_SigmoidSiluMulti_shard_0_output_0\"\n", + " ff_BWD_ssm_in1 = f\"{ff_path}/bwd_step_0_layers_{i}_SigmoidSiluMulti_shard_0_input_0\"\n", + " ff_BWD_ssm_in2 = f\"{ff_path}/bwd_step_0_layers_{i}_SigmoidSiluMulti_shard_0_input_1\"\n", + " ff_BWD_w3_out = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w3_shard_0_output_0\"\n", + " ff_BWD_w3_in = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w3_shard_0_input_0\"\n", + " ff_BWD_lora_A_in = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_input_0\"\n", + " ff_BWD_lora_B_out = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_output_0\"\n", + " ff_lora_A_weight = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_weight_A\"\n", + " ff_lora_B_weight = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_feed_forward_w2_lora_shard_0_weight_B\"\n", + " ff_BWD_w1_out = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w1_shard_0_output_0\"\n", + " ff_BWD_w1_in = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w1_shard_0_input_0\"\n", + " ff_BWD_w1_in_pre = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w1_shard_0_pre_input_0\"\n", + " ff_w1_weight = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_feed_forward_w1_shard_0_weight_0\"\n", + " ff_BWD_ffn_norm_in1 = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_ffn_norm_shard_0_input_0\"\n", + " ff_BWD_ffn_norm_in2 = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_ffn_norm_shard_0_input_1\"\n", + " ff_BWD_ffn_norm_out = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_ffn_norm_shard_0_output_0\"\n", + " ff_BWD_attn_out = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_output_0\"\n", + " ff_BWD_attn_in = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_input_0\"\n", + " ff_BWD_ssm_cached_w1_input = f\"{ff_path}/bwd_step_0_layers_{i}_SigmoidSiluMulti_shard_0_cached_w1_output\"\n", + " ff_BWD_ssm_cached_w3_input = f\"{ff_path}/bwd_step_0_layers_{i}_SigmoidSiluMulti_shard_0_cached_w3_output\"\n", + " ff_FWD_w1_out = f\"{ff_path}/fwd_step_0_layers_0_layers_0_feed_forward_w1_shard_0_output_0\"\n", + " ff_FWD_w3_out = f\"{ff_path}/fwd_step_0_layers_0_layers_0_feed_forward_w3_shard_0_output_0\"\n", + " ff_FWD_act_fnc_out = f\"{ff_path}/bwd_step_0_layers_{i}_SigmoidSiluMulti_shard_0_act_fn_output\"\n", + " ff_BWD_attn_o_proj_in = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_o_proj_in_grad\"\n", + " ff_attn_oproj_weight = f\"{ff_path}/fwd_step_0_layers_{i}_layers_{i}_attention_shard_0_weight_0\"\n", + " \n", + " \n", + " # HuggingFace checks\n", + " print(\"\\nHuggingface checks:\")\n", + " if i == tot_num_layers-1:\n", + " compare_hf_tensors(hf_BWD_norm_in, hf_BWD_loraB_out)\n", + " compare_hf_tensors(hf_BWD_norm_in, hf_BWD_w2_out)\n", + " compare_hf_tensors(hf_BWD_loraB_out, hf_BWD_w2_out)\n", + " compare_hf_tensors(hf_BWD_loraB_in, hf_BWD_loraA_out)\n", + "\n", + " compare_hf_tensors(hf_BWD_act_fn_in, hf_BWD_w1_out)\n", + " check_hf_sum_tensors(hf_BWD_ffn_norm_out, hf_BWD_w1_in, hf_BWD_w3_in)\n", + " if i == tot_num_layers-1:\n", + " check_hf_sum_tensors(hf_BWD_attn_out_out, hf_BWD_ffn_norm_in, hf_BWD_norm_in)\n", + "\n", + " # FlexFlow checks\n", + " print(\"\\nFlexFlow checks:\")\n", + " compare_flexflow_tensors(ff_BWD_w2_out, ff_BWD_lora_B_out)\n", + " compare_flexflow_tensors(ff_BWD_w2_in_pre, ff_BWD_lora_A_in)\n", + " compare_flexflow_tensors(ff_BWD_w2_in, ff_BWD_ssm_out)\n", + " compare_flexflow_tensors(ff_BWD_ssm_in2, ff_BWD_w3_out)\n", + " compare_flexflow_tensors(ff_BWD_ssm_in1, ff_BWD_w1_out)\n", + " compare_flexflow_tensors(ff_BWD_w1_in, ff_BWD_ffn_norm_out)\n", + " compare_flexflow_tensors(ff_BWD_w1_in_pre, ff_BWD_w3_in)\n", + " compare_flexflow_tensors(ff_BWD_ffn_norm_in1, ff_BWD_ffn_norm_in2, max_len=24*768)\n", + " \n", + " # HF-FlexFlow checks\n", + " print(\"\\nHuggingface-FlexFlow checks:\")\n", + " print(\"-- W2 --\")\n", + " compare_tensors(hf_BWD_w2_out, ff_BWD_w2_out, tolerance=1e-5)\n", + " compare_tensors(hf_w2_weight, ff_w2_weight, tolerance=1e-5)\n", + " \n", + " print(\"-- Lora --\")\n", + " compare_tensors(hf_loraA_weight, ff_lora_A_weight, tolerance=1e-5)\n", + " compare_tensors(hf_loraB_weight, ff_lora_B_weight, tolerance=1e-5)\n", + "\n", + " compare_tensors(hf_BWD_loraB_out, ff_BWD_lora_B_out)\n", + " compare_tensors(hf_BWD_loraA_in, ff_BWD_lora_A_in)\n", + " \n", + " print(\"-- W2/W1/W3 --\")\n", + " compare_tensors(hf_BWD_w2_in, ff_BWD_ssm_out)\n", + " compare_tensors(hf_BWD_w2_in, ff_BWD_w2_in)\n", + " compare_tensors(hf_BWD_w1_out, ff_BWD_w1_out)\n", + " compare_tensors_difference(hf_BWD_w1_in, ff_BWD_w1_in, ff_BWD_w1_in_pre)\n", + " compare_tensors(hf_BWD_w3_out, ff_BWD_w3_out)\n", + " compare_tensors(hf_BWD_w3_in, ff_BWD_w3_in)\n", + " compare_tensors(hf_BWD_w1_out, ff_BWD_w1_out)\n", + " \n", + " print(\"-- Attention --\")\n", + " compare_tensors(hf_BWD_attn_out_out, ff_BWD_attn_out)\n", + " hidden_size = 768\n", + " qProjSize = 64\n", + " num_heads = 12\n", + " num_new_tokens = num_tokens = 24\n", + " if attention_tests:\n", + " # compare attn weight tensors\n", + " ff_attn_weight_tensor = np.loadtxt(ff_attn_oproj_weight, delimiter=',')\n", + " ff_attn_qproj_weight_tensor = ff_attn_weight_tensor[:hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')\n", + " ff_attn_kproj_weight_tensor = ff_attn_weight_tensor[hidden_size*qProjSize*num_heads:2*hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')\n", + " ff_attn_vproj_weight_tensor = ff_attn_weight_tensor[2*hidden_size*qProjSize*num_heads:3*hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')\n", + " ff_attn_oproj_weight_tensor = ff_attn_weight_tensor[3*hidden_size*qProjSize*num_heads:].reshape((qProjSize*num_heads,hidden_size), order='F')\n", + " \n", + " hf_attn_qproj_weight_tensor = torch.load(hf_attn_qproj_weight).T.detach().cpu().numpy()\n", + " hf_attn_kproj_weight_tensor = torch.load(hf_attn_kproj_weight).T.detach().cpu().numpy()\n", + " hf_attn_vproj_weight_tensor = torch.load(hf_attn_vproj_weight).T.detach().cpu().numpy()\n", + " hf_attn_oproj_weight_tensor = torch.load(hf_attn_oproj_weight).T.detach().cpu().numpy()\n", + " \n", + " assert(np.allclose(ff_attn_qproj_weight_tensor, hf_attn_qproj_weight_tensor, atol=1e-5))\n", + " assert(np.allclose(ff_attn_kproj_weight_tensor, hf_attn_kproj_weight_tensor, atol=1e-5))\n", + " assert(np.allclose(ff_attn_vproj_weight_tensor, hf_attn_vproj_weight_tensor, atol=1e-5))\n", + " assert(np.allclose(ff_attn_oproj_weight_tensor, hf_attn_oproj_weight_tensor, atol=1e-5))\n", + " \n", + " # Compare attn outproj grad in tensors\n", + " compare_tensors(hf_BWD_attn_oproj_in, ff_BWD_attn_o_proj_in)\n", + " \n", + " ########### Compare value projs grads ######################\n", + " # 1. compare qk prods softmax\n", + " hf_qk_prods_softmax = f\"{hf_path}/fwd_step_0_layers.{i}.self_attn.qk_prods_softmax.output_0\"\n", + " ff_attn_qk_prods_softmax = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_qk_prods_softmax\"\n", + " \n", + " hf_qk_prods_softmax = torch.load(hf_qk_prods_softmax)\n", + " ff_qk_prods_softmax = np.loadtxt(ff_attn_qk_prods_softmax, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", + "\n", + " for head_idx in range(num_heads):\n", + " hf_qkps = hf_qk_prods_softmax.squeeze()[head_idx, :, :].detach().cpu().numpy()\n", + " ff_qkps = ff_qk_prods_softmax[:,:,head_idx]\n", + " assert(np.allclose(ff_qkps, hf_qkps, atol=1e-5))\n", + " \n", + " # 2. compare attn heads grads\n", + " hf_attn_heads_grads = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.o_proj.gi_0\"\n", + " ff_attn_heads_grads = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_o_proj_in_grad\"\n", + "\n", + " hf_attn_heads_grads = torch.load(hf_attn_heads_grads).T.squeeze().detach().cpu().numpy()\n", + " ff_attn_heads_grads = np.loadtxt(ff_attn_heads_grads, delimiter=',').reshape((qProjSize*num_heads, num_new_tokens), order = 'F')\n", + " # NEED TO VISUALLY INSPECT\n", + " compare_loaded_tensors(hf_attn_heads_grads, ff_attn_heads_grads)\n", + "\n", + " # 3. vproj grads\n", + " hf_vproj_grads = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.v_proj.go_0\"\n", + " ff_vproj_grads = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_v_proj_in_grad\"\n", + "\n", + " hf_vproj_grads = torch.load(hf_vproj_grads).squeeze().detach().cpu().numpy()\n", + " ff_vproj_grads = np.loadtxt(ff_vproj_grads, delimiter=',').reshape((num_tokens, qProjSize*num_heads), order='F')\n", + " compare_loaded_tensors(hf_vproj_grads, ff_vproj_grads)\n", + " \n", + " \n", + " ##############################\n", + " hf_value_states = f\"{hf_path}/fwd_step_0_layers.{i}.self_attn.value_states.output_0\"\n", + " hf_value_states = torch.load(hf_value_states).squeeze().permute(2,0,1).detach().cpu().numpy()\n", + " # print(hf_value_states.shape)\n", + " ff_value_states = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_vcache\"\n", + " ff_value_states = np.loadtxt(ff_value_states, delimiter=',').reshape((qProjSize, num_heads, num_tokens), order='F')\n", + " # print(ff_value_states.shape)\n", + " assert(np.allclose(hf_value_states, ff_value_states, atol=1e-2))\n", + " \n", + " \n", + " \n", + " ########## Compare key and query projs grads ##################\n", + " ff_devQKVPRojArray = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_devQKVPRojArray\"\n", + " ff_devQKVPRojArray = np.loadtxt(ff_devQKVPRojArray, delimiter=',').reshape((num_tokens, qProjSize*num_heads, 3), order = 'F')\n", + " ff_qProjGrads = ff_devQKVPRojArray[:,:,0]\n", + " ff_kProjGrads = ff_devQKVPRojArray[:,:,1]\n", + " ff_vProjGrads = ff_devQKVPRojArray[:,:,2]\n", + " assert(np.allclose(ff_vProjGrads, ff_vproj_grads, atol=1e-5))\n", + "\n", + " # simulate qk_prods_softmax\n", + " ff_attn_heads_grads = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_o_proj_in_grad\"\n", + " ff_attn_heads_grads = np.loadtxt(ff_attn_heads_grads, delimiter=',').reshape((qProjSize,num_heads, num_new_tokens), order = 'F')\n", + " ff_attn_heads_grads = torch.from_numpy(ff_attn_heads_grads)\n", + " ff_attn_heads_grads = ff_attn_heads_grads.permute(1,2,0)\n", + " ff_value_states = torch.from_numpy(ff_value_states)\n", + " ff_value_states = ff_value_states.permute(1,0,2)\n", + " # print(ff_attn_heads_grads.shape)\n", + " # print(ff_value_states.shape)\n", + " simulated_qk_prods_softmax_grads = torch.matmul(ff_attn_heads_grads, ff_value_states)\n", + " #simulated_qk_prods_softmax_grads = simulated_qk_prods_softmax_grads\n", + " #print(\"Simulated QK prods grads:\")\n", + " #print(simulated_qk_prods_softmax_grads[0,:,:])\n", + "\n", + " # qk prods softmax right before softmax\n", + " hf_qk_prods_softmax2 = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.qk_prods_softmax.go_0\"\n", + " hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)\n", + " ff_qk_prods_softmax2 = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_qk_prods_softmax_grad\"\n", + " ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", + " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)\n", + " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()\n", + " \n", + " mismatches = np.where(~np.isclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2))\n", + " mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]\n", + " pct_mismatch = len(mismatches) / (hf_qk_prods_softmax2.shape[0] * hf_qk_prods_softmax2.shape[1] * hf_qk_prods_softmax2.shape[2])\n", + " print(f\"{pct_mismatch*100}% mismatch in QK prods softmax out grad\")\n", + " # print(hf_qk_prods_softmax2[:2,:,0])\n", + " # print(ff_qk_prods_softmax2[:2,:,0])\n", + " assert(pct_mismatch <= 0.1)\n", + "\n", + " # qk prods softmax right after softmax\n", + " hf_qk_prods_softmax2 = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.pre_softmax.gi_0\"\n", + " hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)\n", + " ff_qk_prods_softmax2 = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_qk_prods_softmax_grad_in\"\n", + " ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", + " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)\n", + " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()\n", + " compare_loaded_tensors(hf_qk_prods_softmax2, ff_qk_prods_softmax2)\n", + " \n", + " # qk prods softmax after mask\n", + " hf_qk_prods_softmax2 = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.matmul_op.go_0\"\n", + " hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)\n", + " ff_qk_prods_softmax2 = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_qk_prods_softmax_grad_in_masked\"\n", + " ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", + " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)\n", + " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()\n", + " assert(np.allclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2, atol=1e-2))\n", + "\n", + " # Compare query activation\n", + " hf_query_activation = hf_path + f\"/fwd_step_0_layers.11.self_attn.query_activation.output_0\"\n", + " hf_query_activation = torch.load(hf_query_activation)\n", + " ff_query_activation = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_query_activation\"\n", + " ff_query_activation = np.loadtxt(ff_query_activation, delimiter=',').reshape((qProjSize, num_heads, num_new_tokens), order = 'F')\n", + " hf_query_activation = hf_query_activation.squeeze().permute(2,0,1).detach().cpu().numpy()\n", + " # assert(np.allclose(ff_query_activation, hf_query_activation, atol=1e-2))\n", + " # print(hf_query_activation[:,0,:])\n", + " # print()\n", + " # print(ff_query_activation[:,0,:])\n", + " # assert False\n", + " # compare_loaded_tensors(hf_query_activation, ff_query_activation)\n", + " check_rope = False\n", + " if check_rope:\n", + " ########################################## ROPE and Kproj ##########################################\n", + "\n", + " # Compare FF kproj with intermediate kproj data from HF\n", + " hf_kproj_grads_post_rotary = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.identity_kv_post_rotary.go_0\"\n", + " hf_kproj_grads_post_rotary = torch.load(hf_kproj_grads_post_rotary)\n", + " hf_kproj_grads_post_rotary_copy = hf_kproj_grads_post_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()\n", + " # print(\"hf_kproj_grads_post_rotary: \", hf_kproj_grads_post_rotary_copy.shape)\n", + " # print(hf_kproj_grads_post_rotary_copy[:,:,0])\n", + " # Check hf ROPE \n", + " cos, sin = rotary_emb(hf_kproj_grads_post_rotary, seq_len=24)\n", + " cos = cos.cuda()\n", + " sin = sin.cuda()\n", + " # query_states: torch.Size([1, 12, 24, 64])\n", + " # key_states: torch.Size([1, 12, 24, 64])\n", + " # position_ids: torch.Size([1, 24])\n", + " # tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", + " # 18, 19, 20, 21, 22, 23]], device='cuda:0')\n", + " query_states = torch.zeros([1, 12, 24, 64]).cuda()\n", + " position_ids = torch.arange(24).unsqueeze(0).cuda()\n", + " query_states, hf_kproj_grads_post_rotary = apply_rotary_pos_emb(query_states, hf_kproj_grads_post_rotary, cos, sin, position_ids)\n", + " hf_kproj_grads_post_rotary = hf_kproj_grads_post_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()\n", + " # print(\"hf_kproj_grads_post_rotary: \", hf_kproj_grads_post_rotary.shape)\n", + " # print(hf_kproj_grads_post_rotary[:,:,0])\n", + " \n", + " hf_kproj_grads_before_rotary = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.identity_kv_before_rotary.go_0\"\n", + " hf_kproj_grads_before_rotary = torch.load(hf_kproj_grads_before_rotary)\n", + " hf_kproj_grads_before_rotary = hf_kproj_grads_before_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()\n", + " # print(\"hf_kproj_grads_before_rotary: \", hf_kproj_grads_before_rotary.shape)\n", + " # print(hf_kproj_grads_before_rotary[:,:,0])\n", + " # Compare HF rope with manual ROPE\n", + " assert(np.allclose(hf_kproj_grads_post_rotary, hf_kproj_grads_before_rotary, atol=1e-5))\n", + " # Compare HF Kproj with FF Kproj (before ROPE) \n", + " ff_kproj_pre = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_devkproj_pre\"\n", + " ff_kproj_pre = np.loadtxt(ff_kproj_pre, delimiter=',').reshape((num_tokens, qProjSize, num_heads), order = 'F')\n", + " # print(\"ff_kproj_pre: \", ff_kproj_pre.shape)\n", + " #print(ff_kproj_pre[:,:,0])\n", + " mismatches = np.where(~np.isclose(ff_kproj_pre, hf_kproj_grads_post_rotary_copy, atol=1e-5))\n", + " mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]\n", + " pct_mismatch = len(mismatches) / (ff_kproj_pre.shape[0] * ff_kproj_pre.shape[1] * ff_kproj_pre.shape[2])\n", + " print(f\"{pct_mismatch*100}% mismatch between HF and FF for kproj (before applying ROPE)\")\n", + " assert(pct_mismatch <= 0.05)\n", + " #assert(np.allclose(ff_kproj_pre, hf_kproj_grads_post_rotary_copy, atol=1e-5))\n", + " \n", + " ff_kproj = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_devkproj\"\n", + " ff_kproj = np.loadtxt(ff_kproj, delimiter=',').reshape((num_tokens, qProjSize, num_heads), order = 'F')\n", + " # print(\"ff_kproj: \", ff_kproj.shape)\n", + " #print(ff_kproj[:,:,0])\n", + " mismatches = np.where(~np.isclose(ff_kproj, hf_kproj_grads_before_rotary, atol=1e-5))\n", + " mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]\n", + " pct_mismatch = len(mismatches) / (ff_kproj.shape[0] * ff_kproj.shape[1] * ff_kproj.shape[2])\n", + " print(f\"{pct_mismatch*100}% mismatch between HF and FF for kproj (after applying ROPE)\")\n", + " assert(pct_mismatch <= 0.05)\n", + " #assert(np.allclose(ff_kproj, hf_kproj_grads_before_rotary, atol=1e-5))\n", + " \n", + " \n", + " #assert(np.allclose(hf_kproj_grads_post_rotary, hf_kproj_grads_before_rotary, atol=1e-2))\n", + " hf_kproj_grads = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.k_proj.go_0\"\n", + " hf_kproj_grads = torch.load(hf_kproj_grads).squeeze()\n", + " #print(\"hf_kproj_grads: \", hf_kproj_grads.shape)\n", + " #print(hf_kproj_grads[:,:64])\n", + " reshaped_tensor = hf_kproj_grads.view(24, 12, 64).transpose(1, 2).contiguous().detach().cpu().numpy()\n", + " #print(reshaped_tensor.shape)\n", + " assert(np.allclose(ff_kproj, reshaped_tensor, atol=1e-2))\n", + "\n", + " ########################################## Qproj (with ROPE) ##########################################\n", + "\n", + " # Compare QProj\n", + " hf_qproj_grads = f\"{hf_path}/bwd_step_0_layers.{i}.self_attn.q_proj.go_0\"\n", + " hf_qproj_grads = torch.load(hf_qproj_grads).squeeze()\n", + " # print(\"HF Qproj:\")\n", + " # print(hf_qproj_grads.shape)\n", + " reshaped_tensor = hf_qproj_grads.view(24, 12, 64).transpose(1, 2).contiguous().detach().cpu().numpy()\n", + " # print(\"\\t reshaped: \", reshaped_tensor.shape)\n", + " # print(reshaped_tensor[:,:,0])\n", + " ff_qproj = ff_path + f\"/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_devQKVPRojArray\"\n", + " ff_qproj = np.loadtxt(ff_qproj, delimiter=',').reshape((num_tokens, qProjSize, num_heads, 3), order = 'F')[:,:,:,0]\n", + " # print(\"FF Qproj:\")\n", + " # print(ff_qproj.shape)\n", + " # print(ff_qproj[:,:,0])\n", + " assert(np.allclose(ff_qproj, reshaped_tensor, atol=1e-2))\n", + "\n", + " hf_attn_in = f\"{hf_path}/bwd_step_0_layers.{i}.input_layernorm.go_0\"\n", + " hf_attn_in = torch.load(hf_attn_in)\n", + " hf_attn_in = hf_attn_in.squeeze().T\n", + " hf_attn_in = hf_attn_in.detach().cpu().numpy()\n", + " print(\"hf_attn_in: \", hf_attn_in.shape)\n", + " print(hf_attn_in)\n", + "\n", + " ff_attn_in = f\"{ff_path}/bwd_step_0_layers_{i}_layers_{i}_attention_shard_0_attn_final_grad_in\"\n", + " ff_attn_in = np.loadtxt(ff_attn_in, delimiter=',').reshape((768,num_tokens), order = 'F')\n", + " print(\"ff_attn_in: \", ff_attn_in.shape)\n", + " print(ff_attn_in)\n", + " #assert(np.allclose(ff_attn_in, hf_attn_in, atol=1e-2))\n", + "\n", + " mismatches = np.where(~np.isclose(ff_attn_in, hf_attn_in))\n", + " mismatches = [(mismatches[0][i], mismatches[1][i]) for i in range(len(mismatches[0]))]\n", + " pct_mismatch = len(mismatches) / (hf_attn_in.shape[0] * hf_attn_in.shape[1])\n", + " print(f\"{pct_mismatch*100}% mismatch in attention input grads\")\n", + " assert(pct_mismatch <= 0.1)\n", + " \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[-0.01614726 0.01363804 0.01768043 ... 0.00724926 -0.00149747\n", + " -0.01781223]\n" + ] + } + ], + "source": [ + "a = np.fromfile(\"/usr0/home/goliaro/.cache/flexflow/weights/goliaro/llama-160m-lora-full/full-precision/layers_11_feed_forward_w2_lora_A_weight\", dtype=np.float32)\n", + "print(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# value states: torch.Size([1, 12, 24, 64])\n", + "value_states=torch.from_numpy(hf_kproj_grads_post_rotary).permute(2,0,1).unsqueeze(0)\n", + "key_states = value_states\n", + "cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)\n", + "# query_states: torch.Size([1, 12, 24, 64])\n", + "# key_states: torch.Size([1, 12, 24, 64])\n", + "# position_ids: torch.Size([1, 24])\n", + "# tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", + "# 18, 19, 20, 21, 22, 23]], device='cuda:0')\n", + "query_states = torch.zeros([1, 12, 24, 64])\n", + "position_ids = torch.arange(24).unsqueeze(0)\n", + "query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n", + "key_states = key_states.squeeze()\n", + "print(key_states.shape)\n", + "print(key_states[0,:,:])\n", + "print(hf_kproj_grads_before_rotary.shape)\n", + "print(hf_kproj_grads_before_rotary[:,:,0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", + " 18, 19, 20, 21, 22, 23]], device='cuda:0')" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.arange(24).unsqueeze(0).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 12, 24, 24])\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/alignment_tests.ipynb Cell 6\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 17\u001b[0m ff_qkps \u001b[39m=\u001b[39m ff_qk_prods_softmax[:,:,head_idx]\n\u001b[1;32m 18\u001b[0m \u001b[39massert\u001b[39;00m(np\u001b[39m.\u001b[39mallclose(ff_qkps, hf_qkps, atol\u001b[39m=\u001b[39m\u001b[39m1e-5\u001b[39m))\n\u001b[0;32m---> 19\u001b[0m \u001b[39massert\u001b[39;00m(\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m 21\u001b[0m hf_value_states \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mload(hf_value_states)\u001b[39m#.squeeze().T.detach().cpu().numpy()\u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[39mprint\u001b[39m(hf_value_states\u001b[39m.\u001b[39mshape)\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "layer_num = 11\n", + "hf_qk_prods_softmax = f\"{hf_path}/fwd_step_0_layers.11.self_attn.qk_prods_softmax\"\n", + "ff_qk_prods_softmax = f\"{ff_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax\"\n", + "\n", + "hf_value_states = f\"{hf_path}/fwd_step_0_layers.11.self_attn.value_states\"\n", + "\n", + "hf_qk_prods_softmax = torch.load(hf_qk_prods_softmax)#.squeeze().T.detach().cpu().numpy()\n", + "ff_qk_prods_softmax = np.loadtxt(ff_qk_prods_softmax, delimiter=',').reshape((24, 24, 12), order = 'F')\n", + "print(hf_qk_prods_softmax.shape)\n", + "#print(ff_qk_prods_softmax.shape)\n", + "#print(hf_qk_prods_softmax[:,:,0])\n", + "#print()\n", + "#print(ff_qk_prods_softmax[:,:,0])\n", + "\n", + "for head_idx in range(12):\n", + " hf_qkps = hf_qk_prods_softmax.squeeze()[head_idx, :, :].detach().cpu().numpy()\n", + " ff_qkps = ff_qk_prods_softmax[:,:,head_idx]\n", + " assert(np.allclose(ff_qkps, hf_qkps, atol=1e-5))\n", + "\n", + "\n", + "hf_value_states = torch.load(hf_value_states)#.squeeze().T.detach().cpu().numpy()\n", + "print(hf_value_states.shape)\n", + "attn_output = torch.matmul(hf_qk_prods_softmax, hf_value_states)\n", + "print()\n", + "print(attn_output.shape)\n", + "print(attn_output.transpose(1, 2).contiguous().shape)\n", + "print(\"Hf attn heads\")\n", + "print(torch.load(\"/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/fwd_step_0_layers.11.self_attn.o_proj.input_0\").shape)\n", + "\n", + "print(\"Attn heads grads:\")\n", + "hf_attn_heads_grads = f\"{hf_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.gi_0\"\n", + "print(torch.load(hf_attn_heads_grads).shape)\n", + "print(\"HF value grads:\")\n", + "vproj_grads = f\"{hf_path}/bwd_step_0_layers.{layer_num}.self_attn.v_proj.gi_0\"\n", + "print(torch.load(vproj_grads).shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 3, 4])\n", + "torch.Size([4, 3, 2])\n" + ] + } + ], + "source": [ + "a = torch.randn(2,3,4)\n", + "print(a.shape)\n", + "print(a.T.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000,\n", + " 0.0000],\n", + " [ 27.8890, -21.5089, 45.8214, ..., 5.4010, -10.8787,\n", + " 39.7619],\n", + " [ 19.2197, 27.4681, -68.7141, ..., 102.3280, 66.7925,\n", + " -160.8711],\n", + " ...,\n", + " [ 63.9532, 17.4273, -29.4416, ..., 101.6105, 67.5937,\n", + " -198.4432],\n", + " [ 31.2799, 13.0724, -44.7179, ..., 132.4898, 42.3135,\n", + " -194.4037],\n", + " [ 42.3453, -16.2693, -55.7386, ..., 90.5921, 52.2032,\n", + " -124.1802]]], device='cuda:0')\n", + "tensor([[[-1.1845e+06, -6.7460e+05, 7.4494e+05, ..., -9.1441e+05,\n", + " -1.4912e+05, 3.5769e+06],\n", + " [-7.3920e+01, -7.9389e+01, 1.1027e+02, ..., -7.3020e+01,\n", + " -2.3540e+01, 3.4587e+02],\n", + " [-5.3885e+01, -1.7373e+01, -1.9780e+01, ..., 4.1291e+01,\n", + " 5.5099e+01, 5.5910e+01],\n", + " ...,\n", + " [-2.1948e+01, -3.2109e+01, 2.8364e+01, ..., 3.4321e+01,\n", + " 5.0713e+01, 5.6592e+01],\n", + " [-4.4339e+01, -2.8339e+01, 1.4070e+01, ..., 6.2797e+01,\n", + " 3.0760e+01, 6.1743e+01],\n", + " [-1.6287e+01, -5.0413e+01, -1.9940e+01, ..., 4.3766e+01,\n", + " 4.7833e+01, 4.7295e+01]]], device='cuda:0')\n" + ] + } + ], + "source": [ + "a = \"./hf_peft_tensors/bwd_step_0_layers.11.post_attention_layernorm.gi_0\"\n", + "b = \"./hf_peft_tensors/bwd_step_0_layers.11.self_attn.o_proj.go_0\"\n", + "a = torch.load(a)\n", + "b = torch.load(b)\n", + "print(a)\n", + "print(b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n" + ] + } + ], + "source": [ + "for layer_num in range(12):\n", + " hf_lora_A_weight_fp = f\"{hf_path}/layers.{layer_num}.mlp.down_proj.lora_A.default.weight\"\n", + " ff_lora_A_weight_fp = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A\"\n", + " compare_tensors(hf_lora_A_weight_fp, ff_lora_A_weight_fp, tolerance=1e-5)\n", + " hf_lora_B_weight_fp = f\"{hf_path}/layers.{layer_num}.mlp.down_proj.lora_B.default.weight\"\n", + " ff_lora_B_weight_fp = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B\"\n", + " compare_tensors(hf_lora_B_weight_fp, ff_lora_B_weight_fp, tolerance=1e-5)\n", + " hf_w1_weight = f\"{hf_path}/layers.{layer_num}.mlp.gate_proj.weight\"\n", + " ff_w1_weight = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_weight_0\"\n", + " compare_tensors(hf_w1_weight, ff_w1_weight, tolerance=1e-5)\n", + " hf_w3_weight = f\"{hf_path}/layers.{layer_num}.mlp.up_proj.weight\"\n", + " ff_w3_weight = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_weight_0\"\n", + " compare_tensors(hf_w3_weight, ff_w3_weight, tolerance=1e-5)\n", + " hf_w2_weight = f\"{hf_path}/layers.{layer_num}.mlp.down_proj.weight\"\n", + " ff_w2_weight = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_weight_0\"\n", + " compare_tensors(hf_w2_weight, ff_w2_weight, tolerance=1e-5)\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/peft/alignment/opt_alignment_tests.ipynb b/tests/peft/alignment/opt_alignment_tests.ipynb new file mode 100644 index 0000000000..ca679b1857 --- /dev/null +++ b/tests/peft/alignment/opt_alignment_tests.ipynb @@ -0,0 +1,450 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import os, torch\n", + "from align_test_utils import *" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- Attn bias + residual ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "--- MLP ---\n", + "Ok!\n", + "Ok!\n", + "\n", + "--- LM head ---\n", + "Ok!\n", + "Ok!\n", + "\n", + "--- Final Norm ---\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n" + ] + } + ], + "source": [ + "tot_num_layers = 12\n", + "qProjSize = 64\n", + "num_heads = 12\n", + "num_tokens = 25\n", + "for i in range(tot_num_layers):\n", + " hf_base = os.path.join(hf_path, f\"fwd_step_0_decoder.layers.{i}.\")\n", + " ff_base = os.path.join(ff_path, f\"fwd_step_0_layers_{i}_layers_{i}_\")\n", + " \n", + " # LayerNorm\n", + " hf_tensor = hf_base + \"self_attn_layer_norm.input_0\"\n", + " ff_tensor = ff_base + \"attention_layer_norm_shard_0_output_0\"\n", + " compare_tensors(hf_tensor, ff_tensor)\n", + " hf_tensor = hf_base + \"self_attn_layer_norm.output_0\"\n", + " ff_tensor = ff_base + \"attention_layer_norm_shard_0_output_1\"\n", + " compare_tensors(hf_tensor, ff_tensor)\n", + "\n", + " # # Attention QKV proj\n", + " # print(\"---Attn---\")\n", + " # ff_tensor = ff_base + \"attention_shard_0_qkv_proj_output\"\n", + " # ff_tensor = load_ff_tensor(ff_tensor, [qProjSize, num_heads, 3, num_tokens])\n", + " # ff_q_proj = ff_tensor[:,:,0,:]\n", + " # ff_k_proj = ff_tensor[:,:,1,:]\n", + " # ff_v_proj = ff_tensor[:,:,2,:]\n", + " # hf_q_proj = hf_base + \"self_attn.q_proj.output_0\"\n", + " # hf_q_proj = load_hf_tensor(hf_q_proj).squeeze().T\n", + " # hf_q_proj = hf_q_proj.reshape(12,64,25)\n", + " # hf_q_proj = np.transpose(hf_q_proj, (1,0,2))\n", + " # hf_k_proj = hf_base + \"self_attn.k_proj.output_0\"\n", + " # hf_k_proj = load_hf_tensor(hf_k_proj).squeeze().T\n", + " # hf_k_proj = hf_k_proj.reshape(12,64,25)\n", + " # hf_k_proj = np.transpose(hf_k_proj, (1,0,2))\n", + " # hf_v_proj = hf_base + \"self_attn.v_proj.output_0\"\n", + " # hf_v_proj = load_hf_tensor(hf_v_proj).squeeze().T\n", + " # hf_v_proj = hf_v_proj.reshape(12,64,25)\n", + " # hf_v_proj = np.transpose(hf_v_proj, (1,0,2))\n", + " # compare_loaded_tensors(hf_q_proj/np.sqrt(qProjSize), ff_q_proj)\n", + " # compare_loaded_tensors(hf_k_proj, ff_k_proj)\n", + " # compare_loaded_tensors(hf_v_proj, ff_v_proj)\n", + "\n", + " # Compare attn bias, residuals\n", + " print(\"--- Attn bias + residual ---\")\n", + " ff_residual1 = ff_path + f\"/fwd_step_0_layers_{i}_AddBiasResidualLayerNorm_shard_0_input_1\"\n", + " ff_residual2 = ff_base + \"attention_layer_norm_shard_0_output_0\"\n", + " compare_flexflow_tensors(ff_residual1, ff_residual2)\n", + " hf_tensor = hf_base + \"self_attn_layer_norm.input_0\"\n", + " compare_tensors(hf_tensor, ff_residual2)\n", + " ff_tensor = ff_path + f\"/fwd_step_0_layers_{i}_AddBiasResidualLayerNorm_shard_0_output_0\"\n", + " hf_tensor = hf_base + \"final_layer_norm.input_0\"\n", + " compare_tensors(hf_tensor, ff_tensor)\n", + " \n", + " print(\"--- MLP ---\")\n", + " hf_tensor = hf_base + \"fc1.input_0\"\n", + " ff_tensor = ff_base + \"fc1_shard_0_input_0\"\n", + " compare_tensors(hf_tensor, ff_tensor)\n", + " hf_tensor = hf_base + \"fc2.input_0\"\n", + " ff_tensor = ff_base + \"fc2_shard_0_input_0\"\n", + " compare_tensors(hf_tensor, ff_tensor)\n", + "# LM head\n", + "print(\"\\n--- LM head ---\")\n", + "hf_tensor = hf_path + \"/fwd_step_0_base_model.model.lm_head.input_0\"\n", + "ff_tensor = ff_path + \"/fwd_step_0_layers_11_embed_tokens_weight_lm_head_shard_0_input_0\"\n", + "compare_tensors(hf_tensor, ff_tensor)\n", + "hf_tensor = hf_path + \"/fwd_step_0_base_model.model.lm_head.output_0\"\n", + "ff_tensor = ff_path + \"/fwd_step_0_layers_11_embed_tokens_weight_lm_head_shard_0_output_0\"\n", + "compare_tensors(hf_tensor, ff_tensor)\n", + "# Final layer norm\n", + "print(\"\\n--- Final Norm ---\")\n", + "hf_tensor = hf_path + \"/fwd_step_0_decoder.final_layer_norm.input_0\"\n", + "ff_tensor = ff_path + \"/fwd_step_0_layers_11_final_layer_norm_shard_0_output_0\"\n", + "compare_tensors(hf_tensor, ff_tensor)\n", + "ff_tensor1 = ff_path + \"/fwd_step_0_layers_11_final_layer_norm_shard_0_input_activation\"\n", + "# compare_flexflow_tensors_shortest(ff_tensor, ff_tensor1)\n", + "hf_tensor = hf_path + \"/fwd_step_0_decoder.final_layer_norm.output_0\"\n", + "ff_tensor = ff_path + \"/fwd_step_0_layers_11_final_layer_norm_shard_0_output_1\"\n", + "compare_tensors(hf_tensor, ff_tensor)\n", + "hf_tensor = hf_path + \"/fwd_step_0_decoder.final_layer_norm.saved_result_1\"\n", + "ff_tensor = ff_path + \"/fwd_step_0_layers_11_final_layer_norm_shard_0_mean\"\n", + "compare_tensors(hf_tensor, ff_tensor)\n", + "hf_tensor = hf_path + \"/fwd_step_0_decoder.final_layer_norm.saved_result_2\"\n", + "ff_tensor = ff_path + \"/fwd_step_0_layers_11_final_layer_norm_shard_0_rstd\"\n", + "compare_tensors(hf_tensor, ff_tensor)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[17], line 22\u001b[0m\n\u001b[1;32m 19\u001b[0m compare_flexflow_tensors(ff_tensor, ff_tensor1)\n\u001b[1;32m 20\u001b[0m compare_tensors(hf_tensor, ff_tensor) \u001b[38;5;66;03m# fails\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;66;03m# Compare fwd input/output of layernorm\u001b[39;00m\n\u001b[1;32m 25\u001b[0m hf_FWD_norm_in \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhf_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/fwd_step_0_decoder.final_layer_norm.input_0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "# Compare backward pass\n", + "hf_tensor = hf_path + \"/bwd_step_0_base_model.model.lm_head.go_0\"\n", + "ff_tensor = ff_path + \"/bwd_step_0_layers_11_embed_tokens_weight_lm_head_shard_0_output_0\"\n", + "compare_tensors(hf_tensor, ff_tensor, tolerance=1e-5)\n", + "hf_tensor = hf_path + \"/bwd_step_0_base_model.model.lm_head.gi_0\"\n", + "ff_tensor = ff_path + \"/bwd_step_0_layers_11_embed_tokens_weight_lm_head_shard_0_input_0\"\n", + "compare_tensors(hf_tensor, ff_tensor, tolerance=1e-5)\n", + "\n", + "hf_tensor1 = hf_path + \"/bwd_step_0_decoder.final_layer_norm.go_0\"\n", + "compare_hf_tensors(hf_tensor, hf_tensor1)\n", + "ff_tensor = ff_path + \"/bwd_step_0_layers_11_final_layer_norm_shard_0_output_0\"\n", + "compare_tensors(hf_tensor1, ff_tensor)\n", + "\n", + "hf_tensor = hf_path + \"/bwd_step_0_decoder.final_layer_norm.gi_0\"\n", + "ff_tensor = ff_path + \"/bwd_step_0_layers_11_final_layer_norm_shard_0_input_0\"\n", + "ff_tensor1 = ff_path + \"/bwd_step_0_layers_11_final_layer_norm_shard_0_input_1\"\n", + "compare_flexflow_tensors(ff_tensor, ff_tensor1)\n", + "compare_tensors(hf_tensor, ff_tensor) # fails" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/fwd_step_0_decoder.layers.0.fc1.input_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_decoding-step_0_layer-num_0_layer-name_layers_0_fc1_shard-id_0_input_0\n", + "HF: [ 0.0193019 -1.0467215 0.21579844 ... 0.04534929 -0.25642633\n", + " 0.10879952]\n", + "FF:[ 0.01458706 -1.02212262 0.20589906 ... 0.04446212 -0.25625792\n", + " 0.108039 ]\n", + "[ True False True ... True True True]\n", + "[ 1 3 7 ... 19170 19174 19188]\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[11], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m hf_fc1_in \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/fwd_step_0_decoder.layers.0.fc1.input_0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 15\u001b[0m ff_fc1_in \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_decoding-step_0_layer-num_0_layer-name_layers_0_fc1_shard-id_0_input_0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 16\u001b[0m \u001b[43mcompare_tensors\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhf_fc1_in\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mff_fc1_in\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# LORA input\u001b[39;00m\n\u001b[1;32m 20\u001b[0m hf_lora_A_in \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhf_weight_base_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/fwd_step_0_layers.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mlayer_num\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.mlp.down_proj.lora_A.default.input_0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Desktop/FlexFlow/tests/peft/align_test_utils.py:32\u001b[0m, in \u001b[0;36mcompare_tensors\u001b[0;34m(hf_tensor_filepath, ff_tensor_filepath, tolerance)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(mismatches)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;66;03m#print(np.nonzero(hf_tensor)[0])\u001b[39;00m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;66;03m# print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0])\u001b[39;00m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# print(ff_tensor[36], hf_tensor[36])\u001b[39;00m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;66;03m#assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\u001b[39;00m\n\u001b[0;32m---> 32\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m(\u001b[38;5;28mlen\u001b[39m(mismatches) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m.05\u001b[39m\u001b[38;5;241m*\u001b[39mlen_hf_tensor)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOk!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "tot_num_layers = 12\n", + "for layer_num in range(tot_num_layers):\n", + " hf_input_ln_out = f\"{hf_path}/fwd_step_0_decoder.layers.{layer_num}.self_attn_layer_norm.output_0\"\n", + " ff_input_ln_out = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_layer_norm_shard-id_0_output_1\"\n", + " compare_tensors(hf_input_ln_out, ff_input_ln_out)\n", + " \n", + " hf_ffn_norm_in = f\"{hf_path}/fwd_step_0_decoder.layers.{layer_num}.final_layer_norm.input_0\"\n", + " ff_ffn_norm_in = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_AddBiasResidualLayerNorm_shard-id_0_output_0\"\n", + " # compare_tensors(hf_ffn_norm_in, ff_ffn_norm_in)\n", + " \n", + " hf_ffn_norm_out = f\"{hf_path}/fwd_step_0_decoder.layers.{layer_num}.final_layer_norm.output_0\"\n", + " ff_ffn_norm_out = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_AddBiasResidualLayerNorm_shard-id_0_output_1\"\n", + " # compare_tensors(hf_ffn_norm_out, ff_ffn_norm_out)\n", + " hf_fc1_in = \"/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/fwd_step_0_decoder.layers.0.fc1.input_0\"\n", + " ff_fc1_in = \"/usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_decoding-step_0_layer-num_0_layer-name_layers_0_fc1_shard-id_0_input_0\"\n", + " compare_tensors(hf_fc1_in, ff_fc1_in)\n", + "\n", + "\n", + " # LORA input\n", + " hf_lora_A_in = f\"{hf_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.input_0\"\n", + " ff_lora_A_in = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_input_0\"\n", + " compare_hf_tensors(hf_down_proj_in, hf_lora_A_in)\n", + " compare_tensors(hf_lora_A_in, ff_lora_A_in)\n", + " # LORA weights\n", + " hf_lora_A_weight_fp = f\"{hf_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_A.default.weight\"\n", + " ff_lora_A_weight_fp = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A\"\n", + " compare_tensors(hf_lora_A_weight_fp, ff_lora_A_weight_fp)\n", + " hf_lora_B_weight_fp = f\"{hf_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_B.default.weight\"\n", + " ff_lora_B_weight_fp = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B\"\n", + " compare_tensors(hf_lora_B_weight_fp, ff_lora_B_weight_fp)\n", + " # LORA intermediate hf\n", + " hf_lora_A_out = f\"{hf_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.output_0\"\n", + " hf_lora_B_in = f\"{hf_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.input_0\"\n", + " compare_hf_tensors(hf_lora_A_out, hf_lora_B_in)\n", + " # LORA output\n", + " hf_lora_out = f\"{hf_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.output_0\"\n", + " ff_lora_out = f\"{ff_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_output_0\"\n", + " # compare_tensors(hf_lora_out, ff_lora_out)\n", + " # compare_flexflow_tensors(ff_down_proj_out, ff_lora_out)\n", + " # compare_tensors(hf_down_proj_out, ff_lora_out)\n", + " compare_tensors_difference(hf_lora_out, ff_lora_out, ff_down_proj_out)\n", + " \n", + "\n", + "# After last layer only\n", + "hf_norm_out = f\"{hf_path}/fwd_step_0_norm.output_0\"\n", + "ff_norm_out = f\"{ff_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_output_1\"\n", + "compare_tensors(hf_norm_out, ff_norm_out)\n", + "hf_lm_head_out = f\"{hf_path}/fwd_step_0_base_model.model.lm_head.output_0\"\n", + "ff_lm_head_out = f\"{ff_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_output_0\"\n", + "compare_tensors(hf_lm_head_out, ff_lm_head_out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ok!\n", + "Ok!\n", + "Ok!\n", + "Ok!\n", + "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/fwd_step_0_decoder.final_layer_norm.input_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_decoding-step_0_layer-num_11_layer-name_final_layer_norm_shard-id_0_output_0\n", + "HF: [-0.00542103 -1.781267 0.16552497 ... -0.77217525 -0.5760026\n", + " 0.04363118]\n", + "FF:[ 0.03817766 -1.5644939 0.22477378 ... -0.94569921 -0.43960798\n", + " -0.06447437]\n", + "[False False False ... False False False]\n", + "[ 0 1 2 ... 19197 19198 19199]\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[10], line 22\u001b[0m\n\u001b[1;32m 20\u001b[0m ff_FWD_norm_in \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mff_weight_base_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/model_0_decoding-step_0_layer-num_11_layer-name_final_layer_norm_shard-id_0_output_0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 21\u001b[0m ff_FWD_norm_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mff_weight_base_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/model_0_decoding-step_0_layer-num_11_layer-name_final_layer_norm_shard-id_0_output_1\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 22\u001b[0m \u001b[43mcompare_tensors\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhf_FWD_norm_in\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mff_FWD_norm_in\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m compare_tensors(hf_FWD_norm_out, ff_FWD_norm_out)\n\u001b[1;32m 25\u001b[0m hf_BWD_norm_in \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhf_weight_base_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/bwd_step_0_decoder.final_layer_norm.gi_0\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Desktop/FlexFlow/tests/peft/align_test_utils.py:29\u001b[0m, in \u001b[0;36mcompare_tensors\u001b[0;34m(hf_tensor_filepath, ff_tensor_filepath, tolerance)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28mprint\u001b[39m(mismatches)\n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m#print(np.nonzero(hf_tensor)[0])\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0])\u001b[39;00m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# print(ff_tensor[36], hf_tensor[36])\u001b[39;00m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;66;03m#assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\u001b[39;00m\n\u001b[0;32m---> 29\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m(\u001b[38;5;28mlen\u001b[39m(mismatches) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m.05\u001b[39m\u001b[38;5;241m*\u001b[39mlen_hf_tensor)\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOk!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "tot_num_layers = 12\n", + "\n", + "ff_BWD_softmax_in = f\"{ff_path}/model_0_bwd-step_0_layer-num_100_layer-name_Softmax_shard-id_0_input_0\"\n", + "\n", + "hf_BWD_lm_head_out = f\"{hf_path}/bwd_step_0_base_model.model.lm_head.go_0\"\n", + "ff_BWD_lm_head_out = f\"{ff_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_embed_tokens_weight_lm_head_shard-id_0_output_0\"\n", + "compare_tensors(hf_BWD_lm_head_out, ff_BWD_lm_head_out, tolerance=1e-5)\n", + "hf_BWD_lm_head_in = f\"{hf_path}/bwd_step_0_base_model.model.lm_head.gi_0\"\n", + "ff_BWD_lm_head_in = f\"{ff_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_embed_tokens_weight_lm_head_shard-id_0_input_0\"\n", + "compare_tensors(hf_BWD_lm_head_in, ff_BWD_lm_head_in, tolerance=1e-5)\n", + "\n", + "hf_BWD_norm_out = f\"{hf_path}/bwd_step_0_decoder.final_layer_norm.go_0\"\n", + "ff_BWD_norm_out = f\"{ff_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_final_layer_norm_shard-id_0_output_0\"\n", + "compare_hf_tensors(hf_BWD_lm_head_in, hf_BWD_norm_out)\n", + "compare_tensors(hf_BWD_norm_out, ff_BWD_norm_out)\n", + "\n", + "# Compare fwd input/output of layernorm\n", + "hf_FWD_norm_in = f\"{hf_path}/fwd_step_0_decoder.final_layer_norm.input_0\"\n", + "hf_FWD_norm_out = f\"{hf_path}/fwd_step_0_decoder.final_layer_norm.output_0\"\n", + "ff_FWD_norm_in = f\"{ff_path}/model_0_decoding-step_0_layer-num_11_layer-name_final_layer_norm_shard-id_0_output_0\"\n", + "ff_FWD_norm_out = f\"{ff_path}/model_0_decoding-step_0_layer-num_11_layer-name_final_layer_norm_shard-id_0_output_1\"\n", + "compare_tensors(hf_FWD_norm_in, ff_FWD_norm_in)\n", + "compare_tensors(hf_FWD_norm_out, ff_FWD_norm_out)\n", + "\n", + "hf_BWD_norm_in = f\"{hf_path}/bwd_step_0_decoder.final_layer_norm.gi_0\"\n", + "ff_BWD_norm_in = f\"{ff_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_final_layer_norm_shard-id_0_input_1\"\n", + "compare_tensors(hf_BWD_norm_in, ff_BWD_norm_in, tolerance=1e-5)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/peft/alignment_tests.ipynb b/tests/peft/alignment_tests.ipynb deleted file mode 100644 index e2a8978ea3..0000000000 --- a/tests/peft/alignment_tests.ipynb +++ /dev/null @@ -1,1427 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import os, torch" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "hf_weight_base_path = \"/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors\"\n", - "ff_weight_base_path = \"/usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors\"\n", - "def compare_tensors(hf_tensor_filepath, ff_tensor_filepath, tolerance=1e-2):\n", - " assert(os.path.exists(hf_tensor_filepath) and os.path.exists(ff_tensor_filepath))\n", - " hf_tensor = torch.load(hf_tensor_filepath)\n", - " if type(hf_tensor) == tuple or type(hf_tensor) == list:\n", - " assert(len(hf_tensor) == 1)\n", - " hf_tensor = hf_tensor[0]\n", - " hf_tensor = torch.nan_to_num(hf_tensor)\n", - " hf_tensor = hf_tensor.flatten().detach().cpu().numpy()\n", - " ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',')\n", - "\n", - " len_hf_tensor = hf_tensor.shape[0]\n", - " ff_tensor = ff_tensor[:len_hf_tensor]\n", - " \n", - " mismatches = []\n", - " if not np.allclose(ff_tensor, hf_tensor, atol=tolerance):\n", - " print(f\"mismatch between {hf_tensor_filepath} and {ff_tensor_filepath}\")\n", - " print(f\"HF: {hf_tensor}\\nFF:{ff_tensor}\")\n", - " print(np.isclose(ff_tensor, hf_tensor, atol=tolerance))\n", - " mismatches = np.where(~np.isclose(ff_tensor, hf_tensor, atol=tolerance))[0]\n", - " print(mismatches)\n", - " #print(np.nonzero(hf_tensor)[0])\n", - " # print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0])\n", - " # print(ff_tensor[36], hf_tensor[36])\n", - " #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\n", - " assert(len(mismatches) <= .05*len_hf_tensor)\n", - " print(\"Ok!\")\n", - "def compare_tensors_difference(hf_tensor_filepath, ff_tensor1_filepath, ff_tensor2_filepath, tolerance=1e-2):\n", - " assert(os.path.exists(hf_tensor_filepath))\n", - " assert(os.path.exists(ff_tensor1_filepath))\n", - " assert(os.path.exists(ff_tensor2_filepath))\n", - " hf_tensor = torch.load(hf_tensor_filepath)\n", - " if type(hf_tensor) == tuple or type(hf_tensor) == list:\n", - " assert(len(hf_tensor) == 1)\n", - " hf_tensor = hf_tensor[0]\n", - " hf_tensor = torch.nan_to_num(hf_tensor)\n", - " hf_tensor = hf_tensor.flatten().detach().cpu().numpy()\n", - " ff_tensor1 = np.loadtxt(ff_tensor1_filepath, delimiter=',')\n", - " ff_tensor2 = np.loadtxt(ff_tensor2_filepath, delimiter=',')\n", - "\n", - " len_hf_tensor = hf_tensor.shape[0]\n", - " ff_tensor1 = ff_tensor1[:len_hf_tensor]\n", - " ff_tensor2 = ff_tensor2[:len_hf_tensor]\n", - " ff_tensor = ff_tensor1 - ff_tensor2\n", - " \n", - " mismatches = []\n", - " if not np.allclose(ff_tensor, hf_tensor, atol=tolerance):\n", - " print(f\"mismatch between {hf_tensor_filepath} and {ff_tensor1_filepath} - {ff_tensor2_filepath}\")\n", - " print(f\"HF: {hf_tensor}\\nFF:{ff_tensor}\")\n", - " print(np.isclose(ff_tensor, hf_tensor, atol=tolerance))\n", - " mismatches = np.where(~np.isclose(ff_tensor, hf_tensor, atol=tolerance))[0]\n", - " print(mismatches)\n", - " #print(np.nonzero(hf_tensor)[0])\n", - " # print(np.where(np.isclose(ff_tensor, hf_tensor, atol=tolerance) ==0)[0])\n", - " # print(ff_tensor[36], hf_tensor[36])\n", - " #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\n", - " assert(len(mismatches) <= .05*len_hf_tensor)\n", - " print(\"Ok!\")\n", - "def compare_hf_tensors(tensor1_fp, tensor2_fp):\n", - " assert(os.path.exists(tensor1_fp) and os.path.exists(tensor2_fp))\n", - " hf_tensor1 = torch.load(tensor1_fp)\n", - " hf_tensor2 = torch.load(tensor2_fp)\n", - " if type(hf_tensor1) == tuple or type(hf_tensor1) == list:\n", - " assert(len(hf_tensor1) == 1)\n", - " hf_tensor1 = hf_tensor1[0]\n", - " if type(hf_tensor2) == tuple or type(hf_tensor2) == list:\n", - " assert(len(hf_tensor2) == 1)\n", - " hf_tensor2 = hf_tensor2[0]\n", - " assert(torch.squeeze(hf_tensor1).shape == torch.squeeze(hf_tensor2).shape)\n", - " hf_tensor1 = torch.nan_to_num(hf_tensor1)\n", - " hf_tensor2 = torch.nan_to_num(hf_tensor2)\n", - " if not (np.allclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy())):\n", - " print(f\"mismatch between {tensor1_fp} and {tensor2_fp}\")\n", - " print(hf_tensor1)\n", - " print(hf_tensor2)\n", - " print(np.isclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy()))\n", - " mismatches = np.where(~np.isclose(hf_tensor1.detach().cpu().numpy(), hf_tensor2.detach().cpu().numpy()))[0]\n", - " print(mismatches)\n", - " assert(False)\n", - " print(\"Ok!\")\n", - "\n", - "def check_hf_sum_tensors(tensor_sum_fp, tensor1_fp, tensor2_fp):\n", - " assert(os.path.exists(tensor_sum_fp) and os.path.exists(tensor1_fp) and os.path.exists(tensor2_fp))\n", - " hf_tensor_sum = torch.load(tensor_sum_fp)\n", - " hf_tensor1 = torch.load(tensor1_fp)\n", - " hf_tensor2 = torch.load(tensor2_fp)\n", - " if type(hf_tensor_sum) == tuple or type(hf_tensor_sum) == list:\n", - " assert(len(hf_tensor_sum) == 1)\n", - " hf_tensor_sum = hf_tensor_sum[0]\n", - " if type(hf_tensor1) == tuple or type(hf_tensor1) == list:\n", - " assert(len(hf_tensor1) == 1)\n", - " hf_tensor1 = hf_tensor1[0]\n", - " if type(hf_tensor2) == tuple or type(hf_tensor2) == list:\n", - " assert(len(hf_tensor2) == 1)\n", - " hf_tensor2 = hf_tensor2[0]\n", - " assert(torch.squeeze(hf_tensor_sum).shape == torch.squeeze(hf_tensor1).shape)\n", - " assert(torch.squeeze(hf_tensor1).shape == torch.squeeze(hf_tensor2).shape)\n", - " hf_tensor1 = torch.nan_to_num(hf_tensor1)\n", - " hf_tensor2 = torch.nan_to_num(hf_tensor2)\n", - " hf_tensor_sum = torch.nan_to_num(hf_tensor_sum)\n", - " sum_check_tensor = hf_tensor1 + hf_tensor2\n", - " if not (np.allclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy())):\n", - " print(f\"mismatch between {sum_check_tensor} and {tensor1_fp} + {tensor2_fp}\")\n", - " print(tensor_sum_fp)\n", - " print(sum_check_tensor)\n", - " print(hf_tensor1)\n", - " print(hf_tensor2)\n", - " print(np.isclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy()))\n", - " mismatches = np.where(~np.isclose(sum_check_tensor.detach().cpu().numpy(), hf_tensor_sum.detach().cpu().numpy()))[0]\n", - " print(mismatches)\n", - " assert(False)\n", - " print(\"Ok!\")\n", - "def check_hf_zero_tensor(hf_tensor_fp):\n", - " assert(os.path.exists(hf_tensor_fp))\n", - " hf_tensor1 = torch.load(hf_tensor_fp)\n", - " if type(hf_tensor1) == tuple or type(hf_tensor1) == list:\n", - " assert(len(hf_tensor1) == 1)\n", - " hf_tensor1 = hf_tensor1[0]\n", - " assert(torch.count_nonzero(torch.nan_to_num(hf_tensor1)).sum() == 0)\n", - "def print_tensors(hf_tensor_filepath, ff_tensor_filepath, txt=\"\"):\n", - " assert(os.path.exists(hf_tensor_filepath) and os.path.exists(ff_tensor_filepath))\n", - " hf_tensor = torch.load(hf_tensor_filepath)\n", - " if type(hf_tensor) == tuple or type(hf_tensor) == list:\n", - " assert(len(hf_tensor) == 1)\n", - " hf_tensor = hf_tensor[0]\n", - " hf_tensor = torch.nan_to_num(hf_tensor)\n", - " hf_tensor = hf_tensor.flatten().detach().cpu().numpy()\n", - " ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',')\n", - "\n", - " len_hf_tensor = hf_tensor.shape[0]\n", - " ff_tensor = ff_tensor[:len_hf_tensor]\n", - "\n", - " print(f\"{txt} - HF tensor:\")\n", - " print(hf_tensor)\n", - " print(f\"{txt} - FF tensor: \")\n", - " print(ff_tensor)\n", - "def compare_flexflow_tensors(ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5, max_len=-1):\n", - " assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp))\n", - " ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',')\n", - " ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',')\n", - "\n", - " if (ff_tensor1.shape != ff_tensor2.shape):\n", - " print(ff_tensor1.shape, ff_tensor2.shape)\n", - " assert(ff_tensor1.shape == ff_tensor2.shape)\n", - "\n", - " if max_len > -1:\n", - " ff_tensor1 = ff_tensor1[:max_len]\n", - " ff_tensor2 = ff_tensor2[:max_len]\n", - " \n", - " mismatches = []\n", - " if not np.allclose(ff_tensor1, ff_tensor2, atol=tolerance):\n", - " print(f\"mismatch between {ff_tensor1_fp} and {ff_tensor2_fp}\")\n", - " print(f\"Tensor1: {ff_tensor1}\\nTensor2:{ff_tensor2}\")\n", - " print(np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))\n", - " mismatches = np.where(~np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))[0]\n", - " print(mismatches)\n", - " #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\n", - " assert(len(mismatches) <= .05*len(ff_tensor1))\n", - " print(\"Ok!\")\n", - "def compare_flexflow_tensors_shortest(ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5):\n", - " assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp))\n", - " ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',')\n", - " ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',')\n", - " minlen = min(ff_tensor1.shape[0], ff_tensor2.shape[0])\n", - " ff_tensor1 = ff_tensor1[:minlen]\n", - " ff_tensor2 = ff_tensor2[:minlen]\n", - " mismatches = []\n", - " if not np.allclose(ff_tensor1, ff_tensor2, atol=tolerance):\n", - " print(f\"mismatch between {ff_tensor1_fp} and {ff_tensor2_fp}\")\n", - " print(f\"Tensor1: {ff_tensor1}\\nTensor2:{ff_tensor2}\")\n", - " print(np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))\n", - " mismatches = np.where(~np.isclose(ff_tensor1, ff_tensor2, atol=tolerance))[0]\n", - " print(mismatches)\n", - " #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\n", - " assert(len(mismatches) <= .05*len(ff_tensor1))\n", - " print(\"Ok!\")\n", - "def check_flexflow_tensors_sum(ff_tensor_sum_fp, ff_tensor1_fp, ff_tensor2_fp, tolerance=1e-5):\n", - " assert(os.path.exists(ff_tensor1_fp) and os.path.exists(ff_tensor2_fp))\n", - " ff_tensor1 = np.loadtxt(ff_tensor1_fp, delimiter=',')\n", - " ff_tensor2 = np.loadtxt(ff_tensor2_fp, delimiter=',')\n", - " ff_tensor_sum = np.loadtxt(ff_tensor_sum_fp, delimiter=',')\n", - " \n", - " ff_sum = ff_tensor1 + ff_tensor2\n", - " assert(ff_tensor1.shape == ff_tensor2.shape)\n", - " \n", - " mismatches = []\n", - " if not np.allclose(ff_tensor_sum, ff_sum, atol=tolerance):\n", - " print(f\"mismatch between {ff_tensor_sum_fp} and sum of {ff_tensor1_fp} + {ff_tensor2_fp}\")\n", - " print(f\"Tensor1: {ff_tensor1}\\nTensor2:{ff_tensor2}\")\n", - " print(f\"Sum Tensor: {ff_tensor_sum}\\nActual sum:{ff_sum}\")\n", - " print(np.isclose(ff_tensor_sum, ff_sum, atol=tolerance))\n", - " mismatches = np.where(~np.isclose(ff_tensor_sum, ff_sum, atol=tolerance))[0]\n", - " print(mismatches)\n", - " #assert(np.allclose(ff_tensor, hf_tensor, atol=tolerance))\n", - " assert(len(mismatches) <= .05*len(ff_tensor1))\n", - " print(\"Ok!\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n" - ] - } - ], - "source": [ - "tot_num_layers = 12\n", - "for layer_num in range(tot_num_layers):\n", - " hf_input_ln_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.input_layernorm.output_0\"\n", - " ff_input_ln_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_RMSNorm_shard-id_0_output_0\"\n", - " if layer_num > 0:\n", - " ff_input_ln_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_norm_shard-id_0_output_1\"\n", - " compare_tensors(hf_input_ln_out, ff_input_ln_out)\n", - " hf_attn_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.self_attn.o_proj.output_0\"\n", - " ff_attn_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_output_0\"\n", - " compare_tensors(hf_attn_out, ff_attn_out)\n", - " hf_ffn_norm_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.post_attention_layernorm.output_0\"\n", - " ff_ffn_norm_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_output_1\"\n", - " compare_tensors(hf_ffn_norm_out, ff_ffn_norm_out)\n", - " # w1\n", - " hf_gate_proj_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.gate_proj.output_0\"\n", - " ff_gate_proj_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_output_0\"\n", - " compare_tensors(hf_gate_proj_out, ff_gate_proj_out)\n", - " # w3\n", - " hf_up_proj_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.up_proj.output_0\" \n", - " ff_up_proj_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_output_0\"\n", - " compare_tensors(hf_up_proj_out, ff_up_proj_out)\n", - " # w2\n", - " hf_down_proj_in = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.input_0\"\n", - " hf_down_proj_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.output_0\"\n", - " ff_down_proj_in = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_input_0\"\n", - " ff_down_proj_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_output_0\"\n", - " compare_tensors(hf_down_proj_in, ff_down_proj_in)\n", - " # compare_tensors(hf_down_proj_out, ff_down_proj_out)\n", - " # LORA input\n", - " hf_lora_A_in = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.input_0\"\n", - " ff_lora_A_in = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_input_0\"\n", - " compare_hf_tensors(hf_down_proj_in, hf_lora_A_in)\n", - " compare_tensors(hf_lora_A_in, ff_lora_A_in)\n", - " # LORA weights\n", - " hf_lora_A_weight_fp = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_A.default.weight\"\n", - " ff_lora_A_weight_fp = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A\"\n", - " compare_tensors(hf_lora_A_weight_fp, ff_lora_A_weight_fp)\n", - " hf_lora_B_weight_fp = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_B.default.weight\"\n", - " ff_lora_B_weight_fp = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B\"\n", - " compare_tensors(hf_lora_B_weight_fp, ff_lora_B_weight_fp)\n", - " # LORA intermediate hf\n", - " hf_lora_A_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.output_0\"\n", - " hf_lora_B_in = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.input_0\"\n", - " compare_hf_tensors(hf_lora_A_out, hf_lora_B_in)\n", - " # LORA output\n", - " hf_lora_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.output_0\"\n", - " ff_lora_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_output_0\"\n", - " # compare_tensors(hf_lora_out, ff_lora_out)\n", - " # compare_flexflow_tensors(ff_down_proj_out, ff_lora_out)\n", - " # compare_tensors(hf_down_proj_out, ff_lora_out)\n", - " compare_tensors_difference(hf_lora_out, ff_lora_out, ff_down_proj_out)\n", - " \n", - "\n", - "# After last layer only\n", - "hf_norm_out = f\"{hf_weight_base_path}/fwd_step_0_norm.output_0\"\n", - "ff_norm_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_output_1\"\n", - "compare_tensors(hf_norm_out, ff_norm_out)\n", - "hf_lm_head_out = f\"{hf_weight_base_path}/fwd_step_0_base_model.model.lm_head.output_0\"\n", - "ff_lm_head_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_output_0\"\n", - "compare_tensors(hf_lm_head_out, ff_lm_head_out)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n" - ] - } - ], - "source": [ - "tot_num_layers = 12\n", - "\n", - "ff_BWD_softmax_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_100_layer-name_Softmax_shard-id_0_input_0\"\n", - "\n", - "hf_BWD_lm_head_out = f\"{hf_weight_base_path}/bwd_step_0_base_model.model.lm_head.go_0\"\n", - "ff_BWD_lm_head_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_output_0\"\n", - "compare_tensors(hf_BWD_lm_head_out, ff_BWD_lm_head_out, tolerance=1e-5)\n", - "# compare weights\n", - "hf_lm_head_weight = f\"{hf_weight_base_path}/base_model.model.lm_head.weight\"\n", - "ff_lm_head_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_weight_0\"\n", - "compare_tensors(hf_lm_head_weight, ff_lm_head_weight, tolerance=1e-5)\n", - "hf_BWD_lm_head_in = f\"{hf_weight_base_path}/bwd_step_0_base_model.model.lm_head.gi_0\"\n", - "ff_BWD_lm_head_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_output_shard-id_0_input_0\"\n", - "compare_tensors(hf_BWD_lm_head_in, ff_BWD_lm_head_in, tolerance=1e-5)\n", - "# # Manually check the matmul\n", - "# ff_tensor_out = np.loadtxt(ff_BWD_lm_head_out, delimiter=',')\n", - "# ff_weight = np.loadtxt(ff_lm_head_weight, delimiter=',').reshape((4096,32000), order='F')\n", - "# ff_tensor_out = ff_tensor_out[:32000*24].reshape((32000,24), order='F')\n", - "# print(ff_tensor_out.shape)\n", - "# print(ff_weight.shape)\n", - "# print(np.matmul(ff_weight, ff_tensor_out))\n", - "# compare_tensors(hf_BWD_lm_head_in, ff_BWD_lm_head_in)\n", - "# ff_tensor = np.loadtxt(ff_tensor_filepath, delimiter=',')\n", - "\n", - "hf_BWD_norm_out = f\"{hf_weight_base_path}/bwd_step_0_norm.go_0\"\n", - "ff_BWD_norm_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_output_0\"\n", - "compare_hf_tensors(hf_BWD_lm_head_in, hf_BWD_norm_out)\n", - "compare_tensors(hf_BWD_norm_out, ff_BWD_norm_out)\n", - "ff_BWD_norm_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_weight_0\"\n", - "hf_FWD_norm_weight = f\"{hf_weight_base_path}/base_model.model.model.norm.weight\"\n", - "compare_tensors(hf_FWD_norm_weight, ff_BWD_norm_weight, tolerance=1e-5)\n", - "hf_BWD_norm_in = f\"{hf_weight_base_path}/bwd_step_0_norm.gi_0\"\n", - "ff_BWD_norm_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{tot_num_layers-1}_layer-name_norm_shard-id_0_input_1\"\n", - "compare_tensors(hf_BWD_norm_in, ff_BWD_norm_in, tolerance=1e-5)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from torch import nn\n", - "class LlamaRotaryEmbedding(nn.Module):\n", - " def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):\n", - " super().__init__()\n", - "\n", - " self.dim = dim\n", - " self.max_position_embeddings = max_position_embeddings\n", - " self.base = base\n", - " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))\n", - " self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n", - "\n", - " # Build here to make `torch.jit.trace` work.\n", - " self._set_cos_sin_cache(\n", - " seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()\n", - " )\n", - "\n", - " def _set_cos_sin_cache(self, seq_len, device, dtype):\n", - " self.max_seq_len_cached = seq_len\n", - " t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)\n", - "\n", - " freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n", - " # Different from paper, but it uses a different permutation in order to obtain the same calculation\n", - " emb = torch.cat((freqs, freqs), dim=-1)\n", - " self.register_buffer(\"cos_cached\", emb.cos().to(dtype), persistent=False)\n", - " self.register_buffer(\"sin_cached\", emb.sin().to(dtype), persistent=False)\n", - "\n", - " def forward(self, x, seq_len=None):\n", - " # x: [bs, num_attention_heads, seq_len, head_size]\n", - " if seq_len > self.max_seq_len_cached:\n", - " self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)\n", - "\n", - " return (\n", - " self.cos_cached[:seq_len].to(dtype=x.dtype),\n", - " self.sin_cached[:seq_len].to(dtype=x.dtype),\n", - " )\n", - "def rotate_half(x):\n", - " \"\"\"Rotates half the hidden dims of the input.\"\"\"\n", - " x1 = x[..., : x.shape[-1] // 2] # first half\n", - " x2 = x[..., x.shape[-1] // 2 :] # second half\n", - " return torch.cat((x2, -x1), dim=-1)\n", - "def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):\n", - " \"\"\"Applies Rotary Position Embedding to the query and key tensors.\n", - "\n", - " Args:\n", - " q (`torch.Tensor`): The query tensor.\n", - " k (`torch.Tensor`): The key tensor.\n", - " cos (`torch.Tensor`): The cosine part of the rotary embedding.\n", - " sin (`torch.Tensor`): The sine part of the rotary embedding.\n", - " position_ids (`torch.Tensor`):\n", - " The position indices of the tokens corresponding to the query and key tensors. For example, this can be\n", - " used to pass offsetted position ids when working with a KV-cache.\n", - " unsqueeze_dim (`int`, *optional*, defaults to 1):\n", - " The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and\n", - " sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note\n", - " that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and\n", - " k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes\n", - " cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have\n", - " the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.\n", - " Returns:\n", - " `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.\n", - " \"\"\"\n", - " cos = cos[position_ids].unsqueeze(unsqueeze_dim)\n", - " sin = sin[position_ids].unsqueeze(unsqueeze_dim)\n", - " q_embed = (q * cos) + (rotate_half(q) * sin)\n", - " k_embed = (k * cos) + (rotate_half(k) * sin)\n", - " return q_embed, k_embed\n", - "head_dim = 64\n", - "max_position_embeddings = 2048\n", - "rope_theta=10_000\n", - "kv_seq_len = 24\n", - "rotary_emb = LlamaRotaryEmbedding(\n", - " head_dim,\n", - " max_position_embeddings=max_position_embeddings,\n", - " base=rope_theta,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Huggingface checks:\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "\n", - "FlexFlow checks:\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "\n", - "Huggingface-FlexFlow checks:\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_bwd-step_0_layer-num_11_layer-name_SigmoidSiluMulti_shard-id_0_output_0\n", - "HF: [ 6.4350547e+03 -6.4898600e+05 1.1761116e+05 ... 2.1410337e+01\n", - " 1.2096541e+01 3.6424692e+00]\n", - "FF:[ 6.43506250e+03 -6.48986000e+05 1.17611156e+05 ... 2.14103374e+01\n", - " 1.20965424e+01 3.64246750e+00]\n", - "[ True True True ... True True True]\n", - "[2394]\n", - "Ok!\n", - "mismatch between /usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/bwd_step_0_layers.11.mlp.down_proj.gi_0 and /usr0/home/goliaro/Desktop/FlexFlow/build/inference_tensors/model_0_bwd-step_0_layer-num_11_layer-name_layers_11_feed_forward_w2_shard-id_0_input_0\n", - "HF: [ 6.4350547e+03 -6.4898600e+05 1.1761116e+05 ... 2.1410337e+01\n", - " 1.2096541e+01 3.6424692e+00]\n", - "FF:[ 6.43506250e+03 -6.48986000e+05 1.17611156e+05 ... 2.14103374e+01\n", - " 1.20965424e+01 3.64246750e+00]\n", - "[ True True True ... True True True]\n", - "[2394]\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "4.383680555555555% mismatch in QK prods softmax out grad\n", - "3.9116753472222223% mismatch between HF and FF for kproj (before applying ROPE)\n", - "3.9008246527777777% mismatch between HF and FF for kproj (after applying ROPE)\n", - "4.817708333333334% mismatch in attention input grads\n" - ] - }, - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[11], line 353\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpct_mismatch\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m100\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m% mismatch in attention input grads\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m(pct_mismatch \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.05\u001b[39m)\n\u001b[0;32m--> 353\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - } - ], - "source": [ - "tot_num_layers = 12\n", - "for layer_num in range(tot_num_layers-1, -1, -1):\n", - " # HuggingFace filepaths\n", - " hf_BWD_norm_in = f\"{hf_weight_base_path}/bwd_step_0_norm.gi_0\"\n", - " hf_BWD_loraB_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.go_0\"\n", - " hf_BWD_loraB_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_B.default.gi_0\"\n", - " hf_BWD_loraA_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.go_0\"\n", - " hf_BWD_loraA_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_A.default.gi_0\"\n", - " hf_loraA_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_A.default.weight\"\n", - " hf_loraB_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_B.default.weight\"\n", - " hf_BWD_lora_dropout_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_dropout.default.go_0\"\n", - " hf_BWD_lora_dropout_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.lora_dropout.default.gi_0\"\n", - " hf_BWD_w2_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.go_0\"\n", - " hf_BWD_w2_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.down_proj.gi_0\"\n", - " hf_w2_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.weight\"\n", - " hf_BWD_w3_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.up_proj.go_0\"\n", - " hf_BWD_w3_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.up_proj.gi_0\"\n", - " hf_BWD_w1_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.gate_proj.go_0\"\n", - " hf_BWD_w1_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.gate_proj.gi_0\"\n", - " hf_BWD_act_fn_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.act_fn.gi_0\"\n", - " hf_BWD_act_fn_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.mlp.act_fn.go_0\"\n", - " hf_BWD_ffn_norm_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.post_attention_layernorm.go_0\"\n", - " hf_BWD_ffn_norm_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.post_attention_layernorm.gi_0\"\n", - " hf_BWD_attn_out_out = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.go_0\"\n", - " hf_BWD_attn_q_in = f\"{hf_weight_base_path}/bwd_step_0_layers.11.self_attn.q_proj.gi_0\"\n", - " hf_FWD_w1_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.gate_proj.output_0\"\n", - " hf_FWD_w3_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.up_proj.output_0\"\n", - " hf_FWD_act_fn_out = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.mlp.act_fn.output_0\"\n", - " hf_BWD_attn_oproj_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.gi_0\"\n", - " hf_attn_qproj_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.q_proj.weight\"\n", - " hf_attn_kproj_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.k_proj.weight\"\n", - " hf_attn_vproj_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.v_proj.weight\"\n", - " hf_attn_oproj_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.self_attn.o_proj.weight\"\n", - " # hf_BWD_attn_vproj_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.v_proj.gi_0\"\n", - " # FlexFlow filepaths\n", - " ff_BWD_w2_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_output_0\"\n", - " ff_BWD_w2_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_input_0\"\n", - " ff_BWD_w2_in_pre = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_pre_input_0\"\n", - " ff_w2_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_weight_0\"\n", - " ff_BWD_ssm_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_output_0\"\n", - " ff_BWD_ssm_in1 = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_input_0\"\n", - " ff_BWD_ssm_in2 = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_input_1\"\n", - " ff_BWD_w3_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_output_0\"\n", - " ff_BWD_w3_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_input_0\"\n", - " ff_BWD_lora_A_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_input_0\"\n", - " ff_BWD_lora_B_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_output_0\"\n", - " ff_lora_A_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A\"\n", - " ff_lora_B_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B\"\n", - " ff_BWD_w1_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_output_0\"\n", - " ff_BWD_w1_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_input_0\"\n", - " ff_BWD_w1_in_pre = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_pre_input_0\"\n", - " ff_w1_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_weight_0\"\n", - " ff_BWD_ffn_norm_in1 = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_input_0\"\n", - " ff_BWD_ffn_norm_in2 = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_input_1\"\n", - " ff_BWD_ffn_norm_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_ffn_norm_shard-id_0_output_0\"\n", - " ff_BWD_attn_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_output_0\"\n", - " ff_BWD_attn_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_input_0\"\n", - " ff_BWD_ssm_cached_w1_input = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_cached_w1_output\"\n", - " ff_BWD_ssm_cached_w3_input = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_cached_w3_output\"\n", - " ff_FWD_w1_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_output_0\"\n", - " ff_FWD_w3_out = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_output_0\"\n", - " ff_FWD_act_fnc_out = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_SigmoidSiluMulti_shard-id_0_act_fn_output\"\n", - " ff_BWD_attn_o_proj_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_o_proj_in_grad\"\n", - " # ff_BWD_attn_v_proj_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_v_proj_in_grad\"\n", - " ff_attn_oproj_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_11_layer-name_layers_11_attention_shard-id_0_weight_0\"\n", - " # ff_attn_qk_prods_softmax = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax\"\n", - "\n", - " # xxx = torch.load(hf_BWD_attn_out_out)\n", - " # xxx.detach().cpu().numpy().tofile(f\"{hf_BWD_attn_out_out}.flexflow\")\n", - " # print(f\"{hf_BWD_attn_out_out}.flexflow\")\n", - " \n", - " # HuggingFace checks\n", - " print(\"\\nHuggingface checks:\")\n", - " if layer_num == tot_num_layers-1:\n", - " compare_hf_tensors(hf_BWD_norm_in, hf_BWD_loraB_out)\n", - " compare_hf_tensors(hf_BWD_norm_in, hf_BWD_w2_out)\n", - " compare_hf_tensors(hf_BWD_loraB_out, hf_BWD_w2_out)\n", - " compare_hf_tensors(hf_BWD_loraB_in, hf_BWD_loraA_out)\n", - " # compare_hf_tensors(hf_BWD_w3_out, hf_BWD_w2_out)\n", - " compare_hf_tensors(hf_BWD_act_fn_in, hf_BWD_w1_out)\n", - " check_hf_sum_tensors(hf_BWD_ffn_norm_out, hf_BWD_w1_in, hf_BWD_w3_in)\n", - " check_hf_sum_tensors(hf_BWD_attn_out_out, hf_BWD_ffn_norm_in, hf_BWD_norm_in)\n", - "\n", - " # FlexFlow checks\n", - " print(\"\\nFlexFlow checks:\")\n", - " compare_flexflow_tensors(ff_BWD_w2_out, ff_BWD_lora_B_out)\n", - " compare_flexflow_tensors(ff_BWD_w2_in_pre, ff_BWD_lora_A_in)\n", - " compare_flexflow_tensors(ff_BWD_w2_in, ff_BWD_ssm_out)\n", - " compare_flexflow_tensors(ff_BWD_ssm_in2, ff_BWD_w3_out)\n", - " compare_flexflow_tensors(ff_BWD_ssm_in1, ff_BWD_w1_out)\n", - " compare_flexflow_tensors(ff_BWD_w1_in, ff_BWD_ffn_norm_out)\n", - " compare_flexflow_tensors(ff_BWD_w1_in_pre, ff_BWD_w3_in)\n", - " compare_flexflow_tensors(ff_BWD_ffn_norm_in1, ff_BWD_ffn_norm_in2, max_len=24*768)\n", - " #compare_flexflow_tensors(ff_BWD_ffn_norm_in2, ff_BWD_attn_out, max_len=24*768) # should fail\n", - "\n", - " # HF-FlexFlow checks\n", - " print(\"\\nHuggingface-FlexFlow checks:\")\n", - " compare_tensors(hf_BWD_w2_out, ff_BWD_w2_out, tolerance=1e-5)\n", - " compare_tensors(hf_w2_weight, ff_w2_weight, tolerance=1e-5)\n", - " #print(torch.load(hf_w2_weight).shape)\n", - " compare_tensors(hf_loraA_weight, ff_lora_A_weight, tolerance=1e-5)\n", - " compare_tensors(hf_loraB_weight, ff_lora_B_weight, tolerance=1e-5)\n", - "\n", - " compare_tensors(hf_BWD_loraB_out, ff_BWD_lora_B_out)\n", - " compare_tensors(hf_BWD_loraA_in, ff_BWD_lora_A_in)\n", - "\n", - " compare_tensors(hf_BWD_w2_in, ff_BWD_ssm_out)\n", - " compare_tensors(hf_BWD_w2_in, ff_BWD_w2_in)\n", - " compare_tensors(hf_BWD_w1_out, ff_BWD_w1_out)\n", - " compare_tensors_difference(hf_BWD_w1_in, ff_BWD_w1_in, ff_BWD_w1_in_pre)\n", - "\n", - " compare_tensors(hf_FWD_w1_out, ff_FWD_w1_out)\n", - " compare_tensors(hf_FWD_w3_out, ff_FWD_w3_out)\n", - " compare_tensors(hf_BWD_w3_out, ff_BWD_w3_out)\n", - " compare_tensors(hf_BWD_w3_in, ff_BWD_w3_in)\n", - " compare_tensors(hf_BWD_w1_out, ff_BWD_w1_out)\n", - " # compare_tensors(hf_BWD_ffn_norm_out, ff_BWD_ffn_norm_out)\n", - " # compare_tensors(hf_BWD_ffn_norm_in, ff_BWD_ffn_norm_in2)\n", - " # compare_tensors(hf_BWD_attn_out_out, ff_BWD_ffn_norm_in2)\n", - " compare_tensors(hf_BWD_attn_out_out, ff_BWD_attn_out)\n", - "\n", - " # compare attn weight tensors\n", - " hidden_size = 768\n", - " qProjSize = 64\n", - " num_heads = 12\n", - " num_new_tokens = num_tokens = 24\n", - " ff_attn_weight_tensor = np.loadtxt(ff_attn_oproj_weight, delimiter=',')\n", - " ff_attn_qproj_weight_tensor = ff_attn_weight_tensor[:hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')\n", - " ff_attn_kproj_weight_tensor = ff_attn_weight_tensor[hidden_size*qProjSize*num_heads:2*hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')\n", - " ff_attn_vproj_weight_tensor = ff_attn_weight_tensor[2*hidden_size*qProjSize*num_heads:3*hidden_size*qProjSize*num_heads].reshape((hidden_size,qProjSize*num_heads), order = 'F')\n", - " ff_attn_oproj_weight_tensor = ff_attn_weight_tensor[3*hidden_size*qProjSize*num_heads:].reshape((qProjSize*num_heads,hidden_size), order='F')\n", - " \n", - " hf_attn_qproj_weight_tensor = torch.load(hf_attn_qproj_weight).T.detach().cpu().numpy()\n", - " hf_attn_kproj_weight_tensor = torch.load(hf_attn_kproj_weight).T.detach().cpu().numpy()\n", - " hf_attn_vproj_weight_tensor = torch.load(hf_attn_vproj_weight).T.detach().cpu().numpy()\n", - " hf_attn_oproj_weight_tensor = torch.load(hf_attn_oproj_weight).T.detach().cpu().numpy()\n", - " \n", - " assert(np.allclose(ff_attn_qproj_weight_tensor, hf_attn_qproj_weight_tensor, atol=1e-5))\n", - " assert(np.allclose(ff_attn_kproj_weight_tensor, hf_attn_kproj_weight_tensor, atol=1e-5))\n", - " assert(np.allclose(ff_attn_vproj_weight_tensor, hf_attn_vproj_weight_tensor, atol=1e-5))\n", - " assert(np.allclose(ff_attn_oproj_weight_tensor, hf_attn_oproj_weight_tensor, atol=1e-5))\n", - " \n", - " # Compare attn outproj grad in tensors\n", - " compare_tensors(hf_BWD_attn_oproj_in, ff_BWD_attn_o_proj_in)\n", - " \n", - " ########### Compare value projs grads ######################\n", - " # 1. compare qk prods softmax\n", - " hf_qk_prods_softmax = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.self_attn.qk_prods_softmax\"\n", - " ff_attn_qk_prods_softmax = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax\"\n", - " \n", - " hf_qk_prods_softmax = torch.load(hf_qk_prods_softmax)\n", - " ff_qk_prods_softmax = np.loadtxt(ff_attn_qk_prods_softmax, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", - "\n", - " for head_idx in range(num_heads):\n", - " hf_qkps = hf_qk_prods_softmax.squeeze()[head_idx, :, :].detach().cpu().numpy()\n", - " ff_qkps = ff_qk_prods_softmax[:,:,head_idx]\n", - " assert(np.allclose(ff_qkps, hf_qkps, atol=1e-5))\n", - " \n", - " # 2. compare attn heads grads\n", - " hf_attn_heads_grads = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.gi_0\"\n", - " ff_attn_heads_grads = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_o_proj_in_grad\"\n", - "\n", - " hf_attn_heads_grads = torch.load(hf_attn_heads_grads).T.squeeze().detach().cpu().numpy()\n", - " ff_attn_heads_grads = np.loadtxt(ff_attn_heads_grads, delimiter=',').reshape((qProjSize*num_heads, num_new_tokens), order = 'F')\n", - " assert(np.allclose(ff_attn_heads_grads, hf_attn_heads_grads, atol=1e-2))\n", - "\n", - " # 3. vproj grads\n", - " hf_vproj_grads = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.v_proj.go_0\"\n", - " ff_vproj_grads = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_v_proj_in_grad\"\n", - "\n", - " hf_vproj_grads = torch.load(hf_vproj_grads).squeeze().detach().cpu().numpy()\n", - " ff_vproj_grads = np.loadtxt(ff_vproj_grads, delimiter=',').reshape((num_tokens, qProjSize*num_heads), order='F')\n", - " assert(np.allclose(hf_vproj_grads, ff_vproj_grads, atol=1e-2))\n", - "\n", - " \n", - " \n", - " \n", - " ##############################\n", - " hf_value_states = f\"{hf_weight_base_path}/fwd_step_0_layers.11.self_attn.value_states\"\n", - " hf_value_states = torch.load(hf_value_states).squeeze().permute(2,0,1).detach().cpu().numpy()\n", - " # print(hf_value_states.shape)\n", - " ff_value_states = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_vcache\"\n", - " ff_value_states = np.loadtxt(ff_value_states, delimiter=',').reshape((qProjSize, num_heads, num_tokens), order='F')\n", - " # print(ff_value_states.shape)\n", - " assert(np.allclose(hf_value_states, ff_value_states, atol=1e-2))\n", - " \n", - " \n", - " \n", - " ########## Compare key and query projs grads ##################\n", - " ff_devQKVPRojArray = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devQKVPRojArray\"\n", - " ff_devQKVPRojArray = np.loadtxt(ff_devQKVPRojArray, delimiter=',').reshape((num_tokens, qProjSize*num_heads, 3), order = 'F')\n", - " ff_qProjGrads = ff_devQKVPRojArray[:,:,0]\n", - " ff_kProjGrads = ff_devQKVPRojArray[:,:,1]\n", - " ff_vProjGrads = ff_devQKVPRojArray[:,:,2]\n", - " assert(np.allclose(ff_vProjGrads, ff_vproj_grads, atol=1e-5))\n", - "\n", - " # simulate qk_prods_softmax\n", - " ff_attn_heads_grads = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_o_proj_in_grad\"\n", - " ff_attn_heads_grads = np.loadtxt(ff_attn_heads_grads, delimiter=',').reshape((qProjSize,num_heads, num_new_tokens), order = 'F')\n", - " ff_attn_heads_grads = torch.from_numpy(ff_attn_heads_grads)\n", - " ff_attn_heads_grads = ff_attn_heads_grads.permute(1,2,0)\n", - " ff_value_states = torch.from_numpy(ff_value_states)\n", - " ff_value_states = ff_value_states.permute(1,0,2)\n", - " # print(ff_attn_heads_grads.shape)\n", - " # print(ff_value_states.shape)\n", - " simulated_qk_prods_softmax_grads = torch.matmul(ff_attn_heads_grads, ff_value_states)\n", - " #simulated_qk_prods_softmax_grads = simulated_qk_prods_softmax_grads\n", - " #print(\"Simulated QK prods grads:\")\n", - " #print(simulated_qk_prods_softmax_grads[0,:,:])\n", - "\n", - " # qk prods softmax right before softmax\n", - " hf_qk_prods_softmax2 = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.softmax_op.go_0\"\n", - " hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)\n", - " ff_qk_prods_softmax2 = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax_grad\"\n", - " ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", - " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)\n", - " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()\n", - " # assert(np.allclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2, atol=1e-2))\n", - " mismatches = np.where(~np.isclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2))\n", - " mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]\n", - " pct_mismatch = len(mismatches) / (hf_qk_prods_softmax2.shape[0] * hf_qk_prods_softmax2.shape[1] * hf_qk_prods_softmax2.shape[2])\n", - " print(f\"{pct_mismatch*100}% mismatch in QK prods softmax out grad\")\n", - " assert(pct_mismatch <= 0.05)\n", - "\n", - " # qk prods softmax right after softmax\n", - " hf_qk_prods_softmax2 = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.softmax_op.gi_0\"\n", - " hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)\n", - " ff_qk_prods_softmax2 = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax_grad_in\"\n", - " ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", - " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)\n", - " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()\n", - " assert(np.allclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2, atol=1e-2))\n", - " \n", - " # qk prods softmax after mask\n", - " hf_qk_prods_softmax2 = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.matmul_op.go_0\"\n", - " hf_qk_prods_softmax2 = torch.load(hf_qk_prods_softmax2)\n", - " ff_qk_prods_softmax2 = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax_grad_in_masked\"\n", - " ff_qk_prods_softmax2 = np.loadtxt(ff_qk_prods_softmax2, delimiter=',').reshape((num_new_tokens, num_tokens, num_heads), order = 'F')\n", - " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.squeeze().permute(1,2,0)\n", - " hf_qk_prods_softmax2 = hf_qk_prods_softmax2.detach().cpu().numpy()\n", - " assert(np.allclose(ff_qk_prods_softmax2, hf_qk_prods_softmax2, atol=1e-2))\n", - "\n", - " # Compare query activation\n", - " hf_query_activation = f\"{hf_weight_base_path}/fwd_step_0_layers.{layer_num}.self_attn.query_activation\"\n", - " hf_query_activation = torch.load(hf_query_activation)\n", - " ff_query_activation = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_query_activation\"\n", - " ff_query_activation = np.loadtxt(ff_query_activation, delimiter=',').reshape((qProjSize, num_heads, num_new_tokens), order = 'F')\n", - " hf_query_activation = hf_query_activation.squeeze().permute(2,0,1).detach().cpu().numpy()\n", - " assert(np.allclose(ff_query_activation, hf_query_activation, atol=1e-2))\n", - " \n", - " ########################################## ROPE and Kproj ##########################################\n", - "\n", - " # Compare FF kproj with intermediate kproj data from HF\n", - " hf_kproj_grads_post_rotary = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.identity_kv_post_rotary.go_0\"\n", - " hf_kproj_grads_post_rotary = torch.load(hf_kproj_grads_post_rotary)\n", - " hf_kproj_grads_post_rotary_copy = hf_kproj_grads_post_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()\n", - " # print(\"hf_kproj_grads_post_rotary: \", hf_kproj_grads_post_rotary_copy.shape)\n", - " # print(hf_kproj_grads_post_rotary_copy[:,:,0])\n", - " # Check hf ROPE \n", - " cos, sin = rotary_emb(hf_kproj_grads_post_rotary, seq_len=24)\n", - " cos = cos.cuda()\n", - " sin = sin.cuda()\n", - " # query_states: torch.Size([1, 12, 24, 64])\n", - " # key_states: torch.Size([1, 12, 24, 64])\n", - " # position_ids: torch.Size([1, 24])\n", - " # tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", - " # 18, 19, 20, 21, 22, 23]], device='cuda:0')\n", - " query_states = torch.zeros([1, 12, 24, 64]).cuda()\n", - " position_ids = torch.arange(24).unsqueeze(0).cuda()\n", - " query_states, hf_kproj_grads_post_rotary = apply_rotary_pos_emb(query_states, hf_kproj_grads_post_rotary, cos, sin, position_ids)\n", - " hf_kproj_grads_post_rotary = hf_kproj_grads_post_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()\n", - " # print(\"hf_kproj_grads_post_rotary: \", hf_kproj_grads_post_rotary.shape)\n", - " # print(hf_kproj_grads_post_rotary[:,:,0])\n", - " \n", - " hf_kproj_grads_before_rotary = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.identity_kv_before_rotary.go_0\"\n", - " hf_kproj_grads_before_rotary = torch.load(hf_kproj_grads_before_rotary)\n", - " hf_kproj_grads_before_rotary = hf_kproj_grads_before_rotary.squeeze().permute(1,2,0).detach().cpu().numpy()\n", - " # print(\"hf_kproj_grads_before_rotary: \", hf_kproj_grads_before_rotary.shape)\n", - " # print(hf_kproj_grads_before_rotary[:,:,0])\n", - " # Compare HF rope with manual ROPE\n", - " assert(np.allclose(hf_kproj_grads_post_rotary, hf_kproj_grads_before_rotary, atol=1e-5))\n", - " # Compare HF Kproj with FF Kproj (before ROPE) \n", - " ff_kproj_pre = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devkproj_pre\"\n", - " ff_kproj_pre = np.loadtxt(ff_kproj_pre, delimiter=',').reshape((num_tokens, qProjSize, num_heads), order = 'F')\n", - " # print(\"ff_kproj_pre: \", ff_kproj_pre.shape)\n", - " #print(ff_kproj_pre[:,:,0])\n", - " mismatches = np.where(~np.isclose(ff_kproj_pre, hf_kproj_grads_post_rotary_copy, atol=1e-5))\n", - " mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]\n", - " pct_mismatch = len(mismatches) / (ff_kproj_pre.shape[0] * ff_kproj_pre.shape[1] * ff_kproj_pre.shape[2])\n", - " print(f\"{pct_mismatch*100}% mismatch between HF and FF for kproj (before applying ROPE)\")\n", - " assert(pct_mismatch <= 0.05)\n", - " #assert(np.allclose(ff_kproj_pre, hf_kproj_grads_post_rotary_copy, atol=1e-5))\n", - " \n", - " ff_kproj = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devkproj\"\n", - " ff_kproj = np.loadtxt(ff_kproj, delimiter=',').reshape((num_tokens, qProjSize, num_heads), order = 'F')\n", - " # print(\"ff_kproj: \", ff_kproj.shape)\n", - " #print(ff_kproj[:,:,0])\n", - " mismatches = np.where(~np.isclose(ff_kproj, hf_kproj_grads_before_rotary, atol=1e-5))\n", - " mismatches = [(mismatches[0][i],mismatches[1][i], mismatches[2][i]) for i in range(len(mismatches[0]))]\n", - " pct_mismatch = len(mismatches) / (ff_kproj.shape[0] * ff_kproj.shape[1] * ff_kproj.shape[2])\n", - " print(f\"{pct_mismatch*100}% mismatch between HF and FF for kproj (after applying ROPE)\")\n", - " assert(pct_mismatch <= 0.05)\n", - " #assert(np.allclose(ff_kproj, hf_kproj_grads_before_rotary, atol=1e-5))\n", - " \n", - " \n", - " #assert(np.allclose(hf_kproj_grads_post_rotary, hf_kproj_grads_before_rotary, atol=1e-2))\n", - " hf_kproj_grads = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.k_proj.go_0\"\n", - " hf_kproj_grads = torch.load(hf_kproj_grads).squeeze()\n", - " #print(\"hf_kproj_grads: \", hf_kproj_grads.shape)\n", - " #print(hf_kproj_grads[:,:64])\n", - " reshaped_tensor = hf_kproj_grads.view(24, 12, 64).transpose(1, 2).contiguous().detach().cpu().numpy()\n", - " #print(reshaped_tensor.shape)\n", - " assert(np.allclose(ff_kproj, reshaped_tensor, atol=1e-2))\n", - "\n", - " ########################################## Qproj (with ROPE) ##########################################\n", - "\n", - " # Compare QProj\n", - " hf_qproj_grads = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.q_proj.go_0\"\n", - " hf_qproj_grads = torch.load(hf_qproj_grads).squeeze()\n", - " # print(\"HF Qproj:\")\n", - " # print(hf_qproj_grads.shape)\n", - " reshaped_tensor = hf_qproj_grads.view(24, 12, 64).transpose(1, 2).contiguous().detach().cpu().numpy()\n", - " # print(\"\\t reshaped: \", reshaped_tensor.shape)\n", - " # print(reshaped_tensor[:,:,0])\n", - " ff_qproj = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_devQKVPRojArray\"\n", - " ff_qproj = np.loadtxt(ff_qproj, delimiter=',').reshape((num_tokens, qProjSize, num_heads, 3), order = 'F')[:,:,:,0]\n", - " # print(\"FF Qproj:\")\n", - " # print(ff_qproj.shape)\n", - " # print(ff_qproj[:,:,0])\n", - " assert(np.allclose(ff_qproj, reshaped_tensor, atol=1e-2))\n", - "\n", - " hf_attn_in = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.input_layernorm.go_0\"\n", - " hf_attn_in = torch.load(hf_attn_in)\n", - " # print(\"hf_attn_in: \", hf_attn_in.shape)\n", - " hf_attn_in = hf_attn_in.squeeze().T\n", - " hf_attn_in = hf_attn_in.detach().cpu().numpy()\n", - " # print(\"hf_attn_in: \", hf_attn_in.shape)\n", - " # print(hf_attn_in)\n", - "\n", - " ff_attn_in = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_attn_final_grad_in\"\n", - " ff_attn_in = np.loadtxt(ff_attn_in, delimiter=',').reshape((768,num_tokens), order = 'F')\n", - " # print(\"ff_attn_in: \", ff_attn_in.shape)\n", - " # print(ff_attn_in)\n", - " #assert(np.allclose(ff_attn_in, hf_attn_in, atol=1e-2))\n", - "\n", - " mismatches = np.where(~np.isclose(ff_attn_in, hf_attn_in))\n", - " mismatches = [(mismatches[0][i], mismatches[1][i]) for i in range(len(mismatches[0]))]\n", - " pct_mismatch = len(mismatches) / (hf_attn_in.shape[0] * hf_attn_in.shape[1])\n", - " print(f\"{pct_mismatch*100}% mismatch in attention input grads\")\n", - " assert(pct_mismatch <= 0.05)\n", - " \n", - "\n", - " assert False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([12, 24, 64])\n", - "tensor([[-1.5730e-02, -4.1161e-02, 3.0593e-02, ..., 3.8630e-01,\n", - " 3.2884e-01, 3.6067e-01],\n", - " [-2.8613e+01, -5.5872e+00, 2.9385e+01, ..., 3.8782e+01,\n", - " 9.6901e+01, 9.8470e+01],\n", - " [ 3.3027e+00, 1.8276e-01, -1.8497e+00, ..., -4.4052e+01,\n", - " -2.0010e+01, -2.9788e+01],\n", - " ...,\n", - " [-7.6471e-02, -1.8892e-01, 3.6430e-01, ..., -2.7493e-01,\n", - " 5.7017e-01, -1.5986e-01],\n", - " [ 2.5780e+00, -1.8153e+00, 2.5088e+00, ..., -1.0776e+01,\n", - " 6.2167e-01, 8.3755e-01],\n", - " [-6.8324e-02, 1.7568e-01, -3.2311e-01, ..., 3.1202e+00,\n", - " -2.6652e-01, -1.1917e+00]])\n", - "(24, 64, 12)\n", - "[[-1.5729919e-02 -4.1160699e-02 3.0592799e-02 ... 3.8629669e-01\n", - " 3.2884139e-01 3.6066702e-01]\n", - " [-2.8613457e+01 -5.5871558e+00 2.9384506e+01 ... 3.8781765e+01\n", - " 9.6900581e+01 9.8469597e+01]\n", - " [ 3.3027239e+00 1.8275940e-01 -1.8496730e+00 ... -4.4052174e+01\n", - " -2.0009745e+01 -2.9787930e+01]\n", - " ...\n", - " [-7.6470733e-02 -1.8891659e-01 3.6430117e-01 ... -2.7492592e-01\n", - " 5.7017130e-01 -1.5985624e-01]\n", - " [ 2.5780225e+00 -1.8152566e+00 2.5087588e+00 ... -1.0776262e+01\n", - " 6.2166649e-01 8.3755457e-01]\n", - " [-6.8324409e-02 1.7568478e-01 -3.2310838e-01 ... 3.1202292e+00\n", - " -2.6652411e-01 -1.1917179e+00]]\n" - ] - } - ], - "source": [ - "# value states: torch.Size([1, 12, 24, 64])\n", - "value_states=torch.from_numpy(hf_kproj_grads_post_rotary).permute(2,0,1).unsqueeze(0)\n", - "key_states = value_states\n", - "cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)\n", - "# query_states: torch.Size([1, 12, 24, 64])\n", - "# key_states: torch.Size([1, 12, 24, 64])\n", - "# position_ids: torch.Size([1, 24])\n", - "# tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", - "# 18, 19, 20, 21, 22, 23]], device='cuda:0')\n", - "query_states = torch.zeros([1, 12, 24, 64])\n", - "position_ids = torch.arange(24).unsqueeze(0)\n", - "query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n", - "key_states = key_states.squeeze()\n", - "print(key_states.shape)\n", - "print(key_states[0,:,:])\n", - "print(hf_kproj_grads_before_rotary.shape)\n", - "print(hf_kproj_grads_before_rotary[:,:,0])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", - " 18, 19, 20, 21, 22, 23]], device='cuda:0')" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.arange(24).unsqueeze(0).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 12, 24, 24])\n" - ] - }, - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/alignment_tests.ipynb Cell 6\u001b[0m line \u001b[0;36m1\n\u001b[1;32m 17\u001b[0m ff_qkps \u001b[39m=\u001b[39m ff_qk_prods_softmax[:,:,head_idx]\n\u001b[1;32m 18\u001b[0m \u001b[39massert\u001b[39;00m(np\u001b[39m.\u001b[39mallclose(ff_qkps, hf_qkps, atol\u001b[39m=\u001b[39m\u001b[39m1e-5\u001b[39m))\n\u001b[0;32m---> 19\u001b[0m \u001b[39massert\u001b[39;00m(\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m 21\u001b[0m hf_value_states \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mload(hf_value_states)\u001b[39m#.squeeze().T.detach().cpu().numpy()\u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[39mprint\u001b[39m(hf_value_states\u001b[39m.\u001b[39mshape)\n", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - } - ], - "source": [ - "layer_num = 11\n", - "hf_qk_prods_softmax = f\"{hf_weight_base_path}/fwd_step_0_layers.11.self_attn.qk_prods_softmax\"\n", - "ff_qk_prods_softmax = f\"{ff_weight_base_path}/model_0_bwd-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_attention_shard-id_0_qk_prods_softmax\"\n", - "\n", - "hf_value_states = f\"{hf_weight_base_path}/fwd_step_0_layers.11.self_attn.value_states\"\n", - "\n", - "hf_qk_prods_softmax = torch.load(hf_qk_prods_softmax)#.squeeze().T.detach().cpu().numpy()\n", - "ff_qk_prods_softmax = np.loadtxt(ff_qk_prods_softmax, delimiter=',').reshape((24, 24, 12), order = 'F')\n", - "print(hf_qk_prods_softmax.shape)\n", - "#print(ff_qk_prods_softmax.shape)\n", - "#print(hf_qk_prods_softmax[:,:,0])\n", - "#print()\n", - "#print(ff_qk_prods_softmax[:,:,0])\n", - "\n", - "for head_idx in range(12):\n", - " hf_qkps = hf_qk_prods_softmax.squeeze()[head_idx, :, :].detach().cpu().numpy()\n", - " ff_qkps = ff_qk_prods_softmax[:,:,head_idx]\n", - " assert(np.allclose(ff_qkps, hf_qkps, atol=1e-5))\n", - "\n", - "\n", - "hf_value_states = torch.load(hf_value_states)#.squeeze().T.detach().cpu().numpy()\n", - "print(hf_value_states.shape)\n", - "attn_output = torch.matmul(hf_qk_prods_softmax, hf_value_states)\n", - "print()\n", - "print(attn_output.shape)\n", - "print(attn_output.transpose(1, 2).contiguous().shape)\n", - "print(\"Hf attn heads\")\n", - "print(torch.load(\"/usr0/home/goliaro/Desktop/FlexFlow/tests/peft/hf_peft_tensors/fwd_step_0_layers.11.self_attn.o_proj.input_0\").shape)\n", - "\n", - "print(\"Attn heads grads:\")\n", - "hf_attn_heads_grads = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.o_proj.gi_0\"\n", - "print(torch.load(hf_attn_heads_grads).shape)\n", - "print(\"HF value grads:\")\n", - "vproj_grads = f\"{hf_weight_base_path}/bwd_step_0_layers.{layer_num}.self_attn.v_proj.gi_0\"\n", - "print(torch.load(vproj_grads).shape)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([2, 3, 4])\n", - "torch.Size([4, 3, 2])\n" - ] - } - ], - "source": [ - "a = torch.randn(2,3,4)\n", - "print(a.shape)\n", - "print(a.T.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000,\n", - " 0.0000],\n", - " [ 27.8890, -21.5089, 45.8214, ..., 5.4010, -10.8787,\n", - " 39.7619],\n", - " [ 19.2197, 27.4681, -68.7141, ..., 102.3280, 66.7925,\n", - " -160.8711],\n", - " ...,\n", - " [ 63.9532, 17.4273, -29.4416, ..., 101.6105, 67.5937,\n", - " -198.4432],\n", - " [ 31.2799, 13.0724, -44.7179, ..., 132.4898, 42.3135,\n", - " -194.4037],\n", - " [ 42.3453, -16.2693, -55.7386, ..., 90.5921, 52.2032,\n", - " -124.1802]]], device='cuda:0')\n", - "tensor([[[-1.1845e+06, -6.7460e+05, 7.4494e+05, ..., -9.1441e+05,\n", - " -1.4912e+05, 3.5769e+06],\n", - " [-7.3920e+01, -7.9389e+01, 1.1027e+02, ..., -7.3020e+01,\n", - " -2.3540e+01, 3.4587e+02],\n", - " [-5.3885e+01, -1.7373e+01, -1.9780e+01, ..., 4.1291e+01,\n", - " 5.5099e+01, 5.5910e+01],\n", - " ...,\n", - " [-2.1948e+01, -3.2109e+01, 2.8364e+01, ..., 3.4321e+01,\n", - " 5.0713e+01, 5.6592e+01],\n", - " [-4.4339e+01, -2.8339e+01, 1.4070e+01, ..., 6.2797e+01,\n", - " 3.0760e+01, 6.1743e+01],\n", - " [-1.6287e+01, -5.0413e+01, -1.9940e+01, ..., 4.3766e+01,\n", - " 4.7833e+01, 4.7295e+01]]], device='cuda:0')\n" - ] - } - ], - "source": [ - "a = \"./hf_peft_tensors/bwd_step_0_layers.11.post_attention_layernorm.gi_0\"\n", - "b = \"./hf_peft_tensors/bwd_step_0_layers.11.self_attn.o_proj.go_0\"\n", - "a = torch.load(a)\n", - "b = torch.load(b)\n", - "print(a)\n", - "print(b)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# # Manual matmul checks\n", - "# ff_w2_grad_out_tensor = np.loadtxt(ff_BWD_w2_out, delimiter=',').reshape((768,128), order='F')\n", - "# ff_w2_weight_tensor = np.loadtxt(ff_w2_weight, delimiter=',').reshape((3072,768), order='F')\n", - "# ff_w2_gradin_tensor = np.matmul(ff_w2_weight_tensor, ff_w2_grad_out_tensor).reshape((3072,128), order='F')\n", - "\n", - "# ff_lora_gradout_tensor = np.loadtxt(ff_BWD_lora_B_out, delimiter=',').reshape((768,128), order='F')\n", - "# ff_lora_A_weight_tensor = np.loadtxt(ff_lora_A_weight, delimiter=',').reshape((3072,16), order='F')\n", - "# ff_lora_B_weight_tensor = np.loadtxt(ff_lora_B_weight, delimiter=',').reshape((16,768), order='F')\n", - "# ff_lora_int_grad_tensor = np.matmul(ff_lora_B_weight_tensor, ff_lora_gradout_tensor)\n", - "# ff_lora_gradint_tensor = np.matmul(ff_lora_A_weight_tensor, ff_lora_int_grad_tensor)\n", - "\n", - "# # ff_w2_gradin_tensor = ff_w2_gradin_tensor + ff_lora_gradint_tensor\n", - "# #print(ff_w2_gradin_tensor[:,:24])\n", - "# print(\"calculated LORA grad in\")\n", - "# print(ff_lora_gradint_tensor[:,:24])\n", - "# # ff_BWD_w2_in_pre_tensor = np.loadtxt(ff_BWD_w2_in_pre, delimiter=',').reshape((3072,128), order='F')\n", - "# ff_BWD_lora_A_in_tensor = np.loadtxt(ff_BWD_lora_A_in, delimiter=',').reshape((3072,128), order='F')\n", - "# print(\"FlexFlow LORA grad in\")\n", - "# print(ff_BWD_lora_A_in_tensor[:,:24])\n", - "# # print(ff_BWD_w2_in_pre_tensor[:,:24])\n", - "# print(\"HF lora grad in\")\n", - "# print(torch.load(hf_BWD_loraA_in).squeeze().T.detach().cpu().numpy())\n", - "# compare_tensors(hf_BWD_loraA_in, ff_BWD_lora_A_in)\n", - "\n", - "# simulate act_fn_grad\n", - "# ssm_out_grad_tensor = np.loadtxt(ff_BWD_ssm_out, delimiter=',').reshape((3072,128), order='F')\n", - "# w3_fwd_out_tensor = np.loadtxt(ff_FWD_w3_out, delimiter=',').reshape((3072,128), order='F')\n", - "# #print(ssm_out_grad_tensor.shape, w3_fwd_out_tensor.shape)\n", - "# act_fn_out_check = np.multiply(ssm_out_grad_tensor, w3_fwd_out_tensor)\n", - "# print(\"simulated act fn out - simulated\")\n", - "# print(act_fn_out_check[:,:24])\n", - "# print(\"simulated act fn out - HF\")\n", - "# print(torch.load(hf_BWD_act_fn_out).detach().cpu().numpy().squeeze().T)\n", - "\n", - "# Simulated w3_grad\n", - "# ssm_out_grad_tensor = np.loadtxt(ff_BWD_ssm_out, delimiter=',').reshape((3072,128), order='F')[:,:24]\n", - "# act_fnc_out_tensor = np.loadtxt(ff_FWD_act_fnc_out, delimiter=',').reshape((3072,24), order='F')\n", - "# w3_out_gard_check = np.multiply(ssm_out_grad_tensor, act_fnc_out_tensor)\n", - "# print(\"simulated w3 out - FF\")\n", - "# print(w3_out_gard_check)\n", - "# ff_BWD_w3_out_tensor = np.loadtxt(ff_BWD_w3_out, delimiter=',').reshape((3072,128), order='F')\n", - "# hf_BWD_w3_out_tensor = torch.load(hf_BWD_w3_out).detach().cpu().numpy().squeeze().T\n", - "# print(\"w3 out, FF\")\n", - "# print(ff_BWD_w3_out_tensor[:,:24])\n", - "# print(\"w3 out, HF\")\n", - "# print(hf_BWD_w3_out_tensor)\n", - "\n", - "# print_tensors(hf_BWD_w3_out, ff_BWD_w3_out, \"w3 out\")\n", - "# assert False\n", - "# print()\n", - "# print()\n", - "# print_tensors(hf_BWD_w3_out, ff_BWD_w3_out, \"w3 out\")\n", - "# print_tensors(hf_BWD_w3_in, ff_BWD_w3_in, \"w3 in\")\n", - "# print_tensors(hf_BWD_w1_out, ff_BWD_w1_out, \"w1 out\")\n", - "# print_tensors(hf_BWD_w1_in, ff_BWD_w1_in, \"w1 in\")\n", - "# print_tensors(hf_BWD_ffn_norm_out, ff_BWD_ffn_norm_out, \"ffn norm out\")\n", - "# print_tensors(hf_BWD_ffn_norm_in, ff_BWD_ffn_norm_in2, \"ffn norm in\")\n", - "# print()\n", - "# ff_w1_out_tensor = np.loadtxt(ff_BWD_w1_out, delimiter=',').reshape((3072,128), order='F')\n", - "# ff_w1_in_tensor = np.loadtxt(ff_BWD_w1_in, delimiter=',').reshape((768,128), order='F')\n", - "# ff_w1_in_pre_tensor = np.loadtxt(ff_BWD_w1_in_pre, delimiter=',').reshape((768,128), order='F')\n", - "# ff_w1_only_in_tensor = ff_w1_in_tensor - ff_w1_in_pre_tensor\n", - "# ff_w1_weight_tensor = np.loadtxt(ff_w1_weight, delimiter=',').reshape((768,3072), order='F')\n", - "# ff_w1_in_check_tensor = np.matmul(ff_w1_weight_tensor, ff_w1_out_tensor)\n", - "# print(\"W1 in (simulated):\")\n", - "# print(ff_w1_in_check_tensor[:,:24])\n", - "# print(\"W1 in (FF):\")\n", - "# print(ff_w1_only_in_tensor[:,:24])\n", - "# print(\"W1 in (HF):\")\n", - "# print(torch.load(hf_BWD_w1_in).squeeze().T.detach().cpu().numpy())\n", - "\n", - "# compare_tensors_difference(hf_BWD_w2_in, ff_BWD_w2_in, ff_BWD_lora_A_in)\n", - "# compare_tensors(hf_BWD_w3_out, ff_BWD_w3_out)\n", - "#compare_hf_tensors(hf_BWD_ffn_norm_in, hf_BWD_attn_out_out)\n", - "# print(\"\\nw1 out:\")\n", - "\n", - "# print_tensors(hf_BWD_w1_out, ff_BWD_w1_out)\n", - "# print(\"\\nW1 in\\n\")\n", - "# print_tensors(hf_BWD_w1_in, ff_BWD_w1_in)\n", - "# compare_tensors(hf_BWD_w1_in, ff_BWD_w1_in)\n", - "# print(\"\\nffn_norm\")\n", - "# compare_tensors(hf_BWD_ffn_norm_out, ff_BWD_ffn_norm_out)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n", - "Ok!\n" - ] - } - ], - "source": [ - "for layer_num in range(12):\n", - " hf_lora_A_weight_fp = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_A.default.weight\"\n", - " ff_lora_A_weight_fp = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_A\"\n", - " compare_tensors(hf_lora_A_weight_fp, ff_lora_A_weight_fp, tolerance=1e-5)\n", - " hf_lora_B_weight_fp = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.lora_B.default.weight\"\n", - " ff_lora_B_weight_fp = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_lora_shard-id_0_weight_B\"\n", - " compare_tensors(hf_lora_B_weight_fp, ff_lora_B_weight_fp, tolerance=1e-5)\n", - " hf_w1_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.gate_proj.weight\"\n", - " ff_w1_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w1_shard-id_0_weight_0\"\n", - " compare_tensors(hf_w1_weight, ff_w1_weight, tolerance=1e-5)\n", - " hf_w3_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.up_proj.weight\"\n", - " ff_w3_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w3_shard-id_0_weight_0\"\n", - " compare_tensors(hf_w3_weight, ff_w3_weight, tolerance=1e-5)\n", - " hf_w2_weight = f\"{hf_weight_base_path}/base_model.model.model.layers.{layer_num}.mlp.down_proj.weight\"\n", - " ff_w2_weight = f\"{ff_weight_base_path}/model_0_decoding-step_0_layer-num_{layer_num}_layer-name_layers_{layer_num}_feed_forward_w2_shard-id_0_weight_0\"\n", - " compare_tensors(hf_w2_weight, ff_w2_weight, tolerance=1e-5)\n", - " " - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/peft/hf_finetune.py b/tests/peft/hf_finetune.py index 7836633b30..1e0e0bd167 100644 --- a/tests/peft/hf_finetune.py +++ b/tests/peft/hf_finetune.py @@ -60,6 +60,7 @@ def lm_head_pre_backward_hook(module, grad_output): def peft_backward_hook(module, grad_input, grad_output): + assert(type(grad_input) == tuple and type(grad_output) == tuple) if len(grad_input) == 0 or len(grad_output) == 0: return assert module.name is not None and module.bwd_step is not None @@ -95,23 +96,53 @@ def peft_forward_hook(module, input, output): name = module.name.replace("base_model.model.model.", "") print(f"Forward Hook activated for module: {name}, fwd step: {module.fwd_step}") print("Input:") - for i, inp in enumerate(input): - if type(inp) == torch.Tensor: - print(inp.shape) - torch.save( - inp, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.input_{i}" - ) - else: - print(inp) + if type(input) == torch.Tensor: + print(input.shape) + torch.save( + input, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.input_0" + ) + elif type(input) == tuple: + for i, inp in enumerate(input): + if type(inp) == torch.Tensor: + print(inp.shape) + torch.save( + inp, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.input_{i}" + ) + else: + print(inp) + else: + assert False print("Output:") - for i, out in enumerate(output): - if type(out) == torch.Tensor: - print(out.shape) - torch.save( - out, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.output_{i}" - ) - else: - print(out) + if type(output) == torch.Tensor: + print(output.shape) + torch.save( + output, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.output_0" + ) + # if "layer_norm" in name: + # torch.save( + # output.grad_fn._saved_result1, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.saved_result_1" + # ) + # torch.save( + # output.grad_fn._saved_result2, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.saved_result_2" + # ) + elif type(output) == tuple: + for i, out in enumerate(output): + if type(out) == torch.Tensor: + print(out.shape) + torch.save( + out, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.output_{i}" + ) + # if "layer_norm" in name: + # torch.save( + # out.grad_fn._saved_result1, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.saved_result_1" + # ) + # torch.save( + # out.grad_fn._saved_result2, f"./hf_peft_tensors/fwd_step_{module.fwd_step}_{name}.saved_result_2" + # ) + else: + print(out) + else: + assert False # print("Forward Input/Output: ", input[0].shape, output[0].shape) print("===") module.fwd_step += 1 @@ -221,10 +252,13 @@ def main(): layer.register_full_backward_pre_hook(lm_head_pre_backward_hook) # Save any weights of interest for name, params in model.named_parameters(): + simplified_name = name.replace("base_model.model.model.", "") if "lora" in name: - torch.save(params, f"./hf_peft_tensors/{name}") + torch.save(params, f"./hf_peft_tensors/{simplified_name}") if "lm_head" in name or "norm" in name: - torch.save(params, f"./hf_peft_tensors/{name}") + torch.save(params, f"./hf_peft_tensors/{simplified_name}") + if "down_proj" in name or "self_attn" in name: + torch.save(params, f"./hf_peft_tensors/{simplified_name}") # Load fine-tuning dataset data = load_dataset("Abirate/english_quotes") diff --git a/tests/peft/hf_serve.py b/tests/peft/hf_serve.py index ad1f903cfb..1fde4d5a50 100644 --- a/tests/peft/hf_serve.py +++ b/tests/peft/hf_serve.py @@ -92,14 +92,6 @@ def main(): model = PeftModel.from_pretrained(model, peft_model_id) print(model) - for name, params in model.named_parameters(): - print(name) - if ( - name - == "base_model.model.model.layers.11.mlp.down_proj.lora_B.default.weight" - ): - print(params) - assert False # Register hooks to save tensors, if needed if save_peft_tensors: diff --git a/tests/peft/qk_prods_alignment.ipynb b/tests/peft/qk_prods_alignment.ipynb deleted file mode 100644 index c2a3644b3d..0000000000 --- a/tests/peft/qk_prods_alignment.ipynb +++ /dev/null @@ -1,24 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10.4" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 32f0a15412eabdfb45bfce48cbd489a3e5ddbac5 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Fri, 26 Jan 2024 22:45:50 -0500 Subject: [PATCH 02/18] fix legion aliasing error --- .../ops/add_bias_residual_layer_norm.h | 1 + src/ops/add_bias_residual_layer_norm.cc | 122 ++++++++---------- src/ops/linear.cc | 2 - src/ops/residual_layer_norm.cc | 51 +++----- src/ops/residual_rms_norm.cc | 58 ++++----- 5 files changed, 98 insertions(+), 136 deletions(-) diff --git a/include/flexflow/ops/add_bias_residual_layer_norm.h b/include/flexflow/ops/add_bias_residual_layer_norm.h index 38bb825a4d..550d56c47c 100644 --- a/include/flexflow/ops/add_bias_residual_layer_norm.h +++ b/include/flexflow/ops/add_bias_residual_layer_norm.h @@ -26,6 +26,7 @@ class AddBiasResidualLayerNorm : public Op { float _eps, bool allocate_weights, char const *name); + void map_output_tensors(FFModel &ff) override; void init(FFModel const &) override; void init_inference(FFModel const &, std::vector const &, diff --git a/src/ops/add_bias_residual_layer_norm.cc b/src/ops/add_bias_residual_layer_norm.cc index a2b426ec0d..6b71279971 100644 --- a/src/ops/add_bias_residual_layer_norm.cc +++ b/src/ops/add_bias_residual_layer_norm.cc @@ -348,10 +348,13 @@ void AddBiasResidualLayerNorm::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); + assert(batch_outputs[0]->part == batch_inputs[0]->part); + assert(batch_outputs[0]->region == batch_inputs[0]->region); // attn output + // added: attn_output + attn final bias + residual launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); @@ -362,34 +365,27 @@ void AddBiasResidualLayerNorm::init_inference( EXCLUSIVE, batch_inputs[1]->region)); launcher.add_field(1, FID_DATA); - // added: attn_output + attn final bias + residual - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); // layer norm output launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[1]->region)); - launcher.add_field(3, FID_DATA); + launcher.add_field(2, FID_DATA); // attn final bias launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, weights[0]->region)); - launcher.add_field(4, FID_DATA); + launcher.add_field(3, FID_DATA); if (elementwise_affine) { launcher.add_region_requirement(RegionRequirement(weights[1]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, weights[1]->region)); - launcher.add_field(5, FID_DATA); + launcher.add_field(4, FID_DATA); if (use_bias) { launcher.add_region_requirement(RegionRequirement(weights[2]->part, @@ -397,7 +393,7 @@ void AddBiasResidualLayerNorm::init_inference( READ_ONLY, EXCLUSIVE, weights[2]->region)); - launcher.add_field(6, FID_DATA); + launcher.add_field(5, FID_DATA); } } FutureMap fm = runtime->execute_index_space(ctx, launcher); @@ -420,10 +416,13 @@ void AddBiasResidualLayerNorm::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); - // attn output + assert(outputs[0]->part == inputs[0]->part); + assert(outputs[0]->region == inputs[0]->region); + // input: attn output + // added: attn_output + attn final bias + residual launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, inputs[0]->region)); launcher.add_field(0, FID_DATA); @@ -434,34 +433,27 @@ void AddBiasResidualLayerNorm::init(FFModel const &ff) { EXCLUSIVE, inputs[1]->region)); launcher.add_field(1, FID_DATA); - // added: attn_output + attn final bias + residual - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); // layer norm output launcher.add_region_requirement(RegionRequirement(outputs[1]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, outputs[1]->region)); - launcher.add_field(3, FID_DATA); + launcher.add_field(2, FID_DATA); // attn final bias launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, weights[0]->region)); - launcher.add_field(4, FID_DATA); + launcher.add_field(3, FID_DATA); if (elementwise_affine) { launcher.add_region_requirement(RegionRequirement(weights[1]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, weights[1]->region)); - launcher.add_field(5, FID_DATA); + launcher.add_field(4, FID_DATA); if (use_bias) { launcher.add_region_requirement(RegionRequirement(weights[2]->part, @@ -469,7 +461,7 @@ void AddBiasResidualLayerNorm::init(FFModel const &ff) { READ_ONLY, EXCLUSIVE, weights[2]->region)); - launcher.add_field(6, FID_DATA); + launcher.add_field(5, FID_DATA); } } FutureMap fm = runtime->execute_index_space(ctx, launcher); @@ -478,13 +470,11 @@ void AddBiasResidualLayerNorm::init(FFModel const &ff) { } /* - regions[0](I): attn output - regions[1](I): residual - regions[2](O): added output (attn output + final attn bias + residual) - regions[3](O): layer norm output - regions[4](I): final attn bias - regions[5](I): gamma - regions[6](I): beta + regions[0](I/O): attn output AND added output (attn output + final attn bias + + residual) regions[1](I): residual regions[2](O): layer norm output + regions[3](I): final attn bias + regions[4](I): gamma + regions[5](I): beta */ OpMeta *AddBiasResidualLayerNorm::init_task( Task const *task, @@ -545,10 +535,13 @@ FutureMap AddBiasResidualLayerNorm::inference( 0 /*mapper_id*/, machine_view_hash); launcher.add_future(bc); + assert(batch_outputs[0]->part == batch_inputs[0]->part); + assert(batch_outputs[0]->region == batch_inputs[0]->region); // input + // added_output: input + attn bias + residual launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); @@ -566,20 +559,13 @@ FutureMap AddBiasResidualLayerNorm::inference( EXCLUSIVE, batch_inputs[1]->region)); launcher.add_field(2, FID_DATA); - // added_output: input + attn bias + residual - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(3, FID_DATA); // output launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[1]->region)); - launcher.add_field(4, FID_DATA); + launcher.add_field(3, FID_DATA); if (elementwise_affine) { // gamma launcher.add_region_requirement(RegionRequirement(weights[1]->part, @@ -587,7 +573,7 @@ FutureMap AddBiasResidualLayerNorm::inference( READ_ONLY, EXCLUSIVE, weights[1]->region)); - launcher.add_field(5, FID_DATA); + launcher.add_field(4, FID_DATA); if (use_bias) { // beta launcher.add_region_requirement(RegionRequirement(weights[2]->part, @@ -595,20 +581,31 @@ FutureMap AddBiasResidualLayerNorm::inference( READ_ONLY, EXCLUSIVE, weights[2]->region)); - launcher.add_field(6, FID_DATA); + launcher.add_field(5, FID_DATA); } } return runtime->execute_index_space(ctx, launcher); } +void AddBiasResidualLayerNorm::map_output_tensors(FFModel &ff) { + assert(numOutputs == 2); + assert(outputs[0]->get_volume() == inputs[0]->get_volume()); + outputs[0]->parallel_is = inputs[0]->parallel_is; + outputs[0]->region = inputs[0]->region; + outputs[0]->part = inputs[0]->part; + outputs[0]->region_grad = inputs[0]->region_grad; + outputs[0]->part_grad = inputs[0]->part_grad; + // map output 1 to new region + ff.map_tensor(outputs[1], this); +} + /* - regions[0](I): input + regions[0](I): input / added output regions[1](I): attn bias regions[2](I): residual - regions[3](O): added output - regions[4](O): output - regions[5](I): gamma - regions[6](I): beta + regions[3](O): output + regions[4](I): gamma + regions[5](I): beta */ void AddBiasResidualLayerNorm::inference_task( Task const *task, @@ -626,7 +623,7 @@ void AddBiasResidualLayerNorm::inference_task( *((AddBiasResidualLayerNormMeta **)task->local_args); assert(regions.size() == - 5 + (m->elementwise_affine ? (m->use_bias ? 2 : 1) : 0)); + 4 + (m->elementwise_affine ? (m->use_bias ? 2 : 1) : 0)); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); @@ -635,9 +632,9 @@ void AddBiasResidualLayerNorm::inference_task( GenericTensorAccessorR residual = helperGetGenericTensorAccessorRO( m->input_type[1], regions[2], task->regions[2], FID_DATA, ctx, runtime); GenericTensorAccessorW added_output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[3], task->regions[3], FID_DATA, ctx, runtime); + m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[1], regions[4], task->regions[4], FID_DATA, ctx, runtime); + m->output_type[1], regions[3], task->regions[3], FID_DATA, ctx, runtime); GenericTensorAccessorR gamma, beta; @@ -648,9 +645,9 @@ void AddBiasResidualLayerNorm::inference_task( Domain residual_domain = runtime->get_index_space_domain( ctx, task->regions[2].region.get_index_space()); Domain added_out_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); + ctx, task->regions[0].region.get_index_space()); Domain out_domain = runtime->get_index_space_domain( - ctx, task->regions[4].region.get_index_space()); + ctx, task->regions[3].region.get_index_space()); Domain gamma_domain, beta_domain; @@ -675,23 +672,23 @@ void AddBiasResidualLayerNorm::inference_task( if (m->elementwise_affine) { gamma = helperGetGenericTensorAccessorRO(m->weight_type[1], - regions[5], - task->regions[5], + regions[4], + task->regions[4], FID_DATA, ctx, runtime); gamma_domain = runtime->get_index_space_domain( - ctx, task->regions[5].region.get_index_space()); + ctx, task->regions[4].region.get_index_space()); if (m->use_bias) { beta = helperGetGenericTensorAccessorRO(m->weight_type[2], - regions[6], - task->regions[6], + regions[5], + task->regions[5], FID_DATA, ctx, runtime); beta_domain = runtime->get_index_space_domain( - ctx, task->regions[6].region.get_index_space()); + ctx, task->regions[5].region.get_index_space()); assert(gamma_domain == beta_domain); } @@ -723,12 +720,7 @@ void AddBiasResidualLayerNorm::inference_task( } } AddBiasResidualLayerNorm::save_inference_tensors_to_file( - m, - shard_id, - bc, - {input, residual}, - weights_accessors, - {added_output, output}); + m, shard_id, bc, {residual}, weights_accessors, {added_output, output}); } } diff --git a/src/ops/linear.cc b/src/ops/linear.cc index 209f514f65..4563673385 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -621,8 +621,6 @@ void Linear::inference_task(Task const *task, ctx, task->regions[0].region.get_index_space()); LinearMeta *m = *((LinearMeta **)task->local_args); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); - std::string op_name_without_uid = Linear::get_op_name_without_uid(m); - printf("INF %s\n", op_name_without_uid.c_str()); if (bc->num_tokens == 0) { return; } diff --git a/src/ops/residual_layer_norm.cc b/src/ops/residual_layer_norm.cc index 8563c299ab..dc302ce19c 100644 --- a/src/ops/residual_layer_norm.cc +++ b/src/ops/residual_layer_norm.cc @@ -358,11 +358,14 @@ void ResidualLayerNorm::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); + assert(batch_outputs[0]->part == batch_inputs[0]->part); + assert(batch_outputs[0]->region == batch_inputs[0]->region); int field_id = 0; // input + // added: input + residual(s) launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(field_id++, FID_DATA); @@ -382,13 +385,6 @@ void ResidualLayerNorm::init_inference( batch_inputs[2]->region)); launcher.add_field(field_id++, FID_DATA); } - // added: input + residual(s) - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(field_id++, FID_DATA); // layer norm output launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, 0 /*projection id*/, @@ -433,11 +429,14 @@ void ResidualLayerNorm::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + assert(outputs[0]->part == inputs[0]->part); + assert(outputs[0]->region == inputs[0]->region); int field_id = 0; // input + // added: input + residual(s) launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, inputs[0]->region)); launcher.add_field(field_id++, FID_DATA); @@ -457,13 +456,6 @@ void ResidualLayerNorm::init(FFModel const &ff) { inputs[2]->region)); launcher.add_field(field_id++, FID_DATA); } - // added: input + residual(s) - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(field_id++, FID_DATA); // layer norm output launcher.add_region_requirement(RegionRequirement(outputs[1]->part, 0 /*projection id*/, @@ -884,11 +876,14 @@ FutureMap ResidualLayerNorm::inference( 0 /*mapper_id*/, machine_view_hash); launcher.add_future(bc); + assert(batch_outputs[0]->part == batch_inputs[0]->part); + assert(batch_outputs[0]->region == batch_inputs[0]->region); int field_id = 0; // input + // added: input + residual(s) launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(field_id++, FID_DATA); @@ -908,13 +903,6 @@ FutureMap ResidualLayerNorm::inference( batch_inputs[2]->region)); launcher.add_field(field_id++, FID_DATA); } - // added: input + residual(s) - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(field_id++, FID_DATA); // layer norm output launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, 0 /*projection id*/, @@ -956,7 +944,7 @@ void ResidualLayerNorm::inference_task( } assert(regions.size() == - 4 + m->use_two_residuals + + 3 + m->use_two_residuals + (m->elementwise_affine ? (m->use_bias ? 2 : 1) : 0)); int region_idx = 0, task_region_idx = 0; @@ -984,13 +972,8 @@ void ResidualLayerNorm::inference_task( ctx, runtime); } - GenericTensorAccessorW added_output = - helperGetGenericTensorAccessorWO(m->output_type[0], - regions[region_idx++], - task->regions[task_region_idx++], - FID_DATA, - ctx, - runtime); + GenericTensorAccessorW added_output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO(m->output_type[1], regions[region_idx++], @@ -1029,7 +1012,7 @@ void ResidualLayerNorm::inference_task( assert(residual2_domain == in_domain); } Domain added_out_domain = runtime->get_index_space_domain( - ctx, task->regions[task_region_idx++].region.get_index_space()); + ctx, task->regions[0].region.get_index_space()); Domain out_domain = runtime->get_index_space_domain( ctx, task->regions[task_region_idx++].region.get_index_space()); Domain gamma_domain, beta_domain; @@ -1069,7 +1052,7 @@ void ResidualLayerNorm::inference_task( assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; std::vector input_accessors; - input_accessors.push_back(input); + // input_accessors.push_back(input); input_accessors.push_back(residual1); if (m->use_two_residuals) { input_accessors.push_back(residual2); diff --git a/src/ops/residual_rms_norm.cc b/src/ops/residual_rms_norm.cc index c2fbe11544..fb0944cece 100644 --- a/src/ops/residual_rms_norm.cc +++ b/src/ops/residual_rms_norm.cc @@ -261,6 +261,8 @@ void ResidualRMSNorm::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); + assert(outputs[0]->part == inputs[0]->part); + assert(outputs[0]->region == inputs[0]->region); launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -273,24 +275,18 @@ void ResidualRMSNorm::init(FFModel const &ff) { EXCLUSIVE, inputs[1]->region)); launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(outputs[1]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, outputs[1]->region)); - launcher.add_field(3, FID_DATA); + launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, weights[0]->region)); - launcher.add_field(4, FID_DATA); + launcher.add_field(3, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap(ff, fm); @@ -318,9 +314,11 @@ void ResidualRMSNorm::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); + assert(batch_outputs[0]->part == batch_inputs[0]->part); + assert(batch_outputs[0]->region == batch_inputs[0]->region); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); @@ -330,24 +328,18 @@ void ResidualRMSNorm::init_inference( EXCLUSIVE, batch_inputs[1]->region)); launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[1]->region)); - launcher.add_field(3, FID_DATA); + launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_ONLY, EXCLUSIVE, weights[0]->region)); - launcher.add_field(4, FID_DATA); + launcher.add_field(3, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); @@ -398,6 +390,8 @@ FutureMap 0 /*mapper_id*/, machine_view_hash); launcher.add_future(bc); + assert(batch_outputs[0]->part == batch_inputs[0]->part); + assert(batch_outputs[0]->region == batch_inputs[0]->region); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -410,40 +404,33 @@ FutureMap EXCLUSIVE, batch_inputs[1]->region)); launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, batch_outputs[1]->region)); - launcher.add_field(3, FID_DATA); + launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, READ_WRITE, EXCLUSIVE, weights[0]->region)); - launcher.add_field(4, FID_DATA); + launcher.add_field(3, FID_DATA); return runtime->execute_index_space(ctx, launcher); } /* - regions[0](I): input1 + regions[0](I/O): input1 / residual output regions[1](I): input2 - regions[2](O): residual output - regions[3](O): output - regions[4](I/O): weight + regions[2](O): output + regions[3](I/O): weight */ void ResidualRMSNorm::inference_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - assert(task->regions.size() == 5); - assert(regions.size() == 5); + assert(task->regions.size() == 4); + assert(regions.size() == 4); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { return; @@ -453,19 +440,20 @@ void ResidualRMSNorm::inference_task(Task const *task, m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorR input2 = helperGetGenericTensorAccessorRO( m->input_type[1], regions[1], task->regions[1], FID_DATA, ctx, runtime); + // residual_output is mapped to the same region as the input GenericTensorAccessorW residual_output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); + m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[1], regions[3], task->regions[3], FID_DATA, ctx, runtime); + m->output_type[1], regions[2], task->regions[2], FID_DATA, ctx, runtime); GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[4], task->regions[4], FID_DATA, ctx, runtime); + m->weight_type[0], regions[3], task->regions[3], FID_DATA, ctx, runtime); inference_kernel_wrapper( m, bc, input1, input2, weight, residual_output, output); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; ResidualRMSNorm::save_inference_tensors_to_file( - m, shard_id, bc, {input1, input2}, {weight}, {residual_output, output}); + m, shard_id, bc, {input2}, {weight}, {residual_output, output}); } } From c97f63a368b22363b26667a6a963fee0170aea60 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Fri, 26 Jan 2024 23:45:04 -0500 Subject: [PATCH 03/18] fix warnings --- src/ops/lora_linear.cc | 14 ++++++++++++-- src/ops/residual_rms_norm.cc | 4 ++-- src/ops/rms_norm.cc | 2 +- src/ops/sigmoid_silu_multi.cc | 4 ++-- src/runtime/model.cc | 6 +++--- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index c02bddc5a6..409c814329 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -272,8 +272,6 @@ void load_peft_from_file(DT *ptr, size_t size, int shard_id, std::string filepath) { - std::cout << "Loading LORA weight " << filepath << ", size: " << size - << ", shard: " << shard_id << std::endl; std::ifstream in(filepath, std::ios::in | std::ios::binary); if (!in.good()) { printf("Could not open file: %s\n", filepath.c_str()); @@ -360,13 +358,25 @@ void LoraLinear::register_model_task(Task const *task, std::string w1_filepath = join_path({weights_folder_filepath, lora_layername_substr + "_B_weight"}); if (dt == DT_FLOAT) { + std::cout << "Loading LORA weight " << lora_layername_substr + "_A_weight" + << ", size: " << w0_num_elements << ", shard: " << shard_id + << std::endl; load_peft_from_file( (float *)weight.w0_ptr, w0_num_elements, shard_id, w0_filepath); + std::cout << "Loading LORA weight " << lora_layername_substr + "_B_weight" + << ", size: " << w1_num_elements << ", shard: " << shard_id + << std::endl; load_peft_from_file( (float *)weight.w1_ptr, w1_num_elements, shard_id, w1_filepath); } else if (dt == DT_HALF) { + std::cout << "Loading LORA weight " << lora_layername_substr + "_A_weight" + << ", size: " << w0_num_elements << ", shard: " << shard_id + << std::endl; load_peft_from_file( (half *)weight.w0_ptr, w0_num_elements, shard_id, w0_filepath); + std::cout << "Loading LORA weight " << lora_layername_substr + "_B_weight" + << ", size: " << w1_num_elements << ", shard: " << shard_id + << std::endl; load_peft_from_file( (half *)weight.w1_ptr, w1_num_elements, shard_id, w1_filepath); } else { diff --git a/src/ops/residual_rms_norm.cc b/src/ops/residual_rms_norm.cc index fb0944cece..e549e5f6da 100644 --- a/src/ops/residual_rms_norm.cc +++ b/src/ops/residual_rms_norm.cc @@ -639,7 +639,7 @@ Legion::FutureMap launcher.add_region_requirement( RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/, - READ_WRITE, + reset_input_grads[0] ? WRITE_ONLY : READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad)); launcher.add_field(1, FID_DATA); @@ -647,7 +647,7 @@ Legion::FutureMap launcher.add_region_requirement( RegionRequirement(batch_inputs[1]->part_grad, 0 /*projection id*/, - READ_WRITE, + reset_input_grads[1] ? WRITE_ONLY : READ_WRITE, EXCLUSIVE, batch_inputs[1]->region_grad)); launcher.add_field(2, FID_DATA); diff --git a/src/ops/rms_norm.cc b/src/ops/rms_norm.cc index a1749d66af..b9c9206a00 100644 --- a/src/ops/rms_norm.cc +++ b/src/ops/rms_norm.cc @@ -548,7 +548,7 @@ Legion::FutureMap launcher.add_region_requirement( RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/, - READ_WRITE, + reset_input_grads[0] ? WRITE_ONLY : READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad)); launcher.add_field(1, FID_DATA); diff --git a/src/ops/sigmoid_silu_multi.cc b/src/ops/sigmoid_silu_multi.cc index c01f47aa21..c9f86c42cb 100644 --- a/src/ops/sigmoid_silu_multi.cc +++ b/src/ops/sigmoid_silu_multi.cc @@ -384,7 +384,7 @@ FutureMap launcher.add_region_requirement( RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/, - READ_WRITE, + reset_input_grads[0] ? WRITE_ONLY : READ_WRITE, EXCLUSIVE, batch_inputs[0]->region_grad)); launcher.add_field(1, FID_DATA); @@ -392,7 +392,7 @@ FutureMap launcher.add_region_requirement( RegionRequirement(batch_inputs[1]->part_grad, 0 /*projection id*/, - READ_WRITE, + reset_input_grads[1] ? WRITE_ONLY : READ_WRITE, EXCLUSIVE, batch_inputs[1]->region_grad)); launcher.add_field(2, FID_DATA); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 812a432ef1..9512a0c21a 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -5546,7 +5546,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.set_leaf(); if (pre_register) { Runtime::preregister_task_variant( - registrar, "RMS Norm Inference Task"); + registrar, "Residual RMS Norm Inference Task"); } else { if (enable_control_replication) { registrar.global_registration = false; @@ -5562,7 +5562,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.set_leaf(); if (pre_register) { Runtime::preregister_task_variant( - registrar, "RMS Norm Backward Task"); + registrar, "Residual RMS Norm Backward Task"); } else { if (enable_control_replication) { registrar.global_registration = false; @@ -5577,7 +5577,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.set_leaf(); if (pre_register) { Runtime::preregister_task_variant( - registrar, "RMS Norm PEFT Backward Task"); + registrar, "Residual RMS Norm PEFT Backward Task"); } else { if (enable_control_replication) { registrar.global_registration = false; From 3d5a37c70cfb76485b35a6669b4ee90f97476bb9 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sat, 27 Jan 2024 17:31:45 -0500 Subject: [PATCH 04/18] fix --- include/flexflow/ops/kernels/lora_linear_kernels.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/flexflow/ops/kernels/lora_linear_kernels.h b/include/flexflow/ops/kernels/lora_linear_kernels.h index cf03e518fa..739b94ed22 100644 --- a/include/flexflow/ops/kernels/lora_linear_kernels.h +++ b/include/flexflow/ops/kernels/lora_linear_kernels.h @@ -19,7 +19,6 @@ class LoraLinearMeta : public OpMeta { public: LoraLinearMeta(FFHandler handle, LoraLinear const *li); ~LoraLinearMeta(void); - char op_name[MAX_OPNAME]; // PEFT related fields void *low_rank_activation; void *input_activation; From 571f0d375a6fde72267a72ded40878706ab8ab17 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 29 Jan 2024 05:39:00 +0000 Subject: [PATCH 05/18] fix pipeline parallelism --- src/runtime/inference_manager.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index 4f7d0c9632..e82347c981 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -276,9 +276,9 @@ void InferenceManager::init_operators_inference(FFModel *model) { assert(op->outputs[i]->parallel_is != IndexSpace::NO_SPACE); assert(tensor_buffer[op->outputs[i]].size() > batch_index); outputs[i] = tensor_buffer[op->outputs[i]][batch_index]; - if (i > 0) { - assert(outputs[0]->machine_view == outputs[i]->machine_view); - } + // if (i > 0) { + // assert(outputs[0]->machine_view == outputs[i]->machine_view); + // } assert(outputs[i]->parallel_is != IndexSpace::NO_SPACE); } if (op->is_parallel_op()) { From f4a10f3316d0d9f41f2b1dcad97a1618840cfc51 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 29 Jan 2024 06:57:15 +0000 Subject: [PATCH 06/18] fix tp issue in combine op --- src/parallel_ops/combine.cc | 3 --- src/runtime/model.cc | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index 3433e2f21b..8411b42602 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -84,9 +84,6 @@ Combine::Combine(FFModel &model, dims[i] = _input->dims[i]; } assert(combine_degree > 0 && "Must use combine_degree > 0"); - std::cout << "combine_dim : " << combine_dim - << ", dims[combine_dim].degree: " << dims[combine_dim].degree - << ", combine_degree: " << combine_degree << std::endl; assert(dims[combine_dim].degree % combine_degree == 0); dims[combine_dim].degree /= combine_degree; ParallelTensorBase::update_parallel_ids(numdim, dims); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 9512a0c21a..81cf3d966d 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -3283,11 +3283,11 @@ void FFModel::create_operators_from_layers() { inputs.push_back(tensors_to_parallel_tensors[l->inputs[i]]); } Op *op = nullptr; - // add a combine before arg_topk + // add a combine before arg_topk / argmax if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && - (l->op_type == OP_ARG_TOPK || l->op_type == OP_SOFTMAX || - l->op_type == OP_ARGMAX)) { + (layer_idx == layers.size() - 1 && + (l->op_type == OP_ARG_TOPK || l->op_type == OP_ARGMAX))) { std::vector partitioned_inputs; assert(inputs.size() == 1); Combine *comb = new Combine(*this, From ca683f7fca21997e9b3c61a9f331ed6ca1c4ec81 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 29 Jan 2024 07:33:22 +0000 Subject: [PATCH 07/18] fix lora weight loading with tensor parallelism --- src/ops/lora_linear.cc | 11 ++++++----- src/runtime/inference_manager.cc | 22 ++++++++++++++++------ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index 409c814329..81dc2292f6 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -270,6 +270,7 @@ void LoraLinear::register_peft_model( template void load_peft_from_file(DT *ptr, size_t size, + bool sharded, int shard_id, std::string filepath) { std::ifstream in(filepath, std::ios::in | std::ios::binary); @@ -279,7 +280,7 @@ void load_peft_from_file(DT *ptr, assert(in.good() && "incorrect weight file path"); std::vector
host_array(size); size_t target_data_size = sizeof(DT) * size; - in.seekg(shard_id * target_data_size, in.beg); + in.seekg(sharded * shard_id * target_data_size, in.beg); in.read((char *)host_array.data(), target_data_size); size_t in_get_size = in.gcount(); @@ -362,23 +363,23 @@ void LoraLinear::register_model_task(Task const *task, << ", size: " << w0_num_elements << ", shard: " << shard_id << std::endl; load_peft_from_file( - (float *)weight.w0_ptr, w0_num_elements, shard_id, w0_filepath); + (float *)weight.w0_ptr, w0_num_elements, true, shard_id, w0_filepath); std::cout << "Loading LORA weight " << lora_layername_substr + "_B_weight" << ", size: " << w1_num_elements << ", shard: " << shard_id << std::endl; load_peft_from_file( - (float *)weight.w1_ptr, w1_num_elements, shard_id, w1_filepath); + (float *)weight.w1_ptr, w1_num_elements, false, shard_id, w1_filepath); } else if (dt == DT_HALF) { std::cout << "Loading LORA weight " << lora_layername_substr + "_A_weight" << ", size: " << w0_num_elements << ", shard: " << shard_id << std::endl; load_peft_from_file( - (half *)weight.w0_ptr, w0_num_elements, shard_id, w0_filepath); + (half *)weight.w0_ptr, w0_num_elements, true, shard_id, w0_filepath); std::cout << "Loading LORA weight " << lora_layername_substr + "_B_weight" << ", size: " << w1_num_elements << ", shard: " << shard_id << std::endl; load_peft_from_file( - (half *)weight.w1_ptr, w1_num_elements, shard_id, w1_filepath); + (half *)weight.w1_ptr, w1_num_elements, false, shard_id, w1_filepath); } else { assert(false && "Data type not supported"); } diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index e82347c981..9fe9066d6c 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -390,12 +390,22 @@ void InferenceManager::peft_bwd(FFModel *model, while (model->operators[last_op]->op_type == OP_WEIGHT && last_op > 0) { last_op -= 1; } - // Assert that the previous operator must be softmax - assert(model->operators[last_op]->op_type == OP_SOFTMAX || - model->operators[last_op]->op_type == OP_FUSED); - if (model->operators[last_op]->op_type == OP_FUSED) { - FusedOp *fused_op = static_cast(model->operators[last_op]); - assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_SOFTMAX); + if (model->config.tensor_parallelism_degree > 1) { + if (model->operators[last_op]->op_type == OP_FUSED) { + FusedOp *fused_op = static_cast(model->operators[last_op]); + assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_COMBINE); + assert(fused_op->op_op_type[fused_op->numOperators - 2] == OP_SOFTMAX); + } else { + assert(model->operators[last_op]->op_type == OP_COMBINE) + assert(model->operators[last_op-1]->op_type == OP_SOFTMAX) + } + } else { + // Assert that the previous operator must be softmax + assert(model->operators[last_op]->op_type == OP_SOFTMAX || model->operators[last_op]->op_type == OP_FUSED); + if (model->operators[last_op]->op_type == OP_FUSED) { + FusedOp *fused_op = static_cast(model->operators[last_op]); + assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_SOFTMAX); + } } for (int o = last_op; o >= 0; o--) { Op *op = model->operators[o]; From 378bdb5ba157f18d528c65aa0c7a7dba2ec26c08 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 29 Jan 2024 07:45:10 +0000 Subject: [PATCH 08/18] fixes, implement Combine::peft_bwd_task --- include/flexflow/model.h | 1 + include/flexflow/parallel_ops/combine.h | 9 +++ src/ops/lora_linear.cc | 7 +-- src/parallel_ops/combine.cc | 76 +++++++++++++++++++++++++ src/runtime/inference_manager.cc | 7 ++- src/runtime/model.cc | 15 +++++ 6 files changed, 107 insertions(+), 8 deletions(-) diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 851fac94d2..73c985f757 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -240,6 +240,7 @@ enum TaskIDs { COMBINE_INIT_TASK_ID, COMBINE_FWD_TASK_ID, COMBINE_BWD_TASK_ID, + COMBINE_PEFT_BWD_TASK_ID, REPLICATE_INIT_TASK_ID, REPLICATE_FWD_TASK_ID, REPLICATE_BWD_TASK_ID, diff --git a/include/flexflow/parallel_ops/combine.h b/include/flexflow/parallel_ops/combine.h index 2e4fdb86a9..cca34de119 100644 --- a/include/flexflow/parallel_ops/combine.h +++ b/include/flexflow/parallel_ops/combine.h @@ -40,6 +40,11 @@ class Combine : public ParallelOp { std::vector const &, std::vector const &, MachineView const *mv = nullptr) override; + Legion::FutureMap peft_bwd(FFModel const &, + BatchConfigFuture const &bc, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; void backward(FFModel const &) override; bool get_int_parameter(PMParameter, int *) const override; bool append_parallel_op_info( @@ -56,6 +61,10 @@ class Combine : public ParallelOp { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void peft_bwd_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); template static void forward_task_with_type(Legion::Task const *task, diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index 81dc2292f6..366eca27b7 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -268,11 +268,8 @@ void LoraLinear::register_peft_model( } template -void load_peft_from_file(DT *ptr, - size_t size, - bool sharded, - int shard_id, - std::string filepath) { +void load_peft_from_file( + DT *ptr, size_t size, bool sharded, int shard_id, std::string filepath) { std::ifstream in(filepath, std::ios::in | std::ios::binary); if (!in.good()) { printf("Could not open file: %s\n", filepath.c_str()); diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index 8411b42602..7d56d7e46b 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -275,6 +275,47 @@ void Combine::forward(FFModel const &ff) { runtime->execute_index_space(ctx, launcher); } +FutureMap Combine::peft_bwd(FFModel const &ff, + BatchConfigFuture const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + assert(batch_inputs[0]->data_type == batch_outputs[0]->data_type); + DataType data_type = batch_inputs[0]->data_type; + parallel_is = batch_outputs[0]->parallel_is; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + size_t machine_view_hash = view->hash(); + IndexLauncher launcher(COMBINE_PEFT_BWD_TASK_ID, + parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[0]->part_grad, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_outputs[0]->region_grad)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_inputs[0]->part_grad, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region_grad)); + launcher.add_field(1, FID_DATA); + runtime->execute_index_space(ctx, launcher); +} + void Combine::backward(FFModel const &ff) { ArgumentMap argmap; Context ctx = ff.config.lg_ctx; @@ -397,6 +438,41 @@ void Combine::forward_task_with_type(Task const *task, forward_kernel
(input_ptr, output_ptr, output_domain.get_volume()); } +void Combine::peft_bwd_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + CombineMeta const *m = *((CombineMeta **)task->local_args); + GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( + m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( + m->input_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + DataType data_type = output_grad.data_type; + assert(input_grad.data_type == data_type); + assert(output_grad.domain == input_grad.domain); + if (data_type == DT_FLOAT) { + backward_kernel(output_grad.get_float_ptr(), + input_grad.get_float_ptr(), + output_grad.domain.get_volume()); + } else if (data_type == DT_DOUBLE) { + backward_kernel(output_grad.get_double_ptr(), + input_grad.get_double_ptr(), + output_grad.domain.get_volume()); + } else if (data_type == DT_INT32) { + backward_kernel(output_grad.get_int32_ptr(), + input_grad.get_int32_ptr(), + output_grad.domain.get_volume()); + } else if (data_type == DT_INT64) { + backward_kernel(output_grad.get_int64_ptr(), + input_grad.get_int64_ptr(), + output_grad.domain.get_volume()); + } else { + assert(false && "Unsupported data type in Combine backward"); + } +} + void Combine::backward_task(Task const *task, std::vector const ®ions, Context ctx, diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index 9fe9066d6c..ae3b7eaa14 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -396,12 +396,13 @@ void InferenceManager::peft_bwd(FFModel *model, assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_COMBINE); assert(fused_op->op_op_type[fused_op->numOperators - 2] == OP_SOFTMAX); } else { - assert(model->operators[last_op]->op_type == OP_COMBINE) - assert(model->operators[last_op-1]->op_type == OP_SOFTMAX) + assert(model->operators[last_op]->op_type == OP_COMBINE); + assert(model->operators[last_op - 1]->op_type == OP_SOFTMAX); } } else { // Assert that the previous operator must be softmax - assert(model->operators[last_op]->op_type == OP_SOFTMAX || model->operators[last_op]->op_type == OP_FUSED); + assert(model->operators[last_op]->op_type == OP_SOFTMAX || + model->operators[last_op]->op_type == OP_FUSED); if (model->operators[last_op]->op_type == OP_FUSED) { FusedOp *fused_op = static_cast(model->operators[last_op]); assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_SOFTMAX); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 81cf3d966d..42283f570e 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -6726,6 +6726,21 @@ void register_flexflow_internal_tasks(Runtime *runtime, runtime->register_task_variant(registrar); } } + { + TaskVariantRegistrar registrar(COMBINE_PEFT_BWD_TASK_ID, + "Combine PEFT Backward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "Combine PEFT Backward Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } // Replicate { TaskVariantRegistrar registrar(REPLICATE_INIT_TASK_ID, "Replicate Init"); From afdae452ad1502f4f1d4ad01ca2d19380ad0fc22 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 29 Jan 2024 07:52:33 +0000 Subject: [PATCH 09/18] fix --- src/parallel_ops/combine.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index 7d56d7e46b..7260a2745e 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -313,7 +313,7 @@ FutureMap Combine::peft_bwd(FFModel const &ff, EXCLUSIVE, batch_inputs[0]->region_grad)); launcher.add_field(1, FID_DATA); - runtime->execute_index_space(ctx, launcher); + return runtime->execute_index_space(ctx, launcher); } void Combine::backward(FFModel const &ff) { From 5660f55d8e60ccebfb02a71255ede13e4e8fdf83 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 29 Jan 2024 08:50:55 +0000 Subject: [PATCH 10/18] replicate peft bwd --- include/flexflow/model.h | 1 + include/flexflow/parallel_ops/replicate.h | 9 ++++ src/parallel_ops/replicate.cc | 65 +++++++++++++++++++++++ src/runtime/model.cc | 14 +++++ 4 files changed, 89 insertions(+) diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 73c985f757..974a079ddb 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -244,6 +244,7 @@ enum TaskIDs { REPLICATE_INIT_TASK_ID, REPLICATE_FWD_TASK_ID, REPLICATE_BWD_TASK_ID, + REPLICATE_PEFT_BWD_TASK_ID, REDUCTION_INIT_TASK_ID, REDUCTION_FWD_TASK_ID, REDUCTION_BWD_TASK_ID, diff --git a/include/flexflow/parallel_ops/replicate.h b/include/flexflow/parallel_ops/replicate.h index 65d69d8564..c27616634f 100644 --- a/include/flexflow/parallel_ops/replicate.h +++ b/include/flexflow/parallel_ops/replicate.h @@ -54,10 +54,19 @@ class Replicate : public ParallelOp { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + Legion::FutureMap peft_bwd(FFModel const &, + BatchConfigFuture const &bc, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; static void backward_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void peft_bwd_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static void forward_kernel_wrapper(ReplicateMeta const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, diff --git a/src/parallel_ops/replicate.cc b/src/parallel_ops/replicate.cc index 20face74e8..701db40b49 100644 --- a/src/parallel_ops/replicate.cc +++ b/src/parallel_ops/replicate.cc @@ -273,6 +273,45 @@ void Replicate::forward(FFModel const &ff) { runtime->execute_index_space(ctx, launcher); } +FutureMap Replicate::peft_bwd(FFModel const &ff, + BatchConfigFuture const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + assert(batch_inputs[0]->data_type == batch_outputs[0]->data_type); + DataType data_type = batch_inputs[0]->data_type; + parallel_is = batch_outputs[0]->parallel_is; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + size_t machine_view_hash = view->hash(); + IndexLauncher launcher(REPLICATE_PEFT_BWD_TASK_ID, + parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part_grad, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_outputs[0]->region_grad)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part_grad, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region_grad)); + launcher.add_field(1, FID_DATA); + return runtime->execute_index_space(ctx, launcher); +} + void Replicate::backward(FFModel const &ff) { ArgumentMap argmap; Context ctx = ff.config.lg_ctx; @@ -378,6 +417,32 @@ void Replicate::forward_task(Task const *task, } } +void Replicate::peft_bwd_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + Domain output_grad_domain = runtime->get_index_space_domain( + ctx, task->regions[0].region.get_index_space()); + Domain input_grad_domain = runtime->get_index_space_domain( + ctx, task->regions[1].region.get_index_space()); + // Currently only support the outter most dimension + for (int i = 0; i < output_grad_domain.get_dim() - 1; i++) { + assert(output_grad_domain.lo()[i] == input_grad_domain.lo()[i]); + assert(output_grad_domain.hi()[i] == input_grad_domain.hi()[i]); + } + size_t num_elements = input_grad_domain.get_volume(); + size_t num_replicas = output_grad_domain.get_volume() / num_elements; + float const *output_grad_ptr = helperGetTensorPointerRO( + regions[0], task->regions[0], FID_DATA, ctx, runtime); + float *input_grad_ptr = helperGetTensorPointerRW( + regions[1], task->regions[1], FID_DATA, ctx, runtime); + + backward_kernel( + output_grad_ptr, input_grad_ptr, num_elements, num_replicas); +} + void Replicate::backward_task(Task const *task, std::vector const ®ions, Context ctx, diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 42283f570e..11311053e9 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -6784,6 +6784,20 @@ void register_flexflow_internal_tasks(Runtime *runtime, runtime->register_task_variant(registrar); } } + { + TaskVariantRegistrar registrar(REPLICATE_PEFT_BWD_TASK_ID, "Replicate PEFT Backward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "Replicate PEFT Backward Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } // Reduction { TaskVariantRegistrar registrar(REDUCTION_INIT_TASK_ID, "Reduction Init"); From a9bacd31ab937a364ec926c9339f970c9e918b6c Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Tue, 30 Jan 2024 05:54:43 +0000 Subject: [PATCH 11/18] fixes --- .../flexflow/ops/kernels/softmax_kernels.h | 1 + include/flexflow/parallel_ops/parallel_op.h | 2 +- src/ops/fused.cu | 1 + src/ops/kernels/softmax.cu | 25 ++++++---- src/ops/softmax.cc | 46 +++++++++++++------ src/parallel_ops/combine.cc | 30 +++++++++++- src/parallel_ops/partition.cc | 5 ++ src/parallel_ops/reduction.cc | 7 +++ src/parallel_ops/replicate.cc | 36 +++++++++------ src/runtime/model.cc | 3 +- 10 files changed, 114 insertions(+), 42 deletions(-) diff --git a/include/flexflow/ops/kernels/softmax_kernels.h b/include/flexflow/ops/kernels/softmax_kernels.h index b3dfe4f430..0b7f1090f6 100644 --- a/include/flexflow/ops/kernels/softmax_kernels.h +++ b/include/flexflow/ops/kernels/softmax_kernels.h @@ -38,6 +38,7 @@ void backward_kernel_wrapper(SoftmaxMeta const *m, void inference_kernel_wrapper(SoftmaxMeta const *m, BatchConfig const *bc, + bool is_last_op, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad); diff --git a/include/flexflow/parallel_ops/parallel_op.h b/include/flexflow/parallel_ops/parallel_op.h index 0bf573996c..39324c2a51 100644 --- a/include/flexflow/parallel_ops/parallel_op.h +++ b/include/flexflow/parallel_ops/parallel_op.h @@ -41,7 +41,7 @@ class ParallelOp : public Op { public: Legion::LogicalPartition input_lp, output_grad_lp; std::unordered_map - inference_input_lps; + inference_input_lps, inference_output_grad_lps; }; }; // namespace FlexFlow diff --git a/src/ops/fused.cu b/src/ops/fused.cu index f6bed71f6a..55892ab7e9 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -642,6 +642,7 @@ __host__ void Kernels::Softmax::inference_kernel_wrapper( m, bc, + (op == fused->numOperators - 1), my_input_accessor[0], my_output_accessor[0], output_accessor[fused->numOutputs]); diff --git a/src/ops/kernels/softmax.cu b/src/ops/kernels/softmax.cu index 271a291b09..c8bc242af0 100644 --- a/src/ops/kernels/softmax.cu +++ b/src/ops/kernels/softmax.cu @@ -120,6 +120,7 @@ void backward_kernel_wrapper(SoftmaxMeta const *m, void inference_kernel_wrapper(SoftmaxMeta const *m, BatchConfig const *bc, + bool is_last_op, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output, GenericTensorAccessorW const &output_grad) { @@ -139,11 +140,13 @@ void inference_kernel_wrapper(SoftmaxMeta const *m, output.get_float_ptr(), num_classes, stream); - checkCUDA(cudaMemcpyAsync(output_grad.get_float_ptr(), - output.get_float_ptr(), - output.domain.get_volume() * sizeof(float), - cudaMemcpyDeviceToDevice, - stream)); + if (is_last_op) { + checkCUDA(cudaMemcpyAsync(output_grad.get_float_ptr(), + output.get_float_ptr(), + output.domain.get_volume() * sizeof(float), + cudaMemcpyDeviceToDevice, + stream)); + } } else if (m->output_type[0] == DT_HALF) { Internal::inference_kernel(m, bc, @@ -151,11 +154,13 @@ void inference_kernel_wrapper(SoftmaxMeta const *m, output.get_half_ptr(), num_classes, stream); - checkCUDA(cudaMemcpyAsync(output_grad.get_half_ptr(), - output.get_half_ptr(), - output.domain.get_volume() * sizeof(half), - cudaMemcpyDeviceToDevice, - stream)); + if (is_last_op) { + checkCUDA(cudaMemcpyAsync(output_grad.get_half_ptr(), + output.get_half_ptr(), + output.domain.get_volume() * sizeof(half), + cudaMemcpyDeviceToDevice, + stream)); + } } else { assert(false && "Unsupported data type"); } diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index 1d062b552b..cfc3cf6e40 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -355,14 +355,25 @@ FutureMap Softmax::inference(FFModel const &ff, EXCLUSIVE, batch_outputs[0]->region)); launcher.add_field(1, FID_DATA); - // we add the region below in order to copy the output to the grad tensor - launcher.add_region_requirement( - RegionRequirement(batch_outputs[0]->part_grad, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region_grad)); - launcher.add_field(2, FID_DATA); + // if this is the last operator, we add the region below in order to copy the + // output to the grad tensor + assert(ff.config.computationMode == COMP_MODE_INFERENCE); + int last_op = ff.operators.size() - 1; + assert(ff.operators[last_op]->op_type == OP_ARGMAX || + ff.operators[last_op]->op_type == OP_SAMPLING); + last_op -= 1; + while (ff.operators[last_op]->op_type == OP_WEIGHT && last_op > 0) { + last_op -= 1; + } + if (ff.operators[last_op] == this) { + launcher.add_region_requirement( + RegionRequirement(batch_outputs[0]->part_grad, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region_grad)); + launcher.add_field(2, FID_DATA); + } return runtime->execute_index_space(ctx, launcher); } @@ -371,8 +382,8 @@ void Softmax::inference_task(Task const *task, Context ctx, Runtime *runtime) { assert(task->regions.size() == regions.size()); - assert(regions.size() == 3); - assert(task->regions.size() == 3); + assert(regions.size() == 3 || regions.size() == 2); + bool is_last_op = (regions.size() == 3); BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_tokens == 0) { return; @@ -384,9 +395,16 @@ void Softmax::inference_task(Task const *task, m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorW output_grad = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - inference_kernel_wrapper(m, bc, input, output, output_grad); + GenericTensorAccessorW output_grad; + if (is_last_op) { + output_grad = helperGetGenericTensorAccessorWO(m->output_type[0], + regions[2], + task->regions[2], + FID_DATA, + ctx, + runtime); + } + inference_kernel_wrapper(m, bc, is_last_op, input, output, output_grad); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; @@ -428,7 +446,7 @@ FutureMap Softmax::peft_bwd(FFModel const &ff, launcher.add_region_requirement( RegionRequirement(batch_outputs[0]->part_grad, 0 /*projection id*/, - READ_WRITE, + READ_ONLY, EXCLUSIVE, batch_outputs[0]->region_grad)); launcher.add_field(1, FID_DATA); diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index 7260a2745e..737998b141 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -205,6 +205,11 @@ void Combine::create_input_partition_inference( batch_outputs[0]->parallel_is, batch_inputs[0]->region, inference_input_lps[batch_inputs[0]]); + ff.create_disjoint_partition(batch_inputs[0]->num_dims, + batch_inputs[0]->dims, + batch_inputs[0]->parallel_is, + batch_outputs[0]->region_grad, + inference_output_grad_lps[batch_outputs[0]]); } FutureMap Combine::inference(FFModel const &ff, @@ -244,6 +249,25 @@ FutureMap Combine::inference(FFModel const &ff, EXCLUSIVE, batch_outputs[0]->region)); launcher.add_field(1, FID_DATA); + // if this is the last operator, we add the region below in order to copy the + // output to the grad tensor + assert(ff.config.computationMode == COMP_MODE_INFERENCE); + int last_op = ff.operators.size() - 1; + assert(ff.operators[last_op]->op_type == OP_ARGMAX || + ff.operators[last_op]->op_type == OP_SAMPLING); + last_op -= 1; + while (ff.operators[last_op]->op_type == OP_WEIGHT && last_op > 0) { + last_op -= 1; + } + if (ff.operators[last_op] == this) { + launcher.add_region_requirement( + RegionRequirement(batch_outputs[0]->part_grad, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region_grad)); + launcher.add_field(2, FID_DATA); + } return runtime->execute_index_space(ctx, launcher); } @@ -300,7 +324,7 @@ FutureMap Combine::peft_bwd(FFModel const &ff, 0 /*mapper_id*/, machine_view_hash); launcher.add_region_requirement( - RegionRequirement(batch_outputs[0]->part_grad, + RegionRequirement(inference_output_grad_lps[batch_outputs[0]], 0 /*projection id*/, READ_ONLY, EXCLUSIVE, @@ -309,7 +333,7 @@ FutureMap Combine::peft_bwd(FFModel const &ff, launcher.add_region_requirement( RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/, - READ_WRITE, + WRITE_ONLY, EXCLUSIVE, batch_inputs[0]->region_grad)); launcher.add_field(1, FID_DATA); @@ -400,6 +424,7 @@ void Combine::forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { + printf("INF combine\n"); assert(regions.size() == 2); assert(task->regions.size() == 2); CombineMeta const *m = *((CombineMeta **)task->local_args); @@ -442,6 +467,7 @@ void Combine::peft_bwd_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { + printf("BWD combine\n"); assert(regions.size() == 2); assert(task->regions.size() == 2); CombineMeta const *m = *((CombineMeta **)task->local_args); diff --git a/src/parallel_ops/partition.cc b/src/parallel_ops/partition.cc index 353b3ce398..df3c56346c 100644 --- a/src/parallel_ops/partition.cc +++ b/src/parallel_ops/partition.cc @@ -197,6 +197,11 @@ void Repartition::create_input_partition_inference( batch_outputs[0]->parallel_is, batch_inputs[0]->region, inference_input_lps[batch_inputs[0]]); + ff.create_disjoint_partition(batch_inputs[0]->num_dims, + batch_inputs[0]->dims, + batch_inputs[0]->parallel_is, + batch_outputs[0]->region_grad, + inference_output_grad_lps[batch_outputs[0]]); } FutureMap diff --git a/src/parallel_ops/reduction.cc b/src/parallel_ops/reduction.cc index 5dca591328..2e7b4b6723 100644 --- a/src/parallel_ops/reduction.cc +++ b/src/parallel_ops/reduction.cc @@ -122,6 +122,13 @@ void Reduction::create_input_partition_inference( batch_outputs[0]->parallel_is, batch_inputs[0]->region, inference_input_lps[batch_inputs[0]]); + // output_grad_lp is an aliased partitioning along the replica dim + ff.create_aliased_partition(batch_inputs[0]->num_dims, + batch_inputs[0]->dims, + reduction_dim, + batch_inputs[0]->parallel_is, + batch_outputs[0]->region_grad, + inference_output_grad_lps[batch_outputs[0]]); } OpMeta *Reduction::init_task(Task const *task, diff --git a/src/parallel_ops/replicate.cc b/src/parallel_ops/replicate.cc index 701db40b49..e4f19faa0a 100644 --- a/src/parallel_ops/replicate.cc +++ b/src/parallel_ops/replicate.cc @@ -122,6 +122,12 @@ void Replicate::create_input_partition_inference( batch_outputs[0]->parallel_is, batch_inputs[0]->region, inference_input_lps[batch_inputs[0]]); + // output_grad_lp is a disjoint partition + ff.create_disjoint_partition(batch_inputs[0]->num_dims, + batch_inputs[0]->dims, + batch_inputs[0]->parallel_is, + batch_outputs[0]->region_grad, + inference_output_grad_lps[batch_outputs[0]]); } OpMeta *Replicate::init_task(Task const *task, @@ -274,10 +280,10 @@ void Replicate::forward(FFModel const &ff) { } FutureMap Replicate::peft_bwd(FFModel const &ff, - BatchConfigFuture const &bc, - std::vector const &batch_inputs, - std::vector const &batch_outputs, - MachineView const *mv) { + BatchConfigFuture const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { ArgumentMap argmap; Context ctx = ff.config.lg_ctx; Runtime *runtime = ff.config.lg_hlr; @@ -297,17 +303,19 @@ FutureMap Replicate::peft_bwd(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part_grad, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - batch_outputs[0]->region_grad)); + launcher.add_region_requirement( + RegionRequirement(inference_output_grad_lps[batch_outputs[0]], + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_outputs[0]->region_grad)); launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part_grad, - 0 /*projection id*/, - READ_WRITE, - EXCLUSIVE, - batch_inputs[0]->region_grad)); + launcher.add_region_requirement( + RegionRequirement(batch_inputs[0]->part_grad, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + batch_inputs[0]->region_grad)); launcher.add_field(1, FID_DATA); return runtime->execute_index_space(ctx, launcher); } diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 11311053e9..eca8c31785 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -6785,7 +6785,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, } } { - TaskVariantRegistrar registrar(REPLICATE_PEFT_BWD_TASK_ID, "Replicate PEFT Backward"); + TaskVariantRegistrar registrar(REPLICATE_PEFT_BWD_TASK_ID, + "Replicate PEFT Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); if (pre_register) { From f3a97ff3832261393b4c7f0c6231fe292b4964c9 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Wed, 31 Jan 2024 02:45:49 +0000 Subject: [PATCH 12/18] fix --- src/parallel_ops/combine.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index 737998b141..354faa5e1a 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -261,7 +261,7 @@ FutureMap Combine::inference(FFModel const &ff, } if (ff.operators[last_op] == this) { launcher.add_region_requirement( - RegionRequirement(batch_outputs[0]->part_grad, + RegionRequirement(inference_output_grad_lps[batch_outputs[0]], 0 /*projection id*/, WRITE_ONLY, EXCLUSIVE, @@ -424,9 +424,8 @@ void Combine::forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - printf("INF combine\n"); - assert(regions.size() == 2); - assert(task->regions.size() == 2); + // assert(regions.size() == 2); + // assert(task->regions.size() == 2); CombineMeta const *m = *((CombineMeta **)task->local_args); DataType data_type = m->input_type[0]; if (data_type == DT_HALF) { From e0a58bb73364660be05aa8162e960399b5f9d557 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Wed, 31 Jan 2024 23:48:02 +0000 Subject: [PATCH 13/18] fix combine and fwd-bwd pass dependencies --- include/flexflow/model.h | 1 + src/ops/softmax.cc | 1 + src/parallel_ops/allreduce.cc | 2 +- src/parallel_ops/combine.cc | 60 +++++++++++++++----------------- src/runtime/inference_manager.cc | 24 ++++--------- src/runtime/model.cc | 36 ++++++++++++++++--- 6 files changed, 69 insertions(+), 55 deletions(-) diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 974a079ddb..b3a6a85808 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -1124,6 +1124,7 @@ class FFModel { Legion::IndexSpace get_task_is(Legion::Domain const &domain) const; Legion::IndexSpace get_task_is(ParallelConfig const &pc) const; Legion::IndexSpace get_task_is(MachineView const &view) const; + bool need_to_add_combine(int layer_idx) const; bool is_mlp_block(int layer_idx) const; void create_operators_from_layers(); Op *create_operator_from_layer(Layer *layer, diff --git a/src/ops/softmax.cc b/src/ops/softmax.cc index cfc3cf6e40..90f77ab760 100644 --- a/src/ops/softmax.cc +++ b/src/ops/softmax.cc @@ -360,6 +360,7 @@ FutureMap Softmax::inference(FFModel const &ff, assert(ff.config.computationMode == COMP_MODE_INFERENCE); int last_op = ff.operators.size() - 1; assert(ff.operators[last_op]->op_type == OP_ARGMAX || + ff.operators[last_op]->op_type == OP_ARG_TOPK || ff.operators[last_op]->op_type == OP_SAMPLING); last_op -= 1; while (ff.operators[last_op]->op_type == OP_WEIGHT && last_op > 0) { diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc index 4478a2aedc..05c2761e3b 100644 --- a/src/parallel_ops/allreduce.cc +++ b/src/parallel_ops/allreduce.cc @@ -365,7 +365,7 @@ FutureMap AllReduce::peft_bwd(FFModel const &ff, launcher.add_region_requirement( RegionRequirement(batch_inputs[0]->part_grad, 0 /*projection id*/, - READ_WRITE, + WRITE_ONLY, EXCLUSIVE, batch_inputs[0]->region_grad)); launcher.add_field(0, FID_DATA); diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index 354faa5e1a..a328ec7cac 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -199,12 +199,18 @@ void Combine::create_input_partition_inference( assert(ff.config.computationMode == COMP_MODE_INFERENCE); assert(batch_outputs[0]->part != LogicalPartition::NO_PART); assert(batch_inputs[0]->part != LogicalPartition::NO_PART); - // input_lp is a disjoint partition + // partition batch_inputs[0]->region into inference_input_lps[batch_inputs[0]] + // according to the partitioning of batch_outputs[0] (i.e. make the + // partitioned dimension whole again by combining the partitions) ff.create_disjoint_partition(batch_outputs[0]->num_dims, batch_outputs[0]->dims, batch_outputs[0]->parallel_is, batch_inputs[0]->region, inference_input_lps[batch_inputs[0]]); + // partition batch_outputs[0]->region_grad into + // inference_output_grad_lps[batch_outputs[0]] according to the partitioning + // of batch_inputs[0] (i.e. restore the partition in the dimension that was + // combined in the forward pass) ff.create_disjoint_partition(batch_inputs[0]->num_dims, batch_inputs[0]->dims, batch_inputs[0]->parallel_is, @@ -249,25 +255,6 @@ FutureMap Combine::inference(FFModel const &ff, EXCLUSIVE, batch_outputs[0]->region)); launcher.add_field(1, FID_DATA); - // if this is the last operator, we add the region below in order to copy the - // output to the grad tensor - assert(ff.config.computationMode == COMP_MODE_INFERENCE); - int last_op = ff.operators.size() - 1; - assert(ff.operators[last_op]->op_type == OP_ARGMAX || - ff.operators[last_op]->op_type == OP_SAMPLING); - last_op -= 1; - while (ff.operators[last_op]->op_type == OP_WEIGHT && last_op > 0) { - last_op -= 1; - } - if (ff.operators[last_op] == this) { - launcher.add_region_requirement( - RegionRequirement(inference_output_grad_lps[batch_outputs[0]], - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region_grad)); - launcher.add_field(2, FID_DATA); - } return runtime->execute_index_space(ctx, launcher); } @@ -310,23 +297,28 @@ FutureMap Combine::peft_bwd(FFModel const &ff, assert(numOutputs == 1); assert(numInputs == 1); assert(batch_inputs[0]->data_type == batch_outputs[0]->data_type); - DataType data_type = batch_inputs[0]->data_type; - parallel_is = batch_outputs[0]->parallel_is; - MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + DataType data_type = inputs[0]->data_type; + + // Warning: we need to use batch_inputs[0] here, instead of the usual + // batch_outputs[0] + parallel_is = batch_inputs[0]->parallel_is; + MachineView const *view = mv ? mv : &batch_inputs[0]->machine_view; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); size_t machine_view_hash = view->hash(); IndexLauncher launcher(COMBINE_PEFT_BWD_TASK_ID, parallel_is, - TaskArgument(NULL, 0), + TaskArgument(&data_type, sizeof(DataType)), argmap, Predicate::TRUE_PRED, false /*must*/, 0 /*mapper_id*/, machine_view_hash); + launcher.add_future(bc); launcher.add_region_requirement( RegionRequirement(inference_output_grad_lps[batch_outputs[0]], 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, batch_outputs[0]->region_grad)); launcher.add_field(0, FID_DATA); @@ -424,8 +416,8 @@ void Combine::forward_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - // assert(regions.size() == 2); - // assert(task->regions.size() == 2); + assert(regions.size() == 2); + assert(task->regions.size() == 2); CombineMeta const *m = *((CombineMeta **)task->local_args); DataType data_type = m->input_type[0]; if (data_type == DT_HALF) { @@ -466,15 +458,19 @@ void Combine::peft_bwd_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - printf("BWD combine\n"); assert(regions.size() == 2); assert(task->regions.size() == 2); - CombineMeta const *m = *((CombineMeta **)task->local_args); + // CombineMeta const *m = *((CombineMeta **)task->local_args); + BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); + if (bc->num_active_peft_tokens() == 0) { + return; + } + // TODO: figure out why m->output_type[0] or m->input_type[0] are not working + DataType data_type = *((DataType *)task->args); GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( - m->output_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + data_type, regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( - m->input_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - DataType data_type = output_grad.data_type; + data_type, regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input_grad.data_type == data_type); assert(output_grad.domain == input_grad.domain); if (data_type == DT_FLOAT) { diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index ae3b7eaa14..066701f65c 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -385,28 +385,18 @@ void InferenceManager::peft_bwd(FFModel *model, int last_op = model->operators.size() - 1; // Assert that the last operator must be argmax or sampling assert(model->operators[last_op]->op_type == OP_ARGMAX || + model->operators[last_op]->op_type == OP_ARG_TOPK || model->operators[last_op]->op_type == OP_SAMPLING); last_op -= 1; while (model->operators[last_op]->op_type == OP_WEIGHT && last_op > 0) { last_op -= 1; } - if (model->config.tensor_parallelism_degree > 1) { - if (model->operators[last_op]->op_type == OP_FUSED) { - FusedOp *fused_op = static_cast(model->operators[last_op]); - assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_COMBINE); - assert(fused_op->op_op_type[fused_op->numOperators - 2] == OP_SOFTMAX); - } else { - assert(model->operators[last_op]->op_type == OP_COMBINE); - assert(model->operators[last_op - 1]->op_type == OP_SOFTMAX); - } - } else { - // Assert that the previous operator must be softmax - assert(model->operators[last_op]->op_type == OP_SOFTMAX || - model->operators[last_op]->op_type == OP_FUSED); - if (model->operators[last_op]->op_type == OP_FUSED) { - FusedOp *fused_op = static_cast(model->operators[last_op]); - assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_SOFTMAX); - } + // Assert that the previous operator must be softmax + assert(model->operators[last_op]->op_type == OP_SOFTMAX || + model->operators[last_op]->op_type == OP_FUSED); + if (model->operators[last_op]->op_type == OP_FUSED) { + FusedOp *fused_op = static_cast(model->operators[last_op]); + assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_SOFTMAX); } for (int o = last_op; o >= 0; o--) { Op *op = model->operators[o]; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index eca8c31785..6d77730e47 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -3270,6 +3270,34 @@ bool FFModel::is_mlp_block(int layer_idx) const { return false; } +bool FFModel::need_to_add_combine(int layer_idx) const { + if (config.computationMode != COMP_MODE_INFERENCE || + config.tensor_parallelism_degree == 1 || layers.size() <= 2) { + return false; + } + auto const &l = layers[layer_idx]; + // softmax followed by argmax/arg_topk: add combine before softmax + if (layer_idx == layers.size() - 2) { + auto const &l_next = layers[layer_idx + 1]; + if (l->op_type == OP_SOFTMAX && + (l_next->op_type == OP_ARG_TOPK || l_next->op_type == OP_ARGMAX)) { + return true; + } else { + return false; + } + } + // argmax/arg_topk not precedent by softmax: add combine before + // argmax/arg_topk + if (layer_idx == layers.size() - 1 && + (l->op_type == OP_ARG_TOPK || l->op_type == OP_ARGMAX)) { + auto const &l_prev = layers[layer_idx - 1]; + if (l_prev->op_type == OP_SOFTMAX) { + return false; + } + return true; + } + return false; +} void FFModel::create_operators_from_layers() { std::map tensors_to_parallel_tensors; // for (auto const &l : layers) { @@ -3283,11 +3311,9 @@ void FFModel::create_operators_from_layers() { inputs.push_back(tensors_to_parallel_tensors[l->inputs[i]]); } Op *op = nullptr; - // add a combine before arg_topk / argmax - if (config.computationMode == COMP_MODE_INFERENCE && - config.tensor_parallelism_degree > 1 && - (layer_idx == layers.size() - 1 && - (l->op_type == OP_ARG_TOPK || l->op_type == OP_ARGMAX))) { + // add a combine before last arg_max / arg_topk or before second-to-last + // softmax + if (need_to_add_combine(layer_idx)) { std::vector partitioned_inputs; assert(inputs.size() == 1); Combine *comb = new Combine(*this, From 50fc13d20a8175720e031b785ec21b9a3248722d Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Wed, 31 Jan 2024 23:53:23 +0000 Subject: [PATCH 14/18] fix replicate bwd --- src/parallel_ops/replicate.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/parallel_ops/replicate.cc b/src/parallel_ops/replicate.cc index e4f19faa0a..2a3818e212 100644 --- a/src/parallel_ops/replicate.cc +++ b/src/parallel_ops/replicate.cc @@ -291,8 +291,12 @@ FutureMap Replicate::peft_bwd(FFModel const &ff, assert(numInputs == 1); assert(batch_inputs[0]->data_type == batch_outputs[0]->data_type); DataType data_type = batch_inputs[0]->data_type; - parallel_is = batch_outputs[0]->parallel_is; - MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + + // Warning: we need to use batch_inputs[0] here, instead of the usual + // batch_outputs[0] + parallel_is = batch_inputs[0]->parallel_is; + MachineView const *view = mv ? mv : &batch_inputs[0]->machine_view; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); size_t machine_view_hash = view->hash(); IndexLauncher launcher(REPLICATE_PEFT_BWD_TASK_ID, From f2c9a052ddbf4c469f2755c224d0d2faaa1509c3 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Thu, 1 Feb 2024 04:58:32 +0000 Subject: [PATCH 15/18] fix --- src/runtime/inference_manager.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index 066701f65c..66c47e6559 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -391,13 +391,6 @@ void InferenceManager::peft_bwd(FFModel *model, while (model->operators[last_op]->op_type == OP_WEIGHT && last_op > 0) { last_op -= 1; } - // Assert that the previous operator must be softmax - assert(model->operators[last_op]->op_type == OP_SOFTMAX || - model->operators[last_op]->op_type == OP_FUSED); - if (model->operators[last_op]->op_type == OP_FUSED) { - FusedOp *fused_op = static_cast(model->operators[last_op]); - assert(fused_op->op_op_type[fused_op->numOperators - 1] == OP_SOFTMAX); - } for (int o = last_op; o >= 0; o--) { Op *op = model->operators[o]; if (op->op_type == OP_WEIGHT) { From cd68f5d0cf6348410b718283517e2cfa947309ee Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sat, 3 Feb 2024 16:35:54 +0000 Subject: [PATCH 16/18] let user control amount of peft memory --- include/flexflow/config.h | 1 + inference/python/incr_decoding.py | 5 ++- inference/python/spec_infer.py | 5 ++- python/flexflow/core/__init__.py | 5 ++- python/flexflow/serve/__init__.py | 32 +++++++++++++++++-- src/runtime/model.cc | 19 +++++++++-- src/runtime/model.cu | 4 +++ .../python_test_configs/generate_configs.py | 5 ++- 8 files changed, 67 insertions(+), 9 deletions(-) diff --git a/include/flexflow/config.h b/include/flexflow/config.h index 2f112d4fc9..9bb230132a 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -154,6 +154,7 @@ class FFConfig { size_t offload_reserve_space_size; DataType quantization_type; // PEFT related fields + bool enable_peft; size_t peft_activation_reserve_space_size; size_t peft_weight_reserve_space_size; // Control parallelizable dimensions diff --git a/inference/python/incr_decoding.py b/inference/python/incr_decoding.py index 4a146ab503..ed57453762 100644 --- a/inference/python/incr_decoding.py +++ b/inference/python/incr_decoding.py @@ -51,9 +51,12 @@ def get_configs(): "tensor_parallelism_degree": 1, "pipeline_parallelism_degree": 4, "offload": False, - "offload_reserve_space_size": 1024**2, + "offload_reserve_space_size": 8 * 1024, # 8GB "use_4bit_quantization": False, "use_8bit_quantization": False, + "enable_peft": False, + "peft_activation_reserve_space_size": 1024, # 1GB + "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "inference_debugging": False, "fusion": True, diff --git a/inference/python/spec_infer.py b/inference/python/spec_infer.py index c9fb5cc7bb..b31ddf4604 100644 --- a/inference/python/spec_infer.py +++ b/inference/python/spec_infer.py @@ -51,9 +51,12 @@ def get_configs(): "tensor_parallelism_degree": 2, "pipeline_parallelism_degree": 2, "offload": False, - "offload_reserve_space_size": 1024**2, + "offload_reserve_space_size": 8 * 1024, # 8GB "use_4bit_quantization": False, "use_8bit_quantization": False, + "enable_peft": False, + "peft_activation_reserve_space_size": 1024, # 1GB + "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "inference_debugging": False, "fusion": True, diff --git a/python/flexflow/core/__init__.py b/python/flexflow/core/__init__.py index d7b1a595d2..2614518acf 100644 --- a/python/flexflow/core/__init__.py +++ b/python/flexflow/core/__init__.py @@ -87,7 +87,10 @@ "offload": "-offload", "offload_reserve_space_size": "-offload-reserve-space-size", "use_4bit_quantization": "--4bit-quantization", - "use_8bit_quantization": "--8bit-quantization" + "use_8bit_quantization": "--8bit-quantization", + "enable_peft": "", + "peft_activation_reserve_space_size": "-peft-activation-reserve-space-size", + "peft_weight_reserve_space_size": "-peft-weight-reserve-space-size", } diff --git a/python/flexflow/serve/__init__.py b/python/flexflow/serve/__init__.py index 274b431ad8..5805670ae0 100644 --- a/python/flexflow/serve/__init__.py +++ b/python/flexflow/serve/__init__.py @@ -44,6 +44,9 @@ def init( offload_reserve_space_size: Optional[int] = None, use_4bit_quantization: Optional[bool] = None, use_8bit_quantization: Optional[bool] = None, + enable_peft: Optional[bool] = None, + peft_activation_reserve_space_size: Optional[int] = None, + peft_weight_reserve_space_size: Optional[int] = None, profiling: Optional[bool] = None, inference_debugging: Optional[bool] = None, fusion: Optional[bool] = None, @@ -68,9 +71,12 @@ def init( - tensor_parallelism_degree: the degree of parallelization in the tensor parallel dimension (using the Megatron technique), defaults to 1 - pipeline_parallelism_degree: the degree of parallelization in the pipeline parallel dimension, defaults to 1 - offload: whether to enable offloading of the weights to CPU, defaults to False - - offload_reserve_space_size: the space (in MB) to reserve on CPU for offloading, default to 1024^2 + - offload_reserve_space_size: the space (in MB) to reserve on CPU for offloading, defaults to 8 GB - use_4bit_quantization: whether to use 4-bit quantization, defaults to False - use_8bit_quantization: whether to use 8-bit quantization, defaults to False + - enable_peft: whether to enable the use of PEFT, defaults to False + - peft_activation_reserve_space_size: the space (in MB) to reserve on GPU for PEFT activations, default to 1 GB + - peft_weight_reserve_space_size: the space (in MB) to reserve on GPU for PEFT weights, default to 1 GB - profiling: whether to enable the FlexFlow profiling mode, defaults to False - inference_debugging: whether to run inference in debugging mode, saving all inputs/outputs/weights to file, defaults to False - fusion: whether to enable the FlexFlow operator fusion optimization, defaults to True @@ -98,12 +104,18 @@ def init( :type pipeline_parallelism_degree: Optional[int], optional :param offload: whether to enable offloading of the weights to CPU, defaults to False :type offload: Optional[bool], optional - :param offload_reserve_space_size: the space (in MB) to reserve on CPU for offloading, default to 1024^2 + :param offload_reserve_space_size: the space (in MB) to reserve on CPU for offloading, defaults to 8 GB :type offload_reserve_space_size: Optional[int], optional :param use_4bit_quantization: whether to use 4-bit quantization, defaults to False :type use_4bit_quantization: Optional[bool], optional :param use_8bit_quantization: whether to use 8-bit quantization, defaults to False :type use_8bit_quantization: Optional[bool], optional + :param enable_peft: whether to enable the use of PEFT, defaults to False + :type enable_peft: Optional[bool], optional + :param peft_activation_reserve_space_size: the space (in MB) to reserve on GPU for PEFT activations, default to 1 GB + :type peft_activation_reserve_space_size: Optional[int], optional + :param peft_weight_reserve_space_size: the space (in MB) to reserve on GPU for PEFT weights, default to 1 GB + :type peft_weight_reserve_space_size: Optional[int], optional :param profiling: whether to enable the FlexFlow profiling mode, defaults to False :type profiling: Optional[bool], optional :param inference_debugging: whether to run inference in debugging mode, saving all inputs/outputs/weights to file, defaults to False @@ -131,6 +143,9 @@ def init( offload_reserve_space_size is not None, use_4bit_quantization is not None, use_8bit_quantization is not None, + enable_peft is not None, + peft_activation_reserve_space_size is not None, + peft_weight_reserve_space_size is not None, profiling is not None, inference_debugging is not None, fusion is not None, @@ -156,6 +171,9 @@ def init( "offload_reserve_space_size": offload_reserve_space_size, "use_4bit_quantization": use_4bit_quantization, "use_8bit_quantization": use_8bit_quantization, + "enable_peft": enable_peft, + "peft_activation_reserve_space_size": peft_activation_reserve_space_size, + "peft_weight_reserve_space_size": peft_weight_reserve_space_size, "profiling": profiling, "inference_debugging": inference_debugging, "fusion": fusion, @@ -176,6 +194,8 @@ def init( "tensor_parallelism_degree", "pipeline_parallelism_degree", "offload_reserve_space_size", + "peft_activation_reserve_space_size", + "peft_weight_reserve_space_size", ] for param in positive_int_params: __check_positive_int(configs_dict, param) @@ -194,11 +214,17 @@ def init( if configs_dict.get("offload", None) is None: configs_dict["offload"] = False if configs_dict.get("offload_reserve_space_size", None) is None: - configs_dict["offload_reserve_space_size"] = 1024**2 + configs_dict["offload_reserve_space_size"] = 8*1024**3 if configs_dict.get("use_4bit_quantization", None) is None: configs_dict["use_4bit_quantization"] = False if configs_dict.get("use_8bit_quantization", None) is None: configs_dict["use_8bit_quantization"] = False + if configs_dict.get("enable_peft", None) is None: + configs_dict["enable_peft"] = False + if configs_dict.get("peft_activation_reserve_space_size", None) is None: + configs_dict["peft_activation_reserve_space_size"] = 8*1024**3 + if configs_dict.get("peft_weight_reserve_space_size", None) is None: + configs_dict["peft_weight_reserve_space_size"] = 1024**3 if configs_dict.get("profiling", None) is None: configs_dict["profiling"] = False if configs_dict.get("inference_debugging", None) is None: diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 6d77730e47..e73415faaf 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1524,8 +1524,9 @@ FFRuntime::FFRuntime(FFConfig &config) { info.offload_reserve_space_size = config.cpu_offload ? config.offload_reserve_space_size : 0; info.peft_activation_reserve_space_size = - config.peft_activation_reserve_space_size; - info.peft_weight_reserve_space_size = config.peft_weight_reserve_space_size; + config.enable_peft ? config.peft_activation_reserve_space_size : 0; + info.peft_weight_reserve_space_size = + config.enable_peft ? config.peft_weight_reserve_space_size : 0; info.quantization_type = config.quantization_type; info.allowTensorOpMathConversion = config.allow_tensor_op_math_conversion; argmap.set_point(*it, TaskArgument(&info, sizeof(FFInitInfo))); @@ -4062,6 +4063,7 @@ struct DefaultConfig { const static size_t offloadReserveSpaceSize = (size_t)8 * 1024 * 1024 * 1024; // 8 GB // PEFT related fields + const static bool enablePeft = false; const static size_t peftActivationReserveSpaceSize = (size_t)1 * 1024 * 1024 * 1024; // 1GB const static size_t peftWeightReserveSpaceSize = @@ -4102,6 +4104,7 @@ FFConfig::FFConfig() { cpu_offload = DefaultConfig::cpuOffload; offload_reserve_space_size = DefaultConfig::offloadReserveSpaceSize; // PEFT related fields + enable_peft = DefaultConfig::enablePeft; peft_activation_reserve_space_size = DefaultConfig::peftActivationReserveSpaceSize; peft_weight_reserve_space_size = DefaultConfig::peftWeightReserveSpaceSize; @@ -4227,6 +4230,18 @@ void FFConfig::parse_args(char **argv, int argc) { quantization_type = DT_INT8; continue; } + if ((!strcmp(argv[i], "-enable-peft"))) { + enable_peft = true; + continue; + } + if (!strcmp(argv[i], "-peft-activation-reserve-space-size")) { + peft_activation_reserve_space_size = atoll(argv[++i]) * 1024 * 1024; + continue; + } + if (!strcmp(argv[i], "-peft-weight-reserve-space-size")) { + peft_weight_reserve_space_size = atoll(argv[++i]) * 1024 * 1024; + continue; + } if ((!strcmp(argv[i], "--only-data-parallel"))) { only_data_parallel = true; continue; diff --git a/src/runtime/model.cu b/src/runtime/model.cu index 754a6b18d7..80f4fdf143 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -164,6 +164,8 @@ FFHandler handle.peft_activation_allocator = new MemoryAllocator(gpu_mem); handle.peft_activation_allocator->create_legion_instance( workspaceInst, info->peft_activation_reserve_space_size); + } else { + handle.peft_activation_allocator = nullptr; } if (info->peft_weight_reserve_space_size > 0) { @@ -188,6 +190,8 @@ FFHandler void *ptr = workspaceInst.pointer_untyped(0, sizeof(char)); handle.peft_weight_allocator = new PEFTWeightAllocator(ptr, info->peft_weight_reserve_space_size); + } else { + handle.peft_weight_allocator = nullptr; } // checkCUDA(cudaMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL diff --git a/tests/inference/python_test_configs/generate_configs.py b/tests/inference/python_test_configs/generate_configs.py index ebaadade32..b5cad16c65 100644 --- a/tests/inference/python_test_configs/generate_configs.py +++ b/tests/inference/python_test_configs/generate_configs.py @@ -14,9 +14,12 @@ "tensor_parallelism_degree": 1, "pipeline_parallelism_degree": 4, "offload": False, - "offload_reserve_space_size": 1024**2, + "offload_reserve_space_size": 8 * 1024, # 8 GB "use_4bit_quantization": False, "use_8bit_quantization": False, + "enable_peft": False, + "peft_activation_reserve_space_size": 1024, # 1GB + "peft_weight_reserve_space_size": 1024, # 1GB "profiling": False, "inference_debugging": False, "fusion": True, From 64a59d891ae3db48c8234af9bf46fadf48c4bd9b Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sat, 3 Feb 2024 17:17:56 +0000 Subject: [PATCH 17/18] only run peft_bwd if peft is enabled --- src/runtime/request_manager.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index cbb21e03e0..2eebc070d6 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -2206,7 +2206,9 @@ GenerationResult RequestManager::generate_incr_decoding( BatchConfigFuture bcf = prepare_next_batch(next_batch.first, next_batch.second); FutureMap fm = im->inference(llm, 0, bcf); - im->peft_bwd(llm, 0, bcf); + if (llm->config.enable_peft) { + im->peft_bwd(llm, 0, bcf); + } assert(fm.get_future_map_domain().get_volume() == 1); InferenceResultFuture irf = fm.get_future(0); batch_pipeline.push(std::make_pair(bcf, irf)); From 32a07165cf1a68e8b15c8f591a66c397888712ec Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Tue, 6 Feb 2024 05:46:25 +0000 Subject: [PATCH 18/18] fix rms norm inference region reqs --- src/ops/residual_rms_norm.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ops/residual_rms_norm.cc b/src/ops/residual_rms_norm.cc index e549e5f6da..264c12f004 100644 --- a/src/ops/residual_rms_norm.cc +++ b/src/ops/residual_rms_norm.cc @@ -394,7 +394,7 @@ FutureMap assert(batch_outputs[0]->region == batch_inputs[0]->region); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + READ_WRITE, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); @@ -412,7 +412,7 @@ FutureMap launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, - READ_WRITE, + READ_ONLY, EXCLUSIVE, weights[0]->region)); launcher.add_field(3, FID_DATA); @@ -423,7 +423,7 @@ FutureMap regions[0](I/O): input1 / residual output regions[1](I): input2 regions[2](O): output - regions[3](I/O): weight + regions[3](I): weight */ void ResidualRMSNorm::inference_task(Task const *task, std::vector const ®ions,