Skip to content

Commit

Permalink
ndarray: cat, hcat, vcat (#380)
Browse files Browse the repository at this point in the history
e.g. `hcat`
```julia
julia> x
4 mx.NDArray{Float64,1} @ CPU0:
 1.0
 2.0
 3.0
 4.0

julia> y
4 mx.NDArray{Float64,1} @ CPU0:
 2.0
 4.0
 6.0
 8.0

julia> [x y]
4×2 mx.NDArray{Float64,2} @ CPU0:
 1.0  2.0
 2.0  4.0
 3.0  6.0
 4.0  8.0
```
  • Loading branch information
iblislin authored and pluskid committed Dec 16, 2017
1 parent d921225 commit 4507598
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 1 deletion.
28 changes: 27 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,33 @@
x .% 2
2 .% x
```


* `cat`, `vcat`, `hcat` is implemented. (#TBD)

E.g. `hcat`
```julia
julia> x
4 mx.NDArray{Float64,1} @ CPU0:
1.0
2.0
3.0
4.0

julia> y
4 mx.NDArray{Float64,1} @ CPU0:
2.0
4.0
6.0
8.0

julia> [x y]
4×2 mx.NDArray{Float64,2} @ CPU0:
1.0 2.0
2.0 4.0
3.0 6.0
4.0 8.0
```

* Transposing a column `NDArray` to a row `NDArray` is supported now. (#TBD)

```julia
Expand Down
26 changes: 26 additions & 0 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,32 @@ function deepcopy(arr::NDArray)
NDArray(MX_NDArrayHandle(out_ref[]))
end

"""
hcat(x::NDArray...)
"""
Base.hcat(xs::NDArray{T}...) where T = cat(2, xs...)

"""
vcat(x::NDArray...)
"""
Base.vcat(xs::NDArray{T}...) where T = cat(1, xs...)

"""
cat(dim, xs::NDArray...)
Concate the `NDArray`s which have the same element type along the `dim`.
Building a diagonal matrix is not supported yet.
"""
function Base.cat(dim::Int, xs::NDArray{T}...) where T
ns = ndims.(xs)
d = Base.max(dim, maximum(ns))
xs′ = map(zip(ns, xs)) do i
n, x = i
(d > n) ? reshape(x, -2, Base.ones(Int, d - n)...) : x
end
concat(xs′..., dim = d - dim)
end

"""
@inplace
Expand Down
62 changes: 62 additions & 0 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,67 @@ function test_endof()
end
end # function test_endof

function test_cat()
function check_cat(f, A, B = 2A)
C = [A B]
D = [A; B]
x = NDArray(A)
y = NDArray(B)
z = NDArray(C)
d = NDArray(D)

if f == :hcat
@test copy([x y]) == [A B]
@test copy([x y 3y x]) == [A B 3B A]
@test copy([z y x]) == [C B A]
elseif f == :vcat
@test copy([x; y]) == [A; B]
@test copy([x; y; 3y; x]) == [A; B; 3B; A]
@test copy([x; d]) == [A; D]
@test copy([d; x]) == [D; A]
else
@assert false
end
end

let A = [1, 2, 3, 4]
info("NDArray::hcat::1D")
check_cat(:hcat, A)

info("NDArray::vcat::1D")
check_cat(:vcat, A)
end

let A = [1 2; 3 4]
info("NDArray::hcat::2D")
check_cat(:hcat, A)

info("NDArray::vcat::2D")
check_cat(:vcat, A)
end

let A = rand(4, 3, 2)
info("NDArray::hcat::3D")
check_cat(:hcat, A)

info("NDArray::vcat::3D")
check_cat(:vcat, A)
end

let A = rand(4, 3, 2, 2)
info("NDArray::hcat::4D")
check_cat(:hcat, A)

info("NDArray::vcat::4D")
check_cat(:vcat, A)
end

let A = [1, 2, 3, 4]
info("NDArray::cat::3D/1D")
check_cat(:vcat, reshape(A, 4, 1, 1), 2A)
end
end # function test_cat

function test_plus()
dims = rand_dims()
t1, a1 = rand_tensors(dims)
Expand Down Expand Up @@ -927,6 +988,7 @@ end # function test_hyperbolic
test_linear_idx()
test_first()
test_endof()
test_cat()
test_plus()
test_minus()
test_mul()
Expand Down

0 comments on commit 4507598

Please sign in to comment.