From 3b2aad8dd4ad0d8c6e8237ec0a048c8e85308e35 Mon Sep 17 00:00:00 2001 From: DiegoFH <34011351+DiegoFreitasH@users.noreply.github.com> Date: Mon, 18 Nov 2024 21:12:56 -0300 Subject: [PATCH] Fix bug #2606: Setting mixing_weights=False in SoftmaxLikelihood still adds the learnable parameter W (#2607) * Fix mixing_weights condition * Add test for mixing_weights=False --- gpytorch/likelihoods/softmax_likelihood.py | 2 +- test/likelihoods/test_softmax_likelihood.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/gpytorch/likelihoods/softmax_likelihood.py b/gpytorch/likelihoods/softmax_likelihood.py index fa16db253..65b7a60b3 100644 --- a/gpytorch/likelihoods/softmax_likelihood.py +++ b/gpytorch/likelihoods/softmax_likelihood.py @@ -41,7 +41,7 @@ def __init__( if num_classes is None: raise ValueError("num_classes is required") self.num_classes = num_classes - if mixing_weights is not None: + if mixing_weights: if num_features is None: raise ValueError("num_features is required with mixing weights") self.num_features: int = num_features diff --git a/test/likelihoods/test_softmax_likelihood.py b/test/likelihoods/test_softmax_likelihood.py index 15d729abb..bdd100c7f 100644 --- a/test/likelihoods/test_softmax_likelihood.py +++ b/test/likelihoods/test_softmax_likelihood.py @@ -61,3 +61,7 @@ class TestSoftmaxLikelihoodNoMixing(TestSoftmaxLikelihood): def create_likelihood(self): return SoftmaxLikelihood(num_features=6, num_classes=6, mixing_weights=False) + + def _test_learnable_parameters(self): + likelihood = self.create_likelihood() + self.assertEqual(len(list(likelihood.parameters())), 0)