Skip to content

Commit

Permalink
ESAPI: fix type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Erik Larsson <[email protected]>
  • Loading branch information
whooo committed Jan 20, 2024
1 parent 212044f commit 28bbc65
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 43 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ exclude = '''
mypy_path = "mypy_stubs"
exclude = [
'src/tpm2_pytss/encoding.py',
'src/tpm2_pytss/ESAPI.py',
'src/tpm2_pytss/FAPI.py',
'src/tpm2_pytss/internal/crypto.py',
'src/tpm2_pytss/fapi_info.py',
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def build_function(self, d):
if param.type.type.type.names
else None
)
if tn == "char":
if tn in ("char", "uint8_t"):
ft = "CData | bytes"
elif isinstance(
param.type, cparser.pycparser.c_ast.TypeDecl
Expand Down
75 changes: 50 additions & 25 deletions src/tpm2_pytss/ESAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,30 @@
)
from .TCTI import TCTI
from .TCTILdr import TCTILdr
from ._libtpm2_pytss import ffi, lib

from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, List, Any, Type, Callable, Sequence

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'List' is not used.

try:
from typing import Self
except ImportError:
# assume mypy is running on python 3.11+
pass
from types import TracebackType

# Work around this FAPI dependency if FAPI is not present with the constant value
_fapi_installed_ = _lib_version_atleast("tss2-fapi", "3.0.0")
_DEFAULT_LOAD_BLOB_SELECTOR = FAPI_ESYSBLOB.CONTEXTLOAD if _fapi_installed_ else 1


def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs):
def _get_cdata(
value: Any,
expected: Type["TPM_OBJECT"],
varname: str,
allow_none: bool = False,
*args: Any,
**kwargs: Any,
) -> ffi.CData:
tname = expected.__name__

if value is None and allow_none:
Expand All @@ -36,12 +51,12 @@ def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs):
return value

vname = type(value).__name__
parse_method = getattr(expected, "parse", None)
parse_method: Optional[Callable[..., ffi.CData]] = getattr(expected, "parse", None)
if isinstance(value, (bytes, str)) and issubclass(expected, TPM2B_SIMPLE_OBJECT):
bo = expected(value)
return bo._cdata
elif isinstance(value, str) and parse_method and callable(parse_method):
return expected.parse(value, *args, **kwargs)._cdata
return parse_method(value, *args, **kwargs)._cdata
elif issubclass(expected, TPML_OBJECT) and isinstance(value, list):
return expected(value)._cdata
elif not isinstance(value, expected):
Expand All @@ -50,7 +65,9 @@ def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs):
return value._cdata


def _check_handle_type(handle, varname, expected=None):
def _check_handle_type(
handle: ESYS_TR, varname: str, expected: Optional[Sequence[ESYS_TR]] = None
) -> None:
if not isinstance(handle, ESYS_TR):
raise TypeError(f"expected {varname} to be type ESYS_TR, got {type(handle)}")

Expand Down Expand Up @@ -129,10 +146,12 @@ def __init__(self, tcti: Union[TCTI, str, None] = None):
_chkrc(lib.Esys_Initialize(self._ctx_pp, tctx, ffi.NULL))
self._ctx = self._ctx_pp[0]

def __enter__(self):
def __enter__(self) -> "Self":
return self

def __exit__(self, _type, value, traceback) -> None:
def __exit__(
self, _type: Type[Exception], value: Exception, traceback: TracebackType
) -> None:
self.close()

#
Expand All @@ -154,7 +173,7 @@ def close(self) -> None:
self._ctx = ffi.NULL
self._ctx_pp = ffi.NULL
if self._did_load_tcti and self._tcti is not None:
self._tcti.close()
self._tcti.finalize()
self._tcti = None

def get_tcti(self) -> Optional[TCTI]:
Expand Down Expand Up @@ -856,7 +875,7 @@ def load(
def load_external(
self,
in_public: TPM2B_PUBLIC,
in_private: TPM2B_SENSITIVE = None,
in_private: Optional[TPM2B_SENSITIVE] = None,
hierarchy: ESYS_TR = ESYS_TR.NULL,
session1: ESYS_TR = ESYS_TR.NONE,
session2: ESYS_TR = ESYS_TR.NONE,
Expand Down Expand Up @@ -901,7 +920,7 @@ def load_external(

in_public_cdata = _get_cdata(in_public, TPM2B_PUBLIC, "in_public")

hierarchy = ESAPI._fixup_hierarchy(hierarchy)
fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy)

object_handle = ffi.new("ESYS_TR *")
_chkrc(
Expand All @@ -912,7 +931,7 @@ def load_external(
session3,
in_private_cdata,
in_public_cdata,
hierarchy,
fixed_hierarchy,
object_handle,
)
)
Expand Down Expand Up @@ -2103,7 +2122,7 @@ def hash(
_check_friendly_int(hash_alg, "hash_alg", TPM2_ALG)

_check_friendly_int(hierarchy, "hierarchy", ESYS_TR)
hierarchy = ESAPI._fixup_hierarchy(hierarchy)
fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy)

data_cdata = _get_cdata(data, TPM2B_MAX_BUFFER, "data")

Expand All @@ -2117,7 +2136,7 @@ def hash(
session3,
data_cdata,
hash_alg,
hierarchy,
fixed_hierarchy,
out_hash,
validation,
)
Expand Down Expand Up @@ -2616,7 +2635,7 @@ def sequence_complete(
_check_handle_type(session3, "session3")

_check_friendly_int(hierarchy, "hierarchy", ESYS_TR)
hierarchy = ESAPI._fixup_hierarchy(hierarchy)
fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy)

buffer_cdata = _get_cdata(buffer, TPM2B_MAX_BUFFER, "buffer", allow_none=True)

Expand All @@ -2630,7 +2649,7 @@ def sequence_complete(
session2,
session3,
buffer_cdata,
hierarchy,
fixed_hierarchy,
result,
validation,
)
Expand Down Expand Up @@ -5107,7 +5126,7 @@ def hierarchy_control(
"enable",
expected=(ESYS_TR.ENDORSEMENT, ESYS_TR.OWNER, ESYS_TR.PLATFORM),
)
enable = ESAPI._fixup_hierarchy(enable)
fixed_enable = ESAPI._fixup_hierarchy(enable)

if not isinstance(state, bool):
raise TypeError(f"Expected state to be a bool, got {type(state)}")
Expand All @@ -5118,7 +5137,13 @@ def hierarchy_control(

_chkrc(
lib.Esys_HierarchyControl(
self._ctx, auth_handle, session1, session2, session3, enable, state
self._ctx,
auth_handle,
session1,
session2,
session3,
fixed_enable,
state,
)
)

Expand Down Expand Up @@ -5554,7 +5579,7 @@ def pp_commands(

def set_algorithm_set(
self,
algorithm_set: Union[List[int], int],
algorithm_set: int,
auth_handle: ESYS_TR = ESYS_TR.PLATFORM,
session1: ESYS_TR = ESYS_TR.PASSWORD,
session2: ESYS_TR = ESYS_TR.NONE,
Expand All @@ -5567,7 +5592,7 @@ def set_algorithm_set(
available.
Args:
algorithm_set (Union[List[int], int]): A TPM vendor-dependent value indicating the
algorithm_set (int): A TPM vendor-dependent value indicating the
algorithm set selection.
auth_handle (ESYS_TR): ESYS_TR.PLATFORM. Defaults to ESYS_TR.PLATFORM.
session1 (ESYS_TR): A session for securing the TPM command (optional). Defaults to ESYS_TR.PASSWORD.
Expand Down Expand Up @@ -6138,7 +6163,7 @@ def ac_get_capability(
_check_friendly_int(capability, "capability", TPM_AT)

if not isinstance(count, int):
raise TypeError(f"Expected count to be an int, got {type(prop)}")
raise TypeError(f"Expected count to be an int, got {type(count)}")

_check_handle_type(ac, "ac")
_check_handle_type(session1, "session1")
Expand Down Expand Up @@ -6228,7 +6253,7 @@ def ac_send(
ac_data_out,
)
)
return TPMS_AC_OUTPUT(_get_dptr(acDataOut, lib.Esys_Free))
return TPMS_AC_OUTPUT(_get_dptr(ac_data_out, lib.Esys_Free))

def policy_ac_send_select(
self,
Expand Down Expand Up @@ -7169,7 +7194,7 @@ def tr_serialize(self, esys_handle: ESYS_TR) -> bytes:
_chkrc(lib.Esys_TR_Serialize(self._ctx, esys_handle, buffer, buffer_size))
buffer_size = buffer_size[0]
buffer = _get_dptr(buffer, lib.Esys_Free)
return bytes(ffi.buffer(buffer, buffer_size))
return bytes(ffi.buffer(buffer, int(buffer_size)))

def tr_deserialize(self, buffer: bytes) -> ESYS_TR:
"""Deserialization of an ESYS_TR from a byte buffer.
Expand Down Expand Up @@ -7235,6 +7260,6 @@ def _fixup_hierarchy(hierarchy: ESYS_TR) -> Union[TPM2_RH, ESYS_TR]:
"Expected hierarchy to be one of ESYS_TR.NULL, ESYS_TR.PLATFORM, ESYS_TR.OWNER, ESYS_TR.ENDORSMENT"
)

hierarchy = fixup_map[hierarchy]

return hierarchy
return fixup_map[hierarchy]
else:
return hierarchy
27 changes: 14 additions & 13 deletions src/tpm2_pytss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TYPE_CHECKING,
Any,
SupportsIndex,
cast,
)

try:
Expand Down Expand Up @@ -422,8 +423,8 @@ class ESYS_TR(TPM_FRIENDLY_INT):
or a persistent key use :func:`tpm2_pytss.ESAPI.tr_from_tpmpublic`
"""

NONE = lib.ESYS_TR_NONE
PASSWORD = lib.ESYS_TR_PASSWORD
NONE = cast("ESYS_TR", lib.ESYS_TR_NONE)
PASSWORD = cast("ESYS_TR", lib.ESYS_TR_PASSWORD)
PCR0 = lib.ESYS_TR_PCR0
PCR1 = lib.ESYS_TR_PCR1
PCR2 = lib.ESYS_TR_PCR2
Expand Down Expand Up @@ -456,11 +457,11 @@ class ESYS_TR(TPM_FRIENDLY_INT):
PCR29 = lib.ESYS_TR_PCR29
PCR30 = lib.ESYS_TR_PCR30
PCR31 = lib.ESYS_TR_PCR31
OWNER = lib.ESYS_TR_RH_OWNER
NULL = lib.ESYS_TR_RH_NULL
LOCKOUT = lib.ESYS_TR_RH_LOCKOUT
ENDORSEMENT = lib.ESYS_TR_RH_ENDORSEMENT
PLATFORM = lib.ESYS_TR_RH_PLATFORM
OWNER = cast("ESYS_TR", lib.ESYS_TR_RH_OWNER)
NULL = cast("ESYS_TR", lib.ESYS_TR_RH_NULL)
LOCKOUT = cast("ESYS_TR", lib.ESYS_TR_RH_LOCKOUT)
ENDORSEMENT = cast("ESYS_TR", lib.ESYS_TR_RH_ENDORSEMENT)
PLATFORM = cast("ESYS_TR", lib.ESYS_TR_RH_PLATFORM)
PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV
RH_OWNER = lib.ESYS_TR_RH_OWNER
RH_NULL = lib.ESYS_TR_RH_NULL
Expand Down Expand Up @@ -546,21 +547,21 @@ def parts_to_blob(handle: "TPM2_HANDLE", public: "TPM2B_PUBLIC") -> bytes:
@TPM_FRIENDLY_INT._fix_const_type
class TPM2_RH(TPM_FRIENDLY_INT):
SRK = lib.TPM2_RH_SRK
OWNER = lib.TPM2_RH_OWNER
OWNER = cast("TPM2_RH", lib.TPM2_RH_OWNER)
REVOKE = lib.TPM2_RH_REVOKE
TRANSPORT = lib.TPM2_RH_TRANSPORT
OPERATOR = lib.TPM2_RH_OPERATOR
ADMIN = lib.TPM2_RH_ADMIN
EK = lib.TPM2_RH_EK
NULL = lib.TPM2_RH_NULL
NULL = cast("TPM2_RH", lib.TPM2_RH_NULL)
UNASSIGNED = lib.TPM2_RH_UNASSIGNED
try:
PW = lib.TPM2_RS_PW
except AttributeError:
PW = lib.TPM2_RH_PW
LOCKOUT = lib.TPM2_RH_LOCKOUT
ENDORSEMENT = lib.TPM2_RH_ENDORSEMENT
PLATFORM = lib.TPM2_RH_PLATFORM
ENDORSEMENT = cast("TPM2_RH", lib.TPM2_RH_ENDORSEMENT)
PLATFORM = cast("TPM2_RH", lib.TPM2_RH_PLATFORM)
PLATFORM_NV = lib.TPM2_RH_PLATFORM_NV


Expand Down Expand Up @@ -1131,8 +1132,8 @@ class TPM2_ST(TPM_FRIENDLY_INT):

@TPM_FRIENDLY_INT._fix_const_type
class TPM2_SU(TPM_FRIENDLY_INT):
CLEAR = lib.TPM2_SU_CLEAR
STATE = lib.TPM2_SU_STATE
CLEAR = cast("TPM2_SU", lib.TPM2_SU_CLEAR)
STATE = cast("TPM2_SU", lib.TPM2_SU_STATE)


@TPM_FRIENDLY_INT._fix_const_type
Expand Down
2 changes: 1 addition & 1 deletion src/tpm2_pytss/internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def _get_dptr(dptr: ffi.CData, free_func: Callable[[ffi.CData], None]) -> ffi.CD


def _check_friendly_int(
friendly: int, varname: int, clazz: Type["TPM_FRIENDLY_INT"]
friendly: int, varname: str, clazz: Type["TPM_FRIENDLY_INT"]
) -> None:

if not isinstance(friendly, int):
Expand Down
8 changes: 6 additions & 2 deletions src/tpm2_pytss/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class TPM2_HANDLE(int):
class TPM_OBJECT(object):
""" Abstract Base class for all TPM Objects. Not suitable for direct instantiation."""

_cdata: ffi.CData

def __init__(self, _cdata: Optional[Any] = None, **kwargs: Any):

# Rather than trying to mock the FFI interface, just avoid it and return
Expand Down Expand Up @@ -253,7 +255,9 @@ class TPM2B_SIMPLE_OBJECT(TPM_OBJECT):
""" Abstract Base class for all TPM2B Simple Objects. A Simple object contains only
a size and byte buffer fields. This is not suitable for direct instantiation."""

def __init__(self, _cdata: Optional[Union[ffi.CData, bytes]] = None, **kwargs: Any):
def __init__(
self, _cdata: Optional[Union[ffi.CData, bytes, str]] = None, **kwargs: Any
):

_cdata, kwargs = _fixup_cdata_kwargs(self, _cdata, kwargs)
_bytefield = type(self)._get_bytefield()
Expand Down Expand Up @@ -299,7 +303,7 @@ def __getattribute__(self, key: str) -> Any:
if key == _bytefield:
b = getattr(self._cdata, _bytefield)
rb = _ref_parent(b, self._cdata)
return memoryview(ffi.buffer(rb, self._cdata.size))
return memoryview(ffi.buffer(rb, int(self._cdata.size)))
return super().__getattribute__(key)

def __len__(self) -> int:
Expand Down

0 comments on commit 28bbc65

Please sign in to comment.