diff --git a/project/algorithms/jax_image_classifier_test.py b/project/algorithms/jax_image_classifier_test.py index 8f41c745..a1a2ab75 100644 --- a/project/algorithms/jax_image_classifier_test.py +++ b/project/algorithms/jax_image_classifier_test.py @@ -1,13 +1,10 @@ from pathlib import Path -from typing import Any import flax import flax.linen import pytest -from tensor_regression import TensorRegressionFixture from project.algorithms.jax_image_classifier import JaxImageClassifier -from project.algorithms.testsuites.lightning_module_tests import GetStuffFromFirstTrainingStep from project.conftest import fails_on_macOS_in_CI from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, @@ -17,6 +14,10 @@ from .testsuites.lightning_module_tests import LightningModuleTests +@pytest.mark.xfail( + IN_SELF_HOSTED_GITHUB_CI, + reason="TODO: Test appears to be flaky only when run on the self-hosted runner?.", +) @fails_on_macOS_in_CI @run_for_all_configs_of_type("algorithm", JaxImageClassifier) @run_for_all_configs_of_type("algorithm/network", flax.linen.Module) @@ -29,22 +30,6 @@ class TestJaxImageClassifier(LightningModuleTests[JaxImageClassifier]): `flax.linen.Module`. """ - @pytest.mark.xfail( - IN_SELF_HOSTED_GITHUB_CI, - reason="TODO: Test appears to be flaky only when run on the self-hosted runner?.", - ) - def test_initialization_is_reproducible( - self, - training_step_content: tuple[ - JaxImageClassifier, GetStuffFromFirstTrainingStep, list[Any], list[Any] - ], - tensor_regression: TensorRegressionFixture, - accelerator: str, - ): - return super().test_initialization_is_reproducible( - training_step_content, tensor_regression, accelerator - ) - @pytest.mark.slow def test_demo(tmp_path: Path):