Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change method to get module name #335

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

sup3rgiu
Copy link

Right now, torchinfo is using module.__class__.__name__ to retrieve the nn.Module name which will be shown in the summary. However, every PyTorch module exposes a method _get_name(), which in the default implementation simply returns self.__class__.__name__. However, a custom layer could overwrite this method to return a custom module name, avoiding to directly overwrite self.__class__.__name__ which is not a good idea in general.

Demo:

class CustomLayer(torch.nn.Module):
    def __init__(
        self,
    ) -> None:
        super().__init__()
        #self.__class__.__name__ = self._get_name()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input
    
    def _get_name(self) -> str:
        return "CustomFancyName"

img = torch.randn(1, 3, 224, 224)
model = CustomLayer()
summary(model, input_size=[img.shape], dtypes=[torch.float32], depth=2)

Output:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
CustomFancyName                         [1, 3, 224, 224]          --
==========================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.60
==========================================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant