diff --git a/mtenn/conversion_utils/e3nn.py b/mtenn/conversion_utils/e3nn.py index 188473d..7f303fe 100644 --- a/mtenn/conversion_utils/e3nn.py +++ b/mtenn/conversion_utils/e3nn.py @@ -10,12 +10,12 @@ class E3NN(Network): - def __init__(self, model=None, model_kwargs=None): + def __init__(self, *args, model=None, **kwargs): ## If no model is passed, construct E3NN model with model_kwargs, ## otherwise copy all parameters and weights over if model is None: - super(E3NN, self).__init__(**model_kwargs) - self.model_parameters = model_kwargs + super(E3NN, self).__init__(*args, **kwargs) + self.model_parameters = kwargs else: # this will need changing to include model features of e3nn atomref = model.atomref.weight.detach().clone() diff --git a/mtenn/conversion_utils/schnet.py b/mtenn/conversion_utils/schnet.py index c696915..8beb476 100644 --- a/mtenn/conversion_utils/schnet.py +++ b/mtenn/conversion_utils/schnet.py @@ -10,11 +10,11 @@ class SchNet(PygSchNet): - def __init__(self, model=None): + def __init__(self, *args, model=None, **kwargs): ## If no model is passed, construct default SchNet model, otherwise copy ## all parameters and weights over if model is None: - super(SchNet, self).__init__() + super(SchNet, self).__init__(*args, **kwargs) else: try: atomref = model.atomref.weight.detach().clone()