generated from CogitoNTNU/README-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
129 lines (100 loc) · 4.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from argparse import ArgumentParser
from torch import multiprocessing as mp
from src.neuralnet.neural_network import NeuralNetwork
from src.neuralnet.neural_network_connect_four import NeuralNetworkConnectFour
from src.neuralnet.create_neural_network import create_tic_tac_toe_model, create_connect_four_model
from src.alphazero.agents.alphazero_play_agent import alphazero_self_play
from src.alphazero.alphazero_train_model import train_alphazero_model
from src.play.play_vs_alphazero import main as play_vs_alphazero
from src.utils.game_context import GameContext
def test_overfit(context: GameContext):
mp.set_start_method('spawn')
train_alphazero_model(
context=context,
num_games=3,
num_simulations=100,
epochs=1000,
batch_size=256
)
def train_tic_tac_toe(context: GameContext):
mp.set_start_method('spawn')
try:
for i in range(int(1e6)):
train_alphazero_model(
context=context,
num_games=48,
num_simulations=100,
epochs=3,
batch_size=32
)
print(f'Training session {i + 1} finished!')
except KeyboardInterrupt:
print('Training interrupted!')
def train_connect_four(context: GameContext):
mp.set_start_method('spawn')
try:
for i in range(int(1e6)):
train_alphazero_model(
context=context,
num_games=20,
num_simulations=100,
epochs=3,
batch_size=256,
)
print(f'Training session {i + 1} finished!')
except KeyboardInterrupt:
print('Training interrupted!')
def self_play(context: GameContext):
alphazero_self_play(context)
def play(context: GameContext, first: bool, mcts: bool = False):
play_vs_alphazero(
context=context,
first=first,
mcts=mcts
)
overfit_path = "./models/overfit/connect4_nn.nn"
overfit_context = GameContext(
game_name="connect_four",
nn=NeuralNetworkConnectFour().load(overfit_path),
save_path="./models/overfit/connect4_overfit_waste.nn"
)
tic_tac_toe_path = "./models/tic_tac_toe/good_nn.nn"
tic_tac_toe_context = GameContext(
game_name="tic_tac_toe",
nn=NeuralNetwork().load(tic_tac_toe_path),
save_path=tic_tac_toe_path
)
connect4_path = "./models/connect_four/good_nn.nn"
connect4_context = GameContext(
game_name="connect_four",
nn=NeuralNetworkConnectFour().load(connect4_path),
save_path=connect4_path
)
parser: ArgumentParser = ArgumentParser(description='Control the execution of the AlphaZero game playing system.')
parser.add_argument('--test_overfit', action='store_true', help='Test overfitting on Connect Four.')
parser.add_argument('--train_ttt', action='store_true', help='Train AlphaZero on Tic Tac Toe.')
parser.add_argument('--train_c4', action='store_true', help='Train AlphaZero on Connect Four for a long time.')
parser.add_argument('--self_play_ttt', action='store_true', help='Run self-play on Tic Tac Toe.')
parser.add_argument('--self_play_c4', action='store_true', help='Run self-play on Connect Four.')
parser.add_argument('--play_ttt', action='store_true', help='Enable playing against AlphaZero on Tic Tac Toe.')
parser.add_argument('--play_c4', action='store_true', help='Play against AlphaZero on Connect Four.')
parser.add_argument('-f', '--first', action='store_true', help='Play first in the game.')
parser.add_argument('-m', '--mcts', action='store_true', help='Replace human player with MCTS.')
args = parser.parse_args()
if __name__ == '__main__': # Needed for multiprocessing to work
if args.test_overfit:
test_overfit(overfit_context)
if args.train_ttt:
train_tic_tac_toe(tic_tac_toe_context)
if args.train_c4:
train_connect_four(connect4_context)
if args.self_play_ttt:
self_play(tic_tac_toe_context)
if args.self_play_c4:
self_play(connect4_context)
if args.play_ttt:
play(tic_tac_toe_context, first=args.first, mcts=args.mcts)
if args.play_c4:
play(connect4_context, first=args.first, mcts=args.mcts)
# create_tic_tac_toe_model("initial_test")
# create_connect_four_model("overfit_nn")