-
Notifications
You must be signed in to change notification settings - Fork 590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] GMM (and normal distribution) fitting doesn't respect frozen parameters #1054
Comments
I think the finest control is the distribution that you define. I managed to modify the remove Please point me out if I did it wrong |
Hi @NicholasClark. Sorry for the late reply. You are correct that you can freeze individual parameters but you have to do it in a specific way to get it to stick. First, you added the Second, pomegranate does not allow you to incompletely specify distributions as starting points. This should probably raise a warning when it happens. So, what happened is that the distribution did not register as being initialized and so was overwritten in the first step of fitting a GMM. You can get around this by putting some value into This code works for me and keeps the means frozen. I took out the plotting stuff just because it wasn't relevant for me. from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import *
import numpy as np
import matplotlib.pyplot as plt
### Generate data for mixture model
np.random.seed(0)
X = np.concatenate([np.random.normal(4, 0.5, size=400),
np.random.normal(1.5, 0.5, size=600)])
XX = np.array(X).reshape(-1,1)
XX = torch.tensor(XX).float()
### Fit mixture model and freeze the mean of each distribution
m1 = torch.tensor([4]) ### mean = 4
m2 = torch.tensor([1.5]) ### mean = 1.5
d1 = Normal(means=m1, covs=[1], covariance_type='diag')
d2 = Normal(means=m2, covs=[1], covariance_type='diag')
d1.means.frozen = True
d2.means.frozen = True
model = GeneralMixtureModel([d1, d2], verbose=True).fit(XX)
print(model.distributions[0].means.frozen)
print("mean of Normal1: " + str(round(model.distributions[0].means.item(), 2)))
print("mean of Normal2: " + str(round(model.distributions[1].means.item(), 2))) When I run this it gives me:
|
I am trying to fit a mixture model of two normal distributions where I freeze the means at 4 and 1.5 and only fit the variances.
When I use GeneralMixtureModel, it changes the means (fitted means are 3.98 and 1.47) when it fits to the data anyway.
I notice the same issue if I try to fit just one Normal distribution and freeze the mean.
I may be doing something wrong, but I've tried it a number of different ways at this point.
Any help would be highly appreciated!
Here is code to reproduce the issue:
The text was updated successfully, but these errors were encountered: