-
Notifications
You must be signed in to change notification settings - Fork 1
/
mathfuns.t
150 lines (131 loc) · 3.91 KB
/
mathfuns.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
-- SPDX-FileCopyrightText: 2024 René Hiemstra <[email protected]>
-- SPDX-FileCopyrightText: 2024 Torsten Keßler <[email protected]>
--
-- SPDX-License-Identifier: MIT
import "terraform"
local math = {}
local C = terralib.includecstring[[
#include <stdlib.h>
#include <math.h>
#include <tgmath.h>
]]
local concepts = require("concepts")
--constants
math.pi = constant(3.14159265358979323846264338327950288419716939937510)
function float:eps() return 0x1p-23 end
function double:eps() return 0x1p-52 end
local funs_single_var = {
sin = "sin",
cos = "cos",
tan = "tan",
asin = "asin",
acos = "acos",
atan = "atan",
sinh = "sinh",
cosh = "cosh",
tanh = "tanh",
asinh = "asinh",
acosh = "acosh",
atanh = "atanh",
exp = "exp",
expm1 = "expm1",
exp2 = "exp2",
log = "log",
log1p = "log1p",
log10 = "log10",
sqrt = "sqrt",
cbrt = "cbrt",
erf = "erf",
erfc = "erfc",
gamma = "tgamma",
loggamma = "lgamma",
abs = "fabs",
floor = "floor",
ceil = "ceil",
round = "round"
}
local funs_two_var = {
pow = "pow",
atan2 = "atan2",
hypot = "hypot",
fmod = "fmod"
}
local funs_three_var = {
fusedmuladd = "fma"
}
for tname, cname in pairs(funs_single_var) do
local f = terralib.overloadedfunction(tname)
for _, T in ipairs{float,double} do
local cfun = T==float and C[cname.."f"] or C[cname]
f:adddefinition(terra(x : T) return cfun(x) end)
end
math[tname] = f
end
for tname, cname in pairs(funs_two_var) do
local f = terralib.overloadedfunction(tname)
for _, T in ipairs{float,double} do
local cfun = T==float and C[cname.."f"] or C[cname]
f:adddefinition(terra(x : T, y : T) return cfun(x, y) end)
end
math[tname] = f
end
for tname, cname in pairs(funs_three_var) do
local f = terralib.overloadedfunction(tname)
for _, T in ipairs{float,double} do
local cfun = T==float and C[cname.."f"] or C[cname]
f:adddefinition(terra(x : T, y : T, z : T) return cfun(x, y, z) end)
end
math[tname] = f
end
math.beta = terralib.overloadedfunction("beta")
for _, T in ipairs{float,double} do
math.beta:adddefinition(
terra(x : T, y : T) : T
return math.gamma(x) * math.gamma(y) / math.gamma(x+y)
end
)
end
--add some missing defintions
math.abs:adddefinition(terra(x : int) return C.abs(x) end)
math.abs:adddefinition(terra(x : int64) return C.labs(x) end)
math.sign = terralib.overloadedfunction("sign")
for _, T in pairs({int32, int64, float, double}) do
math.sign:adddefinition(
terra(x: T): T
return terralib.select(x < 0, -1, 1)
end
)
end
--convenience functions
local cotf = terra(x : float) return math.cos(x) / math.sin(x) end
local cot = terra(x : double) return math.cos(x) / math.sin(x) end
math.cot = terralib.overloadedfunction("cot", {cotf, cot})
math.ldexp = terralib.overloadedfunction("ldexp", {C.ldexp, C.ldexpf})
--min and max
math.min = terralib.overloadedfunction("min")
math.max = terralib.overloadedfunction("max")
for _, T in ipairs{int32, int64, float, double} do
math.min:adddefinition(terra(x : T, y : T) return terralib.select(x < y, x, y) end)
math.max:adddefinition(terra(x : T, y : T) return terralib.select(x > y, x, y) end)
end
terraform math.dist(a: T, b: T) where {T: concepts.Number}
return math.abs(a - b)
end
-- comparing functions
terraform math.isapprox(a: T, b: T, atol: S)
where {T: concepts.Any, S: concepts.Any}
return math.dist(a, b) < atol
end
for _, name in pairs({"real", "imag", "conj"}) do
math[name] = terralib.overloadedfunction(name)
for _, T in ipairs{int32, int64, float, double} do
local impl
if name == "imag" then
impl = terra(x: T) return [T](0) end
else
impl = terra(x: T) return x end
end
math[name]:adddefinition(impl)
end
end
return math