Model =================== **U-Net** U-Net is a type of convolutional neural network designed for fast and precise segmentation of images. It was developed by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in 2015. The network is based on the fully convolutional network and its architecture was modified and extended to work with fewer training images and to yield more precise segmentations. Segmentation is the process of partitioning an image into multiple segments, which can be used to identify objects and boundaries in an image. :: def get_model_1_unet(self): # layer 1 # bn0_1 = BatchNormalization()(self.input) conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_1')(self.input) bn1_1 = BatchNormalization()(conv1_1) acti1_1 = Activation(activation='relu')(bn1_1) conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_2')(acti1_1) bn1_2 = BatchNormalization()(conv1_2) acti1_2 = Activation(activation='relu')(bn1_2) add1 = acti1_2 maxpool1 = MaxPooling3D()(add1) # layer 2 conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_1')(maxpool1) bn2_1 = BatchNormalization()(conv2_1) acti2_1 = Activation(activation='relu')(bn2_1) conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_2')(acti2_1) bn2_2 = BatchNormalization()(conv2_2) acti2_2 = Activation(activation='relu')(bn2_2) add2 = acti2_2 maxpool2 = MaxPooling3D()(add2) # layer 3 conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_1')(maxpool2) bn3_1 = BatchNormalization()(conv3_1) acti3_1 = Activation(activation='relu')(bn3_1) conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_2')(acti3_1) bn3_2 = BatchNormalization()(conv3_2) acti3_2 = Activation(activation='relu')(bn3_2) add3 = acti3_2 maxpool3 = MaxPooling3D()(add3) # layer 4 deep4 = maxpool3 conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_1')(deep4) bn4_1 = BatchNormalization()(conv4_1) acti4_1 = Activation(activation='relu')(bn4_1) conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_2')(acti4_1) bn4_2 = BatchNormalization()(conv4_2) acti4_2 = Activation(activation='relu')(bn4_2) add4 = acti4_2 attention = add4 # layer 7 (equal to layer 3) upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(attention) bn7 = BatchNormalization()(upsample7) acti7 = Activation(activation='relu')(bn7) concat7 = Concatenate(axis=-1)([acti7, add3]) deep7 = concat7 conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_1')(deep7) bn7_1 = BatchNormalization()(conv7_1) acti7_1 = Activation(activation='relu')(bn7_1) conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_2')(acti7_1) bn7_2 = BatchNormalization()(conv7_2) acti7_2 = Activation(activation='relu')(bn7_2) # layer 8 (equal to layer 2) upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti7_2) bn8 = BatchNormalization()(upsample8) acti8 = Activation(activation='relu')(bn8) concat8 = Concatenate(axis=-1)([acti8, add2]) deep8 = concat8 conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_1')(deep8) bn8_1 = BatchNormalization()(conv8_1) acti8_1 = Activation(activation='relu')(bn8_1) conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_2')(acti8_1) bn8_2 = BatchNormalization()(conv8_2) acti8_2 = Activation(activation='relu')(bn8_2) # layer 9 (equal to layer 1) upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti8_2) bn9 = BatchNormalization()(upsample9) acti9 = Activation(activation='relu')(bn9) concat9 = Concatenate(axis=-1)([acti9, add1]) deep9 = concat9 conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_1')(deep9) bn9_1 = BatchNormalization()(conv9_1) acti9_1 = Activation(activation='relu')(bn9_1) conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_2')(acti9_1) bn9_2 = BatchNormalization()(conv9_2) acti9_2 = Activation(activation='relu')(bn9_2) output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2) model = Model(self.input, output) return model **Dual Attention Network** Dual Attention Network is a segmentation model that uses spacial and channel attention to improve the segmentation results. :: def get_model_1_dual_attention(self): # layer 1 # bn0_1 = BatchNormalization()(self.input) conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_1')(self.input) bn1_1 = BatchNormalization()(conv1_1) acti1_1 = Activation(activation='relu')(bn1_1) conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_2')(acti1_1) bn1_2 = BatchNormalization()(conv1_2) acti1_2 = Activation(activation='relu')(bn1_2) add1 = acti1_2 maxpool1 = MaxPooling3D()(add1) # layer 2 conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_1')(maxpool1) bn2_1 = BatchNormalization()(conv2_1) acti2_1 = Activation(activation='relu')(bn2_1) conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_2')(acti2_1) bn2_2 = BatchNormalization()(conv2_2) acti2_2 = Activation(activation='relu')(bn2_2) add2 = acti2_2 maxpool2 = MaxPooling3D()(add2) # layer 3 conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_1')(maxpool2) bn3_1 = BatchNormalization()(conv3_1) acti3_1 = Activation(activation='relu')(bn3_1) conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_2')(acti3_1) bn3_2 = BatchNormalization()(conv3_2) acti3_2 = Activation(activation='relu')(bn3_2) add3 = acti3_2 maxpool3 = MaxPooling3D()(add3) # layer 4 deep4 = maxpool3 conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_1')(deep4) bn4_1 = BatchNormalization()(conv4_1) acti4_1 = Activation(activation='relu')(bn4_1) conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_2')(acti4_1) bn4_2 = BatchNormalization()(conv4_2) acti4_2 = Activation(activation='relu')(bn4_2) add4 = acti4_2 attention = ATT(add4) # layer 7 (equal to layer 3) upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(attention) bn7 = BatchNormalization()(upsample7) acti7 = Activation(activation='relu')(bn7) concat7 = Concatenate(axis=-1)([acti7, add3]) deep7 = concat7 conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_1')(deep7) bn7_1 = BatchNormalization()(conv7_1) acti7_1 = Activation(activation='relu')(bn7_1) conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_2')(acti7_1) bn7_2 = BatchNormalization()(conv7_2) acti7_2 = Activation(activation='relu')(bn7_2) # layer 8 (equal to layer 2) upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti7_2) bn8 = BatchNormalization()(upsample8) acti8 = Activation(activation='relu')(bn8) concat8 = Concatenate(axis=-1)([acti8, add2]) deep8 = concat8 conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_1')(deep8) bn8_1 = BatchNormalization()(conv8_1) acti8_1 = Activation(activation='relu')(bn8_1) conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_2')(acti8_1) bn8_2 = BatchNormalization()(conv8_2) acti8_2 = Activation(activation='relu')(bn8_2) # layer 9 (equal to layer 1) upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti8_2) bn9 = BatchNormalization()(upsample9) acti9 = Activation(activation='relu')(bn9) concat9 = Concatenate(axis=-1)([acti9, add1]) deep9 = concat9 conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_1')(deep9) bn9_1 = BatchNormalization()(conv9_1) acti9_1 = Activation(activation='relu')(bn9_1) conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_2')(acti9_1) bn9_2 = BatchNormalization()(conv9_2) acti9_2 = Activation(activation='relu')(bn9_2) output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2) model = Model(self.input, output) return model **UNet++** UNet++ is a convolutional neural network for biomedical image segmentation. It is an extension of the original UNet architecture that uses a nested, or "chained" architecture to improve segmentation results. :: def get_model_1_unetpp(self): # layer 1 # bn0_1 = BatchNormalization()(self.input) conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_1')(self.input) bn1_1 = BatchNormalization()(conv1_1) acti1_1 = Activation(activation='relu')(bn1_1) conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_2')(acti1_1) bn1_2 = BatchNormalization()(conv1_2) acti1_2 = Activation(activation='relu')(bn1_2) add1 = acti1_2 maxpool1 = MaxPooling3D()(add1) # layer 2 conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_1')(maxpool1) bn2_1 = BatchNormalization()(conv2_1) acti2_1 = Activation(activation='relu')(bn2_1) conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_2')(acti2_1) bn2_2 = BatchNormalization()(conv2_2) acti2_2 = Activation(activation='relu')(bn2_2) add2 = acti2_2 maxpool2 = MaxPooling3D()(add2) # layer 3 conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_1')(maxpool2) bn3_1 = BatchNormalization()(conv3_1) acti3_1 = Activation(activation='relu')(bn3_1) conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_2')(acti3_1) bn3_2 = BatchNormalization()(conv3_2) acti3_2 = Activation(activation='relu')(bn3_2) add3 = acti3_2 maxpool3 = MaxPooling3D()(add3) # layer 4 deep4 = maxpool3 conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_1')(deep4) bn4_1 = BatchNormalization()(conv4_1) acti4_1 = Activation(activation='relu')(bn4_1) conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_2')(acti4_1) bn4_2 = BatchNormalization()(conv4_2) acti4_2 = Activation(activation='relu')(bn4_2) add4 = acti4_2 attention = add4 # layer12 upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(add2) bn12 = BatchNormalization()(upsample12) acti12 = Activation(activation='relu')(bn12) concat12 = Concatenate(axis=-1)([add1, acti12]) conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat12) bn12_1 = BatchNormalization()(conv12_1) acti12_1 = Activation(activation='relu')(bn12_1) conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti12_1) bn12_2 = BatchNormalization()(conv12_2) acti12_2 = Activation(activation='relu')(bn12_2) # layer22 upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(add3) bn22 = BatchNormalization()(upsample22) acti22 = Activation(activation='relu')(bn22) concat22 = Concatenate(axis=-1)([add2, acti22]) conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat22) bn22_1 = BatchNormalization()(conv22_1) acti22_1 = Activation(activation='relu')(bn22_1) conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti22_1) bn22_2 = BatchNormalization()(conv22_2) acti22_2 = Activation(activation='relu')(bn22_2) # layer13 upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti22_2) bn13 = BatchNormalization()(upsample13) acti13 = Activation(activation='relu')(bn13) concat13 = Concatenate(axis=-1)([acti12_2, acti13]) conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat13) bn13_1 = BatchNormalization()(conv13_1) acti13_1 = Activation(activation='relu')(bn13_1) conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti13_1) bn13_2 = BatchNormalization()(conv13_2) acti13_2 = Activation(activation='relu')(bn13_2) # layer 7 (equal to layer 3) upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(attention) bn7 = BatchNormalization()(upsample7) acti7 = Activation(activation='relu')(bn7) concat7 = Concatenate(axis=-1)([acti7, add3]) deep7 = concat7 conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_1')(deep7) bn7_1 = BatchNormalization()(conv7_1) acti7_1 = Activation(activation='relu')(bn7_1) conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_2')(acti7_1) bn7_2 = BatchNormalization()(conv7_2) acti7_2 = Activation(activation='relu')(bn7_2) # layer 8 (equal to layer 2) upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti7_2) bn8 = BatchNormalization()(upsample8) acti8 = Activation(activation='relu')(bn8) concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2]) deep8 = concat8 conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_1')(deep8) bn8_1 = BatchNormalization()(conv8_1) acti8_1 = Activation(activation='relu')(bn8_1) conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_2')(acti8_1) bn8_2 = BatchNormalization()(conv8_2) acti8_2 = Activation(activation='relu')(bn8_2) # layer 9 (equal to layer 1) upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti8_2) bn9 = BatchNormalization()(upsample9) acti9 = Activation(activation='relu')(bn9) concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2]) deep9 = concat9 conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_1')(deep9) bn9_1 = BatchNormalization()(conv9_1) acti9_1 = Activation(activation='relu')(bn9_1) conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_2')(acti9_1) bn9_2 = BatchNormalization()(conv9_2) acti9_2 = Activation(activation='relu')(bn9_2) output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2) model = Model(self.input, output) return model **VASeg** Ours network structure. :: def get_model_2(self): # layer 1 # bn0_1 = BatchNormalization()(self.input) conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_1')(self.input) bn1_1 = BatchNormalization()(conv1_1) acti1_1 = Activation(activation='relu')(bn1_1) conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_2')(acti1_1) bn1_2 = BatchNormalization()(conv1_2) acti1_2 = Activation(activation='relu')(bn1_2) multiview1 = MultiView(self.input, 16) add1 = Add()([acti1_2, multiview1]) maxpool1 = reduction_block(add1, 16) # layer 2 conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_1')(maxpool1) bn2_1 = BatchNormalization()(conv2_1) acti2_1 = Activation(activation='relu')(bn2_1) conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_2')(acti2_1) bn2_2 = BatchNormalization()(conv2_2) acti2_2 = Activation(activation='relu')(bn2_2) multiview2 = MultiView(maxpool1, 32) add2 = Add()([acti2_2, multiview2]) maxpool2 = reduction_block(add2, 32) # layer 3 conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_1')(maxpool2) bn3_1 = BatchNormalization()(conv3_1) acti3_1 = Activation(activation='relu')(bn3_1) conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_2')(acti3_1) bn3_2 = BatchNormalization()(conv3_2) acti3_2 = Activation(activation='relu')(bn3_2) multiview3 = MultiView(maxpool2, 64) add3 = Add()([acti3_2, multiview3]) maxpool3 = reduction_block(add3, 64) # layer 4 deep4 = deep_block(maxpool3, 128) conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_1')(deep4) bn4_1 = BatchNormalization()(conv4_1) acti4_1 = Activation(activation='relu')(bn4_1) conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_2')(acti4_1) bn4_2 = BatchNormalization()(conv4_2) acti4_2 = Activation(activation='relu')(bn4_2) multiview4 = MultiView(deep4, 128) add4 = Add()([acti4_2, multiview4]) attention = ATT(add4) # layer12 upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(add2) bn12 = BatchNormalization()(upsample12) acti12 = Activation(activation='relu')(bn12) concat12 = Concatenate(axis=-1)([add1, acti12]) conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat12) bn12_1 = BatchNormalization()(conv12_1) acti12_1 = Activation(activation='relu')(bn12_1) conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti12_1) bn12_2 = BatchNormalization()(conv12_2) acti12_2 = Activation(activation='relu')(bn12_2) # layer22 upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(add3) bn22 = BatchNormalization()(upsample22) acti22 = Activation(activation='relu')(bn22) concat22 = Concatenate(axis=-1)([add2, acti22]) conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat22) bn22_1 = BatchNormalization()(conv22_1) acti22_1 = Activation(activation='relu')(bn22_1) conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti22_1) bn22_2 = BatchNormalization()(conv22_2) acti22_2 = Activation(activation='relu')(bn22_2) # layer13 upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti22_2) bn13 = BatchNormalization()(upsample13) acti13 = Activation(activation='relu')(bn13) concat13 = Concatenate(axis=-1)([acti12_2, acti13]) conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat13) bn13_1 = BatchNormalization()(conv13_1) acti13_1 = Activation(activation='relu')(bn13_1) conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti13_1) bn13_2 = BatchNormalization()(conv13_2) acti13_2 = Activation(activation='relu')(bn13_2) # layer 7 (equal to layer 3) upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(attention) bn7 = BatchNormalization()(upsample7) acti7 = Activation(activation='relu')(bn7) concat7 = Concatenate(axis=-1)([acti7, add3]) deep7 = deep_block(concat7, 64) conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_1')(deep7) bn7_1 = BatchNormalization()(conv7_1) acti7_1 = Activation(activation='relu')(bn7_1) conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_2')(acti7_1) bn7_2 = BatchNormalization()(conv7_2) acti7_2 = Activation(activation='relu')(bn7_2) # layer 8 (equal to layer 2) upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti7_2) bn8 = BatchNormalization()(upsample8) acti8 = Activation(activation='relu')(bn8) concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2]) deep8 = deep_block(concat8, 32) conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_1')(deep8) bn8_1 = BatchNormalization()(conv8_1) acti8_1 = Activation(activation='relu')(bn8_1) conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_2')(acti8_1) bn8_2 = BatchNormalization()(conv8_2) acti8_2 = Activation(activation='relu')(bn8_2) # layer 9 (equal to layer 1) upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti8_2) bn9 = BatchNormalization()(upsample9) acti9 = Activation(activation='relu')(bn9) concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2]) deep9 = deep_block(concat9, 16) conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_1')(deep9) bn9_1 = BatchNormalization()(conv9_1) acti9_1 = Activation(activation='relu')(bn9_1) conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_2')(acti9_1) bn9_2 = BatchNormalization()(conv9_2) acti9_2 = Activation(activation='relu')(bn9_2) output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2) model = Model(self.input, output) return model **VASeg Enhanced Version** In this version we added *Magnitude Attention Gate* and *Transformer Block* based on the base VASeg model. :: def get_model_6(self): # layer 1 # bn0_1 = BatchNormalization()(self.input) conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_1')(self.input) bn1_1 = BatchNormalization()(conv1_1) acti1_1 = Activation(activation='relu')(bn1_1) conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv1_2')(acti1_1) bn1_2 = BatchNormalization()(conv1_2) acti1_2 = Activation(activation='relu')(bn1_2) multiview1 = MultiView(self.input, 16) add1 = Add()([acti1_2, multiview1]) maxpool1 = downsample_block(add1, 16) # layer 2 conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_1')(maxpool1) bn2_1 = BatchNormalization()(conv2_1) acti2_1 = Activation(activation='relu')(bn2_1) conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv2_2')(acti2_1) bn2_2 = BatchNormalization()(conv2_2) acti2_2 = Activation(activation='relu')(bn2_2) multiview2 = MultiView(maxpool1, 32) add2 = Add()([acti2_2, multiview2]) maxpool2 = downsample_block(add2, 32) # layer 3 conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_1')(maxpool2) bn3_1 = BatchNormalization()(conv3_1) acti3_1 = Activation(activation='relu')(bn3_1) conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv3_2')(acti3_1) bn3_2 = BatchNormalization()(conv3_2) acti3_2 = Activation(activation='relu')(bn3_2) multiview3 = MultiView(maxpool2, 64) add3 = Add()([acti3_2, multiview3]) maxpool3 = downsample_block(add3, 64) # layer 4 # deep4 = deep_block(maxpool3, 128) deep4 = maxpool3 conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_1')(deep4) bn4_1 = BatchNormalization()(conv4_1) acti4_1 = Activation(activation='relu')(bn4_1) conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='conv4_2')(acti4_1) bn4_2 = BatchNormalization()(conv4_2) acti4_2 = Activation(activation='relu')(bn4_2) multiview4 = MultiView(deep4, 128) add4 = Add()([acti4_2, multiview4]) add4 = ATT(add4) # test = tf.signal.fft3d # embedding = Embedding(196, 768, input_length=196) # emb_pos = np.arange(128) # emb = Embedding(128, 512, input_length=128)(emb_pos) # print(emb.shape) # # emb = tf.transpose(emb) # emb = tf.reshape(emb, shape=(-1, 8, 8, 8, 128)) # print(emb.shape) # embedded = Add()([add4, emb]) print(add4.shape) add4 = LayerNormalization()(add4) attention1 = MultiHeadAttention(4, key_dim=3, dropout=0)(add4, add4) att_add1 = Add()([attention1, add4]) att_add1 = LayerNormalization()(att_add1) mlp1 = mlp(att_add1, hidden_units=[256, 128], dropout_rate=0) att_add1 = Add()([mlp1, att_add1]) att_add2 = LayerNormalization()(att_add1) attention2 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add2, att_add2) att_add2 = Add()([attention2, att_add2]) att_add2 = LayerNormalization()(att_add2) mlp2 = mlp(att_add2, hidden_units=[256, 128], dropout_rate=0) att_add2 = Add()([mlp2, att_add2]) att_add3 = LayerNormalization()(att_add2) attention3 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add3, att_add3) att_add3 = Add()([attention3, att_add3]) att_add3 = LayerNormalization()(att_add3) mlp3 = mlp(att_add3, hidden_units=[256, 128], dropout_rate=0) att_add3 = Add()([mlp3, att_add3]) attention = LayerNormalization()(att_add3) # attention = tf.reshape(attention, shape=(8, 8, 8, 128)) # layer12 upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(add2) bn12 = BatchNormalization()(upsample12) acti12 = Activation(activation='relu')(bn12) concat12 = Concatenate(axis=-1)([add1, acti12]) conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat12) bn12_1 = BatchNormalization()(conv12_1) acti12_1 = Activation(activation='relu')(bn12_1) conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti12_1) bn12_2 = BatchNormalization()(conv12_2) acti12_2 = Activation(activation='relu')(bn12_2) # layer22 upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(add3) bn22 = BatchNormalization()(upsample22) acti22 = Activation(activation='relu')(bn22) concat22 = Concatenate(axis=-1)([add2, acti22]) conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat22) bn22_1 = BatchNormalization()(conv22_1) acti22_1 = Activation(activation='relu')(bn22_1) conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti22_1) bn22_2 = BatchNormalization()(conv22_2) acti22_2 = Activation(activation='relu')(bn22_2) # # Attention Gate # # AG1_1 = acti22_2 + # layer13 upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti22_2) bn13 = BatchNormalization()(upsample13) acti13 = Activation(activation='relu')(bn13) concat13 = Concatenate(axis=-1)([acti12_2, acti13]) conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(concat13) bn13_1 = BatchNormalization()(conv13_1) acti13_1 = Activation(activation='relu')(bn13_1) conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal')(acti13_1) bn13_2 = BatchNormalization()(conv13_2) acti13_2 = Activation(activation='relu', name='sag1_input_l')(bn13_2) # layer 7 (equal to layer 3) upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(attention) bn7 = BatchNormalization()(upsample7) acti7 = Activation(activation='relu')(bn7) ######### Magnitude Attention Gate 3 ################ mag3_1_0 = K.cast(add3, "complex64") mag3_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(mag3_1_0) mag3_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_1_1) acti7mag = Activation(activation='sigmoid')(acti7) acti7mag = K.cast(acti7mag, "complex64") mag3_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(acti7mag) mag3_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_2_1) mag_attention3 = mag3_1_2 + mag3_2_2 mag3_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 16, 16, 16))(mag_attention3) mag3_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 16, 16, 16))(mag3_3_1) mag3_output = Lambda(tf.abs, output_shape=(None, 16, 16, 16))(mag3_3_2) mag3_output = K.cast(mag3_output, "float32") ################################################## concat7 = Concatenate(axis=-1)([acti7, add3]) # deep7 = deep_block(concat7, 64) deep7 = concat7 conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_1')(deep7) bn7_1 = BatchNormalization()(conv7_1) acti7_1 = Activation(activation='relu')(bn7_1) ######################### acti7_1 = Add()([acti7_1, mag3_output]) ######################### conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv7_2')(acti7_1) bn7_2 = BatchNormalization()(conv7_2) acti7_2 = Activation(activation='relu')(bn7_2) # layer 8 (equal to layer 2) upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti7_2) bn8 = BatchNormalization()(upsample8) acti8 = Activation(activation='relu')(bn8) ######### Magnitude Attention Gate 2 ################ mag2_1_0 = K.cast(acti22_2, "complex64") mag2_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(mag2_1_0) mag2_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_1_1) acti8mag = Activation(activation='sigmoid')(acti8) acti8mag = K.cast(acti8mag, "complex64") mag2_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(acti8mag) mag2_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_2_1) mag_attention2 = mag2_1_2 + mag2_2_2 mag2_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 32, 32, 32))(mag_attention2) mag2_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 32, 32, 32))(mag2_3_1) mag2_output = Lambda(tf.abs, output_shape=(None, 32, 32, 32))(mag2_3_2) mag2_output = K.cast(mag2_output, "float32") ################################################## concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2]) # deep8 = deep_block(concat8, 32) deep8 = concat8 conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_1')(deep8) bn8_1 = BatchNormalization()(conv8_1) acti8_1 = Activation(activation='relu')(bn8_1) ######################### acti8_1 = Add()([acti8_1, mag2_output]) ######################### conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv8_2')(acti8_1) bn8_2 = BatchNormalization()(conv8_2) acti8_2 = Activation(activation='relu')(bn8_2) # layer 9 (equal to layer 1) upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal')(acti8_2) bn9 = BatchNormalization()(upsample9) acti9 = Activation(activation='relu', name='sag1_input_r')(bn9) ######### Magnitude Attention Gate 1 ################ mag1_1_0 = K.cast(acti13_2, "complex64") mag1_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_1_1')(mag1_1_0) mag1_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_1_2')(mag1_1_1) acti9mag = Activation(activation='sigmoid', name='acti9mag')(acti9) acti9mag = K.cast(acti9mag, "complex64") mag1_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_2_1')(acti9mag) mag1_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_2_2')(mag1_2_1) mag_attention = mag1_1_2 + mag1_2_2 mag1_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 64, 64, 64), name='mag1_3_1')(mag_attention) mag1_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 64, 64, 64), name='mag1_3_2')(mag1_3_1) mag1_output = Lambda(tf.abs, output_shape=(None, 64, 64, 64), name='mag1_output')(mag1_3_2) mag1_output = K.cast(mag1_output, "float32") concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2]) deep9 = concat9 conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_1')(deep9) bn9_1 = BatchNormalization()(conv9_1) acti9_1 = Activation(activation='relu')(bn9_1) acti9_1 = Add()([acti9_1, mag1_output]) conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', name='up_conv9_2')(acti9_1) bn9_2 = BatchNormalization()(conv9_2) acti9_2 = Activation(activation='relu')(bn9_2) output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same', kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2) model = Model(self.input, output) return model