diff --git a/entry/cityflow_eval b/entry/cityflow_eval index f86975b..360c505 100644 --- a/entry/cityflow_eval +++ b/entry/cityflow_eval @@ -14,14 +14,14 @@ from smartcross.utils.config_utils import read_ding_config from smartcross.policy import FixedPolicy -def main(args, seed=0): +def main(args, seed=None): ding_cfg = args.ding_cfg main_config, create_config = read_ding_config(ding_cfg) cityflow_env_config = {'config_path': args.env_cfg} if args.policy_type == 'fix': create_config.policy['type'] = 'smartcross_fix' main_config.env = deep_merge_dicts(main_config.env, cityflow_env_config) - cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True) + cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True, save_cfg=False) if args.env_num > 0: cfg.env.evaluator_env_num = args.env_num if cfg.env.n_evaluator_episode < args.env_num: @@ -47,6 +47,7 @@ def main(args, seed=0): policy, ) _, eval_reward = evaluator.eval(None, -1, -1, cfg.env.n_evaluator_episode) + eval_reward = [r['final_eval_reward'].item() for r in eval_reward] print('Eval is over! The performance is {}'.format(eval_reward)) evaluator.close() return eval_reward @@ -56,11 +57,11 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='DI-smartcross training script') parser.add_argument('-d', '--ding-cfg', default=None, help='DI-engine configuration path') parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path') - parser.add_argument('-s', '--seed', default=0, type=int, help='random seed for sumo') + parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo') parser.add_argument( '-p', '--policy-type', default='dqn', choices=['fix', 'dqn', 'ppo'], help='RL policy type' ) - parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation') + parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation') parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path') args = parser.parse_args() diff --git a/entry/cityflow_train b/entry/cityflow_train index 454c7fa..8861032 100644 --- a/entry/cityflow_train +++ b/entry/cityflow_train @@ -13,6 +13,7 @@ from ding.utils.default_helper import set_pkg_seed from ding.utils import deep_merge_dicts from ding.rl_utils import get_epsilon_greedy_fn from smartcross.utils.config_utils import read_ding_config +from smartcross.policy.default_policy import get_random_sample_func def main(args, seed=None): @@ -63,7 +64,8 @@ def main(args, seed=None): # Accumulate plenty of data at the beginning of training. if cfg.policy.get('random_collect_size', 0) > 0: action_space = collector_env.action_space - random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + random_sample_func = get_random_sample_func(action_space) + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, forward_fn=random_sample_func) collector.reset_policy(random_policy) new_data = collector.collect(n_sample=cfg.policy.random_collect_size) replay_buffer.push(new_data, cur_collector_envstep=0) @@ -115,8 +117,8 @@ if __name__ == "__main__": parser.add_argument('-d', '--ding-cfg', required=True, help='DI-engine configuration path') parser.add_argument('-e', '--env-cfg', required=True, help='cityflow json configuration path') parser.add_argument('-s', '--seed', default=None, type=int, help='random seed') - parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector env num for training') - parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator env num for training') + parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector env num for training') + parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator env num for training') parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt') args = parser.parse_args() diff --git a/entry/sumo_eval b/entry/sumo_eval index 40270a2..dfa3f34 100644 --- a/entry/sumo_eval +++ b/entry/sumo_eval @@ -15,7 +15,7 @@ from smartcross.utils.config_utils import get_sumo_config from smartcross.policy import RandomPolicy, FixedPolicy -def main(args, seed=0): +def main(args, seed=None): if args.policy_type in ['random', 'fix']: from entry.sumo_config.sumo_eval_default_config import main_config, create_config with open(args.env_cfg, 'r') as f: @@ -27,7 +27,7 @@ def main(args, seed=0): main_config, create_config = get_sumo_config(args) if args.gui: main_config.env.gui = True - cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True) + cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True, save_cfg=False) if args.env_num > 0: cfg.env.evaluator_env_num = args.env_num if cfg.env.n_evaluator_episode < args.env_num: @@ -54,6 +54,7 @@ def main(args, seed=0): policy, ) _, eval_reward = evaluator.eval(None, -1, -1, cfg.env.n_evaluator_episode) + eval_reward = [r['final_eval_reward'].item() for r in eval_reward] print('Eval is over! The performance is {}'.format(eval_reward)) evaluator.close() return eval_reward @@ -63,12 +64,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='DI-smartcross training script') parser.add_argument('-d', '--ding-cfg', default=None, help='DI-engine configuration path') parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path') - parser.add_argument('-s', '--seed', default=0, type=int, help='random seed for sumo') + parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo') parser.add_argument( '-p', '--policy-type', default='dqn', choices=['random', 'fix', 'dqn', 'rainbow', 'ppo'], help='RL policy type' ) parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow") - parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation') + parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation') parser.add_argument('--gui', action='store_true', help="open gui for visualize") parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path') diff --git a/entry/sumo_train b/entry/sumo_train index a5556e6..b32eb96 100644 --- a/entry/sumo_train +++ b/entry/sumo_train @@ -12,6 +12,7 @@ from ding.worker import BaseLearner, InteractionSerialEvaluator, create_serial_c from ding.utils.default_helper import set_pkg_seed from ding.rl_utils import get_epsilon_greedy_fn from smartcross.utils.config_utils import get_sumo_config +from smartcross.policy.default_policy import get_random_sample_func def main(args, seed=None): @@ -60,7 +61,8 @@ def main(args, seed=None): # Accumulate plenty of data at the beginning of training. if cfg.policy.get('random_collect_size', 0) > 0: action_space = collector_env.action_space - random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) + random_sample_func = get_random_sample_func(action_space) + random_policy = PolicyFactory.get_random_policy(policy.collect_mode, forward_fn=random_sample_func) collector.reset_policy(random_policy) new_data = collector.collect(n_sample=cfg.policy.random_collect_size) replay_buffer.push(new_data, cur_collector_envstep=0) @@ -113,8 +115,8 @@ if __name__ == "__main__": parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path') parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo') parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow") - parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector sumo env num for training') - parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator sumo env num for training') + parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector sumo env num for training') + parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator sumo env num for training') parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt') args = parser.parse_args() diff --git a/smartcross/envs/cityflow_env.py b/smartcross/envs/cityflow_env.py index 2936821..f9a22ec 100644 --- a/smartcross/envs/cityflow_env.py +++ b/smartcross/envs/cityflow_env.py @@ -215,6 +215,7 @@ def close(self) -> None: def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._seed = seed self._dynamic_seed = dynamic_seed + self._eng.set_random_seed(seed) @property def observation_space(self) -> gym.spaces.Space: diff --git a/smartcross/envs/crossing.py b/smartcross/envs/crossing.py index 339f5fd..d0e4346 100644 --- a/smartcross/envs/crossing.py +++ b/smartcross/envs/crossing.py @@ -17,7 +17,7 @@ def __init__(self, tls_id: str, env: 'BaseEnv') -> None: self._lane_vehicle_dict = {} self._previous_lane_vehicle_dict = {} - signal_definition = traci.trafficlight.getCompleteRedYellowGreenDefinition(self._id)[0] + signal_definition = traci.trafficlight.getAllProgramLogics(self._id)[0] self._green_phases = [] self._yellow_phases = [] for idx, phase in enumerate(signal_definition.phases): diff --git a/smartcross/envs/sumo_env.py b/smartcross/envs/sumo_env.py index 6cf534b..edc2755 100644 --- a/smartcross/envs/sumo_env.py +++ b/smartcross/envs/sumo_env.py @@ -115,6 +115,8 @@ def reset(self) -> Any: self._action_runner.reset() self._obs_runner.reset() self._reward_runner.reset() + if self._launch_env_flag: + self.close() self._launch_env(self._gui) for tl in self._cfg.tls: self._crosses[tl] = Crossing(tl, self) diff --git a/smartcross/policy/default_policy.py b/smartcross/policy/default_policy.py index fa8bc73..c5c0e12 100644 --- a/smartcross/policy/default_policy.py +++ b/smartcross/policy/default_policy.py @@ -38,6 +38,19 @@ def default_config(cls: type) -> EasyDict: return cfg +def get_random_sample_func(act_space): + + def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: + actions = {} + for env_id in data: + action = act_space.sample() + action = [torch.LongTensor([v]) for v in action] + actions[env_id] = {'action': action} + return actions + + return _forward + + @POLICY_REGISTRY.register('smartcross_fix') class FixedPolicy():