Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ModuleNotFoundError. #6

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
__pycache__
/MA_SNN/DVSGait/data
/MA_SNN/DVSGestures/data/DvsGesture
/MA_SNN/DVSGestures/data/train
/MA_SNN/DVSGestures/data/test
MA_SNN/DVSGestures/data/DvsGesture.tar.gz
info.txt
Attention-SNN.code-workspace
83 changes: 42 additions & 41 deletions MA_SNN/DVSGestures/Att_SNN_CNN.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
import os

import sys

sys.path.append(os.path.dirname("__file__"))
from DVSGestures.CNN import Att_SNN

rootPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(rootPath)[0]
sys.path.append(rootPath)

from DVSGestures.CNN import Config

os.environ["CUDA_VISIBLE_DEVICES"] = "4,"


class Logger(object):
def __init__(self, fileN="Default.log"):
self.terminal = sys.stdout
self.log = open(fileN, "w")

def write(self, message):
self.terminal.write(message)
self.log.write(message)

def flush(self):
pass


logPath = Config.configs().recordPath
if not os.path.exists(logPath):
os.makedirs(logPath)
sys.stdout = Logger(logPath + os.sep + "log_DVS_Gesture_SNN.txt")


def main():
Att_SNN.main()


if __name__ == "__main__":
main()
import os

import sys

# sys.path.append(os.path.dirname("__file__"))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from DVSGestures.CNN import Att_SNN

rootPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(rootPath)[0]
sys.path.append(rootPath)

from DVSGestures.CNN import Config

# os.environ["CUDA_VISIBLE_DEVICES"] = "4,"


class Logger(object):
def __init__(self, fileN="Default.log"):
self.terminal = sys.stdout
self.log = open(fileN, "w")

def write(self, message):
self.terminal.write(message)
self.log.write(message)

def flush(self):
pass


logPath = Config.configs().recordPath
if not os.path.exists(logPath):
os.makedirs(logPath)
sys.stdout = Logger(logPath + os.sep + "log_DVS_Gesture_SNN.txt")


def main():
Att_SNN.main()


if __name__ == "__main__":
main()
42 changes: 42 additions & 0 deletions MA_SNN/DVSGestures/Att_SNN_CNN_SpikingJelly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

import sys

# sys.path.append(os.path.dirname("__file__"))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from DVSGestures.CNN import Att_SNN_SpikingJelly as Att_SNN

rootPath = os.path.abspath(os.path.dirname(__file__))
rootPath = os.path.split(rootPath)[0]
sys.path.append(rootPath)

from DVSGestures.CNN import Config

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"


class Logger(object):
def __init__(self, fileN="Default.log"):
self.terminal = sys.stdout
self.log = open(fileN, "w")

def write(self, message):
self.terminal.write(message)
self.log.write(message)

def flush(self):
pass


logPath = Config.configs().recordPath
if not os.path.exists(logPath):
os.makedirs(logPath)
sys.stdout = Logger(logPath + os.sep + "log_DVS_Gesture_SNN.txt")


def main():
Att_SNN.main()


if __name__ == "__main__":
main()
44 changes: 44 additions & 0 deletions MA_SNN/DVSGestures/CNN/Att_SNN_SpikingJelly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from utils import util
from DVSGestures.DVS_Gesture_utils.dataset import create_data
from DVSGestures.CNN.Networks.Att_SNN_SpikingJelly import create_net
from DVSGestures.CNN.Config import configs
from DVSGestures.DVS_Gesture_utils.process import process
from DVSGestures.DVS_Gesture_utils.save import save_csv


def main():

config = configs()
config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(config.device)

config.device_ids = range(torch.cuda.device_count())
print(config.device_ids)

config.name = (
config.attention
+ "_SNN(CNN)-DVS-Gesture_dt="
+ str(config.dt)
+ "ms"
+ "_T="
+ str(config.T)
)
config.modelNames = config.name + ".t7"
config.recordNames = config.name + ".csv"

print(config)

create_data(config=config)

create_net(config=config)

print(config.model)

print(util.get_parameter_number(config.model))

process(config=config)

print("best acc:", config.best_acc, "best_epoch:", config.best_epoch)

save_csv(config=config)
17 changes: 12 additions & 5 deletions MA_SNN/DVSGestures/CNN/Config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os, torch
import torch.nn as nn
from spikingjelly.activation_based.neuron import surrogate


class configs(object):
def __init__(self):
self.dt = 25
self.dt = 15
self.T = 60

self.attention = "no"
self.attention = "TCSA"
self.c_ratio=8
self.t_ratio=5
self.epoch = 0
Expand All @@ -16,7 +17,7 @@ def __init__(self):
self.pretrained_path = None

self.batch_size = 128
self.batch_size_test = 128
self.batch_size_test = 28

# None 'kaiming' 'xavier'
self.init_method = None
Expand All @@ -37,18 +38,24 @@ def __init__(self):
self.interval_scaling = False

# network
self.beta = 0
self.beta = 0.
self.alpha = 0.3
self.Vreset = 0
self.Vreset = 0.
self.Vthres = 0.3
self.reduction = 16
self.T_extend_Conv = False
self.T_extend_BN = False
self.h_conv = False
self.step_mode = "m"
# self.surrogate_function = surrogate.Sigmoid()
self.surrogate_function = surrogate.LeakyKReLU()
self.backend = "cupy"
# Old parameters
self.mem_act = torch.relu
self.mode_select = "spike"
self.TR_model = "NTR"


# BatchNorm
self.track_running_stats = True

Expand Down
8 changes: 7 additions & 1 deletion MA_SNN/DVSGestures/CNN/Networks/Att_SNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,24 @@ def __init__(
)

def forward(self, input):
# print("input: ",input.shape)
b, t, _, _, _ = input.size()
outputs = input

outputs = self.convAttLIF0(outputs)
# print("convAttLIF0: ",outputs.shape)
outputs = self.convAttLIF1(outputs)
# print("convAttLIF1: ",outputs.shape)
outputs = self.convAttLIF2(outputs)
# print("convAttLIF1: ",outputs.shape)

outputs = outputs.reshape(b, t, -1)

# print("fc_input: ",outputs.shape)
outputs = self.FC0(outputs)
# print("FC0: ",outputs.shape)

outputs = self.FC1(outputs)
# print("FC1: ",outputs.shape)
outputs = torch.sum(outputs, dim=1)
outputs = outputs / t

Expand Down
Loading