Skip to content

Commit

Permalink
fix uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Sep 5, 2024
1 parent a314016 commit 31461be
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
5 changes: 4 additions & 1 deletion source/tests/pt/model/models/dpa2.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
"update_g1_has_attn": true,
"update_g2_has_g1g1": true,
"update_g2_has_attn": true,
"attn2_has_gate": true
"attn2_has_gate": true,
"use_sqrt_nnei": false,
"g1_out_conv": false,
"g1_out_mlp": false
},
"add_tebd_to_repinit_out": false
},
Expand Down
29 changes: 25 additions & 4 deletions source/tests/pt/model/test_dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_consistency(
sm,
prec,
ect,
ns,
) in itertools.product(
["concat", "strip"], # repinit_tebd_input_mode
[
Expand All @@ -70,8 +71,12 @@ def test_consistency(
[True, False], # repformer_update_g1_has_conv
[True, False], # repformer_update_g1_has_drrd
[True, False], # repformer_update_g1_has_grrg
[True, False], # repformer_update_g1_has_attn
[True, False], # repformer_update_g2_has_g1g1
[
False,
], # repformer_update_g1_has_attn
[
False,
], # repformer_update_g2_has_g1g1
[True, False], # repformer_update_g2_has_attn
[
False,
Expand All @@ -83,10 +88,18 @@ def test_consistency(
[
True,
], # repformer_set_davg_zero
[True, False], # smooth
[
True,
], # smooth
["float64"], # precision
[False, True], # use_econf_tebd
[
False,
True,
], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp)
):
if ns and not rp1d and not rp1g:
continue
dtype = PRECISION_DICT[prec]
rtol, atol = get_tols(prec)
if prec == "float64":
Expand Down Expand Up @@ -121,6 +134,9 @@ def test_consistency(
attn2_has_gate=rp2gate,
update_style=rus,
set_davg_zero=rpz,
use_sqrt_nnei=ns,
g1_out_conv=ns,
g1_out_mlp=ns,
)

# dpa2 new impl
Expand Down Expand Up @@ -174,7 +190,7 @@ def test_consistency(
atol=atol,
)
# old impl
if prec == "float64" and rus == "res_avg" and ect is False:
if prec == "float64" and rus == "res_avg" and ect is False and ns is False:
dd3 = DescrptDPA2(
self.nt,
repinit=repinit,
Expand Down Expand Up @@ -239,6 +255,7 @@ def test_jit(
sm,
prec,
ect,
ns,
) in itertools.product(
["concat", "strip"], # repinit_tebd_input_mode
[
Expand Down Expand Up @@ -277,6 +294,7 @@ def test_jit(
], # smooth
["float64"], # precision
[False, True], # use_econf_tebd
[True], # new sub-structures (use_sqrt_nnei, g1_out_conv, g1_out_mlp)
):
dtype = PRECISION_DICT[prec]
rtol, atol = get_tols(prec)
Expand Down Expand Up @@ -310,6 +328,9 @@ def test_jit(
attn2_has_gate=rp2gate,
update_style=rus,
set_davg_zero=rpz,
use_sqrt_nnei=ns,
g1_out_conv=ns,
g1_out_mlp=ns,
)

# dpa2 new impl
Expand Down

0 comments on commit 31461be

Please sign in to comment.