diff --git a/python/src/ccf/ledger.py b/python/src/ccf/ledger.py index 2b35b754e4e..e243c8abef7 100644 --- a/python/src/ccf/ledger.py +++ b/python/src/ccf/ledger.py @@ -23,7 +23,6 @@ from ccf.cose import validate_cose_sign1 import ccf.receipt from hashlib import sha256 -import functools GCM_SIZE_TAG = 16 GCM_SIZE_IV = 12 @@ -44,6 +43,8 @@ # Key used by CCF to record single-key tables WELL_KNOWN_SINGLETON_TABLE_KEY = bytes(bytearray(8)) +SHA256_DIGEST_SIZE = sha256().digest_size + class NodeStatus(Enum): PENDING = "Pending" @@ -86,10 +87,8 @@ def is_ledger_chunk_committed(file_name): return file_name.endswith(COMMITTED_FILE_SUFFIX) -def digest(algo, data): - h = hashes.Hash(algo) - h.update(data) - return h.finalize() +def digest(data): + return sha256(data).digest() def unpack(stream, fmt): @@ -100,10 +99,7 @@ def unpack(stream, fmt): return struct.unpack(fmt, buf)[0] -def unpack_array(stream, fmt, length): - buf = stream.read(length) - if not buf: - raise EOFError # Reached end of stream +def unpack_array(buf, fmt): unpack_iter = struct.iter_unpack(fmt, buf) ret = [] while True: @@ -131,20 +127,18 @@ def range_from_filename(filename: str) -> Tuple[int, Optional[int]]: class GcmHeader: - _gcm_tag = ["\0"] * GCM_SIZE_TAG - _gcm_iv = ["\0"] * GCM_SIZE_IV - view: int seqno: int def __init__(self, buffer): if len(buffer) < GcmHeader.size(): raise ValueError("Corrupt GCM header") - self._gcm_tag = struct.unpack(f"@{GCM_SIZE_TAG}B", buffer[:GCM_SIZE_TAG]) - self._gcm_iv = struct.unpack(f"@{GCM_SIZE_IV}B", buffer[GCM_SIZE_TAG:]) - self.seqno = struct.unpack("@Q", bytes(self._gcm_iv[:8]))[0] - self.view = struct.unpack("@I", bytes(self._gcm_iv[8:]))[0] & 0x7FFFFFFF + # _gcm_tag = buffer[:GCM_SIZE_TAG] # Unused + _gcm_iv = buffer[GCM_SIZE_TAG : GCM_SIZE_TAG + GCM_SIZE_IV] + + self.seqno = struct.unpack("@Q", _gcm_iv[:8])[0] + self.view = struct.unpack("@I", _gcm_iv[8:])[0] & 0x7FFFFFFF @staticmethod def size(): @@ -156,24 +150,30 @@ class PublicDomain: All public tables within a :py:class:`ccf.ledger.Transaction`. """ - _buffer: io.BytesIO - _buffer_size: int + _buffer: bytes + _cursor: int _entry_type: EntryType _claims_digest: bytes _version: int _max_conflict_version: int _tables: dict - def __init__(self, buffer: io.BytesIO): + def __init__(self, buffer: bytes): + self._entry_type = EntryType(buffer[0]) + + # Already read a 1-byte entry-type, so start from 1 not 0 + self._cursor = 1 self._buffer = buffer - self._buffer_size = self._buffer.getbuffer().nbytes - self._entry_type = self._read_entry_type() - self._version = self._read_version() + + self._version = self._read_int64() + if self._entry_type.has_claims(): - self._claims_digest = self._read_claims_digest() + self._claims_digest = self._read_buffer(SHA256_DIGEST_SIZE) + if self._entry_type.has_commit_evidence(): - self._commit_evidence_digest = self._read_commit_evidence_digest() - self._max_conflict_version = self._read_version() + self._commit_evidence_digest = self._read_buffer(SHA256_DIGEST_SIZE) + + self._max_conflict_version = self._read_int64() if self._entry_type == EntryType.SNAPSHOT: self._read_snapshot_header() @@ -181,63 +181,60 @@ def __init__(self, buffer: io.BytesIO): self._tables = {} self._read() - def is_deprecated(self): - return self._entry_type.is_deprecated() + def _read_buffer(self, size): + prev_cursor = self._cursor + self._cursor += size + return self._buffer[prev_cursor : self._cursor] - def _read_entry_type(self): - val = unpack(self._buffer, " PublicDomain: :return: :py:class:`ccf.ledger.PublicDomain` """ if self._public_domain is None: - buffer = io.BytesIO(_byte_read_safe(self._file, self._public_domain_size)) + current_pos = self._file.tell() + buffer = self._buffer[current_pos : current_pos + self._public_domain_size] + self._file.seek(self._public_domain_size, 1) self._public_domain = PublicDomain(buffer) return self._public_domain @@ -714,7 +711,6 @@ class Transaction(Entry): _next_offset: int = LEDGER_HEADER_SIZE _tx_offset: int = 0 _ledger_validator: Optional[LedgerValidator] = None - _dgst = functools.partial(digest, hashes.SHA256()) def __init__( self, filename: str, ledger_validator: Optional[LedgerValidator] = None @@ -758,8 +754,7 @@ def get_offsets(self) -> Tuple[int, int]: return (self._tx_offset, self._next_offset) def get_write_set_digest(self) -> bytes: - self._dgst = functools.partial(digest, hashes.SHA256()) - return self._dgst(self.get_raw_tx()) + return digest(self.get_raw_tx()) def get_tx_digest(self) -> bytes: claims_digest = self.get_public_domain().get_claims_digest() @@ -769,12 +764,12 @@ def get_tx_digest(self) -> bytes: if commit_evidence_digest is None: return write_set_digest else: - return self._dgst(write_set_digest + commit_evidence_digest) + return digest(write_set_digest + commit_evidence_digest) else: assert ( commit_evidence_digest ), "Invalid transaction: commit_evidence_digest not set" - return self._dgst(write_set_digest + commit_evidence_digest + claims_digest) + return digest(write_set_digest + commit_evidence_digest + claims_digest) def _complete_read(self): self._file.seek(self._next_offset, 0)