-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRmCooDisEnv.py
82 lines (61 loc) · 2.17 KB
/
RmCooDisEnv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
import numpy as np
import gym
import random
import os
import time
class RmCooDisEnv():
def __init__(self, seed=1234, port=10000, no_graphics=True, time_scale = 100):
self.port = port
self.seed = seed + port
self.env = Unity3DEnv('coo_c_vs_a/coo_c_vs_a.x86_64',
seed = self.seed, port = self.port,
no_graphics=no_graphics, time_scale=time_scale)
self.obs_size = 21
self.act_size = 10
# self.group_name = list(self.env.unity_env.behavior_specs.keys())
self.observation_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self.obs_size,))
self.action_space = gym.spaces.Discrete(self.act_size)
self.agent_name = []
self.obs = None
self.state = None
obs_dict = self.env.reset(None)
for k,v in obs_dict.items():
self.agent_name.append(k)
pass
def reset(self, dy_para_dict):
obs_dict = self.env.reset(dy_para_dict)
obs = []
for k,v in obs_dict.items():
obs.append(v)
self.obs = obs
return obs
def step(self, action):
act_dict = {}
for i, name in enumerate(self.agent_name):
act_dict[name] = np.array([action[i]//5, action[i]%5])
obs_dict, rewards_dict, dones, infos = self.env.step(act_dict)
obs = []
for k,v in obs_dict.items():
obs.append(v)
rewards = []
for k,v in rewards_dict.items():
rewards.append(v)
self.obs = obs
return obs, rewards, dones, infos
def get_obs(self):
return self.obs
def get_state(self):
self.state = []
self.state.append(self.obs[0])
return self.state
def get_env_info(self):
env_info = {"n_actions": 10, "n_agents_per_party": 2, "state_shape": 21, "obs_shape": 21, "episode_limit": 200}
return env_info
def close(self):
self.env.close()
pass
# env = RmCooConEnv()
# dy_para_dict = {"VK1": 0.0375, "level": 0}
# obs = env.reset(dy_para_dict)
# print(obs)