Skip to content

Commit

Permalink
Add gradient clipping (pytorch#2452)
Browse files Browse the repository at this point in the history
As titled.
  • Loading branch information
chocjy authored Mar 27, 2018
1 parent c4e5001 commit 8fa38f8
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 8 deletions.
5 changes: 4 additions & 1 deletion caffe2/python/layer_model_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,11 @@ def apply_post_grad_net_modifiers(
grad_map,
blob_to_device=None,
):
param_grad_map = {param: grad_map[param]
for param in self.param_to_optim.keys() if param in grad_map}

for modifier in self._post_grad_net_modifiers:
modifier(trainer_net, trainer_init_net, grad_map,
modifier(trainer_net, trainer_init_net, param_grad_map,
blob_to_device=blob_to_device)

def apply_final_net_modifiers(
Expand Down
117 changes: 117 additions & 0 deletions caffe2/python/modeling/gradient_clipping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from caffe2.python import core
from caffe2.proto import caffe2_pb2
from caffe2.python.optimizer import get_param_device
from caffe2.python.modeling.net_modifier import NetModifier

import logging

logger = logging.getLogger(__name__)


class GradientClipping(NetModifier):

L1_NORM = 'l1_norm'
L2_NORM = 'l2_norm'

BY_NORM = 'by_norm'

GRAD_CLIP_METHODS = [BY_NORM]
CLIP_GRADIENT_NORM_TYPES = [L2_NORM, L1_NORM]

def __init__(self, grad_clip_method, clip_norm_type, clip_threshold,
use_parameter_norm=False, compute_norm_ratio=False):
"""
Clips gradient to avoid gradient magnitude explosion or vanishing gradient.
Args:
grad_clip_method: ways to clip the gradients
clip_norm_type: type of norm used in the necessary computation
clip_threshold: threshold used to determine whether to clip
use_parameter_norm: a boolean to indicate whether to incorporate
the norm of the parameter
compute_norm_ratio: a boolean to compute the ratio between gradient norm
and parameter norm explicitly for debugging purpose
"""

assert grad_clip_method in self.GRAD_CLIP_METHODS, (
"This method of clipping, {}, has not been implemented.".format(
clip_norm_type))

assert clip_norm_type in self.CLIP_GRADIENT_NORM_TYPES, (
"This method of clipping, {}, has not been implemented.".format(
clip_norm_type))

self.grad_clip_method = grad_clip_method
self.clip_norm_type = clip_norm_type
self.clip_threshold = float(clip_threshold)
self.use_parameter_norm = use_parameter_norm
self.compute_norm_ratio = compute_norm_ratio

def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):

assert grad_map is not None

CPU = core.DeviceOption(caffe2_pb2.CPU)

for param, grad in grad_map.items():

# currently sparse gradients won't be clipped
# futher implementation is needed to enable it
if isinstance(grad, core.GradientSlice):
continue

device = get_param_device(
param,
grad_map[str(param)],
param_to_device=blob_to_device,
default_device=CPU,
)

with core.DeviceScope(device):
if self.grad_clip_method == self.BY_NORM:
if self.clip_norm_type == self.L2_NORM:
p = 2
elif self.clip_norm_type == self.L1_NORM:
p = 1

grad_norm = net.LpNorm(
[grad],
net.NextScopedBlob(prefix=str(grad) + '_l{}_norm'.format(p)),
p=p,
)

if p == 2:
grad_norm = net.Pow([grad_norm], exponent=0.5)

op_inputs = [grad, grad_norm]

if self.use_parameter_norm:
param_norm = net.LpNorm(
[param],
net.NextScopedBlob(
prefix=str(param) + '_l{}_norm'.format(p)),
p=p,
)

if p == 2:
param_norm = net.Pow([param_norm], exponent=0.5)

op_inputs.append(param_norm)

if self.compute_norm_ratio:
net.Div(
[grad_norm, param_norm],
[net.NextScopedBlob(
prefix=str(param) + '_norm_ratio')]
)

net.ClipTensorByScaling(
op_inputs,
[grad],
threshold=self.clip_threshold,
)
162 changes: 162 additions & 0 deletions caffe2/python/modeling/gradient_clipping_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import unittest
from caffe2.python import workspace, brew, model_helper
from caffe2.python.modeling.gradient_clipping import GradientClipping

import numpy as np


class GradientClippingTest(unittest.TestCase):
def test_gradient_clipping(self):
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2)

# no operator name set, will use default
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1)

sigm = model.net.Sigmoid(fc2, 'sigm')
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq')
loss = model.net.SumElements(sq, 'loss')

grad_map = model.AddGradientOperators([loss])

grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']}

net_modifier = GradientClipping(
grad_clip_method='by_norm',
clip_norm_type='l2_norm',
clip_threshold=0.1,
)

net_modifier(model.net, grad_map=grad_map_for_param)

workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32))
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32))

workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)

# 5 forward ops + 6 backward ops + 2 * (3 gradient clipping ops)
self.assertEqual(len(model.net.Proto().op), 17)

def test_gradient_clipping_l1_norm(self):
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2)

# no operator name set, will use default
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1)

sigm = model.net.Sigmoid(fc2, 'sigm')
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq')
loss = model.net.SumElements(sq, 'loss')

grad_map = model.AddGradientOperators([loss])

grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']}

net_modifier = GradientClipping(
grad_clip_method='by_norm',
clip_norm_type='l1_norm',
clip_threshold=0.1,
)

net_modifier(model.net, grad_map=grad_map_for_param)

workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32))
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32))

workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)

# 5 forward ops + 6 backward ops + 2 * (2 gradient clipping ops)
self.assertEqual(len(model.net.Proto().op), 15)

def test_gradient_clipping_using_param_norm(self):
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2)

# no operator name set, will use default
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1)

sigm = model.net.Sigmoid(fc2, 'sigm')
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq')
loss = model.net.SumElements(sq, 'loss')

grad_map = model.AddGradientOperators([loss])

grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']}

net_modifier = GradientClipping(
grad_clip_method='by_norm',
clip_norm_type='l2_norm',
clip_threshold=0.1,
use_parameter_norm=True,
)

net_modifier(model.net, grad_map=grad_map_for_param)

workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32))
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32))

workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)

# 5 forward ops + 6 backward ops + 2 * (5 gradient clipping ops)
self.assertEqual(len(model.net.Proto().op), 21)

def test_gradient_clipping_compute_norm_ratio(self):
model = model_helper.ModelHelper(name="test")
data = model.net.AddExternalInput("data")
fc1 = brew.fc(model, data, "fc1", dim_in=4, dim_out=2)

# no operator name set, will use default
fc2 = brew.fc(model, fc1, "fc2", dim_in=2, dim_out=1)

sigm = model.net.Sigmoid(fc2, 'sigm')
sq = model.net.SquaredL2Distance([sigm, 'label'], 'sq')
loss = model.net.SumElements(sq, 'loss')

grad_map = model.AddGradientOperators([loss])

grad_map_for_param = {key: grad_map[key] for key in ['fc1_w', 'fc2_w']}

net_modifier = GradientClipping(
grad_clip_method='by_norm',
clip_norm_type='l2_norm',
clip_threshold=0.1,
use_parameter_norm=True,
compute_norm_ratio=True,
)

net_modifier(model.net, grad_map=grad_map_for_param)

workspace.FeedBlob('data', np.random.rand(10, 4).astype(np.float32))
workspace.FeedBlob('label', np.random.rand(10, 1).astype(np.float32))

workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)

# 5 forward ops + 6 backward ops + 2 * (6 gradient clipping ops)
self.assertEqual(len(model.net.Proto().op), 23)
25 changes: 19 additions & 6 deletions caffe2/python/operator_test/clip_tensor_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,44 @@ class TestClipTensorByScalingOp(hu.HypothesisTestCase):

@given(n=st.integers(5, 8), d=st.integers(2, 4),
threshold=st.floats(0.1, 10),
additional_threshold=st.floats(0.1, 10),
use_additional_threshold=st.booleans(),
inplace=st.booleans(),
**hu.gcs_cpu_only)
def test_clip_tensor_by_scaling(self, n, d, threshold, inplace, gc, dc):
def test_clip_tensor_by_scaling(self, n, d, threshold, additional_threshold,
use_additional_threshold, inplace, gc, dc):

tensor = np.random.rand(n, d).astype(np.float32)
val = np.array(np.linalg.norm(tensor))
additional_threshold = np.array([additional_threshold]).astype(np.float32)

def clip_tensor_by_scaling_ref(tensor_data, val_data):
if val_data > threshold:
ratio = threshold / float(val_data)
def clip_tensor_by_scaling_ref(tensor_data, val_data,
additional_threshold=None):

if additional_threshold is not None:
final_threshold = threshold * additional_threshold
else:
final_threshold = threshold

if val_data > final_threshold:
ratio = final_threshold / float(val_data)
tensor_data = tensor_data * ratio

return [tensor_data]

op = core.CreateOperator(
"ClipTensorByScaling",
["tensor", "val"],
["tensor", "val"] if not use_additional_threshold else (
["tensor", "val", "additional_threshold"]),
['Y'] if not inplace else ["tensor"],
threshold=threshold,
)

self.assertReferenceChecks(
device_option=gc,
op=op,
inputs=[tensor, val],
inputs=[tensor, val] if not use_additional_threshold else (
[tensor, val, additional_threshold]),
reference=clip_tensor_by_scaling_ref,
)

Expand Down
9 changes: 8 additions & 1 deletion caffe2/sgd/clip_tensor_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace caffe2 {

REGISTER_CPU_OPERATOR(ClipTensorByScaling, ClipTensorByScalingOp<CPUContext>);
OPERATOR_SCHEMA(ClipTensorByScaling)
.NumInputs(2)
.NumInputs(2, 3)
.NumOutputs(1)
.AllowInplace({{0, 0}})
.SetDoc(R"DOC(
Expand All @@ -14,10 +14,17 @@ OPERATOR_SCHEMA(ClipTensorByScaling)
tensor *= (threshold / value).
An optional input called additional_threshold can be provided which
will scale the original threshold before it is used. That is,
the final threshold will become threshold * additional_threshold.
This op could be used for gradient clipping.
)DOC")
.Input(0, "input_tensor", "Tensor of floats to be clipped.")
.Input(1, "val", "Value to be compared against the threshold")
.Input(
2,
"additional_threshold",
"An optional additonal threshold to scale the orignal threshold")
.Arg("threshold", "Threshold to determine whether to scale down the tensor")
.Output(
0,
Expand Down
7 changes: 7 additions & 0 deletions caffe2/sgd/clip_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ class ClipTensorByScalingOp final : public Operator<Context> {
clipped->ResizeLike(input_tensor);
float* clipped_tensor_data = clipped->template mutable_data<float>();

if (InputSize() > 2) {
const auto& additional_threshold = Input(2);
CAFFE_ENFORCE_EQ(additional_threshold.size(), 1);

threshold_ *= *(additional_threshold.template data<float>());
}

if (*val_data > threshold_) {
float ratio = threshold_ / *val_data;

Expand Down

0 comments on commit 8fa38f8

Please sign in to comment.