Skip to content

Commit

Permalink
format and add cont batch unit tests (will need more cases)
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Dec 23, 2024
1 parent dcbae7b commit 2597105
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 112 deletions.
177 changes: 161 additions & 16 deletions tests/kernels/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
mamba_chunk_scan_combined)
from vllm.platforms import current_platform

import numpy as np

Check failure on line 10 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/kernels/test_mamba_ssm_ssd.py:10:17: F401 `numpy` imported but unused

# Added by the IBM Team, 2024

# Adapted from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton
Expand Down Expand Up @@ -76,32 +78,118 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
return Y, final_state


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [4, 16, 32])
@pytest.mark.parametrize("dim", [128, 512])
def test_mamba_chunk_scan(dim, n_heads, itype):
device = "cuda"
# set seed
current_platform.seed_everything(0)
batch = 1 # batch_size
seqlen = 128
chunk_size = 32
d_head = dim // n_heads
def generate_random_inputs(batch_size,
seqlen,
n_heads,
d_head,
itype,
device='cuda'):

current_platform.seed_everything(0)
A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device)))
dt = F.softplus(
torch.randn(batch, seqlen, n_heads, dtype=itype, device=device) - 4)
X = torch.randn((batch, seqlen, n_heads, d_head),
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) -
4)
X = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
B = torch.randn((batch, seqlen, n_heads, d_head),
B = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)
C = torch.randn((batch, seqlen, n_heads, d_head),
C = torch.randn((batch_size, seqlen, n_heads, d_head),
dtype=itype,
device=device)

return A, dt, X, B, C


def generate_continous_batched_examples(example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device='cuda'):

# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels

# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)

Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length // 4)

# internal function to
def take(example_lens):

indices = []
for i, l in enumerate(example_lens):

Check failure on line 134 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E741)

tests/kernels/test_mamba_ssm_ssd.py:134:16: E741 Ambiguous variable name: `l`
c = last_taken.get(i, 0)
indices.append((c, c + l))
last_taken[i] = (c + l) % full_length
exhausted[i] = last_taken[i] == 0

return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)
]).unsqueeze(0) for x in (dt, X, B, C))

def end_boundary(n):
return n - ((n - 1) // full_length) * full_length

IND_E = None
for i, spec in enumerate(example_lens_by_batch):

# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = take(spec)

# get the metadata
cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0)
sed_idx = torch.zeros(cu_seqlens[-1],
dtype=torch.int32,
device=cu_seqlens.device)
for i, (srt, end) in enumerate(zip(
cu_seqlens,
cu_seqlens[1:],
)):
sed_idx[srt:end] = i

# for cont batch
# IND = np.insert(np.cumsum(spec), [0], [0]) # torch.cumsum
if IND_E is None:
IND_S = [0 for _ in range(len(spec))]
else:
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]

yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2))


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [4, 16, 32])
@pytest.mark.parametrize("dim", [128, 512])
@pytest.mark.parametrize("seq_len_chunk_size", [(32, 128)])
def test_mamba_chunk_scan_single_example(dim, n_heads, seq_len_chunk_size,
itype):

# this tests the kernels on a single example (no batching)

# set seed
batch_size = 1 # batch_size
seqlen, chunk_size = seq_len_chunk_size
d_head = dim // n_heads

A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads,
d_head, itype)

Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt,
B, C, chunk_size)

Expand All @@ -123,3 +211,60 @@ def test_mamba_chunk_scan(dim, n_heads, itype):
final_state_min[:, -1].to(torch.float32),
atol=1e-2,
rtol=1e1)


@pytest.mark.parametrize("itype", [torch.float16])
@pytest.mark.parametrize("n_heads", [4])
@pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("seq_len_chunk_size_cases", [
64,
8,
2,
[(32, 32), (32, 32)],
])
def test_mamba_chunk_scan_batch(dim, n_heads, seq_len_chunk_size_cases, itype):

# this test with multiple examples in a continuous batch

seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
d_head = dim // n_heads

# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken = {} # map: eg -> pointer to last taken sample

Check failure on line 234 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Need type annotation for "last_taken" (hint: "last_taken: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 234 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Need type annotation for "last_taken" (hint: "last_taken: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 234 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Need type annotation for "last_taken" (hint: "last_taken: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 234 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Need type annotation for "last_taken" (hint: "last_taken: dict[<type>, <type>] = ...") [var-annotated]
exhausted = {} # map: eg -> boolean indicating example is exhausted

Check failure on line 235 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Need type annotation for "exhausted" (hint: "exhausted: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 235 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Need type annotation for "exhausted" (hint: "exhausted: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 235 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Need type annotation for "exhausted" (hint: "exhausted: dict[<type>, <type>] = ...") [var-annotated]

Check failure on line 235 in tests/kernels/test_mamba_ssm_ssd.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Need type annotation for "exhausted" (hint: "exhausted: dict[<type>, <type>] = ...") [var-annotated]

states = None
for Y_min, cu_seqlens, sed_idx, (A, dt, X, B,
C) in generate_continous_batched_examples(
cases, num_examples, seqlen,
last_taken, exhausted, n_heads,
d_head, itype):

Y, new_states = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=sed_idx,
return_varlen_states=True,
initial_states=states,
)

# just test the last in sequence
for i in range(num_examples):

# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=1e-2, rtol=1e1)

# update states
states = new_states
for i in [i for i, clear in exhausted.items() if clear]:
states[i].fill_(0.)
exhausted = {}
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def forward_cuda(
# - also need flags to indicate if there are initial states
# - currently we really only support the FlashAttention backend
has_initial_states = None
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata))
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0

Expand Down Expand Up @@ -428,7 +429,7 @@ def forward_cuda(

scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
self.head_dim),
self.head_dim),
dt.unsqueeze(0),
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
Expand Down
Loading

0 comments on commit 2597105

Please sign in to comment.