This repository contains code to reproduce the key findings of "Training spiking neural networks with Forward Porpogation Through Time". This code implements the spiking recurrent networks with Liquid Time-Constant spiking neurons (LTC) on Pytorchtrained via FPTT for various tasks. The Notebook was created to illustrate the funcationality of LTC spiking neurons.
This is scientific software, and as such subject to many modifications; we aim to further improve the software to become more user-friendly and extendible in the future.
- S/P-MNIST, R-MNIST: This dataset can easily be found in torchvision.datasets.MNIST(MNIST)
- Fashion-MNIST: This dataset can easily access via torchvision.datasets.FashionMNIST(FashionMNIST)
- DVS dataset: SpikingJelly includes neuromorphic datasets (Gesture128-DVS and Cifar10-DVS.You can also download the datasets from official sit. Our prerpocess of DVS datasets also support in SpikingJelly.
- PASCAL Visual Object Classes (VOC) dataset(VOC) contains 20 object categories. Each image in this dataset has pixel-level segmentation annotations, bounding box annotations, and object class annotations. This dataset has been widely used as a benchmark for object detection, semantic segmentation, and classification tasks. In this paper, SPiking-YOLO (SPYv4) network was trained and tested on VOC07+12.
- Pyhton 3.8.10
- A working version of python and Pytorch This should be easy: either use the Google Colab facilities, or do a simple installation on your laptop could probabily using pip. (Start Locally | PyTorch) torch==1.7.1
- SpikingJelly(SpikingJelly)
- For object detection taskes, it requires OpenCV 2
for e in range(epochs): # epoch iteration
for i in range(sequence_len): # read the sequence
if i ==0:
model.init_h(x_in.shape[0]) # At first step initialize the hidden states
else:
model.h = list(v.detach() for v in model.h) # detach computation graph from previous timestep
out = model.forward_t(x_in[:,:,i]) # read input and generate output
loss_c = (i)/sequence_len*criterion(out, targets) # get prediction loss
loss_r = get_regularizer_named_params(named_params, _lambda=1.0 ) # get regularizer loss
loss = loss_c+loss_r
optimizer.zero_grad()
loss.backward() # calculate gradient of current timestep
optimizer.step() # update the network
post_optimizer_updates( named_params, epoch) # update trace \bar{w} and \delta{l}
reset_named_parameter(named_params) # reset traces
A video demo of SPiking-YOLO (SPYv4) :
You can find more details in readme file of each task.
Finally, we’d love to hear from you if you have any comments or suggestions.
[1]. https://github.com/bubbliiiing/yolov4-tiny-pytorch
MIT