Skip to content

Commit

Permalink
Make step method state keep track of var_names
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Nov 20, 2024
1 parent 2ec8d27 commit 002f890
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
2 changes: 2 additions & 0 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import field
from enum import IntEnum, unique
from typing import Any

Expand Down Expand Up @@ -96,6 +97,7 @@ def infer_warn_stats_info(

@dataclass_state
class StepMethodState(DataClassState):
var_names: list[str] = field(metadata={"tensor_name": True, "frozen": True})
rng: RandomGeneratorState


Expand Down
15 changes: 13 additions & 2 deletions pymc/step_methods/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,12 @@ def sampling_state(self) -> DataClassState:
state_class = self._state_class
kwargs = {}
for field in fields(state_class):
val = getattr(self, field.name, field.default)
is_tensor_name = field.metadata.get("tensor_name", False)
val: Any
if is_tensor_name:
val = [var.name for var in getattr(self, "vars")]
else:
val = getattr(self, field.name, field.default)
if val is MISSING:
raise AttributeError(
f"{type(self).__name__!r} object has no attribute {field.name!r}"
Expand All @@ -89,11 +94,17 @@ def sampling_state(self, state: DataClassState):
state, state_class
), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
for field in fields(state_class):
is_tensor_name = field.metadata.get("tensor_name", False)
state_val = deepcopy(getattr(state, field.name))
if isinstance(state_val, RandomGeneratorState):
state_val = random_generator_from_state(state_val)
self_val = getattr(self, field.name)
is_frozen = field.metadata.get("frozen", False)
self_val: Any
if is_tensor_name:
self_val = [var.name for var in getattr(self, "vars")]
assert is_frozen
else:
self_val = getattr(self, field.name, field.default)
if is_frozen:
if not equal_dataclass_values(state_val, self_val):
raise ValueError(
Expand Down

0 comments on commit 002f890

Please sign in to comment.