Blocks¶
MLP
Multi-Layer Perceptron (MLP) is a class of feedforward artificial neural network (ANN). The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see § Terminology. Multilayer perceptrons are sometimes colloquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
Dual Attention Module
In latent layer, if the channel of the feature map is 256, then ues the following code to implement the dual attention module.
def ATT256(acti5_2):
# Attention
b = Conv3D(filters=32, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
c = Conv3D(filters=32, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
d = Conv3D(filters=256, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
vec_b = Reshape((512, 32))(b)
vec_cT = Reshape((512, 32))(c)
vec_cT = Permute((2, 1))(vec_cT)
bcT = Dot(axes=(1, 2))([vec_cT, vec_b])
softmax_bcT = Activation('softmax')(bcT)
vec_d = Reshape((512, 256))(d)
bcTd = Dot(axes=(1, 2))([vec_d, softmax_bcT])
bcTd = Reshape((8, 8, 8, 256))(bcTd)
out1 = Add()([bcTd, acti5_2])
pam = BatchNormalization()(out1)
pam = Activation('relu')(pam)
vec_a = Reshape((512, 256))(acti5_2)
vec_aT = Permute((2, 1))(vec_a)
aTa = Dot(axes=(1, 2))([vec_a, vec_aT])
softmax_aTa = Activation('softmax')(aTa)
aaTa = Dot(axes=(1, 2))([softmax_aTa, vec_a])
aaTa = Reshape((8, 8, 8, 256))(aaTa)
out2 = Add()([aaTa, acti5_2])
cam = BatchNormalization()(out2)
cam = Activation('relu')(cam)
attention = Add()([pam, cam])
return attention
Auto Dual Attention Module
In latent layer, if the channel of the feature map is unknown, then ues the following code to implement the dual attention module. acti5_2: the input latent feature. org_channel: the channel of the feature map. For example, if the input channel of the feature map is 256, then org_channel=256 channels: the channel of the feature map after the dual attention module. For example, if the output channel of the feature map is 256, then channels=256 fsize: the size of the feature map. For example, if the size of the feature map is 8*8*8, then fsize=8
def ATT_auto(acti5_2, org_channel, channels, fsize):
# Attention c_channel=32, channels=128, fsize=8 [8,8,8,128]
fsize = acti5_2.shape[-2]
print('####################', acti5_2.shape)
b = Conv3D(filters=org_channel, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
c = Conv3D(filters=org_channel, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
d = Conv3D(filters=channels, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
vec_b = Reshape((fsize*fsize*fsize, org_channel))(b)
vec_cT = Reshape((fsize*fsize*fsize, org_channel))(c)
vec_cT = Permute((2, 1))(vec_cT)
bcT = Dot(axes=(1, 2))([vec_cT, vec_b])
softmax_bcT = Activation('softmax')(bcT)
vec_d = Reshape((fsize*fsize*fsize, channels))(d)
bcTd = Dot(axes=(1, 2))([vec_d, softmax_bcT])
bcTd = Reshape((fsize, fsize, fsize, channels))(bcTd)
out1 = Add()([bcTd, acti5_2])
pam = BatchNormalization()(out1)
pam = Activation('relu')(pam)
vec_a = Reshape((fsize*fsize*fsize, channels))(acti5_2)
vec_aT = Permute((2, 1))(vec_a)
aTa = Dot(axes=(1, 2))([vec_a, vec_aT])
softmax_aTa = Activation('softmax')(aTa)
aaTa = Dot(axes=(1, 2))([softmax_aTa, vec_a])
aaTa = Reshape((fsize, fsize, fsize, channels))(aaTa)
out2 = Add()([aaTa, acti5_2])
cam = BatchNormalization()(out2)
cam = Activation('relu')(cam)
attention = Add()([pam, cam])
return attention
Downsample Block
Instead of using the MaxPooling layer, the downsample block uses the Conv3D layer with strides=2 to downsample the feature map.
def downsample_block(x, filters):
# MaxPooling
maxpool2 = MaxPooling3D()(x)
conv2_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal')(maxpool2)
norm2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation('relu')(norm2_1)
# strides=2
conv3_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation('relu')(norm3_1)
conv3_2 = Conv3D(filters=filters, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti3_1)
norm3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation('relu')(norm3_2)
concat = Concatenate()([acti2_1, acti3_2])
return concat
MultiView Block
Except to the traditional 3D convolution, the MultiView block uses the 3D convolution with different kernel size to extract the features. Cause the vessel feature is long and thin along some plane, the MultiView block can extract the long features along different plane without introducing too much parameters.
def MultiView(x, filters):
conv1_1 = Conv3D(filters=filters, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation('relu')(norm1_1)
conv2_1 = Conv3D(filters=filters, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation('relu')(norm2_1)
conv3_1 = Conv3D(filters=filters, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation('relu')(norm3_1)
#
conv1_2 = Conv3D(filters=filters, kernel_size=(1, 3, 3), strides=1, padding='same',
kernel_initializer='he_normal')(acti1_1)
norm1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation('relu')(norm1_2)
conv2_2 = Conv3D(filters=filters, kernel_size=(3, 1, 3), strides=1, padding='same',
kernel_initializer='he_normal')(acti2_1)
norm2_2 = BatchNormalization()(conv2_2)
acti2_2 = Activation('relu')(norm2_2)
conv3_2 = Conv3D(filters=filters, kernel_size=(3, 3, 1), strides=1, padding='same',
kernel_initializer='he_normal')(acti3_1)
norm3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation('relu')(norm3_2)
#
add = Add()([acti1_2, acti2_2, acti3_2])
conv = Conv3D(filters=filters, kernel_size=(1, 1, 1), strides=1, padding='same',
kernel_initializer='he_normal')(add)
norm = BatchNormalization()(conv)
acti = Activation('relu')(norm)
return acti
Transformer Block
The Transformer Block includes three MultiHeadAttention processes.
def Transformer_Block(input):
C = input.shape[-1]
# print(input.shape)
add4 = GroupNormalization()(input)
attention1 = MultiHeadAttention(4, key_dim=3, dropout=0)(add4, add4)
att_add1 = Add()([attention1, add4])
att_add1 = GroupNormalization()(att_add1)
mlp1 = mlp(att_add1, hidden_units=[2*C, C], dropout_rate=0)
att_add1 = Add()([mlp1, att_add1])
att_add2 = GroupNormalization()(att_add1)
attention2 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add2, att_add2)
att_add2 = Add()([attention2, att_add2])
att_add2 = GroupNormalization()(att_add2)
mlp2 = mlp(att_add2, hidden_units=[2*C, C], dropout_rate=0)
att_add2 = Add()([mlp2, att_add2])
att_add3 = GroupNormalization()(att_add2)
attention3 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add3, att_add3)
att_add3 = Add()([attention3, att_add3])
att_add3 = GroupNormalization()(att_add3)
mlp3 = mlp(att_add3, hidden_units=[2*C, C], dropout_rate=0)
att_add3 = Add()([mlp3, att_add3])
attention = GroupNormalization()(att_add3)
return attention
Reduction Block
The Reduction Block is used to reduce the feature map size in three different ways and then concatenate them.
def reduction_block(x, filters):
# 1
conv1_1 = Conv3D(filters=filters, kernel_size=1, strides=2, padding='same',
kernel_initializer='he_normal')(x)
norm1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation('relu')(norm1_1)
conv1_2 = Conv3D(filters=filters, kernel_size=(1, 1, 5), strides=1, padding='same',
kernel_initializer='he_normal')(acti1_1)
norm1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation('relu')(norm1_2)
conv1_3 = Conv3D(filters=filters, kernel_size=(1, 5, 1), strides=1, padding='same',
kernel_initializer='he_normal')(acti1_2)
norm1_3 = BatchNormalization()(conv1_3)
acti1_3 = Activation('relu')(norm1_3)
conv1_4 = Conv3D(filters=filters, kernel_size=(5, 1, 1), strides=1, padding='same',
kernel_initializer='he_normal')(acti1_3)
norm1_4 = BatchNormalization()(conv1_4)
acti1_4 = Activation('relu')(norm1_4)
# 2
maxpool2 = MaxPooling3D()(x)
conv2_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal')(maxpool2)
norm2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation('relu')(norm2_1)
# 3
conv3_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation('relu')(norm3_1)
conv3_2 = Conv3D(filters=filters, kernel_size=5, strides=2, padding='same',
kernel_initializer='he_normal')(acti3_1)
norm3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation('relu')(norm3_2)
concat = Concatenate()([acti1_4, acti2_1, acti3_2])
return concat
Deep Block
Same with the MultiView block, the Deep block uses the 3D convolution with different kernel size to extract the features along different axis.
def deep_block(x, filters):
# 1
conv1_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation('relu')(norm1_1)
conv1_2 = Conv3D(filters=filters, kernel_size=(1, 1, 7), strides=1, padding='same',
kernel_initializer='he_normal')(acti1_1)
norm1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation('relu')(norm1_2)
conv1_3 = Conv3D(filters=filters, kernel_size=(1, 7, 1), strides=1, padding='same',
kernel_initializer='he_normal')(acti1_2)
norm1_3 = BatchNormalization()(conv1_3)
acti1_3 = Activation('relu')(norm1_3)
conv1_4 = Conv3D(filters=filters, kernel_size=(7, 1, 1), strides=1, padding='same',
kernel_initializer='he_normal')(acti1_3)
norm1_4 = BatchNormalization()(conv1_4)
acti1_4 = Activation('relu')(norm1_4)
# 2
conv2_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation('relu')(norm2_1)
conv2_2 = Conv3D(filters=filters, kernel_size=(7, 1, 1), strides=1, padding='same',
kernel_initializer='he_normal')(acti2_1)
norm2_2 = BatchNormalization()(conv2_2)
acti2_2 = Activation('relu')(norm2_2)
conv2_3 = Conv3D(filters=filters, kernel_size=(1, 7, 1), strides=1, padding='same',
kernel_initializer='he_normal')(acti2_1)
norm2_3 = BatchNormalization()(conv2_3)
acti2_3 = Activation('relu')(norm2_3)
conv2_4 = Conv3D(filters=filters, kernel_size=(1, 1, 7), strides=1, padding='same',
kernel_initializer='he_normal')(acti2_1)
norm2_4 = BatchNormalization()(conv2_4)
acti2_4 = Activation('relu')(norm2_4)
# 3
conv3_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal')(x)
norm3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation('relu')(norm3_1)
conv3_2 = Conv3D(filters=filters, kernel_size=(5, 5, 5), strides=1, padding='same',
kernel_initializer='he_normal')(acti3_1)
norm3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation('relu')(norm3_2)
concat = Concatenate()([acti1_4, acti2_2, acti2_3, acti2_4, acti3_2])
return concat