Skip to content

Commit

Permalink
[Dper2] Add NetModifier abstraction and support for plotting the norm…
Browse files Browse the repository at this point in the history
… of blobs (pytorch#2201)
  • Loading branch information
chocjy authored Mar 8, 2018
1 parent d90cd73 commit f4b1e8b
Show file tree
Hide file tree
Showing 7 changed files with 425 additions and 0 deletions.
41 changes: 41 additions & 0 deletions caffe2/python/layer_model_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from caffe2.python.modeling.parameter_sharing import (
parameter_sharing_context,
)
from caffe2.python.modeling.net_modifier import NetModifier

from caffe2.python.optimizer import get_param_device
from caffe2.python.regularizer import Regularizer
from caffe2.python.layers import layers
Expand Down Expand Up @@ -74,6 +76,9 @@ def __init__(self, name, input_feature_schema, trainer_extra_schema,
self._loss = None
self._output_schema = None

self._post_grad_net_modifiers = []
self._final_net_modifiers = []

# breakdown map; breakdown features are categorical (like dense) but not
# necessarily used to represent data for training
self._breakdown_map = None
Expand Down Expand Up @@ -326,6 +331,20 @@ def get_parameter_blobs(self):

return param_blobs

def add_post_grad_net_modifiers(self, modifier):
assert modifier not in self._post_grad_net_modifiers,\
"{0} is already in {1}".format(modifier, self._post_grad_net_modifiers)
assert isinstance(modifier, NetModifier),\
"{} has to be a NetModifier instance".format(modifier)
self._post_grad_net_modifiers.append(modifier)

def add_final_net_modifiers(self, modifier):
assert modifier not in self._final_net_modifiers,\
"{0} is already in {1}".format(modifier, self._final_net_modifiers)
assert isinstance(modifier, NetModifier),\
"{} has to be a NetModifier instance".format(modifier)
self._final_net_modifiers.append(modifier)

@property
def seed(self):
return self._seed
Expand Down Expand Up @@ -488,6 +507,28 @@ def apply_regularizers_after_optimizer(
regularizer(
train_net, train_init_net, param, grad_map.get(str(param)))

def apply_post_grad_net_modifiers(
self,
trainer_net,
trainer_init_net,
grad_map,
blob_to_device=None,
):
for modifier in self._post_grad_net_modifiers:
modifier(trainer_net, trainer_init_net, grad_map,
blob_to_device=blob_to_device)

def apply_final_net_modifiers(
self,
trainer_net,
trainer_init_net,
grad_map,
blob_to_device=None,
):
for modifier in self._final_net_modifiers:
modifier(trainer_net, trainer_init_net, grad_map,
blob_to_device=blob_to_device)

def apply_optimizers(
self,
train_net,
Expand Down
3 changes: 3 additions & 0 deletions caffe2/python/layer_model_instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def generate_training_nets(model, include_tags=None):
model.apply_regularizers_on_loss(train_net, train_init_net)
loss = model.loss
grad_map = train_net.AddGradientOperators(loss.field_blobs())
model.apply_post_grad_net_modifiers(train_net, train_init_net, grad_map)
model.apply_optimizers(train_net, train_init_net, grad_map)
model.apply_regularizers_after_optimizer(train_net, train_init_net, grad_map)
model.apply_final_net_modifiers(train_net, train_init_net, grad_map)

return train_init_net, train_net
69 changes: 69 additions & 0 deletions caffe2/python/modeling/compute_norm_for_blobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python import core, schema
from caffe2.python.modeling.net_modifier import NetModifier

import numpy as np


class ComputeNormForBlobs(NetModifier):
"""
This class modifies the net passed in by adding ops to compute norms for
certain blobs.
Args:
blobs: list of blobs to compute norm for
logging_frequency: frequency for printing norms to logs
p: type of norm. Currently it supports p=1 or p=2
"""

def __init__(self, blobs, logging_frequency, p=2):
self._blobs = blobs
self._logging_frequency = logging_frequency
self._p = p

def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):

p = self._p

for blob_name in self._blobs:
blob = core.BlobReference(blob_name)
if not net.BlobIsDefined(blob):
raise Exception('blob {0} is not defined in net {1}'.format(
blob, net.Name()))

norm_name = net.NextScopedBlob(prefix=blob + '_l{}_norm'.format(p))
norm = net.LpNorm(blob, norm_name, p=p)

if self._logging_frequency >= 1:
net.Print(norm, [], every_n=self._logging_frequency)

output_field_name = str(blob) + '_l{}_norm'.format(p)
output_scalar = schema.Scalar((np.float, (1,)), norm)

if net.output_record() is None:
net.set_output_record(
schema.Struct((output_field_name, output_scalar))
)
else:
net.AppendOutputRecordField(
output_field_name,
output_scalar)
133 changes: 133 additions & 0 deletions caffe2/python/modeling/compute_norm_for_blobs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import unittest
from caffe2.python import workspace, brew, model_helper
from caffe2.python.modeling.compute_norm_for_blobs import ComputeNormForBlobs

import numpy as np


class ComputeNormForBlobsTest(unittest.TestCase):
def test_compute_norm_for_blobs(self):
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2)

# no operator name set, will use default
brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1)

net_modifier = ComputeNormForBlobs(
blobs=['fc1_w', 'fc2_w'],
logging_frequency=10,
)

net_modifier(model.net)

workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32))

workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)

fc1_w = workspace.FetchBlob('fc1_w')
fc1_w_l2_norm = workspace.FetchBlob('fc1_w_l2_norm')

self.assertEqual(fc1_w_l2_norm.size, 1)
self.assertAlmostEqual(fc1_w_l2_norm[0],
np.linalg.norm(fc1_w)**2,
delta=1e-5)

self.assertEqual(len(model.net.Proto().op), 6)

assert 'fc1_w_l2_norm' in model.net.output_record().field_blobs(),\
model.net.output_record().field_blobs()
assert 'fc2_w_l2_norm' in model.net.output_record().field_blobs(),\
model.net.output_record().field_blobs()

def test_compute_norm_for_blobs_no_print(self):
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2)

# no operator name set, will use default
brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1)

net_modifier = ComputeNormForBlobs(
blobs=['fc1_w', 'fc2_w'],
logging_frequency=-1,
)

net_modifier(model.net)

workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32))

workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)

fc1_w = workspace.FetchBlob('fc1_w')
fc1_w_l2_norm = workspace.FetchBlob('fc1_w_l2_norm')

self.assertEqual(fc1_w_l2_norm.size, 1)
self.assertAlmostEqual(fc1_w_l2_norm[0],
np.linalg.norm(fc1_w)**2,
delta=1e-5)

self.assertEqual(len(model.net.Proto().op), 4)

assert 'fc1_w_l2_norm' in model.net.output_record().field_blobs(),\
model.net.output_record().field_blobs()
assert 'fc2_w_l2_norm' in model.net.output_record().field_blobs(),\
model.net.output_record().field_blobs()

def test_compute_l1_norm_for_blobs(self):
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2)

# no operator name set, will use default
brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1)

net_modifier = ComputeNormForBlobs(
blobs=['fc1_w', 'fc2_w'],
logging_frequency=10,
p=1,
)

net_modifier(model.net)

workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32))

workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)

fc1_w = workspace.FetchBlob('fc1_w')
fc1_w_l1_norm = workspace.FetchBlob('fc1_w_l1_norm')

self.assertEqual(fc1_w_l1_norm.size, 1)
self.assertAlmostEqual(fc1_w_l1_norm[0],
np.sum(np.abs(fc1_w)),
delta=1e-5)

self.assertEqual(len(model.net.Proto().op), 6)

assert 'fc1_w_l1_norm' in model.net.output_record().field_blobs(),\
model.net.output_record().field_blobs()
assert 'fc2_w_l1_norm' in model.net.output_record().field_blobs(),\
model.net.output_record().field_blobs()
65 changes: 65 additions & 0 deletions caffe2/python/modeling/compute_statistics_for_blobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python import core, schema
from caffe2.python.modeling.net_modifier import NetModifier

import numpy as np


class ComputeStatisticsForBlobs(NetModifier):
"""
This class modifies the net passed in by adding ops to compute statistics
for certain blobs. For each blob in the list, its min, max, mean and standard
deviation will be computed.
Args:
blobs: list of blobs to compute norm for
logging_frequency: frequency for printing norms to logs
"""

def __init__(self, blobs, logging_frequency):
self._blobs = blobs
self._logging_frequency = logging_frequency

def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):

for blob_name in self._blobs:
blob = core.BlobReference(blob_name)
if not net.BlobIsDefined(blob):
raise Exception('blob {0} is not defined in net {1}'.format(
blob, net.Name()))

cast_blob = net.Cast(blob, to=core.DataType.FLOAT)
stats_name = net.NextScopedBlob(prefix=blob + '_summary')
stats = net.Summarize(cast_blob, stats_name, to_file=0)
net.Print(stats, [], every_n=self._logging_frequency)

output_field_name = str(blob) + '_summary'
output_scalar = schema.Scalar((np.float, (1,)), stats)

if net.output_record() is None:
net.set_output_record(
schema.Struct((output_field_name, output_scalar))
)
else:
net.AppendOutputRecordField(
output_field_name,
output_scalar)
Loading

0 comments on commit f4b1e8b

Please sign in to comment.