Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmluscher committed Jan 17, 2025
1 parent c4fa60e commit 9da3a3c
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions i6_models/parts/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class LstmBlockV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input dimension size
hidden_dim: hidden dimension of one direction of LSTM, the total output size is twice of this
hidden_dim: hidden dimension of one direction of LSTM
num_layers: number of uni-directional LSTM layers, minimum 2
bias: add a bias term to the LSTM layer
dropout: nn.LSTM supports internal Dropout applied between each layer of LSTM (but not on input/output)
Expand Down Expand Up @@ -47,15 +47,14 @@ def from_dict(cls, model_cfg_dict: Dict[str, Any]):


class LstmBlockV1(nn.Module):
def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict[str, Any]], **kwargs):
def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict[str, Any]]):
"""
Model definition of LSTM block. Contains single lstm stack and padding sequence in forward call. Including
dropout, batch-first variant, hardcoded to use B,T,F input.
Supports: TorchScript, ONNX-export.
:param model_cfg: holds model configuration as dataclass or dict instance.
:param kwargs:
"""
super().__init__()

Expand All @@ -78,7 +77,7 @@ def forward(self, x: torch.Tensor, seq_len: torch.Tensor) -> Tuple[torch.Tensor,
:param x: [B, T, input_dim]
:param seq_len:[B], should be on CPU for Script/Trace mode
:return: [B, T, 2 * hidden_dim]
:return: [B, T, hidden_dim]
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if seq_len.get_device() >= 0:
Expand Down

0 comments on commit 9da3a3c

Please sign in to comment.