Skip to content

Commit

Permalink
[MXNET-792] Fix for issue apache#9816 with dropout operator and RNG (a…
Browse files Browse the repository at this point in the history
…pache#12091)

* added mshadow op for threshold_eq (theshold currently does <, this will do <=)

modified dropout operator to use threshold_eq instead of theshold this will ensure equivalent behavior for the random numbers generated on CPU [0, 1) and GPU (0, 1]

removed fixed seed for test_dropout

* removed comment about flaky test
  • Loading branch information
samskalicky authored and anirudh2290 committed Aug 20, 2018
1 parent 13a875d commit a110ca9
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ MXNET_UNARY_MATH_OP(square_grad, 2.0f * math::id(a));

/*! \brief used for generate Bernoulli mask */
MXNET_BINARY_MATH_OP_NC(threshold, a < b ? DType(1) : DType(0));
MXNET_BINARY_MATH_OP_NC(threshold_eq, a <= b ? DType(1) : DType(0));

/*! \brief used for generate element of abs */
MXNET_UNARY_MATH_OP(abs, math::fabs(a)); // NOLINT(*)
Expand Down
3 changes: 2 additions & 1 deletion src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class DropoutOp {
const real_t pkeep) {
RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
const real_t rand_num = static_cast<real_t>(genImpl.uniform());
mask_out[i] = mshadow_op::threshold::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep);
mask_out[i] = mshadow_op::threshold_eq::Map<real_t>(rand_num, pkeep) * (1.0f / pkeep);
dropout_out[i] = input_data[i] * mask_out[i];
});
}
Expand Down Expand Up @@ -258,6 +258,7 @@ class DropoutOp {
this->pkeep_);
return;
}

// initialize the mask
LaunchRNG<BernoulliKernel, xpu>(s, pgen, mask.Size(),
mask.dptr<DType>(),
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5722,8 +5722,7 @@ def test_stack():
check_numeric_gradient(out, inputs)


# test fails with seed 990952066: 0 output seen with dropout ratio=0. See issue #9816
@with_seed(1234)
@with_seed()
def test_dropout():
def zero_count(array, ratio):
zeros = 0
Expand Down Expand Up @@ -5775,6 +5774,7 @@ def check_dropout_ratio(ratio, shape):

exe.arg_arrays[0][:] = 1
exe.forward(is_train=True)

if not math.isnan(max_value):
assert exe.outputs[0].asnumpy().max() > 0
else:
Expand Down

0 comments on commit a110ca9

Please sign in to comment.