From 233fcfc5a89d69037290964c38066bbe0bda6a87 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Sat, 9 Dec 2017 14:26:02 +0800 Subject: [PATCH] ndarray: change internal api of plus to help autograd (#364) address https://github.com/dmlc/MXNet.jl/pull/274#issuecomment-349951876 --- src/ndarray.jl | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/ndarray.jl b/src/ndarray.jl index 37894882b..c9b17924f 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -587,8 +587,10 @@ Summation. Multiple arguments of either scalar or `NDArray` could be added together. Note at least the first or second argument needs to be an `NDArray` to avoid ambiguity of built-in summation. """ -+(x::NDArray, ys::NDArrayOrReal...) = add_to!(copy(x, context(x)), ys...) -+(x::Real, y::NDArray, zs::NDArrayOrReal...) = add_to!(copy(y, context(y)), x, zs...) ++(x::NDArray) = x ++(x::NDArray, y::NDArray) = _plus(x, y) ++(x::NDArray, y::Real) = _plus_scalar(x, scalar = y) ++(y::Real, x::NDArray) = _plus_scalar(x, scalar = y) broadcast_(::typeof(+), x::NDArray, y::NDArrayOrReal) = x + y broadcast_(::typeof(+), x::Real, y::NDArray) = x + y @@ -1205,20 +1207,16 @@ function _get_ndarray_function_def(name :: String) args = MX_handle[] end - if length(output_vars) > 0 - output_handles = map((x) -> Base.cconvert(MX_handle, x), output_vars) - # XXX: Julia 0.4 has bug: [Array{MX_handle}] == Array{MX_handle} - output_handles_pp = Array{Array{MX_handle}}(1) - output_handles_pp[1] = Base.cconvert(Ptr{MX_handle}, output_handles) + output_handles_pp = if length(output_vars) > 0 + [map(x -> x.handle, output_vars)] else - output_handles_pp = [Base.convert(Ptr{MX_handle}, 0)] + [Ptr{MX_handle}(C_NULL)] end num_outputs_p = [convert(Cint, num_outputs)] kw_keys_str = String[string(x[1]) for x in kwargs] kw_vals_str = String[dump_mx_param(x[2]) for x in kwargs] - #op_handle = _get_cached_libmx_op_handle($(QuoteNode(name))) op_handle = _get_cached_libmx_op_handle($(name)) @mxcall(:MXImperativeInvoke, (MX_handle, Cint, Ptr{MX_handle}, @@ -1229,13 +1227,13 @@ function _get_ndarray_function_def(name :: String) length(kwargs), kw_keys_str, kw_vals_str) if out == nothing - handle_array = unsafe_wrap(Array, output_handles_pp[], num_outputs_p[]) - handle_array = [MX_NDArrayHandle(x) for x in handle_array] - arrays = [NDArray(hdr) for hdr in handle_array] - if length(arrays) == 1 - return arrays[1] + n = num_outputs_p[] + hdls = unsafe_wrap(Array{MX_handle}, output_handles_pp[], n) + xs = NDArray[NDArray(MX_NDArrayHandle(x)) for x in hdls] + if n == 1 + return xs[] else - return arrays + return xs end else return out