You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for building such an amazing package and maintaining it.
When attempting to use torchinfo.summary() on models that include a torch.distributions.Categorical layer in the forward pass, torchinfo fails to complete the summary and throws an error. This issue appears to stem from Categorical not having certain tensor-like properties that torchinfo expects, making it difficult for torchinfo to process.
This issue impacts users attempting to use Categorical within the forward pass of probabilistic models and limits torchinfo's effectiveness for such model architectures.
Reproducible example:
importtorchimporttorch.nnasnnimporttorchinfofromtorch.distributionsimportCategorical# Define a simple model with CategoricalclassProbabilisticModel(nn.Module):
def__init__(self, input_dim, output_dim):
super(ProbabilisticModel, self).__init__()
self.linear=nn.Linear(input_dim, output_dim)
self.final=nn.LogSoftmax(dim=-1)
defforward(self, x):
logits=self.final(self.linear(x))
returnCategorical(logits=logits)
# Initialize the model and inputmodel=ProbabilisticModel(input_dim=10, output_dim=5)
input_data=torch.randn(1, 10)
# Attempt to run torchinfo summarytry:
torchinfo.summary(model, input_size=(1, 10))
exceptExceptionase:
print(f"Encountered an error: {e}")
If you drop Categorical and just return logits, it will be fine.
Addressing this issue will provide the model info.
I do not have a solution on top of my mind, however, I think, it might be possible to check some internal method in case of having problem with extracting shape. For example, self._output_shape if runs into an error, and the user can define it in those customized models.
The text was updated successfully, but these errors were encountered:
Thanks for building such an amazing package and maintaining it.
When attempting to use
torchinfo.summary()
on models that include a torch.distributions.Categorical layer in the forward pass,torchinfo
fails to complete the summary and throws an error. This issue appears to stem fromCategorical
not having certain tensor-like properties thattorchinfo
expects, making it difficult fortorchinfo
to process.This issue impacts users attempting to use
Categorical
within the forward pass of probabilistic models and limitstorchinfo
's effectiveness for such model architectures.Reproducible example:
If you drop
Categorical
and just returnlogits
, it will be fine.Addressing this issue will provide the model info.
I do not have a solution on top of my mind, however, I think, it might be possible to check some internal method in case of having problem with extracting shape. For example,
self._output_shape
if runs into an error, and the user can define it in those customized models.The text was updated successfully, but these errors were encountered: