-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
123 lines (84 loc) · 2.73 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import jax
import flax.linen as nn
import jax.numpy as jnp
class CAMlp(nn.Module):
"""
Mlp in Channel Attention Module.
"""
@nn.compact
def __call__(self, inputs):
x = nn.Conv(
features=inputs.shape[-1]//16,
kernel_size=(1,1),
padding='SAME',
use_bias=False)(inputs)
x = nn.relu(x)
x = nn.Conv(
features=inputs.shape[-1],
kernel_size=(1,1),
padding='SAME',
use_bias=False)(x)
return x
class ChannelAttention(nn.Module):
"""
Channel Attention Module
"""
def setup(self):
self.mlp = CAMlp()
@nn.compact
def __call__(self, x):
avg_out = nn.avg_pool(x, window_shape=(64,64))
avg_out = self.mlp(avg_out)
max_out = nn.max_pool(x, window_shape=(64,64))
max_out = self.mlp(max_out)
out = avg_out + max_out
return nn.sigmoid(out)
class SpatialAttention(nn.Module):
"""
Spatial Attention Module
"""
@nn.compact
def __call__(self,x):
avg_out = jnp.mean(x, axis=-1, keepdims=True)
max_out = jnp.max(x, axis=-1, keepdims=True)
x = jnp.concatenate((avg_out, max_out), axis=-1)
x = nn.Conv(features=1,kernel_size=(3,3),use_bias=False)(x)
return nn.sigmoid(x)
class CBAMBlock(nn.Module):
"""
CBAM Block
"""
def setup(self):
self.ca = ChannelAttention()
self.sa = SpatialAttention()
@nn.compact
def __call__(self, x):
x = self.ca(x) * x
x = self.sa(x) * x
return x
class CBAMResBlock(nn.Module):
"""
CBAM integrated with a ResBlock in ResNet
"""
def setup(self):
self.conv1 = nn.Conv(64, (3,3), use_bias=False)
self.bn1 = nn.BatchNorm(use_running_average=True)
self.relu = nn.relu
self.conv2 = nn.Conv(64, (3,3), use_bias=False)
self.bn2 = nn.BatchNorm(use_running_average=True)
self.cbam = CBAMBlock()
@nn.compact
def __call__(self, x):
residual = x
out = nn.Conv(64, (3,3), use_bias=False)(x)
out = nn.BatchNorm()(out)
out = nn.gelu(out)
out = nn.Conv(64, (3,3), use_bias=False)(out)
out = nn.BatchNorm()(out)
out = self.cbam(out)
if out.shape != residual.shape:
residual = nn.Conv(out.shape[-1], (1,1),
name='conv_projection')(residual)
out += residual
out = nn.gelu(out)
return out