Skip to content

Commit

Permalink
More python ledger optimisations: Use hashlib rather than hazmat (#6708)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyashton authored Dec 17, 2024
1 parent c85da08 commit 81abbc4
Showing 1 changed file with 74 additions and 79 deletions.
153 changes: 74 additions & 79 deletions python/src/ccf/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ccf.cose import validate_cose_sign1
import ccf.receipt
from hashlib import sha256
import functools

GCM_SIZE_TAG = 16
GCM_SIZE_IV = 12
Expand All @@ -44,6 +43,8 @@
# Key used by CCF to record single-key tables
WELL_KNOWN_SINGLETON_TABLE_KEY = bytes(bytearray(8))

SHA256_DIGEST_SIZE = sha256().digest_size


class NodeStatus(Enum):
PENDING = "Pending"
Expand Down Expand Up @@ -86,10 +87,8 @@ def is_ledger_chunk_committed(file_name):
return file_name.endswith(COMMITTED_FILE_SUFFIX)


def digest(algo, data):
h = hashes.Hash(algo)
h.update(data)
return h.finalize()
def digest(data):
return sha256(data).digest()


def unpack(stream, fmt):
Expand All @@ -100,10 +99,7 @@ def unpack(stream, fmt):
return struct.unpack(fmt, buf)[0]


def unpack_array(stream, fmt, length):
buf = stream.read(length)
if not buf:
raise EOFError # Reached end of stream
def unpack_array(buf, fmt):
unpack_iter = struct.iter_unpack(fmt, buf)
ret = []
while True:
Expand Down Expand Up @@ -131,20 +127,18 @@ def range_from_filename(filename: str) -> Tuple[int, Optional[int]]:


class GcmHeader:
_gcm_tag = ["\0"] * GCM_SIZE_TAG
_gcm_iv = ["\0"] * GCM_SIZE_IV

view: int
seqno: int

def __init__(self, buffer):
if len(buffer) < GcmHeader.size():
raise ValueError("Corrupt GCM header")
self._gcm_tag = struct.unpack(f"@{GCM_SIZE_TAG}B", buffer[:GCM_SIZE_TAG])
self._gcm_iv = struct.unpack(f"@{GCM_SIZE_IV}B", buffer[GCM_SIZE_TAG:])

self.seqno = struct.unpack("@Q", bytes(self._gcm_iv[:8]))[0]
self.view = struct.unpack("@I", bytes(self._gcm_iv[8:]))[0] & 0x7FFFFFFF
# _gcm_tag = buffer[:GCM_SIZE_TAG] # Unused
_gcm_iv = buffer[GCM_SIZE_TAG : GCM_SIZE_TAG + GCM_SIZE_IV]

self.seqno = struct.unpack("@Q", _gcm_iv[:8])[0]
self.view = struct.unpack("@I", _gcm_iv[8:])[0] & 0x7FFFFFFF

@staticmethod
def size():
Expand All @@ -156,88 +150,91 @@ class PublicDomain:
All public tables within a :py:class:`ccf.ledger.Transaction`.
"""

_buffer: io.BytesIO
_buffer_size: int
_buffer: bytes
_cursor: int
_entry_type: EntryType
_claims_digest: bytes
_version: int
_max_conflict_version: int
_tables: dict

def __init__(self, buffer: io.BytesIO):
def __init__(self, buffer: bytes):
self._entry_type = EntryType(buffer[0])

# Already read a 1-byte entry-type, so start from 1 not 0
self._cursor = 1
self._buffer = buffer
self._buffer_size = self._buffer.getbuffer().nbytes
self._entry_type = self._read_entry_type()
self._version = self._read_version()

self._version = self._read_int64()

if self._entry_type.has_claims():
self._claims_digest = self._read_claims_digest()
self._claims_digest = self._read_buffer(SHA256_DIGEST_SIZE)

if self._entry_type.has_commit_evidence():
self._commit_evidence_digest = self._read_commit_evidence_digest()
self._max_conflict_version = self._read_version()
self._commit_evidence_digest = self._read_buffer(SHA256_DIGEST_SIZE)

self._max_conflict_version = self._read_int64()

if self._entry_type == EntryType.SNAPSHOT:
self._read_snapshot_header()

self._tables = {}
self._read()

def is_deprecated(self):
return self._entry_type.is_deprecated()
def _read_buffer(self, size):
prev_cursor = self._cursor
self._cursor += size
return self._buffer[prev_cursor : self._cursor]

def _read_entry_type(self):
val = unpack(self._buffer, "<B")
return EntryType(val)
def _read8(self):
return self._read_buffer(8)

def _read_claims_digest(self):
return self._buffer.read(hashes.SHA256.digest_size)
def _read_int64(self):
return struct.unpack("<q", self._read8())[0]

def _read_commit_evidence_digest(self):
return self._buffer.read(hashes.SHA256.digest_size)
def _read_uint64(self):
return struct.unpack("<Q", self._read8())[0]

def _read_version(self):
return unpack(self._buffer, "<q")
def is_deprecated(self):
return self._entry_type.is_deprecated()

def get_version_size(self):
return struct.calcsize("<q")
return 8

def _read_versioned_value(self, size):
if size < self.get_version_size():
raise ValueError(f"Invalid versioned value of size {size}")
return (self._read_version(), self._buffer.read(size - self.get_version_size()))
return (self._read_uint64(), self._read_buffer(size - self.get_version_size()))

def _read_size(self):
return unpack(self._buffer, "<Q")
def _read_next_entry(self):
size = self._read_uint64()
return self._read_buffer(size)

def _read_string(self):
size = self._read_size()
return self._buffer.read(size).decode()

def _read_next_entry(self):
size = self._read_size()
return self._buffer.read(size)
return self._read_next_entry().decode()

def _read_snapshot_header(self):
# read hash of entry at snapshot
hash_size = self._read_size()
buffer = unpack(self._buffer, f"<{hash_size}s")
hash_size = self._read_uint64()
buffer = self._read_buffer(hash_size)
self._hash_at_snapshot = buffer.hex()

# read view history
view_history_size = self._read_size()
self._view_history = unpack_array(self._buffer, "<Q", view_history_size)
view_history_size = self._read_uint64()
self._view_history = unpack_array(self._read_buffer(view_history_size), "<Q")

def _read_snapshot_entry_padding(self, size):
padding = -size % 8 # Padded to 8 bytes
self._buffer.read(padding)
self._cursor += padding

def _read_snapshot_key(self):
size = self._read_size()
key = self._buffer.read(size)
size = self._read_uint64()
key = self._read_buffer(size)
self._read_snapshot_entry_padding(size)
return key

def _read_snapshot_versioned_value(self):
size = self._read_size()
size = self._read_uint64()
ver, value = self._read_versioned_value(size)
if ver < 0:
assert (
Expand All @@ -248,44 +245,42 @@ def _read_snapshot_versioned_value(self):
return value

def _read(self):
while True:
try:
map_name = self._read_string()
except EOFError:
break
buffer_size = len(self._buffer)
while self._cursor < buffer_size:
map_name = self._read_string()

records = {}
self._tables[map_name] = records

if self._entry_type == EntryType.SNAPSHOT:
# map snapshot version
self._read_version()
self._read8()

# size of map entry
map_size = self._read_size()
start_map_pos = self._buffer.tell()
map_size = self._read_uint64()
start_map_pos = self._cursor

while self._buffer.tell() - start_map_pos < map_size:
while self._cursor - start_map_pos < map_size:
k = self._read_snapshot_key()
val = self._read_snapshot_versioned_value()
records[k] = val
else:
# read_version
self._read_version()
self._read8()

# read_count
# Note: Read keys are not currently included in ledger transactions
read_count = self._read_size()
read_count = self._read_uint64()
assert read_count == 0, f"Unexpected read count: {read_count}"

write_count = self._read_size()
write_count = self._read_uint64()
if write_count:
for _ in range(write_count):
k = self._read_next_entry()
val = self._read_next_entry()
records[k] = val

remove_count = self._read_size()
remove_count = self._read_uint64()
if remove_count:
for _ in range(remove_count):
k = self._read_next_entry()
Expand Down Expand Up @@ -384,7 +379,7 @@ def __init__(self, accept_deprecated_entry_types: bool = True):
# Start with empty bytes array. CCF MerkleTree uses an empty array as the first leaf of its merkle tree.
# Don't hash empty bytes array.
self.merkle = MerkleTree()
empty_bytes_array = bytearray(hashes.SHA256.digest_size)
empty_bytes_array = bytearray(SHA256_DIGEST_SIZE)
self.merkle.add_leaf(empty_bytes_array, do_hash=False)

self.last_verified_seqno = 0
Expand Down Expand Up @@ -609,7 +604,7 @@ class TransactionHeader:
size: int

def __init__(self, buffer):
if len(buffer) < TransactionHeader.get_size():
if len(buffer) != TransactionHeader.get_size():
raise ValueError("Incomplete transaction header")

self.version = int.from_bytes(
Expand Down Expand Up @@ -637,7 +632,7 @@ def get_size():


class Entry:
_file: Optional[BinaryIO] = None
_file: BinaryIO
_header: TransactionHeader
_public_domain_size: int = 0
_public_domain: Optional[PublicDomain] = None
Expand All @@ -648,9 +643,9 @@ def __init__(self, filename: str):
if type(self) is Entry:
raise TypeError("Entry is not instantiable")

self._file = open(filename, mode="rb")
if self._file is None:
raise RuntimeError(f"File {filename} could not be opened")
with open(filename, mode="rb") as f:
self._buffer = f.read()
self._file = io.BytesIO(self._buffer)

def __enter__(self):
return self
Expand Down Expand Up @@ -690,7 +685,9 @@ def get_public_domain(self) -> PublicDomain:
:return: :py:class:`ccf.ledger.PublicDomain`
"""
if self._public_domain is None:
buffer = io.BytesIO(_byte_read_safe(self._file, self._public_domain_size))
current_pos = self._file.tell()
buffer = self._buffer[current_pos : current_pos + self._public_domain_size]
self._file.seek(self._public_domain_size, 1)
self._public_domain = PublicDomain(buffer)
return self._public_domain

Expand All @@ -714,7 +711,6 @@ class Transaction(Entry):
_next_offset: int = LEDGER_HEADER_SIZE
_tx_offset: int = 0
_ledger_validator: Optional[LedgerValidator] = None
_dgst = functools.partial(digest, hashes.SHA256())

def __init__(
self, filename: str, ledger_validator: Optional[LedgerValidator] = None
Expand Down Expand Up @@ -758,8 +754,7 @@ def get_offsets(self) -> Tuple[int, int]:
return (self._tx_offset, self._next_offset)

def get_write_set_digest(self) -> bytes:
self._dgst = functools.partial(digest, hashes.SHA256())
return self._dgst(self.get_raw_tx())
return digest(self.get_raw_tx())

def get_tx_digest(self) -> bytes:
claims_digest = self.get_public_domain().get_claims_digest()
Expand All @@ -769,12 +764,12 @@ def get_tx_digest(self) -> bytes:
if commit_evidence_digest is None:
return write_set_digest
else:
return self._dgst(write_set_digest + commit_evidence_digest)
return digest(write_set_digest + commit_evidence_digest)
else:
assert (
commit_evidence_digest
), "Invalid transaction: commit_evidence_digest not set"
return self._dgst(write_set_digest + commit_evidence_digest + claims_digest)
return digest(write_set_digest + commit_evidence_digest + claims_digest)

def _complete_read(self):
self._file.seek(self._next_offset, 0)
Expand Down

0 comments on commit 81abbc4

Please sign in to comment.