Skip to content

Commit

Permalink
tf.test.TestCase.test_session -> tf.test.TestCase.cached_session (#…
Browse files Browse the repository at this point in the history
…2816)

`tf.test.TestCase.test_session` is deprecated in TF 1.11. We used it
when we still tested TF 1.8, and now it is ok to replace it.
  • Loading branch information
njzjz authored Sep 14, 2023
1 parent 58dd3e2 commit 7da9aaf
Show file tree
Hide file tree
Showing 50 changed files with 90 additions and 90 deletions.
2 changes: 1 addition & 1 deletion source/tests/test_activation_fn_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class TestGelu(tf.test.TestCase):
def setUp(self):
self.places = 6
self.sess = self.test_session().__enter__()
self.sess = self.cached_session().__enter__()
self.inputs = tf.reshape(
tf.constant([0.0, 1.0, 2.0, 3.0], dtype=tf.float64), [-1, 1]
)
Expand Down
6 changes: 3 additions & 3 deletions source/tests/test_data_large_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_data_mixed_type(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
# print(sess.run(model.type_embedding))
Expand Down Expand Up @@ -376,7 +376,7 @@ def test_stripped_data_mixed_type(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
# print(sess.run(model.type_embedding))
Expand Down Expand Up @@ -572,7 +572,7 @@ def test_compressible_data_mixed_type(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
# print(sess.run(model.type_embedding))
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_data_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _setUp(self):
model.build(data)

# freeze the graph
with self.test_session() as sess:
with self.cached_session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
graph = tf.get_default_graph()
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_data_modifier_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _setUp(self):
model.build(data)

# freeze the graph
with self.test_session() as sess:
with self.cached_session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
graph = tf.get_default_graph()
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_descrpt_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_descriptor_hybrid(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout], feed_dict=feed_dict_test)

Expand Down
10 changes: 5 additions & 5 deletions source/tests/test_descrpt_nonsmth.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class TestNonSmooth(Inter, tf.test.TestCase):
def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data, sess=self.test_session().__enter__())
Inter.setUp(self, data, sess=self.cached_session().__enter__())

def test_force(self):
force_test(self, self, suffix="_se")
Expand All @@ -180,8 +180,8 @@ def test_pbc(self):
data = Data()
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data, pbc=True, sess=self.test_session().__enter__())
inter1.setUp(data, pbc=False, sess=self.test_session().__enter__())
inter0.setUp(data, pbc=True, sess=self.cached_session().__enter__())
inter1.setUp(data, pbc=False, sess=self.cached_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down Expand Up @@ -233,8 +233,8 @@ def test_pbc_small_box(self):
data1 = Data(box_scale=2)
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data0, pbc=True, sess=self.test_session().__enter__())
inter1.setUp(data1, pbc=False, sess=self.test_session().__enter__())
inter0.setUp(data0, pbc=True, sess=self.cached_session().__enter__())
inter1.setUp(data1, pbc=False, sess=self.cached_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_descrpt_se_a_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_descriptor_se_a_mask(self):
t_aparam: test_data["aparam"][:numb_test, :],
is_training: False,
}
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[op_dout] = sess.run([dout], feed_dict=feed_dict_test)
op_dout = op_dout.reshape([-1])
Expand Down
4 changes: 2 additions & 2 deletions source/tests/test_descrpt_se_a_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_descriptor_two_sides(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout], feed_dict=feed_dict_test)
model_dout = model_dout.reshape([-1])
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_descriptor_one_side(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout], feed_dict=feed_dict_test)
model_dout = model_dout.reshape([-1])
Expand Down
8 changes: 4 additions & 4 deletions source/tests/test_descrpt_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_descriptor_two_sides(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout], feed_dict=feed_dict_test)
model_dout = model_dout.reshape([-1])
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_descriptor_one_side(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout], feed_dict=feed_dict_test)
model_dout = model_dout.reshape([-1])
Expand Down Expand Up @@ -488,7 +488,7 @@ def test_stripped_type_embedding_descriptor_two_sides(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout], feed_dict=feed_dict_test)
model_dout = model_dout.reshape([-1])
Expand Down Expand Up @@ -666,7 +666,7 @@ def test_compressible_descriptor_two_sides(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[model_dout] = sess.run([dout], feed_dict=feed_dict_test)
model_dout = model_dout.reshape([-1])
Expand Down
10 changes: 5 additions & 5 deletions source/tests/test_descrpt_se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class TestSmooth(Inter, tf.test.TestCase):
def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data, sess=self.test_session().__enter__())
Inter.setUp(self, data, sess=self.cached_session().__enter__())

def test_force(self):
force_test(self, self, suffix="_se_r")
Expand All @@ -155,8 +155,8 @@ def test_pbc(self):
data = Data()
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data, pbc=True, sess=self.test_session().__enter__())
inter1.setUp(data, pbc=False, sess=self.test_session().__enter__())
inter0.setUp(data, pbc=True, sess=self.cached_session().__enter__())
inter1.setUp(data, pbc=False, sess=self.cached_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down Expand Up @@ -208,8 +208,8 @@ def test_pbc_small_box(self):
data1 = Data(box_scale=2)
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data0, pbc=True, sess=self.test_session().__enter__())
inter1.setUp(data1, pbc=False, sess=self.test_session().__enter__())
inter0.setUp(data0, pbc=True, sess=self.cached_session().__enter__())
inter1.setUp(data1, pbc=False, sess=self.cached_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_descrpt_sea_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class TestSmooth(Inter, tf.test.TestCase):
def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data, sess=self.test_session().__enter__())
Inter.setUp(self, data, sess=self.cached_session().__enter__())

