Skip to content

Commit

Permalink
Revert "skip float32 tests for numpy"
Browse files Browse the repository at this point in the history
This reverts commit 73b4227.
  • Loading branch information
njzjz committed Nov 12, 2024
1 parent 5f52bb7 commit 918284a
Show file tree
Hide file tree
Showing 11 changed files with 5 additions and 194 deletions.
6 changes: 0 additions & 6 deletions source/tests/consistent/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,6 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
temperature,
Expand Down Expand Up @@ -241,9 +238,6 @@ def skip_array_api_strict(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return (
not INSTALLED_ARRAY_API_STRICT
or self.is_meaningless_zero_attention_layer_tests(
Expand Down
40 changes: 1 addition & 39 deletions source/tests/consistent/descriptor/test_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand Down Expand Up @@ -284,42 +281,7 @@ def skip_tf(self) -> bool:
return True

skip_jax = not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
repinit_tebd_input_mode,
repinit_set_davg_zero,
repinit_type_one_side,
repinit_use_three_body,
repformer_update_g1_has_conv,
repformer_direct_dist,
repformer_update_g1_has_drrd,
repformer_update_g1_has_grrg,
repformer_update_g1_has_attn,
repformer_update_g2_has_g1g1,
repformer_update_g2_has_attn,
repformer_update_h2,
repformer_attn2_has_gate,
repformer_update_style,
repformer_update_residual_init,
repformer_set_davg_zero,
repformer_trainable_ln,
repformer_ln_eps,
repformer_use_sqrt_nnei,
repformer_g1_out_conv,
repformer_g1_out_mlp,
smooth,
exclude_types,
precision,
add_tebd_to_repinit_out,
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptDPA2TF
dp_class = DescrptDPA2DP
Expand Down
6 changes: 0 additions & 6 deletions source/tests/consistent/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp or self.is_meaningless_zero_attention_layer_tests(
attn_layer,
attn_dotr,
Expand Down Expand Up @@ -241,9 +238,6 @@ def skip_array_api_strict(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return (
not INSTALLED_ARRAY_API_STRICT
or self.is_meaningless_zero_attention_layer_tests(
Expand Down
6 changes: 0 additions & 6 deletions source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def skip_dp(self) -> bool:
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand Down Expand Up @@ -134,9 +131,6 @@ def skip_array_api_strict(self) -> bool:
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not type_one_side or not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeATF
Expand Down
6 changes: 0 additions & 6 deletions source/tests/consistent/descriptor/test_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def skip_dp(self) -> bool:
excluded_types,
precision,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not type_one_side or CommonTest.skip_dp

@property
Expand All @@ -115,9 +112,6 @@ def skip_array_api_strict(self) -> bool:
excluded_types,
precision,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not type_one_side or not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeRTF
Expand Down
17 changes: 1 addition & 16 deletions source/tests/consistent/descriptor/test_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ def skip_dp(self) -> bool:
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand All @@ -104,19 +101,7 @@ def skip_tf(self) -> bool:
) = self.param
return env_protection != 0.0 or excluded_types

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
excluded_types,
precision,
env_protection,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
skip_jax = not INSTALLED_JAX

tf_class = DescrptSeTTF
Expand Down
24 changes: 1 addition & 23 deletions source/tests/consistent/descriptor/test_se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,6 @@ def skip_dp(self) -> bool:
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
Expand All @@ -150,26 +147,7 @@ def skip_tf(self) -> bool:
return True

skip_jax = not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
tebd_dim,
tebd_input_mode,
resnet_dt,
excluded_types,
env_protection,
set_davg_zero,
smooth,
concat_output_tebd,
precision,
use_econf_tebd,
use_tebd_bias,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

tf_class = DescrptSeTTebdTF
dp_class = DescrptSeTTebdDP
Expand Down
25 changes: 1 addition & 24 deletions source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,14 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = DipoleFittingTF
dp_class = DipoleFittingDP
pt_class = DipoleFittingPT
jax_class = DipoleFittingJAX
array_api_strict_class = DipoleFittingArrayAPIStrict
args = fitting_dipole()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down
26 changes: 0 additions & 26 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,38 +97,12 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
def skip_jax(self) -> bool:
return not INSTALLED_JAX

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = DOSFittingTF
Expand Down
18 changes: 0 additions & 18 deletions source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,6 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
(numb_aparam, use_aparam_as_mask),
atom_ener,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

skip_jax = not INSTALLED_JAX

@property
Expand All @@ -127,9 +112,6 @@ def skip_array_api_strict(self) -> bool:
(numb_aparam, use_aparam_as_mask),
atom_ener,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
# TypeError: The array_api_strict namespace does not support the dtype 'bfloat16'
return not INSTALLED_ARRAY_API_STRICT or precision == "bfloat16"

Expand Down
25 changes: 1 addition & 24 deletions source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,14 @@ def skip_pt(self) -> bool:
) = self.param
return CommonTest.skip_pt

@property
def skip_dp(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return CommonTest.skip_dp

@property
def skip_array_api_strict(self) -> bool:
(
resnet_dt,
precision,
mixed_types,
) = self.param
if precision == "float32":
# NumPy doesn't throw errors for float64 x float32
return True
return not INSTALLED_ARRAY_API_STRICT

tf_class = PolarFittingTF
dp_class = PolarFittingDP
pt_class = PolarFittingPT
jax_class = PolarFittingJAX
array_api_strict_class = PolarFittingArrayAPIStrict
args = fitting_polar()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

def setUp(self):
CommonTest.setUp(self)
Expand Down

0 comments on commit 918284a

Please sign in to comment.