Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Don't force inlining--let Julia figure that out on its own #160

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ function show(io::IO, t::Tape)
end
end
end
@inline getindex(t::Tape, n::Int) = getindex(tape(t), n)
@inline getindex(t::Tape, node::Node) = getindex(t, pos(node))
@inline lastindex(t::Tape) = length(t)
@inline setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t)
@inline eachindex(t::Tape) = eachindex(tape(t))
@inline length(t::Tape) = length(tape(t))
@inline push!(t::Tape, node::Node) = (push!(tape(t), node); t)
@inline isassigned(t::Tape, n::Int) = isassigned(tape(t), n)
@inline isassigned(t::Tape, node::Node) = isassigned(t, pos(node))
getindex(t::Tape, n::Int) = getindex(tape(t), n)
getindex(t::Tape, node::Node) = getindex(t, pos(node))
lastindex(t::Tape) = length(t)
setindex!(t::Tape, x, n::Int) = (tape(t)[n] = x; t)
eachindex(t::Tape) = eachindex(tape(t))
length(t::Tape) = length(tape(t))
push!(t::Tape, node::Node) = (push!(tape(t), node); t)
isassigned(t::Tape, n::Int) = isassigned(tape(t), n)
isassigned(t::Tape, node::Node) = isassigned(t, pos(node))

# Make `Tape`s broadcast as scalars without a warning on 0.7
Base.Broadcast.broadcastable(tape::Tape) = Ref(tape)
Expand Down Expand Up @@ -122,7 +122,7 @@ zero(n::Node) = zero(unbox(n))
one(n::Node) = one(unbox(n))

