Skip to content

Commit

Permalink
Update test_loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 15, 2024
1 parent 6f08450 commit 984eff0
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions source/tests/pt/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def setUp(self):
atom_pref = rng.random(size=[batch_size, nloc * 3])
drdq = rng.random(size=[batch_size, nloc * 2 * 3])
atom_ener_coeff = rng.random(size=[batch_size, nloc])
# placeholders
l_force_real = l_force
l_force_mag = l_force
p_force_real = p_force
p_force_mag = p_force
else:
# data
np_batch, pt_batch = get_batch(
Expand Down Expand Up @@ -129,6 +134,9 @@ def setUp(self):
drdq = rng.random(size=[batch_size, nloc * 2 * 3])
atom_ener_coeff = rng.random(size=[batch_size, nloc])
self.nloc_tf = nloc
natoms = natoms_tf
l_force = l_force_merge_tf
p_force = p_force_merge_tf

# tf
self.g = tf.Graph()
Expand Down Expand Up @@ -181,13 +189,13 @@ def setUp(self):

self.feed_dict = {
t_cur_lr: self.cur_lr,
t_natoms: natoms if not self.spin else natoms_tf,
t_natoms: natoms,
t_penergy: p_energy,
t_pforce: p_force if not self.spin else p_force_merge_tf,
t_pforce: p_force,
t_pvirial: p_virial.reshape(-1, 9),
t_patom_energy: p_atom_energy,
t_lenergy: l_energy,
t_lforce: l_force if not self.spin else l_force_merge_tf,
t_lforce: l_force,
t_lvirial: l_virial.reshape(-1, 9),
t_latom_energy: l_atom_energy,
t_atom_pref: atom_pref,
Expand Down

0 comments on commit 984eff0

Please sign in to comment.