Skip to content

Commit

Permalink
policy: fix policy type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Erik Larsson <[email protected]>
  • Loading branch information
whooo committed Jan 19, 2024
1 parent cfe4af4 commit 212044f
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 35 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ exclude = '''
mypy_path = "mypy_stubs"
exclude = [
'src/tpm2_pytss/encoding.py',
'src/tpm2_pytss/policy.py',
'src/tpm2_pytss/ESAPI.py',
'src/tpm2_pytss/FAPI.py',
'src/tpm2_pytss/internal/crypto.py',
Expand Down
1 change: 1 addition & 0 deletions src/tpm2_pytss/_libtpm2_pytss/ffi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ def new_handle(python_object: Any) -> CData: ...
def cast(ctype: str, value: CData) -> CData: ...
def memmove(dest: CData | bytes, src: CData | bytes, n: int) -> None: ...
def addressof(cdata: CData, *fields_or_indexes: str | int) -> CData: ...
def unpack(cdata: CData, maxlen: Optional[int]) -> bytes: ...
99 changes: 65 additions & 34 deletions src/tpm2_pytss/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from ._libtpm2_pytss import ffi, lib
from .ESAPI import ESAPI
from enum import Enum
from typing import Callable, Union
from typing import Callable, Union, Any, Type, Optional, Dict
from types import TracebackType


class policy_cb_types(Enum):
Expand All @@ -39,7 +40,12 @@ class policy_cb_types(Enum):


@ffi.def_extern()
def _policy_cb_calc_pcr(selection, out_selection, out_digest, userdata):
def _policy_cb_calc_pcr(
selection: ffi.CData,
out_selection: ffi.CData,
out_digest: ffi.CData,
userdata: ffi.CData,
) -> int:
"""Callback wrapper for policy PCR calculations
Args:
Expand Down Expand Up @@ -73,7 +79,7 @@ def _policy_cb_calc_pcr(selection, out_selection, out_digest, userdata):


@ffi.def_extern()
def _policy_cb_calc_name(path, name, userdata):
def _policy_cb_calc_name(path: ffi.CData, name: ffi.CData, userdata: ffi.CData) -> int:
"""Callback wrapper for policy name calculations
Args:
Expand All @@ -100,7 +106,9 @@ def _policy_cb_calc_name(path, name, userdata):


@ffi.def_extern()
def _policy_cb_calc_public(path, public, userdata):
def _policy_cb_calc_public(
path: ffi.CData, public: ffi.CData, userdata: ffi.CData
) -> int:
"""Callback wrapper for getting the public part for a key path
Args:
Expand Down Expand Up @@ -132,7 +140,9 @@ def _policy_cb_calc_public(path, public, userdata):


@ffi.def_extern()
def _policy_cb_calc_nvpublic(path, nv_index, nv_public, userdata):
def _policy_cb_calc_nvpublic(
path: ffi.CData, nv_index: int, nv_public: ffi.CData, userdata: ffi.CData
) -> int:
"""Callback wrapper for getting the public part for a NV path
Args:
Expand Down Expand Up @@ -165,7 +175,13 @@ def _policy_cb_calc_nvpublic(path, nv_index, nv_public, userdata):


@ffi.def_extern()
def _policy_cb_exec_auth(name, object_handle, auth_handle, auth_session, userdata):
def _policy_cb_exec_auth(
name: ffi.CData,
object_handle: ffi.CData,
auth_handle: ffi.CData,
auth_session: ffi.CData,
userdata: ffi.CData,
) -> int:
"""Callback wrapper for getting authorization sessions for a name
Args:
Expand All @@ -180,7 +196,7 @@ def _policy_cb_exec_auth(name, object_handle, auth_handle, auth_session, userdat
if not cb:
return TSS2_RC.POLICY_RC_NULL_CALLBACK
try:
nb = ffi.unpack(name.name, name.size)
nb = ffi.unpack(name.name, int(name.size))
name2b = TPM2B_NAME(nb)
cb_object_handle, cb_auth_handle, cb_auth_session = cb(name2b)
object_handle[0] = cb_object_handle
Expand All @@ -197,8 +213,12 @@ def _policy_cb_exec_auth(name, object_handle, auth_handle, auth_session, userdat

@ffi.def_extern()
def _policy_cb_exec_polsel(
auth_object, branch_names, branch_count, branch_idx, userdata
):
auth_object: ffi.CData,
branch_names: ffi.CData,
branch_count: int,
branch_idx: ffi.CData,
userdata: ffi.CData,
) -> int:
"""Callback wrapper selection of a policy branch
Args:
Expand Down Expand Up @@ -233,15 +253,15 @@ def _policy_cb_exec_polsel(

@ffi.def_extern()
def _policy_cb_exec_sign(
key_pem,
public_key_hint,
key_pem_hash_alg,
buf,
buf_size,
signature,
signature_size,
userdata,
):
key_pem: ffi.CData,
public_key_hint: ffi.CData,
key_pem_hash_alg: int,
buf: ffi.CData,
buf_size: int,
signature: ffi.CData,
signature_size: ffi.CData,
userdata: ffi.CData,
) -> int:
"""Callback wrapper to signing an operation
Args:
Expand Down Expand Up @@ -277,8 +297,13 @@ def _policy_cb_exec_sign(

@ffi.def_extern()
def _policy_cb_exec_polauth(
key_public, hash_alg, digest, policy_ref, signature, userdata
):
key_public: ffi.CData,
hash_alg: int,
digest: ffi.CData,
policy_ref: ffi.CData,
signature: ffi.CData,
userdata: ffi.CData,
) -> int:
"""Callback for signing a policy
Args:
Expand All @@ -296,8 +321,8 @@ def _policy_cb_exec_polauth(
try:
key_pub = TPMT_PUBLIC(_cdata=key_public)
halg = TPM2_ALG(hash_alg)
db = ffi.unpack(digest.buffer, digest.size)
pb = ffi.unpack(policy_ref.buffer, policy_ref.size)
db = ffi.unpack(digest.buffer, int(digest.size))
pb = ffi.unpack(policy_ref.buffer, int(policy_ref.size))
dig = TPM2B_DIGEST(db)
polref = TPM2B_NONCE(pb)
cb_signature = cb(key_pub, halg, dig, polref)
Expand All @@ -313,7 +338,9 @@ def _policy_cb_exec_polauth(


@ffi.def_extern()
def _policy_cb_exec_polauthnv(nv_public, hash_alg, userdata):
def _policy_cb_exec_polauthnv(
nv_public: ffi.CData, hash_alg: int, userdata: ffi.CData
) -> int:
"""Callback wrapper for NV policy authorization
Args:
Expand All @@ -339,7 +366,7 @@ def _policy_cb_exec_polauthnv(nv_public, hash_alg, userdata):


@ffi.def_extern()
def _policy_cb_exec_poldup(name, userdata):
def _policy_cb_exec_poldup(name: ffi.CData, userdata: ffi.CData) -> int:
"""Callback wrapper to get name for duplication selection
Args:
Expand All @@ -364,7 +391,7 @@ def _policy_cb_exec_poldup(name, userdata):


@ffi.def_extern()
def _policy_cb_exec_polaction(action, userdata):
def _policy_cb_exec_polaction(action: ffi.CData, userdata: ffi.CData) -> int:
"""Callback wrapper for policy action
Args:
Expand Down Expand Up @@ -408,7 +435,7 @@ def __init__(self, policy: Union[bytes, str], hash_alg: TPM2_ALG):
policy = policy.encode()
self._policy = policy
self._hash_alg = hash_alg
self._callbacks = dict()
self._callbacks: Dict[policy_cb_types, Optional[Callable[..., Any]]] = dict()
self._callback_exception = None
self._ctx_pp = ffi.new("TSS2_POLICY_CTX **")
_chkrc(lib.Tss2_PolicyInit(policy, hash_alg, self._ctx_pp))
Expand All @@ -417,13 +444,15 @@ def __init__(self, policy: Union[bytes, str], hash_alg: TPM2_ALG):
self._calc_callbacks = ffi.new("TSS2_POLICY_CALC_CALLBACKS *")
self._exec_callbacks = ffi.new("TSS2_POLICY_EXEC_CALLBACKS *")

def __enter__(self):
def __enter__(self) -> "policy":
return self

def __exit__(self, _type, value, traceback):
def __exit__(
self, _type: Type[Exception], value: Exception, traceback: TracebackType
) -> None:
self.close()

def close(self):
def close(self) -> None:
"""Finalize the policy instance"""
lib.Tss2_PolicyFinalize(self._ctx_pp)
self._ctx_pp = ffi.NULL
Expand All @@ -439,12 +468,14 @@ def hash_alg(self) -> TPM2_ALG:
"""TPM2_ALG: The hash algorithm to be used during policy calculcation."""
return self._hash_alg

def _get_callback(self, callback_type: policy_cb_types) -> Callable:
def _get_callback(
self, callback_type: policy_cb_types
) -> Optional[Callable[..., Any]]:
return self._callbacks.get(callback_type)

def set_callback(
self, callback_type: policy_cb_types, callback: Union[None, Callable]
):
self, callback_type: policy_cb_types, callback: Union[None, Callable[..., Any]]
) -> None:
"""Set callback for policy calculaction or execution
Args:
Expand Down Expand Up @@ -522,7 +553,7 @@ def set_callback(
elif update_exec:
_chkrc(lib.Tss2_PolicySetExecCallbacks(self._ctx, self._exec_callbacks))

def execute(self, esys_ctx: ESAPI, session: ESYS_TR):
def execute(self, esys_ctx: ESAPI, session: ESYS_TR) -> None:
"""Executes the policy
Args:
Expand All @@ -541,7 +572,7 @@ def execute(self, esys_ctx: ESAPI, session: ESYS_TR):
finally:
self._callback_exception = None

def calculate(self):
def calculate(self) -> None:
"""Calculate the policy
Raises:
Expand Down

0 comments on commit 212044f

Please sign in to comment.