forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add gradient clipping (pytorch#2452)
As titled.
- Loading branch information
Showing
6 changed files
with
317 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
from caffe2.python import core | ||
from caffe2.proto import caffe2_pb2 | ||
from caffe2.python.optimizer import get_param_device | ||
from caffe2.python.modeling.net_modifier import NetModifier | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GradientClipping(NetModifier): | ||
|
||
L1_NORM = 'l1_norm' | ||
L2_NORM = 'l2_norm' | ||
|
||
BY_NORM = 'by_norm' | ||
|
||
GRAD_CLIP_METHODS = [BY_NORM] | ||
CLIP_GRADIENT_NORM_TYPES = [L2_NORM, L1_NORM] | ||
|
||
def __init__(self, grad_clip_method, clip_norm_type, clip_threshold, | ||
use_parameter_norm=False, compute_norm_ratio=False): | ||
""" | ||
Clips gradient to avoid gradient magnitude explosion or vanishing gradient. | ||
Args: | ||
grad_clip_method: ways to clip the gradients | ||
clip_norm_type: type of norm used in the necessary computation | ||
clip_threshold: threshold used to determine whether to clip | ||
use_parameter_norm: a boolean to indicate whether to incorporate | ||
the norm of the parameter | ||
compute_norm_ratio: a boolean to compute the ratio between gradient norm | ||
and parameter norm explicitly for debugging purpose | ||
""" | ||
|
||
assert grad_clip_method in self.GRAD_CLIP_METHODS, ( | ||
"This method of clipping, {}, has not been implemented.".format( | ||
clip_norm_type)) | ||
|
||
assert clip_norm_type in self.CLIP_GRADIENT_NORM_TYPES, ( | ||
"This method of clipping, {}, has not been implemented.".format( | ||
clip_norm_type)) | ||
|
||
self.grad_clip_method = grad_clip_method | ||
self.clip_norm_type = clip_norm_type | ||
self.clip_threshold = float(clip_threshold) | ||
self.use_parameter_norm = use_parameter_norm | ||
self.compute_norm_ratio = compute_norm_ratio | ||
|
||
def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None): | ||
|
||
assert grad_map is not None | ||
|
||
CPU = core.DeviceOption(caffe2_pb2.CPU) | ||
|
||
for param, grad in grad_map.items(): | ||
|
||
# currently sparse gradients won't be clipped | ||
# futher implementation is needed to enable it | ||
if isinstance(grad, core.GradientSlice): | ||
continue | ||
|
||
device = get_param_device( | ||
param, | ||
grad_map[str(param)], | ||
param_to_device=blob_to_device, | ||
default_device=CPU, | ||
) | ||
|
||
with core.DeviceScope(device): | ||
if self.grad_clip_method == self.BY_NORM: | ||
if self.clip_norm_type == self.L2_NORM: | ||
p = 2 | ||
elif self.clip_norm_type == self.L1_NORM: | ||
p = 1 | ||
|
||
grad_norm = net.LpNorm( | ||
[grad], | ||
net.NextScopedBlob(prefix=str(grad) + '_l{}_norm'.format(p)), | ||
p=p, | ||
) | ||
|
||
if p == 2: | ||
grad_norm = net.Pow([grad_norm], exponent=0.5) | ||
|
||
op_inputs = [grad, grad_norm] | ||
|
||
if self.use_parameter_norm: | ||
param_norm = net.LpNorm( | ||
[param], | ||
net.NextScopedBlob( | ||
prefix=str(param) + '_l{}_norm'.format(p)), | ||
p=p, | ||
) | ||
|
||
if p == 2: | ||
param_norm = net.Pow([param_norm], exponent=0.5) | ||
|
||
op_inputs.append(param_norm) | ||
|
||
if self.compute_norm_ratio: | ||
net.Div( | ||
[grad_norm, param_norm], | ||
[net.NextScopedBlob( | ||
prefix=str(param) + '_norm_ratio')] | ||
) | ||
|
||
net.ClipTensorByScaling( | ||
op_inputs, | ||
[grad], | ||
threshold=self.clip_threshold, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) 2016-present, Facebook, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
############################################################################## | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import unittest | ||
from caffe2.python import workspace, brew, model_helper | ||
from caffe2.python.modeling.gradient_clipping import GradientClipping | ||
|
||
import numpy as np | ||
|
||
|
||
class GradientClippingTest(unittest.TestCase): | ||
def test_gradient_clipping(self): | ||
model = model_helper.ModelHelper(name="test") | ||
data = model.net.AddExternalInput("data") | ||
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2) | ||
|
||
# no operator name set, will use default | ||
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1) | ||
|
||
sigm = model.net.Sigmoid(fc2, 'sigm') | ||
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq') | ||
loss = model.net.SumElements(sq, 'loss') | ||
|
||
grad_map = model.AddGradientOperators([loss]) | ||
|
||
grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']} | ||
|
||
net_modifier = GradientClipping( | ||
grad_clip_method='by_norm', | ||
clip_norm_type='l2_norm', | ||
clip_threshold=0.1, | ||
) | ||
|
||
net_modifier(model.net, grad_map=grad_map_for_param) | ||
|
||
workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32)) | ||
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32)) | ||
|
||
workspace.RunNetOnce(model.param_init_net) | ||
workspace.RunNetOnce(model.net) | ||
|
||
# 5 forward ops + 6 backward ops + 2 * (3 gradient clipping ops) | ||
self.assertEqual(len(model.net.Proto().op), 17) | ||
|
||
def test_gradient_clipping_l1_norm(self): | ||
model = model_helper.ModelHelper(name="test") | ||
data = model.net.AddExternalInput("data") | ||
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2) | ||
|
||
# no operator name set, will use default | ||
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1) | ||
|
||
sigm = model.net.Sigmoid(fc2, 'sigm') | ||
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq') | ||
loss = model.net.SumElements(sq, 'loss') | ||
|
||
grad_map = model.AddGradientOperators([loss]) | ||
|
||
grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']} | ||
|
||
net_modifier = GradientClipping( | ||
grad_clip_method='by_norm', | ||
clip_norm_type='l1_norm', | ||
clip_threshold=0.1, | ||
) | ||
|
||
net_modifier(model.net, grad_map=grad_map_for_param) | ||
|
||
workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32)) | ||
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32)) | ||
|
||
workspace.RunNetOnce(model.param_init_net) | ||
workspace.RunNetOnce(model.net) | ||
|
||
# 5 forward ops + 6 backward ops + 2 * (2 gradient clipping ops) | ||
self.assertEqual(len(model.net.Proto().op), 15) | ||
|
||
def test_gradient_clipping_using_param_norm(self): | ||
model = model_helper.ModelHelper(name="test") | ||
data = model.net.AddExternalInput("data") | ||
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2) | ||
|
||
# no operator name set, will use default | ||
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1) | ||
|
||
sigm = model.net.Sigmoid(fc2, 'sigm') | ||
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq') | ||
loss = model.net.SumElements(sq, 'loss') | ||
|
||
grad_map = model.AddGradientOperators([loss]) | ||
|
||
grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']} | ||
|
||
net_modifier = GradientClipping( | ||
grad_clip_method='by_norm', | ||
clip_norm_type='l2_norm', | ||
clip_threshold=0.1, | ||
use_parameter_norm=True, | ||
) | ||
|
||
net_modifier(model.net, grad_map=grad_map_for_param) | ||
|
||
workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32)) | ||
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32)) | ||
|
||
workspace.RunNetOnce(model.param_init_net) | ||
workspace.RunNetOnce(model.net) | ||
|
||
# 5 forward ops + 6 backward ops + 2 * (5 gradient clipping ops) | ||
self.assertEqual(len(model.net.Proto().op), 21) | ||
|
||
def test_gradient_clipping_compute_norm_ratio(self): | ||
model = model_helper.ModelHelper(name="test") | ||
data = model.net.AddExternalInput("data") | ||
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2) | ||
|
||
# no operator name set, will use default | ||
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1) | ||
|
||
sigm = model.net.Sigmoid(fc2, 'sigm') | ||
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq') | ||
loss = model.net.SumElements(sq, 'loss') | ||
|
||
grad_map = model.AddGradientOperators([loss]) | ||
|
||
grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']} | ||
|
||
net_modifier = GradientClipping( | ||
grad_clip_method='by_norm', | ||
clip_norm_type='l2_norm', | ||
clip_threshold=0.1, | ||
use_parameter_norm=True, | ||
compute_norm_ratio=True, | ||
) | ||
|
||
net_modifier(model.net, grad_map=grad_map_for_param) | ||
|
||
workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32)) | ||
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32)) | ||
|
||
workspace.RunNetOnce(model.param_init_net) | ||
workspace.RunNetOnce(model.net) | ||
|
||
# 5 forward ops + 6 backward ops + 2 * (6 gradient clipping ops) | ||
self.assertEqual(len(model.net.Proto().op), 23) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters