From 22880173e64e38c5f1343bd27ade3ea401424b90 Mon Sep 17 00:00:00 2001 From: kaminow Date: Wed, 18 Oct 2023 09:14:36 -0400 Subject: [PATCH] Add abc inheritance/decorators for propper polymorphism. --- mtenn/combination.py | 4 +++- mtenn/readout.py | 3 ++- mtenn/representation.py | 3 ++- mtenn/strategy.py | 3 ++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mtenn/combination.py b/mtenn/combination.py index 38f7bbb..2cb4552 100644 --- a/mtenn/combination.py +++ b/mtenn/combination.py @@ -1,7 +1,9 @@ +import abc import torch -class Combination(torch.nn.Module): +class Combination(torch.nn.Module, abc.ABC): + @abc.abstractmethod def forward(self, pred_list, grad_dict, param_names, *model_params): """ This function signature should be the same for any Combination subclass diff --git a/mtenn/readout.py b/mtenn/readout.py index 691ba5f..41dbf2c 100644 --- a/mtenn/readout.py +++ b/mtenn/readout.py @@ -1,8 +1,9 @@ +import abc import torch from typing import Optional -class Readout(torch.nn.Module): +class Readout(torch.nn.Module, abc.ABC): pass diff --git a/mtenn/representation.py b/mtenn/representation.py index 3784381..46b84c0 100644 --- a/mtenn/representation.py +++ b/mtenn/representation.py @@ -1,5 +1,6 @@ +import abc import torch -class Representation(torch.nn.Module): +class Representation(torch.nn.Module, abc.ABC): pass diff --git a/mtenn/strategy.py b/mtenn/strategy.py index 46db842..42e4d09 100644 --- a/mtenn/strategy.py +++ b/mtenn/strategy.py @@ -1,8 +1,9 @@ +import abc from itertools import permutations import torch -class Strategy(torch.nn.Module): +class Strategy(torch.nn.Module, abc.ABC): pass