ModelΒΆ

U-Net

U-Net is a type of convolutional neural network designed for fast and precise segmentation of images. It was developed by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in 2015. The network is based on the fully convolutional network and its architecture was modified and extended to work with fewer training images and to yield more precise segmentations. Segmentation is the process of partitioning an image into multiple segments, which can be used to identify objects and boundaries in an image.

def get_model_1_unet(self):
     # layer 1
     # bn0_1 = BatchNormalization()(self.input)
     conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_1')(self.input)
     bn1_1 = BatchNormalization()(conv1_1)
     acti1_1 = Activation(activation='relu')(bn1_1)

     conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_2')(acti1_1)
     bn1_2 = BatchNormalization()(conv1_2)
     acti1_2 = Activation(activation='relu')(bn1_2)

     add1 = acti1_2
     maxpool1 = MaxPooling3D()(add1)

     # layer 2
     conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_1')(maxpool1)
     bn2_1 = BatchNormalization()(conv2_1)
     acti2_1 = Activation(activation='relu')(bn2_1)

     conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_2')(acti2_1)
     bn2_2 = BatchNormalization()(conv2_2)
     acti2_2 = Activation(activation='relu')(bn2_2)

     add2 = acti2_2
     maxpool2 = MaxPooling3D()(add2)

     # layer 3
     conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_1')(maxpool2)
     bn3_1 = BatchNormalization()(conv3_1)
     acti3_1 = Activation(activation='relu')(bn3_1)

     conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_2')(acti3_1)
     bn3_2 = BatchNormalization()(conv3_2)
     acti3_2 = Activation(activation='relu')(bn3_2)

     add3 = acti3_2
     maxpool3 = MaxPooling3D()(add3)

     # layer 4
     deep4 = maxpool3

     conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_1')(deep4)
     bn4_1 = BatchNormalization()(conv4_1)
     acti4_1 = Activation(activation='relu')(bn4_1)

     conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_2')(acti4_1)
     bn4_2 = BatchNormalization()(conv4_2)
     acti4_2 = Activation(activation='relu')(bn4_2)

     add4 = acti4_2

     attention = add4

     # layer 7 (equal to layer 3)
     upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(attention)
     bn7 = BatchNormalization()(upsample7)
     acti7 = Activation(activation='relu')(bn7)

     concat7 = Concatenate(axis=-1)([acti7, add3])

     deep7 = concat7

     conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_1')(deep7)
     bn7_1 = BatchNormalization()(conv7_1)
     acti7_1 = Activation(activation='relu')(bn7_1)

     conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
     bn7_2 = BatchNormalization()(conv7_2)
     acti7_2 = Activation(activation='relu')(bn7_2)

     # layer 8 (equal to layer 2)
     upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti7_2)
     bn8 = BatchNormalization()(upsample8)
     acti8 = Activation(activation='relu')(bn8)

     concat8 = Concatenate(axis=-1)([acti8, add2])

     deep8 = concat8

     conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_1')(deep8)
     bn8_1 = BatchNormalization()(conv8_1)
     acti8_1 = Activation(activation='relu')(bn8_1)

     conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
     bn8_2 = BatchNormalization()(conv8_2)
     acti8_2 = Activation(activation='relu')(bn8_2)

     # layer 9 (equal to layer 1)
     upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti8_2)
     bn9 = BatchNormalization()(upsample9)
     acti9 = Activation(activation='relu')(bn9)

     concat9 = Concatenate(axis=-1)([acti9, add1])

     deep9 = concat9

     conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_1')(deep9)
     bn9_1 = BatchNormalization()(conv9_1)
     acti9_1 = Activation(activation='relu')(bn9_1)

     conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
     bn9_2 = BatchNormalization()(conv9_2)
     acti9_2 = Activation(activation='relu')(bn9_2)

     output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                     kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

     model = Model(self.input, output)

     return model

Dual Attention Network

Dual Attention Network is a segmentation model that uses spacial and channel attention to improve the segmentation results.