def test_force(self):
force_test(self, self, suffix="_sea_ef")
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_descrpt_sea_ef_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class TestSmooth(Inter, tf.test.TestCase):
def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data, sess=self.test_session().__enter__())
Inter.setUp(self, data, sess=self.cached_session().__enter__())

def test_force(self):
force_test(self, self, suffix="_sea_ef_para")
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_descrpt_sea_ef_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class TestEfRot(tf.test.TestCase):
def setUp(self):
self.sess = self.test_session().__enter__()
self.sess = self.cached_session().__enter__()
self.natoms = [5, 5, 2, 3]
self.ntypes = 2
self.sel_a = [12, 24]
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_descrpt_sea_ef_vert.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class TestSmooth(Inter, tf.test.TestCase):
def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data, sess=self.test_session().__enter__())
Inter.setUp(self, data, sess=self.cached_session().__enter__())

def test_force(self):
force_test(self, self, suffix="_sea_ef_vert")
Expand Down
10 changes: 5 additions & 5 deletions source/tests/test_descrpt_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class TestSmooth(Inter, tf.test.TestCase):
def setUp(self):
self.places = 5
data = Data()
Inter.setUp(self, data, sess=self.test_session().__enter__())
Inter.setUp(self, data, sess=self.cached_session().__enter__())

def test_force(self):
force_test(self, self, suffix="_smth")
Expand All @@ -173,8 +173,8 @@ def test_pbc(self):
data = Data()
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data, pbc=True, sess=self.test_session().__enter__())
inter1.setUp(data, pbc=False, sess=self.test_session().__enter__())
inter0.setUp(data, pbc=True, sess=self.cached_session().__enter__())
inter1.setUp(data, pbc=False, sess=self.cached_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down Expand Up @@ -226,8 +226,8 @@ def test_pbc_small_box(self):
data1 = Data(box_scale=2)
inter0 = Inter()
inter1 = Inter()
inter0.setUp(data0, pbc=True, sess=self.test_session().__enter__())
inter1.setUp(data1, pbc=False, sess=self.test_session().__enter__())
inter0.setUp(data0, pbc=True, sess=self.cached_session().__enter__())
inter1.setUp(data1, pbc=False, sess=self.cached_session().__enter__())
inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt))
inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt))

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_dipole_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_model(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[p, gp] = sess.run([dipole, gdipole], feed_dict=feed_dict_test)

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_dipole_se_a_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_model(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[p, gp] = sess.run([dipole, gdipole], feed_dict=feed_dict_test)

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_embedding_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class Inter(tf.test.TestCase):
def setUp(self):
self.sess = self.test_session().__enter__()
self.sess = self.cached_session().__enter__()
self.inputs = tf.constant([0.0, 1.0, 2.0], dtype=tf.float64)
self.ndata = 3
self.inputs = tf.reshape(self.inputs, [-1, 1])
Expand Down
6 changes: 3 additions & 3 deletions source/tests/test_ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def setUp(self):
def test_py_interface(self):
hh = 1e-4
places = 4
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
t_energy, t_force, t_virial = op_module.ewald_recp(
self.coord,
self.charge,
Expand All @@ -91,7 +91,7 @@ def test_py_interface(self):
def test_force(self):
hh = 1e-4
places = 6
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
t_energy, t_force, t_virial = op_module.ewald_recp(
self.coord,
self.charge,
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_force(self):
def test_virial(self):
hh = 1e-4
places = 6
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
t_energy, t_force, t_virial = op_module.ewald_recp(
self.coord,
self.charge,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_fitting_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_fitting(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[pred_atom_dos] = sess.run([atom_dos], feed_dict=feed_dict_test)

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_fitting_ener_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_fitting(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[pred_atom_ener] = sess.run([atom_ener], feed_dict=feed_dict_test)

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_layer_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_model(self):
is_training: False,
}

with self.test_session() as sess:
with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
[e1, f1, v1, e2, f2, v2] = sess.run(
[e_energy1, e_force1, e_virial1, e_energy2, e_force2, e_virial2],
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_linear_ener_model(self):
t_mesh: test_data["default_mesh"],
is_training: False,
}
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
e = np.reshape(e, [1, -1])
Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_model_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_model(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[pred_dos, pred_atom_dos] = sess.run([dos, atom_dos], feed_dict=feed_dict_test)

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_model_loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_model(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)

Expand Down
2 changes: 1 addition & 1 deletion source/tests/test_model_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_model(self):
t_mesh: test_data["default_mesh"],
is_training: False,
}
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()

# test water energy
sess.run(tf.global_variables_initializer())
Expand Down
6 changes: 3 additions & 3 deletions source/tests/test_model_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_model_atom_ener(self):
t_mesh: test_data["default_mesh"],
is_training: False,
}
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
self.assertAlmostEqual(e[0], set_atom_ener[0], places=10)
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_model(self):
is_training: False,
}

sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)

Expand Down Expand Up @@ -347,7 +347,7 @@ def test_model_atom_ener_type_embedding(self):
t_mesh: test_data["default_mesh"],
is_training: False,
}
sess = self.test_session().__enter__()
sess = self.cached_session().__enter__()
sess.run(tf.global_variables_initializer())
[e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test)
self.assertAlmostEqual(e[0], set_atom_ener[0], places=10)
Expand Down
Loading

0 comments on commit 7da9aaf

Please sign in to comment.