Skip to content

Commit

Permalink
typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-B committed Nov 29, 2023
1 parent 105dea7 commit c5beed5
Show file tree
Hide file tree
Showing 20 changed files with 41 additions and 40 deletions.
4 changes: 2 additions & 2 deletions spinnman/connections/udp_packet_connections/bmp_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import struct
from typing import Sequence, Tuple
from typing import Optional, Sequence, Tuple
from spinn_utilities.overrides import overrides
from .udp_connection import UDPConnection
from spinnman.constants import SCP_SCAMP_PORT
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_scp_data(self, scp_request: AbstractSCPRequest) -> bytes:
return _TWO_SKIP.pack() + scp_request.bytestring

@overrides(AbstractSCPConnection.receive_scp_response)
def receive_scp_response(self, timeout=1.0) -> Tuple[
def receive_scp_response(self, timeout: Optional[float] =1.0) -> Tuple[
SCPResult, int, bytes, int]:
data = self.receive(timeout)
result, sequence = _TWO_SHORTS.unpack_from(data, 10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def send_eieio_message_to(
ip_address: str, port: int):
self.send_to(eieio_message.bytestring, (ip_address, port))

@overrides(Listenable.get_receive_method)
@overrides(Listenable.get_receive_method, return_narrowing=True)
def get_receive_method(self) -> Callable[ # type: ignore[override]
[], AbstractEIEIOMessage]:
return self.receive_eieio_message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def send_sdp_message(self, sdp_message: SDPMessage):
sdp_message.sdp_header.update_for_send(0, 0)
self.send(_TWO_SKIP.pack() + sdp_message.bytestring)

@overrides(Listenable.get_receive_method)
@overrides(Listenable.get_receive_method, return_narrowing=True)
def get_receive_method( # type: ignore[override]
self) -> Callable[[], SDPMessage]:
return self.receive_sdp_message
Expand Down
4 changes: 2 additions & 2 deletions spinnman/connections/udp_packet_connections/udp_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,6 @@ def __repr__(self) -> str:
self.local_ip_address, self.local_port,
self.remote_ip_address, self.remote_port)

@overrides(Listenable.get_receive_method)
@overrides(Listenable.get_receive_method, return_narrowing=True)
def get_receive_method(self) -> Callable[[], bytes]:
return self.receive
return self.receive
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from spinn_utilities.overrides import overrides
from spinnman.messages.eieio import AbstractEIEIOMessage
from spinnman.messages.eieio.command_messages import EIEIOCommandHeader


class EIEIOCommandMessage(AbstractEIEIOMessage):
Expand All @@ -24,7 +25,8 @@ class EIEIOCommandMessage(AbstractEIEIOMessage):
"_eieio_command_header",
"_offset")

def __init__(self, eieio_command_header, data=None, offset=0):
def __init__(self, eieio_command_header: EIEIOCommandHeader,
data=None, offset=0):
"""
:param EIEIOCommandHeader eieio_command_header:
The header of the message
Expand All @@ -39,8 +41,8 @@ def __init__(self, eieio_command_header, data=None, offset=0):
self._offset = offset

@property
@overrides(AbstractEIEIOMessage.eieio_header)
def eieio_header(self):
@overrides(AbstractEIEIOMessage.eieio_header, return_narrowing=True)
def eieio_header(self) -> EIEIOCommandHeader:
"""
:rtype: EIEIOCommandHeader
"""
Expand All @@ -60,7 +62,7 @@ def from_bytestring(command_header, data, offset):

@property
@overrides(AbstractEIEIOMessage.bytestring)
def bytestring(self):
def bytestring(self) -> bytes:
return self._eieio_command_header.bytestring

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def create(
data=data, offset=offset)

@property
@overrides(AbstractEIEIOMessage.eieio_header)
@overrides(AbstractEIEIOMessage.eieio_header, return_narrowing=True)
def eieio_header(self) -> EIEIODataHeader:
"""
:rtype: EIEIODataHeader
Expand Down
2 changes: 1 addition & 1 deletion spinnman/messages/scp/abstract_messages/scp_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __str__(self):
return self.__repr__()

