diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 498d069c05f0d..dd1e6de2e0180 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -424,7 +424,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), // (which occurs when `final_state_position` is a non-positivie index) // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it - if (final_state_position < 0 && seqlen > kWidth){ + if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ input_t vals_load[kNElts] = {0}; if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ // chunk = n_chunks - 2, a segment of the final state sits in the last index diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index f9b11018288be..51be2425d7dd7 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) +@pytest.mark.parametrize("has_initial_state", [True, False]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize( 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]) @pytest.mark.parametrize('dim', [64]) @pytest.mark.parametrize('batch', [1]) def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, - itype): + has_initial_state, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - initial_states = torch.randn(batch, - dim, - width - 1, - device=device, - dtype=itype) + if has_initial_state: + initial_states = torch.randn(batch, + dim, + width - 1, + device=device, + dtype=itype) + has_initial_state_tensor = torch.ones(batch, + dtype=torch.bool, + device=x.device) + else: + initial_states = None + has_initial_state_tensor = None x_ref = x.clone() weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None @@ -183,9 +191,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, bias, activation=activation, conv_states=initial_states, - has_initial_state=torch.ones(batch, - dtype=torch.bool, - device=x.device)) + has_initial_state=has_initial_state_tensor) out_ref, final_states_ref = causal_conv1d_ref( x_ref, weight_ref, @@ -193,11 +199,12 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, initial_states=initial_states_ref, return_final_states=True, activation=activation) - assert initial_states is not None and final_states_ref is not None - assert torch.allclose(initial_states, - final_states_ref, - rtol=rtol, - atol=atol) + if has_initial_state: + assert initial_states is not None and final_states_ref is not None + assert torch.allclose(initial_states, + final_states_ref, + rtol=rtol, + atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) causal_conv1d_opcheck_fn(x, @@ -205,9 +212,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, bias, activation=activation, conv_states=initial_states, - has_initial_state=torch.ones(batch, - dtype=torch.bool, - device=x.device)) + has_initial_state=has_initial_state_tensor) @pytest.mark.parametrize("itype", [torch.bfloat16])