From d543d59332a3f9715a2b02b9c02b963546c3a9e9 Mon Sep 17 00:00:00 2001 From: robinc94 Date: Thu, 21 Apr 2022 02:29:45 +0800 Subject: [PATCH 1/3] suit for 0.3.0 --- entry/cityflow_eval | 2 +- entry/cityflow_train | 4 +++- entry/sumo_eval | 2 +- entry/sumo_train | 4 +++- smartcross/policy/default_policy.py | 11 +++++++++++ 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/entry/cityflow_eval b/entry/cityflow_eval index f86975b..2c6cc93 100644 --- a/entry/cityflow_eval +++ b/entry/cityflow_eval @@ -21,7 +21,7 @@ def main(args, seed=0): if args.policy_type == 'fix': create_config.policy['type'] = 'smartcross_fix' main_config.env = deep_merge_dicts(main_config.env, cityflow_env_config) - cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True) + cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True, save_cfg=False) if args.env_num > 0: cfg.env.evaluator_env_num = args.env_num if cfg.env.n_evaluator_episode < args.env_num: diff --git a/entry/cityflow_train b/entry/cityflow_train index 454c7fa..1c6b266 100644 --- a/entry/cityflow_train +++ b/entry/cityflow_train @@ -13,6 +13,7 @@ from ding.utils.default_helper import set_pkg_seed from ding.utils import deep_merge_dicts from ding.rl_utils import get_epsilon_greedy_fn from smartcross.utils.config_utils import read_ding_config +from smartcross.policy.default_policy import get_random_sample_func def main(args, seed=None): @@ -63,7 +64,8 @@ def main(args, seed=None): # Accumulate plenty of data at the beginning of training. if cfg.policy.get('random_collect_size', 0) > 0: action_space = collector_env.action_space - random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + random_sample_func = get_random_sample_func(action_space) + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, forward_fn=random_sample_func) collector.reset_policy(random_policy) new_data = collector.collect(n_sample=cfg.policy.random_collect_size) replay_buffer.push(new_data, cur_collector_envstep=0) diff --git a/entry/sumo_eval b/entry/sumo_eval index 40270a2..067fe4f 100644 --- a/entry/sumo_eval +++ b/entry/sumo_eval @@ -27,7 +27,7 @@ def main(args, seed=0): main_config, create_config = get_sumo_config(args) if args.gui: main_config.env.gui = True - cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True) + cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True, save_cfg=False) if args.env_num > 0: cfg.env.evaluator_env_num = args.env_num if cfg.env.n_evaluator_episode < args.env_num: diff --git a/entry/sumo_train b/entry/sumo_train index a5556e6..7f99b06 100644 --- a/entry/sumo_train +++ b/entry/sumo_train @@ -12,6 +12,7 @@ from ding.worker import BaseLearner, InteractionSerialEvaluator, create_serial_c from ding.utils.default_helper import set_pkg_seed from ding.rl_utils import get_epsilon_greedy_fn from smartcross.utils.config_utils import get_sumo_config +from smartcross.policy.default_policy import get_random_sample_func def main(args, seed=None): @@ -60,7 +61,8 @@ def main(args, seed=None): # Accumulate plenty of data at the beginning of training. if cfg.policy.get('random_collect_size', 0) > 0: action_space = collector_env.action_space - random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + random_sample_func = get_random_sample_func(action_space) + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, forward_fn=random_sample_func) collector.reset_policy(random_policy) new_data = collector.collect(n_sample=cfg.policy.random_collect_size) replay_buffer.push(new_data, cur_collector_envstep=0) diff --git a/smartcross/policy/default_policy.py b/smartcross/policy/default_policy.py index fa8bc73..1c18280 100644 --- a/smartcross/policy/default_policy.py +++ b/smartcross/policy/default_policy.py @@ -38,6 +38,17 @@ def default_config(cls: type) -> EasyDict: return cfg +def get_random_sample_func(act_space): + def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: + actions = {} + for env_id in data: + action = act_space.sample() + action = [torch.LongTensor([v]) for v in action] + actions[env_id] = {'action': action} + return actions + return _forward + + @POLICY_REGISTRY.register('smartcross_fix') class FixedPolicy(): From eec841d9c6bea380f94c0ac688e86f046e84e998 Mon Sep 17 00:00:00 2001 From: robinc94 Date: Sun, 24 Apr 2022 12:53:43 +0800 Subject: [PATCH 2/3] fix seed bug --- entry/cityflow_eval | 5 +++-- entry/sumo_eval | 5 +++-- smartcross/envs/cityflow_env.py | 1 + smartcross/envs/crossing.py | 2 +- smartcross/envs/sumo_env.py | 2 ++ 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/entry/cityflow_eval b/entry/cityflow_eval index 2c6cc93..865e8a3 100644 --- a/entry/cityflow_eval +++ b/entry/cityflow_eval @@ -14,7 +14,7 @@ from smartcross.utils.config_utils import read_ding_config from smartcross.policy import FixedPolicy -def main(args, seed=0): +def main(args, seed=None): ding_cfg = args.ding_cfg main_config, create_config = read_ding_config(ding_cfg) cityflow_env_config = {'config_path': args.env_cfg} @@ -47,6 +47,7 @@ def main(args, seed=0): policy, ) _, eval_reward = evaluator.eval(None, -1, -1, cfg.env.n_evaluator_episode) + eval_reward = [r['final_eval_reward'].item() for r in eval_reward] print('Eval is over! The performance is {}'.format(eval_reward)) evaluator.close() return eval_reward @@ -56,7 +57,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='DI-smartcross training script') parser.add_argument('-d', '--ding-cfg', default=None, help='DI-engine configuration path') parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path') - parser.add_argument('-s', '--seed', default=0, type=int, help='random seed for sumo') + parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo') parser.add_argument( '-p', '--policy-type', default='dqn', choices=['fix', 'dqn', 'ppo'], help='RL policy type' ) diff --git a/entry/sumo_eval b/entry/sumo_eval index 067fe4f..0afa3a3 100644 --- a/entry/sumo_eval +++ b/entry/sumo_eval @@ -15,7 +15,7 @@ from smartcross.utils.config_utils import get_sumo_config from smartcross.policy import RandomPolicy, FixedPolicy -def main(args, seed=0): +def main(args, seed=None): if args.policy_type in ['random', 'fix']: from entry.sumo_config.sumo_eval_default_config import main_config, create_config with open(args.env_cfg, 'r') as f: @@ -54,6 +54,7 @@ def main(args, seed=0): policy, ) _, eval_reward = evaluator.eval(None, -1, -1, cfg.env.n_evaluator_episode) + eval_reward = [r['final_eval_reward'].item() for r in eval_reward] print('Eval is over! The performance is {}'.format(eval_reward)) evaluator.close() return eval_reward @@ -63,7 +64,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='DI-smartcross training script') parser.add_argument('-d', '--ding-cfg', default=None, help='DI-engine configuration path') parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path') - parser.add_argument('-s', '--seed', default=0, type=int, help='random seed for sumo') + parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo') parser.add_argument( '-p', '--policy-type', default='dqn', choices=['random', 'fix', 'dqn', 'rainbow', 'ppo'], help='RL policy type' ) diff --git a/smartcross/envs/cityflow_env.py b/smartcross/envs/cityflow_env.py index 2936821..f9a22ec 100644 --- a/smartcross/envs/cityflow_env.py +++ b/smartcross/envs/cityflow_env.py @@ -215,6 +215,7 @@ def close(self) -> None: def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._seed = seed self._dynamic_seed = dynamic_seed + self._eng.set_random_seed(seed) @property def observation_space(self) -> gym.spaces.Space: diff --git a/smartcross/envs/crossing.py b/smartcross/envs/crossing.py index 339f5fd..d0e4346 100644 --- a/smartcross/envs/crossing.py +++ b/smartcross/envs/crossing.py @@ -17,7 +17,7 @@ def __init__(self, tls_id: str, env: 'BaseEnv') -> None: self._lane_vehicle_dict = {} self._previous_lane_vehicle_dict = {} - signal_definition = traci.trafficlight.getCompleteRedYellowGreenDefinition(self._id)[0] + signal_definition = traci.trafficlight.getAllProgramLogics(self._id)[0] self._green_phases = [] self._yellow_phases = [] for idx, phase in enumerate(signal_definition.phases): diff --git a/smartcross/envs/sumo_env.py b/smartcross/envs/sumo_env.py index 6cf534b..edc2755 100644 --- a/smartcross/envs/sumo_env.py +++ b/smartcross/envs/sumo_env.py @@ -115,6 +115,8 @@ def reset(self) -> Any: self._action_runner.reset() self._obs_runner.reset() self._reward_runner.reset() + if self._launch_env_flag: + self.close() self._launch_env(self._gui) for tl in self._cfg.tls: self._crosses[tl] = Crossing(tl, self) From 516ade0c47612674b04a846e88ab6d673c6f0a4c Mon Sep 17 00:00:00 2001 From: robinc94 Date: Sun, 24 Apr 2022 16:02:50 +0800 Subject: [PATCH 3/3] fix env num bug --- entry/cityflow_eval | 2 +- entry/cityflow_train | 4 ++-- entry/sumo_eval | 2 +- entry/sumo_train | 4 ++-- smartcross/policy/default_policy.py | 2 ++ 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/entry/cityflow_eval b/entry/cityflow_eval index 865e8a3..360c505 100644 --- a/entry/cityflow_eval +++ b/entry/cityflow_eval @@ -61,7 +61,7 @@ if __name__ == "__main__": parser.add_argument( '-p', '--policy-type', default='dqn', choices=['fix', 'dqn', 'ppo'], help='RL policy type' ) - parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation') + parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation') parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path') args = parser.parse_args() diff --git a/entry/cityflow_train b/entry/cityflow_train index 1c6b266..8861032 100644 --- a/entry/cityflow_train +++ b/entry/cityflow_train @@ -117,8 +117,8 @@ if __name__ == "__main__": parser.add_argument('-d', '--ding-cfg', required=True, help='DI-engine configuration path') parser.add_argument('-e', '--env-cfg', required=True, help='cityflow json configuration path') parser.add_argument('-s', '--seed', default=None, type=int, help='random seed') - parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector env num for training') - parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator env num for training') + parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector env num for training') + parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator env num for training') parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt') args = parser.parse_args() diff --git a/entry/sumo_eval b/entry/sumo_eval index 0afa3a3..dfa3f34 100644 --- a/entry/sumo_eval +++ b/entry/sumo_eval @@ -69,7 +69,7 @@ if __name__ == "__main__": '-p', '--policy-type', default='dqn', choices=['random', 'fix', 'dqn', 'rainbow', 'ppo'], help='RL policy type' ) parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow") - parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation') + parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation') parser.add_argument('--gui', action='store_true', help="open gui for visualize") parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path') diff --git a/entry/sumo_train b/entry/sumo_train index 7f99b06..b32eb96 100644 --- a/entry/sumo_train +++ b/entry/sumo_train @@ -115,8 +115,8 @@ if __name__ == "__main__": parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path') parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo') parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow") - parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector sumo env num for training') - parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator sumo env num for training') + parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector sumo env num for training') + parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator sumo env num for training') parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt') args = parser.parse_args() diff --git a/smartcross/policy/default_policy.py b/smartcross/policy/default_policy.py index 1c18280..c5c0e12 100644 --- a/smartcross/policy/default_policy.py +++ b/smartcross/policy/default_policy.py @@ -39,6 +39,7 @@ def default_config(cls: type) -> EasyDict: def get_random_sample_func(act_space): + def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: actions = {} for env_id in data: @@ -46,6 +47,7 @@ def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: action = [torch.LongTensor([v]) for v in action] actions[env_id] = {'action': action} return actions + return _forward