You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I ran into this error when attempting to fit a distribution with missing data. I converted the input into a masked tensor, following the instructions in the docs, and ran into trouble. The error occus when calling torch.sum in line 262 of pomegranate/distributions/normal.py
The full stack trace is at the bottom
I think the problem is that in the line 'self._w_sum += torch.sum(sample_weight, dim=0),' pytorch attempts to add the masked tensor sample_weight to a standard, unmasked tensor of the same dimensions.
I went through the pytorch code in the debugger, and it seems that just before the error is triggered, the 'args' of torch_function() get this extra unmasked tensor added to them.
In a series of events afterwards, _binary_helper() is eventually called which attempts to call _set_data_mask() on the first argument of args. As this is just a normal tensor, and not a masked tensor with the _set_data_mask() attribute, the error is triggered.
Has masked tensors in pytorch been updated recently? Is this something to do with pomegranate's code not being compatible with the latest api?
I can hack my way past the error by intializing self._w_sum as a masked tensor before torch.sum line:
# can try putting these in-between line 261 and 262 in summarize() in normal.py
# n = sample_weight.size(1)
# self._w_sum = torch.masked.MaskedTensor(torch.zeros(n), torch.ones(n, dtype=torch.bool))
but then run into a different issue in the next line, with axis = 0 not being recognised as a valid kwarg parameter!
Describe the bug
pomegranate version: 1.1.0
pytorch version: 2.3.0
macOs sonoma 14.5
I ran into this error when attempting to fit a distribution with missing data. I converted the input into a masked tensor, following the instructions in the docs, and ran into trouble. The error occus when calling torch.sum in line 262 of pomegranate/distributions/normal.py
The full stack trace is at the bottom
I think the problem is that in the line 'self._w_sum += torch.sum(sample_weight, dim=0),' pytorch attempts to add the masked tensor sample_weight to a standard, unmasked tensor of the same dimensions.
I went through the pytorch code in the debugger, and it seems that just before the error is triggered, the 'args' of torch_function() get this extra unmasked tensor added to them.
In a series of events afterwards, _binary_helper() is eventually called which attempts to call _set_data_mask() on the first argument of args. As this is just a normal tensor, and not a masked tensor with the _set_data_mask() attribute, the error is triggered.
Has masked tensors in pytorch been updated recently? Is this something to do with pomegranate's code not being compatible with the latest api?
I can hack my way past the error by intializing self._w_sum as a masked tensor before torch.sum line:
but then run into a different issue in the next line, with axis = 0 not being recognised as a valid kwarg parameter!
Thanks in advance for any help!
To Reproduce
This error can be reproduced when trying out the example on the Bayesian Networks page of the docs (https://pomegranate.readthedocs.io/en/latest/tutorials/B_Model_Tutorial_6_Bayesian_Networks.html):
The full stack trace
The text was updated successfully, but these errors were encountered: