diff --git a/src/ndarray.jl b/src/ndarray.jl index eba7e2169..6f2a4972a 100644 --- a/src/ndarray.jl +++ b/src/ndarray.jl @@ -638,6 +638,10 @@ function mul_to!(dst :: NDArray, arg :: Union{Real, NDArray}) return dst end +mul(x::NDArray, y::NDArray) = _mul(x, y) +mul(x::NDArray, s::Real) = _mul_scalar(x, scalar=s) +mul(s::Real, x::NDArray) = _mul_scalar(x, scalar=s) + import Base: * """ @@ -645,27 +649,22 @@ import Base: * Elementwise multiplication of `arg0` and `arg`, could be either scalar or `NDArray`. """ -@compatdot function Base.broadcast(::typeof(*), arg0 :: NDArray, arg :: Union{Real, NDArray}) - ret = copy(arg0, context(arg0)) - mul_to!(ret, arg) +@compatdot function Base.broadcast(::typeof(*), arg0 :: NDArray, + arg :: Union{Real, NDArray}) + mul(arg0, arg) end @compatdot function Base.broadcast(::typeof(*), arg0 :: Real, arg :: NDArray) - arg .* arg0 + mul(arg, arg0) end """ *(arg0, arg1) -Currently only multiplication a scalar with an `NDArray` is implemented. Matrix multiplication -is to be added soon. +Currently only multiplication a scalar with an `NDArray` is implemented. +Matrix multiplication is to be added soon. """ -function *(arg0 :: NDArray, arg :: Real) - ret = copy(arg0, context(arg0)) - mul_to!(ret, arg) -end -function *(arg0 :: Real, arg :: NDArray) - *(arg, arg0) -end +*(arg0 :: NDArray, arg :: Real) = mul(arg0, arg) +*(arg0 :: Real, arg :: NDArray) = mul(arg0, arg) """ div_from!(dst :: NDArray, arg :: Union{Real, NDArray}) diff --git a/test/unittest/autograd.jl b/test/unittest/autograd.jl index 5c7c5dec1..728c96e11 100644 --- a/test/unittest/autograd.jl +++ b/test/unittest/autograd.jl @@ -107,11 +107,45 @@ function test_getsymbol() end +function test_mul() + info("AutoGrad::mul") + + let x = mx.NDArray([1 2; 3 4]) + g = mx.attach_grad(x) + y = mx.record() do + 2x .* x + end + + @test copy(g) == [0 0; 0 0] + @test copy(y) == [2 8; 18 32] + + mx.backward(y) + # gradient is 4x + @test copy(g) == [4 8; 12 16] + end + + let x = mx.NDArray([1 2; 3 4]) + g = mx.attach_grad(x) + y = mx.record() do + x * 2 .* x + end + + @test copy(g) == [0 0; 0 0] + @test copy(y) == [2 8; 18 32] + + mx.backward(y) + # gradient is 4x + @test copy(g) == [4 8; 12 16] + end +end + + @testset "AutoGrad Test" begin test_getgrad() test_mark_variables() test_record() test_getsymbol() + test_mul() end diff --git a/test/unittest/ndarray.jl b/test/unittest/ndarray.jl index 2185d920c..359b70cd1 100644 --- a/test/unittest/ndarray.jl +++ b/test/unittest/ndarray.jl @@ -223,6 +223,12 @@ function test_mul() t6, a6 = rand_tensors(Float16, dims) scalar_small = Float16(1e-5) @test reldiff(t6 * scalar_small, copy(a6 .* scalar_small)) < 1e-1 + + let x = mx.NDArray([1 2; 3 4]) + @test eltype(x) == Int + @test copy(1.5x) == [1 2; 3 4] + @test copy(1.9x) == [1 2; 3 4] + end end function test_div() @@ -392,25 +398,25 @@ function test_eltype() end function test_reshape() - info("NDArray::reshape") - A = rand(2, 3, 4) + info("NDArray::reshape") + A = rand(2, 3, 4) - B = reshape(mx.NDArray(A), 4, 3, 2) - @test size(B) == (4, 3, 2) - @test copy(B)[3, 1, 1] == A[1, 2, 1] + B = reshape(mx.NDArray(A), 4, 3, 2) + @test size(B) == (4, 3, 2) + @test copy(B)[3, 1, 1] == A[1, 2, 1] - C = reshape(mx.NDArray(A), (4, 3, 2)) - @test size(C) == (4, 3, 2) - @test copy(C)[3, 1, 1] == A[1, 2, 1] + C = reshape(mx.NDArray(A), (4, 3, 2)) + @test size(C) == (4, 3, 2) + @test copy(C)[3, 1, 1] == A[1, 2, 1] - info("NDArray::reshape::reverse") - A = mx.zeros(10, 5, 4) + info("NDArray::reshape::reverse") + A = mx.zeros(10, 5, 4) - B = reshape(A, -1, 0) - @test size(B) == (40, 5) + B = reshape(A, -1, 0) + @test size(B) == (40, 5) - C = reshape(A, -1, 0, reverse=true) - @test size(C) == (50, 4) + C = reshape(A, -1, 0, reverse=true) + @test size(C) == (50, 4) end function test_kwargs()