-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grok-1 optimization #164
Grok-1 optimization #164
Conversation
…iling code and fix the gemm shape could not be dumped correctly in multiple-gpu
…fused_moe accuracy check test file 4) sync optimization from branch MLPerf-4.1
* First version * Revert error. While there, add missing finalize. * Use the correct defaults for ROCm. Increase sampling area to capture crossover. * Scope end_sync as well. * Guard only volatile keyword for ifndef USE_ROCM * Document crossover
* remove scoping * while there fix a typo * while there remove unused variable
…the different TP setting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind renaming this file to not use a '+' symbol?
@@ -328,7 +332,7 @@ if (VLLM_PUNICA_GPU_ARCHES) | |||
DESTINATION vllm | |||
LANGUAGE ${VLLM_GPU_LANG} | |||
SOURCES ${VLLM_PUNICA_EXT_SRC} | |||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} | |||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_eLAGS} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo?
#triton_col_output = fused_moe_col_major(a, w1, w2, score, topk, renormalize=False) | ||
#print(f"triton_col_output: {triton_col_output}") | ||
|
||
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use torch.testing.assert_close
@@ -293,6 +293,8 @@ __device__ void paged_attention_kernel( | |||
// This includes a reduction across the threads in the same thread group. | |||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot( | |||
q_vecs[thread_group_offset], k_vecs); | |||
float max_attn_val = 30.0; //hardcoded for grok |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this affect any other model? Can this be moved to a #define or constexpr or something?
@@ -998,4 +1000,4 @@ void paged_attention_v2( | |||
#undef WARP_SIZE | |||
#undef MAX | |||
#undef MIN | |||
#undef DIVIDE_ROUND_UP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this affect anything else? Still need everything else to work.
@@ -328,7 +332,7 @@ if (VLLM_PUNICA_GPU_ARCHES) | |||
DESTINATION vllm | |||
LANGUAGE ${VLLM_GPU_LANG} | |||
SOURCES ${VLLM_PUNICA_EXT_SRC} | |||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} | |||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_eLAGS} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks to be a typo, probably does not build.
Thanks for the feedback, we will correct in our new PR "#181" based on the vllm 0.6.0. |
This PR adds GROK-FP8 support in vLLM.