Skip to content

Commit

Permalink
Fix array tangent gen functionality (#344)
Browse files Browse the repository at this point in the history
* Fix array tangent gen functionality

* Fix errors for length-zero mems

* Add extra test case

* Add tests which can pass

* Fix rules

* Bump patch

* Fix bug
  • Loading branch information
willtebbutt authored Nov 5, 2024
1 parent ff56e8e commit 624f638
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 89 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.35"
version = "0.4.36"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
47 changes: 47 additions & 0 deletions src/rrules/array_legacy.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,50 @@
@inline function zero_tangent_internal(x::Array{P, N}, stackdict::IdDict) where {P, N}
haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x))

zt = Array{tangent_type(P), N}(undef, size(x)...)
stackdict[x] = zt
return _map_if_assigned!(Base.Fix2(zero_tangent_internal, stackdict), zt, x)::Array{tangent_type(P), N}
end

function randn_tangent_internal(rng::AbstractRNG, x::Array{T, N}, stackdict::IdDict) where {T, N}
haskey(stackdict, x) && return stackdict[x]::tangent_type(typeof(x))

dx = Array{tangent_type(T), N}(undef, size(x)...)
stackdict[x] = dx
return _map_if_assigned!(x -> randn_tangent_internal(rng, x, stackdict), dx, x)
end

function increment!!(x::T, y::T) where {P, N, T<:Array{P, N}}
return x === y ? x : _map_if_assigned!(increment!!, x, x, y)
end

set_to_zero!!(x::Array) = _map_if_assigned!(set_to_zero!!, x, x)

function _scale(a::Float64, t::Array{T, N}) where {T, N}
t′ = Array{T, N}(undef, size(t)...)
return _map_if_assigned!(Base.Fix1(_scale, a), t′, t)
end

function _dot(t::T, s::T) where {T<:Array}
isbitstype(T) && return sum(_map(_dot, t, s))
return sum(
_map(eachindex(t)) do n
(isassigned(t, n) && isassigned(s, n)) ? _dot(t[n], s[n]) : 0.0
end;
init=0.0,
)
end

function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}) where {P, N}
x′ = Array{P, N}(undef, size(x)...)
return _map_if_assigned!(_add_to_primal, x′, x, t)
end

function _diff(p::P, q::P) where {V, N, P<:Array{V, N}}
t = Array{tangent_type(V), N}(undef, size(p))
return _map_if_assigned!(_diff, t, p, q)
end

