Skip to content

Commit

Permalink
fix typehints in 3.13
Browse files Browse the repository at this point in the history
  • Loading branch information
kratsg committed Dec 21, 2024
1 parent a8d518c commit 1b7f923
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ warn_unused_configs = true
strict = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_unreachable = true
plugins = "numpy.typing.mypy_plugin"

[[tool.mypy.overrides]]
module = [
Expand Down
2 changes: 1 addition & 1 deletion src/pyhf/readxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def process_sample(
modtag.attrib.get('HistoPath', ''),
modtag.attrib['HistoName'],
)
staterr = np.multiply(extstat, data).tolist()
staterr = cast(list[float], np.multiply(extstat, data).tolist())
if not staterr:
raise RuntimeError('cannot determine stat error.')
modifier_staterror: StatError = {
Expand Down
17 changes: 10 additions & 7 deletions src/pyhf/tensor/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable, Generic, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
from collections.abc import Mapping, Sequence

import numpy as np
Expand All @@ -27,7 +27,7 @@
log = logging.getLogger(__name__)


class _BasicPoisson:
class _BasicPoisson(Generic[T]):
def __init__(self, rate: Tensor[T]):
self.rate = rate

Expand All @@ -39,7 +39,7 @@ def log_prob(self, value: NDArray[np.number[T]]) -> ArrayLike:
return tensorlib.poisson_logpdf(value, self.rate)


class _BasicNormal:
class _BasicNormal(Generic[T]):
def __init__(self, loc: Tensor[T], scale: Tensor[T]):
self.loc = loc
self.scale = scale
Expand Down Expand Up @@ -199,9 +199,12 @@ def conditional(
"""
return true_callable() if predicate else false_callable()

def tolist(self, tensor_in: Tensor[T] | list[T]) -> list[T]:
def tolist(
self, tensor_in: Tensor[T] | list[T]
) -> int | float | complex | list[T] | list[Any]:
try:
return tensor_in.tolist() # type: ignore[union-attr,no-any-return]
# unused-ignore for [no-any-return] in python 3.9
return tensor_in.tolist() # type: ignore[union-attr,no-any-return,unused-ignore]
except AttributeError:
if isinstance(tensor_in, list):
return tensor_in
Expand Down Expand Up @@ -551,7 +554,7 @@ def normal_cdf(
"""
return norm.cdf(x, loc=mu, scale=sigma) # type: ignore[no-any-return]

def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson:
def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson[T]:
r"""
The Poisson distribution with rate parameter :code:`rate`.
Expand All @@ -572,7 +575,7 @@ def poisson_dist(self, rate: Tensor[T]) -> _BasicPoisson:
"""
return _BasicPoisson(rate)

def normal_dist(self, mu: Tensor[T], sigma: Tensor[T]) -> _BasicNormal:
def normal_dist(self, mu: Tensor[T], sigma: Tensor[T]) -> _BasicNormal[T]:
r"""
The Normal distribution with mean :code:`mu` and standard deviation :code:`sigma`.
Expand Down

0 comments on commit 1b7f923

Please sign in to comment.