Blocks
===================


**MLP**

Multi-Layer Perceptron (MLP) is a class of feedforward artificial neural network (ANN). The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see ยง Terminology. Multilayer perceptrons are sometimes colloquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.

::

   def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


**Dual Attention Module**

In latent layer, if the channel of the feature map is 256, then ues the following code to implement the dual attention module.

::

    def ATT256(acti5_2):
        # Attention

        b = Conv3D(filters=32, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
        c = Conv3D(filters=32, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
        d = Conv3D(filters=256, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)

        vec_b = Reshape((512, 32))(b)
        vec_cT = Reshape((512, 32))(c)
        vec_cT = Permute((2, 1))(vec_cT)
        bcT = Dot(axes=(1, 2))([vec_cT, vec_b])
        softmax_bcT = Activation('softmax')(bcT)
        vec_d = Reshape((512, 256))(d)
        bcTd = Dot(axes=(1, 2))([vec_d, softmax_bcT])

        bcTd = Reshape((8, 8, 8, 256))(bcTd)
        out1 = Add()([bcTd, acti5_2])
        pam = BatchNormalization()(out1)
        pam = Activation('relu')(pam)

        vec_a = Reshape((512, 256))(acti5_2)
        vec_aT = Permute((2, 1))(vec_a)
        aTa = Dot(axes=(1, 2))([vec_a, vec_aT])
        softmax_aTa = Activation('softmax')(aTa)
        aaTa = Dot(axes=(1, 2))([softmax_aTa, vec_a])
        aaTa = Reshape((8, 8, 8, 256))(aaTa)
        out2 = Add()([aaTa, acti5_2])
        cam = BatchNormalization()(out2)
        cam = Activation('relu')(cam)

        attention = Add()([pam, cam])

        return attention


**Auto Dual Attention Module**

In latent layer, if the channel of the feature map is unknown, then ues the following code to implement the dual attention module.
acti5_2: the input latent feature.
org_channel: the channel of the feature map. For example, if the input channel of the feature map is 256, then org_channel=256
channels: the channel of the feature map after the dual attention module. For example, if the output channel of the feature map is 256, then channels=256
fsize: the size of the feature map. For example, if the size of the feature map is 8*8*8, then fsize=8

::

    def ATT_auto(acti5_2, org_channel, channels, fsize):
        # Attention c_channel=32, channels=128, fsize=8   [8,8,8,128]
        fsize = acti5_2.shape[-2]
        print('####################', acti5_2.shape)

        b = Conv3D(filters=org_channel, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
        c = Conv3D(filters=org_channel, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)
        d = Conv3D(filters=channels, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal')(acti5_2)

        vec_b = Reshape((fsize*fsize*fsize, org_channel))(b)
        vec_cT = Reshape((fsize*fsize*fsize, org_channel))(c)
        vec_cT = Permute((2, 1))(vec_cT)
        bcT = Dot(axes=(1, 2))([vec_cT, vec_b])
        softmax_bcT = Activation('softmax')(bcT)
        vec_d = Reshape((fsize*fsize*fsize, channels))(d)
        bcTd = Dot(axes=(1, 2))([vec_d, softmax_bcT])

        bcTd = Reshape((fsize, fsize, fsize, channels))(bcTd)
        out1 = Add()([bcTd, acti5_2])
        pam = BatchNormalization()(out1)
        pam = Activation('relu')(pam)

        vec_a = Reshape((fsize*fsize*fsize, channels))(acti5_2)
        vec_aT = Permute((2, 1))(vec_a)
        aTa = Dot(axes=(1, 2))([vec_a, vec_aT])
        softmax_aTa = Activation('softmax')(aTa)
        aaTa = Dot(axes=(1, 2))([softmax_aTa, vec_a])
        aaTa = Reshape((fsize, fsize, fsize, channels))(aaTa)
        out2 = Add()([aaTa, acti5_2])
        cam = BatchNormalization()(out2)
        cam = Activation('relu')(cam)

        attention = Add()([pam, cam])

        return attention


**Downsample Block**

Instead of using the MaxPooling layer, the downsample block uses the Conv3D layer with strides=2 to downsample the feature map.

::

    def downsample_block(x, filters):

        # MaxPooling
        maxpool2 = MaxPooling3D()(x)
        conv2_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
                         kernel_initializer='he_normal')(maxpool2)
        norm2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation('relu')(norm2_1)

        # strides=2
        conv3_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation('relu')(norm3_1)

        conv3_2 = Conv3D(filters=filters, kernel_size=3, strides=2, padding='same',
                         kernel_initializer='he_normal')(acti3_1)
        norm3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation('relu')(norm3_2)

        concat = Concatenate()([acti2_1, acti3_2])

        return concat


**MultiView Block**

Except to the traditional 3D convolution, the MultiView block uses the 3D convolution with different kernel size to extract the features.
Cause the vessel feature is long and thin along some plane, the MultiView block can extract the long features along different plane without introducing too much parameters.

::

    def MultiView(x, filters):
        conv1_1 = Conv3D(filters=filters, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation('relu')(norm1_1)

        conv2_1 = Conv3D(filters=filters, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation('relu')(norm2_1)

        conv3_1 = Conv3D(filters=filters, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation('relu')(norm3_1)

        #
        conv1_2 = Conv3D(filters=filters, kernel_size=(1, 3, 3), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti1_1)
        norm1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation('relu')(norm1_2)

        conv2_2 = Conv3D(filters=filters, kernel_size=(3, 1, 3), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti2_1)
        norm2_2 = BatchNormalization()(conv2_2)
        acti2_2 = Activation('relu')(norm2_2)

        conv3_2 = Conv3D(filters=filters, kernel_size=(3, 3, 1), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti3_1)
        norm3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation('relu')(norm3_2)

        #
        add = Add()([acti1_2, acti2_2, acti3_2])
        conv = Conv3D(filters=filters, kernel_size=(1, 1, 1), strides=1, padding='same',
                      kernel_initializer='he_normal')(add)
        norm = BatchNormalization()(conv)
        acti = Activation('relu')(norm)
        return acti


**Transformer Block**

The Transformer Block includes three MultiHeadAttention processes.

::

    def Transformer_Block(input):

        C = input.shape[-1]

        # print(input.shape)
        add4 = GroupNormalization()(input)
        attention1 = MultiHeadAttention(4, key_dim=3, dropout=0)(add4, add4)
        att_add1 = Add()([attention1, add4])
        att_add1 = GroupNormalization()(att_add1)
        mlp1 = mlp(att_add1, hidden_units=[2*C, C], dropout_rate=0)
        att_add1 = Add()([mlp1, att_add1])

        att_add2 = GroupNormalization()(att_add1)
        attention2 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add2, att_add2)
        att_add2 = Add()([attention2, att_add2])
        att_add2 = GroupNormalization()(att_add2)
        mlp2 = mlp(att_add2, hidden_units=[2*C, C], dropout_rate=0)
        att_add2 = Add()([mlp2, att_add2])

        att_add3 = GroupNormalization()(att_add2)
        attention3 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add3, att_add3)
        att_add3 = Add()([attention3, att_add3])
        att_add3 = GroupNormalization()(att_add3)
        mlp3 = mlp(att_add3, hidden_units=[2*C, C], dropout_rate=0)
        att_add3 = Add()([mlp3, att_add3])
        attention = GroupNormalization()(att_add3)
        return attention


**Reduction Block**

The Reduction Block is used to reduce the feature map size in three different ways and then concatenate them.

::

    def reduction_block(x, filters):
        # 1
        conv1_1 = Conv3D(filters=filters, kernel_size=1, strides=2, padding='same',
                         kernel_initializer='he_normal')(x)
        norm1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation('relu')(norm1_1)

        conv1_2 = Conv3D(filters=filters, kernel_size=(1, 1, 5), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti1_1)
        norm1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation('relu')(norm1_2)

        conv1_3 = Conv3D(filters=filters, kernel_size=(1, 5, 1), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti1_2)
        norm1_3 = BatchNormalization()(conv1_3)
        acti1_3 = Activation('relu')(norm1_3)

        conv1_4 = Conv3D(filters=filters, kernel_size=(5, 1, 1), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti1_3)
        norm1_4 = BatchNormalization()(conv1_4)
        acti1_4 = Activation('relu')(norm1_4)

        # 2
        maxpool2 = MaxPooling3D()(x)
        conv2_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
                         kernel_initializer='he_normal')(maxpool2)
        norm2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation('relu')(norm2_1)

        # 3
        conv3_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation('relu')(norm3_1)

        conv3_2 = Conv3D(filters=filters, kernel_size=5, strides=2, padding='same',
                         kernel_initializer='he_normal')(acti3_1)
        norm3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation('relu')(norm3_2)

        concat = Concatenate()([acti1_4, acti2_1, acti3_2])

        return concat


**Deep Block**

Same with the MultiView block, the Deep block uses the 3D convolution with different kernel size to extract the features along different axis.

::

    def deep_block(x, filters):
        # 1
        conv1_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation('relu')(norm1_1)

        conv1_2 = Conv3D(filters=filters, kernel_size=(1, 1, 7), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti1_1)
        norm1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation('relu')(norm1_2)

        conv1_3 = Conv3D(filters=filters, kernel_size=(1, 7, 1), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti1_2)
        norm1_3 = BatchNormalization()(conv1_3)
        acti1_3 = Activation('relu')(norm1_3)

        conv1_4 = Conv3D(filters=filters, kernel_size=(7, 1, 1), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti1_3)
        norm1_4 = BatchNormalization()(conv1_4)
        acti1_4 = Activation('relu')(norm1_4)

        # 2
        conv2_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation('relu')(norm2_1)

        conv2_2 = Conv3D(filters=filters, kernel_size=(7, 1, 1), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti2_1)
        norm2_2 = BatchNormalization()(conv2_2)
        acti2_2 = Activation('relu')(norm2_2)

        conv2_3 = Conv3D(filters=filters, kernel_size=(1, 7, 1), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti2_1)
        norm2_3 = BatchNormalization()(conv2_3)
        acti2_3 = Activation('relu')(norm2_3)

        conv2_4 = Conv3D(filters=filters, kernel_size=(1, 1, 7), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti2_1)
        norm2_4 = BatchNormalization()(conv2_4)
        acti2_4 = Activation('relu')(norm2_4)

        # 3
        conv3_1 = Conv3D(filters=filters, kernel_size=1, strides=1, padding='same',
                         kernel_initializer='he_normal')(x)
        norm3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation('relu')(norm3_1)

        conv3_2 = Conv3D(filters=filters, kernel_size=(5, 5, 5), strides=1, padding='same',
                         kernel_initializer='he_normal')(acti3_1)
        norm3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation('relu')(norm3_2)

        concat = Concatenate()([acti1_4, acti2_2, acti2_3, acti2_4, acti3_2])

        return concat