diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml index 6c11e727..523261b5 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml @@ -75,20 +75,20 @@ grads.network.params.5: grads.network.params.6: device: cuda:0 max: '3.249e-02' - mean: '-7.451e-10' + mean: '-1.397e-09' min: '-2.593e-02' shape: - 10 - sum: '-7.451e-09' + sum: '-1.397e-08' grads.network.params.7: device: cuda:0 max: '3.762e-02' - mean: '-1.673e-10' + mean: '-2.430e-10' min: '-4.220e-02' shape: - 256 - 10 - sum: '-4.284e-07' + sum: '-6.221e-07' outputs.logits: device: cuda:0 max: '1.041e+00' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml index 9276335a..b5a4bcf4 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml @@ -37,20 +37,20 @@ grads.network.params.1: grads.network.params.2: device: cuda:0 max: '6.439e-02' - mean: '0.e+00' + mean: '-3.725e-10' min: '-3.123e-02' shape: - 10 - sum: '0.e+00' + sum: '-3.725e-09' grads.network.params.3: device: cuda:0 max: '1.444e-01' - mean: '-9.313e-11' + mean: '-1.048e-10' min: '-1.493e-01' shape: - 256 - 10 - sum: '-2.384e-07' + sum: '-2.682e-07' outputs.logits: device: cuda:0 max: '2.930e+00' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml index 4bfb9392..ec8098ad 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml @@ -75,20 +75,20 @@ grads.network.params.5: grads.network.params.6: device: cuda:0 max: '6.150e-02' - mean: '0.e+00' + mean: '-2.235e-09' min: '-6.966e-02' shape: - 10 - sum: '0.e+00' + sum: '-2.235e-08' grads.network.params.7: device: cuda:0 max: '1.175e-01' - mean: '-7.567e-11' + mean: '-3.201e-10' min: '-1.294e-01' shape: - 256 - 10 - sum: '-1.937e-07' + sum: '-8.196e-07' outputs.logits: device: cuda:0 max: '9.607e-01' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml index b38f5dbd..dc1cb82e 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml @@ -37,20 +37,20 @@ grads.network.params.1: grads.network.params.2: device: cuda:0 max: '1.382e-01' - mean: '-7.451e-10' + mean: '-2.235e-09' min: '-9.016e-02' shape: - 10 - sum: '-7.451e-09' + sum: '-2.235e-08' grads.network.params.3: device: cuda:0 max: '4.029e-01' - mean: '-6.170e-10' + mean: '-5.646e-10' min: '-2.145e-01' shape: - 256 - 10 - sum: '-1.58e-06' + sum: '-1.445e-06' outputs.logits: device: cuda:0 max: '2.481e+00' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml index e797effc..7ccd72a8 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml @@ -75,20 +75,20 @@ grads.network.params.5: grads.network.params.6: device: cuda:0 max: '6.867e-02' - mean: '-7.451e-10' + mean: '-1.490e-09' min: '-7.932e-02' shape: - 10 - sum: '-7.451e-09' + sum: '-1.490e-08' grads.network.params.7: device: cuda:0 max: '7.035e-02' - mean: '-1.193e-10' + mean: '-3.638e-11' min: '-7.68e-02' shape: - 256 - 10 - sum: '-3.055e-07' + sum: '-9.313e-08' outputs.logits: device: cuda:0 max: '8.371e-01' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml index fdf57a4b..df6a2bf4 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_backward_pass_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml @@ -37,20 +37,20 @@ grads.network.params.1: grads.network.params.2: device: cuda:0 max: '4.535e-02' - mean: '3.725e-10' + mean: '-1.118e-09' min: '-7.950e-02' shape: - 10 - sum: '3.725e-09' + sum: '-1.118e-08' grads.network.params.3: device: cuda:0 max: '8.090e-02' - mean: '-5.472e-10' + mean: '8.149e-11' min: '-1.129e-01' shape: - 256 - 10 - sum: '-1.401e-06' + sum: '2.086e-07' outputs.logits: device: cuda:0 max: '2.035e+00' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml index 5f76c79f..6d200efd 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_cnn_jax_image_classifier.yaml @@ -56,11 +56,11 @@ network.params.5: network.params.6: device: cpu max: '2.593e-05' - mean: '3.638e-13' + mean: '1.091e-12' min: '-3.249e-05' shape: - 10 - sum: '3.638e-12' + sum: '1.091e-11' network.params.7: device: cpu max: '1.421e-01' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml index a49a4abf..604f5ef1 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/cifar10_jax_fcnet_jax_image_classifier.yaml @@ -18,11 +18,11 @@ network.params.1: network.params.2: device: cpu max: '3.123e-05' - mean: '0.e+00' + mean: '3.638e-13' min: '-6.439e-05' shape: - 10 - sum: '0.e+00' + sum: '3.638e-12' network.params.3: device: cpu max: '1.421e-01' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml index 4ec020b1..9e75d24b 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_cnn_jax_image_classifier.yaml @@ -56,11 +56,11 @@ network.params.5: network.params.6: device: cpu max: '6.966e-05' - mean: '-5.457e-13' + mean: '1.637e-12' min: '-6.150e-05' shape: - 10 - sum: '-5.457e-12' + sum: '1.637e-11' network.params.7: device: cpu max: '1.421e-01' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml index d25ff948..72e68c1d 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/fashion_mnist_jax_fcnet_jax_image_classifier.yaml @@ -18,11 +18,11 @@ network.params.1: network.params.2: device: cpu max: '9.016e-05' - mean: '3.638e-13' + mean: '2.547e-12' min: '-1.382e-04' shape: - 10 - sum: '3.638e-12' + sum: '2.547e-11' network.params.3: device: cpu max: '1.421e-01' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml index 22cc8e47..e6df78a3 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_cnn_jax_image_classifier.yaml @@ -56,11 +56,11 @@ network.params.5: network.params.6: device: cpu max: '7.932e-05' - mean: '1.16e-12' + mean: '5.23e-13' min: '-6.867e-05' shape: - 10 - sum: '1.16e-11' + sum: '5.23e-12' network.params.7: device: cpu max: '1.421e-01' diff --git a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml index 755881f8..083756b8 100644 --- a/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml +++ b/.regression_files/project/algorithms/jax_image_classifier_test/test_initialization_is_reproducible/mnist_jax_fcnet_jax_image_classifier.yaml @@ -18,11 +18,11 @@ network.params.1: network.params.2: device: cpu max: '7.950e-05' - mean: '-4.832e-14' + mean: '1.123e-12' min: '-4.535e-05' shape: - 10 - sum: '-4.832e-13' + sum: '1.123e-11' network.params.3: device: cpu max: '1.421e-01' diff --git a/project/algorithms/jax_image_classifier_test.py b/project/algorithms/jax_image_classifier_test.py index a1a2ab75..9c0ebe07 100644 --- a/project/algorithms/jax_image_classifier_test.py +++ b/project/algorithms/jax_image_classifier_test.py @@ -9,15 +9,15 @@ from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) -from project.utils.testutils import IN_SELF_HOSTED_GITHUB_CI, run_for_all_configs_of_type +from project.utils.testutils import run_for_all_configs_of_type from .testsuites.lightning_module_tests import LightningModuleTests -@pytest.mark.xfail( - IN_SELF_HOSTED_GITHUB_CI, - reason="TODO: Test appears to be flaky only when run on the self-hosted runner?.", -) +# @pytest.mark.xfail( +# IN_SELF_HOSTED_GITHUB_CI, +# reason="TODO: Test appears to be flaky only when run on the self-hosted runner?.", +# ) @fails_on_macOS_in_CI @run_for_all_configs_of_type("algorithm", JaxImageClassifier) @run_for_all_configs_of_type("algorithm/network", flax.linen.Module)