diff --git a/src/bit_reader.zig b/src/bit_reader.zig index e79f0bf..adf2fe4 100644 --- a/src/bit_reader.zig +++ b/src/bit_reader.zig @@ -28,31 +28,34 @@ pub fn BitReader(comptime ReaderType: type) type { pub fn init(rdr: ReaderType) Self { var self = Self{ .rdr = rdr }; - self.ensureBits(1, 0) catch {}; + self.fill(1) catch {}; return self; } // Ensure that `nice` or at least `must` bits are available in buffer. // Reads from underlying reader if there is no `nice` bits in buffer. // Returns error if `must` bits can't be read. - pub inline fn ensureBits(self: *Self, nice: u6, must: u6) !void { - if (nice > self.nbits) { - // read more bits from underlying reader - var buf: [8]u8 = [_]u8{0} ** 8; - - const empty_bytes = - @as(u8, if (self.nbits & 0x7 == 0) 8 else 7) - // 8 for 8, 16, 24..., 7 otherwise - (self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8 - - const bytes_read = self.rdr.read(buf[0..empty_bytes]) catch 0; - if (bytes_read > 0) { - const u: u64 = std.mem.readInt(u64, buf[0..8], .little); - self.bits |= u << @as(u6, @intCast(self.nbits)); - self.nbits += 8 * @as(u8, @intCast(bytes_read)); - } - // than check again - if (must > self.nbits) return error.EndOfStream; + pub inline fn fill(self: *Self, nice: u6) !void { + if (self.nbits >= nice) + return; // we have enought bits + + // read more bits from underlying reader + var buf: [8]u8 = [_]u8{0} ** 8; + // number of empty bytes in bits + const empty_bytes = + @as(u8, if (self.nbits & 0x7 == 0) 8 else 7) - // 8 for 8, 16, 24..., 7 otherwise + (self.nbits >> 3); // 0 for 0-7, 1 for 8-16, ... same as / 8 + + const bytes_read = self.rdr.read(buf[0..empty_bytes]) catch 0; + if (bytes_read > 0) { + const u: u64 = std.mem.readInt(u64, buf[0..8], .little); + self.bits |= u << @as(u6, @intCast(self.nbits)); + self.nbits += 8 * @as(u8, @intCast(bytes_read)); + return; } + + if (self.bits == 0) + return error.EndOfStream; } // Read exactly buf.len bytes into buf. @@ -77,33 +80,33 @@ pub fn BitReader(comptime ReaderType: type) type { const n: u6 = @bitSizeOf(U); switch (how) { 0 => { - try self.ensureBits(n, n); + try self.fill(n); const u: U = @truncate(self.bits); - self.advance(n); + self.shift(n); return u; }, (flag.peek) => { - try self.ensureBits(n, n); + try self.fill(n); return @as(U, @truncate(self.bits)); }, flag.buffered => { const u: U = @truncate(self.bits); - self.advance(n); + self.shift(n); return u; }, (flag.reverse) => { - try self.ensureBits(n, n); + try self.fill(n); const u: U = @truncate(self.bits); - self.advance(n); + self.shift(n); return @bitReverse(u); }, (flag.peek | flag.reverse) => { - try self.ensureBits(n, n); + try self.fill(n); return @bitReverse(@as(U, @truncate(self.bits))); }, (flag.buffered | flag.reverse) => { const u: U = @truncate(self.bits); - self.advance(n); + self.shift(n); return @bitReverse(u); }, (flag.peek | flag.buffered | flag.reverse) => { @@ -113,27 +116,22 @@ pub fn BitReader(comptime ReaderType: type) type { } } - pub inline fn readBits(self: *Self, n: u4, comptime how: u3) !u16 { + pub inline fn readN(self: *Self, n: u4, comptime how: u3) !u16 { switch (how) { 0 => { - try self.ensureBits(n, n); - const mask: u16 = (@as(u16, 1) << n) - 1; - const u: u16 = @as(u16, @truncate(self.bits)) & mask; - self.advance(n); - return u; - }, - flag.buffered => { - const mask: u16 = (@as(u16, 1) << n) - 1; - const u: u16 = @as(u16, @truncate(self.bits)) & mask; - self.advance(n); - return u; + try self.fill(n); }, + flag.buffered => {}, else => unreachable, } + const mask: u16 = (@as(u16, 1) << n) - 1; + const u: u16 = @as(u16, @truncate(self.bits)) & mask; + self.shift(n); + return u; } // Advance buffer for n bits. - pub inline fn advance(self: *Self, n: u6) void { + pub inline fn shift(self: *Self, n: u6) void { assert(n <= self.nbits); self.bits >>= n; self.nbits -= n; @@ -142,8 +140,8 @@ pub fn BitReader(comptime ReaderType: type) type { // Skip n bytes. pub inline fn skipBytes(self: *Self, n: u16) !void { for (0..n) |_| { - try self.ensureBits(8, 8); - self.advance(8); + try self.fill(8); + self.shift(8); } } @@ -155,7 +153,7 @@ pub fn BitReader(comptime ReaderType: type) type { // Align stream to the byte boundary. pub inline fn alignToByte(self: *Self) void { const ab = self.alignBits(); - if (ab > 0) self.advance(ab); + if (ab > 0) self.shift(ab); } // Skip zero terminated string. @@ -179,7 +177,7 @@ pub fn BitReader(comptime ReaderType: type) type { // 280 - 287 8 11000000 through // 11000111 pub fn readFixedCode(self: *Self) !u16 { - try self.ensureBits(7 + 2, 7); + try self.fill(7 + 2); const code7 = try self.read(u7, flag.buffered | flag.reverse); if (code7 <= 0b0010_111) { // 7 bits, 256-279, codes 0000_000 - 0010_111 return @as(u16, code7) + 256; @@ -209,7 +207,7 @@ test "BitReader" { try testing.expect(try br.read(u8, F.peek) == 0b0001_1110); try testing.expect(try br.read(u9, F.peek) == 0b1_0001_1110); - br.advance(9); + br.shift(9); try testing.expectEqual(@as(u8, 36), br.nbits); try testing.expectEqual(@as(u3, 4), br.alignBits()); @@ -217,16 +215,16 @@ test "BitReader" { try testing.expectEqual(@as(u8, 32), br.nbits); try testing.expectEqual(@as(u3, 0), br.alignBits()); - br.advance(1); + br.shift(1); try testing.expectEqual(@as(u3, 7), br.alignBits()); - br.advance(1); + br.shift(1); try testing.expectEqual(@as(u3, 6), br.alignBits()); br.alignToByte(); try testing.expectEqual(@as(u3, 0), br.alignBits()); try testing.expectEqual(@as(u64, 0xc9), br.bits); - try testing.expectEqual(@as(u16, 0x9), try br.readBits(4, 0)); - try testing.expectEqual(@as(u16, 0xc), try br.readBits(4, 0)); + try testing.expectEqual(@as(u16, 0x9), try br.readN(4, 0)); + try testing.expectEqual(@as(u16, 0xc), try br.readN(4, 0)); } test "BitReader read block type 1 data" { @@ -262,19 +260,19 @@ test "BitReader init" { var br = bitReader(fbs.reader()); try testing.expectEqual(@as(u64, 0x08_07_06_05_04_03_02_01), br.bits); - br.advance(8); + br.shift(8); try testing.expectEqual(@as(u64, 0x00_08_07_06_05_04_03_02), br.bits); - try br.ensureBits(60, 0); // fill with 1 byte + try br.fill(60); // fill with 1 byte try testing.expectEqual(@as(u64, 0x01_08_07_06_05_04_03_02), br.bits); - br.advance(8 * 4 + 4); + br.shift(8 * 4 + 4); try testing.expectEqual(@as(u64, 0x00_00_00_00_00_10_80_70), br.bits); - try br.ensureBits(60, 0); // fill with 4 bytes (shift by 4) + try br.fill(60); // fill with 4 bytes (shift by 4) try testing.expectEqual(@as(u64, 0x00_50_40_30_20_10_80_70), br.bits); try testing.expectEqual(@as(u8, 8 * 7 + 4), br.nbits); - br.advance(@intCast(br.nbits)); // clear buffer - try br.ensureBits(8, 8); // refill with the rest of the bytes + br.shift(@intCast(br.nbits)); // clear buffer + try br.fill(8); // refill with the rest of the bytes try testing.expectEqual(@as(u64, 0x00_00_00_00_00_08_07_06), br.bits); } diff --git a/src/inflate.zig b/src/inflate.zig index 23b9c9c..5469d9a 100644 --- a/src/inflate.zig +++ b/src/inflate.zig @@ -29,7 +29,7 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { const F = BitReaderType.flag; return struct { - rdr: BitReaderType, + bits: BitReaderType, win: SlidingWindow = .{}, hasher: wrap.Hasher() = .{}, @@ -54,41 +54,41 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { const Self = @This(); pub fn init(rt: ReaderType) Self { - return .{ .rdr = BitReaderType.init(rt) }; + return .{ .bits = BitReaderType.init(rt) }; } inline fn decodeLength(self: *Self, code: u8) !u16 { assert(code <= 28); const ml = Token.matchLength(code); - return if (ml.extra_bits == 0) + return if (ml.extra_bits == 0) // 0 - 5 extra bits ml.base else - ml.base + try self.rdr.readBits(ml.extra_bits, F.buffered); + ml.base + try self.bits.readN(ml.extra_bits, F.buffered); } inline fn decodeDistance(self: *Self, code: u8) !u16 { assert(code <= 29); const mo = Token.matchOffset(code); - return if (mo.extra_bits == 0) + return if (mo.extra_bits == 0) // 0 - 13 extra bits mo.base else - mo.base + try self.rdr.readBits(mo.extra_bits, F.buffered); + mo.base + try self.bits.readN(mo.extra_bits, F.buffered); } fn blockHeader(self: *Self) !void { - self.bfinal = try self.rdr.read(u1, 0); - self.block_type = try self.rdr.read(u2, 0); + self.bfinal = try self.bits.read(u1, 0); + self.block_type = try self.bits.read(u2, 0); } fn storedBlock(self: *Self) !bool { - self.rdr.alignToByte(); // skip 5 bits (block header is 3 bits) - var len = try self.rdr.read(u16, 0); - const nlen = try self.rdr.read(u16, 0); + self.bits.alignToByte(); // skip 5 bits (block header is 3 bits) + var len = try self.bits.read(u16, 0); + const nlen = try self.bits.read(u16, 0); if (len != ~nlen) return error.DeflateWrongNlen; while (len > 0) { const buf = self.win.getWritable(len); - try self.rdr.readAll(buf); + try self.bits.readAll(buf); len -= @intCast(buf.len); } return true; @@ -102,7 +102,7 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { fn fixedBlock(self: *Self) !bool { while (!self.windowFull()) { - const code = try self.rdr.readFixedCode(); + const code = try self.bits.readFixedCode(); switch (code) { 0...255 => self.win.write(@intCast(code)), 256 => return true, // end of block @@ -116,22 +116,22 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { // Handles fixed block non literal (length) code. // Length code is followed by 5 bits of distance code. fn fixedDistanceCode(self: *Self, code: u8) !void { - try self.rdr.ensureBits(5 + 5 + 13, 5); + try self.bits.fill(5 + 5 + 13); const length = try self.decodeLength(code); - const distance = try self.decodeDistance(try self.rdr.read(u5, F.buffered)); + const distance = try self.decodeDistance(try self.bits.read(u5, F.buffered)); self.win.writeCopy(length, distance); } fn dynamicBlockHeader(self: *Self) !void { - const hlit: u16 = @as(u16, try self.rdr.read(u5, 0)) + 257; // number of ll code entries present - 257 - const hdist: u16 = @as(u16, try self.rdr.read(u5, 0)) + 1; // number of distance code entries - 1 - const hclen: u8 = @as(u8, try self.rdr.read(u4, 0)) + 4; // hclen + 4 code lenths are encoded + const hlit: u16 = @as(u16, try self.bits.read(u5, 0)) + 257; // number of ll code entries present - 257 + const hdist: u16 = @as(u16, try self.bits.read(u5, 0)) + 1; // number of distance code entries - 1 + const hclen: u8 = @as(u8, try self.bits.read(u4, 0)) + 4; // hclen + 4 code lenths are encoded // lengths for code lengths var cl_l = [_]u4{0} ** 19; const order = consts.huffman.codegen_order; for (0..hclen) |i| { - cl_l[order[i]] = try self.rdr.read(u3, 0); + cl_l[order[i]] = try self.bits.read(u3, 0); } self.cl_h.build(&cl_l); @@ -139,8 +139,8 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { var lit_l = [_]u4{0} ** (286); var pos: usize = 0; while (pos < hlit) { - const sym = self.cl_h.find(try self.rdr.read(u7, F.peek | F.reverse)); - self.rdr.advance(sym.code_bits); + const sym = self.cl_h.find(try self.bits.read(u7, F.peek | F.reverse)); + self.bits.shift(sym.code_bits); pos += try self.dynamicCodeLength(sym.symbol, &lit_l, pos); } @@ -148,8 +148,8 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { var dst_l = [_]u4{0} ** (30); pos = 0; while (pos < hdist) { - const sym = self.cl_h.find(try self.rdr.read(u7, F.peek | F.reverse)); - self.rdr.advance(sym.code_bits); + const sym = self.cl_h.find(try self.bits.read(u7, F.peek | F.reverse)); + self.bits.shift(sym.code_bits); pos += try self.dynamicCodeLength(sym.symbol, &dst_l, pos); } @@ -159,9 +159,9 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { fn dynamicBlock(self: *Self) !bool { while (!self.windowFull()) { - try self.rdr.ensureBits(15, 2); - const sym = self.lit_h.find(try self.rdr.read(u15, F.peek | F.buffered | F.reverse)); - self.rdr.advance(sym.code_bits); + try self.bits.fill(15); + const sym = self.lit_h.find(try self.bits.read(u15, F.peek | F.buffered | F.reverse)); + self.bits.shift(sym.code_bits); if (sym.kind == .literal) { self.win.write(sym.symbol); @@ -173,11 +173,11 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { } // decode backward pointer - try self.rdr.ensureBits(33, 2); + try self.bits.fill(5 + 15 + 13); const length = try self.decodeLength(sym.symbol); - const dsm = self.dst_h.find(try self.rdr.read(u15, F.peek | F.buffered | F.reverse)); // distance symbol - self.rdr.advance(dsm.code_bits); + const dsm = self.dst_h.find(try self.bits.read(u15, F.peek | F.buffered | F.reverse)); // distance symbol + self.bits.shift(dsm.code_bits); const distance = try self.decodeDistance(dsm.symbol); self.win.writeCopy(length, distance); @@ -193,16 +193,16 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { 16 => { // Copy the previous code length 3 - 6 times. // The next 2 bits indicate repeat length - const n: u8 = @as(u8, try self.rdr.read(u2, 0)) + 3; + const n: u8 = @as(u8, try self.bits.read(u2, 0)) + 3; for (0..n) |i| { lens[pos + i] = lens[pos + i - 1]; } return n; }, // Repeat a code length of 0 for 3 - 10 times. (3 bits of length) - 17 => return @as(u8, try self.rdr.read(u3, 0)) + 3, + 17 => return @as(u8, try self.bits.read(u3, 0)) + 3, // Repeat a code length of 0 for 11 - 138 times (7 bits of length) - 18 => return @as(u8, try self.rdr.read(u7, 0)) + 11, + 18 => return @as(u8, try self.bits.read(u7, 0)) + 11, else => { // Represent code lengths of 0 - 15 lens[pos] = @intCast(code); @@ -234,7 +234,7 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { } }, .protocol_footer => { - self.rdr.alignToByte(); + self.bits.alignToByte(); try self.parseFooter(); self.state = .end; }, @@ -243,11 +243,11 @@ pub fn Inflate(comptime wrap: Wrapper, comptime ReaderType: type) type { } fn parseHeader(self: *Self) !void { - try wrap.parseHeader(&self.rdr); + try wrap.parseHeader(&self.bits); } fn parseFooter(self: *Self) !void { - try wrap.parseFooter(&self.hasher, &self.rdr); + try wrap.parseFooter(&self.hasher, &self.bits); } /// Returns decompressed data from internal sliding window buffer.