-
Notifications
You must be signed in to change notification settings - Fork 0
/
ResidualDynamics_DL_minibatch.py
125 lines (100 loc) · 4.37 KB
/
ResidualDynamics_DL_minibatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
# TODO: fix mini-batch!
# Define the neural network architecture
torch.set_default_dtype(torch.float64)
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(6, 20) # 6 input features, 64 hidden units
self.fc2 = nn.Linear(20, 20) # 64 hidden units, 64 hidden units
self.fc3 = nn.Linear(20, 20) # 64 hidden units, 64 hidden units
self.fc4 = nn.Linear(20, 3) # 64 hidden units, 3 output units
def forward(self, x):
x = torch.relu(self.fc1(x)) # ReLU activation for the first layer
x = torch.relu(self.fc2(x)) # ReLU activation for the second layer
x = torch.relu(self.fc3(x)) # ReLU activation for the third layer
x = self.fc4(x) # Final output layer, no activation
return x
class SimpleNNWithLSTM(nn.Module):
def __init__(self):
super(SimpleNNWithLSTM, self).__init__()
self.fc1 = nn.Linear(6, 10) # 6 input features, 10 hidden units
self.lstm = nn.LSTM(
10, 10, batch_first=True
) # LSTM layer with input size 10 and hidden size 10
self.fc2 = nn.Linear(10, 10) # 10 hidden units, 3 output units
self.fc3 = nn.Linear(10, 3) # 10 hidden units, 3 output units
def forward(self, x):
x = torch.relu(self.fc1(x)) # ReLU activation for the first layer
x, _ = self.lstm(x.unsqueeze(0)) # LSTM layer, unsqueeze to add batch dimension
x = torch.relu(x.squeeze(0)) # Remove batch dimension and apply ReLU activation
x = torch.relu(self.fc2(x)) # ReLU activation for the second-to-last layer
x = self.fc3(x) # Final output layer, no activation
return x
# Create an instance of the neural network
model = SimpleNN()
# model = SimpleNNWithLSTM() #TODO: not working!
# Define loss function and optimizer
criterion = nn.MSELoss() # Mean Squared Error loss
optimizer = optim.Adam(model.parameters(), lr=0.0001) # Adam optimizer
# Load the filtered state data
data = np.load("Project/filtered_state_EKF_CR3BP.npy")
np.random.shuffle(data.T) # Shuffle each column randomly, before splitting
inputs = torch.tensor(data[:6, :]).t()
targets = torch.tensor(data[6:, :]).t()
# Lists to store training loss for plotting
train_loss_history = []
# Lists to store prediction errors for verification plot
prediction_errors = []
# Convert inputs and targets to PyTorch Dataset
dataset = TensorDataset(inputs, targets)
# Create DataLoader for batch processing
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Training loop
num_epochs = 50000
for epoch in range(num_epochs):
# Randomly sample a batch from the DataLoader
batch_inputs, batch_targets = next(iter(dataloader))
# Forward pass
outputs = model(batch_inputs)
loss = criterion(outputs, batch_targets)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Calculate loss and prediction error for the entire dataset after each epoch
with torch.no_grad():
outputs = model(inputs)
loss = criterion(outputs, targets)
train_loss_history.append(loss.item())
prediction_error = torch.abs(outputs - targets).mean().item()
prediction_errors.append(prediction_error)
if (epoch + 1) % 100 == 0:
print(
f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.12f}, Prediction Error: {prediction_error:.12f}"
)
# Optionally, save the trained model
torch.save(model.state_dict(), "Project/simple_nn_model.pth")
# Plot the training loss
plt.figure()
# plt.rc("text", usetex=True)
plt.semilogy(train_loss_history, color="blue")
plt.xlabel(r"Training Epoch [-]")
plt.ylabel(r"Loss [-]")
plt.grid(True, which="both", linestyle="--")
# plt.savefig("Project/TrainingLoss.pdf", format="pdf")
plt.show()
# Plot the prediction errors for verification
plt.figure()
plt.semilogy(prediction_errors, color="red")
plt.xlabel(r"Training Epoch [-]")
plt.ylabel(r"Mean Prediction Error")
plt.grid(True, which="both", linestyle="--")
# plt.savefig("Project/PredictionError.pdf", format="pdf")
plt.show()