From f602e47db8bbc033de3ece52bcdeb15ee0361914 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 19 Jul 2017 18:02:25 +0800 Subject: [PATCH] fix leak reuse --- digits/tools/tensorflow/model.py | 83 ++++++++++++++++---------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/digits/tools/tensorflow/model.py b/digits/tools/tensorflow/model.py index 310e300c..eee4e3af 100644 --- a/digits/tools/tensorflow/model.py +++ b/digits/tools/tensorflow/model.py @@ -132,49 +132,50 @@ def create_model(self, obj_UserModel, stage_scope, batch_x=None): # Run the user model through the build_model function that should be filled in grad_towers = [] + with tf.variable_scope(tf.get_variable_scope()): for dev_i, dev_name in enumerate(available_devices): with tf.device(dev_name): - current_scope = stage_scope if len(available_devices) == 1 else ('tower_%d' % dev_i) - with tf.name_scope(current_scope) as scope_tower: - - if self.stage != digits.STAGE_INF: - tower_model = self.add_tower(obj_tower=obj_UserModel, - x=batch_x_split[dev_i], - y=batch_y_split[dev_i]) - else: - tower_model = self.add_tower(obj_tower=obj_UserModel, - x=batch_x_split[dev_i], - y=None) - - with tf.variable_scope(digits.GraphKeys.MODEL, reuse=dev_i > 0): - tower_model.inference # touch to initialize - - if self.stage == digits.STAGE_INF: - # For inferencing we will only use the inference part of the graph - continue - - with tf.name_scope(digits.GraphKeys.LOSS): - for loss in self.get_tower_losses(tower_model): - tf.add_to_collection(digits.GraphKeys.LOSSES, loss['loss']) - - # Assemble all made within this scope so far. The user can add custom - # losses to the digits.GraphKeys.LOSSES collection - losses = tf.get_collection(digits.GraphKeys.LOSSES, scope=scope_tower) - losses += ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope=None) - tower_loss = tf.add_n(losses, name='loss') - - self.summaries.append(tf.summary.scalar(tower_loss.op.name, tower_loss)) - - # Reuse the variables in this scope for the next tower/device - tf.get_variable_scope().reuse_variables() - - if self.stage == digits.STAGE_TRAIN: - grad_tower_losses = [] - for loss in self.get_tower_losses(tower_model): - grad_tower_loss = self.optimizer.compute_gradients(loss['loss'], loss['vars']) - grad_tower_loss = tower_model.gradientUpdate(grad_tower_loss) - grad_tower_losses.append(grad_tower_loss) - grad_towers.append(grad_tower_losses) + current_scope = stage_scope if len(available_devices) == 1 else ('tower_%d' % dev_i) + with tf.name_scope(current_scope) as scope_tower: + + if self.stage != digits.STAGE_INF: + tower_model = self.add_tower(obj_tower=obj_UserModel, + x=batch_x_split[dev_i], + y=batch_y_split[dev_i]) + else: + tower_model = self.add_tower(obj_tower=obj_UserModel, + x=batch_x_split[dev_i], + y=None) + + with tf.variable_scope(digits.GraphKeys.MODEL, reuse=dev_i > 0): + tower_model.inference # touch to initialize + + if self.stage == digits.STAGE_INF: + # For inferencing we will only use the inference part of the graph + continue + + with tf.name_scope(digits.GraphKeys.LOSS): + for loss in self.get_tower_losses(tower_model): + tf.add_to_collection(digits.GraphKeys.LOSSES, loss['loss']) + + # Assemble all made within this scope so far. The user can add custom + # losses to the digits.GraphKeys.LOSSES collection + losses = tf.get_collection(digits.GraphKeys.LOSSES, scope=scope_tower) + losses += ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope=None) + tower_loss = tf.add_n(losses, name='loss') + + self.summaries.append(tf.summary.scalar(tower_loss.op.name, tower_loss)) + + # Reuse the variables in this scope for the next tower/device + tf.get_variable_scope().reuse_variables() + + if self.stage == digits.STAGE_TRAIN: + grad_tower_losses = [] + for loss in self.get_tower_losses(tower_model): + grad_tower_loss = self.optimizer.compute_gradients(loss['loss'], loss['vars']) + grad_tower_loss = tower_model.gradientUpdate(grad_tower_loss) + grad_tower_losses.append(grad_tower_loss) + grad_towers.append(grad_tower_losses) # Assemble and average the gradients from all towers if self.stage == digits.STAGE_TRAIN: