-
Notifications
You must be signed in to change notification settings - Fork 6
/
train.py
55 lines (41 loc) · 1.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging
import torch
import yaml
import click
import azalea as az
@click.command()
@click.option('--config', type=click.Path(exists=True), required=True,
help='YAML configuration file')
@click.option('--rundir', type=click.Path(), required=True,
help='Directory to save results from training run')
@click.option('--startpos', type=click.Path(),
help='YAML file of board starting positions')
@click.option('--model', type=click.Path(),
help='Warm start training from model checkpoint')
@click.option('--replaybuf', type=click.Path(),
help='Warm start training from replay buffer checkpoint')
def main(config, rundir, model, replaybuf, startpos):
"""Train a chess AI model.
"""
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
az.redirect_all_output(f'{rundir}/train.log')
logging.info(f'azalea {az.__version__}')
expt = az.Experiment(rundir)
config = yaml.load(open(config))
if config['device'] == 'auto':
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
if startpos:
ss = yaml.load(open(startpos))
config['start_positions'] = [s['fen'] for s in ss]
expt.restart(config)
if model:
expt.load_checkpoint(model)
if replaybuf:
expt.load_replaybuf(replaybuf)
expt.train()
expt.save_checkpoint('final')
expt.close()
if __name__ == '__main__':
main()