# Leafs do nothing, Branches compute their own sensitivities and update others.
@inline propagate(y::Leaf, rvs_tape::Tape) = nothing
propagate(y::Leaf, rvs_tape::Tape) = nothing
function propagate(y::Branch, rvs_tape::Tape)
tape = Nabla.tape(rvs_tape)
ȳ, f = tape[pos(y)], getfield(y, :f)
Expand Down Expand Up @@ -177,11 +177,11 @@ is the output of `preprocess`. `x1`, `x2`,... are the inputs to the function, `y
output and `ȳ` the reverse-mode sensitivity of `y`.
"""
∇(y::Node, ȳ) = propagate(tape(y), reverse_tape(y, ȳ))
@inline ∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y)))
∇(y::Node{<:∇Scalar}) = ∇(y, one(unbox(y)))

# This is a fallback method where we don't necessarily know what we'll be adding and whether
# we can update the value in-place, so we'll try to be clever and dispatch.
@inline ∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...))
∇(x̄, f, ::Type{Arg{N}}, args...) where {N} = update!(x̄, ∇(f, Arg{N}, args...))

# Update regular arrays in-place. Structured array types should not be updated in-place,
# even though it technically "works" (https://github.com/JuliaLang/julia/issues/31674),
Expand Down Expand Up @@ -242,14 +242,14 @@ for (f_name, scalar_init, array_init) in
(:zero, :one, nothing),
(:zeros, :ones, nothing))
if scalar_init !== nothing
@eval @inline $f_name(x::Number) = $scalar_init(x)
@eval $f_name(x::Number) = $scalar_init(x)
end
if array_init !== nothing
@eval @inline $f_name(x::AbstractArray{<:Real}) = $array_init(eltype(x), size(x))
@eval $f_name(x::AbstractArray{<:Real}) = $array_init(eltype(x), size(x))
end
eval(quote
@inline $f_name(x::Tuple) = map($f_name, x)
@inline function $f_name(x)
$f_name(x::Tuple) = map($f_name, x)
function $f_name(x)
y = Base.copy(x)
for n in eachindex(y)
@inbounds y[n] = $f_name(y[n])
Expand All @@ -258,10 +258,10 @@ for (f_name, scalar_init, array_init) in
end
end)
end
@inline randned_container(x::Number) = randn(typeof(x))
@inline randned_container(x::AbstractArray{<:Real}) = randn(eltype(x), size(x)...)
randned_container(x::Number) = randn(typeof(x))
randned_container(x::AbstractArray{<:Real}) = randn(eltype(x), size(x)...)
for T in (:Diagonal, :UpperTriangular, :LowerTriangular)
@eval @inline randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...))
@eval randned_container(x::$T{<:Real}) = $T(randn(eltype(x), size(x)...))
end

# Bare-bones FMAD implementation based on DualNumbers. Accepts a Tuple of args and returns
Expand Down
20 changes: 10 additions & 10 deletions src/sensitivities/functional/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,39 +101,39 @@ _∇(::typeof(broadcast), ::Type{Arg{N}}, p, y, ȳ, f, A...) where N =
# Addition.
import Base: +
@eval @explicit_intercepts $(Symbol("+")) Tuple{∇Array, ∇Array}
@inline ∇(::typeof(+), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(::typeof(+), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(broadcast, Arg{2}, p, z, z̄, +, x, y)
@inline ∇(::typeof(+), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(::typeof(+), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(broadcast, Arg{3}, p, z, z̄, +, x, y)

# Multiplication.
import Base: *
@eval @explicit_intercepts $(Symbol("*")) Tuple{∇ArrayOrScalar, ∇ArrayOrScalar}
@inline ∇(::typeof(*), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(::typeof(*), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(broadcast, Arg{2}, p, z, z̄, *, x, y)
@inline ∇(::typeof(*), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(::typeof(*), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(broadcast, Arg{3}, p, z, z̄, *, x, y)

# Subtraction.
import Base: -
@eval @explicit_intercepts $(Symbol("-")) Tuple{∇Array, ∇Array}
@inline ∇(::typeof(-), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(::typeof(-), ::Type{Arg{1}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(broadcast, Arg{2}, p, z, z̄, -, x, y)
@inline ∇(::typeof(-), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(::typeof(-), ::Type{Arg{2}}, p, z, z̄, x::∇Array, y::∇Array) =
∇(broadcast, Arg{3}, p, z, z̄, -, x, y)

# Division from the right by a scalar.
import Base: /
@eval @explicit_intercepts $(Symbol("/")) Tuple{∇Array, ∇Scalar}
@inline ∇(::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(::typeof(/), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(broadcast, Arg{2}, p, z, z̄, /, x, y)
@inline ∇(::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(::typeof(/), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(broadcast, Arg{3}, p, z, z̄, /, x, y)

# Division from the left by a scalar.
import Base: \
@eval @explicit_intercepts $(Symbol("\\")) Tuple{∇Scalar, ∇Array}
@inline ∇(::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(::typeof(\), ::Type{Arg{1}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(broadcast, Arg{2}, p, z, z̄, \, x, y)
@inline ∇(::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(::typeof(\), ::Type{Arg{2}}, p, z, z̄, x::∇ArrayOrScalar, y::∇ArrayOrScalar) =
∇(broadcast, Arg{3}, p, z, z̄, \, x, y)
8 changes: 4 additions & 4 deletions src/sensitivities/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ using DiffRules: DiffRules, @define_diffrule, diffrule, diffrules, hasdiffrule
# gradient implemented for use in higher-order functions.
import Base.identity
@explicit_intercepts identity Tuple{Any}
@inline ∇(::typeof(identity), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ
@inline ∇(::typeof(identity), ::Type{Arg{1}}, x::Real) = one(x)
∇(::typeof(identity), ::Type{Arg{1}}, p, y, ȳ, x) = ȳ
∇(::typeof(identity), ::Type{Arg{1}}, x::Real) = one(x)

# Ignore functions that have complex ranges. This may change when Nabla supports complex
# numbers.
Expand All @@ -29,8 +29,8 @@ for (package, f, arity) in diffrules()
push!(unary_sensitivities, (package, f))
∂f∂x = diffrule(package, f, :x)
@eval @explicit_intercepts $f Tuple{∇Scalar}
@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x
@eval @inline ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x
@eval ∇(::typeof($f), ::Type{Arg{1}}, p, y, ȳ, x::∇Scalar) = ȳ * $∂f∂x
@eval ∇(::typeof($f), ::Type{Arg{1}}, x::∇Scalar) = $∂f∂x
elseif arity == 2
push!(binary_sensitivities, (package, f))
∂f∂x, ∂f∂y = diffrule(package, f, :x, :y)
Expand Down
2 changes: 1 addition & 1 deletion src/sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,4 @@ Default implementation of preprocess returns an empty Tuple. Individual sensitiv
implementations should add methods specific to their use case. The output is passed
in to `∇` as the 3rd or 4th argument in the new-x̄ and update-x̄ cases respectively.
"""
@inline preprocess(::Any, args...) = ()
preprocess(::Any, args...) = ()