diff --git a/lib/complex.dx b/lib/complex.dx index 3b2b40191..293012322 100644 --- a/lib/complex.dx +++ b/lib/complex.dx @@ -100,6 +100,9 @@ def complex_erf(x:Complex) -> Complex = def complex_erfc(x:Complex) -> Complex = todo +def complex_erfinv(x:Complex) -> Complex = + todo + def complex_log1p(x:Complex) -> Complex = case x.re == 0.0 of True -> x @@ -130,3 +133,4 @@ instance Floating(Complex) def lgamma(x) = complex_lgamma(x) def erf(x) = complex_erf(x) def erfc(x) = complex_erfc(x) + def erfinv(x) = complex_erfinv(x) diff --git a/lib/prelude.dx b/lib/prelude.dx index 8c62e15b5..b7dfe5419 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -1050,6 +1050,7 @@ interface Floating(a:Type) lgamma : (a) -> a erf : (a) -> a erfc : (a) -> a + erfinv : (a) -> a def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y) @@ -1066,6 +1067,127 @@ def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f6 def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))) ,%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x)))) +# Polynomial evaluation by Horner's method +def unsafe_horner(x:a, ys:n=>a) -> a given (a|Add|Mul, n|Ix) = + n' = unsafe_i_to_n(n_to_i(size n) - 1) + yield_state ys[unsafe_from_ordinal n'] \ref. rof i:(Fin n'). + ref := ys[unsafe_from_ordinal (ordinal i)] + x * get ref + +# `erfinv` implementations for `Float32` and `Float64` are based on those in Julia in +# https://github.com/JuliaMath/SpecialFunctions.jl, which uses the following reference: +# Blair, J. M., Edwards, C. A., & Johnson, J. H. (1976). Rational Chebyshev approximations +# for the inverse of the error function. In Mathematics of Computation (Vol. 30, Issue 136, +# pp. 827–830). American Mathematical Society (AMS). +# https://doi.org/10.1090/s0025-5718-1976-0421040-7 +def float32_erfinv(x:Float32) -> Float32 = + a = select(x > 0.0, x, -x) + if a >= 1.0 + then + inf = f_to_f32(1.0 / 0.0) + if x == 1.0 + then inf + else + if x == -1.0 + then -inf + else f_to_f32(0.0 / 0.0) # TODO: this should probably error but `error` is not defined yet + else + if a <= 0.75 # Blair table 10 + then + t = x * x - 0.5625 + p1 = unsafe_horner t [-0.130959967422e+2, 0.26785225760e+2, -0.9289057365e+1] + p2 = unsafe_horner t [-0.120749426297e+2, 0.30960614529e+2, -0.17149977991e+2, 0.1e+1] + f_to_f32(x * (p1 / p2)) + else + if a <= 0.9375 # Blair table 29 + then + t = x * x - 0.87890625 + p1 = unsafe_horner t [-0.12402565221, 0.10688059574e+1, -0.19594556078e+1, 0.4230581357] + p2 = unsafe_horner t [-0.8827697997e-1, 0.8900743359, -0.21757031196e+1, 0.1e+1] + f_to_f32(x * (p1 / p2)) + else # Blair table 50 + t = 1.0 / %sqrt(-%log1p(-a)) + p1 = unsafe_horner t [-0.8827697997e-1, 0.8900743359, -0.21757031196e+1, 0.1e+1] + p2 = unsafe_horner t [0.155024849822, 0.1385228141995e+1, 0.1e+1] + s = select(x > 0.0, t, select(x < 0.0, (-t), 0.0)) + f_to_f32(p1 / (s * p2)) + +def float64_erfinv(x:Float64) -> Float64 = + zero64 = (zero::Float64) + one64 = (one::Float64) + a = select(x > zero64, x, %fsub(zero64, x)) + if a >= one64 + then + inf = %fdiv(one64, zero64) + if x == one64 + then inf + else + if x == f_to_f64(-1.0) + then %fsub(zero64, inf) + else %fdiv(zero64, zero64) + else + if a <= f_to_f64(0.75) # Blair table 17 + then + t = %fsub(%fmul(x, x), f_to_f64(0.5625)) + p1 = unsafe_horner t [f_to_f64( 0.160304955844066229311e2), + f_to_f64(-0.90784959262960326650e2), + f_to_f64( 0.18644914861620987391e3), + f_to_f64(-0.16900142734642382420e3), + f_to_f64( 0.6545466284794487048e2), + f_to_f64(-0.864213011587247794e1), + f_to_f64( 0.1760587821390590)] + p2 = unsafe_horner t [f_to_f64( 0.147806470715138316110e2), + f_to_f64(-0.91374167024260313936e2), + f_to_f64( 0.21015790486205317714e3), + f_to_f64(-0.22210254121855132366e3), + f_to_f64( 0.10760453916055123830e3), + f_to_f64(-0.206010730328265443e2), + f_to_f64( 0.1e1)] + %fmul(x, %fdiv(p1, p2)) + else + if a <= f_to_f64(0.9375) # Blair table 37 + then + t = %fsub(%fmul(x, x), f_to_f64(0.87890625)) + p1 = unsafe_horner t [f_to_f64(-0.152389263440726128e-1), + f_to_f64( 0.3444556924136125216), + f_to_f64(-0.29344398672542478687e1), + f_to_f64( 0.11763505705217827302e2), + f_to_f64(-0.22655292823101104193e2), + f_to_f64( 0.19121334396580330163e2), + f_to_f64(-0.5478927619598318769e1), + f_to_f64( 0.237516689024448)] + p2 = unsafe_horner t [f_to_f64(-0.108465169602059954e-1), + f_to_f64( 0.2610628885843078511), + f_to_f64(-0.24068318104393757995e1), + f_to_f64( 0.10695129973387014469e2), + f_to_f64(-0.23716715521596581025e2), + f_to_f64( 0.24640158943917284883e2), + f_to_f64(-0.10014376349783070835e2), + f_to_f64( 0.1e1)] + %fmul(x, %fdiv(p1, p2)) + else # Blair table 57 + t = %fdiv(one64, %sqrt(%fsub(zero64, %log1p(%fsub(zero64, a))))) + p1 = unsafe_horner t [f_to_f64(0.10501311523733438116e-3), + f_to_f64(0.1053261131423333816425e-1), + f_to_f64(0.26987802736243283544516), + f_to_f64(0.23268695788919690806414e1), + f_to_f64(0.71678547949107996810001e1), + f_to_f64(0.85475611822167827825185e1), + f_to_f64(0.68738088073543839802913e1), + f_to_f64(0.3627002483095870893002e1), + f_to_f64(0.886062739296515468149)] + p2 = unsafe_horner t [f_to_f64(0.10501266687030337690e-3), + f_to_f64(0.1053286230093332753111e-1), + f_to_f64(0.27019862373751554845553), + f_to_f64(0.23501436397970253259123e1), + f_to_f64(0.76078028785801277064351e1), + f_to_f64(0.111815861040569078273451e2), + f_to_f64(0.119487879184353966678438e2), + f_to_f64(0.81922409747269907893913e1), + f_to_f64(0.4099387907636801536145e1), + f_to_f64(0.1e1)] + s = select(x > zero64, t, select(x < zero64, %fsub(zero64, t), zero64)) + %fdiv(p1, %fmul(s, p2)) + instance Floating(Float64) def exp(x) = %exp(x) def exp2(x) = %exp2(x) @@ -1087,6 +1209,7 @@ instance Floating(Float64) def lgamma(x)= %lgamma(x) def erf(x) = %erf(x) def erfc(x) = %erfc(x) + def erfinv(x)= float64_erfinv(x) instance Floating(Float32) def exp(x) = %exp(x) @@ -1109,6 +1232,7 @@ instance Floating(Float32) def lgamma(x)= %lgamma(x) def erf(x) = %erf(x) def erfc(x) = %erfc(x) + def erfinv(x)= float32_erfinv(x) '## Raw pointer operations @@ -1249,6 +1373,7 @@ instance Floating(n=>a) given (a|Floating, n|Ix) def lgamma(x) = each x lgamma def erf(x) = each x erf def erfc(x) = each x erfc + def erfinv(x) = each x erfinv '### Reductions diff --git a/tests/eval-tests.dx b/tests/eval-tests.dx index ef5d528b7..28e12d6ca 100644 --- a/tests/eval-tests.dx +++ b/tests/eval-tests.dx @@ -123,6 +123,21 @@ fun = \y. sum (map n_to_f arr) + y :p f_to_i $ round 3.6 > 4 +:p erfinv(f_to_f64 0.84270079294971486934) +> 1. + +:p erfinv 1.0 +> inf + +-- TODO: This should actually be an error since it's outside of the domain of the function +:p erfinv 2.0 +> nan + +:p + xs = each [-0.99, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 0.99] f_to_f64 + erf(erfinv xs) ~~ erfinv(erf xs) && erfinv(xs) ~~ each xs \x. zero - erfinv(zero - x) +> True + s = 1.0 :p s