From 73b4227790a9c38526f2a6d262f3f599a2339b65 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 12 Nov 2024 05:12:26 -0500 Subject: [PATCH] skip float32 tests for numpy Signed-off-by: Jinzhe Zeng --- .../tests/consistent/descriptor/test_dpa1.py | 6 +++ .../tests/consistent/descriptor/test_dpa2.py | 40 ++++++++++++++++++- .../consistent/descriptor/test_se_atten_v2.py | 6 +++ .../consistent/descriptor/test_se_e2_a.py | 6 +++ .../tests/consistent/descriptor/test_se_r.py | 6 +++ .../tests/consistent/descriptor/test_se_t.py | 17 +++++++- .../consistent/descriptor/test_se_t_tebd.py | 24 ++++++++++- .../tests/consistent/fitting/test_dipole.py | 25 +++++++++++- source/tests/consistent/fitting/test_dos.py | 26 ++++++++++++ source/tests/consistent/fitting/test_ener.py | 18 +++++++++ source/tests/consistent/fitting/test_polar.py | 25 +++++++++++- 11 files changed, 194 insertions(+), 5 deletions(-) diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 3d80e310d0..c789eec1d9 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -182,6 +182,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, temperature, @@ -238,6 +241,9 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 17c55db368..1a4d6d5ec0 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -245,6 +245,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -281,7 +284,42 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + + @property + def skip_array_api_strict(self) -> bool: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repinit_use_three_body, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + use_econf_tebd, + use_tebd_bias, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index f4a8119ca3..be9eaeb0d9 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -178,6 +178,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests( attn_layer, attn_dotr, @@ -238,6 +241,9 @@ def skip_array_api_strict(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return ( not INSTALLED_ARRAY_API_STRICT or self.is_meaningless_zero_attention_layer_tests( diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index 286703e21d..a1c26ef98c 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -98,6 +98,9 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -131,6 +134,9 @@ def skip_array_api_strict(self) -> bool: precision, env_protection, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeATF diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index e851106c44..aa352eba14 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -92,6 +92,9 @@ def skip_dp(self) -> bool: excluded_types, precision, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not type_one_side or CommonTest.skip_dp @property @@ -112,6 +115,9 @@ def skip_array_api_strict(self) -> bool: excluded_types, precision, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not type_one_side or not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeRTF diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 1e6110705a..7a66873af9 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -89,6 +89,9 @@ def skip_dp(self) -> bool: precision, env_protection, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -101,7 +104,19 @@ def skip_tf(self) -> bool: ) = self.param return env_protection != 0.0 or excluded_types - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT + skip_jax = not INSTALLED_JAX tf_class = DescrptSeTTF diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 4712c28e53..2e5f7fd2ad 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -127,6 +127,9 @@ def skip_dp(self) -> bool: use_econf_tebd, use_tebd_bias, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return CommonTest.skip_dp @property @@ -147,7 +150,26 @@ def skip_tf(self) -> bool: return True skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + + @property + def skip_array_api_strict(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + excluded_types, + env_protection, + set_davg_zero, + smooth, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 088cb30238..d77beab161 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -86,6 +86,30 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT + tf_class = DipoleFittingTF dp_class = DipoleFittingDP pt_class = DipoleFittingPT @@ -93,7 +117,6 @@ def skip_pt(self) -> bool: array_api_strict_class = DipoleFittingArrayAPIStrict args = fitting_dipole() skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 0649681ccb..c1155a4190 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -97,12 +97,38 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + numb_dos, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + @property def skip_jax(self) -> bool: return not INSTALLED_JAX @property def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + numb_dos, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True return not INSTALLED_ARRAY_API_STRICT tf_class = DOSFittingTF diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index 7be0382b16..b987d16929 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -100,6 +100,21 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + skip_jax = not INSTALLED_JAX @property @@ -112,6 +127,9 @@ def skip_array_api_strict(self) -> bool: (numb_aparam, use_aparam_as_mask), atom_ener, ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True # TypeError: The array_api_strict namespace does not support the dtype 'bfloat16' return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16" diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 12f13d1e08..ed9e48ee7a 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -86,6 +86,30 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_dp(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return CommonTest.skip_dp + + @property + def skip_array_api_strict(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + ) = self.param + if precision == "float32": + # NumPy doesn't throw errors for float64 x float32 + return True + return not INSTALLED_ARRAY_API_STRICT + tf_class = PolarFittingTF dp_class = PolarFittingDP pt_class = PolarFittingPT @@ -93,7 +117,6 @@ def skip_pt(self) -> bool: array_api_strict_class = PolarFittingArrayAPIStrict args = fitting_polar() skip_jax = not INSTALLED_JAX - skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT def setUp(self): CommonTest.setUp(self)