Skip to content

Commit

Permalink
change subclass check for AbstractModelContainer
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram committed Oct 20, 2023
1 parent 548b92e commit 39a5dc1
Showing 1 changed file with 48 additions and 24 deletions.
72 changes: 48 additions & 24 deletions src/stpipe/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,35 @@ class without requiring that they inherit this class.
def __subclasshook__(cls, C):
if cls is not AbstractModelContainer:
return False
mro = C.__mro__

# must provide "crds_observatory"
if not any(hasattr(CC, "crds_observatory") for CC in mro):
if not hasattr(C, "crds_observatory"):
return False

# and be iterable
if not any(hasattr(CC, "__iter__") for CC in mro):
if not hasattr(C, "__iter__"):
return False

# and implement "save"
if not any(hasattr(CC, "save") for CC in mro):
return False

# and that "save" is a function
save_function = getattr(C, "save")
if not callable(save_function):
return False

# and that "save" accepts the required arguments/parameter
save_signature = inspect.signature(C.save)
target_signature = inspect.signature(cls.save)
for parameter in target_signature.parameters:
if parameter == "self":
pass
# check that this parameter/argument exists in the save function
if parameter not in save_signature.parameters:
return False
save_parameter = save_signature.parameters[parameter]
# make sure the required parameter is not positional only
if save_parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
# and implements the functions defined below
for function_name in ["save", "read_asn", "from_asn"]:
function = getattr(C, function_name, None)
if function is None or not callable(function):
return False

# check that functionaccepts the required arguments/parameter
signature = inspect.signature(function)
target_signature = inspect.signature(getattr(cls, function_name))
for parameter_name in target_signature.parameters:
if parameter_name == "self":
pass
# check that this required parameter/argument exists in the function signature
if parameter_name not in signature.parameters:
return False
parameter = signature.parameters[parameter_name]
# make sure the required parameter is not positional only
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
return False

return True

@property
Expand Down Expand Up @@ -83,3 +79,31 @@ def save(self, path, save_model_func):
List of output file paths of where the models were saved.
"""
pass

@staticmethod
@abc.abstractmethod
def read_asn(filepath):
"""
Load from an association file.
Parameters
----------
filepath : str
The path to an association file.
"""
pass

@abc.abstractmethod
def from_asn(self, asn_data, asn_file_path=None):
"""
Load files from an association file.
Parameters
----------
asn_data: dictionary
An association dictionary.
asn_file_path : str or None
Filepath of the association, if known.
"""
pass

0 comments on commit 39a5dc1

Please sign in to comment.