From 8e1a71e63812a2b9ff5afd0a081236c4318a7cb2 Mon Sep 17 00:00:00 2001 From: Jacob Schreiber Date: Fri, 18 Oct 2024 19:26:50 +0000 Subject: [PATCH] v1.1.1 --- docs/whats_new.rst | 10 ++++++++++ pomegranate/distributions/categorical.py | 4 +++- setup.py | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/whats_new.rst b/docs/whats_new.rst index 90658f63..4febaf23 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -5,6 +5,16 @@ Release History =============== +Version 1.1.1 +============== + + +Highlights +---------- + + - Fixed an issue with categorical distributions being used on the GPU + + Version 1.0.4 ============== diff --git a/pomegranate/distributions/categorical.py b/pomegranate/distributions/categorical.py index 132bee1b..28bb5210 100644 --- a/pomegranate/distributions/categorical.py +++ b/pomegranate/distributions/categorical.py @@ -182,7 +182,9 @@ def log_probability(self, X): max_value=self.n_keys-1, ndim=2, shape=(-1, self.d), check_parameter=self.check_data) - logps = torch.zeros(X.shape[0], dtype=self.probs.dtype) + logps = torch.zeros(X.shape[0], dtype=self.probs.dtype, + device=self.device) + for i in range(self.d): if isinstance(X, torch.masked.MaskedTensor): logp_ = self._log_probs[i][X[:, i]._masked_data] diff --git a/setup.py b/setup.py index 22bf909e..85be9956 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='pomegranate', - version='1.1.0', + version='1.1.1', author='Jacob Schreiber', author_email='jmschreiber91@gmail.com', packages=['pomegranate', 'pomegranate.distributions', 'pomegranate.hmm'],