Skip to content

Commit

Permalink
Encapsulate code table details in a CodeTable class
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Dec 2, 2021
1 parent 4ac5da3 commit a4d705f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 36 deletions.
107 changes: 72 additions & 35 deletions dahuffman/huffmancodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pickle
from pathlib import Path
from typing import Union, Any, Callable, Iterator, Optional, Mapping, Iterable
from typing import Union, Any, Callable, Iterator, Optional, Mapping, Iterable, Tuple

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,58 +65,97 @@ def ensure_dir(path: Union[str, Path]) -> Path:
return path


class CodeTable:
"""
Code table: mapping a symbol to codes (and vice versa).
The symbols are the things you want to encode, usually characters in a string
or byte sequence, but it can be anything hashable.
The codes are the corresponding bit sequences, represented as a tuple (bits, value)
where `bits` is the number of bits and `value` the integer interpretation of these bits.
"""

def __init__(self, symbol_code_map: dict):
self._symbol_map = {}
self._code_map = {}
for symbol, (bits, value) in symbol_code_map.items():
assert isinstance(bits, int) and bits >= 1, f"Invalid bit count {bits}"
assert isinstance(value, int) and value >= 0, f"Invalid code value {value}"
self._symbol_map[symbol] = (bits, value)
self._code_map[(bits, value)] = symbol
# TODO check if code table is actually a prefix code

def __len__(self) -> int:
return len(self._symbol_map)

def get_code(self, symbol: Any) -> Tuple[int, int]:
"""Get code for given symbol (encode)."""
# TODO: raise custom EncodeException instead of KeyError?
return self._symbol_map[symbol]

def has_code(self, bits: int, value: int) -> bool:
"""Check if code is valid or defined in code table."""
return (bits, value) in self._code_map

def get_symbol(self, bits: int, value: int) -> Any:
"""Get symbol for given code (decode)"""
# TODO: raise custom DecodeException instead of KeyError?
return self._code_map[(bits, value)]

def print(self, out: IOBase = sys.stdout) -> None:
"""
Print code table overview
"""
# TODO: add sort options?
# Render table cells as string
columns = list(zip(*itertools.chain(
[('Bits', 'Code', 'Value', 'Symbol')],
(
(str(bits), bin(val)[2:].rjust(bits, '0'), str(val), repr(symbol))
for symbol, (bits, val) in self._symbol_map.items()
)
)))
# Find column widths and build row template
widths = tuple(max(len(s) for s in col) for col in columns)
template = '{0:>%d} {1:%d} {2:>%d} {3}\n' % widths[:3]
for row in zip(*columns):
out.write(template.format(*row))


class PrefixCodec:
"""
Prefix code codec, using given code table.
"""

def __init__(
self, code_table: dict, concat: Callable = list, check: bool = True, eof=_EOF
self, code_table: Union[CodeTable, dict], concat: Callable = list, eof=_EOF
):
"""
Initialize codec with given code table.
:param code_table: mapping of symbol to code tuple (bitsize, value)
:param code_table: mapping between symbols and bit codes
:param concat: function to concatenate symbols
:param check: whether to check the code table
:param eof: "end of file" symbol (customizable for advanced usage)
"""
# Code table is dictionary mapping symbol to (bitsize, value)
self._table = code_table
self._table = (
code_table if isinstance(code_table, CodeTable) else CodeTable(code_table)
)
self._concat = concat
self._eof = eof
if check:
assert isinstance(self._table, dict) and all(
isinstance(b, int) and b >= 1 and isinstance(v, int) and v >= 0
for (b, v) in self._table.values()
)
# TODO check if code table is actually a prefix code

def get_code_table(self) -> dict:
def get_code_table(self) -> CodeTable:
"""
Get code table
:return: dictionary mapping symbol to code tuple (bitsize, value)
:return: `CodeTable` object
"""
return self._table

def print_code_table(self, out: IOBase = sys.stdout) -> None:
"""
Print code table overview
"""
# TODO: add sort options?
# Render table cells as string
columns = list(zip(*itertools.chain(
[('Bits', 'Code', 'Value', 'Symbol')],
(
(str(bits), bin(val)[2:].rjust(bits, '0'), str(val), repr(symbol))
for symbol, (bits, val) in self._table.items()
)
)))
# Find column widths and build row template
widths = tuple(max(len(s) for s in col) for col in columns)
template = "{0:>%d} {1:%d} {2:>%d} {3}\n" % widths[:3]
for row in zip(*columns):
out.write(template.format(*row))
return self._table.print(out=out)

def encode(self, data: Union[str, bytes, Iterable]) -> bytes:
"""
Expand All @@ -137,8 +177,7 @@ def encode_streaming(self, data: Union[str, bytes, Iterable]) -> Iterator[int]:
buffer = 0
size = 0
for s in data:
# TODO: raise custom EncodeException instead of KeyError?
b, v = self._table[s]
b, v = self._table.get_code(s)
# Shift new bits in the buffer
buffer = (buffer << b) + v
size += b
Expand All @@ -156,7 +195,7 @@ def encode_streaming(self, data: Union[str, bytes, Iterable]) -> Iterator[int]:
# the end of the current byte and cut off there.
# No new byte has to be started for the remainder, saving us one (or more) output bytes.
if size > 0:
b, v = self._table[self._eof]
b, v = self._table.get_code(self._eof)
buffer = (buffer << b) + v
size += b
if size >= 8:
Expand Down Expand Up @@ -184,17 +223,15 @@ def decode_streaming(self, data: Union[bytes, Iterable[int]]) -> Iterator:
:param data: sequence of bytes (string, list or generator of bytes)
:return: generator of symbols
"""
# Reverse lookup table: map (bitsize, value) to symbols
lookup = {(b, v): s for s, (b, v) in self._table.items()}

buffer = 0
size = 0
for byte in data:
for m in [128, 64, 32, 16, 8, 4, 2, 1]:
buffer = (buffer << 1) + bool(byte & m)
size += 1
if (size, buffer) in lookup:
symbol = lookup[size, buffer]
if self._table.has_code(bits=size, value=buffer):
symbol = self._table.get_symbol(size, buffer)
if symbol == self._eof:
return
yield symbol
Expand Down Expand Up @@ -286,9 +323,9 @@ def from_frequencies(
heappush(heap, merged)

# Code table is dictionary mapping symbol to (bitsize, value)
table = dict(heappop(heap)[1])
table = CodeTable(dict(heappop(heap)[1]))

return cls(table, concat=concat, check=False, eof=eof)
return cls(table, concat=concat, eof=eof)

@classmethod
def from_data(cls, data: Union[str, bytes, Iterable]) -> "HuffmanCodec":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dahuffman.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def test_prefix_codec():
code_table = {"A": (2, 0), "B": (2, 1), _EOF: (2, 3)}
codec = PrefixCodec(code_table, check=True)
codec = PrefixCodec(code_table)
encoded = codec.encode("ABBA")
assert encoded == b"\x14"

Expand Down

0 comments on commit a4d705f

Please sign in to comment.