Model
===================




**U-Net**

U-Net is a type of convolutional neural network designed for fast and precise segmentation of images.
It was developed by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in 2015.
The network is based on the fully convolutional network and its architecture was modified and extended to work with fewer training images and to yield more precise segmentations.
Segmentation is the process of partitioning an image into multiple segments, which can be used to identify objects and boundaries in an image.

::

   def get_model_1_unet(self):
        # layer 1
        # bn0_1 = BatchNormalization()(self.input)
        conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_1')(self.input)
        bn1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation(activation='relu')(bn1_1)

        conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_2')(acti1_1)
        bn1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation(activation='relu')(bn1_2)

        add1 = acti1_2
        maxpool1 = MaxPooling3D()(add1)

        # layer 2
        conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_1')(maxpool1)
        bn2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation(activation='relu')(bn2_1)

        conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_2')(acti2_1)
        bn2_2 = BatchNormalization()(conv2_2)
        acti2_2 = Activation(activation='relu')(bn2_2)

        add2 = acti2_2
        maxpool2 = MaxPooling3D()(add2)

        # layer 3
        conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_1')(maxpool2)
        bn3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation(activation='relu')(bn3_1)

        conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_2')(acti3_1)
        bn3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation(activation='relu')(bn3_2)

        add3 = acti3_2
        maxpool3 = MaxPooling3D()(add3)

        # layer 4
        deep4 = maxpool3

        conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_1')(deep4)
        bn4_1 = BatchNormalization()(conv4_1)
        acti4_1 = Activation(activation='relu')(bn4_1)

        conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_2')(acti4_1)
        bn4_2 = BatchNormalization()(conv4_2)
        acti4_2 = Activation(activation='relu')(bn4_2)

        add4 = acti4_2

        attention = add4

        # layer 7 (equal to layer 3)
        upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(attention)
        bn7 = BatchNormalization()(upsample7)
        acti7 = Activation(activation='relu')(bn7)

        concat7 = Concatenate(axis=-1)([acti7, add3])

        deep7 = concat7

        conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_1')(deep7)
        bn7_1 = BatchNormalization()(conv7_1)
        acti7_1 = Activation(activation='relu')(bn7_1)

        conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
        bn7_2 = BatchNormalization()(conv7_2)
        acti7_2 = Activation(activation='relu')(bn7_2)

        # layer 8 (equal to layer 2)
        upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti7_2)
        bn8 = BatchNormalization()(upsample8)
        acti8 = Activation(activation='relu')(bn8)

        concat8 = Concatenate(axis=-1)([acti8, add2])

        deep8 = concat8

        conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_1')(deep8)
        bn8_1 = BatchNormalization()(conv8_1)
        acti8_1 = Activation(activation='relu')(bn8_1)

        conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
        bn8_2 = BatchNormalization()(conv8_2)
        acti8_2 = Activation(activation='relu')(bn8_2)

        # layer 9 (equal to layer 1)
        upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti8_2)
        bn9 = BatchNormalization()(upsample9)
        acti9 = Activation(activation='relu')(bn9)

        concat9 = Concatenate(axis=-1)([acti9, add1])

        deep9 = concat9

        conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_1')(deep9)
        bn9_1 = BatchNormalization()(conv9_1)
        acti9_1 = Activation(activation='relu')(bn9_1)

        conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
        bn9_2 = BatchNormalization()(conv9_2)
        acti9_2 = Activation(activation='relu')(bn9_2)

        output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                        kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

        model = Model(self.input, output)

        return model


**Dual Attention Network**

Dual Attention Network is a segmentation model that uses spacial and channel attention to improve the segmentation results.

