-
Notifications
You must be signed in to change notification settings - Fork 1
/
tuple.t
99 lines (86 loc) · 2.41 KB
/
tuple.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
-- SPDX-FileCopyrightText: 2024 René Hiemstra <[email protected]>
-- SPDX-FileCopyrightText: 2024 Torsten Keßler <[email protected]>
--
-- SPDX-License-Identifier: MIT
-- Helper functions for tuples
local function istuple(T)
--[=[
Check if a given type is a tuple
--]=]
assert(terralib.types.istype(T))
if T:isunit() then
return true
end
if not T:isstruct() then
return false
end
local entries = T.entries
-- An empty struct cannot be the empty tuple as we already checked for unit, the empty tuple.
if #entries == 0 then
return false
end
-- Entries are named _0, _1, ...
for i = 1, #entries do
if entries[i][1] ~= "_" .. tostring(i - 1) then
return false
end
end
return true
end
local function unpacktuple(T)
--[=[
Return list of types in given tuple type
Args:
tpl: Tuple type
Returns:
One-based terra list composed of the types in the tuple
Examples:
print(unpacktuple(tuple(int, double))[2])
-- double
--]=]
-- The entries key of a tuple type is a terra list of tables,
-- where each table stores the index (zero based) and the type.
-- Hence we can use the map method of a terra list to extract a list
-- of terra types. For details, see the implementation of the tuples type
-- https://github.com/terralang/terra/blob/4d32a10ffe632694aa973c1457f1d3fb9372c737/src/terralib.lua#L1762
assert(istuple(T))
return T.entries:map(function(e) return e[2] end)
end
local function dimensiontuple(T)
local dim = {}
local function go(S, dimS)
if istuple(S) then
for i, e in pairs(unpacktuple(S)) do
dimS[i] = {}
go(e, dimS[i])
end
end
end
local dim = {}
go(T, dim)
return dim
end
local function tensortuple(T)
local dim = dimensiontuple(T)
local loc = dim
local ref = {}
while #loc > 0 do
ref[#ref + 1] = #loc
loc = loc[1]
end
local function go(dim, lvl)
assert((#dim == 0 and ref[lvl] == nil) or #dim == ref[lvl],
string.format("Dimension %d expected but got %d", ref[lvl] or 0, #dim))
for i = 1, #dim do
go(dim[i], lvl + 1)
end
end
go(dim, 1)
return ref
end
return {
istuple = istuple,
unpacktuple = unpacktuple,
dimensiontuple = dimensiontuple,
tensortuple = tensortuple,
}