Skip to content

Commit

Permalink
fix state initialization on reset
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Verbelen committed Oct 3, 2024
1 parent 0e644f1 commit d081c23
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions pymdp/envs/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,24 @@ class Env(Module):
current_obs: List[Array]
dependencies: Dict = field(static=True)

def __init__(self, params: Dict, dependencies: Dict, init_state: List[Array] = None):
def __init__(self, params: Dict, dependencies: Dict):
self.params = params
self.dependencies = dependencies

if init_state is None:
init_state = jtu.tree_map(lambda x: jnp.argmax(x, -1), self.params["D"])

self.state = init_state
self.state = jtu.tree_map(lambda x: jnp.zeros([x.shape[0]]), self.params["D"])
self.current_obs = jtu.tree_map(lambda x: jnp.zeros([x.shape[0], x.shape[1]]), self.params["A"])

@vmap
def reset(self, key: Optional[PRNGKeyArray], state: Optional[List[Array]] = None):
if state is not None:
state = self.state
else:
if state is None:
probs = self.params["D"]
keys = list(jr.split(key, len(probs) + 1))
key = keys[0]
state = jtu.tree_map(cat_sample, keys[1:], probs)

new_obs = self._sample_obs(key, state)

env = tree_at(lambda x: x.state, self, state)

new_obs = self._sample_obs(key, state)
env = tree_at(lambda x: x.current_obs, env, new_obs)
return new_obs, env

Expand Down

0 comments on commit d081c23

Please sign in to comment.