::

   def get_model_1_dual_attention(self):
        # layer 1
        # bn0_1 = BatchNormalization()(self.input)
        conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_1')(self.input)
        bn1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation(activation='relu')(bn1_1)

        conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_2')(acti1_1)
        bn1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation(activation='relu')(bn1_2)

        add1 = acti1_2
        maxpool1 = MaxPooling3D()(add1)

        # layer 2
        conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_1')(maxpool1)
        bn2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation(activation='relu')(bn2_1)

        conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_2')(acti2_1)
        bn2_2 = BatchNormalization()(conv2_2)
        acti2_2 = Activation(activation='relu')(bn2_2)

        add2 = acti2_2
        maxpool2 = MaxPooling3D()(add2)

        # layer 3
        conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_1')(maxpool2)
        bn3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation(activation='relu')(bn3_1)

        conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_2')(acti3_1)
        bn3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation(activation='relu')(bn3_2)

        add3 = acti3_2
        maxpool3 = MaxPooling3D()(add3)

        # layer 4
        deep4 = maxpool3

        conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_1')(deep4)
        bn4_1 = BatchNormalization()(conv4_1)
        acti4_1 = Activation(activation='relu')(bn4_1)

        conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_2')(acti4_1)
        bn4_2 = BatchNormalization()(conv4_2)
        acti4_2 = Activation(activation='relu')(bn4_2)

        add4 = acti4_2

        attention = ATT(add4)


        # layer 7 (equal to layer 3)
        upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(attention)
        bn7 = BatchNormalization()(upsample7)
        acti7 = Activation(activation='relu')(bn7)

        concat7 = Concatenate(axis=-1)([acti7, add3])

        deep7 = concat7

        conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_1')(deep7)
        bn7_1 = BatchNormalization()(conv7_1)
        acti7_1 = Activation(activation='relu')(bn7_1)

        conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
        bn7_2 = BatchNormalization()(conv7_2)
        acti7_2 = Activation(activation='relu')(bn7_2)

        # layer 8 (equal to layer 2)
        upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti7_2)
        bn8 = BatchNormalization()(upsample8)
        acti8 = Activation(activation='relu')(bn8)

        concat8 = Concatenate(axis=-1)([acti8, add2])

        deep8 = concat8

        conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_1')(deep8)
        bn8_1 = BatchNormalization()(conv8_1)
        acti8_1 = Activation(activation='relu')(bn8_1)

        conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
        bn8_2 = BatchNormalization()(conv8_2)
        acti8_2 = Activation(activation='relu')(bn8_2)

        # layer 9 (equal to layer 1)
        upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti8_2)
        bn9 = BatchNormalization()(upsample9)
        acti9 = Activation(activation='relu')(bn9)

        concat9 = Concatenate(axis=-1)([acti9, add1])

        deep9 = concat9

        conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_1')(deep9)
        bn9_1 = BatchNormalization()(conv9_1)
        acti9_1 = Activation(activation='relu')(bn9_1)

        conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
        bn9_2 = BatchNormalization()(conv9_2)
        acti9_2 = Activation(activation='relu')(bn9_2)

        output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                        kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

        model = Model(self.input, output)

        return model


**UNet++**

UNet++ is a convolutional neural network for biomedical image segmentation.
It is an extension of the original UNet architecture that uses a nested, or "chained" architecture to improve segmentation results.

