Skip to content

Commit

Permalink
removed .jax module prefixes from imports in pymdp.envs.rollout
Browse files Browse the repository at this point in the history
  • Loading branch information
conorheins committed Oct 2, 2024
1 parent b4513b4 commit 99cf06e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pymdp/envs/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import jax.tree_util as jtu
import jax.lax

from pymdp.jax.agent import Agent
from pymdp.jax.envs.env import Env
from pymdp.agent import Agent
from pymdp.envs.env import Env


def rollout(agent: Agent, env: Env, num_timesteps: int, rng_key: jr.PRNGKey, policy_search=None):
Expand Down

0 comments on commit 99cf06e

Please sign in to comment.