-
Notifications
You must be signed in to change notification settings - Fork 1
/
dmatrix.t
123 lines (101 loc) · 3.15 KB
/
dmatrix.t
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
-- SPDX-FileCopyrightText: 2024 René Hiemstra <[email protected]>
-- SPDX-FileCopyrightText: 2024 Torsten Keßler <[email protected]>
--
-- SPDX-License-Identifier: MIT
local alloc = require("alloc")
local base = require("base")
local concepts = require("concepts")
local matrix = require("matrix")
local err = require("assert")
local fun = require("fun")
local tupl = require("tuple")
local Allocator = alloc.Allocator
local size_t = uint64
local DynamicMatrix = terralib.memoize(function(T)
local S = alloc.SmartBlock(T)
local struct M{
data: S
rows: size_t
cols: size_t
ld: size_t
}
M.eltype = T
function M.metamethods.__typename(self)
return ("DynamicMatrix(%s)"):format(tostring(T))
end
base.AbstractBase(M)
terra M:rows()
return self.rows
end
terra M:cols()
return self.cols
end
terra M:get(i: size_t, j: size_t)
err.assert(i < self:rows() and j < self:cols())
return self.data:get(j + self.ld * i)
end
terra M:set(i: size_t, j: size_t, a: T)
err.assert(i < self:rows() and j < self:cols())
self.data:set(j + self.ld * i, a)
end
M.metamethods.__apply = macro(function(self, i, j)
return `self.data(j + self.ld * i)
end)
matrix.MatrixBase(M)
if concepts.BLASNumber(T) then
terra M:getblasdenseinfo()
return self:rows(), self:cols(), self.data.ptr, self.ld
end
local matblas = require("matrix_blas_dense")
matblas.BLASDenseMatrixBase(M)
end
terra M.staticmethods.new(alloc: Allocator, rows: size_t, cols: size_t)
return M {alloc:allocate(sizeof(T), rows * cols), rows, cols, cols}
end
terra M.staticmethods.like(alloc: Allocator, m: &M)
return M.new(alloc, m:rows(), m:cols())
end
terra M.staticmethods.all(alloc: Allocator, rows: size_t, cols: size_t, a: T)
var m = M.new(alloc, rows, cols)
for i = 0, rows do
for j = 0, cols do
m:set(i, j, a)
end
end
return m
end
terra M.staticmethods.zeros(alloc: Allocator, rows: size_t, cols: size_t)
return M.all(alloc, rows, cols, 0)
end
terra M.staticmethods.all_like(alloc: Allocator, m: &M, a: T)
return M.all(alloc, m:rows(), m:cols(), a)
end
terra M.staticmethods.zeros_like(alloc: Allocator, m: &M)
return M.all(alloc, m:rows(), m:cols(), 0)
end
M.staticmethods.from = macro(function(alloc, tabl)
local dim = tupl.tensortuple(tabl.tree.type)
assert(#dim == 2)
local rows, cols = unpack(dim)
local m = symbol(M)
local loop = terralib.newlist()
local function get(tpl, i, j)
return `tpl.["_" .. tostring(i)].["_" .. tostring(j)]
end
for i = 0, rows - 1 do
for j = 0, cols - 1 do
loop:insert(quote [m]:set(i, j, [get(tabl, i, j)]) end)
end
end
return quote
var [m] = M.new(alloc, rows, cols)
[loop]
in
[m]
end
end)
return M
end)
return {
DynamicMatrix = DynamicMatrix
}