From 3c6b01338a3cbaf882712390cc5feb9ac7b510f2 Mon Sep 17 00:00:00 2001 From: Andrea Ostuni Date: Tue, 2 Jan 2024 21:18:28 +0100 Subject: [PATCH] migrate to new render API --- examples/python/launch_cartpole.py | 7 ++++++- python/gym_gz/runtimes/gazebo_runtime.py | 17 +++++++++++++++-- python/gym_gz/runtimes/realtime_runtime.py | 11 ++++++++--- tests/.python/test_pendulum_wrt_ground_truth.py | 10 ++++++---- tests/test_gym_gz/test_reproducibility.py | 4 ++-- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/examples/python/launch_cartpole.py b/examples/python/launch_cartpole.py index 292b9743..27698157 100755 --- a/examples/python/launch_cartpole.py +++ b/examples/python/launch_cartpole.py @@ -48,6 +48,8 @@ def make_env_from_id(env_id: str, **kwargs) -> gym.Env: observation = env.reset(seed=42, options={}) # Initialize returned values + terminated = False + truncated = False done = False totalReward = 0 @@ -55,7 +57,10 @@ def make_env_from_id(env_id: str, **kwargs) -> gym.Env: # Execute a random action action = env.action_space.sample() - observation, reward, done, _ = env.step(action) + observation, reward, terminated, truncated, _ = env.step(action) + + # Check if the episode is terminated + done = terminated or truncated # Render the environment. # It is not required to call this in the loop if physics is not randomized. diff --git a/python/gym_gz/runtimes/gazebo_runtime.py b/python/gym_gz/runtimes/gazebo_runtime.py index ca04fcf1..2ef56500 100644 --- a/python/gym_gz/runtimes/gazebo_runtime.py +++ b/python/gym_gz/runtimes/gazebo_runtime.py @@ -46,6 +46,7 @@ def __init__( real_time_factor: float, physics_engine=scenario.PhysicsEngine_dart, world: str = None, + render_mode: Optional[str] = None, **kwargs, ): @@ -69,6 +70,9 @@ def __init__( self._world_sdf = world self._world_name = None + # Store the render mode + self.render_mode = render_mode + # Create the Task object task = task_cls(agent_rate=agent_rate, **kwargs) @@ -132,6 +136,10 @@ def step(self, action: Action) -> State: # Check truncation truncated = self.task.is_truncated() + # Render the environment + if self.render_mode == "human": + self.render() + # Get info info = self.task.get_info() @@ -158,10 +166,15 @@ def reset(self, seed: int = None, options : Dict = {}, **kwargs) -> ResetReturn: logger.warn("The observation does not belong to the observation space") # Get info info = self.task.get_info() - return ResetReturn((Observation(observation), Info(info))) - def render(self, mode: str = "human", **kwargs) -> None: + # Render the environment + if self.render_mode == "human": + self.render() + + return ResetReturn((Observation(observation), Info(info))) + def render(self, **kwargs) -> None: + mode = self.render_mode if mode != "human": raise ValueError(f"Render mode '{mode}' not supported") diff --git a/python/gym_gz/runtimes/realtime_runtime.py b/python/gym_gz/runtimes/realtime_runtime.py index f75fe343..0e407ace 100644 --- a/python/gym_gz/runtimes/realtime_runtime.py +++ b/python/gym_gz/runtimes/realtime_runtime.py @@ -4,7 +4,7 @@ from gym_gz.base import runtime, task from gym_gz.utils.typing import Action, Done, Info, Observation, State, Terminated, Truncated, ResetReturn, Dict - +from typing import Optional class RealTimeRuntime(runtime.Runtime): """ @@ -15,15 +15,19 @@ class RealTimeRuntime(runtime.Runtime): This class is not yet complete. """ - def __init__(self, task_cls: type, robot_cls: type, agent_rate: float, **kwargs): + def __init__(self, task_cls: type, robot_cls: type, agent_rate: float, render_mode: Optional[str] = None, **kwargs): # Build the environment task_object = task_cls(**kwargs) + # Check the task assert isinstance( task_object, task.Task ), "'task_cls' object must inherit from Task" + # Render mode + self.render_mode = render_mode + super().__init__(task=task_object, agent_rate=agent_rate) raise NotImplementedError @@ -80,7 +84,8 @@ def reset(self, seed: int = None, options : Dict = {}, **kwargs) -> ResetReturn: return ResetReturn((observation, Info({}))) - def render(self, mode: str = "human", **kwargs) -> None: + def render(self, **kwargs) -> None: + mode = self.render_mode raise NotImplementedError def close(self) -> None: diff --git a/tests/.python/test_pendulum_wrt_ground_truth.py b/tests/.python/test_pendulum_wrt_ground_truth.py index 073df274..7116d58d 100644 --- a/tests/.python/test_pendulum_wrt_ground_truth.py +++ b/tests/.python/test_pendulum_wrt_ground_truth.py @@ -10,8 +10,8 @@ from gym_gz.robots.sim import gazebo, pybullet from gym_gz.tasks.pendulum_swingup import PendulumSwingUp from gym_gz.utils import logger -from gym_gz.utils.typing import Observation, Reward, State, -from typing import Dict +from gym_gz.utils.typing import Observation, Reward, State +from typing import Dict, Optional # Set verbosity logger.set_level(gym.logger.DEBUG) @@ -25,7 +25,7 @@ class PendulumEnv(gym.Env): metadata = {"render.modes": []} - def __init__(self): + def __init__(self, render_mode: Optional[str] = None): super().__init__() # Check the xacro pendulum model @@ -37,6 +37,7 @@ def __init__(self): self.dt = None # self.force = None + self.render_mode = render_mode self.theta = None self.theta_dot = None @@ -73,7 +74,8 @@ def reset(self, seed: int = None, options: Dict ={}, **kwargs): # Use set_state_from_obs pass - def render(self, mode="human", **kwargs): + def render(self, **kwargs): + mode = self.render_mode raise Exception("This runtime does not support rendering") def seed(self, seed=None): diff --git a/tests/test_gym_gz/test_reproducibility.py b/tests/test_gym_gz/test_reproducibility.py index d9596ac8..9d220f2f 100644 --- a/tests/test_gym_gz/test_reproducibility.py +++ b/tests/test_gym_gz/test_reproducibility.py @@ -35,8 +35,8 @@ def test_reproducibility(num_physics_rollouts: int): assert env1 != env2 - env1.seed(42) - env2.seed(42) + env1.unwrapped.seed(42) + env2.unwrapped.seed(42) for _ in range(5):