From 28bbc6543b2ed147bc6cec9f933c8ededf67c5e0 Mon Sep 17 00:00:00 2001 From: Erik Larsson Date: Sat, 20 Jan 2024 10:07:12 +0100 Subject: [PATCH] ESAPI: fix type hints Signed-off-by: Erik Larsson --- pyproject.toml | 1 - setup.py | 2 +- src/tpm2_pytss/ESAPI.py | 75 +++++++++++++++++++++----------- src/tpm2_pytss/constants.py | 27 ++++++------ src/tpm2_pytss/internal/utils.py | 2 +- src/tpm2_pytss/types.py | 8 +++- 6 files changed, 72 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 50dc3959..337f44ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ exclude = ''' mypy_path = "mypy_stubs" exclude = [ 'src/tpm2_pytss/encoding.py', - 'src/tpm2_pytss/ESAPI.py', 'src/tpm2_pytss/FAPI.py', 'src/tpm2_pytss/internal/crypto.py', 'src/tpm2_pytss/fapi_info.py', diff --git a/setup.py b/setup.py index 36901b05..2939cb09 100644 --- a/setup.py +++ b/setup.py @@ -310,7 +310,7 @@ def build_function(self, d): if param.type.type.type.names else None ) - if tn == "char": + if tn in ("char", "uint8_t"): ft = "CData | bytes" elif isinstance( param.type, cparser.pycparser.c_ast.TypeDecl diff --git a/src/tpm2_pytss/ESAPI.py b/src/tpm2_pytss/ESAPI.py index 7cb88eab..be8355a5 100644 --- a/src/tpm2_pytss/ESAPI.py +++ b/src/tpm2_pytss/ESAPI.py @@ -10,15 +10,30 @@ ) from .TCTI import TCTI from .TCTILdr import TCTILdr +from ._libtpm2_pytss import ffi, lib -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, List, Any, Type, Callable, Sequence + +try: + from typing import Self +except ImportError: + # assume mypy is running on python 3.11+ + pass +from types import TracebackType # Work around this FAPI dependency if FAPI is not present with the constant value _fapi_installed_ = _lib_version_atleast("tss2-fapi", "3.0.0") _DEFAULT_LOAD_BLOB_SELECTOR = FAPI_ESYSBLOB.CONTEXTLOAD if _fapi_installed_ else 1 -def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs): +def _get_cdata( + value: Any, + expected: Type["TPM_OBJECT"], + varname: str, + allow_none: bool = False, + *args: Any, + **kwargs: Any, +) -> ffi.CData: tname = expected.__name__ if value is None and allow_none: @@ -36,12 +51,12 @@ def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs): return value vname = type(value).__name__ - parse_method = getattr(expected, "parse", None) + parse_method: Optional[Callable[..., ffi.CData]] = getattr(expected, "parse", None) if isinstance(value, (bytes, str)) and issubclass(expected, TPM2B_SIMPLE_OBJECT): bo = expected(value) return bo._cdata elif isinstance(value, str) and parse_method and callable(parse_method): - return expected.parse(value, *args, **kwargs)._cdata + return parse_method(value, *args, **kwargs)._cdata elif issubclass(expected, TPML_OBJECT) and isinstance(value, list): return expected(value)._cdata elif not isinstance(value, expected): @@ -50,7 +65,9 @@ def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs): return value._cdata -def _check_handle_type(handle, varname, expected=None): +def _check_handle_type( + handle: ESYS_TR, varname: str, expected: Optional[Sequence[ESYS_TR]] = None +) -> None: if not isinstance(handle, ESYS_TR): raise TypeError(f"expected {varname} to be type ESYS_TR, got {type(handle)}") @@ -129,10 +146,12 @@ def __init__(self, tcti: Union[TCTI, str, None] = None): _chkrc(lib.Esys_Initialize(self._ctx_pp, tctx, ffi.NULL)) self._ctx = self._ctx_pp[0] - def __enter__(self): + def __enter__(self) -> "Self": return self - def __exit__(self, _type, value, traceback) -> None: + def __exit__( + self, _type: Type[Exception], value: Exception, traceback: TracebackType + ) -> None: self.close() # @@ -154,7 +173,7 @@ def close(self) -> None: self._ctx = ffi.NULL self._ctx_pp = ffi.NULL if self._did_load_tcti and self._tcti is not None: - self._tcti.close() + self._tcti.finalize() self._tcti = None def get_tcti(self) -> Optional[TCTI]: @@ -856,7 +875,7 @@ def load( def load_external( self, in_public: TPM2B_PUBLIC, - in_private: TPM2B_SENSITIVE = None, + in_private: Optional[TPM2B_SENSITIVE] = None, hierarchy: ESYS_TR = ESYS_TR.NULL, session1: ESYS_TR = ESYS_TR.NONE, session2: ESYS_TR = ESYS_TR.NONE, @@ -901,7 +920,7 @@ def load_external( in_public_cdata = _get_cdata(in_public, TPM2B_PUBLIC, "in_public") - hierarchy = ESAPI._fixup_hierarchy(hierarchy) + fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy) object_handle = ffi.new("ESYS_TR *") _chkrc( @@ -912,7 +931,7 @@ def load_external( session3, in_private_cdata, in_public_cdata, - hierarchy, + fixed_hierarchy, object_handle, ) ) @@ -2103,7 +2122,7 @@ def hash( _check_friendly_int(hash_alg, "hash_alg", TPM2_ALG) _check_friendly_int(hierarchy, "hierarchy", ESYS_TR) - hierarchy = ESAPI._fixup_hierarchy(hierarchy) + fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy) data_cdata = _get_cdata(data, TPM2B_MAX_BUFFER, "data") @@ -2117,7 +2136,7 @@ def hash( session3, data_cdata, hash_alg, - hierarchy, + fixed_hierarchy, out_hash, validation, ) @@ -2616,7 +2635,7 @@ def sequence_complete( _check_handle_type(session3, "session3") _check_friendly_int(hierarchy, "hierarchy", ESYS_TR) - hierarchy = ESAPI._fixup_hierarchy(hierarchy) + fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy) buffer_cdata = _get_cdata(buffer, TPM2B_MAX_BUFFER, "buffer", allow_none=True) @@ -2630,7 +2649,7 @@ def sequence_complete( session2, session3, buffer_cdata, - hierarchy, + fixed_hierarchy, result, validation, ) @@ -5107,7 +5126,7 @@ def hierarchy_control( "enable", expected=(ESYS_TR.ENDORSEMENT, ESYS_TR.OWNER, ESYS_TR.PLATFORM), ) - enable = ESAPI._fixup_hierarchy(enable) + fixed_enable = ESAPI._fixup_hierarchy(enable) if not isinstance(state, bool): raise TypeError(f"Expected state to be a bool, got {type(state)}") @@ -5118,7 +5137,13 @@ def hierarchy_control( _chkrc( lib.Esys_HierarchyControl( - self._ctx, auth_handle, session1, session2, session3, enable, state + self._ctx, + auth_handle, + session1, + session2, + session3, + fixed_enable, + state, ) ) @@ -5554,7 +5579,7 @@ def pp_commands( def set_algorithm_set( self, - algorithm_set: Union[List[int], int], + algorithm_set: int, auth_handle: ESYS_TR = ESYS_TR.PLATFORM, session1: ESYS_TR = ESYS_TR.PASSWORD, session2: ESYS_TR = ESYS_TR.NONE, @@ -5567,7 +5592,7 @@ def set_algorithm_set( available. Args: - algorithm_set (Union[List[int], int]): A TPM vendor-dependent value indicating the + algorithm_set (int): A TPM vendor-dependent value indicating the algorithm set selection. auth_handle (ESYS_TR): ESYS_TR.PLATFORM. Defaults to ESYS_TR.PLATFORM. session1 (ESYS_TR): A session for securing the TPM command (optional). Defaults to ESYS_TR.PASSWORD. @@ -6138,7 +6163,7 @@ def ac_get_capability( _check_friendly_int(capability, "capability", TPM_AT) if not isinstance(count, int): - raise TypeError(f"Expected count to be an int, got {type(prop)}") + raise TypeError(f"Expected count to be an int, got {type(count)}") _check_handle_type(ac, "ac") _check_handle_type(session1, "session1") @@ -6228,7 +6253,7 @@ def ac_send( ac_data_out, ) ) - return TPMS_AC_OUTPUT(_get_dptr(acDataOut, lib.Esys_Free)) + return TPMS_AC_OUTPUT(_get_dptr(ac_data_out, lib.Esys_Free)) def policy_ac_send_select( self, @@ -7169,7 +7194,7 @@ def tr_serialize(self, esys_handle: ESYS_TR) -> bytes: _chkrc(lib.Esys_TR_Serialize(self._ctx, esys_handle, buffer, buffer_size)) buffer_size = buffer_size[0] buffer = _get_dptr(buffer, lib.Esys_Free) - return bytes(ffi.buffer(buffer, buffer_size)) + return bytes(ffi.buffer(buffer, int(buffer_size))) def tr_deserialize(self, buffer: bytes) -> ESYS_TR: """Deserialization of an ESYS_TR from a byte buffer. @@ -7235,6 +7260,6 @@ def _fixup_hierarchy(hierarchy: ESYS_TR) -> Union[TPM2_RH, ESYS_TR]: "Expected hierarchy to be one of ESYS_TR.NULL, ESYS_TR.PLATFORM, ESYS_TR.OWNER, ESYS_TR.ENDORSMENT" ) - hierarchy = fixup_map[hierarchy] - - return hierarchy + return fixup_map[hierarchy] + else: + return hierarchy diff --git a/src/tpm2_pytss/constants.py b/src/tpm2_pytss/constants.py index 870f309e..4f694d94 100644 --- a/src/tpm2_pytss/constants.py +++ b/src/tpm2_pytss/constants.py @@ -21,6 +21,7 @@ TYPE_CHECKING, Any, SupportsIndex, + cast, ) try: @@ -422,8 +423,8 @@ class ESYS_TR(TPM_FRIENDLY_INT): or a persistent key use :func:`tpm2_pytss.ESAPI.tr_from_tpmpublic` """ - NONE = lib.ESYS_TR_NONE - PASSWORD = lib.ESYS_TR_PASSWORD + NONE = cast("ESYS_TR", lib.ESYS_TR_NONE) + PASSWORD = cast("ESYS_TR", lib.ESYS_TR_PASSWORD) PCR0 = lib.ESYS_TR_PCR0 PCR1 = lib.ESYS_TR_PCR1 PCR2 = lib.ESYS_TR_PCR2 @@ -456,11 +457,11 @@ class ESYS_TR(TPM_FRIENDLY_INT): PCR29 = lib.ESYS_TR_PCR29 PCR30 = lib.ESYS_TR_PCR30 PCR31 = lib.ESYS_TR_PCR31 - OWNER = lib.ESYS_TR_RH_OWNER - NULL = lib.ESYS_TR_RH_NULL - LOCKOUT = lib.ESYS_TR_RH_LOCKOUT - ENDORSEMENT = lib.ESYS_TR_RH_ENDORSEMENT - PLATFORM = lib.ESYS_TR_RH_PLATFORM + OWNER = cast("ESYS_TR", lib.ESYS_TR_RH_OWNER) + NULL = cast("ESYS_TR", lib.ESYS_TR_RH_NULL) + LOCKOUT = cast("ESYS_TR", lib.ESYS_TR_RH_LOCKOUT) + ENDORSEMENT = cast("ESYS_TR", lib.ESYS_TR_RH_ENDORSEMENT) + PLATFORM = cast("ESYS_TR", lib.ESYS_TR_RH_PLATFORM) PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV RH_OWNER = lib.ESYS_TR_RH_OWNER RH_NULL = lib.ESYS_TR_RH_NULL @@ -546,21 +547,21 @@ def parts_to_blob(handle: "TPM2_HANDLE", public: "TPM2B_PUBLIC") -> bytes: @TPM_FRIENDLY_INT._fix_const_type class TPM2_RH(TPM_FRIENDLY_INT): SRK = lib.TPM2_RH_SRK - OWNER = lib.TPM2_RH_OWNER + OWNER = cast("TPM2_RH", lib.TPM2_RH_OWNER) REVOKE = lib.TPM2_RH_REVOKE TRANSPORT = lib.TPM2_RH_TRANSPORT OPERATOR = lib.TPM2_RH_OPERATOR ADMIN = lib.TPM2_RH_ADMIN EK = lib.TPM2_RH_EK - NULL = lib.TPM2_RH_NULL + NULL = cast("TPM2_RH", lib.TPM2_RH_NULL) UNASSIGNED = lib.TPM2_RH_UNASSIGNED try: PW = lib.TPM2_RS_PW except AttributeError: PW = lib.TPM2_RH_PW LOCKOUT = lib.TPM2_RH_LOCKOUT - ENDORSEMENT = lib.TPM2_RH_ENDORSEMENT - PLATFORM = lib.TPM2_RH_PLATFORM + ENDORSEMENT = cast("TPM2_RH", lib.TPM2_RH_ENDORSEMENT) + PLATFORM = cast("TPM2_RH", lib.TPM2_RH_PLATFORM) PLATFORM_NV = lib.TPM2_RH_PLATFORM_NV @@ -1131,8 +1132,8 @@ class TPM2_ST(TPM_FRIENDLY_INT): @TPM_FRIENDLY_INT._fix_const_type class TPM2_SU(TPM_FRIENDLY_INT): - CLEAR = lib.TPM2_SU_CLEAR - STATE = lib.TPM2_SU_STATE + CLEAR = cast("TPM2_SU", lib.TPM2_SU_CLEAR) + STATE = cast("TPM2_SU", lib.TPM2_SU_STATE) @TPM_FRIENDLY_INT._fix_const_type diff --git a/src/tpm2_pytss/internal/utils.py b/src/tpm2_pytss/internal/utils.py index 680ae0b8..80199bd2 100644 --- a/src/tpm2_pytss/internal/utils.py +++ b/src/tpm2_pytss/internal/utils.py @@ -366,7 +366,7 @@ def _get_dptr(dptr: ffi.CData, free_func: Callable[[ffi.CData], None]) -> ffi.CD def _check_friendly_int( - friendly: int, varname: int, clazz: Type["TPM_FRIENDLY_INT"] + friendly: int, varname: str, clazz: Type["TPM_FRIENDLY_INT"] ) -> None: if not isinstance(friendly, int): diff --git a/src/tpm2_pytss/types.py b/src/tpm2_pytss/types.py index 33fb6cd2..8ce63a7d 100644 --- a/src/tpm2_pytss/types.py +++ b/src/tpm2_pytss/types.py @@ -77,6 +77,8 @@ class TPM2_HANDLE(int): class TPM_OBJECT(object): """ Abstract Base class for all TPM Objects. Not suitable for direct instantiation.""" + _cdata: ffi.CData + def __init__(self, _cdata: Optional[Any] = None, **kwargs: Any): # Rather than trying to mock the FFI interface, just avoid it and return @@ -253,7 +255,9 @@ class TPM2B_SIMPLE_OBJECT(TPM_OBJECT): """ Abstract Base class for all TPM2B Simple Objects. A Simple object contains only a size and byte buffer fields. This is not suitable for direct instantiation.""" - def __init__(self, _cdata: Optional[Union[ffi.CData, bytes]] = None, **kwargs: Any): + def __init__( + self, _cdata: Optional[Union[ffi.CData, bytes, str]] = None, **kwargs: Any + ): _cdata, kwargs = _fixup_cdata_kwargs(self, _cdata, kwargs) _bytefield = type(self)._get_bytefield() @@ -299,7 +303,7 @@ def __getattribute__(self, key: str) -> Any: if key == _bytefield: b = getattr(self._cdata, _bytefield) rb = _ref_parent(b, self._cdata) - return memoryview(ffi.buffer(rb, self._cdata.size)) + return memoryview(ffi.buffer(rb, int(self._cdata.size))) return super().__getattribute__(key) def __len__(self) -> int: