From f212e3d2afeb4998798dfd27885155f8f44b07f8 Mon Sep 17 00:00:00 2001 From: kaminow Date: Thu, 9 Nov 2023 16:51:01 -0500 Subject: [PATCH] First pass on unification. --- mtenn/conversion_utils/e3nn.py | 6 +++--- mtenn/conversion_utils/schnet.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 188473d..7f303fe 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -10,12 +10,12 @@ class E3NN(Network): - def __init__(self, model=None, model_kwargs=None): + def __init__(self, *args, model=None, **kwargs): ## If no model is passed, construct E3NN model with model_kwargs, ## otherwise copy all parameters and weights over if model is None: - super(E3NN, self).__init__(**model_kwargs) - self.model_parameters = model_kwargs + super(E3NN, self).__init__(*args, **kwargs) + self.model_parameters = kwargs else: # this will need changing to include model features of e3nn atomref = model.atomref.weight.detach().clone() diff --git a/mtenn/conversion_utils/schnet.py b/mtenn/conversion_utils/schnet.py index c696915..8beb476 100644 --- a/mtenn/conversion_utils/schnet.py +++ b/mtenn/conversion_utils/schnet.py @@ -10,11 +10,11 @@ class SchNet(PygSchNet): - def __init__(self, model=None): + def __init__(self, *args, model=None, **kwargs): ## If no model is passed, construct default SchNet model, otherwise copy ## all parameters and weights over if model is None: - super(SchNet, self).__init__() + super(SchNet, self).__init__(*args, **kwargs) else: try: atomref = model.atomref.weight.detach().clone()