Skip to content

Commit

Permalink
Extend tests to work in 1d, 2d, and 3d
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Apr 5, 2024
1 parent 82b40fb commit 63b5693
Showing 1 changed file with 73 additions and 40 deletions.
113 changes: 73 additions & 40 deletions tests/test_with_dummy_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@


@pytest.mark.parametrize(
"block",
"block,num_spatial_dims",
[
pdeqx.blocks.ClassicDoubleConvBlock,
pdeqx.blocks.ClassicResBlock,
pdeqx.blocks.DilatedResBlock,
pdeqx.blocks.ModernResBlock,
(block, D)
for block in [
pdeqx.blocks.ClassicDoubleConvBlock,
pdeqx.blocks.ClassicResBlock,
pdeqx.blocks.DilatedResBlock,
pdeqx.blocks.ModernResBlock,
]
for D in [1, 2, 3]
],
)
def test_block_with_dummy_input(block):
def test_block_with_dummy_input(block, num_spatial_dims):
instantiated_block = block(
num_spatial_dims=1,
num_spatial_dims=num_spatial_dims,
in_channels=1,
out_channels=1,
hidden_channels=32,
Expand All @@ -24,43 +28,54 @@ def test_block_with_dummy_input(block):
boundary_mode="periodic",
)

dummy_input = jax.random.normal(jax.random.PRNGKey(0), (1, 32))
shape = (1,) + (32,) * num_spatial_dims
dummy_input = jax.random.normal(jax.random.PRNGKey(0), shape)

instantiated_block(dummy_input)


def test_classic_spectral_block_with_dummy_input():
@pytest.mark.parametrize(
"num_spatial_dims",
[1, 2, 3],
)
def test_classic_spectral_block_with_dummy_input(num_spatial_dims):
instantiated_block = pdeqx.blocks.ClassicSpectralBlock(
num_spatial_dims=1,
num_spatial_dims=num_spatial_dims,
in_channels=1,
out_channels=1,
activation=jax.nn.relu,
key=jax.random.PRNGKey(0),
)

dummy_input = jax.random.normal(jax.random.PRNGKey(0), (1, 32))
shape = (1,) + (32,) * num_spatial_dims

dummy_input = jax.random.normal(jax.random.PRNGKey(0), shape)

instantiated_block(dummy_input)


@pytest.mark.parametrize(
"block_factory",
"block_factory,num_spatial_dims",
[
pdeqx.blocks.ClassicDoubleConvBlockFactory,
pdeqx.blocks.ClassicResBlockFactory,
pdeqx.blocks.ClassicSpectralBlockFactory,
pdeqx.blocks.LinearChannelAdjustBlockFactory,
pdeqx.blocks.LinearConvBlockFactory,
pdeqx.blocks.LinearConvDownBlockFactory,
pdeqx.blocks.LinearConvUpBlockFactory,
pdeqx.blocks.ModernResBlockFactory,
(block_factory, D)
for block_factory in [
pdeqx.blocks.ClassicDoubleConvBlockFactory,
pdeqx.blocks.ClassicResBlockFactory,
pdeqx.blocks.ClassicSpectralBlockFactory,
pdeqx.blocks.LinearChannelAdjustBlockFactory,
pdeqx.blocks.LinearConvBlockFactory,
pdeqx.blocks.LinearConvDownBlockFactory,
pdeqx.blocks.LinearConvUpBlockFactory,
pdeqx.blocks.ModernResBlockFactory,
]
for D in [1, 2, 3]
],
)
def test_block_factory_with_dummy_input(block_factory):
def test_block_factory_with_dummy_input(block_factory, num_spatial_dims):
factory = block_factory()

instantiated_block = factory(
num_spatial_dims=1,
num_spatial_dims=num_spatial_dims,
in_channels=1,
out_channels=1,
hidden_channels=32,
Expand All @@ -69,59 +84,77 @@ def test_block_factory_with_dummy_input(block_factory):
boundary_mode="periodic",
)

dummy_input = jax.random.normal(jax.random.PRNGKey(0), (1, 32))
shape = (1,) + (32,) * num_spatial_dims

dummy_input = jax.random.normal(jax.random.PRNGKey(0), shape)

instantiated_block(dummy_input)


@pytest.mark.parametrize(
"arch",
"arch,num_spatial_dims",
[
# pdeqx.arch.ClassicFNO,
pdeqx.arch.ClassicResNet,
pdeqx.arch.ClassicUNet,
pdeqx.arch.ConvNet,
pdeqx.arch.DilatedResNet,
# pdeqx.arch.MLP,
pdeqx.arch.ModernResNet,
(arch, D)
for arch in [
pdeqx.arch.ClassicResNet,
pdeqx.arch.ClassicUNet,
pdeqx.arch.ConvNet,
pdeqx.arch.DilatedResNet,
pdeqx.arch.ModernResNet,
]
for D in [1, 2, 3]
],
)
def test_arch_with_dummy_input(arch):
def test_arch_with_dummy_input(arch, num_spatial_dims):
instantiated_arch = arch(
num_spatial_dims=1,
num_spatial_dims=num_spatial_dims,
in_channels=1,
out_channels=1,
key=jax.random.PRNGKey(0),
boundary_mode="periodic",
)

dummy_input = jax.random.normal(jax.random.PRNGKey(0), (1, 32))
shape = (1,) + (32,) * num_spatial_dims

dummy_input = jax.random.normal(jax.random.PRNGKey(0), shape)

instantiated_arch(dummy_input)


def test_fno_with_dummy_input():
@pytest.mark.parametrize(
"num_spatial_dims",
[1, 2, 3],
)
def test_fno_with_dummy_input(num_spatial_dims):
instantiated_arch = pdeqx.arch.ClassicFNO(
num_spatial_dims=1,
num_spatial_dims=num_spatial_dims,
in_channels=1,
out_channels=1,
key=jax.random.PRNGKey(0),
)

dummy_input = jax.random.normal(jax.random.PRNGKey(0), (1, 32))
shape = (1,) + (32,) * num_spatial_dims

dummy_input = jax.random.normal(jax.random.PRNGKey(0), shape)

instantiated_arch(dummy_input)


def test_mlp_with_dummy_input():
@pytest.mark.parametrize(
"num_spatial_dims",
[1, 2, 3],
)
def test_mlp_with_dummy_input(num_spatial_dims):
instantiated_arch = pdeqx.arch.MLP(
num_spatial_dims=1,
num_spatial_dims=num_spatial_dims,
in_channels=1,
out_channels=1,
num_points=32,
key=jax.random.PRNGKey(0),
)

dummy_input = jax.random.normal(jax.random.PRNGKey(0), (1, 32))
shape = (1,) + (32,) * num_spatial_dims

dummy_input = jax.random.normal(jax.random.PRNGKey(0), shape)

instantiated_arch(dummy_input)

0 comments on commit 63b5693

Please sign in to comment.