From 8d41b9a0456760f79735ff148112da2041b5e1ff Mon Sep 17 00:00:00 2001 From: Pierre Schumacher Date: Wed, 16 Oct 2024 15:15:15 +0200 Subject: [PATCH] set upper limit on gymnasium version --- deprl/custom_distributed.py | 10 +++++----- deprl/env_wrappers/dm_wrapper.py | 2 +- deprl/env_wrappers/gym_wrapper.py | 2 +- deprl/vendor/tonic/environments/builders.py | 6 ++---- pyproject.toml | 2 +- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/deprl/custom_distributed.py b/deprl/custom_distributed.py index c489a9f..0126266 100644 --- a/deprl/custom_distributed.py +++ b/deprl/custom_distributed.py @@ -52,7 +52,7 @@ def __init__( if env_args is not None: [x.merge_args(env_args) for x in self.environments] [x.apply_args() for x in self.environments] - self._max_episode_steps = max_episode_steps + self.max_episode_steps = max_episode_steps self.observation_space = self.environments[0].observation_space self.action_space = self.environments[0].action_space self.name = self.environments[0].name @@ -85,7 +85,7 @@ def step(self, actions): muscle = self.environments[i].muscle_states self.lengths[i] += 1 # Timeouts trigger resets but are not true terminations. - reset = term or self.lengths[i] == self._max_episode_steps + reset = term or self.lengths[i] == self.max_episode_steps next_observations.append(ob) rewards.append(rew) resets.append(reset) @@ -138,7 +138,7 @@ def __init__( self.build_dict = build_dict self.worker_groups = worker_groups self.workers_per_group = workers_per_group - self._max_episode_steps = max_episode_steps + self.max_episode_steps = max_episode_steps self.env_args = env_args self.header = header @@ -171,7 +171,7 @@ def initialize(self, seed): "output_queue": self.output_queue, "group_seed": group_seed, "build_dict": self.build_dict, - "max_episode_steps": self._max_episode_steps, + "max_episode_steps": self.max_episode_steps, "index": i, "workers": self.workers_per_group, "env_args": self.env_args @@ -267,7 +267,7 @@ def distribute( if "header" in tonic_conf: exec(tonic_conf["header"]) dummy_environment = build_env_from_dict(build_dict) - max_episode_steps = dummy_environment._max_episode_steps + max_episode_steps = dummy_environment.max_episode_steps del dummy_environment if parallel < 2: diff --git a/deprl/env_wrappers/dm_wrapper.py b/deprl/env_wrappers/dm_wrapper.py index 5716801..d2f8140 100644 --- a/deprl/env_wrappers/dm_wrapper.py +++ b/deprl/env_wrappers/dm_wrapper.py @@ -26,7 +26,7 @@ def muscle_activity(self): return self.unwrapped.environment.physics.data.act @property - def _max_episode_steps(self): + def max_episode_steps(self): return self.unwrapped.max_episode_steps diff --git a/deprl/env_wrappers/gym_wrapper.py b/deprl/env_wrappers/gym_wrapper.py index fd1eade..eeea3b5 100644 --- a/deprl/env_wrappers/gym_wrapper.py +++ b/deprl/env_wrappers/gym_wrapper.py @@ -56,7 +56,7 @@ def muscle_activity(self): return self.unwrapped.sim.data.act @property - def _max_episode_steps(self): + def max_episode_steps(self): if hasattr(self.unwrapped, "max_episode_steps"): return self.unwrapped.max_episode_steps else: diff --git a/deprl/vendor/tonic/environments/builders.py b/deprl/vendor/tonic/environments/builders.py index ca343f6..ab56660 100644 --- a/deprl/vendor/tonic/environments/builders.py +++ b/deprl/vendor/tonic/environments/builders.py @@ -84,12 +84,10 @@ def build_environment( # Get the default time limit. if max_episode_steps == "default": - if hasattr(environment, "_max_episode_steps"): - max_episode_steps = environment._max_episode_steps + if hasattr(environment, "max_episode_steps"): + max_episode_steps = environment.max_episode_steps elif hasattr(environment, "horizon"): max_episode_steps = environment.horizon - elif hasattr(environment, "max_episode_steps"): - max_episode_steps = environment.max_episode_steps else: logger.log("No max episode steps found, setting them to 1000") diff --git a/pyproject.toml b/pyproject.toml index f7b5fcd..99621f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ numpy = "^1.22.4" termcolor = "^2.2.0" pandas = "^2.0.1" gdown = "^5.1.0" -gymnasium = "*" +gymnasium = "<=0.30" wandb = "^0.15.4" # torch = {version="2.1.0", source="pytorch-cpu"} torch = ">=2.1.0"