diff --git a/pymdp/jax/agent.py b/pymdp/jax/agent.py index 79aac96d..0d42510a 100644 --- a/pymdp/jax/agent.py +++ b/pymdp/jax/agent.py @@ -432,7 +432,6 @@ def multiaction_probabilities(self, q_pi: Array): ) marginals = jnp.where(locs, q_pi, 0.).sum(-1) - # assert jnp.isclose(jnp.sum(marginals), 1.) # this fails inside scan return marginals @vmap diff --git a/pymdp/jax/control.py b/pymdp/jax/control.py index b02cafe0..177a5a41 100644 --- a/pymdp/jax/control.py +++ b/pymdp/jax/control.py @@ -342,8 +342,11 @@ def scan_body(carry, t): inductive_value = calc_inductive_value_t(qs_init, qs_next, I, epsilon=inductive_epsilon) if use_inductive else 0. - param_info_gain = calc_pA_info_gain(pA, qo, qs_next, A_dependencies) if use_param_info_gain else 0. - param_info_gain += calc_pB_info_gain(pB, qs_next, qs, B_dependencies, policy_i[t]) if use_param_info_gain else 0. + param_info_gain = 0. + if pA is not None: + param_info_gain += calc_pA_info_gain(pA, qo, qs_next, A_dependencies) if use_param_info_gain else 0. + if pB is not None: + param_info_gain += calc_pB_info_gain(pB, qs_next, qs, B_dependencies, policy_i[t]) if use_param_info_gain else 0. neg_G += info_gain + utility - param_info_gain + inductive_value