-
Notifications
You must be signed in to change notification settings - Fork 2
/
cbor.lua
587 lines (520 loc) · 14.2 KB
/
cbor.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
-- Concise Binary Object Representation (CBOR)
-- RFC 7049
local function softreq(pkg, field)
local ok, mod = pcall(require, pkg);
if not ok then return end
if field then return mod[field]; end
return mod;
end
local dostring = function(s)
local ok, f = load(function()
local ret = s;
s = nil
return ret;
end);
if ok and f then
return f();
end
end
local setmetatable = setmetatable;
local getmetatable = getmetatable;
local dbg_getmetatable = debug.getmetatable;
local assert = assert;
local error = error;
local type = type;
local pairs = pairs;
local ipairs = ipairs;
local tostring = tostring;
local s_char = string.char;
local t_concat = table.concat;
local t_sort = table.sort;
local m_floor = math.floor;
local m_abs = math.abs;
local m_huge = math.huge;
local m_max = math.max;
local maxint = math.maxinteger or 9007199254740992;
local minint = math.mininteger or -9007199254740992;
local NaN = 0/0;
local m_frexp = math.frexp;
local m_ldexp = math.ldexp or function (x, exp) return x * 2.0 ^ exp; end;
local m_type = math.type or function (n) return n % 1 == 0 and n <= maxint and n >= minint and "integer" or "float" end;
local s_pack = string.pack or softreq("struct", "pack");
local s_unpack = string.unpack or softreq("struct", "unpack");
local b_rshift = softreq("bit32", "rshift") or softreq("bit", "rshift") or
dostring "return function(a,b) return a >> b end" or
function (a, b) return m_max(0, m_floor(a / (2 ^ b))); end;
-- sanity check
if s_pack and s_pack(">I2", 0) ~= "\0\0" then
s_pack = nil;
end
if s_unpack and s_unpack(">I2", "\1\2\3\4") ~= 0x102 then
s_unpack = nil;
end
local _ENV = nil; -- luacheck: ignore 211
local encoder = {};
local function encode(obj, opts)
return encoder[type(obj)](obj, opts);
end
-- Major types 0, 1 and length encoding for others
local function integer(num, m)
if m == 0 and num < 0 then
-- negative integer, major type 1
num, m = - num - 1, 32;
end
if num < 24 then
return s_char(m + num);
elseif num < 2 ^ 8 then
return s_char(m + 24, num);
elseif num < 2 ^ 16 then
return s_char(m + 25, b_rshift(num, 8), num % 0x100);
elseif num < 2 ^ 32 then
return s_char(m + 26,
b_rshift(num, 24) % 0x100,
b_rshift(num, 16) % 0x100,
b_rshift(num, 8) % 0x100,
num % 0x100);
elseif num < 2 ^ 64 then
local high = m_floor(num / 2 ^ 32);
num = num % 2 ^ 32;
return s_char(m + 27,
b_rshift(high, 24) % 0x100,
b_rshift(high, 16) % 0x100,
b_rshift(high, 8) % 0x100,
high % 0x100,
b_rshift(num, 24) % 0x100,
b_rshift(num, 16) % 0x100,
b_rshift(num, 8) % 0x100,
num % 0x100);
end
error "int too large";
end
if s_pack then
function integer(num, m)
local fmt;
m = m or 0;
if num < 24 then
fmt, m = ">B", m + num;
elseif num < 256 then
fmt, m = ">BB", m + 24;
elseif num < 65536 then
fmt, m = ">BI2", m + 25;
elseif num < 4294967296 then
fmt, m = ">BI4", m + 26;
else
fmt, m = ">BI8", m + 27;
end
return s_pack(fmt, m, num);
end
end
local simple_mt = {};
function simple_mt:__tostring() return self.name or ("simple(%d)"):format(self.value); end
function simple_mt:__tocbor() return self.cbor or integer(self.value, 224); end
local function simple(value, name, cbor)
assert(value >= 0 and value <= 255, "bad argument #1 to 'simple' (integer in range 0..255 expected)");
return setmetatable({ value = value, name = name, cbor = cbor }, simple_mt);
end
local tagged_mt = {};
function tagged_mt:__tostring() return ("%d(%s)"):format(self.tag, tostring(self.value)); end
function tagged_mt:__tocbor(opts) return integer(self.tag, 192) .. encode(self.value, opts); end
local function tagged(tag, value)
assert(tag >= 0, "bad argument #1 to 'tagged' (positive integer expected)");
return setmetatable({ tag = tag, value = value }, tagged_mt);
end
local null = simple(22, "null"); -- explicit null
local undefined = simple(23, "undefined"); -- undefined or nil
local BREAK = simple(31, "break", "\255");
-- Number types dispatch
function encoder.number(num)
return encoder[m_type(num)](num);
end
-- Major types 0, 1
function encoder.integer(num)
if num < 0 then
return integer(-1 - num, 32);
end
return integer(num, 0);
end
-- Major type 7
function encoder.float(num)
if num ~= num then -- NaN shortcut
return "\251\127\255\255\255\255\255\255\255";
end
local sign = (num > 0 or 1 / num > 0) and 0 or 1;
num = m_abs(num)
if num == m_huge then
return s_char(251, sign * 128 + 128 - 1) .. "\240\0\0\0\0\0\0";
end
local fraction, exponent = m_frexp(num)
if fraction == 0 then
return s_char(251, sign * 128) .. "\0\0\0\0\0\0\0";
end
fraction = fraction * 2;
exponent = exponent + 1024 - 2;
if exponent <= 0 then
fraction = fraction * 2 ^ (exponent - 1)
exponent = 0;
else
fraction = fraction - 1;
end
return s_char(251,
sign * 2 ^ 7 + m_floor(exponent / 2 ^ 4) % 2 ^ 7,
exponent % 2 ^ 4 * 2 ^ 4 +
m_floor(fraction * 2 ^ 4 % 0x100),
m_floor(fraction * 2 ^ 12 % 0x100),
m_floor(fraction * 2 ^ 20 % 0x100),
m_floor(fraction * 2 ^ 28 % 0x100),
m_floor(fraction * 2 ^ 36 % 0x100),
m_floor(fraction * 2 ^ 44 % 0x100),
m_floor(fraction * 2 ^ 52 % 0x100)
)
end
if s_pack then
function encoder.float(num)
return s_pack(">Bd", 251, num);
end
end
-- Major type 2 - byte strings
function encoder.bytestring(s)
return integer(#s, 64) .. s;
end
-- Major type 3 - UTF-8 strings
function encoder.utf8string(s)
return integer(#s, 96) .. s;
end
-- Lua strings are byte strings
encoder.string = encoder.bytestring;
function encoder.boolean(bool)
return bool and "\245" or "\244";
end
encoder["nil"] = function() return "\246"; end
function encoder.userdata(ud, opts)
local mt = dbg_getmetatable(ud);
if mt then
local encode_ud = opts and opts[mt] or mt.__tocbor;
if encode_ud then
return encode_ud(ud, opts);
end
end
error "can't encode userdata";
end
function encoder.table(t, opts)
local mt = getmetatable(t);
if mt then
local encode_t = opts and opts[mt] or mt.__tocbor;
if encode_t then
return encode_t(t, opts);
end
end
-- the table is encoded as an array iff when we iterate over it,
-- we see successive integer keys starting from 1. The lua
-- language doesn't actually guarantee that this will be the case
-- when we iterate over a table with successive integer keys, but
-- due an implementation detail in PUC Rio Lua, this is what we
-- usually observe. See the Lua manual regarding the # (length)
-- operator. In the case that this does not happen, we will fall
-- back to a map with integer keys, which becomes a bit larger.
local array, map, i, p = { integer(#t, 128) }, { "\191" }, 1, 2;
local is_array = true;
for k, v in pairs(t) do
is_array = is_array and i == k;
i = i + 1;
local encoded_v = encode(v, opts);
array[i] = encoded_v;
map[p], p = encode(k, opts), p + 1;
map[p], p = encoded_v, p + 1;
end
-- map[p] = "\255";
map[1] = integer(i - 1, 160);
return t_concat(is_array and array or map);
end
-- Array or dict-only encoders, which can be set as __tocbor metamethod
function encoder.array(t, opts)
local array = { };
for i, v in ipairs(t) do
array[i] = encode(v, opts);
end
return integer(#array, 128) .. t_concat(array);
end
function encoder.map(t, opts)
local map, p, len = { "\191" }, 2, 0;
for k, v in pairs(t) do
map[p], p = encode(k, opts), p + 1;
map[p], p = encode(v, opts), p + 1;
len = len + 1;
end
-- map[p] = "\255";
map[1] = integer(len, 160);
return t_concat(map);
end
encoder.dict = encoder.map; -- COMPAT
function encoder.ordered_map(t, opts)
local map = {};
if not t[1] then -- no predefined order
local i = 0;
for k in pairs(t) do
i = i + 1;
map[i] = k;
end
t_sort(map);
end
for i, k in ipairs(t[1] and t or map) do
map[i] = encode(k, opts) .. encode(t[k], opts);
end
return integer(#map, 160) .. t_concat(map);
end
encoder["function"] = function ()
error "can't encode function";
end
-- Decoder
-- Reads from a file-handle like object
local function read_bytes(fh, len)
return fh:read(len);
end
local function read_byte(fh)
return fh:read(1):byte();
end
local function read_length(fh, mintyp)
if mintyp < 24 then
return mintyp;
elseif mintyp < 28 then
local out = 0;
for _ = 1, 2 ^ (mintyp - 24) do
out = out * 256 + read_byte(fh);
end
return out;
else
error "invalid length";
end
end
local decoder = {};
local function read_type(fh)
local byte = read_byte(fh);
return b_rshift(byte, 5), byte % 32;
end
local function read_object(fh, opts)
local typ, mintyp = read_type(fh);
return decoder[typ](fh, mintyp, opts);
end
local function read_integer(fh, mintyp)
return read_length(fh, mintyp);
end
local function read_negative_integer(fh, mintyp)
return -1 - read_length(fh, mintyp);
end
local function read_string(fh, mintyp)
if mintyp ~= 31 then
return read_bytes(fh, read_length(fh, mintyp));
end
local out = {};
local i = 1;
local v = read_object(fh);
while v ~= BREAK do
out[i], i = v, i + 1;
v = read_object(fh);
end
return t_concat(out);
end
local function read_unicode_string(fh, mintyp)
return read_string(fh, mintyp);
-- local str = read_string(fh, mintyp);
-- if have_utf8 and not utf8.len(str) then
-- TODO How to handle this?
-- end
-- return str;
end
local function read_array(fh, mintyp, opts)
local out = {};
if mintyp == 31 then
local i = 1;
local v = read_object(fh, opts);
while v ~= BREAK do
out[i], i = v, i + 1;
v = read_object(fh, opts);
end
else
local len = read_length(fh, mintyp);
for i = 1, len do
out[i] = read_object(fh, opts);
end
end
return out;
end
local function read_map(fh, mintyp, opts)
local out = {};
local k;
if mintyp == 31 then
local i = 1;
k = read_object(fh, opts);
while k ~= BREAK do
out[k], i = read_object(fh, opts), i + 1;
k = read_object(fh, opts);
end
else
local len = read_length(fh, mintyp);
for _ = 1, len do
k = read_object(fh, opts);
out[k] = read_object(fh, opts);
end
end
return out;
end
local tagged_decoders = {};
local function read_semantic(fh, mintyp, opts)
local tag = read_length(fh, mintyp);
local value = read_object(fh, opts);
local postproc = opts and opts[tag] or tagged_decoders[tag];
if postproc then
return postproc(value);
end
return tagged(tag, value);
end
local function read_half_float(fh)
local exponent = read_byte(fh);
local fraction = read_byte(fh);
local sign = exponent < 128 and 1 or -1; -- sign is highest bit
fraction = fraction + (exponent * 256) % 1024; -- copy two(?) bits from exponent to fraction
exponent = b_rshift(exponent, 2) % 32; -- remove sign bit and two low bits from fraction;
if exponent == 0 then
return sign * m_ldexp(fraction, -24);
elseif exponent ~= 31 then
return sign * m_ldexp(fraction + 1024, exponent - 25);
elseif fraction == 0 then
return sign * m_huge;
else
return NaN;
end
end
local function read_float(fh)
local exponent = read_byte(fh);
local fraction = read_byte(fh);
local sign = exponent < 128 and 1 or -1; -- sign is highest bit
exponent = exponent * 2 % 256 + b_rshift(fraction, 7);
fraction = fraction % 128;
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
if exponent == 0 then
return sign * m_ldexp(exponent, -149);
elseif exponent ~= 0xff then
return sign * m_ldexp(fraction + 2 ^ 23, exponent - 150);
elseif fraction == 0 then
return sign * m_huge;
else
return NaN;
end
end
local function read_double(fh)
local exponent = read_byte(fh);
local fraction = read_byte(fh);
local sign = exponent < 128 and 1 or -1; -- sign is highest bit
exponent = exponent % 128 * 16 + b_rshift(fraction, 4);
fraction = fraction % 16;
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
fraction = fraction * 256 + read_byte(fh);
if exponent == 0 then
return sign * m_ldexp(exponent, -149);
elseif exponent ~= 0xff then
return sign * m_ldexp(fraction + 2 ^ 52, exponent - 1075);
elseif fraction == 0 then
return sign * m_huge;
else
return NaN;
end
end
if s_unpack then
function read_float(fh) return s_unpack(">f", read_bytes(fh, 4)) end
function read_double(fh) return s_unpack(">d", read_bytes(fh, 8)) end
end
local function read_simple(fh, value, opts)
if value == 24 then
value = read_byte(fh);
end
if value == 20 then
return false;
elseif value == 21 then
return true;
elseif value == 22 then
return null;
elseif value == 23 then
return undefined;
elseif value == 25 then
return read_half_float(fh);
elseif value == 26 then
return read_float(fh);
elseif value == 27 then
return read_double(fh);
elseif value == 31 then
return BREAK;
end
if opts and opts.simple then
return opts.simple(value);
end
return simple(value);
end
decoder[0] = read_integer;
decoder[1] = read_negative_integer;
decoder[2] = read_string;
decoder[3] = read_unicode_string;
decoder[4] = read_array;
decoder[5] = read_map;
decoder[6] = read_semantic;
decoder[7] = read_simple;
-- opts.more(n) -> want more data
-- opts.simple -> decode simple value
-- opts[int] -> tagged decoder
local function decode(s, opts)
local fh = {};
local pos = 1;
local more;
if type(opts) == "function" then
more = opts;
elseif type(opts) == "table" then
more = opts.more;
elseif opts ~= nil then
error(("bad argument #2 to 'decode' (function or table expected, got %s)"):format(type(opts)));
end
if type(more) ~= "function" then
function more()
error "input too short";
end
end
function fh:read(bytes)
local ret = s:sub(pos, pos + bytes - 1);
if #ret < bytes then
ret = more(bytes - #ret, fh, opts);
if ret then self:write(ret); end
return self:read(bytes);
end
pos = pos + bytes;
return ret;
end
function fh:write(bytes) -- luacheck: no self
s = s .. bytes;
if pos > 256 then
s = s:sub(pos + 1);
pos = 1;
end
return #bytes;
end
return read_object(fh, opts);
end
return {
-- en-/decoder functions
encode = encode;
decode = decode;
decode_file = read_object;
-- tables of per-type en-/decoders
type_encoders = encoder;
type_decoders = decoder;
-- special treatment for tagged values
tagged_decoders = tagged_decoders;
-- constructors for annotated types
simple = simple;
tagged = tagged;
-- pre-defined simple values
null = null;
undefined = undefined;
};