diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index d5b46b6d4b..4a748eebdc 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -147,9 +147,9 @@ def test_forward(self) -> None: aparam = None fparam = None if self.module.get_dim_aparam() > 0: - aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) + aparam = rng.random([1, natoms, self.module.get_dim_aparam()]).repeat(nf, 0) if self.module.get_dim_fparam() > 0: - fparam = rng.random([nf, self.module.get_dim_fparam()]) + fparam = rng.random([1, self.module.get_dim_fparam()]).repeat(nf, 0) ret = [] ret_lower = [] for module in self.modules_to_test: