From 31461be4ad4f52edfb0e158e99127c2248b47910 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 5 Sep 2024 21:32:14 +0800 Subject: [PATCH] fix uts --- source/tests/pt/model/models/dpa2.json | 5 ++++- source/tests/pt/model/test_dpa2.py | 29 ++++++++++++++++++++++---- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/source/tests/pt/model/models/dpa2.json b/source/tests/pt/model/models/dpa2.json index ca1948492a..7495f5d78a 100644 --- a/source/tests/pt/model/models/dpa2.json +++ b/source/tests/pt/model/models/dpa2.json @@ -37,7 +37,10 @@ "update_g1_has_attn": true, "update_g2_has_g1g1": true, "update_g2_has_attn": true, - "attn2_has_gate": true + "attn2_has_gate": true, + "use_sqrt_nnei": false, + "g1_out_conv": false, + "g1_out_mlp": false }, "add_tebd_to_repinit_out": false }, diff --git a/source/tests/pt/model/test_dpa2.py b/source/tests/pt/model/test_dpa2.py index 6d3b6e182d..f11be532cb 100644 --- a/source/tests/pt/model/test_dpa2.py +++ b/source/tests/pt/model/test_dpa2.py @@ -62,6 +62,7 @@ def test_consistency( sm, prec, ect, + ns, ) in itertools.product( ["concat", "strip"], # repinit_tebd_input_mode [ @@ -70,8 +71,12 @@ def test_consistency( [True, False], # repformer_update_g1_has_conv [True, False], # repformer_update_g1_has_drrd [True, False], # repformer_update_g1_has_grrg - [True, False], # repformer_update_g1_has_attn - [True, False], # repformer_update_g2_has_g1g1 + [ + False, + ], # repformer_update_g1_has_attn + [ + False, + ], # repformer_update_g2_has_g1g1 [True, False], # repformer_update_g2_has_attn [ False, @@ -83,10 +88,18 @@ def test_consistency( [ True, ], # repformer_set_davg_zero - [True, False], # smooth + [ + True, + ], # smooth ["float64"], # precision [False, True], # use_econf_tebd + [ + False, + True, + ], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) ): + if ns and not rp1d and not rp1g: + continue dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) if prec == "float64": @@ -121,6 +134,9 @@ def test_consistency( attn2_has_gate=rp2gate, update_style=rus, set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, ) # dpa2 new impl @@ -174,7 +190,7 @@ def test_consistency( atol=atol, ) # old impl - if prec == "float64" and rus == "res_avg" and ect is False: + if prec == "float64" and rus == "res_avg" and ect is False and ns is False: dd3 = DescrptDPA2( self.nt, repinit=repinit, @@ -239,6 +255,7 @@ def test_jit( sm, prec, ect, + ns, ) in itertools.product( ["concat", "strip"], # repinit_tebd_input_mode [ @@ -277,6 +294,7 @@ def test_jit( ], # smooth ["float64"], # precision [False, True], # use_econf_tebd + [True], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp) ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -310,6 +328,9 @@ def test_jit( attn2_has_gate=rp2gate, update_style=rus, set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, ) # dpa2 new impl