From f9461b54e2a8e034c81d16a5dcac4ef3bbb790b5 Mon Sep 17 00:00:00 2001 From: Matthew Bradbury Date: Mon, 28 Oct 2024 12:30:34 +0000 Subject: [PATCH] Avoid operations on uninitialised memory in HMMs --- pomegranate/hmm/dense_hmm.py | 20 ++++++++++---------- pomegranate/hmm/sparse_hmm.py | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pomegranate/hmm/dense_hmm.py b/pomegranate/hmm/dense_hmm.py index 437f3134..6378de39 100644 --- a/pomegranate/hmm/dense_hmm.py +++ b/pomegranate/hmm/dense_hmm.py @@ -199,24 +199,24 @@ def add_edge(self, start, end, prob): if start == self.start: if self.starts is None: - self.starts = torch.empty(n, dtype=self.dtype, - device=self.device) - inf + self.starts = torch.full((n,), NEGINF, dtype=self.dtype, + device=self.device) idx = self.distributions.index(end) self.starts[idx] = math.log(prob) elif end == self.end: if self.ends is None: - self.ends = torch.empty(n, dtype=self.dtype, - device=self.device) - inf + self.ends = torch.full((n,), NEGINF, dtype=self.dtype, + device=self.device) idx = self.distributions.index(start) self.ends[idx] = math.log(prob) else: if self.edges is None: - self.edges = torch.empty((n, n), dtype=self.dtype, - device=self.device) - inf + self.edges = torch.full((n, n), NEGINF, dtype=self.dtype, + device=self.device) idx1 = self.distributions.index(start) idx2 = self.distributions.index(end) @@ -250,8 +250,8 @@ def sample(self, n): + "end probabilities.") if self.ends is None: - ends = torch.zeros(self.n_distributions, dtype=self.edges.dtype, - device=self.edges.device) + float("-inf") + ends = torch.full((self.n_distributions,), NEGINF, dtype=self.edges.dtype, + device=self.edges.device) else: ends = self.ends @@ -454,8 +454,8 @@ def backward(self, X=None, emissions=None, priors=None): emissions = _check_inputs(self, X, emissions, priors) n, l, _ = emissions.shape - b = torch.zeros(l, n, self.n_distributions, dtype=self.dtype, - device=self.device) + float("-inf") + b = torch.full((l, n, self.n_distributions), NEGINF, dtype=self.dtype, + device=self.device) b[-1] = self.ends t_max = self.edges.max() diff --git a/pomegranate/hmm/sparse_hmm.py b/pomegranate/hmm/sparse_hmm.py index 91434f5f..8d89b462 100644 --- a/pomegranate/hmm/sparse_hmm.py +++ b/pomegranate/hmm/sparse_hmm.py @@ -60,7 +60,7 @@ def unpack_edges(self, edges, starts, ends): self.starts = _cast_as_parameter(torch.log(starts)) if ends is None: - self.ends = torch.empty(n, dtype=self.dtype, device=self.device) - inf + self.ends = torch.full((n,), NEGINF, dtype=self.dtype, device=self.device) else: ends = _check_parameter(_cast_as_tensor(ends), "ends", ndim=1, shape=(n,), min_value=0., max_value=1.) @@ -93,8 +93,8 @@ def unpack_edges(self, edges, starts, ends): if ni is self.start: if self.starts is None: - self.starts = torch.zeros(n, dtype=self.dtype, - device=self.device) - inf + self.starts = torch.full((n,), NEGINF, dtype=self.dtype, + device=self.device) j = self.distributions.index(nj) self.starts[j] = math.log(probability) @@ -302,9 +302,9 @@ def sample(self, n): + "end probabilities.") if self.ends is None: - ends = torch.zeros(self.n_distributions, + ends = torch.full((self.n_distributions,), NEGINF, dtype=self._edge_log_probs.dtype, - device=self._edge_log_probs.device) + float("-inf") + device=self._edge_log_probs.device) else: ends = self.ends