From b29b7ace937c10f9ff3670e0969f8c86cd37c78d Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 8 Nov 2022 10:13:15 -0800 Subject: [PATCH] Add rem2pi (#544) * Add rem2pi * fix --- src/compiler.jl | 4 +++- src/compiler/interpreter.jl | 5 +++++ test/runtests.jl | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 71094514a5..27d69981a1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -29,6 +29,7 @@ unsafe_to_pointer(ptr) = ccall(Base.@cfunction(x->x, Ptr{Cvoid}, (Ptr{Cvoid},)), # Julia function to LLVM stem and arity const known_ops = Dict( Base.cbrt => (:cbrt, 1), + Base.rem2pi => (:jl_rem2pi, 2), Base.sqrt => (:sqrt, 1), Base.sin => (:sin, 1), Base.sinc => (:sincn, 1), @@ -6044,6 +6045,7 @@ end elseif sparam_vals[2] != T continue end + elseif name == :jl_rem2pi else all(==(T), sparam_vals) || continue end @@ -6060,7 +6062,7 @@ end handleCustom(name, [EnumAttribute("readnone", 0; ctx), StringAttribute("enzyme_shouldrecompute"; ctx)]) end - + @assert actualRetType !== nothing if must_wrap diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index aba5f7456c..687dc1ab78 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -88,6 +88,11 @@ function is_primitive_func(@nospecialize(TT)) if in(ft, PrimitiveFuncs) return true end + if ft === typeof(Base.rem2pi) + if TT <: Tuple{ft, Float32, <:Any} || TT <: Tuple{ft, Float64, <:Any} || TT <: Tuple{ft, Float16, <:Any} + return true + end + end if ft === typeof(Base.cbrt) || ft === typeof(Base.sin) || ft === typeof(Base.cos) || ft === typeof(Base.sinc) || ft === typeof(Base.tan) || ft === typeof(Base.exp) || ft === typeof(Base.FastMath.exp_fast) || diff --git a/test/runtests.jl b/test/runtests.jl index 3d71a452be..80ca246e9b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -146,6 +146,7 @@ end test_scalar(Base.exp10, 1.0) test_scalar(Base.exp2, 1.0) test_scalar(Base.expm1, 1.0) + test_scalar(x->rem2pi(x,RoundDown), 0.7) @test autodiff(Reverse, (x)->log(x), Active(2.0)) == (0.5,) end