Skip to content

Commit

Permalink
Add further type annotations to categorical glm hmm
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Oct 6, 2024
1 parent 076350a commit 36bbc68
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions dynamax/hidden_markov_model/models/categorical_glm_hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMParameterSet, HMMPropertySet
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState, ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions, ParamsStandardHMMTransitions
from dynamax.types import Scalar
from dynamax.types import IntScalar, Scalar
import optax
from typing import NamedTuple, Optional, Tuple, Union

Expand All @@ -24,9 +24,9 @@ class ParamsCategoricalRegressionHMM(NamedTuple):
class CategoricalRegressionHMMEmissions(HMMEmissions):

def __init__(self,
num_states,
num_classes,
input_dim,
num_states: int,
num_classes: int,
input_dim: int,
m_step_optimizer=optax.adam(1e-2),
m_step_num_iters=50):
"""_summary_
Expand All @@ -50,7 +50,13 @@ def inputs_shape(self):
def log_prior(self, params):
return 0.0

def initialize(self, key=jr.PRNGKey(0), method="prior", emission_weights=None, emission_biases=None):
def initialize(
self,
key: Array=jr.PRNGKey(0),
method: str="prior",
emission_weights: Optional[Float[Array, "num_states num_classes input_dim"]]=None,
emission_biases: Optional[Float[Array, "num_states num_classes"]]=None,
):
"""Initialize the model parameters and their corresponding properties.
You can either specify parameters manually via the keyword arguments, or you can have
Expand Down Expand Up @@ -88,7 +94,11 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_weights=None, e
biases=ParameterProperties())
return params, props

def distribution(self, params, state, inputs=None):
def distribution(
self,
params: ParamsCategoricalRegressionHMMEmissions,
state: IntScalar,
inputs: Float[Array, " input_dim"]):
logits = params.weights[state] @ inputs + params.biases[state]
return tfd.Categorical(logits=logits)

Expand Down

0 comments on commit 36bbc68

Please sign in to comment.