From 1b7f923c68c79f8b17ec5ee8d2d907dae2e2fe29 Mon Sep 17 00:00:00 2001 From: Giordon Stark Date: Sat, 21 Dec 2024 10:15:34 -0500 Subject: [PATCH] fix typehints in 3.13 --- pyproject.toml | 1 + src/pyhf/readxml.py | 2 +- src/pyhf/tensor/numpy_backend.py | 17 ++++++++++------- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 629c3ee83e..250987e44f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,6 +242,7 @@ warn_unused_configs = true strict = true enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] warn_unreachable = true +plugins = "numpy.typing.mypy_plugin" [[tool.mypy.overrides]] module = [ diff --git a/src/pyhf/readxml.py b/src/pyhf/readxml.py index 52612ae082..2ad5fe6642 100644 --- a/src/pyhf/readxml.py +++ b/src/pyhf/readxml.py @@ -216,7 +216,7 @@ def process_sample( modtag.attrib.get('HistoPath', ''), modtag.attrib['HistoName'], ) - staterr = np.multiply(extstat, data).tolist() + staterr = cast(list[float], np.multiply(extstat, data).tolist()) if not staterr: raise RuntimeError('cannot determine stat error.') modifier_staterror: StatError = { diff --git a/src/pyhf/tensor/numpy_backend.py b/src/pyhf/tensor/numpy_backend.py index 98d29a185d..423d54c152 100644 --- a/src/pyhf/tensor/numpy_backend.py +++ b/src/pyhf/tensor/numpy_backend.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Callable, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union from collections.abc import Mapping, Sequence import numpy as np @@ -27,7 +27,7 @@ log = logging.getLogger(__name__) -class _BasicPoisson: +class _BasicPoisson(Generic[T]): def __init__(self, rate: Tensor[T]): self.rate = rate @@ -39,7 +39,7 @@ def log_prob(self, value: NDArray[np.number[T]]) -> ArrayLike: return tensorlib.poisson_logpdf(value, self.rate) -class _BasicNormal: +class _BasicNormal(Generic[T]): def __init__(self, loc: Tensor[T], scale: Tensor[T]): self.loc = loc self.scale = scale @@ -199,9 +199,12 @@ def conditional( """ return true_callable() if predicate else false_callable() - def tolist(self, tensor_in: Tensor[T] | list[T]) -> list[T]: + def tolist( + self, tensor_in: Tensor[T] | list[T] + ) -> int | float | complex | list[T] | list[Any]: try: - return tensor_in.tolist() # type: ignore[union-attr,no-any-return] + # unused-ignore for [no-any-return] in python 3.9 + return tensor_in.tolist() # type: ignore[union-attr,no-any-return,unused-ignore] except AttributeError: if isinstance(tensor_in, list): return tensor_in @@ -551,7 +554,7 @@ def normal_cdf( """ return norm.cdf(x, loc=mu, scale=sigma) # type: ignore[no-any-return] - def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson: + def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson[T]: r""" The Poisson distribution with rate parameter :code:`rate`. @@ -572,7 +575,7 @@ def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson: """ return _BasicPoisson(rate) - def normal_dist(self, mu: Tensor[T], sigma: Tensor[T]) -> _BasicNormal: + def normal_dist(self, mu: Tensor[T], sigma: Tensor[T]) -> _BasicNormal[T]: r""" The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`.