diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml new file mode 100644 index 0000000..b0254dc --- /dev/null +++ b/.github/workflows/doc.yml @@ -0,0 +1,34 @@ +name: doc + +on: + push: + branches: [main, 'doc*'] + pull_request: + branches: [main, 'doc*'] + +jobs: + doc: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.7] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Generate + run: | + python -m pip install .[doc] + make -C ./docs html + mv docs/build/html public + rm -rf docs/build + - name: Deploy + uses: JamesIves/github-pages-deploy-action@3.7.1 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + BRANCH: gh-pages # The branch the action should deploy to. + FOLDER: public # The folder the action should deploy. + CLEAN: true # Automatically remove deleted files from the deploy branch \ No newline at end of file diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 232cf87..c4462f2 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -6,7 +6,7 @@ name: unit_test on: [push, pull_request] jobs: - test_test: + test_unitest: runs-on: ubuntu-latest if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: diff --git a/README.md b/README.md index c3947cb..fd49d35 100644 --- a/README.md +++ b/README.md @@ -21,59 +21,13 @@ DI-smartcross supports: ## Installation -DI-smartcross supports SUMO version >= 1.6.0. Here we show an easy guide of installation with SUMO 1.8.0 on Linux. +DI-smartcross supports SUMO version >= 1.6.0. You can refer to +[SUMO documentation](https://sumo.dlr.de/docs/Installing/index.html) or follow our installation guidance in +[documents](https://opendilab.github.io/DI-smartcross/installation.html). -### Install sumo - -1. install required libraries and dependencies - -```bash -sudo apt-get install cmake python g++ libxerces-c-dev libfox-1.6-dev libgdal-dev libproj-dev libgl2ps-dev swig -``` - -2. download and unzip the installation package - -```bash -tar xzf sumo-src-1.8.0.tar.gz -cd sumo-1.8.0 -pwd -``` - -3. compile sumo - -```bash -mkdir build/cmake-build -cd build/cmake-build -cmake ../.. -make -j $(nproc) -``` - -4. environment variables - -```bash -echo 'export PATH=$HOME/sumo-1.8.0/bin:$PATH -export SUMO_HOME=$HOME/sumo-1.8.0' | tee -a $HOME/.bashrc -source ~/.bashrc -``` - -5. check install - -```bash -sumo -``` -If success, the following message will be shown in the shell. - -``` -Eclipse SUMO sumo Version 1.8.0 - Build features: Linux-3.10.0-957.el7.x86_64 x86_64 GNU 5.3.1 Release Proj GUI SWIG GDAL GL2PS - Copyright (C) 2001-2020 German Aerospace Center (DLR) and others; https://sumo.dlr.de - License EPL-2.0: Eclipse Public License Version 2 - Use --help to get the list of options. -``` - -### Install DI-smartcross - -To install DI-smartcross, simply run `pip install` in the root folder of this repository. This will automatically insall [DI-engine](https://github.com/opendilab/DI-engine) as well. +Then, DI-smartcross is able to be installed from source code. +Simply run `pip install` in the root folder of this repository. +This will automatically insall [DI-engine](https://github.com/opendilab/DI-engine) as well. ```bash pip install -e . --user @@ -81,126 +35,29 @@ pip install -e . --user ## Quick Start -### Run training and evaluation - -DI-smartcross supports DQN, Off-policy PPO and Rainbow DQN RL methods with multi-discrete actions for each crossing. A set of default DI-engine configs is provided for each policy. You can check the document of DI-engine to get detail instructions of these configs. +DI-smartcross provides simple entry for RL training and evaluation. DI-smartcross supports DQN, Off-policy PPO +and Rainbow DQN RL methods with multi-discrete actions for each crossing, as well as multi-agent RL policies +in which each crossing is handled by a individual agent. A set of default DI-engine configs is provided for +each policy. You can check the document of DI-engine to get detail instructions of these configs. - train RL policies -``` -usage: sumo_train [-h] -d DING_CFG -e ENV_CFG [-s SEED] [--dynamic-flow] - [-cn COLLECT_ENV_NUM] [-en EVALUATE_ENV_NUM] - [--exp-name EXP_NAME] - -DI-smartcross training script - -optional arguments: - -h, --help show this help message and exit - -d DING_CFG, --ding-cfg DING_CFG - DI-engine configuration path - -e ENV_CFG, --env-cfg ENV_CFG - sumo environment configuration path - -s SEED, --seed SEED random seed for sumo - --dynamic-flow use dynamic route flow - -cn COLLECT_ENV_NUM, --collect-env-num COLLECT_ENV_NUM - collector sumo env num for training - -en EVALUATE_ENV_NUM, --evaluate-env-num EVALUATE_ENV_NUM - evaluator sumo env num for training - --exp-name EXP_NAME experiment name to save log and ckpt -``` - Example of running DQN in wj3 env with default config. ```bash -sumo_train -e smartcross/envs/sumo_arterial_wj3_default_config.yaml -d entry/config/sumo_wj3_dqn_default_config.py +sumo_train -e smartcross/envs/sumo_wj3_default_config.yaml -d entry/config/sumo_wj3_dqn_default_config.py ``` - evaluate existing policies -``` -usage: sumo_eval [-h] [-d DING_CFG] -e ENV_CFG [-s SEED] - [-p {random,fix,dqn,rainbow,ppo}] [--dynamic-flow] - [-n ENV_NUM] [--gui] [-c CKPT_PATH] - -DI-smartcross training script - -optional arguments: - -h, --help show this help message and exit - -d DING_CFG, --ding-cfg DING_CFG - DI-engine configuration path - -e ENV_CFG, --env-cfg ENV_CFG - sumo environment configuration path - -s SEED, --seed SEED random seed for sumo - -p {random,fix,dqn,rainbow,ppo}, --policy-type {random,fix,dqn,rainbow,ppo} - RL policy type - --dynamic-flow use dynamic route flow - -n ENV_NUM, --env-num ENV_NUM - sumo env num for evaluation - --gui open gui for visualize - -c CKPT_PATH, --ckpt-path CKPT_PATH - model ckpt path -``` - Example of running random policy in wj3 env. ```bash -sumo_eval -p random -e smartcross/envs/sumo_arterial_wj3_default_config.yaml +sumo_eval -p random -e smartcross/envs/sumo_wj3_default_config.yaml ``` -## Environments - -### sumo env configuration - -The configuration of sumo env is stored in a config *.yaml* file. You can take a look at the default config file to see how to modify env settings. - -``` python -import yaml -from easy_dict import EasyDict -from smartcross.env import SumoEnv - -with open('smartcross/envs/sumo_arterial_wj3_default_config.yaml') as f: - cfg = yaml.safe_load(f) -cfg = EasyDict(cfg) -env = SumoEnv(config=cfg.env) -``` - -The env configuration consists of basic definition and observation\action\reward settings. The basic definition includes the cumo config file, episode length and light duration. The obs\action\reward define the detail setting of each contains. - -```yaml -env: - sumocfg_path: 'arterial_wj3/rl_wj.sumocfg' - max_episode_steps: 1500 - green_duration: 10 - yellow_duration: 3 - obs: - ... - action: - ... - reward: - ... -``` - -### Observation - -We provide several types of observations of a traffic cross. If `use_centrolized_obs` is set `True`, the observation of each cross will be concatenated into one vector. The contents of observation can me modified by setting `obs_type`. The following observation is supported now. - -- phase: One-hot phase vector of current cross signal -- lane_pos_vec: Lane occupancy in each grid position. The grid num can be set with `lane_grid_num` -- traffic_volumn: Traffic volumn of each lane. Vehicle num / lane length * volumn ratio -- queue_len: Vehicle waiting queue length of each lane. Waiting num / lane length * volumn ratio - -### Action - -Sumo environment supports changing cross signal to target phase. The action space is set to multi-discrete for each cross to reduce action num. - -### Reward - -Reward can be set with `reward_type`. Reward is calculated cross by cross. If `use_centrolized_obs` is set True, the reward of each cross will be summed up. - -- queue_len: Vehicle waiting queue num of each lane -- wait_time: Wait time increment of vehicles in each lane -- delay_time: Delay time of all vahicles in incomming and outgoing lanes -- pressure: Pressure of a cross +It is rerecommended to refer to [documation](https://opendilab.github.io/DI-smartcross/index.html) +for detail information. ## Contributing diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..4f33ed6 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,3 @@ +make.bat +static/ +_build/ diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/figs/di-smartcross_logo.png b/docs/figs/di-smartcross_logo.png similarity index 100% rename from figs/di-smartcross_logo.png rename to docs/figs/di-smartcross_logo.png diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..ac177cf --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,89 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +sys.path.insert(0, os.path.abspath('../..')) +print(sys.path) + + +# -- Project information ----------------------------------------------------- + +project = 'DI-smartcross' +copyright = '2021, OpenDILab' +author = 'OpenDILab' + +# The full version, including alpha/beta/rc tags +version = '0.1.0' +release = '0.1.0' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.mathjax', + 'sphinx.ext.ifconfig', + 'sphinx.ext.viewcode', + 'sphinx.ext.githubpages', + #'myst_parser', +] + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = ['.rst', '.md'] +# source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + + + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +#html_theme = 'alabaster' +html_theme = 'sphinx_rtd_theme' + +# Output file base name for HTML help builder. +htmlhelp_basename = 'dismartcrossdoc' + + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ['_static'] + + +# from recommonmark.parser import CommonMarkParser +# source_parsers = { +# '.md': CommonMarkParser, +# } +# source_suffix = ['.rst', '.md'] \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..e915703 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,59 @@ +.. DI-drive documentation master file, created by + sphinx-quickstart on Mon Jan 25 13:49:15 2021. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +DI-smartcross Documentation +############################## + +.. toctree:: + :maxdepth: 2 + :hidden: + :caption: First steps + + installation + quick_start + rl_environments + + +.. figure:: ../figs/di-smartcross_logo.png + :alt: DI-smartcross + :width: 500px + +Decision Intelligence Platform for Traffic Crossing Signal Control. + +Last updated on + +----- + +**DI-smartcross** is an open-source application platform under **OpenDILab**. +**DI-smartcross** uses Reinforcement Learning in precise control of traffic crossing signals in order +to optimize transportation time cost by coordinating vehicles' movements at crosses. +**DI-smartcross** applies training & evaluation for various RL policies using `DI-engine `_ +in provided road nets. +**DI-smartcross** supports `SUMO `_ and +`CityFlow `_ simulators to enable +traffic flow simulation with different granularity. + + +Main Features +================= + +- Design easy-to-use crossing signal control environments, with various State, Action, and Reward options. + +- Build a variety of road networks of different scales, ideal or from the real world. + +- Adapting several Reinforcement Learning strategies using **DI-engine**, including discrete or continuous space, multi-agent etc. + + +Content +============== + +`Installation `_ +-------------------------------------- + +`Quick Start `_ +----------------------------- + +`RL Environments `_ +---------------------------------------- diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 0000000..13a3e7a --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,110 @@ +Installation +################# + +.. toctree:: + :maxdepth: 2 + +SUMO installation +===================== + +**DI-smartcross** support SUMO version >= 1.6.0. Here we show two easy guides +of SUMO installation on Linux. + +Install SUMO via apt-get or homebrew +-------------------------------------- + +On Debian or Ubuntu, SUMO can be directly installed using ``apt``: + +.. code:: bash + + sudo add-apt-repository ppa:sumo/stable + sudo apt-get update + sudo apt-get install sumo sumo-tools sumo-doc + +On macOS, SUMO can be installed using ``homebrew``. + +.. code:: bash + + brew update + brew tap dlr-ts/sumo + brew install sumo + brew install --cask sumo-gui + +After that, you need to set the ``SUMO_HOME`` environment variable pointing to the directory +of your SUMO installation. Just insert the following line at the end of ``.bashrc``: + +.. code:: bash + + export SUMO_HOME=/your/path/to/sumo + +There might be some trouble arosen when installing with the method above. It is recommended +to build and install SUMO from source code as follows. + +Install SUMO from source code +--------------------------------- + +Here we show step-by-step guidance of installation with SUMO 1.8.0 on Linux. + +1. install required libraries and dependencies + +.. code:: bash + + sudo apt-get install cmake python g++ libxerces-c-dev libfox-1.6-dev libgdal-dev libproj-dev libgl2ps-dev swig + +2. download and unzip the installation package + +.. code:: bash + + tar xzf sumo-src-1.8.0.tar.gz + cd sumo-1.8.0 + pwd + +3. compile SUMO + +.. code:: bash + + mkdir build/cmake-build + cd build/cmake-build + cmake ../.. + make -j $(nproc) + +4. environment variables + +.. code:: bash + + echo 'export PATH=$HOME/sumo-1.8.0/bin:$PATH + export SUMO_HOME=$HOME/sumo-1.8.0' | tee -a $HOME/.bashrc + source ~/.bashrc + +5. check install + +.. code:: bash + + sumo + +If successful, the following message will be shown in the shell. + +.. code:: + + Eclipse SUMO sumo Version 1.8.0 + Build features: Linux-3.10.0-957.el7.x86_64 x86_64 GNU 5.3.1 Release Proj GUI SWIG GDAL GL2PS + Copyright (C) 2001-2020 German Aerospace Center (DLR) and others; https://sumo.dlr.de + License EPL-2.0: Eclipse Public License Version 2 + Use --help to get the list of options. + +Install DI-smartcross +========================== + + +Simply run `pip install` in the root folder of this repository. This will automatically +install `DI-engine `_ as well. + +.. code:: bash + + pip install -e . --user + +You can check the installation by running the following command. + +.. code:: bash + + python -c 'import ding; import smartcross' diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst new file mode 100644 index 0000000..ffb72a1 --- /dev/null +++ b/docs/source/quick_start.rst @@ -0,0 +1,79 @@ +Quick Start +############### + +.. toctree:: + :maxdepth: 2 + +SUMO entries +================ + +**DI-smartcross** supports DQN, Off-policy PPO and Rainbow DQN RL methods +with multi-discrete actions for each crossing. A set of default **DI-engine** +configs is provided for each policy. You can check the document of DI-engine +to get detail instructions of these configs. + +train RL policies +-------------------- + +.. code:: + + usage: sumo_train [-h] -d DING_CFG -e ENV_CFG [-s SEED] [--dynamic-flow] + [-cn COLLECT_ENV_NUM] [-en EVALUATE_ENV_NUM] + [--exp-name EXP_NAME] + + DI-smartcross training script + + optional arguments: + -h, --help show this help message and exit + -d DING_CFG, --ding-cfg DING_CFG + DI-engine configuration path + -e ENV_CFG, --env-cfg ENV_CFG + sumo environment configuration path + -s SEED, --seed SEED random seed for sumo + --dynamic-flow use dynamic route flow + -cn COLLECT_ENV_NUM, --collect-env-num COLLECT_ENV_NUM + collector sumo env num for training + -en EVALUATE_ENV_NUM, --evaluate-env-num EVALUATE_ENV_NUM + evaluator sumo env num for training + --exp-name EXP_NAME experiment name to save log and ckpt + + +Example of running DQN in wj3 env with default config. + +.. code:: bash + + sumo_train -e smartcross/envs/sumo_wj3_default_config.yaml -d entry/config/sumo_wj3_dqn_default_config.py + +evaluate existing policies +-------------------------------- + +.. code:: + + usage: sumo_eval [-h] [-d DING_CFG] -e ENV_CFG [-s SEED] + [-p {random,fix,dqn,rainbow,ppo}] [--dynamic-flow] + [-n ENV_NUM] [--gui] [-c CKPT_PATH] + + DI-smartcross testing script + + optional arguments: + -h, --help show this help message and exit + -d DING_CFG, --ding-cfg DING_CFG + DI-engine configuration path + -e ENV_CFG, --env-cfg ENV_CFG + sumo environment configuration path + -s SEED, --seed SEED random seed for sumo + -p {random,fix,dqn,rainbow,ppo}, --policy-type {random,fix,dqn,rainbow,ppo} + RL policy type + --dynamic-flow use dynamic route flow + -n ENV_NUM, --env-num ENV_NUM + sumo env num for evaluation + --gui open gui for visualize + -c CKPT_PATH, --ckpt-path CKPT_PATH + model ckpt path + + +Example of running random policy in wj3 env. + +.. code:: bash + + sumo_eval -p random -e smartcross/envs/sumo_wj3_default_config.yaml \ No newline at end of file diff --git a/docs/source/rl_environments.rst b/docs/source/rl_environments.rst new file mode 100644 index 0000000..d04494c --- /dev/null +++ b/docs/source/rl_environments.rst @@ -0,0 +1,71 @@ +Reinforcement Learning Environments +######################################## + + +SUMO environments +==================== + +configuration +----------------- + +The configuration of sumo env is stored in a config ``.yaml`` file. You can take a look at the default config file to see how to modify env settings. + +.. code:: python + + import yaml + from easy_dict import EasyDict + from smartcross.env import SumoEnv + + with open('smartcross/envs/sumo_wj3_default_config.yaml') as f: + cfg = yaml.safe_load(f) + cfg = EasyDict(cfg) + env = SumoEnv(config=cfg.env) + +The env configuration consists of basic definition and observation\\action\\reward settings. The basic definition includes the cumo config file, episode length and light duration. The obs\action\reward define the detail setting of each contains. + +.. code:: yaml + + env: + sumocfg_path: 'wj3/rl_wj.sumocfg' + max_episode_steps: 1500 + green_duration: 10 + yellow_duration: 3 + obs: + ... + action: + ... + reward: + ... + +Observation +---------------- + +We provide several types of observations of a traffic cross. If `use_centrolized_obs` is set `True`, the observation of each cross will be concatenated into one vector. The contents of observation can me modified by setting `obs_type`. The following observation is supported now. + +- phase: One-hot phase vector of current cross signal +- lane_pos_vec: Lane occupancy in each grid position. The grid num can be set with `lane_grid_num` +- traffic_volumn: Traffic volumn of each lane. Vehicle num / lane length * volumn ratio +- queue_len: Vehicle waiting queue length of each lane. Waiting num / lane length * volumn ratio + +Action +------------- + +Sumo environment supports changing cross signal to target phase. The action space is set to multi-discrete for each cross to reduce action num. + +Reward +------------- + +Reward can be set with `reward_type`. Reward is calculated cross by cross. If `use_centrolized_obs` is set True, the reward of each cross will be summed up. + +- queue_len: Vehicle waiting queue num of each lane +- wait_time: Wait time increment of vehicles in each lane +- delay_time: Delay time of all vahicles in incomming and outgoing lanes +- pressure: Pressure of a cross + +Multi-agent +--------------- + +**DI-smartcross** supports a one-step configurable multi-agent RL training. +It is only necessary to add ``multi_agent`` in **DI-engine** config file to convert common PPO into MAPPO, +and change the ``use_centrolized_obs`` in environment config into ``True``. The policy and observations can +be automatically changed to run individual agent for each cross. diff --git a/entry/config/sumo_arterial7_dqn_default_config.py b/entry/config/sumo_arterial7_dqn_default_config.py index dd982ea..f3d980b 100644 --- a/entry/config/sumo_arterial7_dqn_default_config.py +++ b/entry/config/sumo_arterial7_dqn_default_config.py @@ -2,7 +2,7 @@ from torch import nn nstep = 1 -sumo_dqn_default_config = dict( +sumo_mddqn_default_config = dict( exp_name='sumo_arterial7_md_dqn', env=dict( manager=dict( @@ -21,8 +21,6 @@ policy=dict( # Whether to use cuda for network. cuda=True, - # Whether the RL algorithm is on-policy or off-policy. - on_policy=False, # Whether use priority priority=True, priority_IS_weight=True, @@ -112,5 +110,5 @@ ) create_config = EasyDict(create_config) -sumo_dqn_default_config = EasyDict(sumo_dqn_default_config) -main_config = sumo_dqn_default_config +sumo_mddqn_default_config = EasyDict(sumo_mddqn_default_config) +main_config = sumo_mddqn_default_config diff --git a/entry/config/sumo_arterial7_mappo_default_config.py b/entry/config/sumo_arterial7_mappo_default_config.py new file mode 100644 index 0000000..9388074 --- /dev/null +++ b/entry/config/sumo_arterial7_mappo_default_config.py @@ -0,0 +1,81 @@ +from easydict import EasyDict +from torch import nn + +sumo_mappo_default_config = dict( + exp_name='sumo_arterial7_mappo', + env=dict( + manager=dict( + # Whether to use shared memory. Only effective if manager type is 'subprocess' + shared_memory=False, + context='spawn', + retry_type='renew', + max_retry=2, + ), + # Episode number for evaluation. + n_evaluator_episode=1, + # Once evaluation reward reaches "stop_value", which means the policy converges, the training can end. + stop_value=0, + collector_env_num=15, + evaluator_env_num=1, + ), + policy=dict( + # (bool) Whether to use cuda for network. + cuda=True, + # (bool) Whether to use priority(priority sample, IS weight, update priority) + priority=False, + # () + multi_agent=True, + action_space='discrete', + model=dict( + agent_obs_shape=32, + global_obs_shape=224, + action_shape=4, + agent_num=7, + activation=nn.Tanh(), + ), + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=1e-4, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + learner=dict( + hook=dict( + save_ckpt_after_iter=1000, + log_show_after_iter=1000, + ), + ), + ), + collect=dict( + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + n_sample=600, + collector=dict( + transform_obs=True, + collect_print_freq=1000, + ) + ), + eval=dict(evaluator=dict(eval_freq=1000, )), + other=dict() + ), +) + +create_config = dict( + env_manager=dict( + type='subprocess', + ), + env=dict( + import_names=['smartcross.envs.sumo_env'], + type='sumo_env', + ), + policy=dict( + import_names=['ding.policy.ppo'], + type='ppo', + ), +) + +create_config = EasyDict(create_config) +sumo_mappo_default_config = EasyDict(sumo_mappo_default_config) +main_config = sumo_mappo_default_config diff --git a/entry/config/sumo_arterial7_offppo_default_config.py b/entry/config/sumo_arterial7_offppo_default_config.py new file mode 100644 index 0000000..1f40d51 --- /dev/null +++ b/entry/config/sumo_arterial7_offppo_default_config.py @@ -0,0 +1,96 @@ +from easydict import EasyDict +from torch import nn + +nstep = 1 +sumo_off_mdppo_default_config = dict( + exp_name='sumo_arterial7_off_md_ppo', + env=dict( + manager=dict( + shared_memory=False, + context='spawn', + retry_type='renew', + max_retry=5, + ), + # Episode number for evaluation. + n_evaluator_episode=1, + # Once evaluation reward reaches "stop_value", which means the policy converges, the training can end. + stop_value=0, + collector_env_num=15, + evaluator_env_num=1, + ), + policy=dict( + # Whether to use cuda for network. + cuda=True, + # (bool) Whether to use priority(priority sample, IS weight, update priority) + priority=False, + # () + continuous=False, + model=dict( + obs_shape=224, + action_shape=[4] * 7, + activation=nn.Tanh(), + ), + # learn_mode config + learn=dict( + update_per_collect=100, + batch_size=64, + learning_rate=1e-4, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + learner=dict( + hook=dict( + save_ckpt_after_iter=1000, + log_show_after_iter=1000, + ), + ), + ignore_done=True, + ), + # collect_mode config + collect=dict( + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + n_sample=600, + collector=dict( + transform_obs=True, + collect_print_freq=1000, + ) + ), + eval=dict( + evaluator=dict( + # Evaluate every "eval_freq" training steps. + eval_freq=1000, + ) + ), + # command_mode config + other=dict( + replay_buffer=dict( + replay_buffer_size=400000, + max_use=10000, + monitor=dict( + sampled_data_attr=dict(print_freq=300, ), + periodic_thruput=dict(seconds=300, ), + ), + ), + ), + ), +) + +create_config = dict( + env_manager=dict(type='subprocess', ), + env=dict( + # Must use the absolute path. All the following "import_names" should obey this too. + import_names=['smartcross.envs.sumo_env'], + type='sumo_env', + ), + # RL policy register name (refer to function "register_policy"). + policy=dict( + import_names=['dizoo.common.policy.md_ppo'], + type='md_ppo_offpolicy', + ), +) + +create_config = EasyDict(create_config) +sumo_off_mdppo_default_config = EasyDict(sumo_off_mdppo_default_config) +main_config = sumo_off_mdppo_default_config diff --git a/entry/config/sumo_arterial7_ppo_default_config.py b/entry/config/sumo_arterial7_ppo_default_config.py index 1652a15..2fab4d5 100644 --- a/entry/config/sumo_arterial7_ppo_default_config.py +++ b/entry/config/sumo_arterial7_ppo_default_config.py @@ -2,7 +2,7 @@ from torch import nn nstep = 1 -sumo_ppo_default_config = dict( +sumo_mdppo_default_config = dict( exp_name='sumo_arterial7_md_ppo', env=dict( manager=dict( @@ -21,8 +21,6 @@ policy=dict( # Whether to use cuda for network. cuda=True, - # Whether the RL algorithm is on-policy or off-policy. - on_policy=False, # (bool) Whether to use priority(priority sample, IS weight, update priority) priority=False, # () @@ -66,16 +64,7 @@ ) ), # command_mode config - other=dict( - replay_buffer=dict( - replay_buffer_size=400000, - max_use=10000, - monitor=dict( - sampled_data_attr=dict(print_freq=300, ), - periodic_thruput=dict(seconds=300, ), - ), - ), - ), + other=dict(), ), ) @@ -89,10 +78,10 @@ # RL policy register name (refer to function "register_policy"). policy=dict( import_names=['dizoo.common.policy.md_ppo'], - type='md_ppo_offpolicy', + type='md_ppo', ), ) create_config = EasyDict(create_config) -sumo_ppo_default_config = EasyDict(sumo_ppo_default_config) -main_config = sumo_ppo_default_config +sumo_mdppo_default_config = EasyDict(sumo_mdppo_default_config) +main_config = sumo_mdppo_default_config diff --git a/entry/config/sumo_wj3_dqn_default_config.py b/entry/config/sumo_wj3_dqn_default_config.py index c854374..6be4626 100644 --- a/entry/config/sumo_wj3_dqn_default_config.py +++ b/entry/config/sumo_wj3_dqn_default_config.py @@ -2,7 +2,7 @@ from torch import nn nstep = 1 -sumo_dqn_default_config = dict( +sumo_mddqn_default_config = dict( exp_name='sumo_wj3_md_dqn', env=dict( manager=dict( @@ -22,8 +22,6 @@ policy=dict( # Whether to use cuda for network. cuda=True, - # Whether the RL algorithm is on-policy or off-policy. - on_policy=False, # Whether use priority priority=True, priority_IS_weight=True, @@ -114,5 +112,5 @@ ) create_config = EasyDict(create_config) -sumo_dqn_default_config = EasyDict(sumo_dqn_default_config) -main_config = sumo_dqn_default_config +sumo_mddqn_default_config = EasyDict(sumo_mddqn_default_config) +main_config = sumo_mddqn_default_config diff --git a/entry/config/sumo_wj3_mappo_default_config.py b/entry/config/sumo_wj3_mappo_default_config.py index aa00bb7..7d4211b 100644 --- a/entry/config/sumo_wj3_mappo_default_config.py +++ b/entry/config/sumo_wj3_mappo_default_config.py @@ -21,11 +21,10 @@ policy=dict( # (bool) Whether to use cuda for network. cuda=True, - # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) - on_policy=True, # (bool) Whether to use priority(priority sample, IS weight, update priority) priority=False, # () + multi_agent=True, action_space='discrete', model=dict( agent_obs_shape=174, @@ -59,16 +58,7 @@ ) ), eval=dict(evaluator=dict(eval_freq=1000, )), - other=dict( - # replay_buffer=dict( - # replay_buffer_size=400000, - # max_use=10000, - # monitor=dict( - # sampled_data_attr=dict(print_freq=300, ), - # periodic_thruput=dict(seconds=300, ), - # ), - # ), - ) + other=dict() ), ) diff --git a/entry/config/sumo_wj3_offppo_default_config.py b/entry/config/sumo_wj3_offppo_default_config.py new file mode 100644 index 0000000..e22a941 --- /dev/null +++ b/entry/config/sumo_wj3_offppo_default_config.py @@ -0,0 +1,87 @@ +from easydict import EasyDict +from torch import nn + +sumo_off_mdppo_default_config = dict( + exp_name='sumo_wj3_off_md_ppo', + env=dict( + manager=dict( + # Whether to use shared memory. Only effective if manager type is 'subprocess' + shared_memory=False, + context='spawn', + retry_type='renew', + max_retry=2, + ), + # Episode number for evaluation. + n_evaluator_episode=1, + # Once evaluation reward reaches "stop_value", which means the policy converges, the training can end. + stop_value=0, + collector_env_num=15, + evaluator_env_num=1, + ), + policy=dict( + # (bool) Whether to use cuda for network. + cuda=True, + # (bool) Whether to use priority(priority sample, IS weight, update priority) + priority=False, + # () + continuous=False, + model=dict( + obs_shape=442, + action_shape=[4, 4, 4], + activation=nn.Tanh(), + ), + learn=dict( + update_per_collect=100, + batch_size=64, + learning_rate=1e-4, + value_weight=0.5, + entropy_weight=0.01, + clip_ratio=0.2, + learner=dict( + hook=dict( + save_ckpt_after_iter=1000, + log_show_after_iter=1000, + ), + ), + ), + collect=dict( + unroll_len=1, + discount_factor=0.99, + gae_lambda=0.95, + n_sample=600, + collector=dict( + transform_obs=True, + collect_print_freq=1000, + ) + ), + eval=dict(evaluator=dict(eval_freq=1000, )), + other=dict( + replay_buffer=dict( + replay_buffer_size=400000, + max_use=10000, + monitor=dict( + sampled_data_attr=dict(print_freq=300, ), + periodic_thruput=dict(seconds=300, ), + ), + ), + ) + ), +) + +create_config = dict( + env_manager=dict( + type='subprocess', + ), + env=dict( + import_names=['smartcross.envs.sumo_env'], + type='sumo_env', + ), + policy=dict( + import_names=['dizoo.common.policy.md_ppo'], + type='md_ppo_offpolicy', + ), +) + +create_config = EasyDict(create_config) +sumo_off_mdppo_default_config = EasyDict(sumo_off_mdppo_default_config) +main_config = sumo_off_mdppo_default_config diff --git a/entry/config/sumo_wj3_ppo_default_config.py b/entry/config/sumo_wj3_ppo_default_config.py index 07823de..9cc0922 100644 --- a/entry/config/sumo_wj3_ppo_default_config.py +++ b/entry/config/sumo_wj3_ppo_default_config.py @@ -1,7 +1,7 @@ from easydict import EasyDict from torch import nn -sumo_ppo_default_config = dict( +sumo_mdppo_default_config = dict( exp_name='sumo_wj3_md_ppo', env=dict( manager=dict( @@ -21,8 +21,6 @@ policy=dict( # (bool) Whether to use cuda for network. cuda=True, - # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used) - on_policy=False, # (bool) Whether to use priority(priority sample, IS weight, update priority) priority=False, # () @@ -33,7 +31,7 @@ activation=nn.Tanh(), ), learn=dict( - update_per_collect=100, + epoch_per_collect=10, batch_size=64, learning_rate=1e-4, value_weight=0.5, @@ -57,16 +55,7 @@ ) ), eval=dict(evaluator=dict(eval_freq=1000, )), - other=dict( - replay_buffer=dict( - replay_buffer_size=400000, - max_use=10000, - monitor=dict( - sampled_data_attr=dict(print_freq=300, ), - periodic_thruput=dict(seconds=300, ), - ), - ), - ) + other=dict() ), ) @@ -80,10 +69,10 @@ ), policy=dict( import_names=['dizoo.common.policy.md_ppo'], - type='md_ppo_offpolicy', + type='md_ppo', ), ) create_config = EasyDict(create_config) -sumo_ppo_default_config = EasyDict(sumo_ppo_default_config) -main_config = sumo_ppo_default_config +sumo_mdppo_default_config = EasyDict(sumo_mdppo_default_config) +main_config = sumo_mdppo_default_config diff --git a/entry/config/sumo_wj3_rainbow_dqn_default_config.py b/entry/config/sumo_wj3_rainbow_default_config.py similarity index 93% rename from entry/config/sumo_wj3_rainbow_dqn_default_config.py rename to entry/config/sumo_wj3_rainbow_default_config.py index 87eb148..0894743 100644 --- a/entry/config/sumo_wj3_rainbow_dqn_default_config.py +++ b/entry/config/sumo_wj3_rainbow_default_config.py @@ -2,7 +2,7 @@ from torch import nn nstep = 3 -sumo_rainbow_dqn_default_config = dict( +sumo_rainbow_default_config = dict( exp_name='sumo_wj3_md_rainbow_dqn', env=dict( manager=dict( @@ -22,8 +22,6 @@ policy=dict( # Whether to use cuda for network. cuda=True, - # Whether the RL algorithm is on-policy or off-policy. - on_policy=False, # Whether use priority priority=True, priority_IS_weight=True, @@ -114,5 +112,5 @@ ) create_config = EasyDict(create_config) -sumo_rainbow_dqn_default_config = EasyDict(sumo_rainbow_dqn_default_config) -main_config = sumo_rainbow_dqn_default_config +sumo_rainbow_default_config = EasyDict(sumo_rainbow_default_config) +main_config = sumo_rainbow_default_config diff --git a/setup.py b/setup.py index e649f00..862c575 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ "di-engine>=0.2", "sumolib", "traci", + "MarkupSafe<=2.0.1'", ], extras_require={ 'doc': [ diff --git a/smartcross/envs/action/__init__.py b/smartcross/envs/action/__init__.py new file mode 100644 index 0000000..bcb4b2c --- /dev/null +++ b/smartcross/envs/action/__init__.py @@ -0,0 +1,2 @@ +from .sumo_action import SumoAction +from .sumo_action_runner import SumoActionRunner diff --git a/smartcross/envs/action/sumo_action.py b/smartcross/envs/action/sumo_action.py new file mode 100644 index 0000000..790d992 --- /dev/null +++ b/smartcross/envs/action/sumo_action.py @@ -0,0 +1,66 @@ +from typing import Dict + +from ding.envs import BaseEnv +from ding.envs.common import EnvElement + +ALL_ACTION_TYPE = set(['change']) + + +class SumoAction(EnvElement): + r""" + Overview: + the action element of Sumo enviroment + + Interface: + _init, _from_agent_processor + """ + _name = "SumoAction" + + def _init(self, env: BaseEnv, cfg: Dict) -> None: + r""" + Overview: + init the sumo action environment with the given config file + Arguments: + - cfg(:obj:`EasyDict`): config, you can refer to `envs/sumo/sumo_env_default_config.yaml` + """ + self._env = env + self._cfg = cfg + action_shape = [] + self._action_type = cfg.action_type + assert self._action_type in ALL_ACTION_TYPE + self._use_multi_discrete = cfg.use_multi_discrete + for tl, cross in self._env.crosses.items(): + if self._action_type == 'change': + action_shape.append(cross.phase_num) + else: + # TODO: add switch action + raise NotImplementedError + if self._use_multi_discrete: + self._shape = len(action_shape) + else: + # TODO: add naive discrete action + raise NotImplementedError + self._value = { + 'min': 0, + 'max': action_shape[0], + 'dtype': int, + } + + def _from_agent_processor(self, data: Dict) -> Dict: + r""" + """ + # TODO: add switch action + action = {k: {} for k in data.keys()} + for k, v in data.items(): + act, last_act = v['action'], v['last_action'] + if last_act is not None and act != last_act: + yellow_phase = self._env.crosses[k].get_yellow_phase_index(last_act) + else: + yellow_phase = None + action[k]['yellow'] = yellow_phase + action[k]['green'] = self._env.crosses[k].get_green_phase_index(act) + return action + + # override + def _details(self): + return 'action dim: {}'.format(self._shape) diff --git a/smartcross/envs/action/sumo_action_runner.py b/smartcross/envs/action/sumo_action_runner.py new file mode 100644 index 0000000..cc062ad --- /dev/null +++ b/smartcross/envs/action/sumo_action_runner.py @@ -0,0 +1,36 @@ +from typing import Dict, Any +import numpy as np +import torch + +from ding.envs.common import EnvElementRunner +from ding.envs.env.base_env import BaseEnv +from ding.torch_utils import to_ndarray +from .sumo_action import SumoAction + + +class SumoActionRunner(EnvElementRunner): + + def _init(self, engine: BaseEnv, cfg: Dict) -> None: + r""" + Overview: + init the sumo observation helper with the given config file + Arguments: + - cfg(:obj:`EasyDict`): config, you can refer to `envs/sumo/sumo_env_default_config.yaml` + """ + # set self._core and other state variable + self._engine = engine + self._core = SumoAction(engine, cfg) + self._last_action = None + + def get(self, raw_action: Any) -> Dict: + raw_action = np.squeeze(raw_action) + if self._last_action is None: + self._last_action = [None for _ in range(len(raw_action))] + data = {} + for tl, act, last_act in zip(self._engine.crosses.keys(), raw_action, self._last_action): + data[tl] = {'action': act, 'last_action': last_act} + action = self._core._from_agent_processor(data) + return action + + def reset(self) -> None: + self._last_action = None diff --git a/smartcross/envs/obs/__init__.py b/smartcross/envs/obs/__init__.py index e69de29..b1dd8bb 100644 --- a/smartcross/envs/obs/__init__.py +++ b/smartcross/envs/obs/__init__.py @@ -0,0 +1,2 @@ +from .sumo_obs import SumoObs +from .sumo_obs_runner import SumoObsRunner diff --git a/smartcross/envs/obs/sumo_obs_helper.py b/smartcross/envs/obs/sumo_obs.py similarity index 88% rename from smartcross/envs/obs/sumo_obs_helper.py rename to smartcross/envs/obs/sumo_obs.py index ef4af91..3fd9aac 100644 --- a/smartcross/envs/obs/sumo_obs_helper.py +++ b/smartcross/envs/obs/sumo_obs.py @@ -3,14 +3,24 @@ from ding.envs import BaseEnv from ding.envs.common.env_element import EnvElementInfo +from ding.envs.common import EnvElement ALL_OBS_TPYE = set(['phase', 'lane_pos_vec', 'traffic_volumn', 'queue_len']) -class SumoObsHelper(): +class SumoObs(EnvElement): + r""" + Overview: + the observation element of Sumo enviroment - def __init__(self, core: BaseEnv, cfg: Dict) -> None: - self._core = core + Interface: + _init, to_agent_processor + """ + + _name = "SumoObs" + + def _init(self, env: BaseEnv, cfg: Dict) -> None: + self._core = env self._cfg = cfg self._tl_num = len(self._core.crosses) self._obs_type = self._cfg.obs_type @@ -18,7 +28,6 @@ def __init__(self, core: BaseEnv, cfg: Dict) -> None: self._use_centralized_obs = self._cfg.use_centralized_obs self._padding = self._cfg.padding - def init_info(self) -> None: obs_shape = [] tl_obs_max_dict = None for tl, cross in self._core.crosses.items(): @@ -41,7 +50,7 @@ def init_info(self) -> None: tl_obs_max_dict = max_dict(tl_obs_max_dict, tl_obs_shape_map) if self._use_centralized_obs: - self._obs_shape = sum(obs_shape) + self._shape = sum(obs_shape) else: global_state_shape = sum(obs_shape) if self._padding: @@ -49,12 +58,12 @@ def init_info(self) -> None: agent_state_shape = sum(self._tl_feature_shape.values()) else: agent_state_shape = max(obs_shape) - self._obs_shape = { + self._shape = { 'agent_state': agent_state_shape, 'global_state': global_state_shape, 'action_mask': self._tl_num } - self._obs_value = { + self._value = { 'min': 0, 'max': 1, 'dtype': float, @@ -75,7 +84,7 @@ def _get_tls_feature(self, tl_id: int) -> Dict: tl_obs['queue_len'] = list(cross.get_lane_queue_len(self._queue_len_ratio).values()) return tl_obs - def get_observation(self) -> Dict[str, np.ndarray]: + def _to_agent_processor(self) -> Dict: obs = {} tl_num = len(self._core.crosses) for tl in self._core.crosses.keys(): @@ -99,8 +108,11 @@ def get_observation(self) -> Dict[str, np.ndarray]: 'action_mask': np.array([action_mask] * tl_num) } - def info(self): - return EnvElementInfo(self._obs_shape, self._obs_value) + def __repr__(self) -> str: + return '{}: {}'.format(self._name, self._details()) + + def _details(self) -> str: + return '{}'.format(self._shape) def max_dict(dict1: Dict, dict2: Dict) -> Dict: diff --git a/smartcross/envs/obs/sumo_obs_runner.py b/smartcross/envs/obs/sumo_obs_runner.py new file mode 100644 index 0000000..e67f41f --- /dev/null +++ b/smartcross/envs/obs/sumo_obs_runner.py @@ -0,0 +1,47 @@ +import numpy as np +from typing import Dict + +from ding.envs.common import EnvElementRunner +from ding.envs.env.base_env import BaseEnv +from ding.torch_utils import to_ndarray +from .sumo_obs import SumoObs + + +class SumoObsRunner(EnvElementRunner): + r""" + Overview: + runner that help to get the observation space + Interface: + _init, get, reset + """ + + def _init(self, engine: BaseEnv, cfg: dict) -> None: + r""" + Overview: + init the sumo observation helper with the given config file + Arguments: + - cfg(:obj:`EasyDict`): config, you can refer to `envs/sumo/sumo_env_default_config.yaml` + """ + # set self._core and other state variable + self._engine = engine + self._core = SumoObs(engine, cfg) + self._obs = None + + def get(self): + """ + Overview: + return the formated observation + Returns: + - obs (:obj:`torch.Tensor` or :obj:`dict`): the returned observation,\ + :obj:`torch.Tensor` if used centerlized_obs, else :obj:`dict` with format {traffic_light: reward} + """ + self._obs = self._core._to_agent_processor() + return to_ndarray(self._obs, dtype=np.float32) + + # override + def reset(self) -> None: + r""" + Overview: + reset obs runner, and return the initial obs + """ + return to_ndarray(self._obs, dtype=np.float32) diff --git a/smartcross/envs/reward/__init__.py b/smartcross/envs/reward/__init__.py new file mode 100644 index 0000000..127cd01 --- /dev/null +++ b/smartcross/envs/reward/__init__.py @@ -0,0 +1,2 @@ +from .sumo_reward import SumoReward +from .sumo_reward_runner import SumoRewardRunner diff --git a/smartcross/envs/reward/sumo_reward.py b/smartcross/envs/reward/sumo_reward.py new file mode 100644 index 0000000..82fa44e --- /dev/null +++ b/smartcross/envs/reward/sumo_reward.py @@ -0,0 +1,60 @@ +from typing import Dict +import numpy as np + +from ding.envs import BaseEnv +from ding.envs.common import EnvElement + +ALL_REWARD_TYPE = set(['queue_len', 'wait_time', 'delay_time', 'pressure']) + + +class SumoReward(EnvElement): + r""" + Overview: + the reward element of Sumo enviroment + + Interface: + _init, to_agent_processor + """ + _name = "SumoReward" + + def _init(self, env: BaseEnv, cfg: Dict) -> None: + r""" + Overview: + init the sumo reward environment with the given config file + Arguments: + - cfg(:obj:`EasyDict`): config, you can refer to `envs/sumo/sumo_env_default_config.yaml` + """ + self._env = env + self._cfg = cfg + self._reward_type = cfg.reward_type + assert set(self._reward_type.keys()).issubset(ALL_REWARD_TYPE) + self._use_centralized_reward = cfg.use_centralized_reward + if self._use_centralized_reward: + self._shape = (1, ) + else: + raise NotImplementedError + self._value = {'min': '-inf', 'max': 'inf', 'dtype': float} + + def _to_agent_processor(self): + reward = {k: 0 for k in self._env.crosses.keys()} + for k in self._env.crosses.keys(): + cross = self._env.crosses[k] + if 'queue_len' in self._reward_type: + queue_len = np.average(list(cross.get_lane_queue_len().values())) + reward[k] += self._reward_type['queue_len'] * -queue_len + if 'wait_time' in self._reward_type: + wait_time = np.average(list(cross.get_lane_wait_time().values())) + reward[k] += self._reward_type['wait_time'] * -wait_time + if 'delay_time' in self._reward_type: + delay_time = np.average(list(cross.get_lane_delay_time().values())) + reward[k] += self._reward_type['delay_time'] * -delay_time + if 'pressure' in self._reward_type: + pressure = cross.get_pressure() + reward[k] += self._reward_type['pressure'] * -pressure + if self._use_centralized_reward: + reward = sum(reward.values()) + return reward + + # override + def _details(self): + return '{}'.format(self._shape) diff --git a/smartcross/envs/reward/sumo_reward_runner.py b/smartcross/envs/reward/sumo_reward_runner.py new file mode 100644 index 0000000..bbce310 --- /dev/null +++ b/smartcross/envs/reward/sumo_reward_runner.py @@ -0,0 +1,36 @@ +import numpy as np +from typing import Dict, Any + +from ding.envs.common import EnvElementRunner +from ding.envs.env.base_env import BaseEnv +from ding.torch_utils import to_ndarray +from .sumo_reward import SumoReward + + +class SumoRewardRunner(EnvElementRunner): + r""" + Overview: + the reward element of Sumo enviroment + + Interface: + _init, to_agent_processor + """ + + def _init(self, engine: BaseEnv, cfg: dict) -> None: + r""" + Overview: + init the sumo reward environment with the given config file + Arguments: + - cfg(:obj:`EasyDict`): config, you can refer to `envs/sumo/sumo_env_default_config.yaml` + """ + self._engine = engine + self._core = SumoReward(engine, cfg) + self._final_eval_reward = 0 + + def get(self) -> Any: + reward = self._core._to_agent_processor() + self._final_eval_reward += reward + return reward + + def reset(self) -> None: + self._final_eval_reward = 0 diff --git a/smartcross/envs/rl_arterial_7roads_scene/route/1100/2_eg.rou.xml b/smartcross/envs/rl_arterial_7roads/route/1100/2_eg.rou.xml similarity index 100% rename from smartcross/envs/rl_arterial_7roads_scene/route/1100/2_eg.rou.xml rename to smartcross/envs/rl_arterial_7roads/route/1100/2_eg.rou.xml diff --git a/smartcross/envs/rl_arterial_7roads_scene/standard.net.xml b/smartcross/envs/rl_arterial_7roads/standard.net.xml similarity index 100% rename from smartcross/envs/rl_arterial_7roads_scene/standard.net.xml rename to smartcross/envs/rl_arterial_7roads/standard.net.xml diff --git a/smartcross/envs/rl_arterial_7roads_scene/standard.sumocfg b/smartcross/envs/rl_arterial_7roads/standard.sumocfg similarity index 85% rename from smartcross/envs/rl_arterial_7roads_scene/standard.sumocfg rename to smartcross/envs/rl_arterial_7roads/standard.sumocfg index 3b4da91..1229100 100644 --- a/smartcross/envs/rl_arterial_7roads_scene/standard.sumocfg +++ b/smartcross/envs/rl_arterial_7roads/standard.sumocfg @@ -2,7 +2,7 @@ - +