::

   def get_model_1_unetpp(self):
        # layer 1
        # bn0_1 = BatchNormalization()(self.input)
        conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_1')(self.input)
        bn1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation(activation='relu')(bn1_1)

        conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_2')(acti1_1)
        bn1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation(activation='relu')(bn1_2)

        add1 = acti1_2
        maxpool1 = MaxPooling3D()(add1)

        # layer 2
        conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_1')(maxpool1)
        bn2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation(activation='relu')(bn2_1)

        conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_2')(acti2_1)
        bn2_2 = BatchNormalization()(conv2_2)
        acti2_2 = Activation(activation='relu')(bn2_2)

        add2 = acti2_2
        maxpool2 = MaxPooling3D()(add2)

        # layer 3
        conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_1')(maxpool2)
        bn3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation(activation='relu')(bn3_1)

        conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_2')(acti3_1)
        bn3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation(activation='relu')(bn3_2)

        add3 = acti3_2
        maxpool3 = MaxPooling3D()(add3)

        # layer 4
        deep4 = maxpool3

        conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_1')(deep4)
        bn4_1 = BatchNormalization()(conv4_1)
        acti4_1 = Activation(activation='relu')(bn4_1)

        conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_2')(acti4_1)
        bn4_2 = BatchNormalization()(conv4_2)
        acti4_2 = Activation(activation='relu')(bn4_2)

        add4 = acti4_2

        attention = add4

        # layer12
        upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(add2)
        bn12 = BatchNormalization()(upsample12)
        acti12 = Activation(activation='relu')(bn12)

        concat12 = Concatenate(axis=-1)([add1, acti12])

        conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat12)
        bn12_1 = BatchNormalization()(conv12_1)
        acti12_1 = Activation(activation='relu')(bn12_1)

        conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti12_1)
        bn12_2 = BatchNormalization()(conv12_2)
        acti12_2 = Activation(activation='relu')(bn12_2)

        # layer22
        upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(add3)
        bn22 = BatchNormalization()(upsample22)
        acti22 = Activation(activation='relu')(bn22)

        concat22 = Concatenate(axis=-1)([add2, acti22])

        conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat22)
        bn22_1 = BatchNormalization()(conv22_1)
        acti22_1 = Activation(activation='relu')(bn22_1)

        conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti22_1)
        bn22_2 = BatchNormalization()(conv22_2)
        acti22_2 = Activation(activation='relu')(bn22_2)

        # layer13
        upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(acti22_2)
        bn13 = BatchNormalization()(upsample13)
        acti13 = Activation(activation='relu')(bn13)

        concat13 = Concatenate(axis=-1)([acti12_2, acti13])

        conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat13)
        bn13_1 = BatchNormalization()(conv13_1)
        acti13_1 = Activation(activation='relu')(bn13_1)

        conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti13_1)
        bn13_2 = BatchNormalization()(conv13_2)
        acti13_2 = Activation(activation='relu')(bn13_2)

        # layer 7 (equal to layer 3)
        upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(attention)
        bn7 = BatchNormalization()(upsample7)
        acti7 = Activation(activation='relu')(bn7)

        concat7 = Concatenate(axis=-1)([acti7, add3])

        deep7 = concat7

        conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_1')(deep7)
        bn7_1 = BatchNormalization()(conv7_1)
        acti7_1 = Activation(activation='relu')(bn7_1)

        conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
        bn7_2 = BatchNormalization()(conv7_2)
        acti7_2 = Activation(activation='relu')(bn7_2)

        # layer 8 (equal to layer 2)
        upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti7_2)
        bn8 = BatchNormalization()(upsample8)
        acti8 = Activation(activation='relu')(bn8)

        concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])

        deep8 = concat8

        conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_1')(deep8)
        bn8_1 = BatchNormalization()(conv8_1)
        acti8_1 = Activation(activation='relu')(bn8_1)

        conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
        bn8_2 = BatchNormalization()(conv8_2)
        acti8_2 = Activation(activation='relu')(bn8_2)

        # layer 9 (equal to layer 1)
        upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti8_2)
        bn9 = BatchNormalization()(upsample9)
        acti9 = Activation(activation='relu')(bn9)

        concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])

        deep9 = concat9

        conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_1')(deep9)
        bn9_1 = BatchNormalization()(conv9_1)
        acti9_1 = Activation(activation='relu')(bn9_1)

        conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
        bn9_2 = BatchNormalization()(conv9_2)
        acti9_2 = Activation(activation='relu')(bn9_2)

        output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                        kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

        model = Model(self.input, output)

        return model


**VASeg**

Ours network structure.

