-
Notifications
You must be signed in to change notification settings - Fork 0
/
DQN.py
executable file
·128 lines (116 loc) · 5.95 KB
/
DQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from models import DuelingNoisyQNetwork
from torch.optim import RMSprop
from typing import Tuple, Dict
import numpy as np
import torch
class ReplayBuffer:
def __init__(self, buffer_size: int, batch_size: int,
state_shape: Tuple[int, int]) -> None:
self.buffer_size = buffer_size
self.batch_size = batch_size
self.state_memory = np.empty((buffer_size, 1, *state_shape),
dtype=np.float32)
self.action_memory = np.empty(buffer_size, dtype=np.int64)
self.reward_memory = np.empty(buffer_size, dtype=np.float32)
self.next_state_memory = np.empty((buffer_size, 1, *state_shape),
dtype=np.float32)
self.is_done_memory = np.empty(buffer_size, dtype=np.bool_)
self.memory_ptr = 0
self.cur_size = 0
def store(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, is_done: bool) -> None:
self.state_memory[self.memory_ptr] = np.expand_dims(state, 0)
self.action_memory[self.memory_ptr] = action
self.reward_memory[self.memory_ptr] = reward
self.next_state_memory[self.memory_ptr] = np.expand_dims(next_state, 0)
self.is_done_memory[self.memory_ptr] = is_done
self.memory_ptr = (self.memory_ptr + 1) % self.buffer_size
self.cur_size = min(self.cur_size + 1, self.buffer_size)
def sample(self) -> Dict[str, torch.Tensor]:
selected_idxs = np.random.choice(self.cur_size, self.batch_size,
False)
return {
"states": torch.from_numpy(self.state_memory[selected_idxs]),
"actions": torch.from_numpy(self.action_memory[selected_idxs]),
"rewards": torch.from_numpy(self.reward_memory[selected_idxs]),
"next_states": torch.from_numpy(self.next_state_memory[
selected_idxs]),
"is_dones": torch.from_numpy(self.is_done_memory[selected_idxs])
}
class Agent:
def __init__(self, state_shape: Tuple[int, int], n_actions: int,
buffer_size: int, batch_size: int, alpha: float, gamma: float,
max_epsilon: float, min_epsilon: float, epsilon_step: int,
tau: float) -> None:
assert 0. < alpha < 1.
assert 0. < gamma < 1.
assert 0. <= max_epsilon <= 1.
assert 0. <= min_epsilon <= 1.
assert 0. < tau < 1.
assert 0 < epsilon_step
self.network = DuelingNoisyQNetwork(n_actions)
self.target_network = DuelingNoisyQNetwork(n_actions)
self.target_network.load_state_dict(self.network.state_dict())
self.network_optimizer = RMSprop(self.network.parameters(), alpha)
self.replay_buffer = ReplayBuffer(buffer_size, batch_size, state_shape)
self.gamma = gamma
self.epsilon = max_epsilon
self.min_epsilon = min_epsilon
self.epsilon_decay_rate = (max_epsilon - min_epsilon) / epsilon_step
self.tau = tau
def choose_action(self, state: np.ndarray, action_mask: np.ndarray):
valid_actions = []
state_action_values = self.network(torch.from_numpy(np.expand_dims(
np.expand_dims(state, 0), 0)))[0].detach()
max_ = -np.inf
best_action = 0
for i, state_action_value in enumerate(state_action_values):
if (state_action_value > max_) and action_mask[i]:
max_ = state_action_value
best_action = i
if action_mask[i]:
valid_actions.append(i)
return np.random.choice(valid_actions) if \
np.random.uniform(0., 1.) <= self.epsilon else best_action
def update(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, is_done: bool) -> float:
self.replay_buffer.store(state, action, reward, next_state, is_done)
if self.replay_buffer.cur_size >= self.replay_buffer.batch_size:
self.network_optimizer.zero_grad()
data = self.replay_buffer.sample()
states, actions, rewards, next_states, is_dones = (
data["states"], data["actions"], data["rewards"],
data["next_states"], data["is_dones"])
states_action_values = self.network(states).gather(
1, actions.unsqueeze(0).T)
loss = torch.nn.functional.mse_loss(
states_action_values, rewards.view(-1, 1) + self.gamma *
self.target_network(next_states).gather(
1, self.network(next_states).argmax(1).unsqueeze(0).T
).detach()
* ~is_dones.view(-1, 1))
loss.backward()
self.network_optimizer.step()
self.network.reset_noise()
self.target_network.reset_noise()
self.update_target_network()
return loss.item()
return .0
def decay_epsilon(self) -> None:
self.epsilon = max(self.epsilon - self.epsilon_decay_rate,
self.min_epsilon)
def update_target_network(self) -> None:
for target_param, param in zip(self.target_network.parameters(),
self.network.parameters()):
target_param.data.copy_(self.tau * param + (1. - self.tau) *
target_param)
def save(self, path_model1: str, path_model2: str) -> None:
with open(path_model1, "wb") as f:
torch.save(self.network.state_dict(), f)
with open(path_model2, "wb") as f:
torch.save(self.target_network.state_dict(), f)
def load(self, path_model1: str, path_model2: str) -> None:
with open(path_model1, "rb") as f:
self.network.load_state_dict(torch.load(f))
with open(path_model2, "rb") as f:
self.target_network.load_state_dict(torch.load(f))