Skip to content

Commit

Permalink
Merge pull request #1514 from SpiNNakerManchester/t_fec
Browse files Browse the repository at this point in the history
Typing for FEC
  • Loading branch information
Christian-B authored Dec 30, 2024
2 parents 60ab4d1 + 9cc8237 commit e58a9df
Show file tree
Hide file tree
Showing 15 changed files with 107 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,41 +105,43 @@ def __init_callback_wrapper(
machine_timestep_ms)

@overrides(LiveEventConnection.add_start_callback)
def add_start_callback(self, label: str, start_callback: _Callback):
def add_start_callback(
self, label: str, start_callback: _Callback) -> None:
super().add_start_callback(
self.__control_label(label), functools.partial(
self.__callback_wrapper, start_callback))

@overrides(LiveEventConnection.add_start_resume_callback)
def add_start_resume_callback(
self, label: str, start_resume_callback: _Callback):
self, label: str, start_resume_callback: _Callback) -> None:
super().add_start_resume_callback(
self.__control_label(label), functools.partial(
self.__callback_wrapper, start_resume_callback))

@overrides(LiveEventConnection.add_init_callback)
def add_init_callback(self, label: str, init_callback: _InitCallback):
def add_init_callback(
self, label: str, init_callback: _InitCallback) -> None:
super().add_init_callback(
self.__control_label(label), functools.partial(
self.__init_callback_wrapper, init_callback))

@overrides(LiveEventConnection.add_receive_callback)
def add_receive_callback(
self, label: str, live_event_callback: _RcvTimeCallback,
translate_key: bool = True):
translate_key: bool = True) -> None:
raise ConfigurationException(
"SpynnakerPoissonControlPopulation can't receive data")

@overrides(LiveEventConnection.add_receive_no_time_callback)
def add_receive_no_time_callback(
self, label: str, live_event_callback: _RcvCallback,
translate_key: bool = True):
translate_key: bool = True) -> None:
raise ConfigurationException(
"SpynnakerPoissonControlPopulation can't receive data")

@overrides(LiveEventConnection.add_pause_stop_callback)
def add_pause_stop_callback(
self, label: str, pause_stop_callback: _Callback):
self, label: str, pause_stop_callback: _Callback) -> None:
super().add_pause_stop_callback(
self.__control_label(label), functools.partial(
self.__callback_wrapper, pause_stop_callback))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def parse_extra_provenance_items(
"or decrease the number of neurons per core.")

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
# reserve regions
self.reserve_memory_regions(spec)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def redundant_packet_count_report() -> None:

def _create_views() -> None:
with ProvenanceWriter() as db:
db.execute(REDUNDANCY_BY_CORE)
db.execute(REDUNDANCY_SUMMARY)
db.cursor().execute(REDUNDANCY_BY_CORE)
db.cursor().execute(REDUNDANCY_SUMMARY)