::

   def get_model_2(self):

        # layer 1
        # bn0_1 = BatchNormalization()(self.input)
        conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_1')(self.input)
        bn1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation(activation='relu')(bn1_1)

        conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_2')(acti1_1)
        bn1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation(activation='relu')(bn1_2)

        multiview1 = MultiView(self.input, 16)
        add1 = Add()([acti1_2, multiview1])

        maxpool1 = reduction_block(add1, 16)

        # layer 2
        conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_1')(maxpool1)
        bn2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation(activation='relu')(bn2_1)

        conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_2')(acti2_1)
        bn2_2 = BatchNormalization()(conv2_2)
        acti2_2 = Activation(activation='relu')(bn2_2)

        multiview2 = MultiView(maxpool1, 32)
        add2 = Add()([acti2_2, multiview2])

        maxpool2 = reduction_block(add2, 32)

        # layer 3
        conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_1')(maxpool2)
        bn3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation(activation='relu')(bn3_1)

        conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_2')(acti3_1)
        bn3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation(activation='relu')(bn3_2)

        multiview3 = MultiView(maxpool2, 64)
        add3 = Add()([acti3_2, multiview3])

        maxpool3 = reduction_block(add3, 64)

        # layer 4
        deep4 = deep_block(maxpool3, 128)

        conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_1')(deep4)
        bn4_1 = BatchNormalization()(conv4_1)
        acti4_1 = Activation(activation='relu')(bn4_1)

        conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_2')(acti4_1)
        bn4_2 = BatchNormalization()(conv4_2)
        acti4_2 = Activation(activation='relu')(bn4_2)

        multiview4 = MultiView(deep4, 128)
        add4 = Add()([acti4_2, multiview4])

        attention = ATT(add4)

        # layer12
        upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(add2)
        bn12 = BatchNormalization()(upsample12)
        acti12 = Activation(activation='relu')(bn12)

        concat12 = Concatenate(axis=-1)([add1, acti12])

        conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat12)
        bn12_1 = BatchNormalization()(conv12_1)
        acti12_1 = Activation(activation='relu')(bn12_1)

        conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti12_1)
        bn12_2 = BatchNormalization()(conv12_2)
        acti12_2 = Activation(activation='relu')(bn12_2)

        # layer22
        upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(add3)
        bn22 = BatchNormalization()(upsample22)
        acti22 = Activation(activation='relu')(bn22)

        concat22 = Concatenate(axis=-1)([add2, acti22])

        conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat22)
        bn22_1 = BatchNormalization()(conv22_1)
        acti22_1 = Activation(activation='relu')(bn22_1)

        conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti22_1)
        bn22_2 = BatchNormalization()(conv22_2)
        acti22_2 = Activation(activation='relu')(bn22_2)

        # layer13
        upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(acti22_2)
        bn13 = BatchNormalization()(upsample13)
        acti13 = Activation(activation='relu')(bn13)

        concat13 = Concatenate(axis=-1)([acti12_2, acti13])

        conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat13)
        bn13_1 = BatchNormalization()(conv13_1)
        acti13_1 = Activation(activation='relu')(bn13_1)

        conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti13_1)
        bn13_2 = BatchNormalization()(conv13_2)
        acti13_2 = Activation(activation='relu')(bn13_2)

        # layer 7 (equal to layer 3)
        upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(attention)
        bn7 = BatchNormalization()(upsample7)
        acti7 = Activation(activation='relu')(bn7)

        concat7 = Concatenate(axis=-1)([acti7, add3])

        deep7 = deep_block(concat7, 64)

        conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_1')(deep7)
        bn7_1 = BatchNormalization()(conv7_1)
        acti7_1 = Activation(activation='relu')(bn7_1)

        conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
        bn7_2 = BatchNormalization()(conv7_2)
        acti7_2 = Activation(activation='relu')(bn7_2)

        # layer 8 (equal to layer 2)
        upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti7_2)
        bn8 = BatchNormalization()(upsample8)
        acti8 = Activation(activation='relu')(bn8)

        concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])

        deep8 = deep_block(concat8, 32)

        conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_1')(deep8)
        bn8_1 = BatchNormalization()(conv8_1)
        acti8_1 = Activation(activation='relu')(bn8_1)

        conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
        bn8_2 = BatchNormalization()(conv8_2)
        acti8_2 = Activation(activation='relu')(bn8_2)

        # layer 9 (equal to layer 1)
        upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti8_2)
        bn9 = BatchNormalization()(upsample9)
        acti9 = Activation(activation='relu')(bn9)

        concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])

        deep9 = deep_block(concat9, 16)

        conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_1')(deep9)
        bn9_1 = BatchNormalization()(conv9_1)
        acti9_1 = Activation(activation='relu')(bn9_1)

        conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
        bn9_2 = BatchNormalization()(conv9_2)
        acti9_2 = Activation(activation='relu')(bn9_2)

        output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                        kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

        model = Model(self.input, output)

        return model


**VASeg Enhanced Version**

In this version we added *Magnitude Attention Gate* and *Transformer Block* based on the base VASeg model.

