Skip to content

Commit

Permalink
Merge pull request #139 from Gaffey/snn_mlp_branch
Browse files Browse the repository at this point in the history
Snn mlp branch
  • Loading branch information
Gaffey authored Sep 16, 2022
2 parents f4ffbe5 + d206ac7 commit 89c8758
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
39 changes: 38 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# 2020.06.09-GhostNet definition for pytorch hub
# 2022.09.16-GhostNet & SNN-MLP definition for pytorch hub
# Huawei Technologies Co., Ltd. <[email protected]>
dependencies = ['torch']
import torch
from ghostnet_pytorch.ghostnet import ghostnet
from snnmlp_pytorch.models.snn_mlp import SNNMLP


state_dict_url = 'https://github.com/huawei-noah/ghostnet/raw/master/ghostnet_pytorch/models/state_dict_73.98.pth'
state_dict_url_snnmlp_t = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_tiny_81.88.pt'
state_dict_url_snnmlp_s = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_small_83.30.pt'
state_dict_url_snnmlp_b = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_base_83.59.pt'


def ghostnet_1x(pretrained=False, **kwargs):
Expand All @@ -18,3 +22,36 @@ def ghostnet_1x(pretrained=False, **kwargs):
state_dict = torch.hub.load_state_dict_from_url(state_dict_url, progress=True)
model.load_state_dict(state_dict)
return model

def snnmlp_t(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
SNN-MLP tiny model
pretrained (bool): kwargs, load pretrained weights into the model
"""
model = SNNMLP(num_classes=1000, embed_dim=96, depths=[2, 2, 6, 2], drop_path_rate=0.2)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(state_dict_url_snnmlp_t, progress=True)
model.load_state_dict(state_dict)
return model

def snnmlp_s(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
SNN-MLP small model
pretrained (bool): kwargs, load pretrained weights into the model
"""
model = SNNMLP(num_classes=1000, embed_dim=96, depths=[2, 2, 18, 2], drop_path_rate=0.3)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(state_dict_url_snnmlp_s, progress=True)
model.load_state_dict(state_dict)
return model

def snnmlp_b(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
SNN-MLP base model
pretrained (bool): kwargs, load pretrained weights into the model
"""
model = SNNMLP(num_classes=1000, embed_dim=128, depths=[2, 2, 18, 2], drop_path_rate=0.5)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(state_dict_url_snnmlp_b, progress=True)
model.load_state_dict(state_dict)
return model
12 changes: 5 additions & 7 deletions snnmlp_pytorch/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,11 @@ Dataset used: [ImageNet2012](http://www.image-net.org/)

```

Download checkpoints: [[BaiduDisk]](https://pan.baidu.com/s/1YuxSJNOUyPZUKUPy419HPg), password: 02tb.

## Result

|Model|Params (M)|FLOPs (B)|Top-1|
|-|-|-|-|
|SNN-MLP-T|28.3|4.4|81.9|
|SNN-MLP-S|49.7|8.5|83.3|
|SNN-MLP-B|87.9|15.2|83.6|
|Model|Params (M)|FLOPs (B)|Top-1|Download URL|
|-|-|-|-|-|
|SNN-MLP-T|28.3|4.4|81.9|[[ckpt]](https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_tiny_81.88.pt) & [[log]]( https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_tiny_81.88.log)|
|SNN-MLP-S|49.7|8.5|83.3|[[ckpt]](https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_small_83.30.pt) & [[log]]( https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_small_83.30.log)|
|SNN-MLP-B|87.9|15.2|83.6|[[ckpt]](https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_base_83.59.pt) & [[log]]( https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/snnmlp/snnmlp_base_83.59.log)|

0 comments on commit 89c8758

Please sign in to comment.