From c5beed52a2f33a666855f5543dd87102bea7f2df Mon Sep 17 00:00:00 2001 From: "Christian Y. Brenninkmeijer" Date: Wed, 29 Nov 2023 09:26:13 +0000 Subject: [PATCH] typing fixes --- .../udp_packet_connections/bmp_connection.py | 4 ++-- .../eieio_connection.py | 2 +- .../udp_packet_connections/sdp_connection.py | 2 +- .../udp_packet_connections/udp_connection.py | 4 ++-- .../command_messages/eieio_command_message.py | 10 +++++---- .../eieio/data_messages/eieio_data_message.py | 2 +- .../scp/abstract_messages/scp_request.py | 2 +- .../messages/scp/impl/check_ok_response.py | 2 +- spinnman/messages/scp/impl/sdram_alloc.py | 2 +- ..._connection_process_connection_selector.py | 6 ++--- .../processes/fixed_connection_selector.py | 22 +++++++++---------- .../processes/get_exclude_cpu_info_process.py | 2 +- .../processes/get_include_cpu_info_process.py | 2 +- .../most_direct_connection_selector.py | 2 +- .../round_robin_connection_selector.py | 5 +++-- spinnman/spalloc/spalloc_boot_connection.py | 2 +- spinnman/spalloc/spalloc_eieio_connection.py | 2 +- spinnman/spalloc/spalloc_eieio_listener.py | 2 +- spinnman/spalloc/spalloc_scp_connection.py | 4 ++-- spinnman/transceiver/mockable_transceiver.py | 2 +- 20 files changed, 41 insertions(+), 40 deletions(-) diff --git a/spinnman/connections/udp_packet_connections/bmp_connection.py b/spinnman/connections/udp_packet_connections/bmp_connection.py index 63e45379e..b9258f13a 100644 --- a/spinnman/connections/udp_packet_connections/bmp_connection.py +++ b/spinnman/connections/udp_packet_connections/bmp_connection.py @@ -13,7 +13,7 @@ # limitations under the License. import struct -from typing import Sequence, Tuple +from typing import Optional, Sequence, Tuple from spinn_utilities.overrides import overrides from .udp_connection import UDPConnection from spinnman.constants import SCP_SCAMP_PORT @@ -74,7 +74,7 @@ def get_scp_data(self, scp_request: AbstractSCPRequest) -> bytes: return _TWO_SKIP.pack() + scp_request.bytestring @overrides(AbstractSCPConnection.receive_scp_response) - def receive_scp_response(self, timeout=1.0) -> Tuple[ + def receive_scp_response(self, timeout: Optional[float] =1.0) -> Tuple[ SCPResult, int, bytes, int]: data = self.receive(timeout) result, sequence = _TWO_SHORTS.unpack_from(data, 10) diff --git a/spinnman/connections/udp_packet_connections/eieio_connection.py b/spinnman/connections/udp_packet_connections/eieio_connection.py index 3d9b403b7..cc9b29c73 100644 --- a/spinnman/connections/udp_packet_connections/eieio_connection.py +++ b/spinnman/connections/udp_packet_connections/eieio_connection.py @@ -74,7 +74,7 @@ def send_eieio_message_to( ip_address: str, port: int): self.send_to(eieio_message.bytestring, (ip_address, port)) - @overrides(Listenable.get_receive_method) + @overrides(Listenable.get_receive_method, return_narrowing=True) def get_receive_method(self) -> Callable[ # type: ignore[override] [], AbstractEIEIOMessage]: return self.receive_eieio_message diff --git a/spinnman/connections/udp_packet_connections/sdp_connection.py b/spinnman/connections/udp_packet_connections/sdp_connection.py index 37915eccc..372f5daaf 100644 --- a/spinnman/connections/udp_packet_connections/sdp_connection.py +++ b/spinnman/connections/udp_packet_connections/sdp_connection.py @@ -100,7 +100,7 @@ def send_sdp_message(self, sdp_message: SDPMessage): sdp_message.sdp_header.update_for_send(0, 0) self.send(_TWO_SKIP.pack() + sdp_message.bytestring) - @overrides(Listenable.get_receive_method) + @overrides(Listenable.get_receive_method, return_narrowing=True) def get_receive_method( # type: ignore[override] self) -> Callable[[], SDPMessage]: return self.receive_sdp_message diff --git a/spinnman/connections/udp_packet_connections/udp_connection.py b/spinnman/connections/udp_packet_connections/udp_connection.py index 92d10f8c1..9e6dfd9b9 100644 --- a/spinnman/connections/udp_packet_connections/udp_connection.py +++ b/spinnman/connections/udp_packet_connections/udp_connection.py @@ -269,6 +269,6 @@ def __repr__(self) -> str: self.local_ip_address, self.local_port, self.remote_ip_address, self.remote_port) - @overrides(Listenable.get_receive_method) + @overrides(Listenable.get_receive_method, return_narrowing=True) def get_receive_method(self) -> Callable[[], bytes]: - return self.receive + return self.receive diff --git a/spinnman/messages/eieio/command_messages/eieio_command_message.py b/spinnman/messages/eieio/command_messages/eieio_command_message.py index 906af845c..ba75a1790 100644 --- a/spinnman/messages/eieio/command_messages/eieio_command_message.py +++ b/spinnman/messages/eieio/command_messages/eieio_command_message.py @@ -13,6 +13,7 @@ # limitations under the License. from spinn_utilities.overrides import overrides from spinnman.messages.eieio import AbstractEIEIOMessage +from spinnman.messages.eieio.command_messages import EIEIOCommandHeader class EIEIOCommandMessage(AbstractEIEIOMessage): @@ -24,7 +25,8 @@ class EIEIOCommandMessage(AbstractEIEIOMessage): "_eieio_command_header", "_offset") - def __init__(self, eieio_command_header, data=None, offset=0): + def __init__(self, eieio_command_header: EIEIOCommandHeader, + data=None, offset=0): """ :param EIEIOCommandHeader eieio_command_header: The header of the message @@ -39,8 +41,8 @@ def __init__(self, eieio_command_header, data=None, offset=0): self._offset = offset @property - @overrides(AbstractEIEIOMessage.eieio_header) - def eieio_header(self): + @overrides(AbstractEIEIOMessage.eieio_header, return_narrowing=True) + def eieio_header(self) -> EIEIOCommandHeader: """ :rtype: EIEIOCommandHeader """ @@ -60,7 +62,7 @@ def from_bytestring(command_header, data, offset): @property @overrides(AbstractEIEIOMessage.bytestring) - def bytestring(self): + def bytestring(self) -> bytes: return self._eieio_command_header.bytestring @staticmethod diff --git a/spinnman/messages/eieio/data_messages/eieio_data_message.py b/spinnman/messages/eieio/data_messages/eieio_data_message.py index 306b83b12..c17485873 100644 --- a/spinnman/messages/eieio/data_messages/eieio_data_message.py +++ b/spinnman/messages/eieio/data_messages/eieio_data_message.py @@ -89,7 +89,7 @@ def create( data=data, offset=offset) @property - @overrides(AbstractEIEIOMessage.eieio_header) + @overrides(AbstractEIEIOMessage.eieio_header, return_narrowing=True) def eieio_header(self) -> EIEIODataHeader: """ :rtype: EIEIODataHeader diff --git a/spinnman/messages/scp/abstract_messages/scp_request.py b/spinnman/messages/scp/abstract_messages/scp_request.py index 4fbae6a07..48eec8fb2 100644 --- a/spinnman/messages/scp/abstract_messages/scp_request.py +++ b/spinnman/messages/scp/abstract_messages/scp_request.py @@ -146,7 +146,7 @@ def __str__(self): return self.__repr__() @abstractmethod - def get_scp_response(self) -> R: + def get_scp_response(self) -> AbstractSCPResponse: """ Get an SCP response message to be used to process any response received. diff --git a/spinnman/messages/scp/impl/check_ok_response.py b/spinnman/messages/scp/impl/check_ok_response.py index 662ea9d9b..4f4a140b0 100644 --- a/spinnman/messages/scp/impl/check_ok_response.py +++ b/spinnman/messages/scp/impl/check_ok_response.py @@ -37,7 +37,7 @@ def __init__(self, operation: str, command): self._command = command @overrides(AbstractSCPResponse.read_data_bytestring) - def read_data_bytestring(self, data: bytes, offset: int) -> None: + def read_data_bytestring(self, data: bytes, offset: int): result = self.scp_response_header.result if result != SCPResult.RC_OK: raise SpinnmanUnexpectedResponseCodeException( diff --git a/spinnman/messages/scp/impl/sdram_alloc.py b/spinnman/messages/scp/impl/sdram_alloc.py index 30821d2f5..ad904146b 100644 --- a/spinnman/messages/scp/impl/sdram_alloc.py +++ b/spinnman/messages/scp/impl/sdram_alloc.py @@ -42,7 +42,7 @@ def __init__(self, size: int): self._base_address: Optional[int] = None @overrides(AbstractSCPResponse.read_data_bytestring) - def read_data_bytestring(self, data, offset): + def read_data_bytestring(self, data: bytes, offset: int): result = self.scp_response_header.result if result != SCPResult.RC_OK: raise SpinnmanUnexpectedResponseCodeException( diff --git a/spinnman/processes/abstract_multi_connection_process_connection_selector.py b/spinnman/processes/abstract_multi_connection_process_connection_selector.py index 4e2433d88..7d4b50460 100644 --- a/spinnman/processes/abstract_multi_connection_process_connection_selector.py +++ b/spinnman/processes/abstract_multi_connection_process_connection_selector.py @@ -21,10 +21,10 @@ #: Type of connections selected between. #: :meta private: -Conn = TypeVar("Conn", SCAMPConnection, BMPConnection) +ConnectionType = TypeVar("ConnectionType", SCAMPConnection, BMPConnection) -class ConnectionSelector(Generic[Conn], metaclass=AbstractBase): +class ConnectionSelector(Generic[ConnectionType], metaclass=AbstractBase): """ A connection selector for multi-connection processes. """ @@ -32,7 +32,7 @@ class ConnectionSelector(Generic[Conn], metaclass=AbstractBase): @abstractmethod def get_next_connection( - self, message: AbstractSCPRequest) -> Conn: + self, message: AbstractSCPRequest) -> ConnectionType: """ Get the index of the next connection for the process from a list of connections. diff --git a/spinnman/processes/fixed_connection_selector.py b/spinnman/processes/fixed_connection_selector.py index e8972a76f..d7eff5721 100644 --- a/spinnman/processes/fixed_connection_selector.py +++ b/spinnman/processes/fixed_connection_selector.py @@ -11,31 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Generic, TypeVar +from typing import Generic from spinn_utilities.overrides import overrides from .abstract_multi_connection_process_connection_selector import ( - ConnectionSelector) -from spinnman.connections.udp_packet_connections import ( - SCAMPConnection, BMPConnection) + ConnectionSelector, ConnectionType) +from spinnman.connections.udp_packet_connections import (SCAMPConnection) +from spinnman.messages.scp.abstract_messages import AbstractSCPRequest -#: Type of connections selected between. -#: :meta private: -Conn = TypeVar("Conn", SCAMPConnection, BMPConnection) - -class FixedConnectionSelector(ConnectionSelector[Conn], Generic[Conn]): +class FixedConnectionSelector( + ConnectionSelector[ConnectionType], Generic[ConnectionType]): """ A connection selector that only uses a single connection. """ __slots__ = ("__connection", ) - def __init__(self, connection: Conn): + def __init__(self, connection: ConnectionType): """ :param SCAMPConnection connection: The connection to be used """ - self.__connection: Conn = connection + self.__connection: ConnectionType = connection @overrides(ConnectionSelector.get_next_connection) - def get_next_connection(self, message: Any) -> Conn: + def get_next_connection( + self, message: AbstractSCPRequest) -> ConnectionType: return self.__connection diff --git a/spinnman/processes/get_exclude_cpu_info_process.py b/spinnman/processes/get_exclude_cpu_info_process.py index 7a5bbbbca..ca5bca04e 100644 --- a/spinnman/processes/get_exclude_cpu_info_process.py +++ b/spinnman/processes/get_exclude_cpu_info_process.py @@ -30,5 +30,5 @@ def __init__(self, connection_selector: ConnectionSelector, self.__states = states @overrides(GetCPUInfoProcess._is_desired) - def _is_desired(self, cpu_info: CPUInfo): + def _is_desired(self, cpu_info: CPUInfo) -> bool: return cpu_info.state not in self.__states diff --git a/spinnman/processes/get_include_cpu_info_process.py b/spinnman/processes/get_include_cpu_info_process.py index 2fd1154c4..749549983 100644 --- a/spinnman/processes/get_include_cpu_info_process.py +++ b/spinnman/processes/get_include_cpu_info_process.py @@ -35,5 +35,5 @@ def __init__(self, connection_selector: ConnectionSelector, self.__states = states @overrides(GetCPUInfoProcess._is_desired) - def _is_desired(self, cpu_info: CPUInfo): + def _is_desired(self, cpu_info: CPUInfo) -> bool: return cpu_info.state in self.__states diff --git a/spinnman/processes/most_direct_connection_selector.py b/spinnman/processes/most_direct_connection_selector.py index 84dc334df..aa194e18d 100644 --- a/spinnman/processes/most_direct_connection_selector.py +++ b/spinnman/processes/most_direct_connection_selector.py @@ -44,7 +44,7 @@ def __init__(self, connections: List[SCAMPConnection]): lead_connection = next(iter(connections)) self._lead_connection = lead_connection - @overrides(ConnectionSelector.get_next_connection) + @overrides(ConnectionSelector.get_next_connection, return_narrowing=True) def get_next_connection( self, message: AbstractSCPRequest) -> SCAMPConnection: key = (message.sdp_header.destination_chip_x, diff --git a/spinnman/processes/round_robin_connection_selector.py b/spinnman/processes/round_robin_connection_selector.py index c2021969e..48d5c3b1d 100644 --- a/spinnman/processes/round_robin_connection_selector.py +++ b/spinnman/processes/round_robin_connection_selector.py @@ -14,6 +14,7 @@ from typing import Any, List from spinn_utilities.overrides import overrides from spinnman.connections.udp_packet_connections import SCAMPConnection +from spinnman.messages.scp.abstract_messages import AbstractSCPRequest from .abstract_multi_connection_process_connection_selector import ( ConnectionSelector) @@ -34,8 +35,8 @@ def __init__(self, connections: List[SCAMPConnection]): self._connections = connections self._next_connection_index = 0 - @overrides(ConnectionSelector.get_next_connection) - def get_next_connection(self, message: Any) -> SCAMPConnection: + @overrides(ConnectionSelector.get_next_connection, return_narrowing=True) + def get_next_connection(self, message: AbstractSCPRequest) -> SCAMPConnection: index = self._next_connection_index self._next_connection_index = (index + 1) % len(self._connections) return self._connections[index] diff --git a/spinnman/spalloc/spalloc_boot_connection.py b/spinnman/spalloc/spalloc_boot_connection.py index 0a2c30a2c..6fedcbcf3 100644 --- a/spinnman/spalloc/spalloc_boot_connection.py +++ b/spinnman/spalloc/spalloc_boot_connection.py @@ -57,7 +57,7 @@ def receive_boot_message( data = self.receive(timeout) return SpinnakerBootMessage.from_bytestring(data, 0) - @overrides(Listenable.get_receive_method) + @overrides(Listenable.get_receive_method, return_narrowing=True) def get_receive_method( # type: ignore[override] self) -> Callable[[], SpinnakerBootMessage]: return self.receive_boot_message diff --git a/spinnman/spalloc/spalloc_eieio_connection.py b/spinnman/spalloc/spalloc_eieio_connection.py index 613c7c7db..1a8c1c6b9 100644 --- a/spinnman/spalloc/spalloc_eieio_connection.py +++ b/spinnman/spalloc/spalloc_eieio_connection.py @@ -45,7 +45,7 @@ class SpallocEIEIOConnection( __slots__ = () @overrides(EIEIOConnection.send_eieio_message) - def send_eieio_message(self, eieio_message): + def send_eieio_message(self, eieio_message: AbstractEIEIOMessage): # Not normally used, as packets need headers to go to SpiNNaker self.send(eieio_message.bytestring) diff --git a/spinnman/spalloc/spalloc_eieio_listener.py b/spinnman/spalloc/spalloc_eieio_listener.py index 435fba604..9e7e2b94a 100644 --- a/spinnman/spalloc/spalloc_eieio_listener.py +++ b/spinnman/spalloc/spalloc_eieio_listener.py @@ -57,7 +57,7 @@ def receive_eieio_message( return read_eieio_data_message(data, 0) @overrides(SpallocProxiedConnection.send) - def send(self, data): + def send(self, data: bytes): """ .. note:: This class does not allow sending. diff --git a/spinnman/spalloc/spalloc_scp_connection.py b/spinnman/spalloc/spalloc_scp_connection.py index a88c140fc..56a48a8ad 100644 --- a/spinnman/spalloc/spalloc_scp_connection.py +++ b/spinnman/spalloc/spalloc_scp_connection.py @@ -55,8 +55,8 @@ def send_sdp_message(self, sdp_message: SDPMessage): self.send(_TWO_SKIP + sdp_message.bytestring) @overrides(SCAMPConnection.receive_scp_response) - def receive_scp_response( - self, timeout=1.0) -> Tuple[SCPResult, int, bytes, int]: + def receive_scp_response(self, timeout: Optional[float] = 1.0) -> Tuple[ + SCPResult, int, bytes, int]: data = self.receive(timeout) result, sequence = _TWO_SHORTS.unpack_from(data, 10) return SCPResult(result), sequence, data, 2 diff --git a/spinnman/transceiver/mockable_transceiver.py b/spinnman/transceiver/mockable_transceiver.py index 558cc7505..14028f7ba 100644 --- a/spinnman/transceiver/mockable_transceiver.py +++ b/spinnman/transceiver/mockable_transceiver.py @@ -80,7 +80,7 @@ def get_cpu_infos( raise NotImplementedError("Needs to be mocked") @overrides(Transceiver.get_clock_drift) - def get_clock_drift(self, x, y): + def get_clock_drift(self, x: int, y: int) -> float: raise NotImplementedError("Needs to be mocked") @overrides(Transceiver.read_user)