Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fit of GMM returns initialization error because some distributions are not returned by fit_predict #1043

Open
Mriv31 opened this issue Jun 10, 2023 · 0 comments

Comments

@Mriv31
Copy link

Mriv31 commented Jun 10, 2023

Describe the bug
I try to fit a GMM with a lot of populations but get the following errors when the number of population is too high.
I understand this is due to the fit_predict method predicting zero elements for some of the distributions of the GMM.
The subsequent attempt by the function initialize to fit the data X[idx] with idx being a zero length array raises an error.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[78], line 1
----> 1 model.fit(torch.from_numpy(x_hat[ind][:,np.newaxis]))

File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:245, in GeneralMixtureModel.fit(self, X, sample_weight, priors)
    242 start_time = time.time()
    244 last_logp = logp
--> 245 logp = self.summarize(X, sample_weight=sample_weight, 
    246 	priors=priors)
    248 if i > 0:
    249 	improvement = logp - last_logp

File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:308, in GeneralMixtureModel.summarize(self, X, sample_weight, priors)
    306 X = _check_parameter(_cast_as_tensor(X), "X", ndim=2)
    307 if not self._initialized:
--> 308 	self._initialize(X, sample_weight=sample_weight)
    310 sample_weight = _reshape_weights(X, _cast_as_tensor(sample_weight, 
    311 	dtype=torch.float32), device=self.device)
    313 e = self._emission_matrix(X, priors=priors)

File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:162, in GeneralMixtureModel._initialize(self, X, sample_weight)
    159 for i in range(self.k):
    160 	idx = y_hat == i
--> 162 	self.distributions[i].fit(X[idx], sample_weight=sample_weight[idx])
    163 	self.priors[i] = idx.type(torch.float32).mean()
    165 self._initialized = True

File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:67, in Distribution.fit(self, X, sample_weight)
     66 def fit(self, X, sample_weight=None):
---> 67 	self.summarize(X, sample_weight=sample_weight)
     68 	self.from_summaries()
     69 	return self

File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/normal.py:258, in Normal.summarize(self, X, sample_weight)
    255 if self.frozen == True:
    256 	return
--> 258 X, sample_weight = super().summarize(X, sample_weight=sample_weight)
    259 X = _cast_as_tensor(X, dtype=self.means.dtype)
    261 if self.covariance_type == 'full':

File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:73, in Distribution.summarize(self, X, sample_weight)
     71 def summarize(self, X, sample_weight=None):
     72 	if not self._initialized:
---> 73 		self._initialize(len(X[0]))
     75 	X = _cast_as_tensor(X)
     76 	_check_parameter(X, "X", ndim=2, shape=(-1, self.d), 
     77 		check_parameter=self.check_data)

IndexError: index 0 is out of bounds for dimension 0 with size 0

To Reproduce

A minimally reproducible example although rather uninteresting :

import torch
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import Normal


dl = []
for i in range(40):
    p = peaks[i] 
    dl.append(Normal([0,1]).double())

model = GeneralMixtureModel(dl)
model.fit(torch.randint(1,[100,1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant