Skip to content

Commit

Permalink
return error if code can't be found
Browse files Browse the repository at this point in the history
Instead of catching that later.
  • Loading branch information
ianic committed Feb 9, 2024
1 parent 5ca1454 commit 05d294e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 37 deletions.
30 changes: 15 additions & 15 deletions src/huffman_decoder.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ pub const Symbol = packed struct {
};

symbol: u8 = 0, // symbol from alphabet
code_bits: u4 = 0, // code bits count
code_bits: u4 = 0, // number of bits in code 0-15
kind: Kind = .literal,

code: u16 = 0,
code: u16 = 0, // huffman code of the symbol
next: u16 = 0, // pointer to the next symbol in linked list
// it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup

Expand Down Expand Up @@ -60,7 +60,7 @@ fn HuffmanDecoder(

const Self = @This();

/// Builds symbols and lookup tables from list of code lens for each symbol.
/// Generates symbols and lookup tables from list of code lens for each symbol.
pub fn generate(self: *Self, lens: []const u4) void {
// init alphabet with code_bits
for (self.symbols, 0..) |_, i| {
Expand Down Expand Up @@ -109,7 +109,7 @@ fn HuffmanDecoder(
}

/// Finds symbol for lookup table code.
pub fn find(self: *Self, code: u16) Symbol {
pub fn find(self: *Self, code: u16) !Symbol {
// try to find in lookup table
const idx = code >> lookup_shift;
const sym = self.lookup[idx];
Expand All @@ -118,7 +118,7 @@ fn HuffmanDecoder(
return self.findLinked(code, sym.next);
}

inline fn findLinked(self: *Self, code: u16, start: u16) Symbol {
inline fn findLinked(self: *Self, code: u16, start: u16) !Symbol {
var pos = start;
while (pos > 0) {
const sym = self.symbols[pos];
Expand All @@ -127,7 +127,7 @@ fn HuffmanDecoder(
if ((code ^ sym.code) >> shift == 0) return sym;
pos = sym.next;
}
return .{};
return error.CorruptInput;
}
};
}
Expand Down Expand Up @@ -180,32 +180,32 @@ test "Huffman init/find" {
for (expected, 12..) |e, i| {
try testing.expectEqual(e.sym.symbol, h.symbols[i].symbol);
try testing.expectEqual(e.sym.code_bits, h.symbols[i].code_bits);
const sym_from_code = h.find(e.code);
const sym_from_code = try h.find(e.code);
try testing.expectEqual(e.sym.symbol, sym_from_code.symbol);
}

// All possible codes for each symbol.
// Lookup table has 126 elements, to cover all possible 7 bit codes.
for (0b0000_000..0b0100_000) |c| // 0..32 (32)
try testing.expectEqual(3, h.find(@intCast(c)).symbol);
try testing.expectEqual(3, (try h.find(@intCast(c))).symbol);

for (0b0100_000..0b1000_000) |c| // 32..64 (32)
try testing.expectEqual(18, h.find(@intCast(c)).symbol);
try testing.expectEqual(18, (try h.find(@intCast(c))).symbol);

for (0b1000_000..0b1010_000) |c| // 64..80 (16)
try testing.expectEqual(1, h.find(@intCast(c)).symbol);
try testing.expectEqual(1, (try h.find(@intCast(c))).symbol);

for (0b1010_000..0b1100_000) |c| // 80..96 (16)
try testing.expectEqual(4, h.find(@intCast(c)).symbol);
try testing.expectEqual(4, (try h.find(@intCast(c))).symbol);

for (0b1100_000..0b1110_000) |c| // 96..112 (16)
try testing.expectEqual(17, h.find(@intCast(c)).symbol);
try testing.expectEqual(17, (try h.find(@intCast(c))).symbol);

for (0b1110_000..0b1111_000) |c| // 112..120 (8)
try testing.expectEqual(0, h.find(@intCast(c)).symbol);
try testing.expectEqual(0, (try h.find(@intCast(c))).symbol);

for (0b1111_000..0b1_0000_000) |c| // 120...128 (8)
try testing.expectEqual(16, h.find(@intCast(c)).symbol);
try testing.expectEqual(16, (try h.find(@intCast(c))).symbol);
}

const print = std.debug.print;
Expand Down Expand Up @@ -250,7 +250,7 @@ test "full " {
if (c.len == 0) continue;

const s_code: u15 = @bitReverse(@as(u15, @intCast(c.code)));
const s = dec.find(s_code);
const s = try dec.find(s_code);
try expect(s.code == s_code);
try expect(s.code_bits == c.len);
}
Expand Down
42 changes: 20 additions & 22 deletions src/inflate.zig
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
hasher: container.Hasher() = .{},

// dynamic block huffman code decoders
lit_h: hfd.LiteralDecoder = .{}, // literals
dst_h: hfd.DistanceDecoder = .{}, // distances
lit_dec: hfd.LiteralDecoder = .{}, // literals
dst_dec: hfd.DistanceDecoder = .{}, // distances

// current read state
bfinal: u1 = 0,
Expand Down Expand Up @@ -146,41 +146,42 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
const hclen: u8 = @as(u8, try self.bits.read(u4)) + 4; // hclen + 4 code lenths are encoded

// lengths for code lengths
var cl_l = [_]u4{0} ** 19;
var cl_lens = [_]u4{0} ** 19;
for (0..hclen) |i| {
cl_l[codegen_order[i]] = try self.bits.read(u3);
cl_lens[codegen_order[i]] = try self.bits.read(u3);
}
var cl_h: hfd.CodegenDecoder = .{};
cl_h.generate(&cl_l);
var cl_dec: hfd.CodegenDecoder = .{};
cl_dec.generate(&cl_lens);

// literal code lengths
var lit_l = [_]u4{0} ** (286);
var lit_lens = [_]u4{0} ** (286);
var pos: usize = 0;
while (pos < hlit) {
const sym = cl_h.find(try self.bits.peekF(u7, F.reverse));
const sym = try cl_dec.find(try self.bits.peekF(u7, F.reverse));
try self.bits.shift(sym.code_bits);
pos += try self.dynamicCodeLength(sym.symbol, &lit_l, pos);
pos += try self.dynamicCodeLength(sym.symbol, &lit_lens, pos);
}

// distance code lenths
var dst_l = [_]u4{0} ** (30);
var dst_lens = [_]u4{0} ** (30);
pos = 0;
while (pos < hdist) {
const sym = cl_h.find(try self.bits.peekF(u7, F.reverse));
const sym = try cl_dec.find(try self.bits.peekF(u7, F.reverse));
try self.bits.shift(sym.code_bits);
pos += try self.dynamicCodeLength(sym.symbol, &dst_l, pos);
pos += try self.dynamicCodeLength(sym.symbol, &dst_lens, pos);
}

self.lit_h.generate(&lit_l);
self.dst_h.generate(&dst_l);
self.lit_dec.generate(&lit_lens);
self.dst_dec.generate(&dst_lens);
}

// Decode code length symbol to code length. Writes decoded length into
// lens slice starting at position pos. Returns number of positions
// advanced.
fn dynamicCodeLength(self: *Self, code: u16, lens: []u4, pos: usize) !usize {
if (pos >= lens.len or code > 18)
if (pos >= lens.len)
return error.CorruptInput;

switch (code) {
0...15 => {
// Represent code lengths of 0 - 15
Expand All @@ -201,7 +202,7 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
17 => return @as(u8, try self.bits.read(u3)) + 3,
// Repeat a code length of 0 for 11 - 138 times (7 bits of length)
18 => return @as(u8, try self.bits.read(u7)) + 11,
else => unreachable,
else => return error.CorruptInput,
}
}

Expand All @@ -211,14 +212,14 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
// Hot path loop!
while (!self.hist.full()) {
try self.bits.fill(15); // optimization so other bit reads can be buffered (avoiding one `if` in hot path)
const sym = try self.decodeSymbol(&self.lit_h);
const sym = try self.decodeSymbol(&self.lit_dec);

switch (sym.kind) {
.literal => self.hist.write(sym.symbol),
.match => { // Decode match backreference <length, distance>
try self.bits.fill(5 + 15 + 13); // so we can use buffered reads
const length = try self.decodeLength(sym.symbol);
const dsm = try self.decodeSymbol(&self.dst_h);
const dsm = try self.decodeSymbol(&self.dst_dec);
const distance = try self.decodeDistance(dsm.symbol);
try self.hist.writeMatch(length, distance);
},
Expand All @@ -233,8 +234,7 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
// used. Shift bit reader for that much bits, those bits are used. And
// return symbol.
fn decodeSymbol(self: *Self, decoder: anytype) !hfd.Symbol {
const sym = decoder.find(try self.bits.peekF(u15, F.buffered | F.reverse));
if (sym.code_bits == 0) return error.CorruptInput;
const sym = try decoder.find(try self.bits.peekF(u15, F.buffered | F.reverse));
try self.bits.shift(sym.code_bits);
return sym;
}
Expand All @@ -243,7 +243,6 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
switch (self.state) {
.protocol_header => {
try container.parseHeader(&self.bits);

self.state = .block_header;
},
.block_header => {
Expand All @@ -264,7 +263,6 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type {
},
.protocol_footer => {
self.bits.alignToByte();

try container.parseFooter(&self.hasher, &self.bits);
self.state = .end;
},
Expand Down

0 comments on commit 05d294e

Please sign in to comment.