-
Notifications
You must be signed in to change notification settings - Fork 25
/
util.lua
324 lines (295 loc) · 6.08 KB
/
util.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
local sort, pairs, select, unpack, error =
table.sort, pairs, select, unpack, error
local type, setmetatable, getmetatable =
type, setmetatable, getmetatable
local random = math.random
local max = math.max
function any(fun, tab)
for _, v in ipairs(tab) do
if fun(v) then
return true
end
end
return false
end
-- bounds b so a<=b<=c
function bound(a, b, c)
if b<a then return a
elseif b>c then return c
else return b end
end
-- mods b so a<=b<=c
function wrap(a, b, c)
return (b-a)%(c-a+1)+a
end
function arr_to_set(tab)
local ret = {}
for i=1,#tab do
ret[tab[i]] = true
end
return ret
end
function set_to_arr(tab)
local ret = {}
for v,_ in pairs(tab) do
table.insert(ret, v)
end
return ret
end
-- filter for numeric tables
function filter(func, tab)
local ret = {}
for i=1,#tab do
if func(tab[i]) then
ret[#ret+1] = tab[i]
end
end
return ret
end
-- map for numeric tables
function map(func, tab)
local ret = {}
for i=1, #tab do
ret[i]=func(tab[i])
end
return ret
end
-- map for dicts
function map_dict(func, tab)
local ret = {}
for key,val in pairs(tab) do
ret[key]=func(val)
end
return ret
end
function map_inplace(func, tab)
for i=1, #tab do
tab[i]=func(tab[i])
end
return tab
end
function map_dict_inplace(func, tab)
for key,val in pairs(tab) do
tab[key]=func(val)
end
return tab
end
-- reduce for numeric tables
function reduce(func, tab, ...)
local idx, value = 2, nil
if select("#", ...) ~= 0 then
value = select(1, ...)
idx = 1
elseif #tab == 0 then
error("Tried to reduce empty table with no initial value")
else
value = tab[1]
end
for i=idx,#tab do
value = func(value, tab[i])
end
return value
end
function car(tab)
return tab[1]
end
-- This sucks lol
function cdr(tab)
return {select(2, unpack(tab))}
end
-- a useful right inverse of table.concat
function procat(str)
local ret = {}
for i=1,#str do
ret[i]=str:sub(i,i)
end
return ret
end
-- iterate over frozen pairs in sorted order
function spairs(tab, ...)
local keys,vals,idx = {},{},0
for k in pairs(tab) do
keys[#keys+1] = k
end
sort(keys, ...)
for i=1,#keys do
vals[i]=tab[keys[i]]
end
return function()
idx = idx + 1
return keys[idx], vals[idx]
end
end
-- like spairs, but returns an array of tuples instead of an iterator
function tspairs(tab, ...)
local ret, idx = {}, 1
for k,v in spairs(tab, ...) do
ret[idx] = {k,v}
idx = idx + 1
end
return ret
end
function uniformly(t)
if #t==0 then
return nil
end
return t[random(#t)]
end
-- accepts a table, returns a shuffled copy of the table
-- accepts >1 args, returns a permutation of the args
function shuffled(first_arg, ...)
local ret = {}
local tab = {first_arg, ...}
local is_packed = (#tab > 1)
if (not is_packed) then
tab = first_arg
for i=1,#tab do
ret[i] = tab[i]
end
else
ret = tab
end
local n = #ret
for i=1,n do
local j = random(i,n)
ret[i], ret[j] = ret[j], ret[i]
end
if is_packed then
return unpack(ret)
end
return ret
end
function shuffle(tab)
local n = #tab
for i=1,n do
local j = random(i,n)
tab[i], tab[j] = tab[j], tab[i]
end
return tab
end
function reverse(tab)
local n = #tab
for i=1,n/2 do
tab[i],tab[n+1-i] = tab[n+1-i],tab[i]
end
return tab
end
function shallowcpy(tab)
local ret = {}
for k,v in pairs(tab) do
ret[k]=v
end
return ret
end
local deepcpy_mapping = {}
local real_deepcpy
function real_deepcpy(tab)
if deepcpy_mapping[tab] ~= nil then
return deepcpy_mapping[tab]
end
local ret = {}
deepcpy_mapping[tab] = ret
deepcpy_mapping[ret] = ret
for k,v in pairs(tab) do
if type(k) == "table" then
k=real_deepcpy(k)
end
if type(v) == "table" then
v=real_deepcpy(v)
end
ret[k]=v
end
return setmetatable(ret, getmetatable(tab))
end
function deepcpy(tab)
if type(tab) ~= "table" then return tab end
local ret = real_deepcpy(tab)
deepcpy_mapping = {}
return ret
end
function shallowcpy(tab)
if type(tab) ~= "table" then return tab end
local ret = {}
for k,v in pairs(tab) do
ret[k]=v
end
return ret
end
-- pls no table keys
function deepeq(a,b)
if type(a) ~= "table" or type(b) ~= "table" then
--print("comparing non-tables "..tostring(a) .." and "..tostring(b))
return a==b
end
local done_k = {}
for k,v in pairs(a) do
done_k[k] = true
if not deepeq(a[k],b[k]) then
--print("false because key "..k.." has different values "..tostring(a[k]).." and "..tostring(b[k]))
return false
end
end
for k,_ in pairs(b) do
if not done_k[k] then
--print("false because key "..k.." is missing from a")
return false
end
end
--print("true!")
return true
end
function file_contents(filename)
if love then
local file = love.filesystem.newFile(filename)
file:open("r")
local ret = file:read(file:getSize())
return ret
else
local ret = io.open(filename):read("*a")
return ret
end
end
function set_file(filename, contents)
if love then
local success = love.filesystem.write(filename, contents)
if not success then
print("error writing to "..filename)
end
else
local file = io.open(filename, "w")
file:write(contents)
file:close()
end
end
function arr_to_counter(t)
local ret = {}
for i=1,#t do
local elem = t[i]
ret[elem] = (ret[elem] or 0) + 1
end
return ret
end
-- for fixing json encoded numeric dicts
-- TODO: this may become a performance bottleneck for the server later
function fix_num_keys(t)
local ret = {}
for k,v in pairs(t) do
if type(v) == "table" then
v = fix_num_keys(v)
end
local new_k = tonumber(k) or k
if (type(ret[new_k]) ~= "number") or (type(v) ~= "number") or (ret[new_k] < v) then
ret[new_k] = v
end
end
return ret
end
function union_counters(list_of_counters)
local ret = {}
for _,counter in pairs(list_of_counters) do
for key,count in pairs(counter) do
ret[key] = max(count, ret[key] or 0)
end
end
return ret
end