def get_model_1_dual_attention(self):
     # layer 1
     # bn0_1 = BatchNormalization()(self.input)
     conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_1')(self.input)
     bn1_1 = BatchNormalization()(conv1_1)
     acti1_1 = Activation(activation='relu')(bn1_1)

     conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_2')(acti1_1)
     bn1_2 = BatchNormalization()(conv1_2)
     acti1_2 = Activation(activation='relu')(bn1_2)

     add1 = acti1_2
     maxpool1 = MaxPooling3D()(add1)

     # layer 2
     conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_1')(maxpool1)
     bn2_1 = BatchNormalization()(conv2_1)
     acti2_1 = Activation(activation='relu')(bn2_1)

     conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_2')(acti2_1)
     bn2_2 = BatchNormalization()(conv2_2)
     acti2_2 = Activation(activation='relu')(bn2_2)

     add2 = acti2_2
     maxpool2 = MaxPooling3D()(add2)

     # layer 3
     conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_1')(maxpool2)
     bn3_1 = BatchNormalization()(conv3_1)
     acti3_1 = Activation(activation='relu')(bn3_1)

     conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_2')(acti3_1)
     bn3_2 = BatchNormalization()(conv3_2)
     acti3_2 = Activation(activation='relu')(bn3_2)

     add3 = acti3_2
     maxpool3 = MaxPooling3D()(add3)

     # layer 4
     deep4 = maxpool3

     conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_1')(deep4)
     bn4_1 = BatchNormalization()(conv4_1)
     acti4_1 = Activation(activation='relu')(bn4_1)

     conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_2')(acti4_1)
     bn4_2 = BatchNormalization()(conv4_2)
     acti4_2 = Activation(activation='relu')(bn4_2)

     add4 = acti4_2

     attention = ATT(add4)


     # layer 7 (equal to layer 3)
     upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(attention)
     bn7 = BatchNormalization()(upsample7)
     acti7 = Activation(activation='relu')(bn7)

     concat7 = Concatenate(axis=-1)([acti7, add3])

     deep7 = concat7

     conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_1')(deep7)
     bn7_1 = BatchNormalization()(conv7_1)
     acti7_1 = Activation(activation='relu')(bn7_1)

     conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
     bn7_2 = BatchNormalization()(conv7_2)
     acti7_2 = Activation(activation='relu')(bn7_2)

     # layer 8 (equal to layer 2)
     upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti7_2)
     bn8 = BatchNormalization()(upsample8)
     acti8 = Activation(activation='relu')(bn8)

     concat8 = Concatenate(axis=-1)([acti8, add2])

     deep8 = concat8

     conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_1')(deep8)
     bn8_1 = BatchNormalization()(conv8_1)
     acti8_1 = Activation(activation='relu')(bn8_1)

     conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
     bn8_2 = BatchNormalization()(conv8_2)
     acti8_2 = Activation(activation='relu')(bn8_2)

     # layer 9 (equal to layer 1)
     upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti8_2)
     bn9 = BatchNormalization()(upsample9)
     acti9 = Activation(activation='relu')(bn9)

     concat9 = Concatenate(axis=-1)([acti9, add1])

     deep9 = concat9

     conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_1')(deep9)
     bn9_1 = BatchNormalization()(conv9_1)
     acti9_1 = Activation(activation='relu')(bn9_1)

     conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
     bn9_2 = BatchNormalization()(conv9_2)
     acti9_2 = Activation(activation='relu')(bn9_2)

     output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                     kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

     model = Model(self.input, output)

     return model

UNet++

UNet++ is a convolutional neural network for biomedical image segmentation. It is an extension of the original UNet architecture that uses a nested, or "chained" architecture to improve segmentation results.

def get_model_1_unetpp(self):
     # layer 1
     # bn0_1 = BatchNormalization()(self.input)
     conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_1')(self.input)
     bn1_1 = BatchNormalization()(conv1_1)
     acti1_1 = Activation(activation='relu')(bn1_1)

     conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_2')(acti1_1)
     bn1_2 = BatchNormalization()(conv1_2)
     acti1_2 = Activation(activation='relu')(bn1_2)

     add1 = acti1_2
     maxpool1 = MaxPooling3D()(add1)

     # layer 2
     conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_1')(maxpool1)
     bn2_1 = BatchNormalization()(conv2_1)
     acti2_1 = Activation(activation='relu')(bn2_1)

     conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_2')(acti2_1)
     bn2_2 = BatchNormalization()(conv2_2)
     acti2_2 = Activation(activation='relu')(bn2_2)

     add2 = acti2_2
     maxpool2 = MaxPooling3D()(add2)

     # layer 3
     conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_1')(maxpool2)
     bn3_1 = BatchNormalization()(conv3_1)
     acti3_1 = Activation(activation='relu')(bn3_1)

     conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_2')(acti3_1)
     bn3_2 = BatchNormalization()(conv3_2)
     acti3_2 = Activation(activation='relu')(bn3_2)

     add3 = acti3_2
     maxpool3 = MaxPooling3D()(add3)

     # layer 4
     deep4 = maxpool3

     conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_1')(deep4)
     bn4_1 = BatchNormalization()(conv4_1)
     acti4_1 = Activation(activation='relu')(bn4_1)

     conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_2')(acti4_1)
     bn4_2 = BatchNormalization()(conv4_2)
     acti4_2 = Activation(activation='relu')(bn4_2)

     add4 = acti4_2

     attention = add4

     # layer12
     upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(add2)
     bn12 = BatchNormalization()(upsample12)
     acti12 = Activation(activation='relu')(bn12)

     concat12 = Concatenate(axis=-1)([add1, acti12])

     conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat12)
     bn12_1 = BatchNormalization()(conv12_1)
     acti12_1 = Activation(activation='relu')(bn12_1)

     conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti12_1)
     bn12_2 = BatchNormalization()(conv12_2)
     acti12_2 = Activation(activation='relu')(bn12_2)

     # layer22
     upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(add3)
     bn22 = BatchNormalization()(upsample22)
     acti22 = Activation(activation='relu')(bn22)

     concat22 = Concatenate(axis=-1)([add2, acti22])

     conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat22)
     bn22_1 = BatchNormalization()(conv22_1)
     acti22_1 = Activation(activation='relu')(bn22_1)

     conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti22_1)
     bn22_2 = BatchNormalization()(conv22_2)
     acti22_2 = Activation(activation='relu')(bn22_2)

     # layer13
     upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(acti22_2)
     bn13 = BatchNormalization()(upsample13)
     acti13 = Activation(activation='relu')(bn13)

     concat13 = Concatenate(axis=-1)([acti12_2, acti13])

     conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat13)
     bn13_1 = BatchNormalization()(conv13_1)
     acti13_1 = Activation(activation='relu')(bn13_1)

     conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti13_1)
     bn13_2 = BatchNormalization()(conv13_2)
     acti13_2 = Activation(activation='relu')(bn13_2)

     # layer 7 (equal to layer 3)
     upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(attention)
     bn7 = BatchNormalization()(upsample7)
     acti7 = Activation(activation='relu')(bn7)

     concat7 = Concatenate(axis=-1)([acti7, add3])

     deep7 = concat7

     conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_1')(deep7)
     bn7_1 = BatchNormalization()(conv7_1)
     acti7_1 = Activation(activation='relu')(bn7_1)

     conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
     bn7_2 = BatchNormalization()(conv7_2)
     acti7_2 = Activation(activation='relu')(bn7_2)

     # layer 8 (equal to layer 2)
     upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti7_2)
     bn8 = BatchNormalization()(upsample8)
     acti8 = Activation(activation='relu')(bn8)

     concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])

     deep8 = concat8

     conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_1')(deep8)
     bn8_1 = BatchNormalization()(conv8_1)
     acti8_1 = Activation(activation='relu')(bn8_1)

     conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
     bn8_2 = BatchNormalization()(conv8_2)
     acti8_2 = Activation(activation='relu')(bn8_2)

     # layer 9 (equal to layer 1)
     upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti8_2)
     bn9 = BatchNormalization()(upsample9)
     acti9 = Activation(activation='relu')(bn9)

     concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])

     deep9 = concat9

     conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_1')(deep9)
     bn9_1 = BatchNormalization()(conv9_1)
     acti9_1 = Activation(activation='relu')(bn9_1)

     conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
     bn9_2 = BatchNormalization()(conv9_2)
     acti9_2 = Activation(activation='relu')(bn9_2)

     output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                     kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

     model = Model(self.input, output)

     return model

