-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathQ_learning.py
142 lines (116 loc) · 5.49 KB
/
Q_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Imports:
# --------
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
# Function 1: Train Q-learning agent
# -----------
def train_q_learning(env,
no_episodes,
epsilon,
epsilon_min,
epsilon_decay,
alpha,
gamma,
q_table_save_path="q_table.npy"):
# Initialize the Q-table:
# -----------------------
q_table = np.zeros((env.grid_size, env.grid_size, env.action_space.n))
# Q-learning algorithm:
# ---------------------
#! Step 1: Run the algorithm for fixed number of episodes
#! -------
for episode in range(no_episodes):
state, _ = env.reset()
state = tuple(state)
total_reward = 0
#! Step 2: Take actions in the environment until "Done" flag is triggered
#! -------
while True:
#! Step 3: Define your Exploration vs. Exploitation
#! -------
if np.random.rand() < epsilon:
action = env.action_space.sample() # Explore
else:
action = np.argmax(q_table[state]) # Exploit
next_state, reward, done, _ = env.step(action)
env.render()
next_state = tuple(next_state)
total_reward += reward
#! Step 4: Update the Q-values using the Q-value update rule
#! -------
q_table[state][action] = q_table[state][action] + alpha * \
(reward + gamma *
np.max(q_table[next_state]) - q_table[state][action])
state = next_state
#! Step 5: Stop the episode if the agent reaches Goal or Hell-states
#! -------
if done:
break
#! Step 6: Perform epsilon decay
#! -------
epsilon = max(epsilon_min, epsilon * epsilon_decay)
print(f"Episode {episode + 1}: Total Reward: {total_reward}")
#! Step 7: Close the environment window
#! -------
env.close()
print("Training finished.\n")
#! Step 8: Save the trained Q-table
#! -------
np.save(q_table_save_path, q_table)
print("Saved the Q-table.")
# Function 2: Visualize the Q-table
# -----------
def visualize_q_table(hell_state_coordinates,
goal_coordinates,
actions=["Up", "Down", "Right", "Left"],
q_values_path="q_table.npy"):
# Load the Q-table:
# -----------------
try:
q_table = np.load(q_values_path)
# Create subplots for each action:
# --------------------------------
_, axes = plt.subplots(1, 4, figsize=(20, 5))
for i, action in enumerate(actions):
ax = axes[i]
heatmap_data = q_table[:, :, i].copy()
# Mask the goal state's Q-value for visualization:
# ------------------------------------------------
mask = np.zeros_like(heatmap_data, dtype=bool)
mask[goal_coordinates] = True
mask[hell_state_coordinates[0]] = True
mask[hell_state_coordinates[1]] = True
mask[hell_state_coordinates[2]] = True
mask[hell_state_coordinates[3]] = True
mask[hell_state_coordinates[4]] = True
mask[hell_state_coordinates[5]] = True
mask[hell_state_coordinates[6]] = True
mask[hell_state_coordinates[7]] = True
sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="viridis",
ax=ax, cbar=False, mask=mask, annot_kws={"size": 9})
# Denote Goal and Hell states:
# ----------------------------
ax.text(goal_coordinates[1] + 0.5, goal_coordinates[0] + 0.5, 'G', color='green',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[0][1] + 0.5, hell_state_coordinates[0][0] + 0.5, 'H', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[1][1] + 0.5, hell_state_coordinates[1][0] + 0.5, 'H', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[2][1] + 0.5, hell_state_coordinates[2][0] + 0.5, 'H', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[3][1] + 0.5, hell_state_coordinates[3][0] + 0.5, 'B', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[4][1] + 0.5, hell_state_coordinates[4][0] + 0.5, 'B', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[5][1] + 0.5, hell_state_coordinates[5][0] + 0.5, 'B', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[6][1] + 0.5, hell_state_coordinates[6][0] + 0.5, 'B', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.text(hell_state_coordinates[7][1] + 0.5, hell_state_coordinates[7][0] + 0.5, 'B', color='red',
ha='center', va='center', weight='bold', fontsize=14)
ax.set_title(f'Action: {action}')
plt.tight_layout()
plt.show()
except FileNotFoundError:
print("No saved Q-table was found. Please train the Q-learning agent first or check your path.")