Skip to content

Commit

Permalink
skip float32 tests for numpy
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 12, 2024
1 parent 76b5ff6 commit 73b4227
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 5 deletions.
6 changes: 6 additions & 0 deletions source/tests/consistent/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
temperature,
Expand Down Expand Up @@ -238,6 +241,9 @@ def skip_array_api_strict(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return (
not INSTALLED_ARRAY_API_STRICT
or self.is_meaningless_zero_attention_layer_tests(
Expand Down
40 changes: 39 additions & 1 deletion source/tests/consistent/descriptor/test_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand Down Expand Up @@ -281,7 +284,42 @@ def skip_tf(self) -> bool:
return True

skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

@property
def skip_array_api_strict(self) -> bool:
(
repinit_tebd_input_mode,
repinit_set_davg_zero,
repinit_type_one_side,
repinit_use_three_body,
repformer_update_g1_has_conv,
repformer_direct_dist,
repformer_update_g1_has_drrd,
repformer_update_g1_has_grrg,
repformer_update_g1_has_attn,
repformer_update_g2_has_g1g1,
repformer_update_g2_has_attn,
repformer_update_h2,
repformer_attn2_has_gate,
repformer_update_style,
repformer_update_residual_init,
repformer_set_davg_zero,
repformer_trainable_ln,
repformer_ln_eps,
repformer_use_sqrt_nnei,
repformer_g1_out_conv,
repformer_g1_out_mlp,
smooth,
exclude_types,
precision,
add_tebd_to_repinit_out,
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptDPA2TF
dp_class = DescrptDPA2DP
Expand Down
6 changes: 6 additions & 0 deletions source/tests/consistent/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
Expand Down Expand Up @@ -238,6 +241,9 @@ def skip_array_api_strict(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return (
not INSTALLED_ARRAY_API_STRICT
or self.is_meaningless_zero_attention_layer_tests(
Expand Down
6 changes: 6 additions & 0 deletions source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def skip_dp(self) -> bool:
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand Down Expand Up @@ -131,6 +134,9 @@ def skip_array_api_strict(self) -> bool:
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not type_one_side or not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeATF
Expand Down
6 changes: 6 additions & 0 deletions source/tests/consistent/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def skip_dp(self) -> bool:
excluded_types,
precision,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not type_one_side or CommonTest.skip_dp

@property
Expand All @@ -112,6 +115,9 @@ def skip_array_api_strict(self) -> bool:
excluded_types,
precision,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not type_one_side or not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeRTF
Expand Down
17 changes: 16 additions & 1 deletion source/tests/consistent/descriptor/test_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def skip_dp(self) -> bool:
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand All @@ -101,7 +104,19 @@ def skip_tf(self) -> bool:
) = self.param
return env_protection != 0.0 or excluded_types

skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
excluded_types,
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

skip_jax = not INSTALLED_JAX

tf_class = DescrptSeTTF
Expand Down
24 changes: 23 additions & 1 deletion source/tests/consistent/descriptor/test_se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand All @@ -147,7 +150,26 @@ def skip_tf(self) -> bool:
return True

skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

@property
def skip_array_api_strict(self) -> bool:
(
tebd_dim,
tebd_input_mode,
resnet_dt,
excluded_types,
env_protection,
set_davg_zero,
smooth,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeTTebdTF
dp_class = DescrptSeTTebdDP
Expand Down
25 changes: 24 additions & 1 deletion source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,37 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = DipoleFittingTF
dp_class = DipoleFittingDP
pt_class = DipoleFittingPT
jax_class = DipoleFittingJAX
array_api_strict_class = DipoleFittingArrayAPIStrict
args = fitting_dipole()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down
26 changes: 26 additions & 0 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,38 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
def skip_jax(self) -> bool:
return not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = DOSFittingTF
Expand Down
18 changes: 18 additions & 0 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
(numb_aparam, use_aparam_as_mask),
atom_ener,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

skip_jax = not INSTALLED_JAX

@property
Expand All @@ -112,6 +127,9 @@ def skip_array_api_strict(self) -> bool:
(numb_aparam, use_aparam_as_mask),
atom_ener,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
# TypeError: The array_api_strict namespace does not support the dtype 'bfloat16'
return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16"

Expand Down
25 changes: 24 additions & 1 deletion source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,37 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = PolarFittingTF
dp_class = PolarFittingDP
pt_class = PolarFittingPT
jax_class = PolarFittingJAX
array_api_strict_class = PolarFittingArrayAPIStrict
args = fitting_polar()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down

0 comments on commit 73b4227

Please sign in to comment.