Skip to content

Commit

Permalink
Merge pull request #4 from opendilab/dev-0.3.0
Browse files Browse the repository at this point in the history
suit for 0.3.0
  • Loading branch information
RobinC94 authored Apr 24, 2022
2 parents 360be75 + 516ade0 commit 6706b8d
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 15 deletions.
9 changes: 5 additions & 4 deletions entry/cityflow_eval
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ from smartcross.utils.config_utils import read_ding_config
from smartcross.policy import FixedPolicy


def main(args, seed=0):
def main(args, seed=None):
ding_cfg = args.ding_cfg
main_config, create_config = read_ding_config(ding_cfg)
cityflow_env_config = {'config_path': args.env_cfg}
if args.policy_type == 'fix':
create_config.policy['type'] = 'smartcross_fix'
main_config.env = deep_merge_dicts(main_config.env, cityflow_env_config)
cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True)
cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True, save_cfg=False)
if args.env_num > 0:
cfg.env.evaluator_env_num = args.env_num
if cfg.env.n_evaluator_episode < args.env_num:
Expand All @@ -47,6 +47,7 @@ def main(args, seed=0):
policy,
)
_, eval_reward = evaluator.eval(None, -1, -1, cfg.env.n_evaluator_episode)
eval_reward = [r['final_eval_reward'].item() for r in eval_reward]
print('Eval is over! The performance is {}'.format(eval_reward))
evaluator.close()
return eval_reward
Expand All @@ -56,11 +57,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='DI-smartcross training script')
parser.add_argument('-d', '--ding-cfg', default=None, help='DI-engine configuration path')
parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path')
parser.add_argument('-s', '--seed', default=0, type=int, help='random seed for sumo')
parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo')
parser.add_argument(
'-p', '--policy-type', default='dqn', choices=['fix', 'dqn', 'ppo'], help='RL policy type'
)
parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation')
parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation')
parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path')

args = parser.parse_args()
Expand Down
8 changes: 5 additions & 3 deletions entry/cityflow_train
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from ding.utils.default_helper import set_pkg_seed
from ding.utils import deep_merge_dicts
from ding.rl_utils import get_epsilon_greedy_fn
from smartcross.utils.config_utils import read_ding_config
from smartcross.policy.default_policy import get_random_sample_func


def main(args, seed=None):
Expand Down Expand Up @@ -63,7 +64,8 @@ def main(args, seed=None):
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.action_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
random_sample_func = get_random_sample_func(action_space)
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, forward_fn=random_sample_func)
collector.reset_policy(random_policy)
new_data = collector.collect(n_sample=cfg.policy.random_collect_size)
replay_buffer.push(new_data, cur_collector_envstep=0)
Expand Down Expand Up @@ -115,8 +117,8 @@ if __name__ == "__main__":
parser.add_argument('-d', '--ding-cfg', required=True, help='DI-engine configuration path')
parser.add_argument('-e', '--env-cfg', required=True, help='cityflow json configuration path')
parser.add_argument('-s', '--seed', default=None, type=int, help='random seed')
parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator env num for training')
parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator env num for training')
parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt')

args = parser.parse_args()
Expand Down
9 changes: 5 additions & 4 deletions entry/sumo_eval
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from smartcross.utils.config_utils import get_sumo_config
from smartcross.policy import RandomPolicy, FixedPolicy


def main(args, seed=0):
def main(args, seed=None):
if args.policy_type in ['random', 'fix']:
from entry.sumo_config.sumo_eval_default_config import main_config, create_config
with open(args.env_cfg, 'r') as f:
Expand All @@ -27,7 +27,7 @@ def main(args, seed=0):
main_config, create_config = get_sumo_config(args)
if args.gui:
main_config.env.gui = True
cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True)
cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True, save_cfg=False)
if args.env_num > 0:
cfg.env.evaluator_env_num = args.env_num
if cfg.env.n_evaluator_episode < args.env_num:
Expand All @@ -54,6 +54,7 @@ def main(args, seed=0):
policy,
)
_, eval_reward = evaluator.eval(None, -1, -1, cfg.env.n_evaluator_episode)
eval_reward = [r['final_eval_reward'].item() for r in eval_reward]
print('Eval is over! The performance is {}'.format(eval_reward))
evaluator.close()
return eval_reward
Expand All @@ -63,12 +64,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='DI-smartcross training script')
parser.add_argument('-d', '--ding-cfg', default=None, help='DI-engine configuration path')
parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path')
parser.add_argument('-s', '--seed', default=0, type=int, help='random seed for sumo')
parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo')
parser.add_argument(
'-p', '--policy-type', default='dqn', choices=['random', 'fix', 'dqn', 'rainbow', 'ppo'], help='RL policy type'
)
parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow")
parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation')
parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation')
parser.add_argument('--gui', action='store_true', help="open gui for visualize")
parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path')

Expand Down
8 changes: 5 additions & 3 deletions entry/sumo_train
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from ding.worker import BaseLearner, InteractionSerialEvaluator, create_serial_c
from ding.utils.default_helper import set_pkg_seed
from ding.rl_utils import get_epsilon_greedy_fn
from smartcross.utils.config_utils import get_sumo_config
from smartcross.policy.default_policy import get_random_sample_func


def main(args, seed=None):
Expand Down Expand Up @@ -60,7 +61,8 @@ def main(args, seed=None):
# Accumulate plenty of data at the beginning of training.
if cfg.policy.get('random_collect_size', 0) > 0:
action_space = collector_env.action_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
random_sample_func = get_random_sample_func(action_space)
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, forward_fn=random_sample_func)
collector.reset_policy(random_policy)
new_data = collector.collect(n_sample=cfg.policy.random_collect_size)
replay_buffer.push(new_data, cur_collector_envstep=0)
Expand Down Expand Up @@ -113,8 +115,8 @@ if __name__ == "__main__":
parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path')
parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo')
parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow")
parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector sumo env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator sumo env num for training')
parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector sumo env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator sumo env num for training')
parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt')

args = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions smartcross/envs/cityflow_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def close(self) -> None:
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
self._eng.set_random_seed(seed)

@property
def observation_space(self) -> gym.spaces.Space:
Expand Down
2 changes: 1 addition & 1 deletion smartcross/envs/crossing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, tls_id: str, env: 'BaseEnv') -> None:

self._lane_vehicle_dict = {}
self._previous_lane_vehicle_dict = {}
signal_definition = traci.trafficlight.getCompleteRedYellowGreenDefinition(self._id)[0]
signal_definition = traci.trafficlight.getAllProgramLogics(self._id)[0]
self._green_phases = []
self._yellow_phases = []
for idx, phase in enumerate(signal_definition.phases):
Expand Down
2 changes: 2 additions & 0 deletions smartcross/envs/sumo_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def reset(self) -> Any:
self._action_runner.reset()
self._obs_runner.reset()
self._reward_runner.reset()
if self._launch_env_flag:
self.close()
self._launch_env(self._gui)
for tl in self._cfg.tls:
self._crosses[tl] = Crossing(tl, self)
Expand Down
13 changes: 13 additions & 0 deletions smartcross/policy/default_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def default_config(cls: type) -> EasyDict:
return cfg


def get_random_sample_func(act_space):

def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
actions = {}
for env_id in data:
action = act_space.sample()
action = [torch.LongTensor([v]) for v in action]
actions[env_id] = {'action': action}
return actions

return _forward


@POLICY_REGISTRY.register('smartcross_fix')
class FixedPolicy():

Expand Down

0 comments on commit 6706b8d

Please sign in to comment.