@abstractmethod
def get_scp_response(self) -> R:
def get_scp_response(self) -> AbstractSCPResponse:
"""
Get an SCP response message to be used to process any response
received.
Expand Down
2 changes: 1 addition & 1 deletion spinnman/messages/scp/impl/check_ok_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, operation: str, command):
self._command = command

@overrides(AbstractSCPResponse.read_data_bytestring)
def read_data_bytestring(self, data: bytes, offset: int) -> None:
def read_data_bytestring(self, data: bytes, offset: int):
result = self.scp_response_header.result
if result != SCPResult.RC_OK:
raise SpinnmanUnexpectedResponseCodeException(
Expand Down
2 changes: 1 addition & 1 deletion spinnman/messages/scp/impl/sdram_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, size: int):
self._base_address: Optional[int] = None

@overrides(AbstractSCPResponse.read_data_bytestring)
def read_data_bytestring(self, data, offset):
def read_data_bytestring(self, data: bytes, offset: int):
result = self.scp_response_header.result
if result != SCPResult.RC_OK:
raise SpinnmanUnexpectedResponseCodeException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@

#: Type of connections selected between.
#: :meta private:
Conn = TypeVar("Conn", SCAMPConnection, BMPConnection)
ConnectionType = TypeVar("ConnectionType", SCAMPConnection, BMPConnection)


class ConnectionSelector(Generic[Conn], metaclass=AbstractBase):
class ConnectionSelector(Generic[ConnectionType], metaclass=AbstractBase):
"""
A connection selector for multi-connection processes.
"""
__slots__ = ()

@abstractmethod
def get_next_connection(
self, message: AbstractSCPRequest) -> Conn:
self, message: AbstractSCPRequest) -> ConnectionType:
"""
Get the index of the next connection for the process from a list
of connections.
Expand Down
22 changes: 10 additions & 12 deletions spinnman/processes/fixed_connection_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Generic, TypeVar
from typing import Generic
from spinn_utilities.overrides import overrides
from .abstract_multi_connection_process_connection_selector import (
ConnectionSelector)
from spinnman.connections.udp_packet_connections import (
SCAMPConnection, BMPConnection)
ConnectionSelector, ConnectionType)
from spinnman.connections.udp_packet_connections import (SCAMPConnection)
from spinnman.messages.scp.abstract_messages import AbstractSCPRequest

#: Type of connections selected between.
#: :meta private:
Conn = TypeVar("Conn", SCAMPConnection, BMPConnection)


class FixedConnectionSelector(ConnectionSelector[Conn], Generic[Conn]):
class FixedConnectionSelector(
ConnectionSelector[ConnectionType], Generic[ConnectionType]):
"""
A connection selector that only uses a single connection.
"""
__slots__ = ("__connection", )

def __init__(self, connection: Conn):
def __init__(self, connection: ConnectionType):
"""
:param SCAMPConnection connection:
The connection to be used
"""
self.__connection: Conn = connection
self.__connection: ConnectionType = connection

