From 39a5dc1c1d2dad6b582dd09b6844a49e7600640e Mon Sep 17 00:00:00 2001 From: Brett Date: Fri, 20 Oct 2023 16:23:28 -0400 Subject: [PATCH] change subclass check for AbstractModelContainer --- src/stpipe/container.py | 72 +++++++++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/src/stpipe/container.py b/src/stpipe/container.py index 92852015..10269c8e 100644 --- a/src/stpipe/container.py +++ b/src/stpipe/container.py @@ -16,39 +16,35 @@ class without requiring that they inherit this class. def __subclasshook__(cls, C): if cls is not AbstractModelContainer: return False - mro = C.__mro__ # must provide "crds_observatory" - if not any(hasattr(CC, "crds_observatory") for CC in mro): + if not hasattr(C, "crds_observatory"): return False # and be iterable - if not any(hasattr(CC, "__iter__") for CC in mro): + if not hasattr(C, "__iter__"): return False - # and implement "save" - if not any(hasattr(CC, "save") for CC in mro): - return False - - # and that "save" is a function - save_function = getattr(C, "save") - if not callable(save_function): - return False - - # and that "save" accepts the required arguments/parameter - save_signature = inspect.signature(C.save) - target_signature = inspect.signature(cls.save) - for parameter in target_signature.parameters: - if parameter == "self": - pass - # check that this parameter/argument exists in the save function - if parameter not in save_signature.parameters: - return False - save_parameter = save_signature.parameters[parameter] - # make sure the required parameter is not positional only - if save_parameter.kind == inspect.Parameter.POSITIONAL_ONLY: + # and implements the functions defined below + for function_name in ["save", "read_asn", "from_asn"]: + function = getattr(C, function_name, None) + if function is None or not callable(function): return False + # check that functionaccepts the required arguments/parameter + signature = inspect.signature(function) + target_signature = inspect.signature(getattr(cls, function_name)) + for parameter_name in target_signature.parameters: + if parameter_name == "self": + pass + # check that this required parameter/argument exists in the function signature + if parameter_name not in signature.parameters: + return False + parameter = signature.parameters[parameter_name] + # make sure the required parameter is not positional only + if parameter.kind == inspect.Parameter.POSITIONAL_ONLY: + return False + return True @property @@ -83,3 +79,31 @@ def save(self, path, save_model_func): List of output file paths of where the models were saved. """ pass + + @staticmethod + @abc.abstractmethod + def read_asn(filepath): + """ + Load from an association file. + + Parameters + ---------- + filepath : str + The path to an association file. + """ + pass + + @abc.abstractmethod + def from_asn(self, asn_data, asn_file_path=None): + """ + Load files from an association file. + + Parameters + ---------- + asn_data: dictionary + An association dictionary. + + asn_file_path : str or None + Filepath of the association, if known. + """ + pass