@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Vararg} where {T, N}
@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), Tuple{}} where {T, N}
@zero_adjoint MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), NTuple{N}} where {T, N}
Expand Down
193 changes: 144 additions & 49 deletions src/rrules/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function randn_tangent_internal(rng::AbstractRNG, x::Memory, stackdict::Maybe{Id
haskey(stackdict, x) && return stackdict[x]::T

t = T(undef, length(x))
stackdict[t] = t
stackdict[x] = t
return _map_if_assigned!(x -> randn_tangent_internal(rng, x, stackdict), t, x)::T
end

Expand Down Expand Up @@ -85,7 +85,8 @@ function _dot(t::Memory{T}, s::Memory{T}) where {T}
return sum(
_map(eachindex(t)) do n
(isassigned(t, n) && isassigned(s, n)) ? _dot(t[n], s[n]) : 0.0
end
end;
init=0.0,
)
end

Expand All @@ -109,11 +110,77 @@ tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Memory} = F

tangent(f::Memory, ::NoRData) = f

function _verify_fdata_value(p::Memory{P}, f::Memory{T}) where {P, T}
@assert length(p) == length(f)
function _verify_fdata_value(p::Memory{P}, f::Memory{F}) where {P, F}
if length(p) != length(f)
msg = "length(p) == $(length(p)) but length(f) == $(length(f)). " *
"p isa Memory{$P} and f isa Memory{$F}"
throw(error(msg))
end
return nothing
end

#
# Array -- tangent interface implementation
#

@inline function zero_tangent_internal(x::Array, stackdict::Maybe{IdDict})
T = tangent_type(typeof(x))

# If we already have a tangent for this, just return that.
haskey(stackdict, x) && return stackdict[x]::T

# Construct a new tangent, log it in the `stackdict`, and return it.
dx = _new_(T)
Base.setfield!(dx, :size, x.size)
stackdict[x] = dx
Base.setfield!(dx, :ref, zero_tangent_internal(x.ref, stackdict))
return dx::T
end

function randn_tangent_internal(rng::AbstractRNG, x::Array, stackdict::Maybe{IdDict})
T = tangent_type(typeof(x))

# If we already have a tangent for this, just return that.
haskey(stackdict, x) && return stackdict[x]::T

# Construct a new tangent, log it in the `stackdict`, and return it.
dx = _new_(T)
Base.setfield!(dx, :size, x.size)
stackdict[x] = dx
Base.setfield!(dx, :ref, randn_tangent_internal(rng, x.ref, stackdict))
return dx::T
end

function increment!!(x::T, y::T) where {T<:Array}
return x === y ? x : _map_if_assigned!(increment!!, x, x, y)
end

set_to_zero!!(x::Array) = _map_if_assigned!(set_to_zero!!, x, x)

function _scale(a::Float64, t::T) where {T<:Array}
t′ = T(undef, size(t)...)
return _map_if_assigned!(Base.Fix1(_scale, a), t′, t)
end

function _dot(t::T, s::T) where {T<:Array}
isbitstype(T) && return sum(_map(_dot, t, s))
return sum(
_map(eachindex(t)) do n
(isassigned(t, n) && isassigned(s, n)) ? _dot(t[n], s[n]) : 0.0
end;
init=0.0,
)
end

function _add_to_primal(x::Array{P, N}, t::Array{<:Any, N}) where {P, N}
x′ = Array{P, N}(undef, size(x)...)
return _map_if_assigned!(_add_to_primal, x′, x, t)
end

function _diff(p::P, q::P) where {P<:Array}
return _map_if_assigned!(_diff, tangent_type(P)(undef, size(p)), p, q)
end

# Rules

@is_primitive(
Expand Down Expand Up @@ -166,14 +233,29 @@ end

tangent_type(::Type{<:MemoryRef{P}}) where {P} = MemoryRef{tangent_type(P)}

#=
Given a new chunk of memory `m`, construct a `MemoryRef` which points to the same relative
position in `x`, as `m` points to in its underlying `Memory` object. For example, in the
following:
```julia
original_mem = Memory{Float64}(undef, 10)
x = memoryref(original_mem, 4)
new_mem = Memory{Float64}(undef, 10)
new_x = construct_ref(x, new_mem)
```
`new_x` will point towards the 4th element of `new_mem`. Care is required of the length
of `original_mem` is `0`. See implementation for details.
=#
function construct_ref(x::MemoryRef, m::Memory)
return isempty(m) ? memoryref(m) : memoryref(m, Core.memoryrefoffset(x))
end

function zero_tangent_internal(x::MemoryRef, stackdict::Maybe{IdDict})
t_mem = zero_tangent_internal(x.mem, stackdict)::Memory
return memoryref(t_mem, Core.memoryrefoffset(x))
return construct_ref(x, zero_tangent_internal(x.mem, stackdict))
end

function randn_tangent_internal(rng::AbstractRNG, x::MemoryRef, stackdict::Maybe{IdDict})
t_mem = randn_tangent_internal(rng, x.mem, stackdict)::Memory
return memoryref(t_mem, Core.memoryrefoffset(x))
return construct_ref(x, randn_tangent_internal(rng, x.mem, stackdict))
end

function TestUtils.has_equal_data_internal(
Expand All @@ -184,30 +266,26 @@ function TestUtils.has_equal_data_internal(
return equal_refs && equal_data
end

function increment!!(x::MemoryRef{P}, y::MemoryRef{P}) where {P}
return memoryref(increment!!(x.mem, y.mem), Core.memoryrefoffset(x))
end
increment!!(x::P, y::P) where {P<:MemoryRef} = construct_ref(x, increment!!(x.mem, y.mem))

function set_to_zero!!(x::MemoryRef)
set_to_zero!!(x.mem)
return x
end

function _add_to_primal(p::MemoryRef, t::MemoryRef)
return memoryref(_add_to_primal(p.mem, t.mem), Core.memoryrefoffset(p))
end
_add_to_primal(p::MemoryRef, t::MemoryRef) = construct_ref(p, _add_to_primal(p.mem, t.mem))

function _diff(p::MemoryRef{P}, q::MemoryRef{P}) where {P}
function _diff(p::P, q::P) where {P<:MemoryRef}
@assert Core.memoryrefoffset(p) == Core.memoryrefoffset(q)
return memoryref(_diff(p.mem, q.mem), Core.memoryrefoffset(p))
return construct_ref(p, _diff(p.mem, q.mem))
end

function _dot(t::MemoryRef{T}, s::MemoryRef{T}) where {T}
function _dot(t::T, s::T) where {T<:MemoryRef}
@assert Core.memoryrefoffset(t) == Core.memoryrefoffset(s)
return _dot(t.mem, s.mem)
end

_scale(a::Float64, t::MemoryRef) = memoryref(_scale(a, t.mem), Core.memoryrefoffset(t))
_scale(a::Float64, t::MemoryRef) = construct_ref(t, _scale(a, t.mem))

function populate_address_map!(m::TestUtils.AddressMap, p::MemoryRef, t::MemoryRef)
return populate_address_map!(m, p.mem, t.mem)
Expand Down Expand Up @@ -561,28 +639,41 @@ function _mems()
(Memory{Vector{Float64}}(undef, 3)),
(Memory{Any}(randn(3))),
(mem_with_single_undef),
(Memory{Any}(undef, 0)),
]
sample_values = [1.0, 3, randn(2), randn(2), 5.0, Memory{Int}(undef, 5)]
sample_values = [1.0, 3, randn(2), randn(2), 5.0, Memory{Int}(undef, 5), nothing]
return mems, sample_values
end

function _mem_refs()

# Generate test cases of arbitrary length.
mems_1, sample_values_1 = _mems()
mems_2, sample_values_2 = _mems()

# Restrict to minimum length of 2.
_mems_2, _sample_values_2 = _mems()
inds = findall(x -> length(x) >= 2, _mems_2)
mems_2 = _mems_2[inds]
sample_values_2 = _sample_values_2[inds]

# Construct memoryref test cases.
mem_refs = vcat([memoryref(m) for m in mems_1], [memoryref(m, 2) for m in mems_2])
return mem_refs, vcat(sample_values_1, sample_values_2)
end

function generate_data_test_cases(rng_ctor, ::Val{:memory})
arrays = [
randn(2),
]
return vcat(_mems()[1], _mem_refs()[1], arrays)
return vcat(_mems()[1], _mem_refs()[1], [randn(2), Any[]])
end

function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:memory})
rng = rng_ctor(123)
mems, _ = _mems()
mem_refs, sample_mem_ref_values = _mem_refs()

assignable_refs = Iterators.filter(
x -> length(x[1].mem) >= Core.memoryrefoffset(x[1]),
zip(mem_refs, sample_mem_ref_values),
)
test_cases = vcat(

# Rules for `Memory`
Expand All @@ -601,78 +692,82 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:memory})
mem_ref in filter(isassigned, mem_refs) for bc in [false, true]
],
[(false, :none, nothing, memoryrefnew, mem) for mem in mems],
[(false, :none, nothing, memoryrefnew, mem, 1) for mem in mem_refs],
[(false, :none, nothing, memoryrefnew, mem, 1) for
mem in filter(x -> length(x.mem) > Core.memoryrefoffset(x), mem_refs)
],
[(false, :none, nothing, memoryrefnew, mem, 1, bc) for
mem in mem_refs for bc in [false, true]
mem in filter(x -> length(x.mem) > Core.memoryrefoffset(x), mem_refs) for
bc in [false, true]
],
[(false, :none, nothing, memoryrefoffset, mem_ref) for mem_ref in mem_refs],
[
(false, :none, nothing, lmemoryrefset!, mem_ref, sample_value, Val(:not_atomic), bc) for
(mem_ref, sample_value) in zip(mem_refs, sample_mem_ref_values) for
(mem_ref, sample_value) in assignable_refs for
bc in [Val(false), Val(true)]
],
[
(false, :none, nothing, memoryrefset!, mem_ref, sample_value, :not_atomic, bc) for
(mem_ref, sample_value) in zip(mem_refs, sample_mem_ref_values) for
(mem_ref, sample_value) in assignable_refs for
bc in [false, true]
],
(false, :stability, nothing, unsafe_copyto!, randn(10).ref, randn(8).ref, 5),
(false, :stability, nothing, unsafe_copyto!, randn(rng, 10).ref, randn(rng, 8).ref, 5),
(
false, :stability, nothing,
unsafe_copyto!,
memoryref(randn(10).ref, 2),
memoryref(randn(8).ref, 3),
memoryref(randn(rng, 10).ref, 2),
memoryref(randn(rng, 8).ref, 3),
4,
),
(
false, :stability, nothing,
unsafe_copyto!,
[randn(10), randn(5)].ref,
[randn(10), randn(3)].ref,
[randn(rng, 10), randn(rng, 5)].ref,
[randn(rng, 10), randn(rng, 3)].ref,
2,
),

