Skip to content

Commit

Permalink
migrate to new render API
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaostuni committed Jan 2, 2024
1 parent 094b815 commit 3c6b013
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
7 changes: 6 additions & 1 deletion examples/python/launch_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,19 @@ def make_env_from_id(env_id: str, **kwargs) -> gym.Env:
observation = env.reset(seed=42, options={})

# Initialize returned values
terminated = False
truncated = False
done = False
totalReward = 0

while not done:

# Execute a random action
action = env.action_space.sample()
observation, reward, done, _ = env.step(action)
observation, reward, terminated, truncated, _ = env.step(action)

# Check if the episode is terminated
done = terminated or truncated

# Render the environment.
# It is not required to call this in the loop if physics is not randomized.
Expand Down
17 changes: 15 additions & 2 deletions python/gym_gz/runtimes/gazebo_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
real_time_factor: float,
physics_engine=scenario.PhysicsEngine_dart,
world: str = None,
render_mode: Optional[str] = None,
**kwargs,
):

Expand All @@ -69,6 +70,9 @@ def __init__(
self._world_sdf = world
self._world_name = None

# Store the render mode
self.render_mode = render_mode

# Create the Task object
task = task_cls(agent_rate=agent_rate, **kwargs)

Expand Down Expand Up @@ -132,6 +136,10 @@ def step(self, action: Action) -> State:
# Check truncation
truncated = self.task.is_truncated()

# Render the environment
if self.render_mode == "human":
self.render()

# Get info
info = self.task.get_info()

Expand All @@ -158,10 +166,15 @@ def reset(self, seed: int = None, options : Dict = {}, **kwargs) -> ResetReturn:
logger.warn("The observation does not belong to the observation space")
# Get info
info = self.task.get_info()
return ResetReturn((Observation(observation), Info(info)))

def render(self, mode: str = "human", **kwargs) -> None:
# Render the environment
if self.render_mode == "human":
self.render()

return ResetReturn((Observation(observation), Info(info)))

def render(self, **kwargs) -> None:
mode = self.render_mode
if mode != "human":
raise ValueError(f"Render mode '{mode}' not supported")

Expand Down
11 changes: 8 additions & 3 deletions python/gym_gz/runtimes/realtime_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from gym_gz.base import runtime, task
from gym_gz.utils.typing import Action, Done, Info, Observation, State, Terminated, Truncated, ResetReturn, Dict

from typing import Optional

class RealTimeRuntime(runtime.Runtime):
"""
Expand All @@ -15,15 +15,19 @@ class RealTimeRuntime(runtime.Runtime):
This class is not yet complete.
"""

def __init__(self, task_cls: type, robot_cls: type, agent_rate: float, **kwargs):
def __init__(self, task_cls: type, robot_cls: type, agent_rate: float, render_mode: Optional[str] = None, **kwargs):

# Build the environment
task_object = task_cls(**kwargs)

# Check the task
assert isinstance(
task_object, task.Task
), "'task_cls' object must inherit from Task"

# Render mode
self.render_mode = render_mode

super().__init__(task=task_object, agent_rate=agent_rate)

raise NotImplementedError
Expand Down Expand Up @@ -80,7 +84,8 @@ def reset(self, seed: int = None, options : Dict = {}, **kwargs) -> ResetReturn:

return ResetReturn((observation, Info({})))

def render(self, mode: str = "human", **kwargs) -> None:
def render(self, **kwargs) -> None:
mode = self.render_mode
raise NotImplementedError

def close(self) -> None:
Expand Down
10 changes: 6 additions & 4 deletions tests/.python/test_pendulum_wrt_ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from gym_gz.robots.sim import gazebo, pybullet
from gym_gz.tasks.pendulum_swingup import PendulumSwingUp
from gym_gz.utils import logger
from gym_gz.utils.typing import Observation, Reward, State,
from typing import Dict
from gym_gz.utils.typing import Observation, Reward, State
from typing import Dict, Optional

# Set verbosity
logger.set_level(gym.logger.DEBUG)
Expand All @@ -25,7 +25,7 @@ class PendulumEnv(gym.Env):

metadata = {"render.modes": []}

def __init__(self):
def __init__(self, render_mode: Optional[str] = None):
super().__init__()

# Check the xacro pendulum model
Expand All @@ -37,6 +37,7 @@ def __init__(self):

self.dt = None
# self.force = None
self.render_mode = render_mode
self.theta = None
self.theta_dot = None

Expand Down Expand Up @@ -73,7 +74,8 @@ def reset(self, seed: int = None, options: Dict ={}, **kwargs):
# Use set_state_from_obs
pass

def render(self, mode="human", **kwargs):
def render(self, **kwargs):
mode = self.render_mode
raise Exception("This runtime does not support rendering")

def seed(self, seed=None):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_gym_gz/test_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_reproducibility(num_physics_rollouts: int):

assert env1 != env2

env1.seed(42)
env2.seed(42)
env1.unwrapped.seed(42)
env2.unwrapped.seed(42)

for _ in range(5):

Expand Down

0 comments on commit 3c6b013

Please sign in to comment.