diff --git a/README.md b/README.md
index 7da1f9a..0100cd6 100644
--- a/README.md
+++ b/README.md
@@ -87,12 +87,12 @@ By adhering to the above guidelines, you'll be well-prepared to contribute to or
-----
-
-
+
-Left to right: [@example](https://github.com/Jonrodtang) [@example](https://github.com/Jonrodtang) [@example](https://github.com/Jonrodtang) [@example](https://github.com/Jonrodtang)
+Back row left to right: [Nils Henrik Lund](https://github.com/Nilsthehacker), [Haagen Mæland Moe](https://github.com/Thesmund)
+Front row left to right: [Kristian Carlenius](https://github.com/kristiancarlenius), [Ludvig Øvrevik](https://github.com/ludvigovrevik), [Christian Fredrik Johnsen](https://github.com/ChristianFredrikJohnsen), [Brage Kvamme](https://github.com/BrageHK)
#### Leaders
diff --git a/docs/pictures/alphazero-group-image.jpg b/docs/pictures/alphazero-group-image.jpg
new file mode 100644
index 0000000..01bb30c
Binary files /dev/null and b/docs/pictures/alphazero-group-image.jpg differ
diff --git a/docs/pictures/sample_pic.jpg b/docs/pictures/sample_pic.jpg
deleted file mode 100644
index 18227f7..0000000
Binary files a/docs/pictures/sample_pic.jpg and /dev/null differ
diff --git a/main.py b/main.py
index 84cb0d9..de8f214 100644
--- a/main.py
+++ b/main.py
@@ -1,7 +1,5 @@
-import pyspiel
-import torch
+from argparse import ArgumentParser, ArgumentTypeError
from torch import multiprocessing as mp
-from src.alphazero.agents.alphazero_training_agent import AlphaZero
from src.neuralnet.neural_network import NeuralNetwork
from src.neuralnet.neural_network_connect_four import NeuralNetworkConnectFour
@@ -24,8 +22,8 @@ def test_overfit(context: GameContext):
context=context,
num_games=3,
num_simulations=100,
- epochs=1,
- batch_size=64
+ epochs=1000,
+ batch_size=256
)
def train_tic_tac_toe(context: GameContext):
@@ -51,7 +49,7 @@ def train_connect_four(context: GameContext):
train_alphazero_model(
context=context,
num_games=20,
- num_simulations=200,
+ num_simulations=100,
epochs=3,
batch_size=256,
)
@@ -91,19 +89,57 @@ def play(context: GameContext, first: bool, mcts: bool = False):
save_path=connect4_path
)
+def str2bool(v: str) -> bool:
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise ArgumentTypeError('Boolean value expected.')
+
+parser: ArgumentParser = ArgumentParser(description='Control the execution of the AlphaZero game playing system.')
+parser.add_argument('--test_overfit', action='store_true', help='Test overfitting on Connect Four.')
+parser.add_argument('--train_tic_tac_toe', action='store_true', help='Train AlphaZero on Tic Tac Toe.')
+parser.add_argument('--train_connect_four', action='store_true', help='Train AlphaZero on Connect Four for a long time.')
+
+parser.add_argument('--self_play_ttt', action='store_true', help='Run self-play on Tic Tac Toe.')
+parser.add_argument('--self_play_c4', action='store_true', help='Run self-play on Connect Four.')
+
+parser.add_argument('--play_ttt', action='store_true', help='Enable playing against AlphaZero on Tic Tac Toe.')
+parser.add_argument('--play_c4', action='store_true', help='Play against AlphaZero on Connect Four.')
+
+parser.add_argument('-f', '--first', action='store_true', help='Play first in the game.')
+parser.add_argument('-m', '--mcts', action='store_true', help='Replace human player with MCTS.')
+
+args = parser.parse_args()
+
+
if __name__ == '__main__': # Needed for multiprocessing to work
+ if args.test_overfit:
+ test_overfit(overfit_context)
+
+ if args.train_tic_tac_toe:
+ train_tic_tac_toe(tic_tac_toe_context)
+
+ if args.train_connect_four:
+ train_connect_four(connect4_context)
+
+ if args.self_play_ttt:
+ self_play(tic_tac_toe_context)
+ if args.self_play_c4:
+ self_play(connect4_context)
- # test_overfit(overfit_context)
- # train_tic_tac_toe(tic_tac_toe_context)
- # train_connect_four(connect4_context)
- # self_play(tic_tac_toe_context)
- # self_play(connect4_context)
- # play(tic_tac_toe_context, first=False)
- play(connect4_context, first=False)
- # play(connect4_context, first=True, mcts=True)
+ if args.play_ttt:
+ play(tic_tac_toe_context, first=args.first)
+
+ if args.play_c4:
+ play(connect4_context, first=args.first, mcts=args.mcts)
+
# create_tic_tac_toe_model("initial_test")
# create_connect_four_model("overfit_nn")