Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to support Gymnasium v1.0 #610

Merged
merged 10 commits into from
Aug 18, 2024
20 changes: 3 additions & 17 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:

strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand All @@ -25,20 +25,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
sudo pip install pygame
pip install -e .[deploy]
- name: Lint with flake8
run: |
pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
run: pip install .[testing]
- name: Test with pytest
run: |
pip install pytest
pip install pytest-cov
pytest --cov=./ --cov-report=xml

run: pytest --cov=./ --cov-report=xml
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
uses: pypa/gh-action-pypi-publish@master
with:
user: __token__
password: ${{ secrets.pypi_password }}
password: ${{ secrets.pypi_password }}
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ repos:
hooks:
- id: isort
args: ["--profile", "black"]
exclude: "__init__.py"
- repo: https://github.com/python/black
rev: 23.3.0
hooks:
Expand Down
8 changes: 0 additions & 8 deletions codecov.yml

This file was deleted.

49 changes: 28 additions & 21 deletions highway_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys

__version__ = "1.8.2"
from gymnasium.envs.registration import register

__version__ = "2.0.0"

try:
from farama_notifications import notifications
Expand All @@ -15,96 +17,101 @@
# Hide pygame support prompt
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"

from gymnasium.envs.registration import register
from highway_env.envs.common.abstract import MultiAgentWrapper


def register_highway_envs():
def _register_highway_envs():
"""Import the envs module so that envs register themselves."""

from highway_env.envs.common.abstract import MultiAgentWrapper

# exit_env.py
register(
id="exit-v0",
entry_point="highway_env.envs:ExitEnv",
entry_point="highway_env.envs.exit_env:ExitEnv",
)

# highway_env.py
register(
id="highway-v0",
entry_point="highway_env.envs:HighwayEnv",
entry_point="highway_env.envs.highway_env:HighwayEnv",
)

register(
id="highway-fast-v0",
entry_point="highway_env.envs:HighwayEnvFast",
entry_point="highway_env.envs.highway_env:HighwayEnvFast",
)

# intersection_env.py
register(
id="intersection-v0",
entry_point="highway_env.envs:IntersectionEnv",
entry_point="highway_env.envs.intersection_env:IntersectionEnv",
)

register(
id="intersection-v1",
entry_point="highway_env.envs:ContinuousIntersectionEnv",
entry_point="highway_env.envs.intersection_env:ContinuousIntersectionEnv",
)

register(
id="intersection-multi-agent-v0",
entry_point="highway_env.envs:MultiAgentIntersectionEnv",
entry_point="highway_env.envs.intersection_env:MultiAgentIntersectionEnv",
)

register(
id="intersection-multi-agent-v1",
entry_point="highway_env.envs:MultiAgentIntersectionEnv",
entry_point="highway_env.envs.intersection_env:MultiAgentIntersectionEnv",
additional_wrappers=(MultiAgentWrapper.wrapper_spec(),),
)

# lane_keeping_env.py
register(
id="lane-keeping-v0",
entry_point="highway_env.envs:LaneKeepingEnv",
entry_point="highway_env.envs.lane_keeping_env:LaneKeepingEnv",
max_episode_steps=200,
)

# merge_env.py
register(
id="merge-v0",
entry_point="highway_env.envs:MergeEnv",
entry_point="highway_env.envs.merge_env:MergeEnv",
)

# parking_env.py
register(
id="parking-v0",
entry_point="highway_env.envs:ParkingEnv",
entry_point="highway_env.envs.parking_env:ParkingEnv",
)

register(
id="parking-ActionRepeat-v0",
entry_point="highway_env.envs:ParkingEnvActionRepeat",
entry_point="highway_env.envs.parking_env:ParkingEnvActionRepeat",
)

register(
id="parking-parked-v0", entry_point="highway_env.envs:ParkingEnvParkedVehicles"
id="parking-parked-v0",
entry_point="highway_env.envs.parking_env:ParkingEnvParkedVehicles",
)

