Skip to content

Commit

Permalink
Revert "fix NoneType self.model in UTs"
Browse files Browse the repository at this point in the history
This reverts commit 8082b49.
  • Loading branch information
iProzd committed Jan 26, 2024
1 parent b7712cd commit 1f09437
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 117 deletions.
24 changes: 0 additions & 24 deletions source/tests/pt/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ def stretch_box(old_coord, old_box, new_box):


class TestForce:
def __init__(self):
self.model = None

def test(
self,
):
Expand Down Expand Up @@ -88,9 +85,6 @@ def ff(_coord):


class TestVirial:
def __init__(self):
self.model = None

def test(
self,
):
Expand Down Expand Up @@ -135,9 +129,6 @@ def ff(bb):


class TestEnergyModelSeAForce(unittest.TestCase, TestForce):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_se_e2_a)
sampled = make_sample(model_params)
Expand All @@ -146,9 +137,6 @@ def setUp(self):


class TestEnergyModelSeAVirial(unittest.TestCase, TestVirial):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_se_e2_a)
sampled = make_sample(model_params)
Expand All @@ -157,9 +145,6 @@ def setUp(self):


class TestEnergyModelDPA1Force(unittest.TestCase, TestForce):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_dpa1)
sampled = make_sample(model_params)
Expand All @@ -168,9 +153,6 @@ def setUp(self):


class TestEnergyModelDPA1Virial(unittest.TestCase, TestVirial):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_dpa1)
sampled = make_sample(model_params)
Expand All @@ -179,9 +161,6 @@ def setUp(self):


class TestEnergyModelDPA2Force(unittest.TestCase, TestForce):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand All @@ -197,9 +176,6 @@ def setUp(self):


class TestEnergyModelDPAUniVirial(unittest.TestCase, TestVirial):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand Down
22 changes: 1 addition & 21 deletions source/tests/pt/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,10 @@ def make_sample(model_params):


class TestPermutation:
def __init__(self):
self.model = None

def test(
self,
):
self.model = None
natoms = 5
cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE)
cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE)
Expand Down Expand Up @@ -238,9 +236,6 @@ def test(


class TestEnergyModelSeA(unittest.TestCase, TestPermutation):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_se_e2_a)
sampled = make_sample(model_params)
Expand All @@ -249,9 +244,6 @@ def setUp(self):


class TestEnergyModelDPA1(unittest.TestCase, TestPermutation):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_dpa1)
sampled = make_sample(model_params)
Expand All @@ -260,9 +252,6 @@ def setUp(self):


class TestEnergyModelDPA2(unittest.TestCase, TestPermutation):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand All @@ -278,9 +267,6 @@ def setUp(self):


class TestForceModelDPA2(unittest.TestCase, TestPermutation):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand All @@ -299,9 +285,6 @@ def setUp(self):

@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, TestPermutation):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_hybrid)
sampled = make_sample(model_params)
Expand All @@ -311,9 +294,6 @@ def setUp(self):

@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, TestPermutation):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_hybrid)
model_params["fitting_net"]["type"] = "direct_force_ener"
Expand Down
10 changes: 1 addition & 9 deletions source/tests/pt/test_permutation_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@


class TestPermutationDenoise:
def __init__(self):
self.model = None

def test(
self,
):
self.model = None
natoms = 5
cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE)
cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE)
Expand Down Expand Up @@ -70,9 +68,6 @@ def test(


class TestDenoiseModelDPA1(unittest.TestCase, TestPermutationDenoise):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_dpa1)
sampled = make_sample(model_params)
Expand All @@ -81,9 +76,6 @@ def setUp(self):


class TestDenoiseModelDPA2(unittest.TestCase, TestPermutationDenoise):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand Down
22 changes: 1 addition & 21 deletions source/tests/pt/test_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@


class TestRot:
def __init__(self):
self.model = None

def test(
self,
):
self.model = None
prec = 1e-10
natoms = 5
cell = 10.0 * torch.eye(3, dtype=dtype).to(env.DEVICE)
Expand Down Expand Up @@ -113,9 +111,6 @@ def test(


class TestEnergyModelSeA(unittest.TestCase, TestRot):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_se_e2_a)
sampled = make_sample(model_params)
Expand All @@ -124,9 +119,6 @@ def setUp(self):


class TestEnergyModelDPA1(unittest.TestCase, TestRot):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_dpa1)
sampled = make_sample(model_params)
Expand All @@ -135,9 +127,6 @@ def setUp(self):


class TestEnergyModelDPA2(unittest.TestCase, TestRot):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand All @@ -153,9 +142,6 @@ def setUp(self):


class TestForceModelDPA2(unittest.TestCase, TestRot):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand All @@ -174,9 +160,6 @@ def setUp(self):

@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, TestRot):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_hybrid)
sampled = make_sample(model_params)
Expand All @@ -186,9 +169,6 @@ def setUp(self):

@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, TestRot):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_hybrid)
model_params["fitting_net"]["type"] = "direct_force_ener"
Expand Down
10 changes: 1 addition & 9 deletions source/tests/pt/test_rot_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@


class TestRotDenoise:
def __init__(self):
self.model = None

def test(
self,
):
self.model = None
prec = 1e-10
natoms = 5
cell = 10.0 * torch.eye(3, dtype=dtype).to(env.DEVICE)
Expand Down Expand Up @@ -101,9 +99,6 @@ def test(


class TestDenoiseModelDPA1(unittest.TestCase, TestRotDenoise):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_dpa1)
sampled = make_sample(model_params)
Expand All @@ -112,9 +107,6 @@ def setUp(self):


class TestDenoiseModelDPA2(unittest.TestCase, TestRotDenoise):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand Down
22 changes: 1 addition & 21 deletions source/tests/pt/test_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@


class TestTrans:
def __init__(self):
self.model = None

def test(
self,
):
self.model = None
natoms = 5
cell = torch.rand([3, 3], dtype=dtype).to(env.DEVICE)
cell = (cell + cell.T) + 5.0 * torch.eye(3).to(env.DEVICE)
Expand Down Expand Up @@ -69,9 +67,6 @@ def test(


class TestEnergyModelSeA(unittest.TestCase, TestTrans):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_se_e2_a)
sampled = make_sample(model_params)
Expand All @@ -80,9 +75,6 @@ def setUp(self):


class TestEnergyModelDPA1(unittest.TestCase, TestTrans):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_dpa1)
sampled = make_sample(model_params)
Expand All @@ -91,9 +83,6 @@ def setUp(self):


class TestEnergyModelDPA2(unittest.TestCase, TestTrans):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand All @@ -109,9 +98,6 @@ def setUp(self):


class TestForceModelDPA2(unittest.TestCase, TestTrans):
def __init__(self):
super().__init__()

def setUp(self):
model_params_sample = copy.deepcopy(model_dpa2)
model_params_sample["descriptor"]["rcut"] = model_params_sample["descriptor"][
Expand All @@ -130,9 +116,6 @@ def setUp(self):

@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, TestTrans):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_hybrid)
sampled = make_sample(model_params)
Expand All @@ -142,9 +125,6 @@ def setUp(self):

@unittest.skip("hybrid not supported at the moment")
class TestForceModelHybrid(unittest.TestCase, TestTrans):
def __init__(self):
super().__init__()

def setUp(self):
model_params = copy.deepcopy(model_hybrid)
model_params["fitting_net"]["type"] = "direct_force_ener"
Expand Down
Loading

0 comments on commit 1f09437

Please sign in to comment.