From d081c23f097311c9e746b8be99127e169b2e8a07 Mon Sep 17 00:00:00 2001 From: Tim Verbelen Date: Thu, 3 Oct 2024 10:28:37 +0200 Subject: [PATCH] fix state initialization on reset --- pymdp/envs/env.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pymdp/envs/env.py b/pymdp/envs/env.py index d1cd769e..4354b55b 100644 --- a/pymdp/envs/env.py +++ b/pymdp/envs/env.py @@ -31,29 +31,24 @@ class Env(Module): current_obs: List[Array] dependencies: Dict = field(static=True) - def __init__(self, params: Dict, dependencies: Dict, init_state: List[Array] = None): + def __init__(self, params: Dict, dependencies: Dict): self.params = params self.dependencies = dependencies - if init_state is None: - init_state = jtu.tree_map(lambda x: jnp.argmax(x, -1), self.params["D"]) - - self.state = init_state + self.state = jtu.tree_map(lambda x: jnp.zeros([x.shape[0]]), self.params["D"]) self.current_obs = jtu.tree_map(lambda x: jnp.zeros([x.shape[0], x.shape[1]]), self.params["A"]) @vmap def reset(self, key: Optional[PRNGKeyArray], state: Optional[List[Array]] = None): - if state is not None: - state = self.state - else: + if state is None: probs = self.params["D"] keys = list(jr.split(key, len(probs) + 1)) key = keys[0] state = jtu.tree_map(cat_sample, keys[1:], probs) - new_obs = self._sample_obs(key, state) - env = tree_at(lambda x: x.state, self, state) + + new_obs = self._sample_obs(key, state) env = tree_at(lambda x: x.current_obs, env, new_obs) return new_obs, env