# Rules for `Array`
(false, :stability, nothing, _new_, Vector{Float64}, randn(10).ref, (10, )),
(false, :stability, nothing, _new_, Vector{Float64}, randn(rng, 10).ref, (10, )),
(
false, :stability, nothing,
_new_,
Vector{Vector{Float64}},
[randn(10), randn(5)].ref,
[randn(rng, 10), randn(rng, 5)].ref,
(2, ),
),
(
false, :none, nothing,
_new_,
Vector{Any},
[1, randn(5)].ref,
[1, randn(rng, 5)].ref,
(2, ),
),
(false, :stability, nothing, _new_, Matrix{Float64}, randn(12).ref, (4, 3)),
(false, :stability, nothing, _new_, Array{Float64, 3}, randn(12).ref, (4, 1, 3)),
(false, :stability, nothing, _new_, Matrix{Float64}, randn(rng, 12).ref, (4, 3)),
(false, :stability, nothing, _new_, Array{Float64, 3}, randn(rng, 12).ref, (4, 1, 3)),
[
(false, :stability, nothing, lgetfield, randn(10), f) for
(false, :stability, nothing, lgetfield, randn(rng, 10), f) for
f in [Val(:ref), Val(:size), Val(1), Val(2)]
],
[
(false, :none, nothing, getfield, randn(10), f) for
(false, :none, nothing, getfield, randn(rng, 10), f) for
f in [:ref, :size, 1, 2]
],
(false, :stability_and_allocs, nothing, lsetfield!, randn(10), Val(:ref), randn(10).ref),
(false, :stability_and_allocs, nothing, lsetfield!, randn(10), Val(1), randn(10).ref),
(false, :stability_and_allocs, nothing, lsetfield!, randn(10), Val(:size), (10, )),
(false, :stability_and_allocs, nothing, lsetfield!, randn(10), Val(2), (10, )),
(false, :none, nothing, setfield!, randn(10), :ref, randn(10).ref),
(false, :none, nothing, setfield!, randn(10), 1, randn(10).ref),
(false, :none, nothing, setfield!, randn(10), :size, (10, )),
(false, :none, nothing, setfield!, randn(10), 2, (10, )),
(false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(:ref), randn(rng, 10).ref),
(false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(1), randn(rng, 10).ref),
(false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(:size), (10, )),
(false, :stability_and_allocs, nothing, lsetfield!, randn(rng, 10), Val(2), (10, )),
(false, :none, nothing, setfield!, randn(rng, 10), :ref, randn(rng, 10).ref),
(false, :none, nothing, setfield!, randn(rng, 10), 1, randn(rng, 10).ref),
(false, :none, nothing, setfield!, randn(rng, 10), :size, (10, )),
(false, :none, nothing, setfield!, randn(rng, 10), 2, (10, )),
)
memory = Any[]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:memory})
x = Memory{Float64}(randn(10))
rng = rng_ctor(123)
x = Memory{Float64}(randn(rng, 10))
test_cases = Any[
(true, :none, nothing, Array{Float64, 0}, undef),
(true, :none, nothing, Array{Float64, 1}, undef, 5),
Expand Down
Loading

2 comments on commit 624f638

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118775

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.36 -m "<description of version>" 624f638d8b2187c336a6cb806b1bf39c6dfdf11d
git push origin v0.4.36

Please sign in to comment.