From 516ade0c47612674b04a846e88ab6d673c6f0a4c Mon Sep 17 00:00:00 2001 From: robinc94 Date: Sun, 24 Apr 2022 16:02:50 +0800 Subject: [PATCH] fix env num bug --- entry/cityflow_eval | 2 +- entry/cityflow_train | 4 ++-- entry/sumo_eval | 2 +- entry/sumo_train | 4 ++-- smartcross/policy/default_policy.py | 2 ++ 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/entry/cityflow_eval b/entry/cityflow_eval index 865e8a3..360c505 100644 --- a/entry/cityflow_eval +++ b/entry/cityflow_eval @@ -61,7 +61,7 @@ if __name__ == "__main__": parser.add_argument( '-p', '--policy-type', default='dqn', choices=['fix', 'dqn', 'ppo'], help='RL policy type' ) - parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation') + parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation') parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path') args = parser.parse_args() diff --git a/entry/cityflow_train b/entry/cityflow_train index 1c6b266..8861032 100644 --- a/entry/cityflow_train +++ b/entry/cityflow_train @@ -117,8 +117,8 @@ if __name__ == "__main__": parser.add_argument('-d', '--ding-cfg', required=True, help='DI-engine configuration path') parser.add_argument('-e', '--env-cfg', required=True, help='cityflow json configuration path') parser.add_argument('-s', '--seed', default=None, type=int, help='random seed') - parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector env num for training') - parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator env num for training') + parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector env num for training') + parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator env num for training') parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt') args = parser.parse_args() diff --git a/entry/sumo_eval b/entry/sumo_eval index 0afa3a3..dfa3f34 100644 --- a/entry/sumo_eval +++ b/entry/sumo_eval @@ -69,7 +69,7 @@ if __name__ == "__main__": '-p', '--policy-type', default='dqn', choices=['random', 'fix', 'dqn', 'rainbow', 'ppo'], help='RL policy type' ) parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow") - parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation') + parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation') parser.add_argument('--gui', action='store_true', help="open gui for visualize") parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path') diff --git a/entry/sumo_train b/entry/sumo_train index 7f99b06..b32eb96 100644 --- a/entry/sumo_train +++ b/entry/sumo_train @@ -115,8 +115,8 @@ if __name__ == "__main__": parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path') parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo') parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow") - parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector sumo env num for training') - parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator sumo env num for training') + parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector sumo env num for training') + parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator sumo env num for training') parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt') args = parser.parse_args() diff --git a/smartcross/policy/default_policy.py b/smartcross/policy/default_policy.py index 1c18280..c5c0e12 100644 --- a/smartcross/policy/default_policy.py +++ b/smartcross/policy/default_policy.py @@ -39,6 +39,7 @@ def default_config(cls: type) -> EasyDict: def get_random_sample_func(act_space): + def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: actions = {} for env_id in data: @@ -46,6 +47,7 @@ def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: action = [torch.LongTensor([v]) for v in action] actions[env_id] = {'action': action} return actions + return _forward