::

   def get_model_6(self):
        # layer 1
        # bn0_1 = BatchNormalization()(self.input)
        conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_1')(self.input)
        bn1_1 = BatchNormalization()(conv1_1)
        acti1_1 = Activation(activation='relu')(bn1_1)

        conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv1_2')(acti1_1)
        bn1_2 = BatchNormalization()(conv1_2)
        acti1_2 = Activation(activation='relu')(bn1_2)

        multiview1 = MultiView(self.input, 16)
        add1 = Add()([acti1_2, multiview1])

        maxpool1 = downsample_block(add1, 16)

        # layer 2
        conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_1')(maxpool1)
        bn2_1 = BatchNormalization()(conv2_1)
        acti2_1 = Activation(activation='relu')(bn2_1)

        conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv2_2')(acti2_1)
        bn2_2 = BatchNormalization()(conv2_2)
        acti2_2 = Activation(activation='relu')(bn2_2)

        multiview2 = MultiView(maxpool1, 32)
        add2 = Add()([acti2_2, multiview2])

        maxpool2 = downsample_block(add2, 32)

        # layer 3
        conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_1')(maxpool2)
        bn3_1 = BatchNormalization()(conv3_1)
        acti3_1 = Activation(activation='relu')(bn3_1)

        conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv3_2')(acti3_1)
        bn3_2 = BatchNormalization()(conv3_2)
        acti3_2 = Activation(activation='relu')(bn3_2)

        multiview3 = MultiView(maxpool2, 64)
        add3 = Add()([acti3_2, multiview3])

        maxpool3 = downsample_block(add3, 64)

        # layer 4
        # deep4 = deep_block(maxpool3, 128)
        deep4 = maxpool3

        conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_1')(deep4)
        bn4_1 = BatchNormalization()(conv4_1)
        acti4_1 = Activation(activation='relu')(bn4_1)

        conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='conv4_2')(acti4_1)
        bn4_2 = BatchNormalization()(conv4_2)
        acti4_2 = Activation(activation='relu')(bn4_2)

        multiview4 = MultiView(deep4, 128)
        add4 = Add()([acti4_2, multiview4])

        add4 = ATT(add4)

        # test = tf.signal.fft3d

        # embedding = Embedding(196, 768, input_length=196)
        # emb_pos = np.arange(128)
        # emb = Embedding(128, 512, input_length=128)(emb_pos)
        # print(emb.shape)
        # # emb = tf.transpose(emb)
        # emb = tf.reshape(emb, shape=(-1, 8, 8, 8, 128))
        # print(emb.shape)
        # embedded = Add()([add4, emb])

        print(add4.shape)
        add4 = LayerNormalization()(add4)
        attention1 = MultiHeadAttention(4, key_dim=3, dropout=0)(add4, add4)
        att_add1 = Add()([attention1, add4])
        att_add1 = LayerNormalization()(att_add1)
        mlp1 = mlp(att_add1, hidden_units=[256, 128], dropout_rate=0)
        att_add1 = Add()([mlp1, att_add1])

        att_add2 = LayerNormalization()(att_add1)
        attention2 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add2, att_add2)
        att_add2 = Add()([attention2, att_add2])
        att_add2 = LayerNormalization()(att_add2)
        mlp2 = mlp(att_add2, hidden_units=[256, 128], dropout_rate=0)
        att_add2 = Add()([mlp2, att_add2])

        att_add3 = LayerNormalization()(att_add2)
        attention3 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add3, att_add3)
        att_add3 = Add()([attention3, att_add3])
        att_add3 = LayerNormalization()(att_add3)
        mlp3 = mlp(att_add3, hidden_units=[256, 128], dropout_rate=0)
        att_add3 = Add()([mlp3, att_add3])
        attention = LayerNormalization()(att_add3)
        # attention = tf.reshape(attention, shape=(8, 8, 8, 128))

        # layer12
        upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(add2)
        bn12 = BatchNormalization()(upsample12)
        acti12 = Activation(activation='relu')(bn12)

        concat12 = Concatenate(axis=-1)([add1, acti12])

        conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat12)
        bn12_1 = BatchNormalization()(conv12_1)
        acti12_1 = Activation(activation='relu')(bn12_1)

        conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti12_1)
        bn12_2 = BatchNormalization()(conv12_2)
        acti12_2 = Activation(activation='relu')(bn12_2)





        # layer22
        upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(add3)
        bn22 = BatchNormalization()(upsample22)
        acti22 = Activation(activation='relu')(bn22)

        concat22 = Concatenate(axis=-1)([add2, acti22])

        conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat22)
        bn22_1 = BatchNormalization()(conv22_1)
        acti22_1 = Activation(activation='relu')(bn22_1)

        conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti22_1)
        bn22_2 = BatchNormalization()(conv22_2)
        acti22_2 = Activation(activation='relu')(bn22_2)





        # # Attention Gate
        #
        # AG1_1 = acti22_2 +



        # layer13
        upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                     kernel_initializer='he_normal')(acti22_2)
        bn13 = BatchNormalization()(upsample13)
        acti13 = Activation(activation='relu')(bn13)

        concat13 = Concatenate(axis=-1)([acti12_2, acti13])

        conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(concat13)
        bn13_1 = BatchNormalization()(conv13_1)
        acti13_1 = Activation(activation='relu')(bn13_1)

        conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                          kernel_initializer='he_normal')(acti13_1)
        bn13_2 = BatchNormalization()(conv13_2)
        acti13_2 = Activation(activation='relu', name='sag1_input_l')(bn13_2)

        # layer 7 (equal to layer 3)
        upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(attention)
        bn7 = BatchNormalization()(upsample7)
        acti7 = Activation(activation='relu')(bn7)

        ######### Magnitude Attention Gate 3 ################
        mag3_1_0 = K.cast(add3, "complex64")
        mag3_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(mag3_1_0)
        mag3_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_1_1)

        acti7mag = Activation(activation='sigmoid')(acti7)
        acti7mag = K.cast(acti7mag, "complex64")
        mag3_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(acti7mag)
        mag3_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_2_1)

        mag_attention3 = mag3_1_2 + mag3_2_2

        mag3_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 16, 16, 16))(mag_attention3)
        mag3_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 16, 16, 16))(mag3_3_1)
        mag3_output = Lambda(tf.abs, output_shape=(None, 16, 16, 16))(mag3_3_2)
        mag3_output = K.cast(mag3_output, "float32")
        ##################################################

        concat7 = Concatenate(axis=-1)([acti7, add3])

        # deep7 = deep_block(concat7, 64)
        deep7 = concat7

        conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_1')(deep7)
        bn7_1 = BatchNormalization()(conv7_1)
        acti7_1 = Activation(activation='relu')(bn7_1)

        #########################
        acti7_1 = Add()([acti7_1, mag3_output])
        #########################

        conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
        bn7_2 = BatchNormalization()(conv7_2)
        acti7_2 = Activation(activation='relu')(bn7_2)

        # layer 8 (equal to layer 2)
        upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti7_2)
        bn8 = BatchNormalization()(upsample8)
        acti8 = Activation(activation='relu')(bn8)

        ######### Magnitude Attention Gate 2 ################
        mag2_1_0 = K.cast(acti22_2, "complex64")
        mag2_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(mag2_1_0)
        mag2_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_1_1)

        acti8mag = Activation(activation='sigmoid')(acti8)
        acti8mag = K.cast(acti8mag, "complex64")
        mag2_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(acti8mag)
        mag2_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_2_1)

        mag_attention2 = mag2_1_2 + mag2_2_2

        mag2_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 32, 32, 32))(mag_attention2)
        mag2_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 32, 32, 32))(mag2_3_1)
        mag2_output = Lambda(tf.abs, output_shape=(None, 32, 32, 32))(mag2_3_2)
        mag2_output = K.cast(mag2_output, "float32")
        ##################################################

        concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])

        # deep8 = deep_block(concat8, 32)
        deep8 = concat8

        conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_1')(deep8)
        bn8_1 = BatchNormalization()(conv8_1)
        acti8_1 = Activation(activation='relu')(bn8_1)

        #########################
        acti8_1 = Add()([acti8_1, mag2_output])
        #########################

        conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
        bn8_2 = BatchNormalization()(conv8_2)
        acti8_2 = Activation(activation='relu')(bn8_2)

        # layer 9 (equal to layer 1)
        upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                    kernel_initializer='he_normal')(acti8_2)
        bn9 = BatchNormalization()(upsample9)
        acti9 = Activation(activation='relu', name='sag1_input_r')(bn9)

        ######### Magnitude Attention Gate 1 ################
        mag1_1_0 = K.cast(acti13_2, "complex64")
        mag1_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_1_1')(mag1_1_0)
        mag1_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_1_2')(mag1_1_1)

        acti9mag = Activation(activation='sigmoid', name='acti9mag')(acti9)
        acti9mag = K.cast(acti9mag, "complex64")
        mag1_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_2_1')(acti9mag)
        mag1_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_2_2')(mag1_2_1)

        mag_attention = mag1_1_2 + mag1_2_2

        mag1_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 64, 64, 64), name='mag1_3_1')(mag_attention)
        mag1_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 64, 64, 64), name='mag1_3_2')(mag1_3_1)
        mag1_output = Lambda(tf.abs, output_shape=(None, 64, 64, 64), name='mag1_output')(mag1_3_2)
        mag1_output = K.cast(mag1_output, "float32")

        concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])

        deep9 = concat9

        conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_1')(deep9)
        bn9_1 = BatchNormalization()(conv9_1)
        acti9_1 = Activation(activation='relu')(bn9_1)

        acti9_1 = Add()([acti9_1, mag1_output])

        conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                         kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
        bn9_2 = BatchNormalization()(conv9_2)
        acti9_2 = Activation(activation='relu')(bn9_2)

        output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                        kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)





        model = Model(self.input, output)

        return model