Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for intervene operations on dictionaries where act is a Callable #574

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions chirho/indexed/internals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Dict, Hashable, Optional, TypeVar, Union
from typing import Any, Dict, Hashable, Optional, TypeVar, Union

import pyro
import pyro.infer.reparam
Expand Down Expand Up @@ -67,7 +67,7 @@ def _gather_tensor(


@gather.register(dict)
def _gather_state(
def _gather_dict(
value: Dict[K, T], indices: IndexSet, *, event_dim: int = 0, **kwargs
) -> Dict[K, T]:
return {
Expand Down Expand Up @@ -143,6 +143,31 @@ def _scatter_tensor(
return result


@scatter.register(dict)
def _scatter_dict(
value: Dict[K, T],
indexset: IndexSet,
*,
result: Optional[Dict[K, Optional[T]]] = None,
event_dim: Optional[int] = None,
name_to_dim: Optional[Dict[Hashable, int]] = None,
) -> Dict[K, Any]:

if result is None:
result = {k: None for k in value.keys()}

for k in value.keys():
result[k] = scatter(
value[k],
indexset,
result=result[k],
event_dim=event_dim,
name_to_dim=name_to_dim,
)

return result


@indices_of.register
def _indices_of_number(value: numbers.Number, **kwargs) -> IndexSet:
return IndexSet()
Expand Down
17 changes: 15 additions & 2 deletions chirho/interventional/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import collections
import functools
from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar
from typing import Callable, Dict, Generic, Hashable, Mapping, Optional, TypeVar, Union

import pyro
import torch
Expand Down Expand Up @@ -60,14 +60,27 @@ def _intervene_atom_distribution(

@intervene.register(dict)
def _dict_intervene(
obs: Dict[K, T], act: Dict[K, AtomicIntervention[T]], **kwargs
obs: Dict[K, T],
act: Union[Dict[K, AtomicIntervention[T]], Callable[[Dict[K, T]], Dict[K, T]]],
**kwargs,
) -> Dict[K, T]:

if callable(act):
return _dict_intervene_callable(obs, act, **kwargs)

result: Dict[K, T] = {}
for k in obs.keys():
result[k] = intervene(obs[k], act[k] if k in act else None, **kwargs)
return result


@pyro.poutine.runtime.effectful(type="intervene")
def _dict_intervene_callable(
obs: Dict[K, T], act: Callable[[Dict[K, T]], Dict[K, T]], **kwargs
) -> Dict[K, T]:
return act(obs)


@intervene.register
def _intervene_callable(
obs: collections.abc.Callable,
Expand Down
1 change: 1 addition & 0 deletions tests/dynamical/test_static_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
dict(I=torch.tensor(50.0)),
dict(S=torch.tensor(50.0), R=torch.tensor(50.0)),
dict(S=torch.tensor(50.0), I=torch.tensor(50.0), R=torch.tensor(50.0)),
lambda X: {k: v / 2 for k, v in X.items()},
]

# Define intervention times before all tspan values.
Expand Down
Loading