diff --git a/tf/net.py b/tf/net.py index de779ca0..7f4d9787 100755 --- a/tf/net.py +++ b/tf/net.py @@ -49,6 +49,7 @@ def __init__(self, self.set_policyformat(policy) self.set_valueformat(value) self.set_movesleftformat(moves_left) + self.set_uncformat(pb.NetworkFormat.POLICY_UNC_CONVOLUTION) def set_networkformat(self, net): self.pb.format.network_format.network = net @@ -68,6 +69,9 @@ def set_valueformat(self, value): def set_movesleftformat(self, moves_left): self.pb.format.network_format.moves_left = moves_left + def set_uncformat(self, unc): + self.pb.format.network_format.unc = unc + def set_input(self, input_format): self.pb.format.network_format.input = input_format if input_format == pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2 or input_format == pb.NetworkFormat.INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON: @@ -309,6 +313,10 @@ def moves_left_to_bp(l, w): pb_name = moves_left_to_bp(layers[1], weights_name) else: pb_name = 'moves_left.' + convblock_to_bp(weights_name) + elif base_layer == 'uncertainty1': + pb_name = 'unc1.' + convblock_to_bp(weights_name) + elif base_layer == 'uncertainty': + pb_name = 'unc.' + convblock_to_bp(weights_name) elif base_layer.startswith('residual'): block = int(base_layer.split('_')[1]) - 1 # 1 indexed if layers[1] == '1': diff --git a/tf/tfprocess.py b/tf/tfprocess.py index 2c41a4d4..950b5779 100644 --- a/tf/tfprocess.py +++ b/tf/tfprocess.py @@ -211,11 +211,11 @@ def init_net_v2(self): self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001)) input_var = tf.keras.Input(shape=(112, 8 * 8)) x_planes = tf.keras.layers.Reshape([112, 8, 8])(input_var) - policy, value, moves_left = self.construct_net_v2(x_planes) + policy, value, moves_left, unc = self.construct_net_v2(x_planes) if self.moves_left: - outputs = [policy, value, moves_left] + outputs = [policy, value, moves_left, unc] else: - outputs = [policy, value] + outputs = [policy, value, unc] self.model = tf.keras.Model(inputs=input_var, outputs=outputs) # swa_count initialized reguardless to make checkpoint code simpler. @@ -355,6 +355,52 @@ def moves_left_loss(target, output): self.moves_left_loss_fn = moves_left_loss + def correct_unc(target, output): + output = tf.cast(output, tf.float32) + counts = tf.ones_like(output) * 1858. + # Calculate loss on policy head + if self.cfg['training'].get('mask_legal_moves'): + # extract mask for legal moves from target policy + move_is_legal = tf.greater_equal(target, 0) + counts = tf.reduce_sum(tf.cast(move_is_legal, tf.float32), axis=-1, keepdims=True) + # replace logits of illegal moves with large negative value (so that it doesn't affect policy of legal moves) without gradient + illegal_filler = tf.zeros_like(output) + output = tf.where(move_is_legal, output, illegal_filler) + return output, counts + + def unc_loss(target, output, unc_output): + unc_output, counts = correct_unc(target, unc_output) + target, output = correct_policy(target, output) + softmax_policy = tf.nn.softmax(tf.stop_gradient(output)) + calculated_target = tf.stop_gradient(target) - softmax_policy + # Loss scaled up by 29 as the average number of legal moves, so loss is of the similar order to mean total loss per position, not per legal output. + # Scale down because ... it generates too much gradient :P + loss = 29*tf.math.squared_difference(calculated_target, unc_output) / 4 + return tf.reduce_sum(loss) / tf.reduce_sum(tf.stop_gradient(counts)) + + self.unc_loss_fn = unc_loss + + def unc_mae(target, output, unc_output): + unc_output, counts = correct_unc(target, unc_output) + target, output = correct_policy(target, output) + softmax_policy = tf.nn.softmax(tf.stop_gradient(output)) + calculated_target = tf.stop_gradient(target) - softmax_policy + loss = tf.math.abs(calculated_target - unc_output) + return tf.reduce_sum(loss) / tf.reduce_sum(tf.stop_gradient(counts)) + + self.unc_mae_fn = unc_mae + + def unc_mpi(target, output, unc_output): + unc_output, counts = correct_unc(target, unc_output) + target, output = correct_policy(target, output) + softmax_policy = tf.nn.softmax(tf.stop_gradient(output)) + calculated_target = tf.stop_gradient(target) - softmax_policy + # maximum value should be when unc_output == calculated_target, output should be abs(calculated_target). Should be 0 if unc_output is 0. So function abs(calculted_target) - abs(calculated_target - unc_output). + loss = tf.math.abs(calculated_target) - tf.math.abs(calculated_target - unc_output) + return tf.reduce_sum(loss) / tf.reduce_sum(tf.stop_gradient(counts)) + + self.unc_mpi_fn = unc_mpi + pol_loss_w = self.cfg['training']['policy_loss_weight'] val_loss_w = self.cfg['training']['value_loss_weight'] @@ -381,6 +427,7 @@ def accuracy(target, output): self.avg_value_loss = [] self.avg_moves_left_loss = [] self.avg_mse_loss = [] + self.avg_unc_loss = [] self.avg_reg_term = [] self.time_start = None self.last_steps = None @@ -559,20 +606,21 @@ def process_inner_loop(self, x, y, z, q, m): moves_left_loss = self.moves_left_loss_fn(m, moves_left) else: moves_left_loss = tf.constant(0.) + unc_loss = self.unc_loss_fn(y, policy, outputs[3]) total_loss = self.lossMix(policy_loss, value_loss, - moves_left_loss) + reg_term + moves_left_loss) + reg_term + unc_loss if self.loss_scale != 1: total_loss = self.optimizer.get_scaled_loss(total_loss) if self.wdl: mse_loss = self.mse_loss_fn(self.qMix(z, q), value) else: value_loss = self.value_loss_fn(self.qMix(z, q), value) - return policy_loss, value_loss, mse_loss, moves_left_loss, reg_term, tape.gradient( + return policy_loss, value_loss, mse_loss, moves_left_loss, unc_loss, reg_term, tape.gradient( total_loss, self.model.trainable_weights) @tf.function() def strategy_process_inner_loop(self, x, y, z, q, m): - policy_loss, value_loss, mse_loss, moves_left_loss, reg_term, new_grads = self.strategy.run( + policy_loss, value_loss, mse_loss, moves_left_loss, unc_loss, reg_term, new_grads = self.strategy.run( self.process_inner_loop, args=(x, y, z, q, m)) policy_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, policy_loss, @@ -586,10 +634,13 @@ def strategy_process_inner_loop(self, x, y, z, q, m): moves_left_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, moves_left_loss, axis=None) + unc_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, + unc_loss, + axis=None) reg_term = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, reg_term, axis=None) - return policy_loss, value_loss, mse_loss, moves_left_loss, reg_term, new_grads + return policy_loss, value_loss, mse_loss, moves_left_loss, unc_loss, reg_term, new_grads def apply_grads(self, grads, effective_batch_splits): grads = [ @@ -635,10 +686,10 @@ def train_step(self, steps, batch_size, batch_splits): for _ in range(batch_splits): x, y, z, q, m = next(self.train_iter) if self.strategy is not None: - policy_loss, value_loss, mse_loss, moves_left_loss, reg_term, new_grads = self.strategy_process_inner_loop( + policy_loss, value_loss, mse_loss, moves_left_loss, unc_loss, reg_term, new_grads = self.strategy_process_inner_loop( x, y, z, q, m) else: - policy_loss, value_loss, mse_loss, moves_left_loss, reg_term, new_grads = self.process_inner_loop( + policy_loss, value_loss, mse_loss, moves_left_loss, unc_loss, reg_term, new_grads = self.process_inner_loop( x, y, z, q, m) if not grads: grads = new_grads @@ -656,6 +707,7 @@ def train_step(self, steps, batch_size, batch_splits): self.avg_value_loss.append(value_loss) if self.moves_left: self.avg_moves_left_loss.append(moves_left_loss) + self.avg_unc_loss.append(unc_loss) self.avg_mse_loss.append(mse_loss) self.avg_reg_term.append(reg_term) # Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this. @@ -694,15 +746,16 @@ def train_step(self, steps, batch_size, batch_splits): avg_moves_left_loss = np.mean(self.avg_moves_left_loss or [0]) avg_value_loss = np.mean(self.avg_value_loss or [0]) avg_mse_loss = np.mean(self.avg_mse_loss or [0]) + avg_unc_loss = np.mean(self.avg_unc_loss or[0]) avg_reg_term = np.mean(self.avg_reg_term or [0]) print( - "step {}, lr={:g} policy={:g} value={:g} mse={:g} moves={:g} reg={:g} total={:g} ({:g} pos/s)" + "step {}, lr={:g} policy={:g} value={:g} mse={:g} moves={:g} unc={:g} reg={:g} total={:g} ({:g} pos/s)" .format( steps, self.lr, avg_policy_loss, avg_value_loss, - avg_mse_loss, avg_moves_left_loss, avg_reg_term, + avg_mse_loss, avg_moves_left_loss, avg_unc_loss, avg_reg_term, pol_loss_w * avg_policy_loss + val_loss_w * avg_value_loss + avg_reg_term + - moves_loss_w * avg_moves_left_loss, speed)) + moves_loss_w * avg_moves_left_loss + avg_unc_loss, speed)) after_weights = self.read_weights() with self.train_writer.as_default(): @@ -712,6 +765,7 @@ def train_step(self, steps, batch_size, batch_splits): tf.summary.scalar("Moves Left Loss", avg_moves_left_loss, step=steps) + tf.summary.scalar("Uncertainty Loss", avg_unc_loss, step=steps) tf.summary.scalar("Reg term", avg_reg_term, step=steps) tf.summary.scalar("LR", self.lr, step=steps) tf.summary.scalar("Gradient norm", @@ -727,6 +781,7 @@ def train_step(self, steps, batch_size, batch_splits): self.avg_moves_left_loss = [] self.avg_value_loss = [] self.avg_mse_loss = [] + self.avg_unc_loss = [] self.avg_reg_term = [] return steps @@ -843,11 +898,14 @@ def calculate_test_summaries_inner_loop(self, x, y, z, q, m): else: moves_left_loss = tf.constant(0.) moves_left_mean_error = tf.constant(0.) - return policy_loss, value_loss, moves_left_loss, mse_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul + unc_loss = self.unc_loss_fn(y, policy, outputs[3]) + unc_mae = self.unc_mae_fn(y, policy, outputs[3]) + unc_mpi = self.unc_mpi_fn(y, policy, outputs[3]) + return policy_loss, value_loss, moves_left_loss, mse_loss, unc_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul, unc_mae, unc_mpi @tf.function() def strategy_calculate_test_summaries_inner_loop(self, x, y, z, q, m): - policy_loss, value_loss, moves_left_loss, mse_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul = self.strategy.run( + policy_loss, value_loss, moves_left_loss, mse_loss, unc_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul, unc_mae, unc_mpi = self.strategy.run( self.calculate_test_summaries_inner_loop, args=(x, y, z, q, m)) policy_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, policy_loss, @@ -858,6 +916,9 @@ def strategy_calculate_test_summaries_inner_loop(self, x, y, z, q, m): mse_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, mse_loss, axis=None) + unc_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, + unc_loss, + axis=None) policy_accuracy = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, policy_accuracy, axis=None) @@ -875,7 +936,13 @@ def strategy_calculate_test_summaries_inner_loop(self, x, y, z, q, m): policy_ul = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, policy_ul, axis=None) - return policy_loss, value_loss, moves_left_loss, mse_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul + unc_mae = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, + unc_mae, + axis=None) + unc_mpi = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, + unc_mpi, + axis=None) + return policy_loss, value_loss, moves_left_loss, mse_loss, unc_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul, unc_mae, unc_mpi def calculate_test_summaries_v2(self, test_batches, steps): sum_policy_accuracy = 0 @@ -883,17 +950,20 @@ def calculate_test_summaries_v2(self, test_batches, steps): sum_moves_left = 0 sum_moves_left_mean_error = 0 sum_mse = 0 + sum_unc = 0 sum_policy = 0 sum_value = 0 sum_policy_entropy = 0 sum_policy_ul = 0 + sum_unc_mae = 0 + sum_unc_mpi = 0 for _ in range(0, test_batches): x, y, z, q, m = next(self.test_iter) if self.strategy is not None: - policy_loss, value_loss, moves_left_loss, mse_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul = self.strategy_calculate_test_summaries_inner_loop( + policy_loss, value_loss, moves_left_loss, mse_loss, unc_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul, unc_mae, unc_mpi = self.strategy_calculate_test_summaries_inner_loop( x, y, z, q, m) else: - policy_loss, value_loss, moves_left_loss, mse_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul = self.calculate_test_summaries_inner_loop( + policy_loss, value_loss, moves_left_loss, mse_loss, unc_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul, unc_mae, unc_mpi = self.calculate_test_summaries_inner_loop( x, y, z, q, m) sum_policy_accuracy += policy_accuracy sum_policy_entropy += policy_entropy @@ -906,6 +976,9 @@ def calculate_test_summaries_v2(self, test_batches, steps): if self.moves_left: sum_moves_left += moves_left_loss sum_moves_left_mean_error += moves_left_mean_error + sum_unc += unc_loss + sum_unc_mae += unc_mae + sum_unc_mpi += unc_mpi sum_policy_accuracy /= test_batches sum_policy_accuracy *= 100 sum_policy /= test_batches @@ -920,6 +993,9 @@ def calculate_test_summaries_v2(self, test_batches, steps): if self.moves_left: sum_moves_left /= test_batches sum_moves_left_mean_error /= test_batches + sum_unc /= test_batches + sum_unc_mae /= test_batches + sum_unc_mpi /= test_batches self.net.pb.training_params.learning_rate = self.lr self.net.pb.training_params.mse_loss = sum_mse self.net.pb.training_params.policy_loss = sum_policy @@ -929,6 +1005,9 @@ def calculate_test_summaries_v2(self, test_batches, steps): tf.summary.scalar("Policy Loss", sum_policy, step=steps) tf.summary.scalar("Value Loss", sum_value, step=steps) tf.summary.scalar("MSE Loss", sum_mse, step=steps) + tf.summary.scalar("Uncertainty Loss", sum_unc, step=steps) + tf.summary.scalar("Uncertainty Mean Error", sum_unc_mae, step=steps) + tf.summary.scalar("Uncertainty Mean Policy Improvement", sum_unc_mpi, step=steps) tf.summary.scalar("Policy Accuracy", sum_policy_accuracy, step=steps) @@ -949,8 +1028,8 @@ def calculate_test_summaries_v2(self, test_batches, steps): tf.summary.histogram(w.name, w, step=steps) self.test_writer.flush() - print("step {}, policy={:g} value={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g} policy entropy={:g} policy ul={:g}".\ - format(steps, sum_policy, sum_value, sum_policy_accuracy, sum_value_accuracy, sum_mse, sum_policy_entropy, sum_policy_ul), end = '') + print("step {}, policy={:g} value={:g} unc={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g} policy entropy={:g} policy ul={:g} unc_mae={:g} unc_mpi={:g}".\ + format(steps, sum_policy, sum_value, sum_unc, sum_policy_accuracy, sum_value_accuracy, sum_mse, sum_policy_entropy, sum_policy_ul, sum_unc_mae, sum_unc_mpi), end = '') if self.moves_left: print(" moves={:g} moves mean={:g}".format( @@ -977,15 +1056,18 @@ def calculate_test_validations_v2(self, steps): sum_mse = 0 sum_policy = 0 sum_value = 0 + sum_unc = 0 sum_policy_entropy = 0 sum_policy_ul = 0 + sum_unc_mae = 0 + sum_unc_mpi = 0 counter = 0 for (x, y, z, q, m) in self.validation_dataset: if self.strategy is not None: - policy_loss, value_loss, moves_left_loss, mse_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul = self.strategy_calculate_test_summaries_inner_loop( + policy_loss, value_loss, moves_left_loss, mse_loss, unc_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul, unc_mae, unc_mpi = self.strategy_calculate_test_summaries_inner_loop( x, y, z, q, m) else: - policy_loss, value_loss, moves_left_loss, mse_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul = self.calculate_test_summaries_inner_loop( + policy_loss, value_loss, moves_left_loss, mse_loss, unc_loss, policy_accuracy, value_accuracy, moves_left_mean_error, policy_entropy, policy_ul, unc_mae, unc_mpi = self.calculate_test_summaries_inner_loop( x, y, z, q, m) sum_policy_accuracy += policy_accuracy sum_policy_entropy += policy_entropy @@ -995,6 +1077,9 @@ def calculate_test_validations_v2(self, steps): if self.moves_left: sum_moves_left += moves_left_loss sum_moves_left_mean_error += moves_left_mean_error + sum_unc += unc_loss + sum_unc_mae += unc_mae + sum_unc_mpi += unc_mpi counter += 1 if self.wdl: sum_value_accuracy += value_accuracy @@ -1011,12 +1096,18 @@ def calculate_test_validations_v2(self, steps): if self.moves_left: sum_moves_left /= counter sum_moves_left_mean_error /= counter + sum_unc /= counter + sum_unc_mae /= counter + sum_unc_mpi /= counter # Additionally rescale to [0, 1] so divide by 4 sum_mse /= (4.0 * counter) with self.validation_writer.as_default(): tf.summary.scalar("Policy Loss", sum_policy, step=steps) tf.summary.scalar("Value Loss", sum_value, step=steps) tf.summary.scalar("MSE Loss", sum_mse, step=steps) + tf.summary.scalar("Uncertainty Loss", sum_unc, step=steps) + tf.summary.scalar("Uncertainty Mean Error", sum_unc_mae, step=steps) + tf.summary.scalar("Uncertainty Mean Policy Improvement", sum_unc_mpi, step=steps) tf.summary.scalar("Policy Accuracy", sum_policy_accuracy, step=steps) @@ -1035,8 +1126,8 @@ def calculate_test_validations_v2(self, steps): step=steps) self.validation_writer.flush() - print("step {}, validation: policy={:g} value={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g} policy entropy={:g} policy ul={:g}".\ - format(steps, sum_policy, sum_value, sum_policy_accuracy, sum_value_accuracy, sum_mse, sum_policy_entropy, sum_policy_ul), end='') + print("step {}, validation: policy={:g} value={:g} unc={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g} policy entropy={:g} policy ul={:g}".\ + format(steps, sum_policy, sum_value, sum_unc, sum_policy_accuracy, sum_value_accuracy, sum_mse, sum_policy_entropy, sum_policy_ul), end='') if self.moves_left: print(" moves={:g} moves mean={:g}".format( @@ -1272,4 +1363,22 @@ def construct_net_v2(self, inputs): else: h_fc5 = None - return h_fc1, h_fc3, h_fc5 + # Policy Uncertainty head + conv_unc = self.conv_block_v2( + flow, + filter_size=3, + output_channels=self.RESIDUAL_FILTERS, + name='uncertainty1') + conv_unc2 = tf.keras.layers.Conv2D( + 80, + 3, + use_bias=True, + padding='same', + kernel_initializer='glorot_normal', + kernel_regularizer=self.l2reg, + bias_regularizer=self.l2reg, + data_format='channels_first', + name='uncertainty')(conv_unc) + h_fc7 = tf.keras.activations.tanh(ApplyPolicyMap()(conv_unc2)) + + return h_fc1, h_fc3, h_fc5, h_fc7