-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
40 lines (28 loc) · 1.07 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch as T
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
import os
class Network(nn.Module):
def __init__(self, lr, n_actions, name, input_dims, chkpt_dir):
super(Network, self).__init__()
self.checkpoint_dir = chkpt_dir
self.checkpoint_file = os.path.join(self.checkpoint_dir, name)
self.fc1 = nn.Linear(*input_dims, 512)
self.V = nn.Linear(512, 1)
self.A = nn.Linear(512, n_actions)
self.optimizer = optim.Adam(self.parameters(), lr=lr)
self.loss = nn.MSELoss()
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def forward(self, state):
flat1 = F.relu(self.fc1(state))
V = self.V(flat1)
A = self.A(flat1)
return V, A
def save_checkpoint(self):
print('... saving checkpoint ...')
T.save(self.state_dict(), self.checkpoint_file)
def load_checkpoint(self):
print('... loading checkpoint ...')
self.load_state_dict(T.load(self.checkpoint_file))