-
Notifications
You must be signed in to change notification settings - Fork 98
/
Copy pathmichi.py
executable file
·1083 lines (923 loc) · 39.1 KB
/
michi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env pypy
# -*- coding: utf-8 -*-
#
# (c) Petr Baudis <[email protected]> 2015
# MIT licence (i.e. almost public domain)
#
# A minimalistic Go-playing engine attempting to strike a balance between
# brevity, educational value and strength. It can beat GNUGo on 13x13 board
# on a modest 4-thread laptop.
#
# When benchmarking, note that at the beginning of the first move the program
# runs much slower because pypy is JIT compiling on the background!
#
# To start reading the code, begin either:
# * Bottom up, by looking at the goban implementation - starting with
# the 'empty' definition below and Position.move() method.
# * In the middle, by looking at the Monte Carlo playout implementation,
# starting with the mcplayout() function.
# * Top down, by looking at the MCTS implementation, starting with the
# tree_search() function. It can look a little confusing due to the
# parallelization, but really is just a loop of tree_descend(),
# mcplayout() and tree_update() round and round.
# It may be better to jump around a bit instead of just reading straight
# from start to end.
from __future__ import print_function
from collections import namedtuple
from itertools import chain, count
from joblib import Parallel, delayed
import math
import multiprocessing
from multiprocessing import Process, Queue
from multiprocessing.pool import Pool
import numpy as np
import random
import re
from six.moves import input
import sys
import time
# Given a board of size NxN (N=9, 19, ...), we represent the position
# as an (N+1)*(N+2) string, with '.' (empty), 'X' (to-play player),
# 'x' (other player), and whitespace (off-board border to make rules
# implementation easier). Coordinates are just indices in this string.
# You can simply print(board) when debugging.
N = 19
W = N + 2
empty = "\n".join([(N+1)*' '] + N*[' '+N*'.'] + [(N+2)*' '])
colstr = 'ABCDEFGHJKLMNOPQRST'
N_SIMS = 800
PUCT_C = 0.1
PROPORTIONAL_STAGE = 3
TEMPERATURE = 2
P_ALLOW_RESIGN = 0.8
RAVE_EQUIV = 100
EXPAND_VISITS = 1
PRIOR_EVEN = 4 # should be even number; 0.5 prior
PRIOR_NET = 40
REPORT_PERIOD = 200
RESIGN_THRES = 0.025
#######################
# board string routines
def neighbors(c):
""" generator of coordinates for all neighbors of c """
return [c-1, c+1, c-W, c+W]
def diag_neighbors(c):
""" generator of coordinates for all diagonal neighbors of c """
return [c-W-1, c-W+1, c+W-1, c+W+1]
def board_put(board, c, p):
return board[:c] + p + board[c+1:]
def floodfill(board, c):
""" replace continuous-color area starting at c with special color # """
# This is called so much that a bytearray is worthwhile...
byteboard = bytearray(board, 'utf-8')
p = byteboard[c]
byteboard[c] = ord('#')
fringe = [c]
while fringe:
c = fringe.pop()
for d in neighbors(c):
if byteboard[d] == p:
byteboard[d] = ord('#')
fringe.append(d)
return byteboard.decode('utf-8')
# Regex that matches various kind of points adjecent to '#' (floodfilled) points
contact_res = dict()
for p in ['.', 'x', 'X']:
rp = '\\.' if p == '.' else p
contact_res_src = ['#' + rp, # p at right
rp + '#', # p at left
'#' + '.'*(W-1) + rp, # p below
rp + '.'*(W-1) + '#'] # p above
contact_res[p] = re.compile('|'.join(contact_res_src), flags=re.DOTALL)
def contact(board, p):
""" test if point of color p is adjecent to color # anywhere
on the board; use in conjunction with floodfill for reachability """
m = contact_res[p].search(board)
if not m:
return None
return m.start() if m.group(0)[0] == p else m.end() - 1
def is_eyeish(board, c):
""" test if c is inside a single-color diamond and return the diamond
color or None; this could be an eye, but also a false one """
eyecolor = None
for d in neighbors(c):
if board[d].isspace():
continue
if board[d] == '.':
return None
if eyecolor is None:
eyecolor = board[d]
othercolor = eyecolor.swapcase()
elif board[d] == othercolor:
return None
return eyecolor
def is_eye(board, c):
""" test if c is an eye and return its color or None """
eyecolor = is_eyeish(board, c)
if eyecolor is None:
return None
# Eye-like shape, but it could be a falsified eye
falsecolor = eyecolor.swapcase()
false_count = 0
at_edge = False
for d in diag_neighbors(c):
if board[d].isspace():
at_edge = True
elif board[d] == falsecolor:
false_count += 1
if at_edge:
false_count += 1
if false_count >= 2:
return None
return eyecolor
class Position(namedtuple('Position', 'board cap n ko last last2 komi data')):
""" Implementation of simple Chinese Go rules;
n is how many moves were played so far """
def move(self, c):
""" play as player X at the given coord c, return the new position """
# Are we trying to play in enemy's eye?
in_enemy_eye = is_eyeish(self.board, c) == 'x'
board = board_put(self.board, c, 'X')
# Test for captures, and track ko
capX = self.cap[0]
singlecaps = []
for d in neighbors(c):
if board[d] != 'x':
continue
# XXX: The following is an extremely naive and SLOW approach
# at things - to do it properly, we should maintain some per-group
# data structures tracking liberties.
fboard = floodfill(board, d) # get a board with the adjecent group replaced by '#'
if contact(fboard, '.') is not None:
continue # some liberties left
# no liberties left for this group, remove the stones!
capcount = fboard.count('#')
if capcount == 1:
singlecaps.append(d)
capX += capcount
board = fboard.replace('#', '.') # capture the group
# Test for suicide
if contact(floodfill(board, c), '.') is None:
return None
# Test for (positional super)ko
if board in self.ko or board.swapcase() in self.ko:
return None
# Update the position and return
return Position(board=board.swapcase(), cap=(self.cap[1], capX),
n=self.n + 1, ko=self.ko | { board }, last=c, last2=self.last, komi=self.komi, data=dict())
def pass_move(self):
""" pass - i.e. return simply a flipped position """
return Position(board=self.board.swapcase(), cap=(self.cap[1], self.cap[0]),
n=self.n + 1, ko=self.ko, last=None, last2=self.last, komi=self.komi, data=dict())
def moves(self, i0):
""" Generate a list of moves (includes false positives - suicide moves;
does not include true-eye-filling moves), starting from a given board
index (that can be used for randomization) """
i = i0-1
passes = 0
while True:
i = self.board.find('.', i+1)
if passes > 0 and (i == -1 or i >= i0):
break # we have looked through the whole board
elif i == -1:
i = 0
passes += 1
continue # go back and start from the beginning
# Test for to-play player's one-point eye
if is_eye(self.board, i) == 'X':
continue
yield i
def last_moves_neighbors(self):
""" generate a randomly shuffled list of points including and
surrounding the last two moves (but with the last move having
priority) """
clist = []
for c in self.last, self.last2:
if c is None: continue
dlist = [c] + list(neighbors(c) + diag_neighbors(c))
random.shuffle(dlist)
clist += [d for d in dlist if d not in clist]
return clist
def score(self, owner_map=None):
""" compute score for to-play player; this assumes a final position
with all dead stones captured; if owner_map is passed, it is assumed
to be an array of statistics with average owner at the end of the game
(+1 black, -1 white) """
board = self.board
i = 0
while True:
i = self.board.find('.', i+1)
if i == -1:
break
fboard = floodfill(board, i)
# fboard is board with some continuous area of empty space replaced by #
touches_X = contact(fboard, 'X') is not None
touches_x = contact(fboard, 'x') is not None
if touches_X and not touches_x:
board = fboard.replace('#', 'X')
elif touches_x and not touches_X:
board = fboard.replace('#', 'x')
else:
board = fboard.replace('#', ':') # seki, rare
# now that area is replaced either by X, x or :
komi = self.komi if self.n % 2 == 1 else -self.komi
if owner_map is not None:
for c in range(W*W):
n = 1 if board[c] == 'X' else -1 if board[c] == 'x' else 0
owner_map[c] += n * (1 if self.n % 2 == 0 else -1)
return board.count('X') - board.count('x') + komi
def flip_vert(self):
board = '\n'.join(reversed(self.board[:-1].split('\n'))) + ' '
def coord_flip_vert(c):
if c is None: return None
return (W-1 - c // W) * W + c % W
# XXX: Doesn't update ko properly
return Position(board=board, cap=self.cap, n=self.n, ko=set(), last=coord_flip_vert(self.last), last2=coord_flip_vert(self.last2), komi=self.komi, data=self.data)
def flip_horiz(self):
board = '\n'.join([' ' + l[1:][::-1] for l in self.board.split('\n')])
def coord_flip_horiz(c):
if c is None: return None
return c // W * W + (W-1 - c % W)
# XXX: Doesn't update ko properly
return Position(board=board, cap=self.cap, n=self.n, ko=set(), last=coord_flip_horiz(self.last), last2=coord_flip_horiz(self.last2), komi=self.komi, data=self.data)
def flip_both(self):
board = '\n'.join(reversed([' ' + l[1:][::-1] for l in self.board[:-1].split('\n')])) + ' '
def coord_flip_both(c):
if c is None: return None
return (W-1 - c // W) * W + (W-1 - c % W)
# XXX: Doesn't update ko properly
return Position(board=board, cap=self.cap, n=self.n, ko=set(), last=coord_flip_both(self.last), last2=coord_flip_both(self.last2), komi=self.komi, data=self.data)
def flip_random(self):
pos = self
if random.random() < 0.5:
pos = pos.flip_vert()
if random.random() < 0.5:
pos = pos.flip_horiz()
return pos
def empty_position():
""" Return an initial board position """
return Position(board=empty, cap=(0, 0), n=0, ko=set(), last=None, last2=None, komi=7.5, data=dict())
########################
# fork safe model wrapper
def encode_position(position, board_transform=None):
my_stones, their_stones, edge, last, last2, to_play = np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N))
if board_transform:
position = eval('Position.' + board_transform)(position)
board = position.board
for c, p in enumerate(board):
x, y = c % W - 1, c // W - 1
# In either case, y and x should be sane (not off-board)
if p == 'X':
my_stones[y, x] = 1
elif p == 'x':
their_stones[y, x] = 1
if not (x >= 0 and x < N and y >= 0 and y < N):
continue
if x == 0 or x == N-1 or y == 0 or y == N-1:
edge[y, x] = 1
if position.last == c:
last[y, x] = 1
if position.last2 == c:
last2[y, x] = 1
if position.n % 2 == 1:
to_play[y, x] = 1
return np.stack((my_stones, their_stones, edge, last, last2, to_play), axis=-1)
class ModelServer(Process):
def __init__(self, cmd_queue, res_queues, load_snapshot=None):
super(ModelServer, self).__init__()
self.cmd_queue = cmd_queue
self.res_queues = res_queues
self.load_snapshot = load_snapshot
def run(self):
try:
from michi.net import AGZeroModel
net = AGZeroModel(N)
net.create()
if self.load_snapshot is not None:
net.load(self.load_snapshot)
class PredictStash(object):
""" prediction batcher """
def __init__(self, trigger, res_queues):
self.stash = []
self.trigger = trigger # XXX must not be higher than #workers
self.res_queues = res_queues
def add(self, kind, X_pos, ri):
self.stash.append((kind, X_pos, ri))
if len(self.stash) >= self.trigger:
self.process()
def process(self):
if not self.stash:
return
dist, res = net.predict(np.array([s[1] for s in self.stash]))
for d, r, s in zip(dist, res, self.stash):
kind, _, ri = s
self.res_queues[ri].put(d if kind == 0 else r)
self.stash = []
stash = PredictStash(1, self.res_queues)
fit_counter = 0
while True:
cmd, args, ri = self.cmd_queue.get()
if cmd == 'stash_size':
stash.process()
stash.trigger = args['stash_size']
elif cmd == 'fit_game':
stash.process()
print('\rFit %d...' % (fit_counter,), end='')
sys.stdout.flush()
fit_counter += 1
net.fit_game(**args)
elif cmd == 'predict_distribution':
stash.add(0, args['X_position'], ri)
elif cmd == 'predict_winrate':
stash.add(1, args['X_position'], ri)
elif cmd == 'model_name':
self.res_queues[ri].put(net.model_name)
elif cmd == 'save':
stash.process()
net.save(args['snapshot_id'])
except:
import traceback
traceback.print_exc()
class GoModel(object):
def __init__(self, load_snapshot=None):
self.cmd_queue = Queue()
self.res_queues = [Queue() for i in range(128)]
self.server = ModelServer(self.cmd_queue, self.res_queues, load_snapshot=load_snapshot)
self.server.start()
self.ri = 0 # id of process in case of multiple processes, to prevent mixups
def stash_size(self, stash_size):
self.cmd_queue.put(('stash_size', {'stash_size': stash_size}, self.ri))
def fit_game(self, positions, result, board_transform=None):
X_positions = [(encode_position(pos, board_transform=board_transform), dist) for pos, dist in positions]
self.cmd_queue.put(('fit_game', {'X_positions': X_positions, 'result': result}, self.ri))
def predict_distribution(self, position):
self.cmd_queue.put(('predict_distribution', {'X_position': encode_position(position)}, self.ri))
return self.res_queues[self.ri].get()
def predict_winrate(self, position):
self.cmd_queue.put(('predict_winrate', {'X_position': encode_position(position)}, self.ri))
return self.res_queues[self.ri].get()
def model_name(self):
self.cmd_queue.put(('model_name', {}, self.ri))
return self.res_queues[self.ri].get()
def save(self, snapshot_id):
self.cmd_queue.put(('save', {'snapshot_id': snapshot_id}, self.ri))
########################
# montecarlo tree search
class TreeNode():
""" Monte-Carlo tree node;
v is #visits, w is #wins for to-play (expected reward is w/v)
pv, pw are prior values (node value = w/v + pw/pv)
av, aw are amaf values ("all moves as first", used for the RAVE tree policy)
children is None for leaf nodes """
def __init__(self, net, pos):
self.net = net
self.pos = pos
self.v = 0
self.w = 0
self.pv = 0
self.pw = 0
self.av = 0
self.aw = 0
self.children = None
def expand(self):
""" add and initialize children to a leaf node """
distribution = self.net.predict_distribution(self.pos)
self.children = []
for c in self.pos.moves(0):
pos2 = self.pos.move(c)
if pos2 is None:
continue
node = TreeNode(self.net, pos2)
self.children.append(node)
x, y = c % W - 1, c // W - 1
value = distribution[y * N + x]
node.pv = PRIOR_NET
node.pw = PRIOR_NET * value
# Add also a pass move - but only if this doesn't trigger a losing
# scoring (or we have no other option)
if not self.children:
can_pass = True
else:
can_pass = self.pos.score() >= 0
if can_pass:
node = TreeNode(self.net, self.pos.pass_move())
self.children.append(node)
node.pv = PRIOR_NET
node.pw = PRIOR_NET * distribution[-1]
def puct_urgency(self, n0):
# XXX: This is substituted by global_puct_urgency()
expectation = float(self.w + PRIOR_EVEN/2) / (self.v + PRIOR_EVEN)
try:
prior = float(self.pw) / self.pv
except:
prior = 0.1 # XXX
return expectation + PUCT_C * prior * math.sqrt(n0) / (1 + self.v)
def rave_urgency(self):
v = self.v + self.pv
expectation = float(self.w+self.pw) / v
if self.av == 0:
return expectation
rave_expectation = float(self.aw) / self.av
beta = self.av / (self.av + v + float(v) * self.av / RAVE_EQUIV)
return beta * rave_expectation + (1-beta) * expectation
def winrate(self):
return float(self.w) / self.v if self.v > 0 else float('nan')
def prior(self):
return float(self.pw) / self.pv if self.pv > 0 else float('nan')
def best_move(self, proportional=False):
""" best move is the most simulated one """
if self.children is None:
return None
if proportional:
probs = [(float(node.v) / self.v) ** TEMPERATURE for node in self.children]
probs_tot = sum(probs)
probs = [p / probs_tot for p in probs]
# print([(str_coord(n.pos.last), p, p * probs_tot) for n, p in zip(self.children, probs)])
i = np.random.choice(len(self.children), p=probs)
return self.children[i]
else:
return max(self.children, key=lambda node: node.v)
def distribution(self):
distribution = np.zeros(N * N + 1)
for child in self.children:
p = float(child.v) / self.v
c = child.pos.last
if c is not None:
x, y = c % W - 1, c // W - 1
distribution[y * N + x] = p
else:
distribution[-1] = p
return distribution
def puct_urgency_input(nodes):
w = np.array([float(n.w) for n in nodes])
v = np.array([float(n.v) for n in nodes])
pw = np.array([float(n.pw) if n.pv > 0 else 1. for n in nodes])
pv = np.array([float(n.pv) if n.pv > 0 else 10. for n in nodes])
return w, v, pw, pv
def global_puct_urgency(n0, w, v, pw, pv):
# Like Node.puct_urgency(), but for all children, more quickly.
# Expects numpy arrays (except n0 which is scalar).
expectation = (w + PRIOR_EVEN/2) / (v + PRIOR_EVEN)
prior = pw / pv
return expectation + PUCT_C * prior * math.sqrt(n0) / (1 + v)
def tree_descend(tree, amaf_map, disp=False):
""" Descend through the tree to a leaf """
tree.v += 1
nodes = [tree]
passes = 0
root = True
while nodes[-1].children is not None and passes < 2:
if disp: print_pos(nodes[-1].pos)
# Pick the most urgent child
children = list(nodes[-1].children)
if disp:
for c in children:
dump_subtree(c, recurse=False)
random.shuffle(children) # randomize the max in case of equal urgency
urgencies = global_puct_urgency(nodes[-1].v, *puct_urgency_input(children))
if root:
dirichlet = np.random.dirichlet((0.03,1), len(children))
urgencies = urgencies*0.75 + dirichlet[:,0]*0.25
root = False
node = max(zip(children, urgencies), key=lambda t: t[1])[0]
nodes.append(node)
if disp: print('chosen %s' % (str_coord(node.pos.last),), file=sys.stderr)
if node.pos.last is None:
passes += 1
else:
passes = 0
if amaf_map[node.pos.last] == 0: # Mark the coordinate with 1 for black
amaf_map[node.pos.last] = 1 if nodes[-2].pos.n % 2 == 0 else -1
# updating visits on the way *down* represents "virtual loss", relevant for parallelization
node.v += 1
if node.children is None and node.v > EXPAND_VISITS:
node.expand()
return nodes
def tree_update(nodes, amaf_map, score, disp=False):
""" Store simulation result in the tree (@nodes is the tree path) """
for node in reversed(nodes):
if disp: print('updating', str_coord(node.pos.last), score < 0, file=sys.stderr)
node.w += score < 0 # score is for to-play, node statistics for just-played
# Update the node children AMAF stats with moves we made
# with their color
amaf_map_value = 1 if node.pos.n % 2 == 0 else -1
if node.children is not None:
for child in node.children:
if child.pos.last is None:
continue
if amaf_map[child.pos.last] == amaf_map_value:
if disp: print(' AMAF updating', str_coord(child.pos.last), score > 0, file=sys.stderr)
child.aw += score > 0 # reversed perspective
child.av += 1
score = -score
def tree_search(tree, n, owner_map, disp=False, debug_disp=False):
""" Perform MCTS search from a given position for a given #iterations """
# Initialize root node
if tree.children is None:
tree.expand()
i = 0
while i < n:
amaf_map = W*W*[0]
nodes = tree_descend(tree, amaf_map, disp=debug_disp)
i += 1
if disp and i % REPORT_PERIOD == 0:
print_tree_summary(tree, i, f=sys.stderr)
last_node = nodes[-1]
if last_node.pos.last is None and last_node.pos.last2 is None:
score = 1 if last_node.pos.score() > 0 else -1
else:
score = tree.net.predict_winrate(last_node.pos)
tree_update(nodes, amaf_map, score, disp=debug_disp)
if debug_disp:
dump_subtree(tree)
if disp and i % REPORT_PERIOD != 0:
print_tree_summary(tree, i, f=sys.stderr)
return tree.best_move(tree.pos.n <= PROPORTIONAL_STAGE)
###################
# user interface(s)
# utility routines
def print_pos(pos, f=sys.stderr, owner_map=None):
""" print visualization of the given board position, optionally also
including an owner map statistic (probability of that area of board
eventually becoming black/white) """
if pos.n % 2 == 0: # to-play is black
board = pos.board.replace('x', 'O')
Xcap, Ocap = pos.cap
else: # to-play is white
board = pos.board.replace('X', 'O').replace('x', 'X')
Ocap, Xcap = pos.cap
print('Move: %-3d Black: %d caps White: %d caps Komi: %.1f' % (pos.n, Xcap, Ocap, pos.komi), file=f)
pretty_board = ' '.join(board.rstrip()) + ' '
if pos.last is not None:
pretty_board = pretty_board[:pos.last*2-1] + '(' + board[pos.last] + ')' + pretty_board[pos.last*2+2:]
rowcounter = count()
pretty_board = [' %-02d%s' % (N-i, row[2:]) for row, i in zip(pretty_board.split("\n")[1:], rowcounter)]
if owner_map is not None:
pretty_ownermap = ''
for c in range(W*W):
if board[c].isspace():
pretty_ownermap += board[c]
elif owner_map[c] > 0.6:
pretty_ownermap += 'X'
elif owner_map[c] > 0.3:
pretty_ownermap += 'x'
elif owner_map[c] < -0.6:
pretty_ownermap += 'O'
elif owner_map[c] < -0.3:
pretty_ownermap += 'o'
else:
pretty_ownermap += '.'
pretty_ownermap = ' '.join(pretty_ownermap.rstrip())
pretty_board = ['%s %s' % (brow, orow[2:]) for brow, orow in zip(pretty_board, pretty_ownermap.split("\n")[1:])]
print("\n".join(pretty_board), file=f)
print(' ' + ' '.join(colstr[:N]), file=f)
print('', file=f)
def dump_subtree(node, thres=N_SIMS/50, indent=0, f=sys.stderr, recurse=True):
""" print this node and all its children with v >= thres. """
print("%s+- %s %.3f (%d/%d, prior %d/%d, rave %d/%d=%.3f, pred %.3f)" %
(indent*' ', str_coord(node.pos.last), node.winrate(),
node.w, node.v, node.pw, node.pv, node.aw, node.av,
float(node.aw)/node.av if node.av > 0 else float('nan'),
float(-node.net.predict_winrate(node.pos) + 1) / 2), file=f)
if not recurse or not node.children:
return
for child in sorted(node.children, key=lambda n: n.v, reverse=True):
if child.v >= thres:
dump_subtree(child, thres=thres, indent=indent+3, f=f)
def print_tree_summary(tree, sims, f=sys.stderr):
best_nodes = sorted(tree.children, key=lambda n: n.v, reverse=True)[:5]
best_seq = []
node = tree
while node is not None:
best_seq.append(node.pos.last)
node = node.best_move()
best_predwinrate = float(-tree.net.predict_winrate(best_nodes[0].pos) + 1) / 2
print('[%4d] winrate %.3f/%.3f | seq %s | can %s' %
(sims, best_nodes[0].winrate(), best_predwinrate, ' '.join([str_coord(c) for c in best_seq[1:6]]),
' '.join(['%s(%.3f|%d/%.3f)' % (str_coord(n.pos.last), n.winrate(), n.v, n.prior()) for n in best_nodes])), file=f)
def parse_coord(s):
if s == 'pass':
return None
return W+1 + (N - int(s[1:])) * W + colstr.index(s[0].upper())
def str_coord(c):
if c is None:
return 'pass'
row, col = divmod(c - (W+1), W)
return '%c%d' % (colstr[col], N - row)
# various main programs
def play_and_train(net, i, batches_per_game=2, disp=False):
positions = []
allow_resign = i > 10 and np.random.rand() < P_ALLOW_RESIGN
tree = TreeNode(net=net, pos=empty_position())
tree.expand()
owner_map = W*W*[0]
while True:
owner_map = W*W*[0]
next_tree = tree_search(tree, N_SIMS, owner_map, disp=disp)
positions.append((tree.pos, tree.distribution()))
tree = next_tree
if disp:
print_pos(tree.pos, sys.stdout, owner_map)
if tree.pos.last is None and tree.pos.last2 is None:
score = 1 if tree.pos.score() > 0 else -1
if tree.pos.n % 2:
score = -score
if disp:
print('Two passes, score: B%+.1f' % (score,))
count = tree.pos.score()
if tree.pos.n % 2:
count = -count
print('Counted score: B%+.1f' % (count,))
break
if allow_resign and float(tree.w)/tree.v < RESIGN_THRES and tree.v > N_SIMS / 10 and tree.pos.n > 10:
score = 1 # win for player to-play from this position
if tree.pos.n % 2:
score = -score
if disp:
print('Resign (%d), score: B%+.1f' % (tree.pos.n % 2, score))
count = tree.pos.score()
if tree.pos.n % 2:
count = -count
print('Counted score: B%+.1f' % (count,))
break
if tree.pos.n > N*N*2:
if disp:
print('Stopping too long a game.')
score = 0
break
# score here is for black to play (player-to-play from empty_position)
if disp:
print(score)
dump_subtree(tree)
for i in range(batches_per_game):
net.fit_game(positions, score)
# fit flipped positions
for i in range(batches_per_game):
net.fit_game(positions, score, board_transform='flip_vert')
for i in range(batches_per_game):
net.fit_game(positions, score, board_transform='flip_horiz')
for i in range(batches_per_game):
net.fit_game(positions, score, board_transform='flip_both')
# TODO 90\deg rot
def selfplay_singlethread(net, worker_id, disp=False, snapshot_interval=25):
net.ri = worker_id
i = 0
while True:
print('[%d %d] Self-play of game #%d ...' % (worker_id, time.time(), i,))
play_and_train(net, i, disp=disp)
i += 1
if snapshot_interval and i % snapshot_interval == 0:
snapshot_id = '%s_%09d' % (net.model_name(), i)
print(snapshot_id)
net.save(snapshot_id)
def selfplay(net, disp=True):
n_workers = multiprocessing.cpu_count() * 6
# group up parallel predict requests
net.stash_size(max(multiprocessing.cpu_count(), 1))
# First process is verbose and snapshots the model
processes = [Process(target=selfplay_singlethread, kwargs=dict(net=net, worker_id=0, disp=disp))]
# The rest work silently
for i in range(1, n_workers):
processes.append(Process(target=selfplay_singlethread, kwargs=dict(net=net, worker_id=i, snapshot_interval=None)))
for p in processes:
p.start()
for p in processes:
p.join()
def gather_positions(filename, subsample=16):
from gomill import sgf
with open(filename) as f:
g = sgf.Sgf_game.from_string(f.read())
if g.get_size() != N:
raise ValueError('size mismatch')
if g.get_handicap() is not None:
raise ValueError('handicap game')
score = 1 if g.get_winner() == 'B' else -1
pos_to_play = [[], []] # black-to-play, white-to-play
pos = empty_position()
for node in g.get_main_sequence()[1:]:
color, move = node.get_move()
if move is not None:
c = (move[0]+1) * W + move[1]+1
pos.data['next'] = c
pos = pos.move(c)
else:
pos.data['next'] = None
pos = pos.pass_move()
if pos is None:
raise ValueError('invalid move %s' % (move,))
pos_to_play[pos.n % 2].append(pos)
pos.data['next'] = None
# subsample positions
pos_to_play = [random.sample(pos_to_play[0], subsample//2), random.sample(pos_to_play[1], subsample//2)]
# alternate positions and randomly rotate
positions = list(chain(*zip(*pos_to_play)))
flipped = [pos.flip_random() for pos in positions]
return (flipped, score)
def position_dist(net, worker_id, pos, disp=False):
net.ri = worker_id
tree = TreeNode(net=net, pos=pos)
tree.expand()
owner_map = W*W*[0]
tree_search(tree, N_SIMS, owner_map, disp=disp)
return tree.distribution()
def position_distnext(pos):
distribution = np.zeros(N * N + 1)
c = pos.data['next']
if c is not None:
x, y = c % W - 1, c // W - 1
distribution[y * N + x] = 1
else:
distribution[-1] = 1
return distribution
def replay_train(net, snapshot_interval=500, continuous_predict=False, disp=True):
n_workers = multiprocessing.cpu_count()
# group up parallel predict requests
# net.stash_size(max(2, 1)) # XXX not all workers will always be busy
for i, f in enumerate(sys.stdin):
f = f.rstrip()
print('[%d] %s' % (i, f))
try:
positions, score = gather_positions(f, subsample=16)
except ValueError:
print('SKIP')
import traceback
traceback.print_exc()
continue
if continuous_predict:
dist = Parallel(n_jobs=n_workers, verbose=100)(delayed(position_dist)(i, pos, disp) for i, pos in enumerate(positions))
else:
dist = [position_distnext(pos) for pos in positions]
X_positions = list(zip(positions, dist))
if disp:
print_pos(X_positions[0][0], sys.stdout, None)
net.fit_game(X_positions, score)
if snapshot_interval and i > 0 and i % snapshot_interval == 0:
snapshot_id = '%s_R%09d' % (net.model_name(), i)
print(snapshot_id)
net.save(snapshot_id)
snapshot_id = '%s_Rfinal' % (net.model_name(),)
print(snapshot_id)
net.save(snapshot_id)
def game_io(net, computer_black=False):
""" A simple minimalistic text mode UI. """
tree = TreeNode(net=net, pos=empty_position())
tree.expand()
owner_map = W*W*[0]
while True:
if not (tree.pos.n == 0 and computer_black):
print_pos(tree.pos, sys.stdout, owner_map)
sc = input("Your move: ")
c = parse_coord(sc)
if c is not None:
# Not a pass
if tree.pos.board[c] != '.':
print('Bad move (not empty point)')
continue
# Find the next node in the game tree and proceed there
nodes = list(filter(lambda n: n.pos.last == c, tree.children))
if not nodes:
print('Bad move (rule violation)')
continue
tree = nodes[0]
else:
# Pass move
if tree.children[0].pos.last is None:
tree = tree.children[0]
else:
tree = TreeNode(net=net, pos=tree.pos.pass_move())
print_pos(tree.pos)
owner_map = W*W*[0]
tree = tree_search(tree, N_SIMS, owner_map, disp=True)
if tree.pos.last is None and tree.pos.last2 is None:
score = tree.pos.score()
if tree.pos.n % 2:
score = -score
print('Game over, score: B%+.1f' % (score,))
break
if float(tree.w)/tree.v < RESIGN_THRES and tree.pos.n > 10:
print('I resign.')
break
print('Thank you for the game!')
def gtp_io(net):
""" GTP interface for our program. We can play only on the board size
which is configured (N), and we ignore color information and assume
alternating play! """
known_commands = ['boardsize', 'clear_board', 'komi', 'play', 'genmove',
'final_score', 'quit', 'name', 'version', 'known_command',
'list_commands', 'protocol_version', 'tsdebug']
tree = TreeNode(net=net, pos=empty_position())
tree.expand()
while True:
try:
line = input().strip()
except EOFError:
break
if line == '':
continue
command = [s.lower() for s in line.split()]
if re.match('\d+', command[0]):
cmdid = command[0]
command = command[1:]
else:
cmdid = ''
owner_map = W*W*[0]
ret = ''
if command[0] == "boardsize":
if int(command[1]) != N:
print("Warning: Trying to set incompatible boardsize %s (!= %d)" % (command[1], N), file=sys.stderr)
ret = None
elif command[0] == "clear_board":
tree = TreeNode(net=net, pos=empty_position())
tree.expand()
elif command[0] == "komi":
# XXX: can we do this nicer
tree.pos = Position(board=tree.pos.board, cap=(tree.pos.cap[0], tree.pos.cap[1]),
n=tree.pos.n, ko=tree.pos.ko, last=tree.pos.last, last2=tree.pos.last2,
komi=float(command[1]), data=dict())
elif command[0] == "play":
c = parse_coord(command[2])
if c is not None:
# Find the next node in the game tree and proceed there
if tree.children is not None and filter(lambda n: n.pos.last == c, tree.children):
tree = list(filter(lambda n: n.pos.last == c, tree.children))[0]
else:
# Several play commands in row, eye-filling move, etc.
tree = TreeNode(net=net, pos=tree.pos.move(c))