Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Error Message #312

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.22"
version = "0.4.23"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
63 changes: 42 additions & 21 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1479,30 +1479,51 @@

_copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode)

@noinline function _build_rule!(rule::LazyDerivedRule{sig, Trule}) where {sig, Trule}
@inline function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N}
isdefined(rule, :rule) || _build_rule!(rule, args)
return rule.rule(args...)
end

struct BadRuleTypeException <: Exception
mi::Core.MethodInstance
sig::Type
actual_rule_type::Type
expected_rule_type::Type
end

function Base.showerror(io::IO, err::BadRuleTypeException)
println(io, "BadRuleTypeException:")
println(io)
println(io, "Rule is of type:")
println(io, err.actual_rule_type)
println(io)
println(io, "However, expected rule to be of type:")
println(io, err.expected_rule_type)
println(io)
println(io, "This error occured for $(err.mi) with signature:")
println(io, err.sig)
println(io)
msg = "Usually this error is indicative of something having gone wrong in the " *

Check warning on line 1506 in src/interpreter/s2s_reverse_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_reverse_mode_ad.jl#L1494-L1506

Added lines #L1494 - L1506 were not covered by tests
"compilation of the rule in question. Look at the error message for the error " *
"which caused this error (below) for more details. If the error below does not " *
"immediately give you enough information to debug what is going on, consider " *
"building the rule for the signature above, and inspecting the IR."
println(io, msg)

Check warning on line 1511 in src/interpreter/s2s_reverse_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_reverse_mode_ad.jl#L1511

Added line #L1511 was not covered by tests
end

@noinline function _build_rule!(rule::LazyDerivedRule{sig, Trule}, args) where {sig, Trule}
derived_rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode)
if derived_rule isa Trule
rule.rule = derived_rule
else
@warn "Unable to put rule in rule field. Rule should error."
println("MethodInstance is")
display(rule.mi)
println()
println("with signature")
display(sig)
println()
println("derived_rule is of type")
display(typeof(derived_rule))
println()
println("Expected type is")
display(Trule)
println()
derived_rule(args...)
error("Rule with bad type ran without error.")
@warn "Unable to put rule in rule field. A `BadRuleTypeException` should be thrown."
err = BadRuleTypeException(rule.mi, sig, typeof(derived_rule), Trule)
try
derived_rule(args...)
catch
throw(err)
end
@warn "`BadRuleTypException was _not_ thrown. Throwing now."
throw(err)

Check warning on line 1527 in src/interpreter/s2s_reverse_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_reverse_mode_ad.jl#L1526-L1527

Added lines #L1526 - L1527 were not covered by tests
end
end

@inline function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N}
isdefined(rule, :rule) || _build_rule!(rule)
return rule.rule(args...)
end
15 changes: 13 additions & 2 deletions test/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module S2SGlobals
using LinearAlgebra
using LinearAlgebra, Mooncake

non_const_global = 5.0
const const_float = 5.0
Expand All @@ -11,6 +11,13 @@ module S2SGlobals
data
end
f(a, x) = dot(a.data, x)

# Test cases designed to cause `LazyDerivedRule` to throw an error when attempting to
# construct a rule for `bar`.
foo(x) = x
@noinline bar(x) = foo(x)
baz(x) = bar(x)
Mooncake.@is_primitive Mooncake.MinimalCtx Tuple{typeof(foo), Any}
end

@testset "s2s_reverse_mode_ad" begin
Expand Down Expand Up @@ -246,7 +253,11 @@ end
rule = Mooncake.build_rrule(interp, sig; debug_mode)
@test rule isa Mooncake.rule_type(interp, sig; debug_mode)
end

@testset "LazyDerivedRule" begin
fargs = (S2SGlobals.baz, 5.0)
rule = build_rrule(fargs...)
@test_throws Mooncake.BadRuleTypeException rule(map(zero_fcodual, fargs)...)
end
@testset "$(_typeof((f, x...)))" for (n, (interface_only, perf_flag, bnds, f, x...)) in
collect(enumerate(TestResources.generate_test_functions()))

Expand Down
Loading