You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I try to fit a GMM with a lot of populations but get the following errors when the number of population is too high.
I understand this is due to the fit_predict method predicting zero elements for some of the distributions of the GMM.
The subsequent attempt by the function initialize to fit the data X[idx] with idx being a zero length array raises an error.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[78], line 1
----> 1 model.fit(torch.from_numpy(x_hat[ind][:,np.newaxis]))
File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:245, in GeneralMixtureModel.fit(self, X, sample_weight, priors)
242 start_time = time.time()
244 last_logp = logp
--> 245 logp = self.summarize(X, sample_weight=sample_weight,
246 priors=priors)
248 if i > 0:
249 improvement = logp - last_logp
File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:308, in GeneralMixtureModel.summarize(self, X, sample_weight, priors)
306 X = _check_parameter(_cast_as_tensor(X), "X", ndim=2)
307 if not self._initialized:
--> 308 self._initialize(X, sample_weight=sample_weight)
310 sample_weight = _reshape_weights(X, _cast_as_tensor(sample_weight,
311 dtype=torch.float32), device=self.device)
313 e = self._emission_matrix(X, priors=priors)
File ~/.local/lib/python3.10/site-packages/pomegranate/gmm.py:162, in GeneralMixtureModel._initialize(self, X, sample_weight)
159 for i in range(self.k):
160 idx = y_hat == i
--> 162 self.distributions[i].fit(X[idx], sample_weight=sample_weight[idx])
163 self.priors[i] = idx.type(torch.float32).mean()
165 self._initialized = True
File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:67, in Distribution.fit(self, X, sample_weight)
66 def fit(self, X, sample_weight=None):
---> 67 self.summarize(X, sample_weight=sample_weight)
68 self.from_summaries()
69 return self
File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/normal.py:258, in Normal.summarize(self, X, sample_weight)
255 if self.frozen == True:
256 return
--> 258 X, sample_weight = super().summarize(X, sample_weight=sample_weight)
259 X = _cast_as_tensor(X, dtype=self.means.dtype)
261 if self.covariance_type == 'full':
File ~/.local/lib/python3.10/site-packages/pomegranate/distributions/_distribution.py:73, in Distribution.summarize(self, X, sample_weight)
71 def summarize(self, X, sample_weight=None):
72 if not self._initialized:
---> 73 self._initialize(len(X[0]))
75 X = _cast_as_tensor(X)
76 _check_parameter(X, "X", ndim=2, shape=(-1, self.d),
77 check_parameter=self.check_data)
IndexError: index 0 is out of bounds for dimension 0 with size 0
To Reproduce
A minimally reproducible example although rather uninteresting :
import torch
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import Normal
dl = []
for i in range(40):
p = peaks[i]
dl.append(Normal([0,1]).double())
model = GeneralMixtureModel(dl)
model.fit(torch.randint(1,[100,1]))
The text was updated successfully, but these errors were encountered:
Describe the bug
I try to fit a GMM with a lot of populations but get the following errors when the number of population is too high.
I understand this is due to the
fit_predict
method predicting zero elements for some of the distributions of the GMM.The subsequent attempt by the function
initialize
to fit the dataX[idx]
withidx
being a zero length array raises an error.To Reproduce
A minimally reproducible example although rather uninteresting :
The text was updated successfully, but these errors were encountered: