Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve broadcast_dims to work directly with Dimensions #775

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/src/broadcast_dims.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ We can see the means of each month are scaled by the broadcast :
mean(eachslice(data; dims=(X, Y)))
mean(eachslice(scaled; dims=(X, Y)))
````

Broadcasting also works directly over `Dimension`s (or references to them).
For example, a new `DimArray` can be constructed by broadcasting a function over a set of dimensions:

````@ansi bd
broadcast_dims(*, x, y)
````

Existing dimensions can be referenced by name, in which case its lookup values are used:

````@ansi bd
broadcast_dims(*, data, X) # or `:X` or `X(:)`
````
102 changes: 94 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,111 @@ function modify(f, index::AbstractArray)
end

"""
broadcast_dims(f, sources::AbstractDimArray...) => AbstractDimArray
broadcast_dims(f, sources::Union{AbstractDimArray, Dimension, Symbol}...) => AbstractDimArray

Broadcast function `f` over the `AbstractDimArray`s in `sources`, permuting and reshaping
dimensions to match where required. The result will contain all the dimensions in
all passed in arrays in the order in which they are found.
Broadcast function `f` over the `AbstractDimArray`s, and/or `Dimension`s in `sources`, permuting and reshaping
dimensions to match where required. The result will contain all the dimensions in all passed in arrays in the
order in which they are found.

## Arguments
Existing dimensions can be referenced by e.g. `X`, `:X`, `X(:)`, `X(1.0:0.5:10.0)`.
New dimensions can be passed, but must have an explicit lookup, e.g. `X(1.0:0.5:10.0)`.

- `sources`: `AbstractDimArrays` to broadcast over with `f`.
# Arguments

- `sources`: `AbstractDimArrays`, `Dimension`s, `Symbol`s, to broadcast over with `f`.

This is like broadcasting over every slice of `A` if it is sliced by the dimensions of `B`.

# Throws
- `ArgumentError` if a `Dimension` without explicit lookup values is passed and it is not found among the passed in `DimArray`s.

# Extended help

## Examples

In the simplest use case, `broadcast_dims` can be used to construct a `DimArray` from multiple `Dimension`s:
```julia
julia> x, y, z = X(1:2:6), Y(10.5:1.0:13.5), Z(-0.5:0.5:0.5)
↓ X 1:2:5,
→ Y 10.5:1.0:13.5,
↗ Z -0.5:0.5:0.5

julia> A = broadcast_dims(*, x, y)
╭─────────────────────────╮
│ 3×4 DimArray{Float64,2} │
├─────────────────────────┴────────────────────────────────── dims ┐
↓ X Sampled{Int64} 1:2:5 ForwardOrdered Regular Points,
→ Y Sampled{Float64} 10.5:1.0:13.5 ForwardOrdered Regular Points
└──────────────────────────────────────────────────────────────────┘
↓ → 10.5 11.5 12.5 13.5
1 10.5 11.5 12.5 13.5
3 31.5 34.5 37.5 40.5
5 52.5 57.5 62.5 67.5
```

This is like broadcasting over every slice of `A` if it is
sliced by the dimensions of `B`.
We can also implicitly refer to existing dimensions in `DimArray`s:
```julia
julia> B = ones(x, y);

julia> broadcast_dims(+, B, Y) # also `Y(:)`, or `:Y` works
╭─────────────────────────╮
│ 3×4 DimArray{Float64,2} │
├─────────────────────────┴────────────────────────────────── dims ┐
↓ X Sampled{Int64} 1:2:5 ForwardOrdered Regular Points,
→ Y Sampled{Float64} 10.5:1.0:13.5 ForwardOrdered Regular Points
└──────────────────────────────────────────────────────────────────┘
↓ → 10.5 11.5 12.5 13.5
1 11.5 12.5 13.5 14.5
3 11.5 12.5 13.5 14.5
5 11.5 12.5 13.5 14.5
```

Finally, we can mix and match `DimArray`s and `Dimension`s:
```julia
julia> broadcast_dims(+, A, B, z)
╭───────────────────────────╮
│ 3×4×3 DimArray{Float64,3} │
├───────────────────────────┴───────────────────────────────── dims ┐
↓ X Sampled{Int64} 1:2:5 ForwardOrdered Regular Points,
→ Y Sampled{Float64} 10.5:1.0:13.5 ForwardOrdered Regular Points,
↗ Z Sampled{Float64} -0.5:0.5:0.5 ForwardOrdered Regular Points
└───────────────────────────────────────────────────────────────────┘
[:, :, 1]
↓ → 10.5 11.5 12.5 13.5
1 11.0 12.0 13.0 14.0
3 32.0 35.0 38.0 41.0
5 53.0 58.0 63.0 68.0
```
"""
function broadcast_dims(f, As::AbstractBasicDimArray...)
dims = combinedims(As...)
T = Base.Broadcast.combine_eltypes(f, As)
broadcast_dims!(f, similar(first(As), T, dims), As...)
end

