From 4794011b14c35a8240447e127ba9eae3c92f9151 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 5 Dec 2024 12:20:23 -0500 Subject: [PATCH] added scatter for dict and intervention function on dicts --- chirho/indexed/internals.py | 29 ++++++++++++++++++-- chirho/interventional/handlers.py | 17 ++++++++++-- tests/dynamical/test_static_interventions.py | 1 + 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/chirho/indexed/internals.py b/chirho/indexed/internals.py index d684b8ad6..19bbb2237 100644 --- a/chirho/indexed/internals.py +++ b/chirho/indexed/internals.py @@ -1,5 +1,5 @@ import numbers -from typing import Dict, Hashable, Optional, TypeVar, Union +from typing import Any, Dict, Hashable, Optional, TypeVar, Union import pyro import pyro.infer.reparam @@ -67,7 +67,7 @@ def _gather_tensor( @gather.register(dict) -def _gather_state( +def _gather_dict( value: Dict[K, T], indices: IndexSet, *, event_dim: int = 0, **kwargs ) -> Dict[K, T]: return { @@ -143,6 +143,31 @@ def _scatter_tensor( return result +@scatter.register(dict) +def _scatter_dict( + value: Dict[K, T], + indexset: IndexSet, + *, + result: Optional[Dict[K, Optional[T]]] = None, + event_dim: Optional[int] = None, + name_to_dim: Optional[Dict[Hashable, int]] = None, +) -> Dict[K, Any]: + + if result is None: + result = {k: None for k in value.keys()} + + for k in value.keys(): + result[k] = scatter( + value[k], + indexset, + result=result[k], + event_dim=event_dim, + name_to_dim=name_to_dim, + ) + + return result + + @indices_of.register def _indices_of_number(value: numbers.Number, **kwargs) -> IndexSet: return IndexSet() diff --git a/chirho/interventional/handlers.py b/chirho/interventional/handlers.py index 5e2d231bd..0a74307a9 100644 --- a/chirho/interventional/handlers.py +++ b/chirho/interventional/handlers.py @@ -2,7 +2,7 @@ import collections import functools -from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar +from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar, Union import pyro import torch @@ -60,14 +60,27 @@ def _intervene_atom_distribution( @intervene.register(dict) def _dict_intervene( - obs: Dict[K, T], act: Dict[K, AtomicIntervention[T]], **kwargs + obs: Dict[K, T], + act: Union[Dict[K, AtomicIntervention[T]], Callable[[Dict[K, T]], Dict[K, T]]], + **kwargs, ) -> Dict[K, T]: + + if callable(act): + return _dict_intervene_callable(obs, act, **kwargs) + result: Dict[K, T] = {} for k in obs.keys(): result[k] = intervene(obs[k], act[k] if k in act else None, **kwargs) return result +@pyro.poutine.runtime.effectful(type="intervene") +def _dict_intervene_callable( + obs: Dict[K, T], act: Callable[[Dict[K, T]], Dict[K, T]], **kwargs +) -> Dict[K, T]: + return act(obs) + + @intervene.register def _intervene_callable( obs: collections.abc.Callable, diff --git a/tests/dynamical/test_static_interventions.py b/tests/dynamical/test_static_interventions.py index dae83e22d..1ea022f4f 100644 --- a/tests/dynamical/test_static_interventions.py +++ b/tests/dynamical/test_static_interventions.py @@ -34,6 +34,7 @@ dict(I=torch.tensor(50.0)), dict(S=torch.tensor(50.0), R=torch.tensor(50.0)), dict(S=torch.tensor(50.0), I=torch.tensor(50.0), R=torch.tensor(50.0)), + lambda X: {k: v / 2 for k, v in X.items()}, ] # Define intervention times before all tspan values.