diff --git a/source/tests/pt/test_loss.py b/source/tests/pt/test_loss.py index 0b1358150f..72ea961c37 100644 --- a/source/tests/pt/test_loss.py +++ b/source/tests/pt/test_loss.py @@ -88,6 +88,11 @@ def setUp(self): atom_pref = rng.random(size=[batch_size, nloc * 3]) drdq = rng.random(size=[batch_size, nloc * 2 * 3]) atom_ener_coeff = rng.random(size=[batch_size, nloc]) + # placeholders + l_force_real = l_force + l_force_mag = l_force + p_force_real = p_force + p_force_mag = p_force else: # data np_batch, pt_batch = get_batch( @@ -129,6 +134,9 @@ def setUp(self): drdq = rng.random(size=[batch_size, nloc * 2 * 3]) atom_ener_coeff = rng.random(size=[batch_size, nloc]) self.nloc_tf = nloc + natoms = natoms_tf + l_force = l_force_merge_tf + p_force = p_force_merge_tf # tf self.g = tf.Graph() @@ -181,13 +189,13 @@ def setUp(self): self.feed_dict = { t_cur_lr: self.cur_lr, - t_natoms: natoms if not self.spin else natoms_tf, + t_natoms: natoms, t_penergy: p_energy, - t_pforce: p_force if not self.spin else p_force_merge_tf, + t_pforce: p_force, t_pvirial: p_virial.reshape(-1, 9), t_patom_energy: p_atom_energy, t_lenergy: l_energy, - t_lforce: l_force if not self.spin else l_force_merge_tf, + t_lforce: l_force, t_lvirial: l_virial.reshape(-1, 9), t_latom_energy: l_atom_energy, t_atom_pref: atom_pref,