# racetrack_env.py
register(
id="racetrack-v0",
entry_point="highway_env.envs:RacetrackEnv",
entry_point="highway_env.envs.racetrack_env:RacetrackEnv",
)

# roundabout_env.py
register(
id="roundabout-v0",
entry_point="highway_env.envs:RoundaboutEnv",
entry_point="highway_env.envs.roundabout_env:RoundaboutEnv",
)

# two_way_env.py
register(
id="two-way-v0", entry_point="highway_env.envs:TwoWayEnv", max_episode_steps=15
id="two-way-v0",
entry_point="highway_env.envs.two_way_env:TwoWayEnv",
max_episode_steps=15,
)

# u_turn_env.py
register(id="u-turn-v0", entry_point="highway_env.envs:UTurnEnv")
register(id="u-turn-v0", entry_point="highway_env.envs.u_turn_env:UTurnEnv")


_register_highway_envs()
45 changes: 35 additions & 10 deletions highway_env/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from highway_env.envs.highway_env import *
from highway_env.envs.merge_env import *
from highway_env.envs.parking_env import *
from highway_env.envs.roundabout_env import *
from highway_env.envs.two_way_env import *
from highway_env.envs.intersection_env import *
from highway_env.envs.lane_keeping_env import *
from highway_env.envs.u_turn_env import *
from highway_env.envs.exit_env import *
from highway_env.envs.racetrack_env import *
from highway_env.envs.exit_env import ExitEnv
from highway_env.envs.highway_env import HighwayEnv, HighwayEnvFast
from highway_env.envs.intersection_env import (
ContinuousIntersectionEnv,
IntersectionEnv,
MultiAgentIntersectionEnv,
)
from highway_env.envs.lane_keeping_env import LaneKeepingEnv
from highway_env.envs.merge_env import MergeEnv
from highway_env.envs.parking_env import (
ParkingEnv,
ParkingEnvActionRepeat,
ParkingEnvParkedVehicles,
)
from highway_env.envs.racetrack_env import RacetrackEnv
from highway_env.envs.roundabout_env import RoundaboutEnv
from highway_env.envs.two_way_env import TwoWayEnv
from highway_env.envs.u_turn_env import UTurnEnv

__all__ = [
"ExitEnv",
"HighwayEnv",
"HighwayEnvFast",
"IntersectionEnv",
"ContinuousIntersectionEnv",
"MultiAgentIntersectionEnv",
"LaneKeepingEnv",
"MergeEnv",
"ParkingEnv",
"ParkingEnvActionRepeat",
"RacetrackEnv",
"RoundaboutEnv",
"TwoWayEnv",
"UTurnEnv",
]
10 changes: 7 additions & 3 deletions highway_env/envs/common/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gymnasium as gym
import numpy as np
from gymnasium import Wrapper
from gymnasium.utils import RecordConstructorArgs
from gymnasium.wrappers import RecordVideo

from highway_env import utils
Expand Down Expand Up @@ -426,10 +427,13 @@ def __deepcopy__(self, memo):
return result


class MultiAgentWrapper(Wrapper):
class MultiAgentWrapper(Wrapper, RecordConstructorArgs):
def __init__(self, env):
Wrapper.__init__(self, env)
RecordConstructorArgs.__init__(self)