@overrides(ConnectionSelector.get_next_connection)
def get_next_connection(self, message: Any) -> Conn:
def get_next_connection(
self, message: AbstractSCPRequest) -> ConnectionType:
return self.__connection
2 changes: 1 addition & 1 deletion spinnman/processes/get_exclude_cpu_info_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def __init__(self, connection_selector: ConnectionSelector,
self.__states = states

@overrides(GetCPUInfoProcess._is_desired)
def _is_desired(self, cpu_info: CPUInfo):
def _is_desired(self, cpu_info: CPUInfo) -> bool:
return cpu_info.state not in self.__states
2 changes: 1 addition & 1 deletion spinnman/processes/get_include_cpu_info_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def __init__(self, connection_selector: ConnectionSelector,
self.__states = states

@overrides(GetCPUInfoProcess._is_desired)
def _is_desired(self, cpu_info: CPUInfo):
def _is_desired(self, cpu_info: CPUInfo) -> bool:
return cpu_info.state in self.__states
2 changes: 1 addition & 1 deletion spinnman/processes/most_direct_connection_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, connections: List[SCAMPConnection]):
lead_connection = next(iter(connections))
self._lead_connection = lead_connection

@overrides(ConnectionSelector.get_next_connection)
@overrides(ConnectionSelector.get_next_connection, return_narrowing=True)
def get_next_connection(
self, message: AbstractSCPRequest) -> SCAMPConnection:
key = (message.sdp_header.destination_chip_x,
Expand Down
5 changes: 3 additions & 2 deletions spinnman/processes/round_robin_connection_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, List
from spinn_utilities.overrides import overrides
from spinnman.connections.udp_packet_connections import SCAMPConnection
from spinnman.messages.scp.abstract_messages import AbstractSCPRequest
from .abstract_multi_connection_process_connection_selector import (
ConnectionSelector)

Expand All @@ -34,8 +35,8 @@ def __init__(self, connections: List[SCAMPConnection]):
self._connections = connections
self._next_connection_index = 0

@overrides(ConnectionSelector.get_next_connection)
def get_next_connection(self, message: Any) -> SCAMPConnection:
@overrides(ConnectionSelector.get_next_connection, return_narrowing=True)
def get_next_connection(self, message: AbstractSCPRequest) -> SCAMPConnection:
index = self._next_connection_index
self._next_connection_index = (index + 1) % len(self._connections)
return self._connections[index]
2 changes: 1 addition & 1 deletion spinnman/spalloc/spalloc_boot_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def receive_boot_message(
data = self.receive(timeout)
return SpinnakerBootMessage.from_bytestring(data, 0)

@overrides(Listenable.get_receive_method)
@overrides(Listenable.get_receive_method, return_narrowing=True)
def get_receive_method( # type: ignore[override]
self) -> Callable[[], SpinnakerBootMessage]:
return self.receive_boot_message
2 changes: 1 addition & 1 deletion spinnman/spalloc/spalloc_eieio_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SpallocEIEIOConnection(
__slots__ = ()

@overrides(EIEIOConnection.send_eieio_message)
def send_eieio_message(self, eieio_message):
def send_eieio_message(self, eieio_message: AbstractEIEIOMessage):
# Not normally used, as packets need headers to go to SpiNNaker
self.send(eieio_message.bytestring)

Expand Down
2 changes: 1 addition & 1 deletion spinnman/spalloc/spalloc_eieio_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def receive_eieio_message(
return read_eieio_data_message(data, 0)

@overrides(SpallocProxiedConnection.send)
def send(self, data):
def send(self, data: bytes):
"""
.. note::
This class does not allow sending.
Expand Down
4 changes: 2 additions & 2 deletions spinnman/spalloc/spalloc_scp_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def send_sdp_message(self, sdp_message: SDPMessage):
self.send(_TWO_SKIP + sdp_message.bytestring)

@overrides(SCAMPConnection.receive_scp_response)
def receive_scp_response(
self, timeout=1.0) -> Tuple[SCPResult, int, bytes, int]:
def receive_scp_response(self, timeout: Optional[float] = 1.0) -> Tuple[
SCPResult, int, bytes, int]:
data = self.receive(timeout)
result, sequence = _TWO_SHORTS.unpack_from(data, 10)
return SCPResult(result), sequence, data, 2
Expand Down
2 changes: 1 addition & 1 deletion spinnman/transceiver/mockable_transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_cpu_infos(
raise NotImplementedError("Needs to be mocked")

@overrides(Transceiver.get_clock_drift)
def get_clock_drift(self, x, y):
def get_clock_drift(self, x: int, y: int) -> float:
raise NotImplementedError("Needs to be mocked")

@overrides(Transceiver.read_user)
Expand Down

0 comments on commit c5beed5

Please sign in to comment.