diff --git a/README.md b/README.md index 7da1f9a..0100cd6 100644 --- a/README.md +++ b/README.md @@ -87,12 +87,12 @@ By adhering to the above guidelines, you'll be well-prepared to contribute to or ----- -
- +
-Left to right: [@example](https://github.com/Jonrodtang) [@example](https://github.com/Jonrodtang) [@example](https://github.com/Jonrodtang) [@example](https://github.com/Jonrodtang) +Back row left to right: [Nils Henrik Lund](https://github.com/Nilsthehacker), [Haagen Mæland Moe](https://github.com/Thesmund) +Front row left to right: [Kristian Carlenius](https://github.com/kristiancarlenius), [Ludvig Øvrevik](https://github.com/ludvigovrevik), [Christian Fredrik Johnsen](https://github.com/ChristianFredrikJohnsen), [Brage Kvamme](https://github.com/BrageHK) #### Leaders diff --git a/docs/pictures/alphazero-group-image.jpg b/docs/pictures/alphazero-group-image.jpg new file mode 100644 index 0000000..01bb30c Binary files /dev/null and b/docs/pictures/alphazero-group-image.jpg differ diff --git a/docs/pictures/sample_pic.jpg b/docs/pictures/sample_pic.jpg deleted file mode 100644 index 18227f7..0000000 Binary files a/docs/pictures/sample_pic.jpg and /dev/null differ diff --git a/main.py b/main.py index 84cb0d9..de8f214 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,5 @@ -import pyspiel -import torch +from argparse import ArgumentParser, ArgumentTypeError from torch import multiprocessing as mp -from src.alphazero.agents.alphazero_training_agent import AlphaZero from src.neuralnet.neural_network import NeuralNetwork from src.neuralnet.neural_network_connect_four import NeuralNetworkConnectFour @@ -24,8 +22,8 @@ def test_overfit(context: GameContext): context=context, num_games=3, num_simulations=100, - epochs=1, - batch_size=64 + epochs=1000, + batch_size=256 ) def train_tic_tac_toe(context: GameContext): @@ -51,7 +49,7 @@ def train_connect_four(context: GameContext): train_alphazero_model( context=context, num_games=20, - num_simulations=200, + num_simulations=100, epochs=3, batch_size=256, ) @@ -91,19 +89,57 @@ def play(context: GameContext, first: bool, mcts: bool = False): save_path=connect4_path ) +def str2bool(v: str) -> bool: + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise ArgumentTypeError('Boolean value expected.') + +parser: ArgumentParser = ArgumentParser(description='Control the execution of the AlphaZero game playing system.') +parser.add_argument('--test_overfit', action='store_true', help='Test overfitting on Connect Four.') +parser.add_argument('--train_tic_tac_toe', action='store_true', help='Train AlphaZero on Tic Tac Toe.') +parser.add_argument('--train_connect_four', action='store_true', help='Train AlphaZero on Connect Four for a long time.') + +parser.add_argument('--self_play_ttt', action='store_true', help='Run self-play on Tic Tac Toe.') +parser.add_argument('--self_play_c4', action='store_true', help='Run self-play on Connect Four.') + +parser.add_argument('--play_ttt', action='store_true', help='Enable playing against AlphaZero on Tic Tac Toe.') +parser.add_argument('--play_c4', action='store_true', help='Play against AlphaZero on Connect Four.') + +parser.add_argument('-f', '--first', action='store_true', help='Play first in the game.') +parser.add_argument('-m', '--mcts', action='store_true', help='Replace human player with MCTS.') + +args = parser.parse_args() + + if __name__ == '__main__': # Needed for multiprocessing to work + if args.test_overfit: + test_overfit(overfit_context) + + if args.train_tic_tac_toe: + train_tic_tac_toe(tic_tac_toe_context) + + if args.train_connect_four: + train_connect_four(connect4_context) + + if args.self_play_ttt: + self_play(tic_tac_toe_context) + if args.self_play_c4: + self_play(connect4_context) - # test_overfit(overfit_context) - # train_tic_tac_toe(tic_tac_toe_context) - # train_connect_four(connect4_context) - # self_play(tic_tac_toe_context) - # self_play(connect4_context) - # play(tic_tac_toe_context, first=False) - play(connect4_context, first=False) - # play(connect4_context, first=True, mcts=True) + if args.play_ttt: + play(tic_tac_toe_context, first=args.first) + + if args.play_c4: + play(connect4_context, first=args.first, mcts=args.mcts) + # create_tic_tac_toe_model("initial_test") # create_connect_four_model("overfit_nn")