Skip to content

Commit

Permalink
Organize code out of compiler.jl (#2137)
Browse files Browse the repository at this point in the history
* Organize code out of compiler.jl

* More cleanup

* more cleaning

* add file

* Add file

* fix

* Update sugar.jl

* Update sugar.jl

* Update sugar.jl

* fixup

* fix

* Update sugar.jl

* Update sugar.jl

* Update sugar.jl
  • Loading branch information
wsmoses authored Nov 29, 2024
1 parent a207b27 commit 2bfc9b5
Show file tree
Hide file tree
Showing 9 changed files with 3,507 additions and 3,494 deletions.
1,036 changes: 3 additions & 1,033 deletions src/Enzyme.jl

Large diffs are not rendered by default.

457 changes: 457 additions & 0 deletions src/analyses/activity.jl

Large diffs are not rendered by default.

File renamed without changes.
2,567 changes: 106 additions & 2,461 deletions src/compiler.jl

Large diffs are not rendered by default.

640 changes: 640 additions & 0 deletions src/errors.jl

Large diffs are not rendered by default.

1,060 changes: 1,060 additions & 0 deletions src/llvm/attributes.jl

Large diffs are not rendered by default.

1,155 changes: 1,155 additions & 0 deletions src/sugar.jl

Large diffs are not rendered by default.

File renamed without changes.
86 changes: 86 additions & 0 deletions src/typeutils/recursive_add.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Recursively return x + f(y), where y is active, otherwise x

@inline function recursive_add(
x::T,
y::T,
f::F = identity,
forcelhs::F2 = guaranteed_const,
) where {T,F,F2}
if forcelhs(T)
return x
end
splatnew(T, ntuple(Val(fieldcount(T))) do i
Base.@_inline_meta
prev = getfield(x, i)
next = getfield(y, i)
recursive_add(prev, next, f, forcelhs)
end)
end

@inline function recursive_add(
x::T,
y::T,
f::F = identity,
forcelhs::F2 = guaranteed_const,
) where {T<:AbstractFloat,F,F2}
if forcelhs(T)
return x
end
return x + f(y)
end

@inline function recursive_add(
x::T,
y::T,
f::F = identity,
forcelhs::F2 = guaranteed_const,
) where {T<:Complex,F,F2}
if forcelhs(T)
return x
end
return x + f(y)
end

@inline mutable_register(::Type{T}) where {T<:Integer} = true
@inline mutable_register(::Type{T}) where {T<:AbstractFloat} = false
@inline mutable_register(::Type{Complex{T}}) where {T<:AbstractFloat} = false
@inline mutable_register(::Type{T}) where {T<:Tuple} = false
@inline mutable_register(::Type{T}) where {T<:NamedTuple} = false
@inline mutable_register(::Type{Core.Box}) = true
@inline mutable_register(::Type{T}) where {T<:Array} = true
@inline mutable_register(::Type{T}) where {T} = ismutabletype(T)

# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y)
@inline function recursive_accumulate(x::Array{T}, y::Array{T}, f::F = identity) where {T,F}
if !mutable_register(T)
for I in eachindex(x)
prev = x[I]
@inbounds x[I] = recursive_add(x[I], (@inbounds y[I]), f, mutable_register)
end
end
end


# Recursively In-place accumulate(aka +=). E.g. generalization of x .+= f(y)
@inline function recursive_accumulate(x::Core.Box, y::Core.Box, f::F = identity) where {F}
recursive_accumulate(x.contents, y.contents, seen, f)
end

@inline function recursive_accumulate(x::T, y::T, f::F = identity) where {T,F}
@assert !Base.isabstracttype(T)
@assert Base.isconcretetype(T)
nf = fieldcount(T)

for i = 1:nf
if isdefined(x, i)
xi = getfield(x, i)
ST = Core.Typeof(xi)
if !mutable_register(ST)
@assert ismutable(x)
yi = getfield(y, i)
nexti = recursive_add(xi, yi, f, mutable_register)
setfield!(x, i, nexti)
end
end
end
end

0 comments on commit 2bfc9b5

Please sign in to comment.