diff --git a/go/border/rpkt/path.go b/go/border/rpkt/path.go index 9a2adf86ea..d27a99e023 100644 --- a/go/border/rpkt/path.go +++ b/go/border/rpkt/path.go @@ -104,7 +104,7 @@ func (rp *RtrPkt) validateLocalIF(ifid *spath.IntfID) *common.Error { return nil } // Check that we have a revocation for the current epoch. - if revInfo.Epoch() < crypto.GetCurrentHashTreeEpoch() { + if !crypto.VerifyHashTreeEpoch(revInfo.Epoch()) { // If the BR does not have a revocation for the current epoch, it considers // the interface as active until it receives a new revocation. ifstate.Activate(*ifid) diff --git a/go/lib/crypto/htree.go b/go/lib/crypto/htree.go index cfd3e2b7fe..cd5d82ebc6 100644 --- a/go/lib/crypto/htree.go +++ b/go/lib/crypto/htree.go @@ -18,15 +18,33 @@ import ( "time" ) -// HashTreeTTL is the TTL of one hash tree (in seconds). -// FIXME(shitz): This should really be matching spath.MaxTTL, but more importantly, -// it needs to match the hash tree ttl used by the BS, which is currently set to 30 mins. -const HashTreeTTL = 30 * 60 +const ( + // HashTreeTTL is the TTL of one hash tree (in seconds). + // FIXME(shitz): This should really be matching spath.MaxTTL, but more importantly, + // it needs to match the hash tree ttl used by the BS, which is currently set to 30 mins. + HashTreeTTL = 30 * 60 * time.Second -// HashTreeEpochTime is the duration of one epoch (in seconds). -const HashTreeEpochTime = 10 + // HashTreeEpochTime is the duration of one epoch (in seconds). + HashTreeEpochTime = 10 * time.Second -func GetCurrentHashTreeEpoch() uint16 { - window := time.Now().Unix() % HashTreeTTL - return uint16(window / HashTreeEpochTime) + // HashTreeEpochTolerance is the duration after a revocation expired within which a + // revocation is still accepted by a verifier. + HashTreeEpochTolerance = 2 * time.Second +) + +// GetCurrentHashTreeEpoch returns the current epoch ID. +func GetCurrentHashTreeEpoch() uint64 { + return uint64(time.Now().Unix() / int64(HashTreeEpochTime.Seconds())) +} + +// GetTimeSinceHashTreeEpoch returns the time since the start of epoch. +func GetTimeSinceHashTreeEpoch(epoch uint64) time.Duration { + epochStart := time.Unix(0, int64(epoch)*HashTreeEpochTime.Nanoseconds()) + return time.Since(epochStart) +} + +// VerifyHashTreeEpoch verifies a given hash tree epoch. An epoch is valid if it is +// equal to the current epoch or within the tolerance limit of the next epoch. +func VerifyHashTreeEpoch(epoch uint64) bool { + return GetTimeSinceHashTreeEpoch(epoch) < (HashTreeEpochTime + HashTreeEpochTolerance) } diff --git a/infrastructure/beacon_server/base.py b/infrastructure/beacon_server/base.py index f3f2a1fcdf..c99e73954e 100644 --- a/infrastructure/beacon_server/base.py +++ b/infrastructure/beacon_server/base.py @@ -70,6 +70,7 @@ from lib.thread import thread_safety_net, kill_self from lib.types import ( CertMgmtType, + HashType, PathMgmtType as PMT, PayloadClass, ) @@ -165,8 +166,8 @@ def __init__(self, server_id, conf_dir): def _init_hash_tree(self): ifs = list(self.ifid2br.keys()) - self._hash_tree = ConnectedHashTree(self.addr.isd_as, - ifs, self.hashtree_gen_key) + self._hash_tree = ConnectedHashTree( + self.addr.isd_as, ifs, self.hashtree_gen_key, HashType.SHA256) def _get_ht_proof(self, if_id): with self._hash_tree_lock: @@ -439,8 +440,8 @@ def _create_next_tree(self): ht_start = time.time() ifs = list(self.ifid2br.keys()) - tree = ConnectedHashTree.get_next_tree(self.addr.isd_as, ifs, - self.hashtree_gen_key) + tree = ConnectedHashTree.get_next_tree( + self.addr.isd_as, ifs, self.hashtree_gen_key, HashType.SHA256) ht_end = time.time() with self._hash_tree_lock: self._next_tree = tree diff --git a/infrastructure/cert_server/main.py b/infrastructure/cert_server/main.py index c056b09ea8..ea27047e95 100644 --- a/infrastructure/cert_server/main.py +++ b/infrastructure/cert_server/main.py @@ -174,18 +174,18 @@ def process_cert_chain_request(self, req, meta): """Process a certificate chain request.""" assert isinstance(req, CertChainRequest) key = req.isd_as(), req.p.version - logging.info("Cert chain request received for %sv%s", *key) + logging.info("Cert chain request received for %sv%s from %s", *key, meta) local = meta.ia == self.addr.isd_as - if not self._check_cc(key) and not local: - logging.warning( - "Dropping CC request from %s for %sv%s: " - "CC not found && requester is not local)", - meta.get_addr(), *key) - return - if req.p.cacheOnly: - self._reply_cc(key, meta) + if not self._check_cc(key): + if not local: + logging.warning( + "Dropping CC request from %s for %sv%s: " + "CC not found && requester is not local)", + meta.get_addr(), *key) + else: + self.cc_requests.put((key, (meta, req))) return - self.cc_requests.put((key, meta)) + self._reply_cc(key, (meta, req)) def process_cert_chain_reply(self, rep, meta, from_zk=False): """Process a certificate chain reply.""" @@ -206,19 +206,26 @@ def _check_cc(self, key): logging.debug('Cert chain not found for %sv%s', *key) return False - def _fetch_cc(self, key, _): + def _fetch_cc(self, key, req_info): + # Do not attempt to fetch the CertChain from a remote AS if the cacheOnly flag is set. + _, orig_req = req_info + if orig_req.p.cacheOnly: + return isd_as, ver = key req = CertChainRequest.from_values(isd_as, ver, cache_only=True) - path = self._get_path_via_api(isd_as) - meta = self._build_meta(isd_as, host=SVCType.CS_A, path=path) - if path and self.send_meta(req, meta): - logging.info("Cert chain request sent: %s", req.short_desc()) + path_meta = self._get_path_via_api(isd_as) + if path_meta: + meta = self._build_meta(isd_as, host=SVCType.CS_A, path=path_meta.fwd_path()) + self.send_meta(req, meta) + logging.info("Cert chain request sent to %s via [%s]: %s", + meta, path_meta.short_desc(), req.short_desc()) else: logging.warning("Cert chain request (for %s) not sent: " - "no destination found", req.short_desc()) + "no path found", req.short_desc()) - def _reply_cc(self, key, meta): + def _reply_cc(self, key, req_info): isd_as, ver = key + meta = req_info[0] cert_chain = self.trust_store.get_cert(isd_as, ver) self.send_meta(CertChainReply.from_values(cert_chain), meta) logging.info("Cert chain for %sv%s sent to %s:%s", isd_as, ver, meta.get_addr(), meta.port) @@ -227,18 +234,18 @@ def process_trc_request(self, req, meta): """Process a TRC request.""" assert isinstance(req, TRCRequest) key = req.isd_as()[0], req.p.version - logging.info("TRC request received for %sv%s", *key) + logging.info("TRC request received for %sv%s from %s", *key, meta) local = meta.ia == self.addr.isd_as - if not self._check_trc(key) and not local: - logging.warning( - "Dropping TRC request from %s for %sv%s: " - "TRC not found && requester is not local)", - meta.get_addr(), *key) - return - if req.p.cacheOnly: - self._reply_trc(key, meta) + if not self._check_trc(key): + if not local: + logging.warning( + "Dropping TRC request from %s for %sv%s: " + "TRC not found && requester is not local)", + meta.get_addr(), *key) + else: + self.trc_requests.put((key, (meta, req))) return - self.trc_requests.put((key, (meta, req.isd_as()[1]),)) + self._reply_trc(key, (meta, req)) def process_trc_reply(self, trc_rep, meta, from_zk=False): """ @@ -264,21 +271,26 @@ def _check_trc(self, key): logging.debug('TRC not found for %sv%s', *key) return False - def _fetch_trc(self, key, info): + def _fetch_trc(self, key, req_info): + # Do not attempt to fetch the TRC from a remote AS if the cacheOnly flag is set. + _, orig_req = req_info + if orig_req.p.cacheOnly: + return isd, ver = key - isd_as = ISD_AS.from_values(isd, info[1]) + isd_as = ISD_AS.from_values(isd, orig_req.isd_as()[1]) trc_req = TRCRequest.from_values(isd_as, ver, cache_only=True) - path = self._get_path_via_api(isd_as) - meta = self._build_meta(isd_as, host=SVCType.CS_A, path=path) - if path and self.send_meta(trc_req, meta): - logging.info("TRC request sent for %sv%s.", *key) + path_meta = self._get_path_via_api(isd_as) + if path_meta: + meta = self._build_meta(isd_as, host=SVCType.CS_A, path=path_meta.fwd_path()) + self.send_meta(trc_req, meta) + logging.info("TRC request sent to %s via [%s]: %s", + meta, path_meta.short_desc(), trc_req.short_desc()) else: - logging.warning("TRC request not sent for %sv%s: " - "no destination found.", *key) + logging.warning("TRC request not sent for %s: no path found.", trc_req.short_desc()) - def _reply_trc(self, key, info): + def _reply_trc(self, key, req_info): isd, ver = key - meta = info[0] + meta = req_info[0] trc = self.trust_store.get_trc(isd, ver) self.send_meta(TRCReply.from_values(trc), meta) logging.info("TRC for %sv%s sent to %s:%s", isd, ver, meta.get_addr(), meta.port) @@ -294,7 +306,7 @@ def _get_path_via_api(self, isd_as, flush=False): logging.error("Error during path lookup: %s" % e) continue if path_entries: - return path_entries[0].path().fwd_path() + return path_entries[0].path() logging.warning("Unable to get path to %s from local api.", isd_as) return None diff --git a/infrastructure/path_server/base.py b/infrastructure/path_server/base.py index 499c82128a..3fce51fef4 100644 --- a/infrastructure/path_server/base.py +++ b/infrastructure/path_server/base.py @@ -439,8 +439,8 @@ def _send_waiting_queries(self, dst_isd, pcb): (seg_req, logger) = targets.pop(0) meta = self._build_meta(ia=src_ia, path=path, host=SVCType.PS_A, reuse=True) self.send_meta(seg_req, meta) - logger.info("Waiting request (%s) sent via %s", - seg_req.short_desc(), pcb.short_desc()) + logger.info("Waiting request (%s) sent to %s via %s", + seg_req.short_desc(), meta, pcb.short_desc()) def _share_via_zk(self): if not self._segs_to_zk: diff --git a/infrastructure/scion_elem.py b/infrastructure/scion_elem.py index 607da54865..4dd3c3d21d 100644 --- a/infrastructure/scion_elem.py +++ b/infrastructure/scion_elem.py @@ -341,7 +341,7 @@ def request_missing_trcs(self, seg_meta): continue self.requested_trcs.add((isd, ver)) isd_as = ISD_AS.from_values(isd, 0) - trc_req = TRCRequest.from_values(isd_as, ver) + trc_req = TRCRequest.from_values(isd_as, ver, cache_only=True) logging.info("Requesting %sv%s TRC for PCB %s", isd, ver, seg_meta.seg.short_id()) if not seg_meta.meta: meta = self.get_cs() @@ -367,7 +367,7 @@ def request_missing_certs(self, seg_meta): if (isd_as, ver) in self.requested_certs: continue self.requested_certs.add((isd_as, ver)) - cert_req = CertChainRequest.from_values(isd_as, ver) + cert_req = CertChainRequest.from_values(isd_as, ver, cache_only=True) meta = seg_meta.meta if not meta: meta = self.get_cs() diff --git a/lib/crypto/hash_tree.py b/lib/crypto/hash_tree.py index 5276a955c4..a5a140b51d 100644 --- a/lib/crypto/hash_tree.py +++ b/lib/crypto/hash_tree.py @@ -20,7 +20,7 @@ import time # SCION -from lib.crypto.symcrypto import crypto_hash +from lib.crypto.symcrypto import hash_func_for_type from lib.defines import ( HASHTREE_EPOCH_TIME, HASHTREE_EPOCH_TOLERANCE, @@ -37,21 +37,23 @@ class HashTree(object): The used hash function needs to implement the hashlib interface. """ - def __init__(self, isd_as, if_ids, seed, hash_func=crypto_hash): + def __init__(self, isd_as, if_ids, seed, ttl_window, hash_type): """ :param ISD_AS isd_as: The ISD_AS of the AS. :param List[int] if_ids: List of interface IDs of the AS. :param str seed: Seed for creating hash-tree nonces. - :param hash_func: Hash function. hash_func(msg) outputs hash of the msg. + :param int ttl_window: The TTL window for which this hash tree is valid. + :param hash_type: Hash function type. """ self._isd_as = isd_as self._seed = seed - self._n_epochs = HASHTREE_N_EPOCHS - self._hash_func = hash_func + self._ttl_window = ttl_window + self._hash_type = hash_type + self._hash_func = hash_func_for_type(hash_type) self._setup(if_ids) def _setup(self, if_ids): - self.calc_tree_depth(len(if_ids) * self._n_epochs) + self.calc_tree_depth(len(if_ids) * HASHTREE_N_EPOCHS) self.create_tree(if_ids) def calc_tree_depth(self, leaf_count): @@ -90,14 +92,15 @@ def create_tree(self, if_ids): self._if2idx = {} for if_id in if_ids: # For given (if_id, epoch) leaves self._if2idx[if_id] = idx - for i in range(self._n_epochs): + for i in range(HASHTREE_N_EPOCHS): raw_nonce = (self._seed + struct.pack("!qq", if_id, i)) nonce = self._hash_func(raw_nonce) if_tuple = struct.pack("!qq", if_id, i) + nonce self._nodes[idx] = self._hash_func(if_tuple) idx = idx + 1 + null_hash = self._hash_func(b"0") while idx < node_count: # For extra leaves added to complete tree - self._nodes[idx] = self._hash_func(b"0") + self._nodes[idx] = null_hash idx = idx + 1 # Compute and fill in the hash values for internal nodes (bottom up). @@ -115,13 +118,14 @@ def get_proof(self, if_id, epoch, prev_root, next_root): :param bytes next_root: hash of the next root. """ assert if_id in self._if2idx.keys(), "if_id not found in AS" + relative_epoch = epoch % HASHTREE_N_EPOCHS # Obtain the nonce for the (if_id, epoch) pair using the seed. - raw_nonce = self._seed + struct.pack("!qq", if_id, epoch) + raw_nonce = self._seed + struct.pack("!qq", if_id, relative_epoch) nonce = self._hash_func(raw_nonce) # Obtain the sibling hashes along with their left/right position info. siblings = [] - idx = self._if2idx[if_id] + epoch + idx = self._if2idx[if_id] + relative_epoch while idx > 0: if idx % 2 == 0: siblings.append((True, self._nodes[idx - 1])) @@ -131,7 +135,8 @@ def get_proof(self, if_id, epoch, prev_root, next_root): # Using the above fields, construct a RevInfo capnp as the proof. return RevocationInfo.from_values( - self._isd_as, if_id, epoch, nonce, siblings, prev_root, next_root) + self._isd_as, if_id, epoch, nonce, siblings, prev_root, next_root, + self._hash_type) class ConnectedHashTree(object): @@ -143,13 +148,12 @@ class ConnectedHashTree(object): """ - def __init__(self, isd_as, if_ids, seed, - hash_func=crypto_hash): # pragma: no cover + def __init__(self, isd_as, if_ids, seed, hash_type): # pragma: no cover """ :param ISD_AS isd_as: The ISD_AS of the AS. :param List[int] if_ids: list of interface IDs of the AS. :param List[str] seeds: list of 3 seeds for creating hash-tree nonces. - :param hash_func: Hash function. hash_func(msg) outputs hash of the msg. + :param hash_type: Hash function type. """ assert len(if_ids)*HASHTREE_N_EPOCHS >= 1, "Must have at least 1 leaf" ttl_window = self.get_ttl_window() @@ -157,20 +161,18 @@ def __init__(self, isd_as, if_ids, seed, seed2 = seed + (ttl_window + 0).to_bytes(8, 'big') seed3 = seed + (ttl_window + 1).to_bytes(8, 'big') - self._hash_func = hash_func - self._ht0_root = hash_func(str(seed1).encode('utf-8')) - self._ht1 = HashTree(isd_as, if_ids, seed2, hash_func) - self._ht2 = HashTree(isd_as, if_ids, seed3, hash_func) + self._hash_func = hash_func_for_type(hash_type) + self._ht0_root = self._hash_func(str(seed1).encode('utf-8')) + self._ht1 = HashTree(isd_as, if_ids, seed2, ttl_window, hash_type) + self._ht2 = HashTree(isd_as, if_ids, seed3, ttl_window + 1, hash_type) @classmethod def get_ttl_window(cls): - cur_time = int(time.time()) - return cur_time // HASHTREE_TTL + return int(time.time()) // HASHTREE_TTL @classmethod def get_current_epoch(cls): - cur_window = int(time.time()) % HASHTREE_TTL - return cur_window // HASHTREE_EPOCH_TIME + return int(time.time()) // HASHTREE_EPOCH_TIME @classmethod def get_time_since_epoch(cls): @@ -185,9 +187,10 @@ def get_time_till_next_ttl(cls): return HASHTREE_TTL - cls.get_time_since_ttl() @classmethod - def get_next_tree(cls, isd_as, if_ids, seed, hash_func=crypto_hash): - seed += (cls.get_ttl_window() + 2).to_bytes(8, 'big') - return HashTree(isd_as, if_ids, seed, hash_func) + def get_next_tree(cls, isd_as, if_ids, seed, hash_type): + ttl_window = cls.get_ttl_window() + 2 + seed += ttl_window.to_bytes(8, 'big') + return HashTree(isd_as, if_ids, seed, ttl_window, hash_type) def update(self, next_tree): self._ht0_root = self._ht1._nodes[0] @@ -210,13 +213,15 @@ def get_proof(self, if_id): if_id, epoch, self._ht0_root, self._ht2._nodes[0]) @classmethod - def get_possible_hashes(cls, revProof, hash_func=crypto_hash): + def get_possible_hashes(cls, rev_info): """ - Compute the hashes of the connected hash-tree roots given revProof. + Compute the hashes of the connected hash-tree roots given rev_info. """ + proof = rev_info.p + hash_func = hash_func_for_type(proof.hashType) # Calculate the hashes upwards till the tree root (of T). - proof = revProof.p - if_tuple = struct.pack("!qq", proof.ifID, proof.epoch) + proof.nonce + relative_epoch = proof.epoch % HASHTREE_N_EPOCHS + if_tuple = struct.pack("!qq", proof.ifID, relative_epoch) + proof.nonce curr_hash = hash_func(if_tuple) for i in range(len(proof.siblings)): @@ -234,17 +239,16 @@ def get_possible_hashes(cls, revProof, hash_func=crypto_hash): return (hash01, hash12) @classmethod - def verify(cls, revProof, root, hash_func=crypto_hash): # pragma: no cover + def verify(cls, rev_info, root): # pragma: no cover """ - Verify whether revProof proves the revocation for the current epoch, + Verify whether rev_info proves the revocation for the current epoch, given the root of the connected hash-tree. - :param RevInfo revProof: proof for the revocation. + :param RevInfo rev_info: proof for the revocation. :param bytes root: hash of the root, used for validating the proof. - :param hash_func: hash function that implements hashlib interface. """ - assert not isinstance(revProof.p, bytes) - h01, h12 = cls.get_possible_hashes(revProof, hash_func) + assert not isinstance(rev_info.p, bytes) + h01, h12 = cls.get_possible_hashes(rev_info) return h01 == root or h12 == root @classmethod diff --git a/lib/crypto/symcrypto.py b/lib/crypto/symcrypto.py index e9d417446a..ed95026ccf 100644 --- a/lib/crypto/symcrypto.py +++ b/lib/crypto/symcrypto.py @@ -16,13 +16,17 @@ ===================================================== """ # Stdlib -from hashlib import pbkdf2_hmac, sha256 +import hashlib # External packages from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers.algorithms import AES from cryptography.hazmat.primitives.cmac import CMAC +# SCION +from lib.errors import SCIONTypeError +from lib.types import HashType + def mac(key, msg): """ @@ -52,13 +56,26 @@ def kdf(secret, phrase): """ Default key derivation function. """ - return pbkdf2_hmac('sha256', secret, phrase, 1000)[:16] + return hashlib.pbkdf2_hmac('sha256', secret, phrase, 1000)[:16] -def crypto_hash(data): +def sha256(data): """ Default hash function. """ - digest = sha256() + digest = hashlib.sha256() digest.update(data) return digest.digest() + + +# Default hash function +crypto_hash = sha256 + + +def hash_func_for_type(type): + """ + Returns a callable corresponding to 'type'. + """ + if type == HashType.SHA256: + return sha256 + raise SCIONTypeError("Unknown hash function type.") diff --git a/lib/defines.py b/lib/defines.py index 5d1a572b46..198615105b 100644 --- a/lib/defines.py +++ b/lib/defines.py @@ -134,7 +134,7 @@ # Time per Epoch HASHTREE_EPOCH_TIME = 10 # The tolerable error in epoch in seconds. -HASHTREE_EPOCH_TOLERANCE = 5 +HASHTREE_EPOCH_TOLERANCE = 2 # Max time to live HASHTREE_TTL = MAX_SEGMENT_TTL # Number of epochs in one TTL per interface diff --git a/lib/packet/path_mgmt/rev_info.py b/lib/packet/path_mgmt/rev_info.py index be5a2d5152..64b7045031 100644 --- a/lib/packet/path_mgmt/rev_info.py +++ b/lib/packet/path_mgmt/rev_info.py @@ -37,7 +37,7 @@ class RevocationInfo(PathMgmtPayloadBase): @classmethod def from_values(cls, isd_as, if_id, epoch, nonce, siblings, prev_root, - next_root): + next_root, hash_type): """ Returns a RevocationInfo object with the specified values. @@ -48,10 +48,11 @@ def from_values(cls, isd_as, if_id, epoch, nonce, siblings, prev_root, :param list[(bool, bytes)] siblings: Positions and hashes of siblings :param bytes prev_root: Hash of the tree root at time T-1 :param bytes next_root: Hash of the tree root at time T+1 + :param hash_type: The hash function needed to verify the revocation. """ # Put the isd_as, if_id, epoch and nonce of the leaf into the proof. p = cls.P_CLS.new_message(isdas=int(isd_as), ifID=if_id, epoch=epoch, - nonce=nonce) + nonce=nonce, hashType=hash_type) # Put the list of sibling hashes (along with l/r) into the proof. sibs = p.init('siblings', len(siblings)) for i, sibling in enumerate(siblings): @@ -66,9 +67,9 @@ def isd_as(self): def cmp_str(self): b = [] - b.append(self.p.isdas.to_bytes(8, 'big')) + b.append(self.p.isdas.to_bytes(4, 'big')) b.append(self.p.ifID.to_bytes(8, 'big')) - b.append(self.p.epoch.to_bytes(2, 'big')) + b.append(self.p.epoch.to_bytes(8, 'big')) b.append(self.p.nonce) return b"".join(b) diff --git a/lib/sciond_api/path_meta.py b/lib/sciond_api/path_meta.py index d1e151848e..69de31319c 100644 --- a/lib/sciond_api/path_meta.py +++ b/lib/sciond_api/path_meta.py @@ -54,7 +54,7 @@ def iter_ifs(self): def short_desc(self): if_str = ", ".join([if_.short_desc() for if_ in self.iter_ifs()]) - return "MTU: %d Interfaces: %s" % (self.p.mtu, if_str) + return "Interfaces: %s MTU: %d" % (if_str, self.p.mtu) def __eq__(self, other): return list(self.iter_ifs()) == list(other.iter_ifs()) diff --git a/lib/types.py b/lib/types.py index c49c8d2d59..49cd342e18 100644 --- a/lib/types.py +++ b/lib/types.py @@ -178,3 +178,10 @@ class SCIONDMsgType(TypeBase): IF_REPLY = "ifInfoReply" SERVICE_REQUEST = "serviceInfoRequest" SERVICE_REPLY = "serviceInfoReply" + + +####################### +# Hash function types +####################### +class HashType(TypeBase): + SHA256 = 0 diff --git a/proto/rev_info.capnp b/proto/rev_info.capnp index 886021b00a..d353c5ed7a 100644 --- a/proto/rev_info.capnp +++ b/proto/rev_info.capnp @@ -10,10 +10,11 @@ struct SiblingHash { struct RevInfo { ifID @0 :UInt64; # ID of the interface to be revoked - epoch @1 :UInt16; # Epoch for which interface is to be revoked + epoch @1 :UInt64; # Epoch for which interface is to be revoked nonce @2 :Data; # Nonce corresponding to the (ifID,epoch) leaf in hashtree siblings @3 :List(SiblingHash); # Hash values of siblings, bottom to top prevRoot @4 :Data; # Root of the hashtree of previous time block (T-1) nextRoot @5 :Data; # Root of the hashtree of next time block (T+1) isdas @6 :UInt32; # ISD-AS of the revocation issuer. + hashType @7 :UInt16; # The hash function type needed to verify the revocation. } diff --git a/test/lib/crypto/hash_tree_test.py b/test/lib/crypto/hash_tree_test.py index b33222ebf8..fb7a4905a7 100644 --- a/test/lib/crypto/hash_tree_test.py +++ b/test/lib/crypto/hash_tree_test.py @@ -32,6 +32,7 @@ HASHTREE_EPOCH_TOLERANCE, ) from lib.packet.scion_addr import ISD_AS +from lib.types import HashType from test.testcommon import create_mock_full @@ -42,7 +43,7 @@ class TestHashTreeCalcTreeDepth(object): @patch("lib.crypto.hash_tree.HashTree._setup", autospec=True) def test_for_non2power(self, _): # Setup - inst = HashTree(ISD_AS("1-11"), "if_ids", "seed") + inst = HashTree(ISD_AS("1-11"), "if_ids", "seed", 1, HashType.SHA256) # Call inst.calc_tree_depth(6) # Tests @@ -53,7 +54,7 @@ def test_for_2power(self, _): # Setup if_ids = [1, 2, 3, 4] seed = b"abc" - inst = HashTree(ISD_AS("1-11"), if_ids, seed) + inst = HashTree(ISD_AS("1-11"), if_ids, seed, 1, HashType.SHA256) # Call inst.calc_tree_depth(8) # Tests @@ -64,16 +65,18 @@ class TestHashTreeCreateTree(object): """ Unit test for lib.crypto.hash_tree.HashTree.create_tree """ + @patch("lib.crypto.hash_tree.HASHTREE_N_EPOCHS", 1) @patch("lib.crypto.hash_tree.HashTree._setup", autospec=True) - def test(self, _): + @patch("lib.crypto.hash_tree.hash_func_for_type", autospec=True) + def test(self, hash_func_for_type, _): # Setup isd_as = ISD_AS("1-11") if_ids = [1, 2, 3] hashes = [b"s10", b"10s10", b"s20", b"20s20", b"s30", b"30s30", b"0", b"30s300", b"10s1020s20", b"10s1020s2030s300"] hash_func = create_mock_full(side_effect=hashes) - inst = HashTree(isd_as, if_ids, b"s", hash_func) - inst._n_epochs = 1 + hash_func_for_type.return_value = hash_func + inst = HashTree(isd_as, if_ids, b"s", 1, HashType.SHA256) inst._depth = 2 # Call inst.create_tree(if_ids) @@ -87,16 +90,18 @@ class TestHashTreeGetProof(object): """ Unit test for lib.crypto.hash_tree.HashTree.get_proof """ + @patch("lib.crypto.hash_tree.HASHTREE_N_EPOCHS", 1) @patch("lib.crypto.hash_tree.HashTree._setup", autospec=True) - def test(self, _): + @patch("lib.crypto.hash_tree.hash_func_for_type", autospec=True) + def test(self, hash_func_for_type, _): # Setup isd_as = ISD_AS("1-11") if_ids = [1, 2, 3] hashes = [b"s10", b"10s10", b"s20", b"20s20", b"s30", b"30s30", b"0", b"30s300", b"10s1020s20", b"10s1020s2030s300", b"s20"] hash_func = create_mock_full(side_effect=hashes) - inst = HashTree(isd_as, if_ids, b"s", hash_func) - inst._n_epochs = 1 + hash_func_for_type.return_value = hash_func + inst = HashTree(isd_as, if_ids, b"s", 1, HashType.SHA256) inst._depth = 2 inst.create_tree(if_ids) # Call @@ -110,6 +115,7 @@ def test(self, _): ntools.eq_(proof.p.siblings[1].hash, b"30s300") +@patch("lib.crypto.hash_tree.HASHTREE_N_EPOCHS", 1) class TestConnectedHashTreeUpdate(object): """ Unit test for lib.crypto.hash_tree.ConnectedHashTree.update @@ -119,11 +125,11 @@ def test(self): isd_as = ISD_AS("1-11") if_ids = [23, 35, 120] initial_seed = b"qwerty" - inst = ConnectedHashTree(isd_as, if_ids, initial_seed) + inst = ConnectedHashTree(isd_as, if_ids, initial_seed, HashType.SHA256) root1_before_update = inst._ht1._nodes[0] root2_before_update = inst._ht2._nodes[0] # Call - new_tree = inst.get_next_tree(isd_as, if_ids, b"new!!seed") + new_tree = inst.get_next_tree(isd_as, if_ids, b"new!!seed", HashType.SHA256) inst.update(new_tree) # Tests root0_after_update = inst._ht0_root @@ -136,21 +142,22 @@ class TestConnectedHashtreeGetPossibleHashes(object): """ Unit test for lib.crypto.hash_tree.ConnectedHashTree.get_possible_hashes """ - def test(self): + @patch("lib.crypto.hash_tree.hash_func_for_type", autospec=True) + def test(self, hash_func_for_type): # Setup siblings = [] siblings.append(create_mock_full({"isLeft": True, "hash": "10s10"})) siblings.append(create_mock_full({"isLeft": False, "hash": "30s300"})) p = create_mock_full( {"ifID": 2, "epoch": 0, "nonce": b"s20", "siblings": siblings, - "prevRoot": "p", "nextRoot": "n"}) - revProof = create_mock_full({"p": p}) + "prevRoot": "p", "nextRoot": "n", "hashType": 0}) + rev_info = create_mock_full({"p": p}) hashes = ["20s20", "10s1020s20", "10s1020s2030s300", "p10s1020s2030s300", "10s1020s2030s300n"] hash_func = create_mock_full(side_effect=hashes) + hash_func_for_type.return_value = hash_func # Call - hash01, hash12 = ConnectedHashTree.get_possible_hashes( - revProof, hash_func) + hash01, hash12 = ConnectedHashTree.get_possible_hashes(rev_info) # Tests ntools.eq_(hash01, "p10s1020s2030s300") ntools.eq_(hash12, "10s1020s2030s300n") @@ -161,16 +168,28 @@ class TestConnectedHashTreeUpdateAndVerify(object): Unit tests for lib.crypto.hash_tree.ConnectedHashTree.verify used along with lib.crypto.hash_tree.ConnectedHashTree.update """ + def test(self): + # Check that the revocation proof is verifiable in T. + isd_as = ISD_AS("1-11") + if_ids = [23, 35, 120] + initial_seed = b"qwerty" + inst = ConnectedHashTree(isd_as, if_ids, initial_seed, HashType.SHA256) + root = inst.get_root() + # Call + proof = inst.get_proof(120) + # Tests + ntools.eq_(ConnectedHashTree.verify(proof, root), True) + def test_one_timestep(self): # Check that the revocation proof is verifiable across T and T+1. # Setup isd_as = ISD_AS("1-11") if_ids = [23, 35, 120] initial_seed = b"qwerty" - inst = ConnectedHashTree(isd_as, if_ids, initial_seed) + inst = ConnectedHashTree(isd_as, if_ids, initial_seed, HashType.SHA256) root = inst.get_root() # Call - next_tree = inst.get_next_tree(isd_as, if_ids, b"new!!seed") + next_tree = inst.get_next_tree(isd_as, if_ids, b"new!!seed", HashType.SHA256) inst.update(next_tree) # Tests proof = inst.get_proof(35) # if_id = 35. @@ -182,12 +201,12 @@ def test_two_timesteps(self): isd_as = ISD_AS("1-11") if_ids = [23, 35, 120] initial_seed = b"qwerty" - inst = ConnectedHashTree(isd_as, if_ids, initial_seed) + inst = ConnectedHashTree(isd_as, if_ids, initial_seed, HashType.SHA256) root = inst.get_root() # Call - new_tree = inst.get_next_tree(isd_as, if_ids, b"newseed.@1") + new_tree = inst.get_next_tree(isd_as, if_ids, b"newseed.@1", HashType.SHA256) inst.update(new_tree) - new_tree = inst.get_next_tree(isd_as, if_ids, b"newseed.@2") + new_tree = inst.get_next_tree(isd_as, if_ids, b"newseed.@2", HashType.SHA256) inst.update(new_tree) # Tests proof = inst.get_proof(35) # if_id = 35.