Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with torchinfo.summary() Failing on Models with torch.distributions.Categorical #329

Open
Naeemkh opened this issue Oct 30, 2024 · 0 comments

Comments

@Naeemkh
Copy link

Naeemkh commented Oct 30, 2024

Thanks for building such an amazing package and maintaining it.

When attempting to use torchinfo.summary() on models that include a torch.distributions.Categorical layer in the forward pass, torchinfo fails to complete the summary and throws an error. This issue appears to stem from Categorical not having certain tensor-like properties that torchinfo expects, making it difficult for torchinfo to process.

This issue impacts users attempting to use Categorical within the forward pass of probabilistic models and limits torchinfo's effectiveness for such model architectures.

Reproducible example:

import torch
import torch.nn as nn
import torchinfo
from torch.distributions import Categorical

# Define a simple model with Categorical
class ProbabilisticModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProbabilisticModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.final = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        logits = self.final(self.linear(x))
        return Categorical(logits=logits)

# Initialize the model and input
model = ProbabilisticModel(input_dim=10, output_dim=5)
input_data = torch.randn(1, 10)

# Attempt to run torchinfo summary
try:
    torchinfo.summary(model, input_size=(1, 10))
except Exception as e:
    print(f"Encountered an error: {e}")

If you drop Categorical and just return logits, it will be fine.

Addressing this issue will provide the model info.

I do not have a solution on top of my mind, however, I think, it might be possible to check some internal method in case of having problem with extracting shape. For example, self._output_shape if runs into an error, and the user can define it in those customized models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant