Skip to content

Commit

Permalink
Add parameter counting test for pw linear
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 18, 2024
1 parent 765793f commit 303811d
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_count_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,41 @@ def test_spectral_conv(
) * 2 ** (num_spatial_dims - 1)

assert num_parameters == num_parameters_expected


@pytest.mark.parametrize(
"num_spatial_dims,in_channels,out_channels,use_bias,zero_bias_init",
[
(num_spatial_dims, in_channels, out_channels, use_bias, zero_bias_init)
for num_spatial_dims in [1, 2, 3]
for in_channels in [1, 2, 5]
for out_channels in [1, 2, 5]
for use_bias in [True, False]
for zero_bias_init in [True, False]
],
)
def test_pointwise_linear_conv(
num_spatial_dims: int,
in_channels: int,
out_channels: int,
use_bias: bool,
zero_bias_init: bool,
):
net = pdeqx.conv.PointwiseLinearConv(
num_spatial_dims=num_spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
use_bias=use_bias,
key=jax.random.PRNGKey(0),
zero_bias_init=zero_bias_init,
)

num_parameters = pdeqx.count_parameters(net)

# Compute the expected number of parameters
num_parameters_expected = in_channels * out_channels

if use_bias:
num_parameters_expected += out_channels

assert num_parameters == num_parameters_expected

0 comments on commit 303811d

Please sign in to comment.