diff --git a/source/tests/test_adjust_sel.py b/source/tests/test_adjust_sel.py index b1cbdc5afc..9bed3606fd 100644 --- a/source/tests/test_adjust_sel.py +++ b/source/tests/test_adjust_sel.py @@ -82,12 +82,10 @@ def _init_models(): return INPUT, frozen_model, decreased_model, increased_model -INPUT, FROZEN_MODEL, DECREASED_MODEL, INCREASED_MODEL = _init_models() - - class TestDeepPotAAdjustSel(unittest.TestCase): @classmethod def setUpClass(self): + INPUT, FROZEN_MODEL, DECREASED_MODEL, INCREASED_MODEL = _init_models() self.dp_original = DeepPot(FROZEN_MODEL) self.dp_decreased = DeepPot(DECREASED_MODEL) self.dp_increased = DeepPot(INCREASED_MODEL) diff --git a/source/tests/test_finetune_se_atten.py b/source/tests/test_finetune_se_atten.py index 3614fcb13a..47fedcf685 100644 --- a/source/tests/test_finetune_se_atten.py +++ b/source/tests/test_finetune_se_atten.py @@ -147,67 +147,77 @@ def _init_models(setup_model, i): ) -if not parse_version(tf.__version__) < parse_version("1.15"): - - def previous_se_atten(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = False - jdata["model"]["descriptor"]["attn_layer"] = 2 - - def stripped_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True - jdata["model"]["descriptor"]["attn_layer"] = 2 - - def compressible_model(jdata): - jdata["model"]["descriptor"]["stripped_type_embedding"] = True - jdata["model"]["descriptor"]["attn_layer"] = 0 - - models = [previous_se_atten, stripped_model, compressible_model] - INPUT_PRES = [] - INPUT_FINETUNES = [] - INPUT_FINETUNE_MIXS = [] - PRE_MODELS = [] - FINETUNED_MODELS = [] - FINETUNED_MODEL_MIXS = [] - PRE_MAPS = [] - FINETUNED_MAPS = [] - VALID_DATAS = [] - for i, model in enumerate(models): - ( - INPUT_PRE, - INPUT_FINETUNE, - INPUT_FINETUNE_MIX, - PRE_MODEL, - FINETUNED_MODEL, - FINETUNED_MODEL_MIX, - PRE_MAP, - FINETUNED_MAP, - VALID_DATA, - ) = _init_models(model, i) - INPUT_PRES.append(INPUT_PRE) - INPUT_FINETUNES.append(INPUT_FINETUNE) - INPUT_FINETUNE_MIXS.append(INPUT_FINETUNE_MIX) - PRE_MODELS.append(PRE_MODEL) - FINETUNED_MODELS.append(FINETUNED_MODEL) - FINETUNED_MODEL_MIXS.append(FINETUNED_MODEL_MIX) - PRE_MAPS.append(PRE_MAP) - FINETUNED_MAPS.append(FINETUNED_MAP) - VALID_DATAS.append(VALID_DATA) - - @unittest.skipIf( parse_version(tf.__version__) < parse_version("1.15"), f"The current tf version {tf.__version__} is too low to run the new testing model.", ) class TestFinetuneSeAtten(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + if not parse_version(tf.__version__) < parse_version("1.15"): + + def previous_se_atten(jdata): + jdata["model"]["descriptor"]["stripped_type_embedding"] = False + jdata["model"]["descriptor"]["attn_layer"] = 2 + + def stripped_model(jdata): + jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["attn_layer"] = 2 + + def compressible_model(jdata): + jdata["model"]["descriptor"]["stripped_type_embedding"] = True + jdata["model"]["descriptor"]["attn_layer"] = 0 + + models = [previous_se_atten, stripped_model, compressible_model] + INPUT_PRES = [] + INPUT_FINETUNES = [] + INPUT_FINETUNE_MIXS = [] + PRE_MODELS = [] + FINETUNED_MODELS = [] + FINETUNED_MODEL_MIXS = [] + PRE_MAPS = [] + FINETUNED_MAPS = [] + VALID_DATAS = [] + for i, model in enumerate(models): + ( + INPUT_PRE, + INPUT_FINETUNE, + INPUT_FINETUNE_MIX, + PRE_MODEL, + FINETUNED_MODEL, + FINETUNED_MODEL_MIX, + PRE_MAP, + FINETUNED_MAP, + VALID_DATA, + ) = _init_models(model, i) + INPUT_PRES.append(INPUT_PRE) + INPUT_FINETUNES.append(INPUT_FINETUNE) + INPUT_FINETUNE_MIXS.append(INPUT_FINETUNE_MIX) + PRE_MODELS.append(PRE_MODEL) + FINETUNED_MODELS.append(FINETUNED_MODEL) + FINETUNED_MODEL_MIXS.append(FINETUNED_MODEL_MIX) + PRE_MAPS.append(PRE_MAP) + FINETUNED_MAPS.append(FINETUNED_MAP) + VALID_DATAS.append(VALID_DATA) + cls.INPUT_PRES = INPUT_PRES + cls.INPUT_FINETUNES = INPUT_FINETUNES + cls.INPUT_FINETUNE_MIXS = INPUT_FINETUNE_MIXS + cls.PRE_MODELS = PRE_MODELS + cls.FINETUNED_MODELS = FINETUNED_MODELS + cls.FINETUNED_MODEL_MIXS = FINETUNED_MODEL_MIXS + cls.PRE_MAPS = PRE_MAPS + cls.FINETUNED_MAPS = FINETUNED_MAPS + cls.VALID_DATAS = VALID_DATAS + @classmethod def tearDownClass(self): - for i in range(len(INPUT_PRES)): - _file_delete(INPUT_PRES[i]) - _file_delete(INPUT_FINETUNES[i]) - _file_delete(INPUT_FINETUNE_MIXS[i]) - _file_delete(PRE_MODELS[i]) - _file_delete(FINETUNED_MODELS[i]) - _file_delete(FINETUNED_MODEL_MIXS[i]) + for i in range(len(self.INPUT_PRES)): + _file_delete(self.INPUT_PRES[i]) + _file_delete(self.INPUT_FINETUNES[i]) + _file_delete(self.INPUT_FINETUNE_MIXS[i]) + _file_delete(self.PRE_MODELS[i]) + _file_delete(self.FINETUNED_MODELS[i]) + _file_delete(self.FINETUNED_MODEL_MIXS[i]) _file_delete("out.json") _file_delete("model.ckpt.meta") _file_delete("model.ckpt.index") @@ -223,22 +233,22 @@ def tearDownClass(self): _file_delete("lcurve.out") def test_finetune_standard(self): - for i in range(len(INPUT_PRES)): - self.valid_data = VALID_DATAS[i] + for i in range(len(self.INPUT_PRES)): + self.valid_data = self.VALID_DATAS[i] pretrained_bias = get_tensor_by_name( - PRE_MODELS[i], "fitting_attr/t_bias_atom_e" + self.PRE_MODELS[i], "fitting_attr/t_bias_atom_e" ) finetuned_bias = get_tensor_by_name( - FINETUNED_MODELS[i], "fitting_attr/t_bias_atom_e" + self.FINETUNED_MODELS[i], "fitting_attr/t_bias_atom_e" ) - sorter = np.argsort(PRE_MAPS[i]) + sorter = np.argsort(self.PRE_MAPS[i]) idx_type_map = sorter[ - np.searchsorted(PRE_MAPS[i], FINETUNED_MAPS[i], sorter=sorter) + np.searchsorted(self.PRE_MAPS[i], self.FINETUNED_MAPS[i], sorter=sorter) ] test_data = self.valid_data.get_test() atom_nums = np.tile(np.bincount(test_data["type"][0])[idx_type_map], (4, 1)) - dp = DeepPotential(PRE_MODELS[i]) + dp = DeepPotential(self.PRE_MODELS[i]) energy = dp.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] @@ -250,7 +260,7 @@ def test_finetune_standard(self): 0 ].reshape(-1) - dp_finetuned = DeepPotential(FINETUNED_MODELS[i]) + dp_finetuned = DeepPotential(self.FINETUNED_MODELS[i]) energy_finetuned = dp_finetuned.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] @@ -266,22 +276,22 @@ def test_finetune_standard(self): np.testing.assert_almost_equal(finetune_results, 0.0, default_places) def test_finetune_mixed_type(self): - for i in range(len(INPUT_PRES)): - self.valid_data = VALID_DATAS[i] + for i in range(len(self.INPUT_PRES)): + self.valid_data = self.VALID_DATAS[i] pretrained_bias = get_tensor_by_name( - PRE_MODELS[i], "fitting_attr/t_bias_atom_e" + self.PRE_MODELS[i], "fitting_attr/t_bias_atom_e" ) finetuned_bias_mixed_type = get_tensor_by_name( - FINETUNED_MODEL_MIXS[i], "fitting_attr/t_bias_atom_e" + self.FINETUNED_MODEL_MIXS[i], "fitting_attr/t_bias_atom_e" ) - sorter = np.argsort(PRE_MAPS[i]) + sorter = np.argsort(self.PRE_MAPS[i]) idx_type_map = sorter[ - np.searchsorted(PRE_MAPS[i], FINETUNED_MAPS[i], sorter=sorter) + np.searchsorted(self.PRE_MAPS[i], self.FINETUNED_MAPS[i], sorter=sorter) ] test_data = self.valid_data.get_test() atom_nums = np.tile(np.bincount(test_data["type"][0])[idx_type_map], (4, 1)) - dp = DeepPotential(PRE_MODELS[i]) + dp = DeepPotential(self.PRE_MODELS[i]) energy = dp.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] @@ -293,7 +303,7 @@ def test_finetune_mixed_type(self): 0 ].reshape(-1) - dp_finetuned_mixed_type = DeepPotential(FINETUNED_MODEL_MIXS[i]) + dp_finetuned_mixed_type = DeepPotential(self.FINETUNED_MODEL_MIXS[i]) energy_finetuned = dp_finetuned_mixed_type.eval( test_data["coord"], test_data["box"], test_data["type"][0] )[0] diff --git a/source/tests/test_init_frz_model_multi.py b/source/tests/test_init_frz_model_multi.py index e5e5733c7d..fc37d82397 100644 --- a/source/tests/test_init_frz_model_multi.py +++ b/source/tests/test_init_frz_model_multi.py @@ -180,20 +180,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelMulti(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data_dict = {"water_ener": VALID_DATA} @@ -205,19 +204,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_a.py b/source/tests/test_init_frz_model_se_a.py index d98c2bc14f..7545e3aae9 100644 --- a/source/tests/test_init_frz_model_se_a.py +++ b/source/tests/test_init_frz_model_se_a.py @@ -128,20 +128,18 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelA(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -149,19 +147,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_a_tebd.py b/source/tests/test_init_frz_model_se_a_tebd.py index 594bf83085..1b282c00d5 100644 --- a/source/tests/test_init_frz_model_se_a_tebd.py +++ b/source/tests/test_init_frz_model_se_a_tebd.py @@ -129,20 +129,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelA(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -150,19 +149,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_a_type.py b/source/tests/test_init_frz_model_se_a_type.py index 3221245065..b356dbf6d0 100644 --- a/source/tests/test_init_frz_model_se_a_type.py +++ b/source/tests/test_init_frz_model_se_a_type.py @@ -132,20 +132,18 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelAType(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -153,19 +151,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_se_atten.py b/source/tests/test_init_frz_model_se_atten.py index 5554ae415c..85a3e8457c 100644 --- a/source/tests/test_init_frz_model_se_atten.py +++ b/source/tests/test_init_frz_model_se_atten.py @@ -146,32 +146,6 @@ def compressible_model(jdata): jdata["model"]["descriptor"]["stripped_type_embedding"] = True jdata["model"]["descriptor"]["attn_layer"] = 0 - models = [previous_se_atten, stripped_model, compressible_model] - INPUTS = [] - CKPTS = [] - FROZEN_MODELS = [] - CKPT_TRAINERS = [] - FRZ_TRAINERS = [] - VALID_DATAS = [] - STOP_BATCHS = [] - for i, model in enumerate(models): - ( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, - ) = _init_models(model, i) - INPUTS.append(INPUT) - CKPTS.append(CKPT) - FROZEN_MODELS.append(FROZEN_MODEL) - CKPT_TRAINERS.append(CKPT_TRAINER) - FRZ_TRAINERS.append(FRZ_TRAINER) - VALID_DATAS.append(VALID_DATA) - STOP_BATCHS.append(STOP_BATCH) - @unittest.skipIf( parse_version(tf.__version__) < parse_version("1.15"), @@ -180,6 +154,38 @@ def compressible_model(jdata): class TestInitFrzModelAtten(unittest.TestCase): @classmethod def setUpClass(cls): + models = [previous_se_atten, stripped_model, compressible_model] + INPUTS = [] + CKPTS = [] + FROZEN_MODELS = [] + CKPT_TRAINERS = [] + FRZ_TRAINERS = [] + VALID_DATAS = [] + STOP_BATCHS = [] + for i, model in enumerate(models): + ( + INPUT, + CKPT, + FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models(model, i) + INPUTS.append(INPUT) + CKPTS.append(CKPT) + FROZEN_MODELS.append(FROZEN_MODEL) + CKPT_TRAINERS.append(CKPT_TRAINER) + FRZ_TRAINERS.append(FRZ_TRAINER) + VALID_DATAS.append(VALID_DATA) + STOP_BATCHS.append(STOP_BATCH) + cls.INPUTS = INPUTS + cls.CKPTS = CKPTS + cls.FROZEN_MODELS = FROZEN_MODELS + cls.CKPT_TRAINERS = CKPT_TRAINERS + cls.FRZ_TRAINERS = FRZ_TRAINERS + cls.VALID_DATAS = VALID_DATAS + cls.STOP_BATCHS = STOP_BATCHS cls.dp_ckpts = CKPT_TRAINERS cls.dp_frzs = FRZ_TRAINERS cls.valid_datas = VALID_DATAS @@ -188,28 +194,28 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): for i in range(len(cls.dp_ckpts)): - _file_delete(INPUTS[i]) - _file_delete(FROZEN_MODELS[i]) + _file_delete(cls.INPUTS[i]) + _file_delete(cls.FROZEN_MODELS[i]) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT[i] + ".meta") - _file_delete(CKPT[i] + ".index") - _file_delete(CKPT[i] + ".data-00000-of-00001") - _file_delete(CKPT[i] + "-0.meta") - _file_delete(CKPT[i] + "-0.index") - _file_delete(CKPT[i] + "-0.data-00000-of-00001") - _file_delete(CKPT[i] + "-1.meta") - _file_delete(CKPT[i] + "-1.index") - _file_delete(CKPT[i] + "-1.data-00000-of-00001") + _file_delete(cls.CKPT[i] + ".meta") + _file_delete(cls.CKPT[i] + ".index") + _file_delete(cls.CKPT[i] + ".data-00000-of-00001") + _file_delete(cls.CKPT[i] + "-0.meta") + _file_delete(cls.CKPT[i] + "-0.index") + _file_delete(cls.CKPT[i] + "-0.data-00000-of-00001") + _file_delete(cls.CKPT[i] + "-1.meta") + _file_delete(cls.CKPT[i] + "-1.index") + _file_delete(cls.CKPT[i] + "-1.data-00000-of-00001") _file_delete(f"input_v2_compat{i}.json") _file_delete("lcurve.out") def test_single_frame(self): for i in range(len(self.dp_ckpts)): - self.dp_ckpt = CKPT_TRAINERS[i] - self.dp_frz = FRZ_TRAINERS[i] - self.valid_data = VALID_DATAS[i] - self.stop_batch = STOP_BATCHS[i] + self.dp_ckpt = self.CKPT_TRAINERS[i] + self.dp_frz = self.FRZ_TRAINERS[i] + self.valid_data = self.VALID_DATAS[i] + self.stop_batch = self.STOP_BATCHS[i] valid_batch = self.valid_data.get_batch() natoms = valid_batch["natoms_vec"] diff --git a/source/tests/test_init_frz_model_se_r.py b/source/tests/test_init_frz_model_se_r.py index 84d109bcfd..fd916b3fdc 100644 --- a/source/tests/test_init_frz_model_se_r.py +++ b/source/tests/test_init_frz_model_se_r.py @@ -136,20 +136,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelR(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -157,19 +156,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_init_frz_model_spin.py b/source/tests/test_init_frz_model_spin.py index 7aa3d514dc..b5c480c2ba 100644 --- a/source/tests/test_init_frz_model_spin.py +++ b/source/tests/test_init_frz_model_spin.py @@ -140,20 +140,19 @@ def _init_models(): return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch -( - INPUT, - CKPT, - FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, -) = _init_models() - - class TestInitFrzModelR(unittest.TestCase): @classmethod def setUpClass(cls): + ( + cls.INPUT, + cls.CKPT, + cls.FROZEN_MODEL, + CKPT_TRAINER, + FRZ_TRAINER, + VALID_DATA, + STOP_BATCH, + ) = _init_models() + cls.dp_ckpt = CKPT_TRAINER cls.dp_frz = FRZ_TRAINER cls.valid_data = VALID_DATA @@ -161,19 +160,19 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - _file_delete(INPUT) - _file_delete(FROZEN_MODEL) + _file_delete(cls.INPUT) + _file_delete(cls.FROZEN_MODEL) _file_delete("out.json") _file_delete(str(tests_path / "checkpoint")) - _file_delete(CKPT + ".meta") - _file_delete(CKPT + ".index") - _file_delete(CKPT + ".data-00000-of-00001") - _file_delete(CKPT + "-0.meta") - _file_delete(CKPT + "-0.index") - _file_delete(CKPT + "-0.data-00000-of-00001") - _file_delete(CKPT + "-1.meta") - _file_delete(CKPT + "-1.index") - _file_delete(CKPT + "-1.data-00000-of-00001") + _file_delete(cls.CKPT + ".meta") + _file_delete(cls.CKPT + ".index") + _file_delete(cls.CKPT + ".data-00000-of-00001") + _file_delete(cls.CKPT + "-0.meta") + _file_delete(cls.CKPT + "-0.index") + _file_delete(cls.CKPT + "-0.data-00000-of-00001") + _file_delete(cls.CKPT + "-1.meta") + _file_delete(cls.CKPT + "-1.index") + _file_delete(cls.CKPT + "-1.data-00000-of-00001") _file_delete("input_v2_compat.json") _file_delete("lcurve.out") diff --git a/source/tests/test_model_compression_se_a_ebd_type_one_side.py b/source/tests/test_model_compression_se_a_ebd_type_one_side.py index 9ad1970e9b..741c95b26e 100644 --- a/source/tests/test_model_compression_se_a_ebd_type_one_side.py +++ b/source/tests/test_model_compression_se_a_ebd_type_one_side.py @@ -98,7 +98,6 @@ def _init_models_exclude_types(): INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() -INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() class TestDeepPotAPBC(unittest.TestCase): @@ -444,8 +443,13 @@ def test_ase(self): class TestDeepPotAPBCExcludeTypes(unittest.TestCase): @classmethod def setUpClass(self): - self.dp_original = DeepPot(FROZEN_MODEL_ET) - self.dp_compressed = DeepPot(COMPRESSED_MODEL_ET) + ( + self.INPUT_ET, + self.FROZEN_MODEL_ET, + self.COMPRESSED_MODEL_ET, + ) = _init_models_exclude_types() + self.dp_original = DeepPot(self.FROZEN_MODEL_ET) + self.dp_compressed = DeepPot(self.COMPRESSED_MODEL_ET) self.coords = np.array( [ 12.83, @@ -473,9 +477,9 @@ def setUpClass(self): @classmethod def tearDownClass(self): - _file_delete(INPUT_ET) - _file_delete(FROZEN_MODEL_ET) - _file_delete(COMPRESSED_MODEL_ET) + _file_delete(self.INPUT_ET) + _file_delete(self.FROZEN_MODEL_ET) + _file_delete(self.COMPRESSED_MODEL_ET) _file_delete("out.json") _file_delete("compress.json") _file_delete("checkpoint") diff --git a/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py b/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py index 5b6ac4e13e..bdf09cf3e8 100644 --- a/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py +++ b/source/tests/test_model_compression_se_a_type_one_side_exclude_types.py @@ -66,12 +66,11 @@ def _init_models(): return INPUT, frozen_model, compressed_model -INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() - - class TestDeepPotAPBCTypeOneSideExcludeTypes(unittest.TestCase): @classmethod def setUpClass(self): + INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() + self.dp_original = DeepPot(FROZEN_MODEL) self.dp_compressed = DeepPot(COMPRESSED_MODEL) self.coords = np.array(