Skip to content

Commit

Permalink
fix: refactor module tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar committed Oct 7, 2024
1 parent 2a7c32c commit b971321
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tests/modules/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import torch
from einops import rearrange
from flax import nnx
from flux.modules.layers import DoubleStreamBlock as TorchDoubleStreamBlock
from flux.modules.layers import MLPEmbedder as TorchMLPEmbedder
from flux.modules.layers import Modulation as TorchModulation
from flux.modules.layers import QKNorm as TorchQKNorm
from flux.modules.layers import RMSNorm as TorchRMSNorm
from flux.modules.layers import DoubleStreamBlock as TorchDoubleStreamBlock

from jflux.modules.layers import DoubleStreamBlock as JaxDoubleStreamBlock
from jflux.modules.layers import MLPEmbedder as JaxMLPEmbedder
from jflux.modules.layers import Modulation as JaxModulation
from jflux.modules.layers import QKNorm as JaxQKNorm
from jflux.modules.layers import RMSNorm as JaxRMSNorm
from jflux.modules.layers import DoubleStreamBlock as JaxDoubleStreamBlock
from tests.utils import torch2jax


Expand Down
15 changes: 9 additions & 6 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from flux.modules.layers import Modulation as PytorchModulation
from flux.modules.layers import SelfAttention as PytorchSelfAttention

from jflux.modules import MLPEmbedder as JaxMLPEmbedder
from jflux.modules import Modulation as JaxModulation
from jflux.modules import SelfAttention as JaxSelfAttention
from jflux.modules.layers import MLPEmbedder as JaxMLPEmbedder
from jflux.modules.layers import Modulation as JaxModulation
from jflux.modules.layers import SelfAttention as JaxSelfAttention
from tests.utils import torch2jax


Expand All @@ -18,7 +18,10 @@ def test_mlp_embedder(self):
# Initialize layers
pytorch_mlp_embedder = MLPEmbedder(in_dim=512, hidden_dim=256)
jax_mlp_embedder = JaxMLPEmbedder(
in_dim=512, hidden_dim=256, rngs=nnx.Rngs(default=42), dtype=jnp.float32
in_dim=512,
hidden_dim=256,
rngs=nnx.Rngs(default=42),
param_dtype=jnp.float32,
)

# Generate random inputs
Expand All @@ -37,7 +40,7 @@ def test_self_attention(self):
# Initialize layers
pytorch_self_attention = PytorchSelfAttention(dim=512)
jax_self_attention = JaxSelfAttention(
dim=512, rngs=nnx.Rngs(default=42), dtype=jnp.float32
dim=512, rngs=nnx.Rngs(default=42), param_dtype=jnp.float32
)

# Generate random inputs
Expand All @@ -57,7 +60,7 @@ def test_modulation(self):
# Initialize layers
pytorch_modulation = PytorchModulation(dim=512, double=True)
jax_modulation = JaxModulation(
dim=512, double=True, rngs=nnx.Rngs(default=42), dtype=jnp.float32
dim=512, double=True, rngs=nnx.Rngs(default=42), param_dtype=jnp.float32
)

# Generate random inputs
Expand Down

0 comments on commit b971321

Please sign in to comment.