From 2597105d7e89b23f46606823059c080682885c9a Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 23 Dec 2024 01:28:56 +0000 Subject: [PATCH] format and add cont batch unit tests (will need more cases) Signed-off-by: Yu Chin Fabian Lim --- tests/kernels/test_mamba_ssm_ssd.py | 177 ++++++++++++++++-- .../layers/mamba/mamba_mixer2.py | 5 +- .../layers/mamba/ops/ssd_chunk_scan.py | 119 ++++++------ .../layers/mamba/ops/ssd_chunk_state.py | 49 +++-- .../layers/mamba/ops/ssd_combined.py | 24 ++- .../layers/mamba/ops/ssd_state_passing.py | 2 +- 6 files changed, 264 insertions(+), 112 deletions(-) diff --git a/tests/kernels/test_mamba_ssm_ssd.py b/tests/kernels/test_mamba_ssm_ssd.py index 328a91459ff24..d9b1766f1f2f6 100644 --- a/tests/kernels/test_mamba_ssm_ssd.py +++ b/tests/kernels/test_mamba_ssm_ssd.py @@ -7,6 +7,8 @@ mamba_chunk_scan_combined) from vllm.platforms import current_platform +import numpy as np + # Added by the IBM Team, 2024 # Adapted from https://github.com/state-spaces/mamba/tree/main/mamba_ssm/ops/triton @@ -76,32 +78,118 @@ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): return Y, final_state -@pytest.mark.parametrize("itype", - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("n_heads", [4, 16, 32]) -@pytest.mark.parametrize("dim", [128, 512]) -def test_mamba_chunk_scan(dim, n_heads, itype): - device = "cuda" - # set seed - current_platform.seed_everything(0) - batch = 1 # batch_size - seqlen = 128 - chunk_size = 32 - d_head = dim // n_heads +def generate_random_inputs(batch_size, + seqlen, + n_heads, + d_head, + itype, + device='cuda'): + current_platform.seed_everything(0) A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) dt = F.softplus( - torch.randn(batch, seqlen, n_heads, dtype=itype, device=device) - 4) - X = torch.randn((batch, seqlen, n_heads, d_head), + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - + 4) + X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) - B = torch.randn((batch, seqlen, n_heads, d_head), + B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) - C = torch.randn((batch, seqlen, n_heads, d_head), + C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + return A, dt, X, B, C + + +def generate_continous_batched_examples(example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device='cuda'): + + # this function generates a random examples of certain length + # and then cut according to "example_lens_by_batch" and feed + # them in continuous batches to the kernels + + # generate the full-length example + A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, + d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), + A * dt, + B, + C, + block_len=full_length // 4) + + # internal function to + def take(example_lens): + + indices = [] + for i, l in enumerate(example_lens): + c = last_taken.get(i, 0) + indices.append((c, c + l)) + last_taken[i] = (c + l) % full_length + exhausted[i] = last_taken[i] == 0 + + return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) + ]).unsqueeze(0) for x in (dt, X, B, C)) + + def end_boundary(n): + return n - ((n - 1) // full_length) * full_length + + IND_E = None + for i, spec in enumerate(example_lens_by_batch): + + # get the (maybe partial) example seen in this cont batch + dt2, X2, B2, C2 = take(spec) + + # get the metadata + cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) + sed_idx = torch.zeros(cu_seqlens[-1], + dtype=torch.int32, + device=cu_seqlens.device) + for i, (srt, end) in enumerate(zip( + cu_seqlens, + cu_seqlens[1:], + )): + sed_idx[srt:end] = i + + # for cont batch + # IND = np.insert(np.cumsum(spec), [0], [0]) # torch.cumsum + if IND_E is None: + IND_S = [0 for _ in range(len(spec))] + else: + IND_S = [x % full_length for x in IND_E] + IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] + + yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], + cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [4, 16, 32]) +@pytest.mark.parametrize("dim", [128, 512]) +@pytest.mark.parametrize("seq_len_chunk_size", [(32, 128)]) +def test_mamba_chunk_scan_single_example(dim, n_heads, seq_len_chunk_size, + itype): + + # this tests the kernels on a single example (no batching) + + # set seed + batch_size = 1 # batch_size + seqlen, chunk_size = seq_len_chunk_size + d_head = dim // n_heads + + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, + d_head, itype) + Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, B, C, chunk_size) @@ -123,3 +211,60 @@ def test_mamba_chunk_scan(dim, n_heads, itype): final_state_min[:, -1].to(torch.float32), atol=1e-2, rtol=1e1) + + +@pytest.mark.parametrize("itype", [torch.float16]) +@pytest.mark.parametrize("n_heads", [4]) +@pytest.mark.parametrize("dim", [64]) +@pytest.mark.parametrize("seq_len_chunk_size_cases", [ + 64, + 8, + 2, + [(32, 32), (32, 32)], +]) +def test_mamba_chunk_scan_batch(dim, n_heads, seq_len_chunk_size_cases, itype): + + # this test with multiple examples in a continuous batch + + seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + d_head = dim // n_heads + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken = {} # map: eg -> pointer to last taken sample + exhausted = {} # map: eg -> boolean indicating example is exhausted + + states = None + for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, + C) in generate_continous_batched_examples( + cases, num_examples, seqlen, + last_taken, exhausted, n_heads, + d_head, itype): + + Y, new_states = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=sed_idx, + return_varlen_states=True, + initial_states=states, + ) + + # just test the last in sequence + for i in range(num_examples): + + # just test one dim and dstate + Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] + Y_min_eg = Y_min[i][:, 0, 0] + torch.testing.assert_close(Y_eg, Y_min_eg, atol=1e-2, rtol=1e1) + + # update states + states = new_states + for i in [i for i, clear in exhausted.items() if clear]: + states[i].fill_(0.) + exhausted = {} diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0b3f9f1028753..e64f8fb2210b2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -353,7 +353,8 @@ def forward_cuda( # - also need flags to indicate if there are initial states # - currently we really only support the FlashAttention backend has_initial_states = None - if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata)) + if (isinstance(attn_metadata, + (FlashAttentionMetadata, XFormersMetadata)) and attn_metadata.context_lens_tensor is not None): has_initial_states = attn_metadata.context_lens_tensor > 0 @@ -428,7 +429,7 @@ def forward_cuda( scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states.view(1, seq_len, self.num_heads // self.tp_size, - self.head_dim), + self.head_dim), dt.unsqueeze(0), self.A, B.view(1, seq_len, self.n_groups // self.tp_size, -1), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index a548f11207baa..27b53f334336e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -210,7 +210,7 @@ def _chunk_scan_fwd_kernel( # - logic in next block may override these if there is an active offset offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head - prev_states_hdim = stride_states_hdim + prev_states_hdim = stride_states_hdim prev_states_dstate = stride_states_dstate chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) @@ -219,8 +219,8 @@ def _chunk_scan_fwd_kernel( # - seq_idx_prev points to be previous (possibly logical) chunk. seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=pid_c>= 1, - other=0) + mask=pid_c >= 1, + other=0) if HAS_INITSTATES: # if there are init states, we only need seq_idx_m to point @@ -229,20 +229,21 @@ def _chunk_scan_fwd_kernel( # get current seq idx if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: seq_idx_m = tl.load( - seq_idx_ptr + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, - ) + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) # - recall that in ssd_state_passing, for the case c_off == 0 - # i.e., the very first sequence, we made states_ptr hold its inital state + # i.e., the very first sequence, we made states_ptr hold its initial state # so this edge case is taken care of - if ( - (c_off == 0) and (seq_idx_prev != seq_idx_m) # if a seq is changed exactly on boundary - or (c_off > 0) # implies a new example (pseudo chunk) - ): + if ((c_off == 0) and + (seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): # - replace prev_states_ptr with init_states prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides + prev_states_hdim = stride_init_states_hdim # override strides prev_states_dstate = stride_init_states_dstate offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -258,15 +259,16 @@ def _chunk_scan_fwd_kernel( # get the c_idx for the next (logica) chunk c_idx_n = tl.load( - chunk_indices_ptr + (pid_c+1), - mask=pid_c > -1 and (pid_c+1) < chunk_meta_num, other=-1 # to trigger different chunk + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1 # to trigger different chunk ) # - there are things to consider - # A. if c_off > 0 then we need to move the dA_cs bounary to ensure correct + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct # contribution of past states - # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to - # encroach into the next sequence, where c_off_n is the offset of the next + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next # (logical) chunk. # An equivalent check for B is c_idx == c_idx_n, where there is repetition in # (logical) chunk indices. @@ -274,10 +276,9 @@ def _chunk_scan_fwd_kernel( if (c_idx == c_idx_n) or c_off > 0: # get the next offset - c_off_n = tl.load( - chunk_offsets_ptr + (pid_c+1), - mask=pid_c > -1 and (pid_c+1) < chunk_meta_num, other=chunk_size - ) + c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size) # in this case, adjust down the chunk_size_limit if c_idx == c_idx_n: @@ -286,8 +287,9 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + (pid_m * BLOCK_SIZE_M + c_off -1) * stride_dA_cs_csize, - mask=(pid_m * BLOCK_SIZE_M + c_off -1) > -1, + dA_cumsum_ptr + + (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, + mask=(pid_m * BLOCK_SIZE_M + c_off - 1) > -1, other=0.0).to(tl.float32) if HAS_SEQ_IDX: @@ -297,7 +299,6 @@ def _chunk_scan_fwd_kernel( mask=offs_m < chunk_size_limit, other=-1) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get @@ -309,18 +310,19 @@ def _chunk_scan_fwd_kernel( 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - + prev_states_ptrs = prev_states_ptr + ( offs_n[None, :] * prev_states_hdim + offs_k_dstate[:, None] * prev_states_dstate) if HAS_SEQ_IDX: if not HAS_INITSTATES: - # - this is for continous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + 0.0) else: # - if there is initstates, we will rely on prev_states, no zeroing - # reqiured. + # required. scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) else: scale_m = tl.exp(dA_cs_m) @@ -329,7 +331,7 @@ def _chunk_scan_fwd_kernel( mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), @@ -435,17 +437,18 @@ def _chunk_scan_fwd_kernel( (offs_out_n[None, :] < hdim)) -def _chunk_scan_fwd(cb, - x, - dt, - dA_cumsum, - C, - states, - D=None, - z=None, - seq_idx=None, - initial_states=None, - ): +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + initial_states=None, +): batch, seqlen, nheads, headdim = x.shape _, _, nchunks, chunk_size = dt.shape _, _, ngroups, dstate = C.shape @@ -465,10 +468,11 @@ def _chunk_scan_fwd(cb, assert seq_idx.shape == (batch, seqlen) if initial_states is not None: - # with initial states, we need to take care of how + # with initial states, we need to take care of how # seq_idx crosses the boundaries assert batch == 1, "chunk scan only supports initial states with batch 1" - assert initial_states.shape == (seq_idx[0].max()+1, nheads, headdim, dstate) + assert initial_states.shape == (seq_idx[0].max() + 1, nheads, + headdim, dstate) if initial_states.shape[0] == 1: # no in this case no point to use initial states @@ -480,16 +484,20 @@ def _chunk_scan_fwd(cb, o = i % chunk_size c = idx > p if o == 0 or c: - # this means we have a change in sequence + # this means we have a change in sequence # - that does not accur on the chunk boundary chunk_indices.append(i // chunk_size) chunk_offsets.append(o) if c: - p = idx # new sequence + p = idx # new sequence - chunk_indices = torch.tensor(chunk_indices, dtype=torch.int, device=seq_idx.device) - chunk_offsets = torch.tensor(chunk_offsets, dtype=torch.int, device=seq_idx.device) + chunk_indices = torch.tensor(chunk_indices, + dtype=torch.int, + device=seq_idx.device) + chunk_offsets = torch.tensor(chunk_offsets, + dtype=torch.int, + device=seq_idx.device) # Allocates output. out = torch.empty(batch, @@ -509,13 +517,10 @@ def _chunk_scan_fwd(cb, else: out_x = None - - grid = lambda META: (triton.cdiv( - chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - headdim, META['BLOCK_SIZE_N']), - batch * nchunks if chunk_offsets is None else len(chunk_offsets), - nheads - ) + grid = lambda META: ( + triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + headdim, META['BLOCK_SIZE_N']), batch * nchunks + if chunk_offsets is None else len(chunk_offsets), nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( @@ -576,12 +581,10 @@ def _chunk_scan_fwd(cb, states.stride(2), states.stride(3), states.stride(4), - *( - ( - initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2), initial_states.stride(3) - ) if initial_states is not None else (0, 0, 0, 0) - ), + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), D.stride(0) if D is not None else 0, True, D is not None, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 731e350399b59..59bb852e4b54d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -502,15 +502,10 @@ def _chunk_state_varlen_kernel( # If HAS_INITSTATES==True need to consider two possiblties # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs # - if state_idx >= pid * chunk_size, then we need to insert initstates - if ( - (start_idx < pid_c * chunk_size) # first chunk - or - ( - HAS_INITSTATES - ) - ): + if ((start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES)): - dA_cs_boundary = 0.0 # default + dA_cs_boundary = 0.0 # default if not HAS_INITSTATES: past_states_ptrs = chunk_states_ptr + ( @@ -525,20 +520,21 @@ def _chunk_state_varlen_kernel( offs_n[None, :] * stride_chunk_states_dstate) else: past_states_ptrs = initstates_ptr + ( - pid_b * stride_init_states_batch + + pid_b * stride_init_states_batch + offs_m[:, None] * stride_init_states_hdim + offs_n[None, :] * stride_init_states_dstate) # need to adjust the boundary - if start_idx > pid_c * chunk_size: - dA_cs_boundary = tl.load( - dA_cumsum_ptr + (start_idx - pid_c * chunk_size - 1) * - stride_dA_cs_csize).to(tl.float32) + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load(dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - + 1) * stride_dA_cs_csize).to( + tl.float32) past_states = tl.load(past_states_ptrs, - mask=(offs_m[:, None] < hdim) & - (offs_n[None, :] < dstate), - other=0.0).to(tl.float32) + mask=(offs_m[:, None] < hdim) & + (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) scale = tl.exp(dA_cs_last - dA_cs_boundary) acc += past_states * scale @@ -680,7 +676,13 @@ def _chunk_state_fwd(B, return states -def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None): +def chunk_state_varlen(B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None): total_seqlen, nheads, headdim = x.shape _, nchunks, chunk_size = dt.shape _, ngroups, dstate = B.shape @@ -738,12 +740,9 @@ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_st states.stride(1), states.stride(2), states.stride(3), - *( - ( - initial_states.stride(0), initial_states.stride(1), - initial_states.stride(2), initial_states.stride(3) - ) if initial_states is not None else (0, 0, 0, 0) - ), - HAS_INITSTATES=initial_states is not None - ) + *((initial_states.stride(0), initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3)) if initial_states is not None else + (0, 0, 0, 0)), + HAS_INITSTATES=initial_states is not None) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 361190a6ed409..1f10e86cddd92 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -95,9 +95,9 @@ def _mamba_chunk_scan_combined_fwd(x, # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states # ii) seq_idx and iii) has_cu_seqlens to be all specified. - # - When a new seq_idx is detected, we will stopp passing the prev_state + # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. - # - this will ensure that states will be updated with the righmost flushed seq_idx + # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. states, final_states = _state_passing_fwd( @@ -121,10 +121,10 @@ def _mamba_chunk_scan_combined_fwd(x, # 5. Scan and compute the diagonal blocks, taking into # account past causal states. - # - if initial states are provided, then states information will be + # - if initial states are provided, then states information will be # augmented with initial_states. # - to do this properly, we need to account for example changes in - # the continous batch, therefore we introduce pseudo chunks, which is + # the continuous batch, therefore we introduce pseudo chunks, which is # a chunk that is split up each time an example changes. # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had # a seq_idx change, in which case we take states information from @@ -140,16 +140,20 @@ def _mamba_chunk_scan_combined_fwd(x, z=z, seq_idx=seq_idx, initial_states=initial_states, - ) + ) if cu_seqlens is None: return out, out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), - dt.squeeze(0), dA_cumsum.squeeze(0), - cu_seqlens, states.squeeze(0), - initial_states=initial_states, - ) + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) return out, out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index c4e6cd2f961f4..f7d94f8da4ac2 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -108,7 +108,7 @@ def _state_passing_fwd_kernel( if HAS_INITSTATES: if IS_CONT_BATCHED and seq_idx != seq_idx_new: # this means in the current chunk the rightmost flushed seq - # has changed. + # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch