forked from HirokiNakahara/GUINNESS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
link_batch_normalization.py
154 lines (123 loc) · 6.09 KB
/
link_batch_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy
from chainer.functions.normalization import batch_normalization
from chainer import initializers
from chainer import link
from chainer import variable
import function_batch_normalization
class BatchNormalization(link.Link):
"""Batch normalization layer on outputs of linear or convolution functions.
This link wraps the :func:`~chainer.functions.batch_normalization` and
:func:`~chainer.functions.fixed_batch_normalization` functions.
It runs in three modes: training mode, fine-tuning mode, and testing mode.
In training mode, it normalizes the input by *batch statistics*. It also
maintains approximated population statistics by moving averages, which can
be used for instant evaluation in testing mode.
In fine-tuning mode, it accumulates the input to compute *population
statistics*. In order to correctly compute the population statistics, a
user must use this mode to feed mini batches running through whole training
dataset.
In testing mode, it uses pre-computed population statistics to normalize
the input variable. The population statistics is approximated if it is
computed by training mode, or accurate if it is correctly computed by
fine-tuning mode.
Args:
size (int or tuple of ints): Size (or shape) of channel
dimensions.
decay (float): Decay rate of moving average. It is used on training.
eps (float): Epsilon value for numerical stability.
dtype (numpy.dtype): Type to use in computing.
use_gamma (bool): If `True`, use scaling parameter. Otherwise, use
unit(1) which makes no effect.
use_beta (bool): If `True`, use shifting parameter. Otherwise, use
unit(0) which makes no effect.
See: `Batch Normalization: Accelerating Deep Network Training by Reducing\
Internal Covariate Shift <http://arxiv.org/abs/1502.03167>`_
.. seealso::
:func:`~chainer.functions.batch_normalization`,
:func:`~chainer.functions.fixed_batch_normalization`
Attributes:
gamma (~chainer.Variable): Scaling parameter.
beta (~chainer.Variable): Shifting parameter.
avg_mean (~chainer.Variable): Population mean.
avg_var (~chainer.Variable): Population variance.
N (int): Count of batches given for fine-tuning.
decay (float): Decay rate of moving average. It is used on training.
eps (float): Epsilon value for numerical stability. This value is added
to the batch variances.
"""
def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32,
use_gamma=True, use_beta=True,
initial_gamma=None, initial_beta=None):
super(BatchNormalization, self).__init__()
if use_gamma:
self.add_param('gamma', size, dtype=dtype)
if initial_gamma is None:
initial_gamma = initializers.One()
initializers.init_weight(self.gamma.data, initial_gamma)
if use_beta:
self.add_param('beta', size, dtype=dtype)
if initial_beta is None:
initial_beta = initializers.Zero()
initializers.init_weight(self.beta.data, initial_beta)
self.add_persistent('avg_mean', numpy.zeros(size, dtype=dtype))
self.add_persistent('avg_var', numpy.zeros(size, dtype=dtype))
self.add_persistent('N', 0)
self.decay = decay
self.eps = eps
def __call__(self, x, test=False, finetune=False):
"""Invokes the forward propagation of BatchNormalization.
BatchNormalization accepts additional arguments, which controls three
different running mode.
Args:
x (Variable): An input variable.
test (bool): If ``True``, BatchNormalization runs in testing mode;
it normalizes the input using pre-computed statistics.
finetune (bool): If ``True``, BatchNormalization runs in
fine-tuning mode; it accumulates the input array to compute
population statistics for normalization, and normalizes the
input using batch statistics.
If ``test`` and ``finetune`` are both ``False``, then
BatchNormalization runs in training mode; it computes moving averages
of mean and variance for evaluation during training, and normalizes the
input using batch statistics.
"""
# use_batch_mean = not test or finetune --------------------------------
# -----------------------------------------------------------------------------
use_batch_mean = False
if hasattr(self, 'gamma'):
gamma = self.gamma
else:
gamma = variable.Variable(self.xp.ones(
self.avg_mean.shape, dtype=x.dtype), volatile='auto')
if hasattr(self, 'beta'):
beta = self.beta
else:
beta = variable.Variable(self.xp.zeros(
self.avg_mean.shape, dtype=x.dtype), volatile='auto')
if use_batch_mean:
if finetune:
self.N += 1
decay = 1. - 1. / self.N
else:
decay = self.decay
func = function_batch_normalization.BatchNormalizationFunction(
self.eps, self.avg_mean, self.avg_var, True, decay)
ret = func(x, gamma, beta)
self.avg_mean = func.running_mean
self.avg_var = func.running_var
else:
# Use running average statistics or fine-tuned statistics.
mean = variable.Variable(self.avg_mean, volatile='auto')
var = variable.Variable(self.avg_var, volatile='auto')
ret = batch_normalization.fixed_batch_normalization(
x, gamma, beta, mean, var, self.eps)
return ret
def start_finetuning(self):
"""Resets the population count for collecting population statistics.
This method can be skipped if it is the first time to use the
fine-tuning mode. Otherwise, this method should be called before
starting the fine-tuning mode again.
"""
self.N = 0