Skip to content

Commit

Permalink
Merge branch 'main' into charlifu/MLPerf_opt
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu authored Sep 12, 2024
2 parents 11df7bd + b1c3273 commit f724002
Show file tree
Hide file tree
Showing 16 changed files with 899 additions and 135 deletions.
38 changes: 38 additions & 0 deletions .buildkite/test-template.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %}
{% set docker_image_amd = "rocm/vllm-ci:$BUILDKITE_COMMIT" %}
{% set default_working_dir = "vllm/tests" %}
{% set hf_home = "/root/.cache/huggingface" %}

steps:
- label: ":docker: build image"
depends_on: ~
commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image_amd }} -f Dockerfile.rocm --progress plain ."
- "docker push {{ docker_image_amd }}"
key: "amd-build"
env:
DOCKER_BUILDKIT: "1"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- exit_status: -10 # Agent was lost
limit: 5
agents:
queue: amd

{% for step in steps %}
{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
- label: "AMD: {{ step.label }}"
depends_on:
- "amd-build"
agents:
queue: amd
commands:
- bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" && ")) | safe }}"
env:
DOCKER_BUILDKIT: "1"
priority: 100
soft_fail: true
{% endif %}
{% endfor %}
3 changes: 2 additions & 1 deletion csrc/custom/custom_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);
const std::string& kv_cache_dtype, double k_scale,
double v_scale);
253 changes: 160 additions & 93 deletions csrc/custom/paged_attention/attention_ll4mi.cu

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion csrc/custom/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, custom_ops) {
" Tensor context_lens, int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype) -> ()");
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
custom_ops.impl("paged_attention_custom", torch::kCUDA,
&paged_attention_custom);
custom_ops.def(
Expand Down
92 changes: 90 additions & 2 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,52 @@ using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif

#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

namespace vllm {

template <typename scalar_t>
struct __align__(16) vec8_t {
scalar_t x, y, z, w, u, v, s, t;

__device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {}
__device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u,
scalar_t v, scalar_t s, scalar_t t)
: x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {}

__device__ vec8_t operator*(const vec8_t& other) const {
return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w,
u * other.u, v * other.v, s * other.s, t * other.t);
}

__device__ vec8_t operator*(const float& scale) const {
return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale,
v * scale, s * scale, t * scale);
}

__device__ vec8_t operator+(const vec8_t& other) const {
return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w,
u + other.u, v + other.v, s + other.s, t + other.t);
}

__device__ void operator+=(const vec8_t& other) {
x += other.x;
y += other.y;
z += other.z;
w += other.w;
u += other.u;
v += other.v;
s += other.s;
t += other.t;
}

__device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; }
};

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_kernel(
Expand All @@ -31,6 +75,49 @@ __global__ void rms_norm_kernel(
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;

vec8_t<scalar_t> v8_variance = {0, 0, 0, 0, 0, 0, 0, 0};

vec8_t<scalar_t>* vectorized_out = reinterpret_cast<vec8_t<scalar_t>*>(out);
vec8_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec8_t<scalar_t> const*>(input);
vec8_t<scalar_t> const* vectorized_weight =
reinterpret_cast<vec8_t<scalar_t> const*>(weight);
const int vec_hidden_size = hidden_size >> 3;

// Compute variance. Be careful, hidden_size should multiple of 4.
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> x = vectorized_in[blockIdx.x * vec_hidden_size + idx];
v8_variance += x * x;
}
float v8_variance_sum = v8_variance.sum();

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float variance =
BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
vec8_t<scalar_t> v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx];
vec8_t<scalar_t> v8_w = vectorized_weight[idx];
vectorized_out[blockIdx.x * vec_hidden_size + idx] =
v8_in * s_variance * v8_w;
}
}

template <typename scalar_t>
__global__ void scaled_rms_norm_kernel(
hip_fp8* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* scale, const float epsilon, const int num_tokens,
const int hidden_size, const int hidden_size_padded) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
Expand All @@ -49,8 +136,9 @@ __global__ void rms_norm_kernel(

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
x = (x * s_variance) * (float)weight[idx] / (*scale);

out[blockIdx.x * hidden_size_padded + idx] = hip_fp8(x);
}
}

Expand Down
2 changes: 0 additions & 2 deletions csrc/quantization/fp8/amd/quant_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
return __float2bfloat16(f * scale);
}

using __nv_bfloat162 = __hip_bfloat162;

// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
Expand Down
12 changes: 6 additions & 6 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,13 +702,13 @@ def paged_attention_custom(
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: int,
v_scale: int,
) -> None:
torch.ops._custom_C.paged_attention_custom(out, exp_sum, max_logits,
tmp_out, query, key_cache,
value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len,
alibi_slopes, kv_cache_dtype)
torch.ops._custom_C.paged_attention_custom(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale)


def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int,
Expand Down
32 changes: 9 additions & 23 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
_PARTITION_SIZE_CUSTOM = 512


@dataclass
Expand Down Expand Up @@ -122,11 +122,8 @@ def forward_decode(
num_seqs, num_heads, head_size = query.shape
gqa_ratio = num_heads // num_kv_heads
use_custom = (custom_attn_available
and (query.dtype == torch.half
or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and kv_cache_dtype == "auto"
and query.dtype in (torch.half, torch.bfloat16)
and head_size in (64, 128) and block_size in (16, 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 32768)
if not use_custom:
Expand Down Expand Up @@ -208,23 +205,12 @@ def forward_decode(
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_custom(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
)
ops.paged_attention_custom(output, exp_sums, max_logits,
tmp_output, query, key_cache,
value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size,
max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
return output

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,7 @@ def _process_model_outputs(self,
return None

def _advance_to_next_step(
self, output: List[SamplerOutput],
self, output: SamplerOutput,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
"""Given model output from a single run, append the tokens to the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 0,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
"Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from .interfaces import SupportsLoRA


@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
Expand Down
Loading

0 comments on commit f724002

Please sign in to comment.