Skip to content

Commit

Permalink
v1.1.1
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed Oct 18, 2024
1 parent 34a6e9a commit 8e1a71e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
10 changes: 10 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@
Release History
===============

Version 1.1.1
==============


Highlights
----------

- Fixed an issue with categorical distributions being used on the GPU


Version 1.0.4
==============

Expand Down
4 changes: 3 additions & 1 deletion pomegranate/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def log_probability(self, X):
max_value=self.n_keys-1, ndim=2, shape=(-1, self.d),
check_parameter=self.check_data)

logps = torch.zeros(X.shape[0], dtype=self.probs.dtype)
logps = torch.zeros(X.shape[0], dtype=self.probs.dtype,
device=self.device)

for i in range(self.d):
if isinstance(X, torch.masked.MaskedTensor):
logp_ = self._log_probs[i][X[:, i]._masked_data]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='pomegranate',
version='1.1.0',
version='1.1.1',
author='Jacob Schreiber',
author_email='[email protected]',
packages=['pomegranate', 'pomegranate.distributions', 'pomegranate.hmm'],
Expand Down

0 comments on commit 8e1a71e

Please sign in to comment.