forked from eleurent/rl-agents
-
Notifications
You must be signed in to change notification settings - Fork 0
/
robust_value_iteration.py
73 lines (57 loc) · 2.51 KB
/
robust_value_iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
from rl_agents.agents.dynamic_programming.value_iteration import ValueIterationAgent
class RobustValueIterationAgent(ValueIterationAgent):
def __init__(self, env, config=None):
super(ValueIterationAgent, self).__init__(config)
self.env = env
self.mode = None
self.transitions = np.array([]) # Dimension: M x S x A (x S)
self.rewards = np.array([]) # Dimension: M x S x A
self.models_from_config()
@classmethod
def default_config(cls):
config = super(RobustValueIterationAgent, cls).default_config()
config.update(dict(models=[]))
return config
def models_from_config(self):
if not self.config.get("models", None):
raise ValueError("No finite MDP model provided in agent configuration")
self.mode = self.config["models"][0]["mode"] # Assume all modes are the same
self.transitions = np.array([mdp["transition"] for mdp in self.config["models"]])
self.rewards = np.array([mdp["reward"] for mdp in self.config["models"]])
def act(self, state):
return np.argmax(self.get_state_action_value()[state, :])
def get_state_value(self):
return self.fixed_point_iteration(
lambda v: RobustValueIterationAgent.best_action_value(
RobustValueIterationAgent.worst_case(
self.bellman_expectation(v))),
np.zeros((self.transitions.shape[1],)))
def get_state_action_value(self):
return self.fixed_point_iteration(
lambda q: RobustValueIterationAgent.worst_case(
self.bellman_expectation(
RobustValueIterationAgent.best_action_value(q))),
np.zeros(self.transitions.shape[1:3]))
@staticmethod
def worst_case(model_action_values):
return np.min(model_action_values, axis=0)
def bellman_expectation(self, value):
if self.mode == "deterministic":
next_v = value[self.transitions]
elif self.mode == "stochastic":
v_shaped = value.reshape((1, 1, 1, np.size(value)))
next_v = (self.transitions * v_shaped).sum(axis=-1)
else:
raise ValueError("Unknown mode")
return self.rewards + self.config["gamma"] * next_v
def record(self, state, action, reward, next_state, done, info):
pass
def reset(self):
pass
def seed(self, seed=None):
pass
def save(self, filename):
return False
def load(self, filename):
return False