diff --git a/spinn_utilities/ranged/abstract_dict.py b/spinn_utilities/ranged/abstract_dict.py index b6ffd4fc..2e4e4743 100644 --- a/spinn_utilities/ranged/abstract_dict.py +++ b/spinn_utilities/ranged/abstract_dict.py @@ -22,7 +22,7 @@ # Can't be Iterable[str] or Sequence[str] because that includes str itself _StrSeq: TypeAlias = Union[ MutableSequence[str], Tuple[str, ...], FrozenSet[str], Set[str]] -_Keys: TypeAlias = Union[None, str, _StrSeq] +_Keys: TypeAlias = Optional[Union[str, _StrSeq]] class AbstractDict(Generic[T], metaclass=AbstractBase): @@ -117,7 +117,7 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @abstractmethod - def iter_all_values(self, key, update_safe=False): + def iter_all_values(self, key: _Keys, update_safe: bool = False): """ Iterates over the value(s) for all IDs covered by this view. There will be one yield for each ID even if values are repeated. @@ -181,7 +181,7 @@ def iter_ranges(self, key: Optional[_StrSeq]) -> Iterator[Tuple[ ... @abstractmethod - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None): """ Iterates over the ranges(s) for all IDs covered by this view. There will be one yield for each range which may cover one or diff --git a/spinn_utilities/ranged/abstract_list.py b/spinn_utilities/ranged/abstract_list.py index 3f0ab846..a18ca3ef 100644 --- a/spinn_utilities/ranged/abstract_list.py +++ b/spinn_utilities/ranged/abstract_list.py @@ -645,20 +645,21 @@ def range_based(self): return self._a_list.range_based() @overrides(AbstractList.get_value_by_id) - def get_value_by_id(self, the_id): + def get_value_by_id(self, the_id: int) -> T: return self._operation(self._a_list.get_value_by_id(the_id)) @overrides(AbstractList.get_single_value_by_slice) - def get_single_value_by_slice(self, slice_start, slice_stop): + def get_single_value_by_slice( + self, slice_start: int, slice_stop: int) -> T: return self._operation(self._a_list.get_single_value_by_slice( slice_start, slice_stop)) @overrides(AbstractList.get_single_value_by_ids) - def get_single_value_by_ids(self, ids): + def get_single_value_by_ids(self, ids: IdsType) -> T: return self._operation(self._a_list.get_single_value_by_ids(ids)) @overrides(AbstractList.iter_ranges) - def iter_ranges(self): + def iter_ranges(self) -> Iterator[Tuple[int, int, T]]: for (start, stop, value) in self._a_list.iter_ranges(): yield (start, stop, self._operation(value)) @@ -667,7 +668,9 @@ def get_default(self): return self._operation(self._a_list.get_default()) @overrides(AbstractList.iter_ranges_by_slice) - def iter_ranges_by_slice(self, slice_start, slice_stop): + def iter_ranges_by_slice( + self, slice_start: int, slice_stop: int) -> Iterator[ + Tuple[int, int, T]]: for (start, stop, value) in \ self._a_list.iter_ranges_by_slice(slice_start, slice_stop): yield (start, stop, self._operation(value)) @@ -704,25 +707,26 @@ def range_based(self): return self._left.range_based() and self._right.range_based() @overrides(AbstractList.get_value_by_id) - def get_value_by_id(self, the_id): + def get_value_by_id(self, the_id: int) -> T: return self._operation( self._left.get_value_by_id(the_id), self._right.get_value_by_id(the_id)) @overrides(AbstractList.get_single_value_by_slice) - def get_single_value_by_slice(self, slice_start, slice_stop): + def get_single_value_by_slice( + self, slice_start: int, slice_stop: int) -> T: return self._operation( self._left.get_single_value_by_slice(slice_start, slice_stop), self._right.get_single_value_by_slice(slice_start, slice_stop)) @overrides(AbstractList.get_single_value_by_ids) - def get_single_value_by_ids(self, ids): + def get_single_value_by_ids(self, ids: IdsType) -> T: return self._operation( self._left.get_single_value_by_ids(ids), self._right.get_single_value_by_ids(ids)) @overrides(AbstractList.iter_by_slice) - def iter_by_slice(self, slice_start, slice_stop): + def iter_by_slice(self, slice_start: int, slice_stop: int) -> Iterator[T]: slice_start, slice_stop = self._check_slice_in_range( slice_start, slice_stop) if self._left.range_based(): @@ -772,7 +776,9 @@ def iter_ranges(self): return self._merge_ranges(left_iter, right_iter) @overrides(AbstractList.iter_ranges_by_slice) - def iter_ranges_by_slice(self, slice_start, slice_stop): + def iter_ranges_by_slice( + self, slice_start: int, slice_stop: int) -> Iterator[ + Tuple[int, int, T]]: left_iter = self._left.iter_ranges_by_slice(slice_start, slice_stop) right_iter = self._right.iter_ranges_by_slice(slice_start, slice_stop) return self._merge_ranges(left_iter, right_iter) diff --git a/spinn_utilities/ranged/ids_view.py b/spinn_utilities/ranged/ids_view.py index eb03b020..80ba3405 100644 --- a/spinn_utilities/ranged/ids_view.py +++ b/spinn_utilities/ranged/ids_view.py @@ -65,7 +65,8 @@ def get_value(self, key: _Keys): for k in key} @overrides(AbstractDict.set_value) - def set_value(self, key: str, value: T, use_list_as_value=False): + def set_value( + self, key: str, value: T, use_list_as_value: bool = False): ranged_list = self._range_dict.get_list(key) for _id in self._ids: ranged_list.set_value_by_id(the_id=_id, value=value) @@ -85,7 +86,7 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @overrides(AbstractDict.iter_all_values) - def iter_all_values(self, key: _Keys, update_safe=False): + def iter_all_values(self, key: _Keys, update_safe: bool = False): if isinstance(key, str): yield from self._range_dict.iter_values_by_ids( ids=self._ids, key=key, update_safe=update_safe) @@ -103,5 +104,5 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[Tuple[ ... @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None): return self._range_dict.iter_ranges_by_ids(key=key, ids=self._ids) diff --git a/spinn_utilities/ranged/range_dictionary.py b/spinn_utilities/ranged/range_dictionary.py index f4eb98c3..195d28bc 100644 --- a/spinn_utilities/ranged/range_dictionary.py +++ b/spinn_utilities/ranged/range_dictionary.py @@ -28,6 +28,8 @@ from .abstract_view import AbstractView _KeyType: TypeAlias = Union[int, slice, Iterable[int]] +_Keys: TypeAlias = Union[None, str, _StrSeq] + _Range: TypeAlias = Tuple[int, int, T] _SimpleRangeIter: TypeAlias = Iterator[_Range] _CompoundRangeIter: TypeAlias = Iterator[Tuple[int, int, Dict[str, T]]] @@ -246,7 +248,7 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @overrides(AbstractDict.iter_all_values, extend_defaults=True) - def iter_all_values(self, key=None, update_safe: bool = False): + def iter_all_values(self, key: _Keys, update_safe=False): if isinstance(key, str): if update_safe: return self._value_lists[key].iter() diff --git a/spinn_utilities/ranged/single_view.py b/spinn_utilities/ranged/single_view.py index 40f42118..b46d8767 100644 --- a/spinn_utilities/ranged/single_view.py +++ b/spinn_utilities/ranged/single_view.py @@ -74,7 +74,7 @@ def iter_all_values( ... @overrides(AbstractDict.iter_all_values) - def iter_all_values(self, key, update_safe=False): + def iter_all_values(self, key: _Keys, update_safe: bool = False): if isinstance(key, str): yield self._range_dict.get_list(key).get_value_by_id( the_id=self._id) @@ -82,7 +82,7 @@ def iter_all_values(self, key, update_safe=False): yield self._range_dict.get_values_by_id(key=key, the_id=self._id) @overrides(AbstractDict.set_value) - def set_value(self, key: str, value: T, use_list_as_value=False): + def set_value(self, key: str, value: T, use_list_as_value: bool = False): return self._range_dict.get_list(key).set_value_by_id( value=value, the_id=self._id) @@ -96,5 +96,5 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[ ... @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None): return self._range_dict.iter_ranges_by_id(key=key, the_id=self._id) diff --git a/spinn_utilities/ranged/slice_view.py b/spinn_utilities/ranged/slice_view.py index 4ed4519a..18a6e385 100644 --- a/spinn_utilities/ranged/slice_view.py +++ b/spinn_utilities/ranged/slice_view.py @@ -82,7 +82,7 @@ def iter_all_values( ... @overrides(AbstractDict.iter_all_values, extend_defaults=True) - def iter_all_values(self, key=None, update_safe=False): + def iter_all_values(self, key: _Keys = None, update_safe: bool = False): if isinstance(key, str): if update_safe: return self.update_safe_iter_all_values(key) @@ -94,7 +94,7 @@ def iter_all_values(self, key=None, update_safe=False): @overrides(AbstractDict.set_value) def set_value( - self, key: str, value: _ValueType, use_list_as_value=False): + self, key: str, value: _ValueType, use_list_as_value: bool = False): self._range_dict.get_list(key).set_value_by_slice( slice_start=self._start, slice_stop=self._stop, value=value, use_list_as_value=use_list_as_value) @@ -109,6 +109,6 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[ ... @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None): return self._range_dict.iter_ranges_by_slice( key=key, slice_start=self._start, slice_stop=self._stop) diff --git a/unittests/abstract_base/abstract_has_constraints.py b/unittests/abstract_base/abstract_has_constraints.py index fb960db9..d67ed6de 100644 --- a/unittests/abstract_base/abstract_has_constraints.py +++ b/unittests/abstract_base/abstract_has_constraints.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.abstract_base import ( AbstractBase, abstractproperty, abstractmethod) @@ -23,7 +24,7 @@ class AbstractHasConstraints(object, metaclass=AbstractBase): __slots__ = () @abstractmethod - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): """ Add a new constraint to the collection of constraints :param constraint: constraint to add @@ -33,7 +34,8 @@ def add_constraint(self, constraint): If the constraint is not valid """ - @abstractproperty + @property + @abstractmethod def constraints(self): """ An iterable of constraints diff --git a/unittests/abstract_base/grandparent.py b/unittests/abstract_base/grandparent.py index b142ebfb..99cc8392 100644 --- a/unittests/abstract_base/grandparent.py +++ b/unittests/abstract_base/grandparent.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.overrides import overrides from .abstract_grandparent import AbstractGrandParent from .abstract_has_constraints import AbstractHasConstraints @@ -26,7 +27,7 @@ def set_label(selfself, label): pass @overrides(AbstractHasConstraints.add_constraint) - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): raise NotImplementedError("We set our own constrainst") @overrides(AbstractHasConstraints.constraints) diff --git a/unittests/abstract_base/no_label.py b/unittests/abstract_base/no_label.py index 4b5d99ee..3ac3bb23 100644 --- a/unittests/abstract_base/no_label.py +++ b/unittests/abstract_base/no_label.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.overrides import overrides from .abstract_grandparent import AbstractGrandParent from .abstract_has_constraints import AbstractHasConstraints @@ -23,7 +24,7 @@ def set_label(selfself, label): pass @overrides(AbstractHasConstraints.add_constraint) - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): raise NotImplementedError("We set our own constraints") @overrides(AbstractHasConstraints.constraints) diff --git a/unittests/abstract_base/unchecked_bad_param.py b/unittests/abstract_base/unchecked_bad_param.py index b7289d77..fb11553a 100644 --- a/unittests/abstract_base/unchecked_bad_param.py +++ b/unittests/abstract_base/unchecked_bad_param.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.overrides import overrides from .abstract_grandparent import AbstractGrandParent from .abstract_has_constraints import AbstractHasConstraints @@ -25,7 +26,7 @@ def set_label(selfself, not_label): pass @overrides(AbstractHasConstraints.add_constraint) - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): raise NotImplementedError("We set our own constrainst") @overrides(AbstractHasConstraints.constraints) diff --git a/unittests/test_log.py b/unittests/test_log.py index 0e896399..67c04368 100644 --- a/unittests/test_log.py +++ b/unittests/test_log.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime import logging +from typing import List, Optional, Tuple from spinn_utilities.log import ( _BraceMessage, ConfiguredFilter, ConfiguredFormatter, FormatAdapter, LogLevelTooHighException) @@ -54,13 +56,15 @@ def __init__(self): self.data = [] @overrides(LogStore.store_log) - def store_log(self, level, message, timestamp=None): + def store_log(self, level: int, message: str, + timestamp: Optional[datetime] = None): if level == logging.CRITICAL: 1/0 self.data.append((level, message)) @overrides(LogStore.retreive_log_messages) - def retreive_log_messages(self, min_level=0): + def retreive_log_messages( + self, min_level: int = 0) -> List[Tuple[int, str]]: result = [] for (level, message) in self.data: if level >= min_level: