diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 09e0b23ca..491b0b01a 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -280,10 +280,31 @@ jobs:
run: |
sudo apt-get -y install xvfb
sudo /usr/bin/Xvfb :0 -screen 0 1280x1024x24 &
+ - name: Setup FFmpeg
+ uses: FedericoCarboni/setup-ffmpeg@v3
+ id: setup-ffmpeg
+ with:
+ # A specific version to download, may also be "release" or a specific version
+ # like "6.1.0". At the moment semver specifiers (i.e. >=6.1.0) are supported
+ # only on Windows, on other platforms they are allowed but version is matched
+ # exactly regardless.
+ ffmpeg-version: release
+ # Target architecture of the ffmpeg executable to install. Defaults to the
+ # system architecture. Only x64 and arm64 are supported (arm64 only on Linux).
+ architecture: ''
+ # Linking type of the binaries. Use "shared" to download shared binaries and
+ # "static" for statically linked ones. Shared builds are currently only available
+ # for windows releases. Defaults to "static"
+ linking-type: static
+ # As of version 3 of this action, builds are no longer downloaded from GitHub
+ # except on Windows: https://github.com/GyanD/codexffmpeg/releases.
+ github-token: ${{ github.server_url == 'https://github.com' && github.token || '' }}
- name: Blackbox tests
run: |
pip install cython
pip install numpy
+ pip install mediapy
+ conda install ffmpeg
pip install -e .
pip install -e .[gym]
python -m metadrive.pull_asset
diff --git a/documentation/source/index.rst b/documentation/source/index.rst
index c5c618c3c..69e48e956 100644
--- a/documentation/source/index.rst
+++ b/documentation/source/index.rst
@@ -47,6 +47,7 @@ Please feel free to contact us if you have any suggestions or ideas!
action.ipynb
reward_cost_done.ipynb
training.ipynb
+ multigoal_intersection.ipynb
.. toctree::
:hidden:
diff --git a/documentation/source/multigoal_intersection.ipynb b/documentation/source/multigoal_intersection.ipynb
new file mode 100644
index 000000000..e0224cb8b
--- /dev/null
+++ b/documentation/source/multigoal_intersection.ipynb
@@ -0,0 +1,272 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "2832faf1-1bd3-4a95-8b0d-b3289e74d4d0",
+ "metadata": {},
+ "source": [
+ "# Demonstration on MultigoalIntersection\n",
+ "\n",
+ "In this notebook, we demonstrate how to setup a multigoal intersection environment where you can access relevant stats (e.g. route completion, reward, success rate) for all four possible goals (right turn, left turn, move forward, U turn) simultaneously.\n",
+ "\n",
+ "We demonstrate how to build the environment, in which we have successfully trained a SAC expert that achieves 99% success rate, and how to access those stats in the info dict returned each step.\n",
+ "\n",
+ "*Note: We pretrain the SAC expert with `use_multigoal_intersection=False` and then finetune it with `use_multigoal_intersection=True`.*"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "b9733eac-9d07-47cf-bda7-4dbb8d5f2412",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "from metadrive.envs.gym_wrapper import create_gym_wrapper\n",
+ "from metadrive.envs.multigoal_intersection import MultiGoalIntersectionEnv\n",
+ "import mediapy as media\n",
+ "\n",
+ "render = False\n",
+ "num_scenarios = 1000\n",
+ "start_seed = 100\n",
+ "goal_probabilities = {\n",
+ " \"right_turn\": 0.25,\n",
+ " \"left_turn\": 0.25,\n",
+ " \"go_straight\": 0.25,\n",
+ " \"u_turn\": 0.25\n",
+ "}\n",
+ "\n",
+ "\n",
+ "class MultiGoalWrapped(MultiGoalIntersectionEnv):\n",
+ " current_goal = None\n",
+ "\n",
+ " def step(self, actions):\n",
+ " o, r, tm, tc, i = super().step(actions)\n",
+ "\n",
+ " o = i['obs/goals/{}'.format(self.current_goal)]\n",
+ " r = i['reward/goals/{}'.format(self.current_goal)]\n",
+ " i['route_completion'] = i['route_completion/goals/{}'.format(self.current_goal)]\n",
+ " i['arrive_dest'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n",
+ " i['reward/goals/default'] = i['reward/goals/{}'.format(self.current_goal)]\n",
+ " i['route_completion/goals/default'] = i['route_completion/goals/{}'.format(self.current_goal)]\n",
+ " i['arrive_dest/goals/default'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n",
+ " i[\"current_goal\"] = self.current_goal\n",
+ " return o, r, tm, tc, i\n",
+ "\n",
+ " def reset(self, *args, **kwargs):\n",
+ " o, i = super().reset(*args, **kwargs)\n",
+ "\n",
+ " # Sample a goal from the goal set\n",
+ " if self.config[\"use_multigoal_intersection\"]:\n",
+ " p = goal_probabilities\n",
+ " self.current_goal = np.random.choice(list(p.keys()), p=list(p.values()))\n",
+ "\n",
+ " else:\n",
+ " self.current_goal = \"default\"\n",
+ "\n",
+ " o = i['obs/goals/{}'.format(self.current_goal)]\n",
+ " i['route_completion'] = i['route_completion/goals/{}'.format(self.current_goal)]\n",
+ " i['arrive_dest'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n",
+ " i['reward/goals/default'] = i['reward/goals/{}'.format(self.current_goal)]\n",
+ " i['route_completion/goals/default'] = i['route_completion/goals/{}'.format(self.current_goal)]\n",
+ " i['arrive_dest/goals/default'] = i['arrive_dest/goals/{}'.format(self.current_goal)]\n",
+ " i[\"current_goal\"] = self.current_goal\n",
+ "\n",
+ " return o, i"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "f5b6f059-52f8-46ee-bcfe-dee6f4d2e2e6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[38;20m[INFO] Environment: MultiGoalWrapped\u001b[0m\n",
+ "\u001b[38;20m[INFO] MetaDrive version: 0.4.2.3\u001b[0m\n",
+ "\u001b[38;20m[INFO] Sensors: [lidar: Lidar(), side_detector: SideDetector(), lane_line_detector: LaneLineDetector()]\u001b[0m\n",
+ "\u001b[38;20m[INFO] Render Mode: none\u001b[0m\n",
+ "\u001b[38;20m[INFO] Horizon (Max steps per agent): 500\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "env_config = dict(\n",
+ " use_render=render,\n",
+ " manual_control=False,\n",
+ " vehicle_config=dict(show_lidar=False, show_navi_mark=True, show_line_to_navi_mark=True,\n",
+ " show_line_to_dest=True, show_dest_mark=True),\n",
+ " horizon=500, # to speed up training\n",
+ "\n",
+ " traffic_density=0.06,\n",
+ " \n",
+ " use_multigoal_intersection=True, # Set to False if want to use the same observation but with original PG scenarios.\n",
+ " out_of_route_done=False,\n",
+ "\n",
+ " num_scenarios=num_scenarios,\n",
+ " start_seed=start_seed,\n",
+ " accident_prob=0.8,\n",
+ " crash_vehicle_done=False,\n",
+ " crash_object_done=False,\n",
+ ")\n",
+ "\n",
+ "wrapped = create_gym_wrapper(MultiGoalWrapped)\n",
+ "\n",
+ "env = wrapped(env_config)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ae2abe78-f3e3-40b9-88dd-a958fc932363",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[38;20m[INFO] Assets version: 0.4.2.3\u001b[0m\n",
+ "\u001b[38;20m[INFO] Known Pipes: glxGraphicsPipe\u001b[0m\n",
+ "\u001b[38;20m[INFO] Start Scenario Index: 100, Num Scenarios : 1000\u001b[0m\n",
+ "\u001b[33;20m[WARNING] env.vehicle will be deprecated soon. Use env.agent instead (base_env.py:731)\u001b[0m\n",
+ "\u001b[38;20m[INFO] Episode ended! Scenario Index: 606 Reason: arrive_dest.\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "frames = []\n",
+ "\n",
+ "env.reset()\n",
+ "while True:\n",
+ " action = [0, 1]\n",
+ " o, r, d, i = env.step(action)\n",
+ " frame = env.render(mode=\"topdown\")\n",
+ " frames.append(frame)\n",
+ " if d:\n",
+ " break"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "40ac0392-67e3-4d2d-a9bd-2065831e43ca",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Output at final step:\n",
+ "\tacceleration: 1.000\n",
+ "\tarrive_dest: 1.000\n",
+ "\tarrive_dest/goals/default: 1.000\n",
+ "\tarrive_dest/goals/go_straight: 1.000\n",
+ "\tarrive_dest/goals/left_turn: 0.000\n",
+ "\tarrive_dest/goals/right_turn: 0.000\n",
+ "\tarrive_dest/goals/u_turn: 0.000\n",
+ "\tcost: 0.000\n",
+ "\tcrash: 0.000\n",
+ "\tcrash_building: 0.000\n",
+ "\tcrash_human: 0.000\n",
+ "\tcrash_object: 0.000\n",
+ "\tcrash_sidewalk: 0.000\n",
+ "\tcrash_vehicle: 0.000\n",
+ "\tcurrent_goal: go_straight\n",
+ "\tenv_seed: 606.000\n",
+ "\tepisode_energy: 6.986\n",
+ "\tepisode_length: 88.000\n",
+ "\tepisode_reward: 35.834\n",
+ "\tmax_step: 0.000\n",
+ "\tnavigation_command: right\n",
+ "\tnavigation_forward: 0.000\n",
+ "\tnavigation_left: 0.000\n",
+ "\tnavigation_right: 1.000\n",
+ "\tout_of_road: 0.000\n",
+ "\tovertake_vehicle_num: 0.000\n",
+ "\tpolicy: EnvInputPolicy\n",
+ "\treward/default_reward: -10.000\n",
+ "\treward/goals/default: 12.335\n",
+ "\treward/goals/go_straight: 12.335\n",
+ "\treward/goals/left_turn: -10.000\n",
+ "\treward/goals/right_turn: -10.000\n",
+ "\treward/goals/u_turn: -10.000\n",
+ "\troute_completion: 0.969\n",
+ "\troute_completion/goals/default: 0.969\n",
+ "\troute_completion/goals/go_straight: 0.969\n",
+ "\troute_completion/goals/left_turn: 0.632\n",
+ "\troute_completion/goals/right_turn: 0.643\n",
+ "\troute_completion/goals/u_turn: 0.552\n",
+ "\tsteering: 0.000\n",
+ "\tstep_energy: 0.162\n",
+ "\tvelocity: 22.313\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Output at final step:\")\n",
+ "\n",
+ "i = {k: i[k] for k in sorted(i.keys())}\n",
+ "for k, v in i.items():\n",
+ " if isinstance(v, str):\n",
+ " s = v\n",
+ " elif np.iterable(v):\n",
+ " continue\n",
+ " else:\n",
+ " s = \"{:.3f}\".format(v)\n",
+ " print(\"\\t{}: {}\".format(k, s))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "dc986e4e-f81c-4882-88b2-9eb306552fb3",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
|
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "media.show_video(frames)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/metadrive/component/algorithm/BIG.py b/metadrive/component/algorithm/BIG.py
index c322c472d..3384b3631 100644
--- a/metadrive/component/algorithm/BIG.py
+++ b/metadrive/component/algorithm/BIG.py
@@ -73,9 +73,17 @@ def generate(self, generate_method: str, parameter: Union[str, int]):
assert isinstance(parameter, int), "When generating map by assigning block num, the parameter should be int"
self.block_num = parameter + 1
elif generate_method == BigGenerateMethod.BLOCK_SEQUENCE:
- assert isinstance(parameter, str), "When generating map from block sequence, the parameter should be a str"
- self.block_num = len(parameter) + 1
- self._block_sequence = FirstPGBlock.ID + parameter
+ if isinstance(parameter, list):
+ self.block_num = len(parameter) + 1
+ self._block_sequence = [FirstPGBlock] + parameter
+ else:
+ assert isinstance(
+ parameter, str
+ ), "When generating map from block sequence, the parameter should be a str. But got {}".format(
+ type(parameter)
+ )
+ self.block_num = len(parameter) + 1
+ self._block_sequence = FirstPGBlock.ID + parameter
while True:
if self.big_helper_func():
break
@@ -104,8 +112,11 @@ def sample_block(self) -> PGBlock:
block_type = self.np_random.choice(block_types, p=block_probabilities)
block_type = get_metadrive_class(block_type)
else:
- type_id = self._block_sequence[len(self.blocks)]
- block_type = self.block_dist_config.get_block(type_id)
+ if isinstance(self._block_sequence[0], str):
+ type_id = self._block_sequence[len(self.blocks)]
+ block_type = self.block_dist_config.get_block(type_id)
+ else:
+ block_type = self._block_sequence[len(self.blocks)]
socket = self.np_random.choice(self.blocks[-1].get_socket_indices())
block = block_type(
diff --git a/metadrive/component/algorithm/blocks_prob_dist.py b/metadrive/component/algorithm/blocks_prob_dist.py
index bbf3e1b3a..86f4e179d 100644
--- a/metadrive/component/algorithm/blocks_prob_dist.py
+++ b/metadrive/component/algorithm/blocks_prob_dist.py
@@ -37,7 +37,8 @@ class PGBlockDistConfig:
"Split": 0.00,
"ParkingLot": 0.00,
"TollGate": 0.00,
- "Bidirection": 0.00
+ "Bidirection": 0.00,
+ "StdInterSectionWithUTurn": 0.00
}
@classmethod
diff --git a/metadrive/component/map/pg_map.py b/metadrive/component/map/pg_map.py
index 98ad32c36..b99940c36 100644
--- a/metadrive/component/map/pg_map.py
+++ b/metadrive/component/map/pg_map.py
@@ -20,7 +20,7 @@ def parse_map_config(easy_map_config, new_map_config, default_config):
assert isinstance(default_config, Config)
# Return the user specified config if overwritten
- if not default_config["map_config"].is_identical(new_map_config):
+ if easy_map_config is None or not default_config["map_config"].is_identical(new_map_config):
new_map_config = default_config["map_config"].copy(unchangeable=False).update(new_map_config)
assert default_config["map"] == easy_map_config
return new_map_config
diff --git a/metadrive/component/navigation_module/base_navigation.py b/metadrive/component/navigation_module/base_navigation.py
index 24c04daa9..9f09643cc 100644
--- a/metadrive/component/navigation_module/base_navigation.py
+++ b/metadrive/component/navigation_module/base_navigation.py
@@ -68,6 +68,7 @@ def __init__(
self.navi_arrow_dir = [0, 0]
self._dest_node_path = None
self._goal_node_path = None
+ self._goal_node_path2 = None
self._node_path_list = []
@@ -78,15 +79,20 @@ def __init__(
# nodepath
self._line_to_dest = self.origin.attachNewNode("line")
self._goal_node_path = self.origin.attachNewNode("target")
+ self._goal_node_path2 = self.origin.attachNewNode("target2")
self._dest_node_path = self.origin.attachNewNode("dest")
self._node_path_list.append(self._line_to_dest)
self._node_path_list.append(self._goal_node_path)
+ self._node_path_list.append(self._goal_node_path2)
self._node_path_list.append(self._dest_node_path)
if show_navi_mark:
navi_point_model = AssetLoader.loader.loadModel(AssetLoader.file_path("models", "box.bam"))
navi_point_model.reparentTo(self._goal_node_path)
+
+ navi_point_model2 = AssetLoader.loader.loadModel(AssetLoader.file_path("models", "box.bam"))
+ navi_point_model2.reparentTo(self._goal_node_path2)
if show_dest_mark:
dest_point_model = AssetLoader.loader.loadModel(AssetLoader.file_path("models", "box.bam"))
dest_point_model.reparentTo(self._dest_node_path)
@@ -108,18 +114,20 @@ def __init__(
line_seg.setColor(self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 1.0)
line_seg.setThickness(4)
self._dynamic_line_np_2 = NodePath(line_seg.create(True))
-
self._node_path_list.append(self._dynamic_line_np_2)
-
self._dynamic_line_np_2.reparentTo(self.origin)
self._line_to_navi = line_seg
self._goal_node_path.setTransparency(TransparencyAttrib.M_alpha)
+ self._goal_node_path2.setTransparency(TransparencyAttrib.M_alpha)
self._dest_node_path.setTransparency(TransparencyAttrib.M_alpha)
self._goal_node_path.setColor(
self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 0.7
)
+ self._goal_node_path2.setColor(
+ self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 0.5
+ )
self._dest_node_path.setColor(
self.navi_mark_color[0], self.navi_mark_color[1], self.navi_mark_color[2], 0.7
)
@@ -180,6 +188,7 @@ def destroy(self):
pass
self._dest_node_path.removeNode()
self._goal_node_path.removeNode()
+ self._goal_node_path2.removeNode()
for np in self._node_path_list:
np.detachNode()
@@ -234,12 +243,16 @@ def _draw_line_to_dest(self, start_position, end_position):
self._dynamic_line_np.hide(CamMask.Shadow | CamMask.RgbCam)
self._dynamic_line_np.reparentTo(self.origin)
- def _draw_line_to_navi(self, start_position, end_position):
+ def _draw_line_to_navi(self, start_position, end_position, next_checkpoint=None):
if not self._show_line_to_navi_mark:
return
line_seg = self._line_to_navi
line_seg.moveTo(panda_vector(start_position, self.LINE_TO_DEST_HEIGHT))
line_seg.drawTo(panda_vector(end_position, self.LINE_TO_DEST_HEIGHT))
+
+ if next_checkpoint is not None:
+ line_seg.drawTo(panda_vector(next_checkpoint, self.LINE_TO_DEST_HEIGHT))
+
self._dynamic_line_np_2.removeNode()
self._dynamic_line_np_2 = NodePath(line_seg.create(False))
diff --git a/metadrive/component/navigation_module/edge_network_navigation.py b/metadrive/component/navigation_module/edge_network_navigation.py
index 9082bd84e..80bd992a8 100644
--- a/metadrive/component/navigation_module/edge_network_navigation.py
+++ b/metadrive/component/navigation_module/edge_network_navigation.py
@@ -100,27 +100,37 @@ def update_localization(self, ego_vehicle):
self._navi_info.fill(0.0)
half = self.CHECK_POINT_INFO_DIM
- self._navi_info[:half], lanes_heading1, checkpoint = self._get_info_for_checkpoint(
+ self._navi_info[:half], lanes_heading1, next_checkpoint = self._get_info_for_checkpoint(
lanes_id=0,
ref_lane=self.map.road_network.get_lane(self.current_checkpoint_lane_index),
ego_vehicle=ego_vehicle
)
- self._navi_info[half:], lanes_heading2, _ = self._get_info_for_checkpoint(
+ self._navi_info[half:], lanes_heading2, next_next_checkpoint = self._get_info_for_checkpoint(
lanes_id=1,
ref_lane=self.map.road_network.get_lane(self.next_checkpoint_lane_index),
ego_vehicle=ego_vehicle
)
if self._show_navi_info: # Whether to visualize little boxes in the scene denoting the checkpoints
- pos_of_goal = checkpoint
+ pos_of_goal = next_checkpoint
self._goal_node_path.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT))
self._goal_node_path.setH(self._goal_node_path.getH() + 3)
+
+ pos_of_goal = next_next_checkpoint
+ self._goal_node_path2.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT))
+ self._goal_node_path2.setH(self._goal_node_path2.getH() + 3)
+
self.navi_arrow_dir = [lanes_heading1, lanes_heading2]
dest_pos = self._dest_node_path.getPos()
self._draw_line_to_dest(start_position=ego_vehicle.position, end_position=(dest_pos[0], dest_pos[1]))
navi_pos = self._goal_node_path.getPos()
- self._draw_line_to_navi(start_position=ego_vehicle.position, end_position=(navi_pos[0], navi_pos[1]))
+ next_navi_pos = self._goal_node_path2.getPos()
+ self._draw_line_to_navi(
+ start_position=ego_vehicle.position,
+ end_position=(navi_pos[0], navi_pos[1]),
+ next_checkpoint=(next_navi_pos[0], next_navi_pos[1])
+ )
def _update_target_checkpoints(self, ego_lane_index) -> bool:
"""
diff --git a/metadrive/component/navigation_module/node_network_navigation.py b/metadrive/component/navigation_module/node_network_navigation.py
index 596d08358..9b2fa81f4 100644
--- a/metadrive/component/navigation_module/node_network_navigation.py
+++ b/metadrive/component/navigation_module/node_network_navigation.py
@@ -40,7 +40,7 @@ def __init__(
self.current_road = None
self.next_road = None
- def reset(self, vehicle):
+ def reset(self, vehicle, dest=None, random_seed=None):
possible_lanes = ray_localization(vehicle.heading, vehicle.spawn_place, self.engine, use_heading_filter=False)
possible_lane_indexes = [lane_index for lane, lane_index, dist in possible_lanes]
@@ -56,11 +56,12 @@ def reset(self, vehicle):
assert len(possible_lanes) > 0
lane, new_l_index = possible_lanes[0][:-1]
- dest = vehicle.config["destination"]
+ if dest is None:
+ dest = vehicle.config["destination"]
current_lane = lane
destination = dest if dest is not None else None
- random_seed = self.engine.global_random_seed
+ random_seed = self.engine.global_random_seed if random_seed is None else random_seed
assert current_lane is not None, "spawn place is not on road!"
super(NodeNetworkNavigation, self).reset(current_lane)
assert self.map.road_network_type == NodeRoadNetwork, "This Navigation module only support NodeRoadNetwork type"
@@ -188,12 +189,12 @@ def update_localization(self, ego_vehicle):
self._navi_info.fill(0.0)
half = self.CHECK_POINT_INFO_DIM
# Put the next checkpoint's information into the first half of the navi_info
- self._navi_info[:half], lanes_heading1, checkpoint = self._get_info_for_checkpoint(
+ self._navi_info[:half], lanes_heading1, next_checkpoint = self._get_info_for_checkpoint(
lanes_id=0, ref_lane=self.current_ref_lanes[0], ego_vehicle=ego_vehicle
)
# Put the next of the next checkpoint's information into the first half of the navi_info
- self._navi_info[half:], lanes_heading2, _ = self._get_info_for_checkpoint(
+ self._navi_info[half:], lanes_heading2, next_next_checkpoint = self._get_info_for_checkpoint(
lanes_id=1,
ref_lane=self.next_ref_lanes[0] if self.next_ref_lanes is not None else self.current_ref_lanes[0],
ego_vehicle=ego_vehicle
@@ -202,13 +203,23 @@ def update_localization(self, ego_vehicle):
self.navi_arrow_dir = [lanes_heading1, lanes_heading2]
if self._show_navi_info:
# Whether to visualize little boxes in the scene denoting the checkpoints
- pos_of_goal = checkpoint
+ pos_of_goal = next_checkpoint
self._goal_node_path.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT))
self._goal_node_path.setH(self._goal_node_path.getH() + 3)
+
+ pos_of_goal = next_next_checkpoint
+ self._goal_node_path2.setPos(panda_vector(pos_of_goal[0], pos_of_goal[1], self.MARK_HEIGHT))
+ self._goal_node_path2.setH(self._goal_node_path2.getH() + 3)
+
dest_pos = self._dest_node_path.getPos()
self._draw_line_to_dest(start_position=ego_vehicle.position, end_position=(dest_pos[0], dest_pos[1]))
navi_pos = self._goal_node_path.getPos()
- self._draw_line_to_navi(start_position=ego_vehicle.position, end_position=(navi_pos[0], navi_pos[1]))
+ next_navi_pos = self._goal_node_path2.getPos()
+ self._draw_line_to_navi(
+ start_position=ego_vehicle.position,
+ end_position=(navi_pos[0], navi_pos[1]),
+ next_checkpoint=(next_navi_pos[0], next_navi_pos[1])
+ )
def _update_target_checkpoints(self, ego_lane_index, ego_lane_longitude) -> bool:
"""
@@ -250,7 +261,9 @@ def get_current_lateral_range(self, current_position, engine) -> float:
def _get_current_lane(self, ego_vehicle):
"""
- Called in update_localization to find current lane information
+ Called in update_localization to find current lane information. If the vehicle is in the current reference lane,
+ meaning it is not yet moving to the next road segment, then return the current reference lane. Otherwise, return
+ the next reference lane. If the vehicle is not in any of the reference lanes, then return the closest lane.
"""
possible_lanes, on_lane = ray_localization(
ego_vehicle.heading, ego_vehicle.position, ego_vehicle.engine, return_on_lane=True
diff --git a/metadrive/component/pgblock/intersection.py b/metadrive/component/pgblock/intersection.py
index 8d244d12b..5e91407eb 100644
--- a/metadrive/component/pgblock/intersection.py
+++ b/metadrive/component/pgblock/intersection.py
@@ -110,7 +110,6 @@ def _create_part(self, attach_lanes, attach_road: Road, radius: float, intersect
# u-turn
if self._enable_u_turn_flag:
- adverse_road = -attach_road
self._create_u_turn(attach_road, part_idx)
# go forward part
@@ -221,6 +220,17 @@ def _create_left_turn(self, radius, lane_num, attach_left_lane, attach_road, int
)
def _create_u_turn(self, attach_road, part_idx):
+ """
+ Create a U turn.
+
+ Args:
+ attach_road: the road where the U turn starts.
+ part_idx: in [0, 1, 2, 3]. When part_idx!=0, we grab the lanes from road network. Otherwise we use the
+ initial lanes (positive_lanes).
+
+ Returns:
+ None.
+ """
# set to CONTINUOUS to debug
line_type = PGLineType.NONE
lanes = attach_road.get_lanes(self.block_network) if part_idx != 0 else self.positive_lanes
@@ -253,3 +263,9 @@ def get_intermediate_spawn_lanes(self):
"""Override this function for intersection so that we won't spawn vehicles in the center of intersection."""
respawn_lanes = self.get_respawn_lanes()
return respawn_lanes
+
+
+class InterSectionWithUTurn(InterSection):
+ ID = "U"
+ _enable_u_turn_flag = True
+ SOCKET_NUM = 4
diff --git a/metadrive/component/pgblock/pg_block.py b/metadrive/component/pgblock/pg_block.py
index c717d7073..c2f1af836 100644
--- a/metadrive/component/pgblock/pg_block.py
+++ b/metadrive/component/pgblock/pg_block.py
@@ -252,10 +252,15 @@ def create_in_world(self):
for _id, lane in enumerate(lanes):
self._construct_lane(lane, (_from, _to, _id))
+
+ # choose_side is a two-elemental list, the first element is for left side,
+ # the second element is for right side. If False, then the left/right side line (broken line or
+ # continuous line) will not be constructed.
+
choose_side = [True, True] if _id == len(lanes) - 1 else [True, False]
- if Road(_from, _to).is_negative_road() and _id == 0:
- # draw center line with positive road
- choose_side = [False, False]
+ # if Road(_from, _to).is_negative_road() and _id == 0:
+ # # draw center line with positive road
+ # choose_side = [False, False]
self._construct_lane_line_in_block(lane, choose_side)
self._construct_sidewalk()
self._construct_crosswalk()
diff --git a/metadrive/component/pgblock/std_intersection.py b/metadrive/component/pgblock/std_intersection.py
index 2e87792fb..4aac8171d 100644
--- a/metadrive/component/pgblock/std_intersection.py
+++ b/metadrive/component/pgblock/std_intersection.py
@@ -1,5 +1,5 @@
-from metadrive.component.pgblock.intersection import InterSection
from metadrive.component.pg_space import Parameter
+from metadrive.component.pgblock.intersection import InterSection, InterSectionWithUTurn
class StdInterSection(InterSection):
@@ -7,3 +7,10 @@ def _try_plug_into_previous_block(self) -> bool:
self._config[Parameter.change_lane_num] = 0
success = super(StdInterSection, self)._try_plug_into_previous_block()
return success
+
+
+class StdInterSectionWithUTurn(InterSectionWithUTurn):
+ def _try_plug_into_previous_block(self) -> bool:
+ self._config[Parameter.change_lane_num] = 0
+ success = super(StdInterSectionWithUTurn, self)._try_plug_into_previous_block()
+ return success
diff --git a/metadrive/component/road_network/node_road_network.py b/metadrive/component/road_network/node_road_network.py
index 7ecf8b6a3..85f02c642 100644
--- a/metadrive/component/road_network/node_road_network.py
+++ b/metadrive/component/road_network/node_road_network.py
@@ -272,6 +272,7 @@ def shortest_path(self, start: str, goal: str) -> List[str]:
Returns:
The shortest checkpoints from start to goal.
"""
+ assert isinstance(goal, str)
start_road_node = start[0]
assert start != goal
return next(self.bfs_paths(start_road_node, goal), [])
diff --git a/metadrive/component/sensors/distance_detector.py b/metadrive/component/sensors/distance_detector.py
index 8261496f8..eae9cce4f 100644
--- a/metadrive/component/sensors/distance_detector.py
+++ b/metadrive/component/sensors/distance_detector.py
@@ -89,7 +89,6 @@ class DistanceDetector(BaseSensor):
"""
It is a module like lidar, used to detect sidewalk/center line or other static things
"""
- Lidar_point_cloud_obs_dim = 240
DEFAULT_HEIGHT = 0.2
# for vis debug
@@ -196,7 +195,7 @@ def __init__(self, engine):
super(SideDetector, self).__init__(engine)
self.set_start_phase_offset(90)
self.origin.hide(CamMask.RgbCam | CamMask.Shadow | CamMask.Shadow | CamMask.DepthCam | CamMask.SemanticCam)
- self.mask = CollisionGroup.ContinuousLaneLine
+ self.mask = CollisionGroup.ContinuousLaneLine | CollisionGroup.Sidewalk
class LaneLineDetector(SideDetector):
@@ -206,4 +205,4 @@ def __init__(self, engine):
super(SideDetector, self).__init__(engine)
self.set_start_phase_offset(90)
self.origin.hide(CamMask.RgbCam | CamMask.Shadow | CamMask.Shadow | CamMask.DepthCam | CamMask.SemanticCam)
- self.mask = CollisionGroup.ContinuousLaneLine | CollisionGroup.BrokenLaneLine
+ self.mask = CollisionGroup.ContinuousLaneLine | CollisionGroup.BrokenLaneLine | CollisionGroup.Sidewalk
diff --git a/metadrive/component/sensors/lidar.py b/metadrive/component/sensors/lidar.py
index 6722eaa61..a68d49f90 100644
--- a/metadrive/component/sensors/lidar.py
+++ b/metadrive/component/sensors/lidar.py
@@ -15,7 +15,6 @@
class Lidar(DistanceDetector):
ANGLE_FACTOR = True
- Lidar_point_cloud_obs_dim = 240
DEFAULT_HEIGHT = 1.2
BROAD_PHASE_EXTRA_DIST = 0
diff --git a/metadrive/component/vehicle/base_vehicle.py b/metadrive/component/vehicle/base_vehicle.py
index f6890b89a..071a69d18 100644
--- a/metadrive/component/vehicle/base_vehicle.py
+++ b/metadrive/component/vehicle/base_vehicle.py
@@ -232,9 +232,26 @@ def before_step(self, action=None):
return step_info
def after_step(self):
- step_info = {}
if self.navigation and self.config["navigation_module"]:
self.navigation.update_localization(self)
+ self._state_check()
+ self.update_dist_to_left_right()
+ step_energy, episode_energy = self._update_energy_consumption()
+ self.out_of_route = self._out_of_route()
+ step_info = self._update_overtake_stat()
+ my_policy = self.engine.get_policy(self.name)
+ step_info.update(
+ {
+ "velocity": float(self.speed),
+ "steering": float(self.steering),
+ "acceleration": float(self.throttle_brake),
+ "step_energy": step_energy,
+ "episode_energy": episode_energy,
+ "policy": my_policy.name if my_policy is not None else my_policy
+ }
+ )
+
+ if self.navigation is not None and hasattr(self.navigation, "navi_arrow_dir"):
lanes_heading = self.navigation.navi_arrow_dir
lane_0_heading = lanes_heading[0]
lane_1_heading = lanes_heading[1]
@@ -258,22 +275,7 @@ def after_step(self):
"navigation_right": navigation_turn_right
}
)
- self._state_check()
- self.update_dist_to_left_right()
- step_energy, episode_energy = self._update_energy_consumption()
- self.out_of_route = self._out_of_route()
- step_info.update(self._update_overtake_stat())
- my_policy = self.engine.get_policy(self.name)
- step_info.update(
- {
- "velocity": float(self.speed),
- "steering": float(self.steering),
- "acceleration": float(self.throttle_brake),
- "step_energy": step_energy,
- "episode_energy": episode_energy,
- "policy": my_policy.name if my_policy is not None else my_policy
- }
- )
+
return step_info
def _out_of_route(self):
@@ -512,14 +514,15 @@ def _apply_throttle_brake(self, throttle_brake):
def update_dist_to_left_right(self):
self.dist_to_left_side, self.dist_to_right_side = self._dist_to_route_left_right()
- def _dist_to_route_left_right(self):
- # TODO
- if self.navigation is None or self.navigation.current_ref_lanes is None:
+ def _dist_to_route_left_right(self, navigation=None):
+ if navigation is None:
+ navigation = self.navigation
+ if navigation is None or navigation.current_ref_lanes is None:
return 0, 0
- current_reference_lane = self.navigation.current_ref_lanes[0]
+ current_reference_lane = navigation.current_ref_lanes[0]
_, lateral_to_reference = current_reference_lane.local_coordinates(self.position)
- lateral_to_left = lateral_to_reference + self.navigation.get_current_lane_width() / 2
- lateral_to_right = self.navigation.get_current_lateral_range(self.position, self.engine) - lateral_to_left
+ lateral_to_left = lateral_to_reference + navigation.get_current_lane_width() / 2
+ lateral_to_right = navigation.get_current_lateral_range(self.position, self.engine) - lateral_to_left
return lateral_to_left, lateral_to_right
# @property
diff --git a/metadrive/engine/interface.py b/metadrive/engine/interface.py
index 3961f0b8d..9c8842786 100644
--- a/metadrive/engine/interface.py
+++ b/metadrive/engine/interface.py
@@ -188,6 +188,8 @@ def destroy(self):
self.left_panel.destroy()
def _update_navi_arrow(self, lanes_heading):
+ if not self.engine.global_config["vehicle_config"]["show_navigation_arrow"]:
+ return
lane_0_heading = lanes_heading[0]
lane_1_heading = lanes_heading[1]
if abs(lane_0_heading - lane_1_heading) < 0.01:
diff --git a/metadrive/envs/base_env.py b/metadrive/envs/base_env.py
index 8846238d9..4102fe914 100644
--- a/metadrive/envs/base_env.py
+++ b/metadrive/envs/base_env.py
@@ -120,6 +120,8 @@
show_line_to_dest=False,
# Whether to draw a line from current vehicle position to the next navigation point
show_line_to_navi_mark=False,
+ # Whether to draw left / right arrow in the interface to denote the navigation direction
+ show_navigation_arrow=True,
# If set to True, the vehicle will be in color green in top-down renderer or MARL setting
use_special_color=False,
# Clear wheel friction, so it can not move by setting steering and throttle/brake. Used for ReplayPolicy
diff --git a/metadrive/envs/metadrive_env.py b/metadrive/envs/metadrive_env.py
index 29c445306..fea4c2ec6 100644
--- a/metadrive/envs/metadrive_env.py
+++ b/metadrive/envs/metadrive_env.py
@@ -72,6 +72,7 @@
out_of_road_penalty=5.0,
crash_vehicle_penalty=5.0,
crash_object_penalty=5.0,
+ crash_sidewalk_penalty=0.0,
driving_reward=1.0,
speed_reward=0.1,
use_lateral_reward=False,
@@ -83,6 +84,7 @@
# ===== Termination Scheme =====
out_of_route_done=False,
+ out_of_road_done=True,
on_continuous_line_done=True,
on_broken_line_done=False,
crash_vehicle_done=True,
@@ -160,7 +162,7 @@ def done_function(self, vehicle_id: str):
"Episode ended! Scenario Index: {} Reason: arrive_dest.".format(self.current_seed),
extra={"log_once": True}
)
- if done_info[TerminationState.OUT_OF_ROAD]:
+ if done_info[TerminationState.OUT_OF_ROAD] and self.config["out_of_road_done"]:
done = True
self.logger.info(
"Episode ended! Scenario Index: {} Reason: out_of_road.".format(self.current_seed),
@@ -280,7 +282,8 @@ def reward_function(self, vehicle_id: str):
reward = -self.config["crash_vehicle_penalty"]
elif vehicle.crash_object:
reward = -self.config["crash_object_penalty"]
-
+ elif vehicle.crash_sidewalk:
+ reward = -self.config["crash_sidewalk_penalty"]
step_info["route_completion"] = vehicle.navigation.route_completion
return reward, step_info
diff --git a/metadrive/envs/multigoal_intersection.py b/metadrive/envs/multigoal_intersection.py
new file mode 100644
index 000000000..5ca0c655e
--- /dev/null
+++ b/metadrive/envs/multigoal_intersection.py
@@ -0,0 +1,616 @@
+"""
+This file provides a multi-goal environment based on the intersection environment. The environment fully support
+conventional MetaDrive PG maps, where there is a special config['use_pg_map'] to enable the PG maps and all config are
+the same as MetaDriveEnv.
+If config['use_pg_map'] is False, the environment will use an intersection map and the goals information for all
+possible destinations will be provided.
+"""
+from collections import defaultdict
+
+import gymnasium as gym
+import numpy as np
+import seaborn as sns
+
+from metadrive.component.navigation_module.node_network_navigation import NodeNetworkNavigation
+from metadrive.component.pg_space import ParameterSpace, Parameter, DiscreteSpace, BoxSpace
+from metadrive.component.pgblock.first_block import FirstPGBlock
+from metadrive.component.pgblock.intersection import InterSectionWithUTurn
+from metadrive.component.road_network import Road
+from metadrive.constants import DEFAULT_AGENT
+from metadrive.engine.logger import get_logger
+from metadrive.envs.metadrive_env import MetaDriveEnv
+from metadrive.manager.base_manager import BaseManager
+from metadrive.obs.state_obs import BaseObservation, StateObservation
+from metadrive.utils.math import clip, norm
+
+logger = get_logger()
+
+EGO_STATE_DIM = 5
+NAVI_DIM = 10
+GOAL_DEPENDENT_STATE_DIM = 3
+
+
+class CustomizedObservation(BaseObservation):
+ def __init__(self, config):
+ self.state_obs = StateObservation(config)
+ super(CustomizedObservation, self).__init__(config)
+ self.latest_observation = {}
+
+ self.lane_detect_dim = self.config['vehicle_config']['lane_line_detector']['num_lasers']
+ self.side_detect_dim = self.config['vehicle_config']['side_detector']['num_lasers']
+ self.vehicle_detect_dim = self.config['vehicle_config']['lidar']['num_lasers']
+
+ @property
+ def observation_space(self):
+ shape = (
+ EGO_STATE_DIM + self.side_detect_dim + self.lane_detect_dim + self.vehicle_detect_dim + NAVI_DIM +
+ GOAL_DEPENDENT_STATE_DIM,
+ )
+ return gym.spaces.Box(-1.0, 1.0, shape=shape, dtype=np.float32)
+
+ def observe(self, vehicle, navigation=None):
+ ego = self.state_observe(vehicle)
+ assert ego.shape[0] == EGO_STATE_DIM
+
+ obs = [ego]
+
+ if vehicle.config["side_detector"]["num_lasers"] > 0:
+ side = self.side_detector_observe(vehicle)
+ assert side.shape[0] == self.side_detect_dim
+ obs.append(side)
+ self.latest_observation["side_detect"] = side
+
+ if vehicle.config["lane_line_detector"]["num_lasers"] > 0:
+ lane = self.lane_line_detector_observe(vehicle)
+ assert lane.shape[0] == self.lane_detect_dim
+ obs.append(lane)
+ self.latest_observation["lane_detect"] = lane
+
+ if vehicle.config["lidar"]["num_lasers"] > 0:
+ veh = self.vehicle_detector_observe(vehicle)
+ assert veh.shape[0] == self.vehicle_detect_dim
+ obs.append(veh)
+ self.latest_observation["vehicle_detect"] = veh
+ if navigation is None:
+ navigation = vehicle.navigation
+ navi = navigation.get_navi_info()
+ assert len(navi) == NAVI_DIM
+ obs.append(navi)
+
+ # Goal-dependent infos
+ goal_dependent_info = []
+ lateral_to_left, lateral_to_right = vehicle._dist_to_route_left_right(navigation=navigation)
+ if self.engine.current_map:
+ total_width = float((self.engine.current_map.MAX_LANE_NUM + 1) * self.engine.current_map.MAX_LANE_WIDTH)
+ else:
+ total_width = 100
+ lateral_to_left /= total_width
+ lateral_to_right /= total_width
+ goal_dependent_info += [clip(lateral_to_left, 0.0, 1.0), clip(lateral_to_right, 0.0, 1.0)]
+ current_reference_lane = navigation.current_ref_lanes[-1]
+ goal_dependent_info += [
+ # The angular difference between vehicle's heading and the lane heading at this location.
+ vehicle.heading_diff(current_reference_lane),
+ ]
+ goal_dependent_info = np.asarray(goal_dependent_info)
+ assert goal_dependent_info.shape[0] == GOAL_DEPENDENT_STATE_DIM
+ obs.append(goal_dependent_info)
+
+ obs = np.concatenate(obs)
+
+ self.latest_observation["state"] = ego
+ self.latest_observation["raw_navi"] = navi
+
+ return obs
+
+ def state_observe(self, vehicle):
+ # update out of road
+ info = np.zeros([
+ EGO_STATE_DIM,
+ ])
+
+ # The velocity of target vehicle
+ info[0] = clip((vehicle.speed_km_h + 1) / (vehicle.max_speed_km_h + 1), 0.0, 1.0)
+
+ # Current steering
+ info[1] = clip((vehicle.steering / vehicle.MAX_STEERING + 1) / 2, 0.0, 1.0)
+
+ # The normalized actions at last steps
+ info[2] = clip((vehicle.last_current_action[1][0] + 1) / 2, 0.0, 1.0)
+ info[3] = clip((vehicle.last_current_action[1][1] + 1) / 2, 0.0, 1.0)
+
+ # Current angular acceleration (yaw rate)
+ heading_dir_last = vehicle.last_heading_dir
+ heading_dir_now = vehicle.heading
+ cos_beta = heading_dir_now.dot(heading_dir_last) / (norm(*heading_dir_now) * norm(*heading_dir_last))
+ beta_diff = np.arccos(clip(cos_beta, 0.0, 1.0))
+ yaw_rate = beta_diff / 0.1
+ info[4] = clip(yaw_rate, 0.0, 1.0)
+
+ return info
+
+ def side_detector_observe(self, vehicle):
+ return np.asarray(
+ self.engine.get_sensor("side_detector").perceive(
+ vehicle,
+ num_lasers=vehicle.config["side_detector"]["num_lasers"],
+ distance=vehicle.config["side_detector"]["distance"],
+ physics_world=vehicle.engine.physics_world.static_world,
+ show=vehicle.config["show_side_detector"],
+ ).cloud_points
+ )
+
+ def lane_line_detector_observe(self, vehicle):
+ return np.asarray(
+ self.engine.get_sensor("lane_line_detector").perceive(
+ vehicle,
+ vehicle.engine.physics_world.static_world,
+ num_lasers=vehicle.config["lane_line_detector"]["num_lasers"],
+ distance=vehicle.config["lane_line_detector"]["distance"],
+ show=vehicle.config["show_lane_line_detector"],
+ ).cloud_points
+ )
+
+ def vehicle_detector_observe(self, vehicle):
+ cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive(
+ vehicle,
+ physics_world=self.engine.physics_world.dynamic_world,
+ num_lasers=vehicle.config["lidar"]["num_lasers"],
+ distance=vehicle.config["lidar"]["distance"],
+ show=vehicle.config["show_lidar"],
+ )
+ return np.asarray(cloud_points)
+
+ def destroy(self):
+ """
+ Clear allocated memory
+ """
+ self.state_obs.destroy()
+ super(CustomizedObservation, self).destroy()
+ self.cloud_points = None
+ self.detected_objects = None
+
+
+class CustomizedIntersection(InterSectionWithUTurn):
+ PARAMETER_SPACE = ParameterSpace(
+ {
+ Parameter.radius: BoxSpace(min=9, max=20.0),
+ Parameter.change_lane_num: DiscreteSpace(min=0, max=2),
+ Parameter.decrease_increase: DiscreteSpace(min=0, max=0)
+ }
+ )
+
+
+class MultiGoalIntersectionNavigationManager(BaseManager):
+ """
+ This manager is responsible for managing multiple navigation modules, each of which is responsible for guiding the
+ agent to a specific goal.
+ """
+ GOALS = {
+ "u_turn": (-Road(FirstPGBlock.NODE_2, FirstPGBlock.NODE_3)).end_node,
+ "right_turn": Road(
+ CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=0),
+ CustomizedIntersection.node(block_idx=1, part_idx=0, road_idx=1)
+ ).end_node,
+ "go_straight": Road(
+ CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=0),
+ CustomizedIntersection.node(block_idx=1, part_idx=1, road_idx=1)
+ ).end_node,
+ "left_turn": Road(
+ CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=0),
+ CustomizedIntersection.node(block_idx=1, part_idx=2, road_idx=1)
+ ).end_node,
+ }
+
+ def __init__(self):
+ super().__init__()
+ config = self.engine.global_config
+ vehicle_config = config["vehicle_config"]
+ self.navigations = {}
+ navi = NodeNetworkNavigation
+ colors = sns.color_palette("colorblind")
+ for c, (dest_name, road) in enumerate(self.GOALS.items()):
+ self.navigations[dest_name] = navi(
+ # self.engine,
+ show_navi_mark=vehicle_config["show_navi_mark"],
+ show_dest_mark=vehicle_config["show_dest_mark"],
+ show_line_to_dest=vehicle_config["show_line_to_dest"],
+ panda_color=colors[c], # color for navigation marker
+ name=dest_name,
+ vehicle_config=vehicle_config
+ )
+
+ @property
+ def agent(self):
+ return self.engine.agents[DEFAULT_AGENT]
+
+ @property
+ def goals(self):
+ return self.GOALS
+
+ def after_reset(self):
+ """Reset all navigation modules."""
+ # print("[DEBUG]: after_reset in MultiGoalIntersectionNavigationManager")
+ for name, navi in self.navigations.items():
+ navi.reset(self.agent, dest=self.goals[name])
+ navi.update_localization(self.agent)
+
+ def after_step(self):
+ """Update all navigation modules."""
+ # print("[DEBUG]: after_step in MultiGoalIntersectionNavigationManager")
+ for name, navi in self.navigations.items():
+ navi.update_localization(self.agent)
+ # print("Navigation {} next checkpoint: {}".format(name, navi.get_checkpoints()))
+
+ def get_navigation(self, goal_name):
+ """Return the navigation module for the given goal."""
+ assert goal_name in self.goals, "Invalid goal name!"
+ return self.navigations[goal_name]
+
+
+class MultiGoalIntersectionEnv(MetaDriveEnv):
+ """
+ This environment is an intersection with multiple goals. We provide the reward function, observation, termination
+ conditions for each goal in the info dict returned by env.reset and env.step, with prefix "goals/{goal_name}/".
+ """
+ @classmethod
+ def default_config(cls):
+ config = MetaDriveEnv.default_config()
+ # config.update(VaryingDynamicsConfig)
+ config.update(
+ {
+ "use_multigoal_intersection": True,
+
+ # Set the map to an Intersection
+ "start_seed": 0,
+
+ # Even though the map will not change, the traffic flow will change.
+ "num_scenarios": 1000,
+
+ # Remove all traffic vehicles for now.
+ # "traffic_density": 0.2,
+
+ # If the vehicle does not reach the default destination, it will receive a penalty.
+ "wrong_way_penalty": 10.0,
+ # "crash_sidewalk_penalty": 10.0,
+ # "crash_vehicle_penalty": 10.0,
+ # "crash_object_penalty": 10.0,
+ # "out_of_road_penalty": 10.0,
+ "out_of_route_penalty": 0.0,
+ # "success_reward": 10.0,
+ # "driving_reward": 1.0,
+ # "on_continuous_line_done": True,
+ # "out_of_road_done": True,
+ "vehicle_config": {
+
+ # Remove navigation arrows in the window as we are in multi-goal environment.
+ "show_navigation_arrow": False,
+
+ # Turn off vehicle's own navigation module.
+ "side_detector": dict(num_lasers=120, distance=50), # laser num, distance
+ "lidar": dict(num_lasers=120, distance=50),
+
+ # To avoid goal-dependent lane detection, we use Lidar to detect distance to nearby lane lines.
+ # Otherwise, we will ask the navigation module to provide current lane and extract the lateral
+ # distance directly on this lane.
+ "lane_line_detector": dict(num_lasers=0, distance=20)
+ }
+ }
+ )
+ return config
+
+ def _post_process_config(self, config):
+ config = super()._post_process_config(config)
+ if config["use_multigoal_intersection"]:
+ config['map'] = None
+ config['map_config'] = dict(
+ type="block_sequence", config=[
+ CustomizedIntersection,
+ ], lane_num=2, lane_width=3.5
+ )
+ return config
+
+ # def _get_agent_manager(self):
+ # return VaryingDynamicsAgentManager(init_observations=self._get_observations())
+
+ def get_single_observation(self):
+ return CustomizedObservation(self.config)
+
+ # else:
+ # return super().get_single_observation()
+ # img_obs = self.config["image_observation"]
+ # o = ImageStateObservation(self.config) if img_obs else LidarStateObservation(self.config)
+
+ def setup_engine(self):
+ super().setup_engine()
+
+ # Introducing a new navigation manager
+ if self.config["use_multigoal_intersection"]:
+ self.engine.register_manager("goal_manager", MultiGoalIntersectionNavigationManager())
+
+ def _get_step_return(self, actions, engine_info):
+ """Add goal-dependent observation to the info dict."""
+ o, r, tm, tc, i = super(MultiGoalIntersectionEnv, self)._get_step_return(actions, engine_info)
+
+ if self.config["use_multigoal_intersection"]:
+ for goal_name in self.engine.goal_manager.goals.keys():
+ navi = self.engine.goal_manager.get_navigation(goal_name)
+ goal_obs = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT], navi)
+ i["obs/goals/{}".format(goal_name)] = goal_obs
+ assert r == i["reward/default_reward"]
+ assert i["route_completion"] == i["route_completion/goals/default"]
+
+ else:
+ i["obs/goals/default"] = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT])
+ return o, r, tm, tc, i
+
+ def _get_reset_return(self, reset_info):
+ """Add goal-dependent observation to the info dict."""
+ o, i = super(MultiGoalIntersectionEnv, self)._get_reset_return(reset_info)
+
+ if self.config["use_multigoal_intersection"]:
+ for goal_name in self.engine.goal_manager.goals.keys():
+ navi = self.engine.goal_manager.get_navigation(goal_name)
+ goal_obs = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT], navi)
+ i["obs/goals/{}".format(goal_name)] = goal_obs
+
+ else:
+ i["obs/goals/default"] = self.observations["default_agent"].observe(self.agents[DEFAULT_AGENT])
+
+ return o, i
+
+ def _reward_per_navigation(self, vehicle, navi, goal_name):
+ """Compute the reward for the given goal. goal_name='default' means we use the vehicle's own navigation."""
+ reward = 0.0
+
+ # Get goal-dependent information
+ if navi.current_lane in navi.current_ref_lanes:
+ current_lane = navi.current_lane
+ positive_road = 1
+ else:
+ current_lane = navi.current_ref_lanes[0]
+ current_road = navi.current_road
+ positive_road = 1 if not current_road.is_negative_road() else -1
+ long_last, _ = current_lane.local_coordinates(vehicle.last_position)
+ long_now, lateral_now = current_lane.local_coordinates(vehicle.position)
+
+ # Reward for moving forward in current lane
+ reward += self.config["driving_reward"] * (long_now - long_last) * positive_road
+
+ left, right = vehicle._dist_to_route_left_right(navigation=navi)
+ out_of_route = (right < 0) or (left < 0)
+
+ # Reward for speed, sign determined by whether in the correct lanes (instead of driving in the wrong
+ # direction).
+ reward += self.config["speed_reward"] * (vehicle.speed_km_h / vehicle.max_speed_km_h) * positive_road
+ if self._is_arrive_destination(vehicle):
+ if self._is_arrive_destination(vehicle, goal_name):
+ reward += self.config["success_reward"]
+ else:
+ # if goal_name == "default":
+ # print("WRONG WAY")
+ reward = -self.config["wrong_way_penalty"]
+ else:
+ if self._is_out_of_road(vehicle):
+ reward = -self.config["out_of_road_penalty"]
+ elif vehicle.crash_vehicle:
+ reward = -self.config["crash_vehicle_penalty"]
+ elif vehicle.crash_object:
+ reward = -self.config["crash_object_penalty"]
+ elif vehicle.crash_sidewalk:
+ reward = -self.config["crash_sidewalk_penalty"]
+ elif out_of_route:
+ # if goal_name == "default":
+ # print("OUT OF ROUTE")
+ reward = -self.config["out_of_route_penalty"]
+
+ return reward, navi.route_completion
+
+ def reward_function(self, vehicle_id: str):
+ """
+ Compared to the original reward_function, we add goal-dependent reward to info dict.
+ """
+ vehicle = self.agents[vehicle_id]
+ step_info = dict()
+
+ # Compute goal-dependent reward and saved to step_info
+ if self.config["use_multigoal_intersection"]:
+ for goal_name in self.engine.goal_manager.goals.keys():
+ navi = self.engine.goal_manager.get_navigation(goal_name)
+ prefix = goal_name
+ reward, route_completion = self._reward_per_navigation(vehicle, navi, goal_name)
+ step_info[f"reward/goals/{prefix}"] = reward
+ step_info[f"route_completion/goals/{prefix}"] = route_completion
+
+ else:
+ navi = vehicle.navigation
+ goal_name = "default"
+ reward, route_completion = self._reward_per_navigation(vehicle, navi, goal_name)
+ step_info[f"reward/goals/{goal_name}"] = reward
+ step_info[f"route_completion/goals/{goal_name}"] = route_completion
+
+ default_reward, default_rc = self._reward_per_navigation(vehicle, vehicle.navigation, "default")
+ step_info[f"reward/goals/default"] = default_reward
+ step_info[f"route_completion/goals/default"] = default_rc
+ step_info[f"reward/default_reward"] = default_reward
+ step_info[f"route_completion"] = vehicle.navigation.route_completion
+
+ return default_reward, step_info
+
+ def _is_arrive_destination(self, vehicle, goal_name=None):
+ """
+ Compared to the original function, here we look up the navigation from goal_manager.
+
+ Args:
+ vehicle: The BaseVehicle instance.
+ goal_name: The name of the goal. If None, return True if any goal is arrived.
+
+ Returns:
+ flag: Whether this vehicle arrives its destination.
+ """
+
+ if self.config["use_multigoal_intersection"]:
+ if goal_name is None:
+ ret = False
+ for name in self.engine.goal_manager.goals.keys():
+ ret = ret or self._is_arrive_destination(vehicle, name)
+ return ret
+
+ if goal_name == "default":
+ navi = self.vehicle.navigation
+ else:
+ navi = self.engine.goal_manager.get_navigation(goal_name)
+
+ else:
+ navi = vehicle.navigation
+
+ long, lat = navi.final_lane.local_coordinates(vehicle.position)
+ flag = (navi.final_lane.length - 5 < long < navi.final_lane.length + 5) and (
+ navi.get_current_lane_width() / 2 >= lat >=
+ (0.5 - navi.get_current_lane_num()) * navi.get_current_lane_width()
+ )
+ return flag
+
+ def done_function(self, vehicle_id: str):
+ """
+ Compared to MetaDriveEnv's done_function, we add more stats here to record which goal is arrived.
+ """
+ done, done_info = super(MultiGoalIntersectionEnv, self).done_function(vehicle_id)
+ vehicle = self.agents[vehicle_id]
+
+ if self.config["use_multigoal_intersection"]:
+ for goal_name in self.engine.goal_manager.goals.keys():
+ done_info[f"arrive_dest/goals/{goal_name}"] = self._is_arrive_destination(vehicle, goal_name)
+
+ else:
+ done_info[f"arrive_dest/goals/default"] = done
+
+ done_info["arrive_dest/goals/default"] = self._is_arrive_destination(vehicle, "default")
+
+ return done, done_info
+
+
+if __name__ == "__main__":
+ config = dict(
+ use_render=True,
+ manual_control=True,
+ vehicle_config=dict(
+ show_navi_mark=True,
+ show_line_to_navi_mark=True,
+ show_lidar=False,
+ show_side_detector=True,
+ show_lane_line_detector=True,
+ ),
+
+ # ********************************************
+ use_multigoal_intersection=False
+ # ********************************************
+
+ # **{
+ # "map_config": dict(
+ # lane_num=5,
+ # lane_width=3.5
+ # ),
+ # }
+ )
+ env = MultiGoalIntersectionEnv(config)
+ episode_rewards = defaultdict(float)
+ try:
+ o, info = env.reset()
+
+ # default_ckpt = env.vehicle.navigation.checkpoints[-1]
+ # for goal, navi in env.engine.goal_manager.navigations.items():
+ # if navi.checkpoints[-1] == default_ckpt:
+ # break
+ # assert np.all(o == info["obs/goals/{}".format(goal)])
+
+ goal = "default"
+
+ print('=======================')
+ print("Full observation shape:\n\t", o.shape)
+ print("Goal-agnostic observation shape:\n\t", {k: v.shape for k, v in info.items() if k.startswith("obs/ego")})
+ print("Observation shape for each goals: ")
+ for k in sorted(info.keys()):
+ if k.startswith("obs/goals/"):
+ print(f"\t{k}: {info[k].shape}")
+ print('=======================')
+
+ obs_recorder = defaultdict(list)
+
+ s = 0
+ for i in range(1, 1000000000):
+ o, r, tm, tc, info = env.step([0, 1])
+
+ assert np.all(o == info["obs/goals/{}".format(goal)])
+ assert np.all(r == info["reward/goals/{}".format(goal)])
+
+ done = tm or tc
+ s += 1
+ # env.render()
+ env.render(mode="topdown")
+
+ for k in info.keys():
+ if k.startswith("obs/goals"):
+ obs_recorder[k].append(info[k])
+
+ for k, v in info.items():
+ if k.startswith("reward/goals"):
+ episode_rewards[k] += v
+
+ if s % 20 == 0:
+ print('\n===== timestep {} ====='.format(s))
+ print('goal: ', goal)
+ print('route completion:')
+ for k in sorted(info.keys()):
+ if k.startswith("route_completion/goals/"):
+ print(f"\t{k}: {info[k]:.2f}")
+
+ print('\nreward:')
+ for k in sorted(info.keys()):
+ if k.startswith("reward/"):
+ print(f"\t{k}: {info[k]:.2f}")
+ print('=======================')
+
+ if done:
+ print('\n===== timestep {} ====='.format(s))
+ print("EPISODE DONE\n")
+ print('route completion:')
+ for k in sorted(info.keys()):
+ # kk = k.replace("/route_completion", "")
+ if k.startswith("route_completion/goals/"):
+ print(f"\t{k}: {info[k]:.2f}")
+
+ print('\narrive destination (success):')
+ for k in sorted(info.keys()):
+ # kk = k.replace("/arrive_dest", "")
+ if k.startswith("arrive_dest/goals/"):
+ print(f"\t{k}: {info[k]:.2f}")
+
+ print('\nepisode_rewards:')
+ for k in sorted(episode_rewards.keys()):
+ # kk = k.replace("/step_reward", "")
+ print(f"\t{k}: {episode_rewards[k]:.2f}")
+ episode_rewards.clear()
+ print('=======================')
+
+ if done:
+
+ import numpy as np
+
+ # for t in range(i):
+ # # avg = [v[t] for k, v in obs_recorder.items()]
+ # v = np.stack([v[0] for k, v in obs_recorder.items()])
+
+ print('\n\n\n')
+ o, info = env.reset()
+
+ default_ckpt = env.vehicle.navigation.checkpoints[-1]
+ # for goal, navi in env.engine.goal_manager.navigations.items():
+ # if navi.checkpoints[-1] == default_ckpt:
+ # break
+ #
+ # assert np.all(o == info["obs/goals/{}".format(goal)])
+
+ s = 0
+ finally:
+ env.close()
diff --git a/metadrive/examples/train_generalization_experiment.py b/metadrive/examples/train_generalization_experiment.py
index f5e31c596..ce65b7040 100755
--- a/metadrive/examples/train_generalization_experiment.py
+++ b/metadrive/examples/train_generalization_experiment.py
@@ -3,7 +3,12 @@
in the same test set using rllib.
We verified this script with ray==2.2.0. Please report to use if you find newer version of ray is not compatible with
-this script.
+this script. Installation guide:
+
+ pip install ray[rllib]==2.2.0
+ pip install tensorflow_probability==0.24.0
+ pip install torch
+
"""
import argparse
import copy
diff --git a/metadrive/manager/traffic_manager.py b/metadrive/manager/traffic_manager.py
index e6bd4f1c4..20bc49b10 100644
--- a/metadrive/manager/traffic_manager.py
+++ b/metadrive/manager/traffic_manager.py
@@ -80,12 +80,12 @@ def before_step(self):
engine = self.engine
if self.mode != TrafficMode.Respawn:
for v in engine.agent_manager.active_agents.values():
- ego_lane_idx = v.lane_index[:-1]
- ego_road = Road(ego_lane_idx[0], ego_lane_idx[1])
- if len(self.block_triggered_vehicles) > 0 and \
- ego_road == self.block_triggered_vehicles[-1].trigger_road:
- block_vehicles = self.block_triggered_vehicles.pop()
- self._traffic_vehicles += list(self.get_objects(block_vehicles.vehicles).values())
+ if len(self.block_triggered_vehicles) > 0:
+ ego_lane_idx = v.lane_index[:-1]
+ ego_road = Road(ego_lane_idx[0], ego_lane_idx[1])
+ if ego_road == self.block_triggered_vehicles[-1].trigger_road:
+ block_vehicles = self.block_triggered_vehicles.pop()
+ self._traffic_vehicles += list(self.get_objects(block_vehicles.vehicles).values())
for v in self._traffic_vehicles:
p = self.engine.get_policy(v.name)
v.before_step(p.act())
@@ -266,7 +266,8 @@ def _create_vehicles_once(self, map: BaseMap, traffic_density: float) -> None:
vehicle_type = self.random_vehicle_type()
v_config.update(self.engine.global_config["traffic_vehicle_config"])
random_v = self.spawn_object(vehicle_type, vehicle_config=v_config)
- self.add_policy(random_v.id, IDMPolicy, random_v, self.generate_seed())
+ seed = self.generate_seed()
+ self.add_policy(random_v.id, IDMPolicy, random_v, seed)
vehicles_on_block.append(random_v.name)
trigger_road = block.pre_block_socket.positive_road
@@ -310,8 +311,7 @@ def destroy(self) -> None:
# current map
# traffic vehicle list
- self._traffic_vehicles = None
- self.block_triggered_vehicles = None
+ self.block_triggered_vehicles = []
# traffic property
self.mode = None
diff --git a/metadrive/obs/state_obs.py b/metadrive/obs/state_obs.py
index 69463a8e9..0c71e4e64 100644
--- a/metadrive/obs/state_obs.py
+++ b/metadrive/obs/state_obs.py
@@ -15,7 +15,7 @@ def __init__(self, config):
if config["vehicle_config"]["navigation_module"]:
navi_dim = config["vehicle_config"]["navigation_module"].get_navigation_info_dim()
else:
- navi_dim = NodeNetworkNavigation.get_navigation_info_dim()
+ navi_dim = 0
self.navi_dim = navi_dim
super(StateObservation, self).__init__(config)
@@ -56,9 +56,12 @@ def observe(self, vehicle):
:param vehicle: BaseVehicle
:return: Vehicle State + Navigation information
"""
- navi_info = vehicle.navigation.get_navi_info()
ego_state = self.vehicle_state(vehicle)
- ret = np.concatenate([ego_state, navi_info])
+ if self.navi_dim > 0:
+ navi_info = vehicle.navigation.get_navi_info()
+ ret = np.concatenate([ego_state, navi_info])
+ else:
+ ret = np.asarray(ego_state)
return ret.astype(np.float32)
def vehicle_state(self, vehicle):
@@ -89,8 +92,8 @@ def vehicle_state(self, vehicle):
# If the side detector is turn off, then add the distance to left and right road borders as state.
lateral_to_left, lateral_to_right, = vehicle.dist_to_left_side, vehicle.dist_to_right_side
- if vehicle.navigation.map:
- total_width = float((vehicle.navigation.map.MAX_LANE_NUM + 1) * vehicle.navigation.map.MAX_LANE_WIDTH)
+ if self.engine.current_map:
+ total_width = float((self.engine.current_map.MAX_LANE_NUM + 1) * self.engine.current_map.MAX_LANE_WIDTH)
else:
total_width = 100
lateral_to_left /= total_width
diff --git a/metadrive/tests/test_component/test_lane_line_detector.py b/metadrive/tests/test_component/test_lane_line_detector.py
index 09d0211de..3fa091eea 100644
--- a/metadrive/tests/test_component/test_lane_line_detector.py
+++ b/metadrive/tests/test_component/test_lane_line_detector.py
@@ -264,132 +264,28 @@
]
pg_gt_3 = [
- 0.17000000178813934,
- 0.18000000715255737,
- 0.18000000715255737,
- 0.1899999976158142,
- 0.20000000298023224,
- 0.2199999988079071,
- 0.23999999463558197,
- 0.27000001072883606,
- 0.33000001311302185,
- 0.4099999964237213,
- 0.5600000023841858,
- 1.0,
- 1.0,
- 0.550000011920929,
- 0.18000000715255737,
- 0.10999999940395355,
- 0.07999999821186066,
- 0.05999999865889549,
- 0.05000000074505806,
- 0.05000000074505806,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.029999999329447746,
- 0.029999999329447746,
- 0.029999999329447746,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.05000000074505806,
- 0.05000000074505806,
- 0.05999999865889549,
- 0.07999999821186066,
- 0.10999999940395355,
- 0.18000000715255737,
- 0.47999998927116394,
- 0.7099999785423279,
- 0.8799999952316284,
- 1.0,
- 0.4099999964237213,
- 0.33000001311302185,
- 0.27000001072883606,
- 0.23999999463558197,
- 0.2199999988079071,
- 0.20000000298023224,
- 0.1899999976158142,
- 0.18000000715255737,
- 0.18000000715255737,
- 0.5,
- 0.009999999776482582,
- 0.5,
- 0.5,
- 0.5,
- 0.0,
- 0.17000000178813934,
- 0.029999999329447746,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.05000000074505806,
- 0.27000001072883606,
- 0.1899999976158142,
- 0.4099999964237213,
- 0.10999999940395355,
- 0.18000000715255737,
- 0.5600000023841858,
- 0.550000011920929,
- 0.18000000715255737,
- 0.10999999940395355,
- 0.07999999821186066,
- 0.05999999865889549,
- 0.05000000074505806,
- 0.05000000074505806,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.029999999329447746,
- 0.029999999329447746,
- 0.029999999329447746,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.03999999910593033,
- 0.05000000074505806,
- 0.05000000074505806,
- 0.05999999865889549,
- 0.07999999821186066,
- 0.10999999940395355,
- 0.18000000715255737,
- 0.47999998927116394,
- 0.7099999785423279,
- 0.7400000095367432,
- 0.7799999713897705,
- 0.07999999821186066,
- 0.05999999865889549,
- 0.05000000074505806,
- 0.23999999463558197,
- 0.12999999523162842,
- 0.11999999731779099,
- 0.10999999940395355,
- 0.18000000715255737,
- 0.18000000715255737,
- 0.30000001192092896,
- 0.4300000071525574,
- 0.0,
- 0.5,
- 0.5,
- 0.699999988079071,
- 0.4300000071525574,
- 0.0,
- 0.5,
- 0.5,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
- 1.0,
+ 0.17424996197223663, 0.17563384771347046, 0.179901123046875, 0.18740931153297424, 0.19884595274925232,
+ 0.21538487076759338, 0.23903638124465942, 0.273365318775177, 0.32519838213920593, 0.40924975275993347,
+ 0.5638853311538696, 1.0, 1.0, 0.5454577803611755, 0.18278196454048157, 0.11083516478538513, 0.08044067770242691,
+ 0.06391915678977966, 0.05373150855302811, 0.04698386788368225, 0.04233507066965103, 0.03907269984483719,
+ 0.03683317080140114, 0.03535793721675873, 0.034397244453430176, 0.03412747383117676, 0.034522198140621185,
+ 0.0353609174489975, 0.036836810410022736, 0.03908421844244003, 0.04233531653881073, 0.04698418080806732,
+ 0.05373169109225273, 0.06391977518796921, 0.08044077455997467, 0.11082261800765991, 0.18277062475681305,
+ 0.45523008704185486, 0.7089459300041199, 0.8811144232749939, 1.0, 0.40924930572509766, 0.325198233127594,
+ 0.2732689678668976, 0.23903624713420868, 0.21538473665714264, 0.19884586334228516, 0.1874106079339981,
+ 0.17990022897720337, 0.17563492059707642, 0.49999988079071045, 0.012532129883766174, 0.5, 0.5, 0.5,
+ 1.3244441561255371e-06, 0.17424996197223663, 0.03450844809412956, 0.035360947251319885, 0.03683680295944214,
+ 0.03908339887857437, 0.04233534261584282, 0.04698362201452255, 0.273365318775177, 0.19455832242965698,
+ 0.40924975275993347, 0.1108340248465538, 0.18278250098228455, 0.5611181855201721, 0.5454577803611755,
+ 0.18278196454048157, 0.11083516478538513, 0.08044067770242691, 0.06391915678977966, 0.05373150855302811,
+ 0.04698386788368225, 0.04233507066965103, 0.03907269984483719, 0.03683317080140114, 0.03535793721675873,
+ 0.034397244453430176, 0.03412747383117676, 0.034522198140621185, 0.0353609174489975, 0.036836810410022736,
+ 0.03908421844244003, 0.04233531653881073, 0.04698418080806732, 0.05373169109225273, 0.06391977518796921,
+ 0.08044077455997467, 0.11082261800765991, 0.18277062475681305, 0.45523008704185486, 0.7089459300041199,
+ 0.7421090602874756, 0.7847582697868347, 0.08044064044952393, 0.0639198049902916, 0.053731828927993774,
+ 0.23903624713420868, 0.1288066953420639, 0.11896516382694244, 0.11209990084171295, 0.17990022897720337,
+ 0.17563492059707642, 0.3000001609325409, 0.43000006675720215, 0.0, 0.5, 0.5, 0.699988842010498, 0.4299999475479126,
+ 0.0, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0
]
@@ -412,7 +308,7 @@ def test_pg_map(render=False):
)
try:
env.reset()
- env.vehicle.set_position([73, 12])
+ env.agent.set_position([73, 12])
for s in range(1, 5):
o, r, tm, tc, info = env.step([0, 0])
@@ -422,7 +318,7 @@ def test_pg_map(render=False):
print("]")
np.testing.assert_almost_equal(pg_gt_1, np.round(o, 2), decimal=2)
- env.vehicle.set_position([30, 3.5])
+ env.agent.set_position([30, 3.5])
for s in range(1, 5):
o, r, tm, tc, info = env.step([0, 0])
print("[")
@@ -431,13 +327,16 @@ def test_pg_map(render=False):
print("]")
np.testing.assert_almost_equal(np.array(pg_gt_2), o, decimal=2)
- env.vehicle.set_position([30, 10.5])
+ env.agent.set_position([30, 10.5])
for s in range(1, 5):
o, r, tm, tc, info = env.step([0, 0])
print("[")
- for _o in o:
- print("{},".format(round(_o, 2)))
+ for ind, _o in enumerate(o):
+ print("{:.2f}, GT: {:.2f}".format(_o, pg_gt_3[ind]))
print("]")
+
+ print(o.tolist())
+
np.testing.assert_almost_equal(np.array(pg_gt_3), o, decimal=2)
finally:
env.close()
@@ -755,7 +654,7 @@ def test_nuscenes(render=False):
)
try:
env.reset(seed=0)
- env.vehicle.set_position([-9.4, -27.2])
+ env.agent.set_position([-9.4, -27.2])
for s in range(1, 5):
o, r, tm, tc, info = env.step([0, 0])
@@ -766,7 +665,7 @@ def test_nuscenes(render=False):
np.testing.assert_almost_equal(nuscenes_gt_1, o, decimal=3)
env.reset(seed=1)
- env.vehicle.set_position([79.96, -6.2])
+ env.agent.set_position([79.96, -6.2])
for s in range(1, 5):
o, r, tm, tc, info = env.step([0, 0])
@@ -781,4 +680,4 @@ def test_nuscenes(render=False):
if __name__ == '__main__':
# test_nuscenes(True)
- test_pg_map(True)
+ test_pg_map(False)
diff --git a/metadrive/utils/registry.py b/metadrive/utils/registry.py
index 4411b2fce..b0f1dc87f 100644
--- a/metadrive/utils/registry.py
+++ b/metadrive/utils/registry.py
@@ -13,7 +13,7 @@ def _initialize_registry():
from metadrive.component.pgblock.parking_lot import ParkingLot
from metadrive.component.pgblock.ramp import InRampOnStraight, OutRampOnStraight
from metadrive.component.pgblock.roundabout import Roundabout
- from metadrive.component.pgblock.std_intersection import StdInterSection
+ from metadrive.component.pgblock.std_intersection import StdInterSection, StdInterSectionWithUTurn
from metadrive.component.pgblock.std_t_intersection import StdTInterSection
from metadrive.component.pgblock.straight import Straight
from metadrive.component.pgblock.tollgate import TollGate
@@ -21,7 +21,7 @@ def _initialize_registry():
_metadrive_class_list.extend(
[
Merge, Split, Curve, InFork, OutFork, ParkingLot, InRampOnStraight, OutRampOnStraight, Roundabout,
- StdInterSection, StdTInterSection, Straight, TollGate, Bidirection
+ StdInterSection, StdTInterSection, StdInterSectionWithUTurn, Straight, TollGate, Bidirection
]
)
diff --git a/setup.py b/setup.py
index da5e82984..affbbbbb1 100644
--- a/setup.py
+++ b/setup.py
@@ -61,6 +61,7 @@ def is_win():
"shapely",
"filelock",
"Pygments",
+ "mediapy"
]