VASeg

Ours network structure.

def get_model_2(self):

     # layer 1
     # bn0_1 = BatchNormalization()(self.input)
     conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_1')(self.input)
     bn1_1 = BatchNormalization()(conv1_1)
     acti1_1 = Activation(activation='relu')(bn1_1)

     conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_2')(acti1_1)
     bn1_2 = BatchNormalization()(conv1_2)
     acti1_2 = Activation(activation='relu')(bn1_2)

     multiview1 = MultiView(self.input, 16)
     add1 = Add()([acti1_2, multiview1])

     maxpool1 = reduction_block(add1, 16)

     # layer 2
     conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_1')(maxpool1)
     bn2_1 = BatchNormalization()(conv2_1)
     acti2_1 = Activation(activation='relu')(bn2_1)

     conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_2')(acti2_1)
     bn2_2 = BatchNormalization()(conv2_2)
     acti2_2 = Activation(activation='relu')(bn2_2)

     multiview2 = MultiView(maxpool1, 32)
     add2 = Add()([acti2_2, multiview2])

     maxpool2 = reduction_block(add2, 32)

     # layer 3
     conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_1')(maxpool2)
     bn3_1 = BatchNormalization()(conv3_1)
     acti3_1 = Activation(activation='relu')(bn3_1)

     conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_2')(acti3_1)
     bn3_2 = BatchNormalization()(conv3_2)
     acti3_2 = Activation(activation='relu')(bn3_2)

     multiview3 = MultiView(maxpool2, 64)
     add3 = Add()([acti3_2, multiview3])

     maxpool3 = reduction_block(add3, 64)

     # layer 4
     deep4 = deep_block(maxpool3, 128)

     conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_1')(deep4)
     bn4_1 = BatchNormalization()(conv4_1)
     acti4_1 = Activation(activation='relu')(bn4_1)

     conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_2')(acti4_1)
     bn4_2 = BatchNormalization()(conv4_2)
     acti4_2 = Activation(activation='relu')(bn4_2)

     multiview4 = MultiView(deep4, 128)
     add4 = Add()([acti4_2, multiview4])

     attention = ATT(add4)

     # layer12
     upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(add2)
     bn12 = BatchNormalization()(upsample12)
     acti12 = Activation(activation='relu')(bn12)

     concat12 = Concatenate(axis=-1)([add1, acti12])

     conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat12)
     bn12_1 = BatchNormalization()(conv12_1)
     acti12_1 = Activation(activation='relu')(bn12_1)

     conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti12_1)
     bn12_2 = BatchNormalization()(conv12_2)
     acti12_2 = Activation(activation='relu')(bn12_2)

     # layer22
     upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(add3)
     bn22 = BatchNormalization()(upsample22)
     acti22 = Activation(activation='relu')(bn22)

     concat22 = Concatenate(axis=-1)([add2, acti22])

     conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat22)
     bn22_1 = BatchNormalization()(conv22_1)
     acti22_1 = Activation(activation='relu')(bn22_1)

     conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti22_1)
     bn22_2 = BatchNormalization()(conv22_2)
     acti22_2 = Activation(activation='relu')(bn22_2)

     # layer13
     upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(acti22_2)
     bn13 = BatchNormalization()(upsample13)
     acti13 = Activation(activation='relu')(bn13)

     concat13 = Concatenate(axis=-1)([acti12_2, acti13])

     conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat13)
     bn13_1 = BatchNormalization()(conv13_1)
     acti13_1 = Activation(activation='relu')(bn13_1)

     conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti13_1)
     bn13_2 = BatchNormalization()(conv13_2)
     acti13_2 = Activation(activation='relu')(bn13_2)

     # layer 7 (equal to layer 3)
     upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(attention)
     bn7 = BatchNormalization()(upsample7)
     acti7 = Activation(activation='relu')(bn7)

     concat7 = Concatenate(axis=-1)([acti7, add3])

     deep7 = deep_block(concat7, 64)

     conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_1')(deep7)
     bn7_1 = BatchNormalization()(conv7_1)
     acti7_1 = Activation(activation='relu')(bn7_1)

     conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
     bn7_2 = BatchNormalization()(conv7_2)
     acti7_2 = Activation(activation='relu')(bn7_2)

     # layer 8 (equal to layer 2)
     upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti7_2)
     bn8 = BatchNormalization()(upsample8)
     acti8 = Activation(activation='relu')(bn8)

     concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])

     deep8 = deep_block(concat8, 32)

     conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_1')(deep8)
     bn8_1 = BatchNormalization()(conv8_1)
     acti8_1 = Activation(activation='relu')(bn8_1)

     conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
     bn8_2 = BatchNormalization()(conv8_2)
     acti8_2 = Activation(activation='relu')(bn8_2)

     # layer 9 (equal to layer 1)
     upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti8_2)
     bn9 = BatchNormalization()(upsample9)
     acti9 = Activation(activation='relu')(bn9)

     concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])

     deep9 = deep_block(concat9, 16)

     conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_1')(deep9)
     bn9_1 = BatchNormalization()(conv9_1)
     acti9_1 = Activation(activation='relu')(bn9_1)

     conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
     bn9_2 = BatchNormalization()(conv9_2)
     acti9_2 = Activation(activation='relu')(bn9_2)

     output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                     kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)

     model = Model(self.input, output)

     return model