def _write_report(output: TextIO):
Expand Down
4 changes: 2 additions & 2 deletions spynnaker/pyNN/models/neuron/master_pop_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def _make_array(ctype: Type[_T], n_items: int) -> ctypes.Array[_T]:
:return: a ctype array
:rtype: _ctypes.PyCArrayType
"""
array_type = ctype * n_items
return array_type()
array_type = ctype * n_items # type: ignore
return array_type() # type: ignore


class _MasterPopEntryCType(ctypes.LittleEndianStructure):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ def get_recorded_region_ids(self) -> List[int]:
return ids

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
rec_regions = self._pop_vertex.neuron_recorder.get_region_sizes(
self.vertex_slice)
rec_regions.extend(self._pop_vertex.synapse_recorder.get_region_sizes(
Expand Down Expand Up @@ -323,8 +323,8 @@ def __write_local_only_data(self, spec: DataSpecificationGenerator):
spec.write_value(int(self._pop_vertex.drop_late_spikes))

@overrides(AbstractRewritesDataSpecification.regenerate_data_specification)
def regenerate_data_specification(
self, spec: DataSpecificationReloader, placement: Placement):
def regenerate_data_specification(self, spec: DataSpecificationReloader,
placement: Placement) -> None:
self._rewrite_neuron_data_spec(spec)

# close spec
Expand All @@ -335,7 +335,7 @@ def reload_required(self) -> bool:
return self.__regenerate_data

@overrides(AbstractRewritesDataSpecification.set_reload_required)
def set_reload_required(self, new_value: bool):
def set_reload_required(self, new_value: bool) -> None:
self.__regenerate_data = new_value

def _parse_local_only_provenance(
Expand Down
10 changes: 5 additions & 5 deletions spynnaker/pyNN/models/neuron/population_machine_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def get_recorded_region_ids(self) -> List[int]:
return ids

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
rec_regions = self._pop_vertex.neuron_recorder.get_region_sizes(
self.vertex_slice)
rec_regions.extend(self._pop_vertex.synapse_recorder.get_region_sizes(
Expand All @@ -340,8 +340,8 @@ def generate_data_specification(

@overrides(
AbstractRewritesDataSpecification.regenerate_data_specification)
def regenerate_data_specification(
self, spec: DataSpecificationReloader, placement: Placement):
def regenerate_data_specification(self, spec: DataSpecificationReloader,
placement: Placement) -> None:
if self.__regenerate_neuron_data:
self._rewrite_neuron_data_spec(spec)
self.__regenerate_neuron_data = False
Expand All @@ -360,7 +360,7 @@ def reload_required(self) -> bool:
return self.__regenerate_neuron_data or self.__regenerate_synapse_data

@overrides(AbstractRewritesDataSpecification.set_reload_required)
def set_reload_required(self, new_value: bool):
def set_reload_required(self, new_value: bool) -> None:
# These are set elsewhere once data is generated
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def get_recorded_region_ids(self) -> List[int]:
return ids

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
assert self.__sdram_partition is not None
rec_regions = self._pop_vertex.neuron_recorder.get_region_sizes(
self.vertex_slice)
Expand All @@ -281,8 +281,8 @@ def generate_data_specification(

@overrides(
AbstractRewritesDataSpecification.regenerate_data_specification)
def regenerate_data_specification(
self, spec: DataSpecificationReloader, placement: Placement):
def regenerate_data_specification(self, spec: DataSpecificationReloader,
placement: Placement) -> None:
# Write the other parameters
self._rewrite_neuron_data_spec(spec)

Expand All @@ -294,7 +294,7 @@ def reload_required(self) -> bool:
return self.__regenerate_data

@overrides(AbstractRewritesDataSpecification.set_reload_required)
def set_reload_required(self, new_value: bool):
def set_reload_required(self, new_value: bool) -> None:
self.__regenerate_data = new_value

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def get_recorded_region_ids(self) -> List[int]:
return ids

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
rec_regions = self._pop_vertex.synapse_recorder.get_region_sizes(
self.vertex_slice)
self._write_common_data_spec(spec, rec_regions)
Expand All @@ -140,8 +140,8 @@ def _parse_synapse_provenance(
self, label, x, y, p, provenance_data)

@overrides(AbstractRewritesDataSpecification.regenerate_data_specification)
def regenerate_data_specification(
self, spec: DataSpecificationReloader, placement: Placement):
def regenerate_data_specification(self, spec: DataSpecificationReloader,
placement: Placement) -> None:
# We don't need to do anything here because the originally written
# data can be used again
pass
Expand All @@ -151,7 +151,7 @@ def reload_required(self) -> bool:
return self.__regenerate_data

@overrides(AbstractRewritesDataSpecification.set_reload_required)
def set_reload_required(self, new_value: bool):
def set_reload_required(self, new_value: bool) -> None:
self.__regenerate_data = new_value

@overrides(PopulationMachineSynapses.set_do_synapse_regeneration)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def __init__(
self.__synapse_references = synapse_references

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
rec_regions = self._pop_vertex.synapse_recorder.get_region_sizes(
self.vertex_slice)
self._write_common_data_spec(spec, rec_regions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_n_keys_for_partition(self, partition_id: str) -> int:
return n_keys * n_colours

@overrides(ReverseIPTagMulticastSourceMachineVertex._fill_send_buffer_1d)
def _fill_send_buffer_1d(self, key_base: int):
def _fill_send_buffer_1d(self, key_base: int) -> None:
first_time_step = SpynnakerDataView.get_first_machine_time_step()
end_time_step = (
SpynnakerDataView.get_current_run_timesteps() or sys.maxsize)
Expand All @@ -58,7 +58,7 @@ def _fill_send_buffer_1d(self, key_base: int):
tick, keys + (tick & colour_mask))

@overrides(ReverseIPTagMulticastSourceMachineVertex._fill_send_buffer_2d)
def _fill_send_buffer_2d(self, key_base: int):
def _fill_send_buffer_2d(self, key_base: int) -> None:
first_time_step = SpynnakerDataView.get_first_machine_time_step()
end_time_step = (
SpynnakerDataView.get_current_run_timesteps() or sys.maxsize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ def reload_required(self) -> bool:
return SpynnakerDataView.get_first_machine_time_step() == 0

@overrides(AbstractRewritesDataSpecification.set_reload_required)
def set_reload_required(self, new_value: bool):
def set_reload_required(self, new_value: bool) -> None:
self.__rate_changed = new_value

@overrides(AbstractRewritesDataSpecification.regenerate_data_specification)
def regenerate_data_specification(
self, spec: DataSpecificationReloader, placement: Placement):
def regenerate_data_specification(self, spec: DataSpecificationReloader,
placement: Placement) -> None:
# write rates
self._write_poisson_rates(spec)

Expand All @@ -354,8 +354,8 @@ def __conn(synapse_info: SynapseInformation
'AbstractGenerateConnectorOnHost', synapse_info.connector)

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
spec.comment("\n*** Spec for SpikeSourcePoisson Instance ***\n\n")
# if we are here, the rates have changed!
self.__rate_changed = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def sdram_required(self) -> AbstractSDRAM:
parse_extra_provenance_items)
def parse_extra_provenance_items(
self, label: str, x: int, y: int, p: int,
provenance_data: Sequence[int]):
provenance_data: Sequence[int]) -> None:
(n_received, n_processed, n_added, n_sent, n_overflows, n_delays,
n_sat, n_bad_neuron, n_bad_keys, n_late_spikes, max_bg,
n_bg_overloads) = provenance_data
Expand Down Expand Up @@ -267,8 +267,8 @@ def get_binary_start_type(self) -> ExecutableType:
return ExecutableType.USES_SIMULATION_INTERFACE

@overrides(AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
def generate_data_specification(self, spec: DataSpecificationGenerator,
placement: Placement) -> None:
vertex = placement.vertex

# Reserve memory:
Expand Down
9 changes: 6 additions & 3 deletions spynnaker/pyNN/utilities/data_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from __future__ import annotations
import logging
from typing import (
Any, Dict, Iterable, Optional, overload, Sequence, Union, TYPE_CHECKING)
Any, cast, Dict, Iterable, Optional, overload, Sequence, Union,
TYPE_CHECKING)

import numpy
from numpy import floating
Expand Down Expand Up @@ -165,12 +166,13 @@ def id_to_index(self, id: Union[int, Iterable[int]]
_, first_id, _ = db.get_population_metadata(self.__label)
last_id = self._size + first_id
if not numpy.iterable(id):
id = cast(int, id)
if not first_id <= id <= last_id:
raise ValueError(
f"id should be in the range [{first_id},{last_id}], "
f"actually {id}")
return int(id - first_id) # assume IDs are consecutive
return id - first_id
return [_id - first_id for _id in id]

@overload
def index_to_id(self, index: int) -> int:
Expand All @@ -188,13 +190,14 @@ def index_to_id(self, index: Union[int, Iterable[int]]
with NeoBufferDatabase(self.__database_file) as db:
_, first_id, _ = db.get_population_metadata(self.__label)
if not numpy.iterable(index):
index = cast(int, index)
if index >= self._size:
raise ValueError(
f"indexes should be in the range [0,{self._size}],"
f" actually {index}")
return int(index + first_id)
# this assumes IDs are consecutive
return index + first_id
return [_index + first_id for _index in index]

def __getitem__(self, index_or_slice: Selector) -> DataPopulation:
"""
Expand Down
Loading

0 comments on commit e58a9df

Please sign in to comment.