-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a170935
commit b4f03c5
Showing
9 changed files
with
98 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,47 @@ | ||
abstract type AbstractIndexer{N} <: AbstractArray{Bool, N} end | ||
|
||
struct Indexer{T <: AbstractMask, N, D<:Base.Dims{N}, Ns<:NamedTuple} <: AbstractIndexer{N} | ||
__fields::Ns | ||
dest_size::D | ||
function Indexer{T}(x::NamedTuple, dest_size::Base.Dims) where T | ||
N = length(dest_size) | ||
return new{T, N, typeof(dest_size), typeof(x)}(x, dest_size) | ||
struct Indexer{M <: AbstractMask, N} <: AbstractIndexer{N} | ||
mask::M | ||
destsize::Dims{N} | ||
function Indexer(mask::AbstractMask, destsize::Dims{N}) where N | ||
m = adapt(Indexer, mask) | ||
return new{typeof(m), N}(m, destsize) | ||
end | ||
end | ||
|
||
function Indexer(m::AbstractMask, dest_size::Base.Dims) | ||
if @generated | ||
ex = Expr(:tuple) | ||
for i = 1:fieldcount(m) | ||
fn = fieldname(m, i) | ||
ft = fieldtype(m, i) | ||
expr = :(getfield(m, $(QuoteNode(fn)))) | ||
if ft <: AbstractMask | ||
expr = :(Indexer($expr, dest_size)) | ||
elseif ft <: Tuple{Vararg{AbstractMask}} | ||
expr = :(map(Base.Fix2(Indexer, dest_size), $expr)) | ||
end | ||
push!(ex.args, Expr(:(=), fn, expr)) | ||
end | ||
if isempty(ex.args) | ||
ex = :(NamedTuple()) | ||
end | ||
T = Base.typename(m).wrapper | ||
ret = quote | ||
vs = $ex | ||
return Indexer{$T}(vs, dest_size) | ||
end | ||
return ret | ||
else | ||
T = typeof(m) | ||
vs = NamedTuple{fieldnames(T)}( | ||
ntuple(fieldcount(T)) do i | ||
v = getfield(m, i) | ||
if v isa AbstractMask | ||
v = Indexer(v, dest_size) | ||
elseif v isa Tuple{Vararg{AbstractMask}} | ||
v = map(Base.Fix2(Indexer, dest_size), v) | ||
end | ||
return v | ||
end | ||
) | ||
return Indexer{Base.typename(T).wrapper}(vs, dest_size) | ||
end | ||
function GetIndexer(mask::AbstractMask, destsize::Dims) | ||
check_constraint(AxesConstraint(mask), destsize) | ||
return Indexer(mask, destsize) | ||
end | ||
|
||
Base.length(I::Indexer) = prod(size(I)) | ||
Base.size(I::Indexer) = getfield(I, :dest_size) | ||
|
||
IndexedType(::Indexer{T}) where T = T | ||
|
||
set_dest_size(x, dest_size::Base.Dims) = x | ||
set_dest_size(t::Tuple{Vararg{Indexer}}, dest_size::Base.Dims) = map(Base.Fix2(set_dest_size, dest_size), t) | ||
set_dest_size(I::Indexer, dest_size::Base.Dims) = | ||
Indexer{IndexedType(I)}(map(Base.Fix2(set_dest_size, dest_size), getfield(I, :__fields)), dest_size) | ||
|
||
function Base.getproperty(I::Indexer, x::Symbol) | ||
fs = getfield(I, :__fields) | ||
haskey(fs, x) && return fs[x] | ||
x == :__fields && return fs | ||
x == :dest_size && return getfield(I, :dest_size) | ||
error("type Indexer{$(IndexedType(I))} has no field $x") | ||
end | ||
|
||
const MaskIndexer = Indexer{<:AbstractMask} | ||
Base.@propagate_inbounds Broadcast.newindex(arg::MaskIndexer, I::CartesianIndex) = I | ||
Base.@propagate_inbounds Broadcast.newindex(arg::MaskIndexer, I::Integer) = I | ||
Base.eltype(::MaskIndexer) = Bool | ||
Base.@propagate_inbounds Base.getindex(m::MaskIndexer, i::CartesianIndex) = m[Tuple(i)] | ||
Base.@propagate_inbounds Base.getindex(m::MaskIndexer, I::Tuple) = m[I...] | ||
|
||
function GetIndexer(m::AbstractMask, dest_size::Base.Dims) | ||
check_constraint(AxesConstraint(m), dest_size) | ||
return Indexer(m, dest_size) | ||
end | ||
Base.size(I::Indexer) = getfield(I, :destsize) | ||
Base.eltype(::Indexer) = Bool | ||
|
||
@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Integer...) = __maskgetindex__(m.destsize, m.mask, I...) | ||
@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Tuple) = __maskgetindex__(m.destsize, m.mask, I...) | ||
|
||
using Adapt | ||
import Adapt: adapt_structure | ||
adapt_structure(to, x::Indexer) = Indexer{IndexedType(x)}(adapt(to, getfield(x, :__fields)), getfield(x, :dest_size)) | ||
Base.print_array(io::IO, I::Indexer) = invoke(Base.print_array, Tuple{IO, AbstractArray{Bool, ndims(I)}}, io, Adapt.adapt(Array, I)) | ||
adapt_structure(to, m::Indexer) = Indexer(adapt(to, m.mask), m.destsize) | ||
Base.print_array(io::IO, m::Indexer) = invoke(Base.print_array, Tuple{IO, AbstractArray{Bool, ndims(m)}}, io, Adapt.adapt(Array, m)) | ||
|
||
using FuncTransforms: FuncTransforms, FuncTransform, FA, VA | ||
function _maskgetindex_generator(world, source, self, destsize, mask, I) | ||
caller = Core.Compiler.specialize_method( | ||
FuncTransforms.method_by_ftype(Tuple{self, destsize, mask, I...}, nothing, world)) | ||
sig = Base.to_tuple_type((typeof(maskgetindex), destsize, mask, I...)) | ||
ft = FuncTransform(sig, world, [FA(:maskgetindex, 1), FA(:destsize, 2), FA(:mask, 3), VA(:I, 3)]; caller) | ||
for (ssavalue, code) in FuncTransforms.FuncInfoIter(ft.fi) | ||
stmt = code.stmt | ||
newstmt = FuncTransforms.walk(stmt) do x | ||
FuncTransforms.resolve(x) isa typeof(maskgetindex) ? FuncTransforms.getparg(ft.fi, 1) : x | ||
end | ||
FuncTransforms.inlineflag!(code) | ||
code.stmt = newstmt | ||
end | ||
ci = FuncTransforms.toCodeInfo(ft; inline = true, propagate_inbounds = true) | ||
return ci | ||
end | ||
@eval function __maskgetindex__(destsize::Dims, mask::AbstractMask, I::Integer...) | ||
$(Expr(:meta, :generated, _maskgetindex_generator)) | ||
$(Expr(:meta, :generated_only)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.