Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Error when fitting a Normal distribution with a masked tensor #1113

Open
mattShorvon opened this issue Aug 9, 2024 · 0 comments
Open

Comments

@mattShorvon
Copy link

Describe the bug
pomegranate version: 1.1.0
pytorch version: 2.3.0
macOs sonoma 14.5

I ran into this error when attempting to fit a distribution with missing data. I converted the input into a masked tensor, following the instructions in the docs, and ran into trouble. The error occus when calling torch.sum in line 262 of pomegranate/distributions/normal.py

image

The full stack trace is at the bottom

I think the problem is that in the line 'self._w_sum += torch.sum(sample_weight, dim=0),' pytorch attempts to add the masked tensor sample_weight to a standard, unmasked tensor of the same dimensions.

I went through the pytorch code in the debugger, and it seems that just before the error is triggered, the 'args' of torch_function() get this extra unmasked tensor added to them.

image

In a series of events afterwards, _binary_helper() is eventually called which attempts to call _set_data_mask() on the first argument of args. As this is just a normal tensor, and not a masked tensor with the _set_data_mask() attribute, the error is triggered.

Has masked tensors in pytorch been updated recently? Is this something to do with pomegranate's code not being compatible with the latest api?

I can hack my way past the error by intializing self._w_sum as a masked tensor before torch.sum line:

# can try putting these in-between line 261 and 262 in summarize() in normal.py
# n = sample_weight.size(1)
# self._w_sum = torch.masked.MaskedTensor(torch.zeros(n), torch.ones(n, dtype=torch.bool))

but then run into a different issue in the next line, with axis = 0 not being recognised as a valid kwarg parameter!

image

Thanks in advance for any help!

To Reproduce
This error can be reproduced when trying out the example on the Bayesian Networks page of the docs (https://pomegranate.readthedocs.io/en/latest/tutorials/B_Model_Tutorial_6_Bayesian_Networks.html):

import numpy as np
import torch
from pomegranate.distributions import *

X = np.random.randint(2, size=(10, 2))
X_torch = torch.tensor(X[:4])
mask = torch.tensor([[True, False],
                     [False, True],
                     [True, True],
                     [False, False]])

X_masked = torch.masked.MaskedTensor(X_torch, mask=mask)
Normal().fit(X_masked)

The full stack trace

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[30], [line 13](vscode-notebook-cell:?execution_count=30&line=13)
      [7](vscode-notebook-cell:?execution_count=30&line=7) mask = torch.tensor([[True, False],
      [8](vscode-notebook-cell:?execution_count=30&line=8)                      [False, True],
      [9](vscode-notebook-cell:?execution_count=30&line=9)                      [True, True],
     [10](vscode-notebook-cell:?execution_count=30&line=10)                      [False, False]])
     [12](vscode-notebook-cell:?execution_count=30&line=12) X_masked = torch.masked.MaskedTensor(X_torch, mask=mask)
---> [13](vscode-notebook-cell:?execution_count=30&line=13) Normal().fit(X_masked)
     [15](vscode-notebook-cell:?execution_count=30&line=15) # can try putting these in-between line 261 and 262 in summarize() in normal.py
     [16](vscode-notebook-cell:?execution_count=30&line=16) # n = sample_weights.size(1)
     [17](vscode-notebook-cell:?execution_count=30&line=17) # self._w_sum = torch.masked.MaskedTensor(torch.zeros(n), torch.ones(n, dtype=torch.bool))

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/_distribution.py:75, in Distribution.fit(self, X, sample_weight)
     [74](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/_distribution.py:74) def fit(self, X, sample_weight=None):
---> [75](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/_distribution.py:75) 	self.summarize(X, sample_weight=sample_weight)
     [76](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/_distribution.py:76) 	self.from_summaries()
     [77](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/_distribution.py:77) 	return self

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/normal.py:262, in Normal.summarize(self, X, sample_weight)
    [260](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/normal.py:260) sample_weight = _cast_as_tensor(sample_weight, dtype=self.means.dtype)
    [261](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/normal.py:261) if self.covariance_type == 'full':
--> [262](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/normal.py:262) 	self._w_sum += torch.sum(sample_weight, dim=0)
    [263](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/normal.py:263) 	self._xw_sum += torch.sum(X * sample_weight, axis=0)
    [264](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/pomegranate/distributions/normal.py:264) 	self._xxw_sum += torch.matmul((X * sample_weight).T, X)

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:276, in MaskedTensor.__torch_function__(cls, func, types, args, kwargs)
    [274](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:274)     return NotImplemented
    [275](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:275) with torch._C.DisableTorchFunctionSubclass():
--> [276](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:276)     ret = func(*args, **kwargs)
    [277](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:277)     if func in get_default_nowrap_functions():
    [278](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:278)         return ret

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:292, in MaskedTensor.__torch_dispatch__(cls, func, types, args, kwargs)
    [290](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:290) from ._ops_refs import _MASKEDTENSOR_DISPATCH_TABLE
    [291](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:291) if func in _MASKEDTENSOR_DISPATCH_TABLE:
--> [292](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:292)     return _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs)
    [294](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:294) msg = (
    [295](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:295)     f"{func.__name__} is not implemented in __torch_dispatch__ for MaskedTensor.\n"
    [296](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:296)     "If you would like this operator to be supported, please file an issue for a feature request at "
   (...)
    [299](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:299)     "to also include a proposal for the semantics."
    [300](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:300) )
    [301](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/core.py:301) warnings.warn(msg)

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/_ops_refs.py:253, in _general_binary(func, *args, **kwargs)
    [251](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/_ops_refs.py:251) @register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS)
    [252](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/_ops_refs.py:252) def _general_binary(func, *args, **kwargs):
--> [253](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/_ops_refs.py:253)     return _apply_native_binary(func, *args, **kwargs)

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:193, in _apply_native_binary(fn, *args, **kwargs)
    [191](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:191)     return NATIVE_BINARY_MAP[fn](*args, **kwargs)
    [192](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:192) if fn in NATIVE_INPLACE_BINARY_FNS:
--> [193](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:193)     return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
    [194](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:194) return NotImplemented

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:168, in _torch_inplace_binary.<locals>.binary_fn(*args, **kwargs)
    [167](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:167) def binary_fn(*args, **kwargs):
--> [168](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:168)     return _binary_helper(fn, args, kwargs, inplace=True)

File /opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:145, in _binary_helper(fn, args, kwargs, inplace)
    [143](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:143)     print(args)
    [144](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:144)     breakpoint()
--> [145](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:145)     args[0]._set_data_mask(result_data, mask_args[0])
    [146](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:146)     return args[0]
    [147](https://file+.vscode-resource.vscode-cdn.net/opt/miniconda3/envs/adjudication_model/lib/python3.11/site-packages/torch/masked/maskedtensor/binary.py:147) else:

AttributeError: 'Tensor' object has no attribute '_set_data_mask'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant