diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 5d1ad9ad47..5e9ea01602 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -2,6 +2,7 @@ import unittest from typing import ( Any, + Optional, Tuple, ) @@ -111,6 +112,15 @@ def data(self) -> dict: "seed": 1145141919810, } + def is_meaningless_zero_attention_layer_tests( + self, + attn_layer: int, + attn_dotr: bool, + normalize: bool, + temperature: Optional[float], + ) -> bool: + return attn_layer == 0 and (attn_dotr or normalize or temperature is not None) + @property def skip_pt(self) -> bool: ( @@ -133,7 +143,12 @@ def skip_pt(self) -> bool: precision, use_econf_tebd, ) = self.param - return CommonTest.skip_pt + return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) @property def skip_dp(self) -> bool: @@ -157,7 +172,12 @@ def skip_dp(self) -> bool: precision, use_econf_tebd, ) = self.param - return CommonTest.skip_pt + return CommonTest.skip_pt or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) @property def skip_tf(self) -> bool: @@ -183,12 +203,21 @@ def skip_tf(self) -> bool: ) = self.param # TODO (excluded_types != [] and attn_layer > 0) need fix return ( - env_protection != 0.0 - or smooth_type_embedding - or not normalize - or temperature != 1.0 - or (excluded_types != [] and attn_layer > 0) - or (type_one_side and tebd_input_mode == "strip") # not consistent yet + CommonTest.skip_tf + or ( + env_protection != 0.0 + or smooth_type_embedding + or not normalize + or temperature != 1.0 + or (excluded_types != [] and attn_layer > 0) + or (type_one_side and tebd_input_mode == "strip") # not consistent yet + ) + or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) ) tf_class = DescrptDPA1TF