Skip to content

Commit

Permalink
rework Indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed May 28, 2024
1 parent a170935 commit b4f03c5
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 161 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.2.13"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FuncTransforms = "79abecb7-a74d-442d-bb0e-6136fbda6b73"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -31,6 +32,7 @@ Adapt = "3.3, 4"
CUDA = "5"
ChainRulesCore = "1.3"
FiniteDifferences = "0.12"
FuncTransforms = "0.1"
GPUArrays = "10"
GPUArraysCore = "0.1"
NNlib = "0.9"
Expand Down
16 changes: 9 additions & 7 deletions src/mask/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Base.ndims(::GenericAttenMask{N}) where N = N

adapt_structure(to, x::GenericAttenMask) = GenericAttenMask(adapt(to, x.mask))

Base.@propagate_inbounds Base.getindex(m::Indexer{<:GenericAttenMask}, I::Integer...) = m.mask[I...]
Base.@propagate_inbounds maskgetindex(::Dims, m::GenericAttenMask, I::Integer...) = m.mask[I...]

AxesConstraint(m::GenericAttenMask) = (NDimConstraint(ndims(m)), ntuple(i->DimConstraint(i, size(m.mask, i), i <= 2), ndims(m))...)

Expand All @@ -24,7 +24,7 @@ SymLengthMask(len::AbstractArray) = SymLengthMask(convert(AbstractArray{Int32},

adapt_structure(to, x::SymLengthMask) = SymLengthMask(adapt(to, x.len))

Base.@propagate_inbounds function Base.getindex(m::Indexer{<:SymLengthMask}, i::Integer, j::Integer, J::Integer...)
Base.@propagate_inbounds function maskgetindex(::Dims, m::SymLengthMask, i::Integer, j::Integer, J::Integer...)
l = m.len[J...]
return i <= l && j <= l
end
Expand Down Expand Up @@ -53,7 +53,7 @@ end

adapt_structure(to, x::BiLengthMask) = BiLengthMask(adapt(to, x.q_len), adapt(to, x.k_len))

Base.@propagate_inbounds function Base.getindex(m::Indexer{<:BiLengthMask}, i::Integer, j::Integer, J::Integer...)
Base.@propagate_inbounds function maskgetindex(::Dims, m::BiLengthMask, i::Integer, j::Integer, J::Integer...)
ql = m.q_len[J...]
kl = m.k_len[J...]
return i <= kl && j <= ql
Expand All @@ -80,8 +80,9 @@ RevSymLengthMask(len::AbstractArray) = RevSymLengthMask(convert(AbstractArray{In

adapt_structure(to, x::RevSymLengthMask) = RevSymLengthMask(adapt(to, x.len))

Base.@propagate_inbounds function Base.getindex(m::Indexer{<:RevSymLengthMask}, i::Integer, j::Integer, J::Integer...)
rl, cl = m.dest_size
Base.@propagate_inbounds function maskgetindex(destsize::Dims, m::RevSymLengthMask, i::Integer, j::Integer, J::Integer...)
rl = destsize[1]
cl = destsize[2]
l = m.len[J...]
return rl - l < i && cl - l < j
end
Expand Down Expand Up @@ -110,8 +111,9 @@ end

adapt_structure(to, x::RevBiLengthMask) = RevBiLengthMask(adapt(to, x.q_len), adapt(to, x.k_len))

Base.@propagate_inbounds function Base.getindex(m::Indexer{<:RevBiLengthMask}, i::Integer, j::Integer, J::Integer...)
rl, cl = m.dest_size
Base.@propagate_inbounds function maskgetindex(destsize::Dims, m::RevBiLengthMask, i::Integer, j::Integer, J::Integer...)
rl = destsize[1]
cl = destsize[2]
ql = m.q_len[J...]
kl = m.k_len[J...]
return rl - kl < i && cl - ql < j
Expand Down
8 changes: 4 additions & 4 deletions src/mask/dataless.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ AxesConstraint(::AbstractAttenMask{DATALESS}) = (NDimConstraint(2, true),)

struct CausalMask <: AbstractAttenMask{DATALESS} end

Base.@propagate_inbounds Base.getindex(::Indexer{CausalMask}, i::Integer, j::Integer, _::Integer...) = j >= i
Base.@propagate_inbounds maskgetindex(::Dims, ::CausalMask, i::Integer, j::Integer, _::Integer...) = j >= i

struct LocalMask <: AbstractAttenMask{DATALESS}
width::Int
end

Base.@propagate_inbounds Base.getindex(m::Indexer{LocalMask}, i::Integer, j::Integer, _::Integer...) = j - m.width < i < j + m.width
Base.@propagate_inbounds maskgetindex(::Dims, m::LocalMask, i::Integer, j::Integer, _::Integer...) = j - m.width < i < j + m.width

struct RandomMask <: AbstractAttenMask{DATALESS}
p::Float64
end

Base.@propagate_inbounds Base.getindex(m::Indexer{RandomMask}, _::Integer...) = rand() > m.p
Base.@propagate_inbounds maskgetindex(::Dims, m::RandomMask, _::Integer...) = rand() > m.p

AxesConstraint(m::RandomMask) = ()
randomness(::RandomMask) = static(true)
Expand All @@ -26,4 +26,4 @@ struct BandPartMask <: AbstractAttenMask{DATALESS}
u::Int
end

Base.@propagate_inbounds Base.getindex(m::Indexer{BandPartMask}, i::Integer, j::Integer, _::Integer...) = (m.l < 0 || i <= j + m.l) && (m.u < 0 || i >= j - m.u)
Base.@propagate_inbounds maskgetindex(::Dims, m::BandPartMask, i::Integer, j::Integer, _::Integer...) = (m.l < 0 || i <= j + m.l) && (m.u < 0 || i >= j - m.u)
6 changes: 2 additions & 4 deletions src/mask/grad.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using Base.Broadcast: BroadcastFunction, broadcasted, materialize

ChainRulesCore.@non_differentiable Base.getindex(m::AbstractMask, I::Integer...)
ChainRulesCore.@non_differentiable Base.getindex(m::MaskIndexer, I::Integer...)
ChainRulesCore.@non_differentiable Base.getindex(m::AbstractMask, I::Tuple)
ChainRulesCore.@non_differentiable Base.getindex(m::MaskIndexer, I::Tuple)
ChainRulesCore.@non_differentiable Base.getindex(m::Indexer, I...)
ChainRulesCore.@non_differentiable maskgetindex(::Dims, ::AbstractMask, I::Integer...)
ChainRulesCore.@non_differentiable (::Type{<:AbstractMask})(args...)
ChainRulesCore.@non_differentiable (::Type{<:AbstractMaskOp})(args...)
ChainRulesCore.@non_differentiable getmask(arg...)
Expand Down
113 changes: 37 additions & 76 deletions src/mask/indexer.jl
Original file line number Diff line number Diff line change
@@ -1,86 +1,47 @@
abstract type AbstractIndexer{N} <: AbstractArray{Bool, N} end

struct Indexer{T <: AbstractMask, N, D<:Base.Dims{N}, Ns<:NamedTuple} <: AbstractIndexer{N}
__fields::Ns
dest_size::D
function Indexer{T}(x::NamedTuple, dest_size::Base.Dims) where T
N = length(dest_size)
return new{T, N, typeof(dest_size), typeof(x)}(x, dest_size)
struct Indexer{M <: AbstractMask, N} <: AbstractIndexer{N}
mask::M
destsize::Dims{N}
function Indexer(mask::AbstractMask, destsize::Dims{N}) where N
m = adapt(Indexer, mask)
return new{typeof(m), N}(m, destsize)
end
end

function Indexer(m::AbstractMask, dest_size::Base.Dims)
if @generated
ex = Expr(:tuple)
for i = 1:fieldcount(m)
fn = fieldname(m, i)
ft = fieldtype(m, i)
expr = :(getfield(m, $(QuoteNode(fn))))
if ft <: AbstractMask
expr = :(Indexer($expr, dest_size))
elseif ft <: Tuple{Vararg{AbstractMask}}
expr = :(map(Base.Fix2(Indexer, dest_size), $expr))
end
push!(ex.args, Expr(:(=), fn, expr))
end
if isempty(ex.args)
ex = :(NamedTuple())
end
T = Base.typename(m).wrapper
ret = quote
vs = $ex
return Indexer{$T}(vs, dest_size)
end
return ret
else
T = typeof(m)
vs = NamedTuple{fieldnames(T)}(
ntuple(fieldcount(T)) do i
v = getfield(m, i)
if v isa AbstractMask
v = Indexer(v, dest_size)
elseif v isa Tuple{Vararg{AbstractMask}}
v = map(Base.Fix2(Indexer, dest_size), v)
end
return v
end
)
return Indexer{Base.typename(T).wrapper}(vs, dest_size)
end
function GetIndexer(mask::AbstractMask, destsize::Dims)
check_constraint(AxesConstraint(mask), destsize)
return Indexer(mask, destsize)
end

Base.length(I::Indexer) = prod(size(I))
Base.size(I::Indexer) = getfield(I, :dest_size)

IndexedType(::Indexer{T}) where T = T

set_dest_size(x, dest_size::Base.Dims) = x
set_dest_size(t::Tuple{Vararg{Indexer}}, dest_size::Base.Dims) = map(Base.Fix2(set_dest_size, dest_size), t)
set_dest_size(I::Indexer, dest_size::Base.Dims) =
Indexer{IndexedType(I)}(map(Base.Fix2(set_dest_size, dest_size), getfield(I, :__fields)), dest_size)

function Base.getproperty(I::Indexer, x::Symbol)
fs = getfield(I, :__fields)
haskey(fs, x) && return fs[x]
x == :__fields && return fs
x == :dest_size && return getfield(I, :dest_size)
error("type Indexer{$(IndexedType(I))} has no field $x")
end

const MaskIndexer = Indexer{<:AbstractMask}
Base.@propagate_inbounds Broadcast.newindex(arg::MaskIndexer, I::CartesianIndex) = I
Base.@propagate_inbounds Broadcast.newindex(arg::MaskIndexer, I::Integer) = I
Base.eltype(::MaskIndexer) = Bool
Base.@propagate_inbounds Base.getindex(m::MaskIndexer, i::CartesianIndex) = m[Tuple(i)]
Base.@propagate_inbounds Base.getindex(m::MaskIndexer, I::Tuple) = m[I...]

function GetIndexer(m::AbstractMask, dest_size::Base.Dims)
check_constraint(AxesConstraint(m), dest_size)
return Indexer(m, dest_size)
end
Base.size(I::Indexer) = getfield(I, :destsize)
Base.eltype(::Indexer) = Bool

@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Integer...) = __maskgetindex__(m.destsize, m.mask, I...)
@inline Base.@propagate_inbounds Base.getindex(m::Indexer, I::Tuple) = __maskgetindex__(m.destsize, m.mask, I...)

using Adapt
import Adapt: adapt_structure
adapt_structure(to, x::Indexer) = Indexer{IndexedType(x)}(adapt(to, getfield(x, :__fields)), getfield(x, :dest_size))
Base.print_array(io::IO, I::Indexer) = invoke(Base.print_array, Tuple{IO, AbstractArray{Bool, ndims(I)}}, io, Adapt.adapt(Array, I))
adapt_structure(to, m::Indexer) = Indexer(adapt(to, m.mask), m.destsize)
Base.print_array(io::IO, m::Indexer) = invoke(Base.print_array, Tuple{IO, AbstractArray{Bool, ndims(m)}}, io, Adapt.adapt(Array, m))

using FuncTransforms: FuncTransforms, FuncTransform, FA, VA
function _maskgetindex_generator(world, source, self, destsize, mask, I)
caller = Core.Compiler.specialize_method(
FuncTransforms.method_by_ftype(Tuple{self, destsize, mask, I...}, nothing, world))
sig = Base.to_tuple_type((typeof(maskgetindex), destsize, mask, I...))
ft = FuncTransform(sig, world, [FA(:maskgetindex, 1), FA(:destsize, 2), FA(:mask, 3), VA(:I, 3)]; caller)
for (ssavalue, code) in FuncTransforms.FuncInfoIter(ft.fi)
stmt = code.stmt
newstmt = FuncTransforms.walk(stmt) do x
FuncTransforms.resolve(x) isa typeof(maskgetindex) ? FuncTransforms.getparg(ft.fi, 1) : x
end
FuncTransforms.inlineflag!(code)
code.stmt = newstmt
end
ci = FuncTransforms.toCodeInfo(ft; inline = true, propagate_inbounds = true)
return ci
end
@eval function __maskgetindex__(destsize::Dims, mask::AbstractMask, I::Integer...)
$(Expr(:meta, :generated, _maskgetindex_generator))
$(Expr(:meta, :generated_only))
end
31 changes: 10 additions & 21 deletions src/mask/mask.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,30 @@
Base.@enum MASKDATA::UInt8 DATALESS ARRAYDATA MIXDATA
Base.@enum MASKTYPE::UInt8 ATTENTION SEQUENCE MIXTYPE
const MASKTAG = Union{MASKDATA, MASKTYPE}
abstract type AbstractMask{D, T} end

abstract type AbstractWrapperMask{D, T} <: AbstractMask{D, T} end

const AbstractAttenMask{D} = AbstractMask{D, ATTENTION}
const AbstractSeqMask{D} = AbstractMask{D, SEQUENCE}
const AbstractArrayMask{T} = AbstractMask{ARRAYDATA, T}
const AbstractDatalessMask{T} = AbstractMask{DATALESS, T}

const AbstractDatalessAttenMask = AbstractAttenMask{DATALESS}
const AbstractArrayDataAttenMask = AbstractAttenMask{ARRAYDATA}
const AbstractDatalessSeqMask = AbstractSeqMask{DATALESS}
const AbstractArrayDataSeqMask = AbstractSeqMask{ARRAYDATA}

MASKDATA(t::MASKDATA) = t
MASKTYPE(t::MASKTYPE) = t
MASKDATA(::AbstractMask{D, T}) where {D, T} = D
MASKTYPE(::AbstractMask{D, T}) where {D, T} = T
MASKDATA(t1::MASKDATA, t2::MASKDATA) = t1 == t2 ? t1 : MIXDATA
MASKTYPE(t1::MASKTYPE, t2::MASKTYPE) = t1 == t2 ? t1 : MIXTYPE

_combine_masktag(f, t1::T, t2::T) where {T <: Union{MASKDATA, MASKTYPE}} = f(t1, t2)
_combine_masktag(f, t::Tuple{T}) where {T <: Union{MASKDATA, MASKTYPE}} = t[1]
_combine_masktag(f, t::NTuple{2, T}) where {T <: Union{MASKDATA, MASKTYPE}} = _combine_masktag(f, t[1], t[2])
function _combine_masktag(f, t::Tuple{T, T, T, Vararg{T}}) where {T <: Union{MASKDATA, MASKTYPE}}
return _combine_masktag(f, _combine_masktag(f, t[1], t[2]), Base.tail(Base.tail(t)))
end
_combine_masktag(f, t0::T, ::Tuple{}) where {T <: Union{MASKDATA, MASKTYPE}} = t0
function _combine_masktag(f, t0::T, t::Tuple{T, Vararg{T}}) where {T <: Union{MASKDATA, MASKTYPE}}
return _combine_masktag(f, _combine_masktag(f, t0, t[1]), Base.tail(t))
end

function _combine_masktag(f, t0::T, m::Tuple{AbstractMask, Vararg{AbstractMask}}) where {T <: Union{MASKDATA, MASKTYPE}}
_combine_masktag(f, _combine_masktag(f, t0, T(m[1])), Base.tail(m))
end
function _combine_masktag(f::Type{T}, m::Tuple{AbstractMask, Vararg{AbstractMask}}) where {T <: Union{MASKDATA, MASKTYPE}}
return _combine_masktag(f, T(m[1]), Base.tail(m))
end

combine_maskdatatag(args...) = _combine_masktag(MASKDATA, args...)
combine_masktypetag(args...) = _combine_masktag(MASKTYPE, args...)
_combine_masktag(::Type{T}, t::NTuple{1}) where T <: MASKTAG = T(t[1])
_combine_masktag(::Type{T}, t::Tuple) where T <: MASKTAG = T(T(first(t)), _combine_masktag(T, Base.tail(t)))
combine_maskdatatag(args...) = _combine_masktag(MASKDATA, args)
combine_masktypetag(args...) = _combine_masktag(MASKTYPE, args)

"""
AbstractMask
Expand Down
8 changes: 4 additions & 4 deletions src/mask/sequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Base.ndims(::GenericSequenceMask{N}) where N = N

adapt_structure(to, x::GenericSequenceMask{N}) where N = GenericSequenceMask{N}(adapt(to, x.mask))

Base.@propagate_inbounds Base.getindex(m::Indexer{<:GenericSequenceMask}, I::Integer...) = m.mask[1, Base.tail(I)...]
Base.@propagate_inbounds maskgetindex(::Dims, m::GenericSequenceMask, I::Integer...) = m.mask[1, Base.tail(I)...]

AxesConstraint(m::GenericSequenceMask) = (NDimConstraint(ndims(m)), ntuple(i->DimConstraint(i+1, size(m.mask, i+1), i < 2), ndims(m)-1)...)

Expand All @@ -45,7 +45,7 @@ LengthMask(len::AbstractArray) = LengthMask(convert(AbstractArray{Int32}, len))

adapt_structure(to, x::LengthMask) = LengthMask(adapt(to, x.len))

Base.@propagate_inbounds function Base.getindex(m::Indexer{<:LengthMask}, _::Integer, j::Integer, J::Integer...)
Base.@propagate_inbounds function maskgetindex(::Dims, m::LengthMask, _::Integer, j::Integer, J::Integer...)
l = m.len[J...]
return j <= l
end
Expand Down Expand Up @@ -78,8 +78,8 @@ RevLengthMask(len::AbstractArray) = RevLengthMask(convert(AbstractArray{Int32},

adapt_structure(to, x::RevLengthMask) = RevLengthMask(adapt(to, x.len))

Base.@propagate_inbounds function Base.getindex(m::Indexer{<:RevLengthMask}, _::Integer, j::Integer, J::Integer...)
cl = m.dest_size[2]
Base.@propagate_inbounds function maskgetindex(destsize::Dims, m::RevLengthMask, _::Integer, j::Integer, J::Integer...)
cl = destsize[2]
l = m.len[J...]
return cl - l < j
end
Expand Down
Loading

0 comments on commit b4f03c5

Please sign in to comment.