VASeg Enhanced Version

In this version we added Magnitude Attention Gate and Transformer Block based on the base VASeg model.

def get_model_6(self):
     # layer 1
     # bn0_1 = BatchNormalization()(self.input)
     conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_1')(self.input)
     bn1_1 = BatchNormalization()(conv1_1)
     acti1_1 = Activation(activation='relu')(bn1_1)

     conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv1_2')(acti1_1)
     bn1_2 = BatchNormalization()(conv1_2)
     acti1_2 = Activation(activation='relu')(bn1_2)

     multiview1 = MultiView(self.input, 16)
     add1 = Add()([acti1_2, multiview1])

     maxpool1 = downsample_block(add1, 16)

     # layer 2
     conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_1')(maxpool1)
     bn2_1 = BatchNormalization()(conv2_1)
     acti2_1 = Activation(activation='relu')(bn2_1)

     conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv2_2')(acti2_1)
     bn2_2 = BatchNormalization()(conv2_2)
     acti2_2 = Activation(activation='relu')(bn2_2)

     multiview2 = MultiView(maxpool1, 32)
     add2 = Add()([acti2_2, multiview2])

     maxpool2 = downsample_block(add2, 32)

     # layer 3
     conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_1')(maxpool2)
     bn3_1 = BatchNormalization()(conv3_1)
     acti3_1 = Activation(activation='relu')(bn3_1)

     conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv3_2')(acti3_1)
     bn3_2 = BatchNormalization()(conv3_2)
     acti3_2 = Activation(activation='relu')(bn3_2)

     multiview3 = MultiView(maxpool2, 64)
     add3 = Add()([acti3_2, multiview3])

     maxpool3 = downsample_block(add3, 64)

     # layer 4
     # deep4 = deep_block(maxpool3, 128)
     deep4 = maxpool3

     conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_1')(deep4)
     bn4_1 = BatchNormalization()(conv4_1)
     acti4_1 = Activation(activation='relu')(bn4_1)

     conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='conv4_2')(acti4_1)
     bn4_2 = BatchNormalization()(conv4_2)
     acti4_2 = Activation(activation='relu')(bn4_2)

     multiview4 = MultiView(deep4, 128)
     add4 = Add()([acti4_2, multiview4])

     add4 = ATT(add4)

     # test = tf.signal.fft3d

     # embedding = Embedding(196, 768, input_length=196)
     # emb_pos = np.arange(128)
     # emb = Embedding(128, 512, input_length=128)(emb_pos)
     # print(emb.shape)
     # # emb = tf.transpose(emb)
     # emb = tf.reshape(emb, shape=(-1, 8, 8, 8, 128))
     # print(emb.shape)
     # embedded = Add()([add4, emb])

     print(add4.shape)
     add4 = LayerNormalization()(add4)
     attention1 = MultiHeadAttention(4, key_dim=3, dropout=0)(add4, add4)
     att_add1 = Add()([attention1, add4])
     att_add1 = LayerNormalization()(att_add1)
     mlp1 = mlp(att_add1, hidden_units=[256, 128], dropout_rate=0)
     att_add1 = Add()([mlp1, att_add1])

     att_add2 = LayerNormalization()(att_add1)
     attention2 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add2, att_add2)
     att_add2 = Add()([attention2, att_add2])
     att_add2 = LayerNormalization()(att_add2)
     mlp2 = mlp(att_add2, hidden_units=[256, 128], dropout_rate=0)
     att_add2 = Add()([mlp2, att_add2])

     att_add3 = LayerNormalization()(att_add2)
     attention3 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add3, att_add3)
     att_add3 = Add()([attention3, att_add3])
     att_add3 = LayerNormalization()(att_add3)
     mlp3 = mlp(att_add3, hidden_units=[256, 128], dropout_rate=0)
     att_add3 = Add()([mlp3, att_add3])
     attention = LayerNormalization()(att_add3)
     # attention = tf.reshape(attention, shape=(8, 8, 8, 128))

     # layer12
     upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(add2)
     bn12 = BatchNormalization()(upsample12)
     acti12 = Activation(activation='relu')(bn12)

     concat12 = Concatenate(axis=-1)([add1, acti12])

     conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat12)
     bn12_1 = BatchNormalization()(conv12_1)
     acti12_1 = Activation(activation='relu')(bn12_1)

     conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti12_1)
     bn12_2 = BatchNormalization()(conv12_2)
     acti12_2 = Activation(activation='relu')(bn12_2)





     # layer22
     upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(add3)
     bn22 = BatchNormalization()(upsample22)
     acti22 = Activation(activation='relu')(bn22)

     concat22 = Concatenate(axis=-1)([add2, acti22])

     conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat22)
     bn22_1 = BatchNormalization()(conv22_1)
     acti22_1 = Activation(activation='relu')(bn22_1)

     conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti22_1)
     bn22_2 = BatchNormalization()(conv22_2)
     acti22_2 = Activation(activation='relu')(bn22_2)





     # # Attention Gate
     #
     # AG1_1 = acti22_2 +



     # layer13
     upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                  kernel_initializer='he_normal')(acti22_2)
     bn13 = BatchNormalization()(upsample13)
     acti13 = Activation(activation='relu')(bn13)

     concat13 = Concatenate(axis=-1)([acti12_2, acti13])

     conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(concat13)
     bn13_1 = BatchNormalization()(conv13_1)
     acti13_1 = Activation(activation='relu')(bn13_1)

     conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                       kernel_initializer='he_normal')(acti13_1)
     bn13_2 = BatchNormalization()(conv13_2)
     acti13_2 = Activation(activation='relu', name='sag1_input_l')(bn13_2)

     # layer 7 (equal to layer 3)
     upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(attention)
     bn7 = BatchNormalization()(upsample7)
     acti7 = Activation(activation='relu')(bn7)

     ######### Magnitude Attention Gate 3 ################
     mag3_1_0 = K.cast(add3, "complex64")
     mag3_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(mag3_1_0)
     mag3_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_1_1)

     acti7mag = Activation(activation='sigmoid')(acti7)
     acti7mag = K.cast(acti7mag, "complex64")
     mag3_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(acti7mag)
     mag3_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_2_1)

     mag_attention3 = mag3_1_2 + mag3_2_2

     mag3_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 16, 16, 16))(mag_attention3)
     mag3_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 16, 16, 16))(mag3_3_1)
     mag3_output = Lambda(tf.abs, output_shape=(None, 16, 16, 16))(mag3_3_2)
     mag3_output = K.cast(mag3_output, "float32")
     ##################################################

     concat7 = Concatenate(axis=-1)([acti7, add3])

     # deep7 = deep_block(concat7, 64)
     deep7 = concat7

     conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_1')(deep7)
     bn7_1 = BatchNormalization()(conv7_1)
     acti7_1 = Activation(activation='relu')(bn7_1)

     #########################
     acti7_1 = Add()([acti7_1, mag3_output])
     #########################

     conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
     bn7_2 = BatchNormalization()(conv7_2)
     acti7_2 = Activation(activation='relu')(bn7_2)

     # layer 8 (equal to layer 2)
     upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti7_2)
     bn8 = BatchNormalization()(upsample8)
     acti8 = Activation(activation='relu')(bn8)

     ######### Magnitude Attention Gate 2 ################
     mag2_1_0 = K.cast(acti22_2, "complex64")
     mag2_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(mag2_1_0)
     mag2_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_1_1)

     acti8mag = Activation(activation='sigmoid')(acti8)
     acti8mag = K.cast(acti8mag, "complex64")
     mag2_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(acti8mag)
     mag2_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_2_1)

     mag_attention2 = mag2_1_2 + mag2_2_2

     mag2_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 32, 32, 32))(mag_attention2)
     mag2_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 32, 32, 32))(mag2_3_1)
     mag2_output = Lambda(tf.abs, output_shape=(None, 32, 32, 32))(mag2_3_2)
     mag2_output = K.cast(mag2_output, "float32")
     ##################################################

     concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])

     # deep8 = deep_block(concat8, 32)
     deep8 = concat8

     conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_1')(deep8)
     bn8_1 = BatchNormalization()(conv8_1)
     acti8_1 = Activation(activation='relu')(bn8_1)

     #########################
     acti8_1 = Add()([acti8_1, mag2_output])
     #########################

     conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
     bn8_2 = BatchNormalization()(conv8_2)
     acti8_2 = Activation(activation='relu')(bn8_2)

     # layer 9 (equal to layer 1)
     upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
                                 kernel_initializer='he_normal')(acti8_2)
     bn9 = BatchNormalization()(upsample9)
     acti9 = Activation(activation='relu', name='sag1_input_r')(bn9)

     ######### Magnitude Attention Gate 1 ################
     mag1_1_0 = K.cast(acti13_2, "complex64")
     mag1_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_1_1')(mag1_1_0)
     mag1_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_1_2')(mag1_1_1)

     acti9mag = Activation(activation='sigmoid', name='acti9mag')(acti9)
     acti9mag = K.cast(acti9mag, "complex64")
     mag1_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_2_1')(acti9mag)
     mag1_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_2_2')(mag1_2_1)

     mag_attention = mag1_1_2 + mag1_2_2

     mag1_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 64, 64, 64), name='mag1_3_1')(mag_attention)
     mag1_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 64, 64, 64), name='mag1_3_2')(mag1_3_1)
     mag1_output = Lambda(tf.abs, output_shape=(None, 64, 64, 64), name='mag1_output')(mag1_3_2)
     mag1_output = K.cast(mag1_output, "float32")

     concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])

     deep9 = concat9

     conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_1')(deep9)
     bn9_1 = BatchNormalization()(conv9_1)
     acti9_1 = Activation(activation='relu')(bn9_1)

     acti9_1 = Add()([acti9_1, mag1_output])

     conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
                      kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
     bn9_2 = BatchNormalization()(conv9_2)
     acti9_2 = Activation(activation='relu')(bn9_2)

     output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
                     kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)





     model = Model(self.input, output)

     return model