-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy.py
110 lines (86 loc) · 3.44 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from dsl import *
from env_settings import *
from lbforaging.foraging.environment import Action as ForagingActions
import numpy as np
class StateActionProgram(object):
"""
A callable object with input (state, action) and Boolean output.
Made a class to have nice strs and pickling and to avoid redundant evals.
"""
def __init__(self, program):
self.program = program
self.wrapped = None
def __call__(self, *args, **kwargs):
if self.wrapped is None:
self.wrapped = eval('lambda s, a, pos: ' + self.program)
return self.wrapped(*args, **kwargs)
def __repr__(self):
return self.program
def __str__(self):
return self.program
def __getstate__(self):
return self.program
def __setstate__(self, program):
self.program = program
self.wrapped = None
def __add__(self, s):
if isinstance(s, str):
return StateActionProgram(self.program + s)
elif isinstance(s, StateActionProgram):
return StateActionProgram(self.program + s.program)
raise Exception()
def __radd__(self, s):
if isinstance(s, str):
return StateActionProgram(s + self.program)
elif isinstance(s, StateActionProgram):
return StateActionProgram(s.program + self.program)
raise Exception()
class PLPPolicy(object):
def __init__(self, plps, probs, seed=0, map_choices=True):
#assert abs(np.sum(probs) - 1.) < 1e-5
self.plps = plps
self.probs = probs
self.map_choices = map_choices
self.rng = np.random.RandomState(seed)
self._action_prob_cache = {}
def __call__(self, obs, pos, vis=False):
action_probs = self.get_action_probs(obs, pos, vis)
action_probs = action_probs.flatten()
if self.map_choices:
idx = np.argmax(action_probs).squeeze()
# if vis:
#print(f'chosen action: {idx}')
#selected_action = action_single_probs[idx]
#selected_action.sort(key=lambda x: x[0])
#best_lpp = selected_action[-1]
# print(best_lpp)
else:
idx = self.rng.choice(len(action_probs), p=action_probs)
return idx
def hash_obs(self, full_obs):
return tuple(tuple(tuple(l) for l in obs) for obs in full_obs)
def get_action_probs(self, obs, pos, vis):
hashed_obs = self.hash_obs(obs)
if hashed_obs in self._action_prob_cache:
return self._action_prob_cache[hashed_obs]
action_probs = np.zeros(len(ForagingActions), dtype=np.float32)
action_single_probs = {}
for plp, prob in zip(self.plps, self.probs):
for action in self.get_plp_suggestions(plp, obs, pos):
if action not in action_single_probs:
action_single_probs[action] = []
action_probs[action] += prob
action_single_probs[action].append((prob, plp))
denom = np.sum(action_probs)
if denom == 0.:
action_probs += 1./(len(action_probs))
else:
action_probs = action_probs / denom
self._action_prob_cache[hashed_obs] = action_probs
return action_probs
def get_plp_suggestions(self, plp, obs, pos):
suggestions = []
for action in ForagingActions:
if plp(obs, action.value, pos):
suggestions.append(action.value)
return suggestions