-
Notifications
You must be signed in to change notification settings - Fork 0
/
topp_sampler.lua
92 lines (69 loc) · 1.83 KB
/
topp_sampler.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
local ToppSampler = {}
ToppSampler.__index = ToppSampler
function ToppSampler:new(maxNumberOfElements, temperature, topp)
local self = setmetatable({}, ToppSampler)
self.indices = {}
for i = 1, maxNumberOfElements do
self.indices[i] = 0
end
self.topp = topp
self.temperature = temperature
return self
end
local function sift_down(logits, array, from, n)
local prev, next = from, nil
while true do
next = 2 * prev + 1
if next >= n then break end
local r = 2 * prev + 2
if r < n and logits:GetFloat(array[r]) - logits:GetFloat(array[next]) < 0 then
next = r
end
if logits:GetFloat(array[next]) - logits:GetFloat(array[prev]) < 0 then
array[prev], array[next] = array[next], array[prev]
prev = next
else
break
end
end
end
function ToppSampler:SampleToken(logits)
logits:DivideInPlace(0, logits.size, self.temperature)
logits:SoftMaxInPlace(0, logits.size)
local n = logits.size
local head = 1
local tail = n
local cutoff = (1.0 - self.topp) / (n - 1)
for i = 1, #self.indices do
if logits:GetFloat(i - 1) >= cutoff then
self.indices[head] = i - 1
head = head + 1
else
self.indices[tail] = i - 1
tail = tail - 1
end
end
local n0 = head - 1
for i = math.floor(n0 / 2), 1, -1 do
sift_down(logits, self.indices, i, n0)
end
local cumulativeProb = 0.0
local lastIndex = 1
for i = n0, 1, -1 do
self.indices[1], self.indices[i] = self.indices[i], self.indices[1]
cumulativeProb = cumulativeProb + logits:GetFloat(self.indices[i])
if cumulativeProb > self.topp then
lastIndex = i
break
end
sift_down(logits, self.indices, 1, i - 1)
end
local r = math.random() * cumulativeProb
local cdf = 0.0
for i = n0, lastIndex, -1 do
cdf = cdf + logits:GetFloat(self.indices[i])
if r < cdf then return self.indices[i] end
end
return self.indices[lastIndex]
end
return ToppSampler