You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
batch, length, dim = 2, 64, 32
x = torch.randn(batch, length, dim)
model = Hydra(d_model=dim,d_state=16,d_conv=4,expand=2,dt_min=0.001,dt_max=0.1,dt_init_floor=1e-4,conv_bias=True,bias=False)
y = model(x)
I get this error message :
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/splice/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/splice/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/splice/hydra/modules/hydra.py", line 183, in forward
y = mamba_chunk_scan_combined(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/splice/lib/python3.12/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 582, in mamba_chunk_scan_combined
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/splice/lib/python3.12/site-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/splice/lib/python3.12/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 541, in forward
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/splice/lib/python3.12/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 288, in _mamba_chunk_scan_combined_fwd
assert dt.shape == (batch, seqlen, nheads)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
The text was updated successfully, but these errors were encountered:
When I execute these lines of code :
I get this error message :
The text was updated successfully, but these errors were encountered: