Skip to content

Commit

Permalink
fix env num bug
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinC94 committed Apr 24, 2022
1 parent eec841d commit 516ade0
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion entry/cityflow_eval
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ if __name__ == "__main__":
parser.add_argument(
'-p', '--policy-type', default='dqn', choices=['fix', 'dqn', 'ppo'], help='RL policy type'
)
parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation')
parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation')
parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path')

args = parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions entry/cityflow_train
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ if __name__ == "__main__":
parser.add_argument('-d', '--ding-cfg', required=True, help='DI-engine configuration path')
parser.add_argument('-e', '--env-cfg', required=True, help='cityflow json configuration path')
parser.add_argument('-s', '--seed', default=None, type=int, help='random seed')
parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator env num for training')
parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator env num for training')
parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt')

args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion entry/sumo_eval
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ if __name__ == "__main__":
'-p', '--policy-type', default='dqn', choices=['random', 'fix', 'dqn', 'rainbow', 'ppo'], help='RL policy type'
)
parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow")
parser.add_argument('-n', '--env-num', type=int, default=-1, help='sumo env num for evaluation')
parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation')
parser.add_argument('--gui', action='store_true', help="open gui for visualize")
parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path')

Expand Down
4 changes: 2 additions & 2 deletions entry/sumo_train
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ if __name__ == "__main__":
parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path')
parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo')
parser.add_argument('--dynamic-flow', action='store_true', help="use dynamic route flow")
parser.add_argument('-cn', '--collect-env-num', type=int, default=-1, help='collector sumo env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=-1, help='evaluator sumo env num for training')
parser.add_argument('-cn', '--collect-env-num', type=int, default=1, help='collector sumo env num for training')
parser.add_argument('-en', '--evaluate-env-num', type=int, default=1, help='evaluator sumo env num for training')
parser.add_argument('--exp-name', type=str, default=None, help='experiment name to save log and ckpt')

args = parser.parse_args()
Expand Down
2 changes: 2 additions & 0 deletions smartcross/policy/default_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def default_config(cls: type) -> EasyDict:


def get_random_sample_func(act_space):

def _forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
actions = {}
for env_id in data:
action = act_space.sample()
action = [torch.LongTensor([v]) for v in action]
actions[env_id] = {'action': action}
return actions

return _forward


Expand Down

0 comments on commit 516ade0

Please sign in to comment.