Skip to content

Commit

Permalink
test: adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Nov 27, 2024
1 parent a703688 commit f1be563
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
attention_implementation: str = "Flex Attention",
attention_implementation: str = "flex attention",
softcap: float = 0.0,
use_alibi_slopes: bool = None,
**kwargs,
Expand Down
18 changes: 16 additions & 2 deletions tests/layers/block/test_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ class TestTransformerProcessorBlock:
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p, softcap):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
num_channels,
hidden_dim,
num_heads,
activation,
window_size,
dropout_p=dropout_p,
attention_implementation="scaled dot product attention",
softcap=softcap,
)
assert isinstance(block, TransformerProcessorBlock)

Expand Down Expand Up @@ -74,7 +81,14 @@ def test_forward_output(
):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
num_channels,
hidden_dim,
num_heads,
activation,
window_size,
dropout_p=dropout_p,
attention_implementation="scaled dot product attention",
softcap=softcap,
)

x = torch.randn((batch_size, num_channels)) # .to(torch.float16, non_blocking=True)
Expand Down
4 changes: 4 additions & 0 deletions tests/layers/chunk/test_chunk_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def init(self):
activation: str = "GELU"
window_size: int = 13
dropout_p: float = 0.1
attention_implementation = "scaled dot product attention"

# num_heads must be evenly divisible by num_channels for MHSA
return (
Expand All @@ -34,6 +35,7 @@ def init(self):
activation,
window_size,
dropout_p,
attention_implementation,
)

@pytest.fixture
Expand All @@ -46,6 +48,7 @@ def processor_chunk(self, init):
activation,
window_size,
dropout_p,
attention_implementation,
) = init
return TransformerProcessorChunk(
num_channels=num_channels,
Expand All @@ -55,6 +58,7 @@ def processor_chunk(self, init):
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
attention_implementation=attention_implementation,
)

def test_all_blocks(self, processor_chunk):
Expand Down
6 changes: 6 additions & 0 deletions tests/layers/processor/test_transformer_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def transformer_processor_init():
mlp_hidden_ratio = 4
dropout_p = 0.1
softcap = 0.5
attention_implementation = "scaled dot product attention"
return (
num_layers,
window_size,
Expand All @@ -37,6 +38,7 @@ def transformer_processor_init():
mlp_hidden_ratio,
dropout_p,
softcap,
attention_implementation,
)


Expand All @@ -53,6 +55,7 @@ def transformer_processor(transformer_processor_init):
mlp_hidden_ratio,
dropout_p,
softcap,
attention_implementation,
) = transformer_processor_init
return TransformerProcessor(
num_layers=num_layers,
Expand All @@ -64,6 +67,7 @@ def transformer_processor(transformer_processor_init):
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
dropout_p=dropout_p,
attention_implementation=attention_implementation,
softcap=softcap,
)

Expand All @@ -79,6 +83,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_attention_implementation,
_softcap,
) = transformer_processor_init
assert isinstance(transformer_processor, TransformerProcessor)
Expand All @@ -98,6 +103,7 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_attention_implementation,
_softcap,
) = transformer_processor_init
gridsize = 100
Expand Down
25 changes: 19 additions & 6 deletions tests/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
attention_implementation=st.sampled_from(["scaled dot product attention", "flex attention"]),
)
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p, softcap):
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p, softcap, attention_implementation):
embed_dim = (
num_heads * embed_dim_multiplier
) # TODO: Make assert in MHSA to check if embed_dim is divisible by num_heads
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, use_flash_attention=False, softcap=softcap)
mhsa = MultiHeadSelfAttention(
num_heads, embed_dim, dropout_p=dropout_p, attention_implementation=attention_implementation, softcap=softcap
)

assert isinstance(mhsa, nn.Module)
assert mhsa.num_heads == num_heads
Expand All @@ -44,11 +47,16 @@ def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
attention_implementation=st.sampled_from(["scaled dot product attention"]),
)
@settings(deadline=None)
def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap):
def test_multi_head_self_attention_forward(
batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap, attention_implementation
):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, use_flash_attention=False, softcap=softcap)
mhsa = MultiHeadSelfAttention(
num_heads, embed_dim, dropout_p=dropout_p, attention_implementation=attention_implementation, softcap=softcap
)

x = torch.randn(batch_size * 2, embed_dim)
shapes = [list(x.shape)]
Expand All @@ -64,10 +72,15 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
attention_implementation=st.sampled_from(["scaled dot product attention"]),
)
def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap):
def test_multi_head_self_attention_backward(
batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap, attention_implementation
):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, use_flash_attention=False, softcap=softcap)
mhsa = MultiHeadSelfAttention(
num_heads, embed_dim, dropout_p=dropout_p, attention_implementation=attention_implementation, softcap=softcap
)

x = torch.randn(batch_size * 2, embed_dim, requires_grad=True)
shapes = [list(x.shape)]
Expand Down

0 comments on commit f1be563

Please sign in to comment.