From 984eff0d217bccf023e19a2a62107f16beab23b7 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 15 May 2024 10:39:07 +0800 Subject: [PATCH] Update test_loss.py --- source/tests/pt/test_loss.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 0b1358150f..72ea961c37 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -88,6 +88,11 @@ def setUp(self): atom_pref = rng.random(size=[batch_size, nloc * 3]) drdq = rng.random(size=[batch_size, nloc * 2 * 3]) atom_ener_coeff = rng.random(size=[batch_size, nloc]) + # placeholders + l_force_real = l_force + l_force_mag = l_force + p_force_real = p_force + p_force_mag = p_force else: # data np_batch, pt_batch = get_batch( @@ -129,6 +134,9 @@ def setUp(self): drdq = rng.random(size=[batch_size, nloc * 2 * 3]) atom_ener_coeff = rng.random(size=[batch_size, nloc]) self.nloc_tf = nloc + natoms = natoms_tf + l_force = l_force_merge_tf + p_force = p_force_merge_tf # tf self.g = tf.Graph() @@ -181,13 +189,13 @@ def setUp(self): self.feed_dict = { t_cur_lr: self.cur_lr, - t_natoms: natoms if not self.spin else natoms_tf, + t_natoms: natoms, t_penergy: p_energy, - t_pforce: p_force if not self.spin else p_force_merge_tf, + t_pforce: p_force, t_pvirial: p_virial.reshape(-1, 9), t_patom_energy: p_atom_energy, t_lenergy: l_energy, - t_lforce: l_force if not self.spin else l_force_merge_tf, + t_lforce: l_force, t_lvirial: l_virial.reshape(-1, 9), t_latom_energy: l_atom_energy, t_atom_pref: atom_pref,