From 918284a05c54bd3e7d367657c6ed19bfaa2d024b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 06:40:07 -0500 Subject: [PATCH] Revert "skip float32 tests for numpy" This reverts commit 73b4227790a9c38526f2a6d262f3f599a2339b65. --- .../tests/consistent/descriptor/test_dpa1.py | 6 --- .../tests/consistent/descriptor/test_dpa2.py | 40 +------------------ .../consistent/descriptor/test_se_atten_v2.py | 6 --- .../consistent/descriptor/test_se_e2_a.py | 6 --- .../tests/consistent/descriptor/test_se_r.py | 6 --- .../tests/consistent/descriptor/test_se_t.py | 17 +------- .../consistent/descriptor/test_se_t_tebd.py | 24 +---------- .../tests/consistent/fitting/test_dipole.py | 25 +----------- source/tests/consistent/fitting/test_dos.py | 26 ------------ source/tests/consistent/fitting/test_ener.py | 18 --------- source/tests/consistent/fitting/test_polar.py | 25 +----------- 11 files changed, 5 insertions(+), 194 deletions(-) diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index c789eec1d9..3d80e310d0 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -182,9 +182,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, temperature, @@ -241,9 +238,6 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 1a4d6d5ec0..17c55db368 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -245,9 +245,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -284,42 +281,7 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - - @property - def skip_array_api_strict(self) -> bool: - ( - repinit_tebd_input_mode, - repinit_set_davg_zero, - repinit_type_one_side, - repinit_use_three_body, - repformer_update_g1_has_conv, - repformer_direct_dist, - repformer_update_g1_has_drrd, - repformer_update_g1_has_grrg, - repformer_update_g1_has_attn, - repformer_update_g2_has_g1g1, - repformer_update_g2_has_attn, - repformer_update_h2, - repformer_attn2_has_gate, - repformer_update_style, - repformer_update_residual_init, - repformer_set_davg_zero, - repformer_trainable_ln, - repformer_ln_eps, - repformer_use_sqrt_nnei, - repformer_g1_out_conv, - repformer_g1_out_mlp, - smooth, - exclude_types, - precision, - add_tebd_to_repinit_out, - use_econf_tebd, - use_tebd_bias, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index be9eaeb0d9..f4a8119ca3 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -178,9 +178,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, attn_dotr, @@ -241,9 +238,6 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index a1c26ef98c..286703e21d 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -98,9 +98,6 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -134,9 +131,6 @@ def skip_array_api_strict(self) -> bool: precision, env_protection, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeATF diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index aa352eba14..e851106c44 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -92,9 +92,6 @@ def skip_dp(self) -> bool: excluded_types, precision, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not type_one_side or CommonTest.skip_dp @property @@ -115,9 +112,6 @@ def skip_array_api_strict(self) -> bool: excluded_types, precision, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeRTF diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 7a66873af9..1e6110705a 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -89,9 +89,6 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -104,19 +101,7 @@ def skip_tf(self) -> bool: ) = self.param return env_protection != 0.0 or excluded_types - @property - def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - excluded_types, - precision, - env_protection, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT - + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT skip_jax = not INSTALLED_JAX tf_class = DescrptSeTTF diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 2e5f7fd2ad..4712c28e53 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -127,9 +127,6 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return CommonTest.skip_dp @property @@ -150,26 +147,7 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - - @property - def skip_array_api_strict(self) -> bool: - ( - tebd_dim, - tebd_input_mode, - resnet_dt, - excluded_types, - env_protection, - set_davg_zero, - smooth, - concat_output_tebd, - precision, - use_econf_tebd, - use_tebd_bias, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index d77beab161..088cb30238 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -86,30 +86,6 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - - @property - def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT - tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT @@ -117,6 +93,7 @@ def skip_array_api_strict(self) -> bool: array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index c1155a4190..0649681ccb 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -97,38 +97,12 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - numb_aparam, - numb_dos, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - @property def skip_jax(self) -> bool: return not INSTALLED_JAX @property def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - numb_aparam, - numb_dos, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True return not INSTALLED_ARRAY_API_STRICT tf_class = DOSFittingTF diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index b987d16929..7be0382b16 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -100,21 +100,6 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - numb_fparam, - (numb_aparam, use_aparam_as_mask), - atom_ener, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - skip_jax = not INSTALLED_JAX @property @@ -127,9 +112,6 @@ def skip_array_api_strict(self) -> bool: (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True # TypeError: The array_api_strict namespace does not support the dtype 'bfloat16' return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16" diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index ed9e48ee7a..12f13d1e08 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -86,30 +86,6 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt - @property - def skip_dp(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return CommonTest.skip_dp - - @property - def skip_array_api_strict(self) -> bool: - ( - resnet_dt, - precision, - mixed_types, - ) = self.param - if precision == "float32": - # NumPy doesn't throw errors for float64 x float32 - return True - return not INSTALLED_ARRAY_API_STRICT - tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT @@ -117,6 +93,7 @@ def skip_array_api_strict(self) -> bool: array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self)