Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix leak reuse #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 42 additions & 41 deletions digits/tools/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,49 +132,50 @@ def create_model(self, obj_UserModel, stage_scope, batch_x=None):

# Run the user model through the build_model function that should be filled in
grad_towers = []
with tf.variable_scope(tf.get_variable_scope()):
for dev_i, dev_name in enumerate(available_devices):
with tf.device(dev_name):
current_scope = stage_scope if len(available_devices) == 1 else ('tower_%d' % dev_i)
with tf.name_scope(current_scope) as scope_tower:

if self.stage != digits.STAGE_INF:
tower_model = self.add_tower(obj_tower=obj_UserModel,
x=batch_x_split[dev_i],
y=batch_y_split[dev_i])
else:
tower_model = self.add_tower(obj_tower=obj_UserModel,
x=batch_x_split[dev_i],
y=None)

with tf.variable_scope(digits.GraphKeys.MODEL, reuse=dev_i > 0):
tower_model.inference # touch to initialize

if self.stage == digits.STAGE_INF:
# For inferencing we will only use the inference part of the graph
continue

with tf.name_scope(digits.GraphKeys.LOSS):
for loss in self.get_tower_losses(tower_model):
tf.add_to_collection(digits.GraphKeys.LOSSES, loss['loss'])

# Assemble all made within this scope so far. The user can add custom
# losses to the digits.GraphKeys.LOSSES collection
losses = tf.get_collection(digits.GraphKeys.LOSSES, scope=scope_tower)
losses += ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope=None)
tower_loss = tf.add_n(losses, name='loss')

self.summaries.append(tf.summary.scalar(tower_loss.op.name, tower_loss))

# Reuse the variables in this scope for the next tower/device
tf.get_variable_scope().reuse_variables()

if self.stage == digits.STAGE_TRAIN:
grad_tower_losses = []
for loss in self.get_tower_losses(tower_model):
grad_tower_loss = self.optimizer.compute_gradients(loss['loss'], loss['vars'])
grad_tower_loss = tower_model.gradientUpdate(grad_tower_loss)
grad_tower_losses.append(grad_tower_loss)
grad_towers.append(grad_tower_losses)
current_scope = stage_scope if len(available_devices) == 1 else ('tower_%d' % dev_i)
with tf.name_scope(current_scope) as scope_tower:

if self.stage != digits.STAGE_INF:
tower_model = self.add_tower(obj_tower=obj_UserModel,
x=batch_x_split[dev_i],
y=batch_y_split[dev_i])
else:
tower_model = self.add_tower(obj_tower=obj_UserModel,
x=batch_x_split[dev_i],
y=None)

with tf.variable_scope(digits.GraphKeys.MODEL, reuse=dev_i > 0):
tower_model.inference # touch to initialize

if self.stage == digits.STAGE_INF:
# For inferencing we will only use the inference part of the graph
continue

with tf.name_scope(digits.GraphKeys.LOSS):
for loss in self.get_tower_losses(tower_model):
tf.add_to_collection(digits.GraphKeys.LOSSES, loss['loss'])

# Assemble all made within this scope so far. The user can add custom
# losses to the digits.GraphKeys.LOSSES collection
losses = tf.get_collection(digits.GraphKeys.LOSSES, scope=scope_tower)
losses += ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope=None)
tower_loss = tf.add_n(losses, name='loss')

self.summaries.append(tf.summary.scalar(tower_loss.op.name, tower_loss))

# Reuse the variables in this scope for the next tower/device
tf.get_variable_scope().reuse_variables()

if self.stage == digits.STAGE_TRAIN:
grad_tower_losses = []
for loss in self.get_tower_losses(tower_model):
grad_tower_loss = self.optimizer.compute_gradients(loss['loss'], loss['vars'])
grad_tower_loss = tower_model.gradientUpdate(grad_tower_loss)
grad_tower_losses.append(grad_tower_loss)
grad_towers.append(grad_tower_losses)

# Assemble and average the gradients from all towers
if self.stage == digits.STAGE_TRAIN:
Expand Down