diff --git a/src/huffman_decoder.zig b/src/huffman_decoder.zig index 4116dc9..12aee2e 100644 --- a/src/huffman_decoder.zig +++ b/src/huffman_decoder.zig @@ -9,10 +9,10 @@ pub const Symbol = packed struct { }; symbol: u8 = 0, // symbol from alphabet - code_bits: u4 = 0, // code bits count + code_bits: u4 = 0, // number of bits in code 0-15 kind: Kind = .literal, - code: u16 = 0, + code: u16 = 0, // huffman code of the symbol next: u16 = 0, // pointer to the next symbol in linked list // it is safe to use 0 as null pointer, when sorted 0 has shortest code and fits into lookup @@ -60,7 +60,7 @@ fn HuffmanDecoder( const Self = @This(); - /// Builds symbols and lookup tables from list of code lens for each symbol. + /// Generates symbols and lookup tables from list of code lens for each symbol. pub fn generate(self: *Self, lens: []const u4) void { // init alphabet with code_bits for (self.symbols, 0..) |_, i| { @@ -109,7 +109,7 @@ fn HuffmanDecoder( } /// Finds symbol for lookup table code. - pub fn find(self: *Self, code: u16) Symbol { + pub fn find(self: *Self, code: u16) !Symbol { // try to find in lookup table const idx = code >> lookup_shift; const sym = self.lookup[idx]; @@ -118,7 +118,7 @@ fn HuffmanDecoder( return self.findLinked(code, sym.next); } - inline fn findLinked(self: *Self, code: u16, start: u16) Symbol { + inline fn findLinked(self: *Self, code: u16, start: u16) !Symbol { var pos = start; while (pos > 0) { const sym = self.symbols[pos]; @@ -127,7 +127,7 @@ fn HuffmanDecoder( if ((code ^ sym.code) >> shift == 0) return sym; pos = sym.next; } - return .{}; + return error.CorruptInput; } }; } @@ -180,32 +180,32 @@ test "Huffman init/find" { for (expected, 12..) |e, i| { try testing.expectEqual(e.sym.symbol, h.symbols[i].symbol); try testing.expectEqual(e.sym.code_bits, h.symbols[i].code_bits); - const sym_from_code = h.find(e.code); + const sym_from_code = try h.find(e.code); try testing.expectEqual(e.sym.symbol, sym_from_code.symbol); } // All possible codes for each symbol. // Lookup table has 126 elements, to cover all possible 7 bit codes. for (0b0000_000..0b0100_000) |c| // 0..32 (32) - try testing.expectEqual(3, h.find(@intCast(c)).symbol); + try testing.expectEqual(3, (try h.find(@intCast(c))).symbol); for (0b0100_000..0b1000_000) |c| // 32..64 (32) - try testing.expectEqual(18, h.find(@intCast(c)).symbol); + try testing.expectEqual(18, (try h.find(@intCast(c))).symbol); for (0b1000_000..0b1010_000) |c| // 64..80 (16) - try testing.expectEqual(1, h.find(@intCast(c)).symbol); + try testing.expectEqual(1, (try h.find(@intCast(c))).symbol); for (0b1010_000..0b1100_000) |c| // 80..96 (16) - try testing.expectEqual(4, h.find(@intCast(c)).symbol); + try testing.expectEqual(4, (try h.find(@intCast(c))).symbol); for (0b1100_000..0b1110_000) |c| // 96..112 (16) - try testing.expectEqual(17, h.find(@intCast(c)).symbol); + try testing.expectEqual(17, (try h.find(@intCast(c))).symbol); for (0b1110_000..0b1111_000) |c| // 112..120 (8) - try testing.expectEqual(0, h.find(@intCast(c)).symbol); + try testing.expectEqual(0, (try h.find(@intCast(c))).symbol); for (0b1111_000..0b1_0000_000) |c| // 120...128 (8) - try testing.expectEqual(16, h.find(@intCast(c)).symbol); + try testing.expectEqual(16, (try h.find(@intCast(c))).symbol); } const print = std.debug.print; @@ -250,7 +250,7 @@ test "full " { if (c.len == 0) continue; const s_code: u15 = @bitReverse(@as(u15, @intCast(c.code))); - const s = dec.find(s_code); + const s = try dec.find(s_code); try expect(s.code == s_code); try expect(s.code_bits == c.len); } diff --git a/src/inflate.zig b/src/inflate.zig index f642a64..a62f017 100644 --- a/src/inflate.zig +++ b/src/inflate.zig @@ -51,8 +51,8 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { hasher: container.Hasher() = .{}, // dynamic block huffman code decoders - lit_h: hfd.LiteralDecoder = .{}, // literals - dst_h: hfd.DistanceDecoder = .{}, // distances + lit_dec: hfd.LiteralDecoder = .{}, // literals + dst_dec: hfd.DistanceDecoder = .{}, // distances // current read state bfinal: u1 = 0, @@ -146,41 +146,42 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { const hclen: u8 = @as(u8, try self.bits.read(u4)) + 4; // hclen + 4 code lenths are encoded // lengths for code lengths - var cl_l = [_]u4{0} ** 19; + var cl_lens = [_]u4{0} ** 19; for (0..hclen) |i| { - cl_l[codegen_order[i]] = try self.bits.read(u3); + cl_lens[codegen_order[i]] = try self.bits.read(u3); } - var cl_h: hfd.CodegenDecoder = .{}; - cl_h.generate(&cl_l); + var cl_dec: hfd.CodegenDecoder = .{}; + cl_dec.generate(&cl_lens); // literal code lengths - var lit_l = [_]u4{0} ** (286); + var lit_lens = [_]u4{0} ** (286); var pos: usize = 0; while (pos < hlit) { - const sym = cl_h.find(try self.bits.peekF(u7, F.reverse)); + const sym = try cl_dec.find(try self.bits.peekF(u7, F.reverse)); try self.bits.shift(sym.code_bits); - pos += try self.dynamicCodeLength(sym.symbol, &lit_l, pos); + pos += try self.dynamicCodeLength(sym.symbol, &lit_lens, pos); } // distance code lenths - var dst_l = [_]u4{0} ** (30); + var dst_lens = [_]u4{0} ** (30); pos = 0; while (pos < hdist) { - const sym = cl_h.find(try self.bits.peekF(u7, F.reverse)); + const sym = try cl_dec.find(try self.bits.peekF(u7, F.reverse)); try self.bits.shift(sym.code_bits); - pos += try self.dynamicCodeLength(sym.symbol, &dst_l, pos); + pos += try self.dynamicCodeLength(sym.symbol, &dst_lens, pos); } - self.lit_h.generate(&lit_l); - self.dst_h.generate(&dst_l); + self.lit_dec.generate(&lit_lens); + self.dst_dec.generate(&dst_lens); } // Decode code length symbol to code length. Writes decoded length into // lens slice starting at position pos. Returns number of positions // advanced. fn dynamicCodeLength(self: *Self, code: u16, lens: []u4, pos: usize) !usize { - if (pos >= lens.len or code > 18) + if (pos >= lens.len) return error.CorruptInput; + switch (code) { 0...15 => { // Represent code lengths of 0 - 15 @@ -201,7 +202,7 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { 17 => return @as(u8, try self.bits.read(u3)) + 3, // Repeat a code length of 0 for 11 - 138 times (7 bits of length) 18 => return @as(u8, try self.bits.read(u7)) + 11, - else => unreachable, + else => return error.CorruptInput, } } @@ -211,14 +212,14 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { // Hot path loop! while (!self.hist.full()) { try self.bits.fill(15); // optimization so other bit reads can be buffered (avoiding one `if` in hot path) - const sym = try self.decodeSymbol(&self.lit_h); + const sym = try self.decodeSymbol(&self.lit_dec); switch (sym.kind) { .literal => self.hist.write(sym.symbol), .match => { // Decode match backreference try self.bits.fill(5 + 15 + 13); // so we can use buffered reads const length = try self.decodeLength(sym.symbol); - const dsm = try self.decodeSymbol(&self.dst_h); + const dsm = try self.decodeSymbol(&self.dst_dec); const distance = try self.decodeDistance(dsm.symbol); try self.hist.writeMatch(length, distance); }, @@ -233,8 +234,7 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { // used. Shift bit reader for that much bits, those bits are used. And // return symbol. fn decodeSymbol(self: *Self, decoder: anytype) !hfd.Symbol { - const sym = decoder.find(try self.bits.peekF(u15, F.buffered | F.reverse)); - if (sym.code_bits == 0) return error.CorruptInput; + const sym = try decoder.find(try self.bits.peekF(u15, F.buffered | F.reverse)); try self.bits.shift(sym.code_bits); return sym; } @@ -243,7 +243,6 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { switch (self.state) { .protocol_header => { try container.parseHeader(&self.bits); - self.state = .block_header; }, .block_header => { @@ -264,7 +263,6 @@ pub fn Inflate(comptime container: Container, comptime ReaderType: type) type { }, .protocol_footer => { self.bits.alignToByte(); - try container.parseFooter(&self.hasher, &self.bits); self.state = .end; },