Skip to content

Commit

Permalink
Avoid operations on uninitialised memory in HMMs
Browse files Browse the repository at this point in the history
  • Loading branch information
MBradbury committed Oct 28, 2024
1 parent 56ab929 commit cb73603
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
20 changes: 10 additions & 10 deletions pomegranate/hmm/dense_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,24 +199,24 @@ def add_edge(self, start, end, prob):

if start == self.start:
if self.starts is None:
self.starts = torch.empty(n, dtype=self.dtype,
device=self.device) - inf
self.starts = torch.full((n,), NEGINF, dtype=self.dtype,
device=self.device)

idx = self.distributions.index(end)
self.starts[idx] = math.log(prob)

elif end == self.end:
if self.ends is None:
self.ends = torch.empty(n, dtype=self.dtype,
device=self.device) - inf
self.ends = torch.full((n,), NEGINF, dtype=self.dtype,
device=self.device)

idx = self.distributions.index(start)
self.ends[idx] = math.log(prob)

else:
if self.edges is None:
self.edges = torch.empty((n, n), dtype=self.dtype,
device=self.device) - inf
self.edges = torch.full((n, n), NEGINF, dtype=self.dtype,
device=self.device)

idx1 = self.distributions.index(start)
idx2 = self.distributions.index(end)
Expand Down Expand Up @@ -250,8 +250,8 @@ def sample(self, n):
+ "end probabilities.")

if self.ends is None:
ends = torch.zeros(self.n_distributions, dtype=self.edges.dtype,
device=self.edges.device) + float("-inf")
ends = torch.full((self.n_distributions,), NEGINF, dtype=self.edges.dtype,
device=self.edges.device)
else:
ends = self.ends

Expand Down Expand Up @@ -454,8 +454,8 @@ def backward(self, X=None, emissions=None, priors=None):
emissions = _check_inputs(self, X, emissions, priors)
n, l, _ = emissions.shape

b = torch.zeros(l, n, self.n_distributions, dtype=self.dtype,
device=self.device) + float("-inf")
b = torch.zeros((l, n, self.n_distributions), NEGINF, dtype=self.dtype,
device=self.device)
b[-1] = self.ends

t_max = self.edges.max()
Expand Down
10 changes: 5 additions & 5 deletions pomegranate/hmm/sparse_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def unpack_edges(self, edges, starts, ends):
self.starts = _cast_as_parameter(torch.log(starts))

if ends is None:
self.ends = torch.empty(n, dtype=self.dtype, device=self.device) - inf
self.ends = torch.full((n,), NEGINF, dtype=self.dtype, device=self.device)
else:
ends = _check_parameter(_cast_as_tensor(ends), "ends", ndim=1,
shape=(n,), min_value=0., max_value=1.)
Expand Down Expand Up @@ -93,8 +93,8 @@ def unpack_edges(self, edges, starts, ends):

if ni is self.start:
if self.starts is None:
self.starts = torch.zeros(n, dtype=self.dtype,
device=self.device) - inf
self.starts = torch.full((n,), NEGINF, dtype=self.dtype,
device=self.device)

j = self.distributions.index(nj)
self.starts[j] = math.log(probability)
Expand Down Expand Up @@ -302,9 +302,9 @@ def sample(self, n):
+ "end probabilities.")

if self.ends is None:
ends = torch.zeros(self.n_distributions,
ends = torch.full((self.n_distributions,), NEGINF,
dtype=self._edge_log_probs.dtype,
device=self._edge_log_probs.device) + float("-inf")
device=self._edge_log_probs.device)
else:
ends = self.ends

Expand Down

0 comments on commit cb73603

Please sign in to comment.