Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_final_transformation of FullyCNN not set during initialization #63

Open
tztsai opened this issue Jul 18, 2023 · 2 comments
Open

_final_transformation of FullyCNN not set during initialization #63

tztsai opened this issue Jul 18, 2023 · 2 comments

Comments

@tztsai
Copy link

tztsai commented Jul 18, 2023

Directly passing a tensor to FullyCNN() will cause AttributeError: 'FullyCNN' object has no attribute 'final_transformation'. In the test of FullyCNN, the _final_transformation attribute is manually set. Perhaps this line should be moved into the __init__ method of FullyCNN.
Besides, there is a Mixin class related to the final transformation here, but the FullyCNN class does not inherit it. It instead implements the setter and getter methods of final_transformation in itself.

@raehik
Copy link
Contributor

raehik commented Jul 19, 2023

I have a feeling the test is setting a placeholder value for _final_transformation. I think it should be set when configuring the model object. See:

# Recover the model's class, based on the corresponding CLI parameters
try:
models_module = importlib.import_module(model_module_name)
model_cls = getattr(models_module, model_cls_name)
except ModuleNotFoundError as e:
raise type(e)("Could not find the specified module for : " + str(e))
except AttributeError as e:
raise type(e)("Could not find the specified model class: " + str(e))
net = model_cls(datasets[0].n_features, criterion.n_required_channels)
try:
transformation_cls = getattr(models.transforms, transformation_cls_name)
transformation = transformation_cls()
transformation.indices = criterion.precision_indices
net.final_transformation = transformation
except AttributeError as e:
raise type(e)("Could not find the specified transformation class: " + str(e))

It's a little clunky due to the dynamic module loading, which we could strip and add back only if required (it was not used). Maybe it belongs as an __init__ parameter...?

I guess setting the identity function as the default may be sensible...? But having it a required parameter would seem clearer. @tztsai what do you think?

@tztsai
Copy link
Author

tztsai commented Jul 19, 2023

I agree that adding final_transformation to the argument list would be clearer, but it needs to be a callable object, so it cannot be directly specified from the CMD arguments. Perhaps a map from names to callable transformations could be added so that the user can provide a string as an argument and it will be mapped to a transformation, e.g. {'identity': lambda x: x}?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants