-
Notifications
You must be signed in to change notification settings - Fork 0
/
reductions.py
63 lines (48 loc) · 1.91 KB
/
reductions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import (
annotations as __annotations__,
) # Delayed parsing of type annotations
import mitsuba as mi
import drjit as dr
if __name__ == "__main__":
mi.set_variant("cuda_ad_rgb")
def scatter_reduce_with(func, target, value, index, active=True):
# n_value = dr.shape(value)[-1]
# n_target = dr.shape(target)[-1]
n_value = dr.width(value)
n_target = dr.width(target)
# print(f"{n_value=}")
# print(f"{n_target=}")
current_scatter = dr.zeros(mi.UInt, n_target)
queued_values = dr.arange(mi.UInt, n_value)
while len(queued_values) > 0:
"""
First we scatter into the `current_scatter` array.
For every double index, a random element is selected
"""
target_idx = dr.gather(mi.UInt, index, queued_values)
lane_idx = dr.gather(mi.UInt, dr.arange(mi.UInt, n_value), queued_values)
dr.scatter(
current_scatter,
lane_idx,
dr.gather(mi.UInt, index, queued_values),
)
"""
We now get the selected values for scattering in this loop iteration
"""
current = dr.eq(dr.gather(mi.UInt, current_scatter, target_idx), lane_idx)
current_idx = dr.gather(mi.UInt, queued_values, dr.compress(current))
queued_values = dr.gather(mi.UInt, queued_values, dr.compress(~current))
target_idx = dr.gather(mi.UInt, index, current_idx)
a = dr.gather(type(target), target, target_idx)
b = dr.gather(type(value), value, current_idx)
"""
After gathering the target and values of the current lanes we compute the result
"""
res = func(a, b)
dr.scatter(target, res, target_idx)
if __name__ == "__main__":
target = dr.zeros(mi.Float, 10)
index = dr.arange(mi.UInt, 25) % 10
value = dr.ones(mi.Float, 25)
scatter_reduce_with(lambda a, b: a + b, target, value, index)
print(f"{target=}")