def step(self, action):
obs, reward, terminated, truncated, info = super().step(action)
obs, _, _, truncated, info = super().step(action)
reward = info["agents_rewards"]
terminated = info["agents_terminated"]
truncated = info["agents_truncated"]
return obs, reward, terminated, truncated, info
6 changes: 3 additions & 3 deletions highway_env/envs/common/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
lateral: bool = True,
dynamical: bool = False,
clip: bool = True,
**kwargs
**kwargs,
) -> None:
"""
Create a continuous action space.
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
dynamical: bool = False,
clip: bool = True,
actions_per_axis: int = 3,
**kwargs
**kwargs,
) -> None:
super().__init__(
env,
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(
longitudinal: bool = True,
lateral: bool = True,
target_speeds: Optional[Vector] = None,
**kwargs
**kwargs,
) -> None:
"""
Create a discrete action space of meta-actions.
Expand Down
10 changes: 5 additions & 5 deletions highway_env/envs/common/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
weights: List[float],
scaling: Optional[float] = None,
centering_position: Optional[List[float]] = None,
**kwargs
**kwargs,
) -> None:
super().__init__(env)
self.observation_shape = observation_shape
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
see_behind: bool = False,
observe_intentions: bool = False,
include_obstacles: bool = True,
**kwargs: dict
**kwargs: dict,
) -> None:
"""
:param env: The environment to observe
Expand Down Expand Up @@ -293,7 +293,7 @@ def __init__(
align_to_vehicle_axes: bool = False,
clip: bool = True,
as_image: bool = False,
**kwargs: dict
**kwargs: dict,
) -> None:
"""
:param env: The environment to observe
Expand Down Expand Up @@ -674,7 +674,7 @@ def observe(self) -> np.ndarray:
if self.order == "shuffled":
self.env.np_random.shuffle(obs[1:])
# Flatten
return obs
return obs.astype(self.space().dtype)


class LidarObservation(ObservationType):
Expand All @@ -687,7 +687,7 @@ def __init__(
cells: int = 16,
maximum_range: float = 60,
normalize: bool = True,
**kwargs
**kwargs,
):
super().__init__(env, **kwargs)
self.cells = cells
Expand Down
14 changes: 8 additions & 6 deletions highway_env/envs/exit_env.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Dict, Text, Tuple
from __future__ import annotations

import numpy as np

from highway_env import utils
from highway_env.envs import CircularLane, HighwayEnv, Vehicle
from highway_env.envs.common.action import Action
from highway_env.envs.highway_env import HighwayEnv
from highway_env.road.lane import CircularLane
from highway_env.road.road import Road, RoadNetwork
from highway_env.vehicle.controller import ControlledVehicle
from highway_env.vehicle.kinematics import Vehicle


class ExitEnv(HighwayEnv):
Expand Down Expand Up @@ -44,10 +46,10 @@ def _reset(self) -> None:
self._create_road()
self._create_vehicles()

def step(self, action) -> Tuple[np.ndarray, float, bool, dict]:
obs, reward, terminal, info = super().step(action)
def step(self, action) -> tuple[np.ndarray, float, bool, bool, dict]:
obs, reward, terminated, truncated, info = super().step(action)
info.update({"is_success": self._is_success()})
return obs, reward, terminal, info
return obs, reward, terminated, truncated, info

def _create_road(
self, road_length=1000, exit_position=400, exit_length=100
Expand Down Expand Up @@ -154,7 +156,7 @@ def _reward(self, action: Action) -> float:
reward = np.clip(reward, 0, 1)
return reward

def _rewards(self, action: Action) -> Dict[Text, float]:
def _rewards(self, action: Action) -> dict[str, float]:
lane_index = (
self.vehicle.target_lane_index
if isinstance(self.vehicle, ControlledVehicle)
Expand Down
2 changes: 1 addition & 1 deletion highway_env/envs/intersection_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _info(self, obs: np.ndarray, action: int) -> dict:
info["agents_rewards"] = tuple(
self._agent_reward(action, vehicle) for vehicle in self.controlled_vehicles
)
info["agents_dones"] = tuple(
info["agents_terminated"] = tuple(
self._agent_is_terminal(vehicle) for vehicle in self.controlled_vehicles
)
return info
Expand Down
2 changes: 1 addition & 1 deletion highway_env/road/road.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def straight_road_network(
*nodes_str,
StraightLane(
origin, end, line_types=line_types, speed_limit=speed_limit
)
),
)
return net

Expand Down
Loading
Loading