diff --git a/digits/tools/tensorflow/model.py b/digits/tools/tensorflow/model.py index a19df6309..6743f5291 100644 --- a/digits/tools/tensorflow/model.py +++ b/digits/tools/tensorflow/model.py @@ -144,16 +144,17 @@ def create_model(self, obj_UserModel, stage_scope, batch_x=None): current_scope = stage_scope if len(available_devices) == 1 else ('tower_%d' % dev_i) with tf.name_scope(current_scope) as scope_tower: - if self.stage != digits.STAGE_INF: - tower_model = self.add_tower(obj_tower=obj_UserModel, - x=batch_x_split[dev_i], - y=batch_y_split[dev_i]) - else: - tower_model = self.add_tower(obj_tower=obj_UserModel, - x=batch_x_split[dev_i], - y=None) - with tf.variable_scope(digits.GraphKeys.MODEL, reuse=dev_i > 0 or self._reuse): + + if self.stage != digits.STAGE_INF: + tower_model = self.add_tower(obj_tower=obj_UserModel, + x=batch_x_split[dev_i], + y=batch_y_split[dev_i]) + else: + tower_model = self.add_tower(obj_tower=obj_UserModel, + x=batch_x_split[dev_i], + y=None) + tower_model.inference # touch to initialize # Reuse the variables in this scope for the next tower/device