function broadcast_dims(f, As::Union{AbstractBasicDimArray, Dimensions.Dimension, Type{<:Dimension}, Symbol}...)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function broadcast_dims(f, As::Union{AbstractBasicDimArray, Dimensions.Dimension, Type{<:Dimension}, Symbol}...)
function broadcast_dims(f, As::Union{AbstractBasicDimArray,Dimension,Type{<:Dimension},Symbol}...)

# We have to look up dims for any actual DimArrays first if support for `X`, `Ti`, `:X`, etc, as input should work,
# because we need the lookup array
existing_dims = combinedims(filter(Base.Fix2(isa, AbstractBasicDimArray), As)...)
Bs = map(As) do A
if A isa Dimension && !(parent(A) isa Colon)
# A dimension is explicitly passed, so use it
DimArray(parent(A), A)
elseif A isa Dimension || A isa Type{<:Dimension} || A isa Symbol
# If a reference to a dimension, e.g. `X(:)`, `X` or `:X` is passed, look up values from `existing_dims`
dim = dims(existing_dims, A)
# If `A` isn't among the existing dimensions, and since we don't have its lookup values, we can't proceed
isnothing(dim) && throw(ArgumentError("Dimension $A not found among the passed in `DimArray`s"))
# otherwise, construct a `DimArray` with the looked up values
DimArray(parent(dim), dim)
else
# finally, if it's actually a `DimArray`, just pass it through
A
end
end # map(As)
broadcast_dims(f, Bs...)
end

function broadcast_dims(f, As::Union{AbstractDimStack,AbstractBasicDimArray}...)
st = _firststack(As...)
nts = _as_extended_nts(NamedTuple(st), As...)
Expand Down
24 changes: 23 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ end
@test dc1 == [2, 4, 6]
dc2 = broadcast_dims(+, da2, db1)
@test dc2 == [2 4 6; 5 7 9]
dc4 = broadcast_dims(+, da2, db1)

A3 = cat([1 2 3; 4 5 6], [11 12 13; 14 15 16]; dims=3)
da3 = DimArray(A3, (X, Y, Z))
Expand All @@ -157,6 +156,29 @@ end
dc3 = broadcast_dims(+, da3, db1)
@test dc3 == cat([2 4 6; 5 7 9], [12 14 16; 15 17 19]; dims=3)

@testset "works directly with Dimensions" begin
x, y, z = X([1, 2, 3]), Y([1, 2]), Z([0.1])

# construct a DimArray from dimensions, using `broadcast_dims`
da_from_dims = broadcast_dims(+, x, y)
@test da_from_dims == [2 3; 3 4; 4 5]

# different ways to refer to existing dimensions
da_with_reference_da = broadcast_dims(+, da_from_dims, DimArray(parent(y), y)) # reference computation
da_and_existing_dims = broadcast_dims(+, da_from_dims, Y)
da_and_existing_dims2 = broadcast_dims(+, da_from_dims, Y(:))
da_and_existing_dims3 = broadcast_dims(+, da_from_dims, :Y)
da_and_existing_dims4 = broadcast_dims(+, da_from_dims, y)
@test da_and_existing_dims == [3 5; 4 6; 5 7]
@test da_and_existing_dims == da_with_reference_da
@test da_and_existing_dims == da_and_existing_dims2
@test da_and_existing_dims == da_and_existing_dims3
@test da_and_existing_dims == da_and_existing_dims4

# combine `DimArray` and `Dimension`
da_and_new_dims = broadcast_dims(+, da_from_dims, z)
@test da_and_new_dims == [2.1 3.1; 3.1 4.1; 4.1 5.1;;;]
end
@testset "works with permuted dims" begin
db2p = permutedims(da2)
dc3p = broadcast_dims(+, da3, db2p)
Expand Down
Loading