Skip to content

Commit

Permalink
fix broadcast perf
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed Jun 1, 2024
1 parent 587ea7e commit dd3a56b
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ function ChainRulesCore.rrule(config::RuleConfig, pf::PrefixedFunction, args...)
return y, PrefixedFunctionPullback(back, num_input, num_f_args)
end

struct SkipFirstArg{F} <: Function
f::F
# https://github.com/FluxML/NNlib.jl/blob/7369244c1a64317eef5b0a20c142316947a85bb3/src/utils.jl#L131-L141
function _fast_broadcast2!(f::F, dst::Array, x, yz...) where {F<:Function}
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
return x
end
function _fast_broadcast2!(f::F, dst::AbstractArray, x, yz...) where {F<:Function}
return broadcast!(f, dst, x, yz...)
end
@inline (_f::SkipFirstArg)(dst, xs...) = _f.f(xs...)

using NNlib: _fast_broadcast!
@inline _fast_broadcast(f, x, yz...) = _fast_broadcast!(f, copy(x), yz...)
@inline _fast_broadcast2(f, x, yz...) = _fast_broadcast2!(f, similar(x), x, yz...)
@inline _fast_broadcast2!(f, dst, x, yz...) = _fast_broadcast!(SkipFirstArg(f), dst, x, yz...)

0 comments on commit dd3a56b

Please sign in to comment.