ModelΒΆ
U-Net
U-Net is a type of convolutional neural network designed for fast and precise segmentation of images. It was developed by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in 2015. The network is based on the fully convolutional network and its architecture was modified and extended to work with fewer training images and to yield more precise segmentations. Segmentation is the process of partitioning an image into multiple segments, which can be used to identify objects and boundaries in an image.
def get_model_1_unet(self):
# layer 1
# bn0_1 = BatchNormalization()(self.input)
conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_1')(self.input)
bn1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation(activation='relu')(bn1_1)
conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_2')(acti1_1)
bn1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation(activation='relu')(bn1_2)
add1 = acti1_2
maxpool1 = MaxPooling3D()(add1)
# layer 2
conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_1')(maxpool1)
bn2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation(activation='relu')(bn2_1)
conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_2')(acti2_1)
bn2_2 = BatchNormalization()(conv2_2)
acti2_2 = Activation(activation='relu')(bn2_2)
add2 = acti2_2
maxpool2 = MaxPooling3D()(add2)
# layer 3
conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_1')(maxpool2)
bn3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation(activation='relu')(bn3_1)
conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_2')(acti3_1)
bn3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation(activation='relu')(bn3_2)
add3 = acti3_2
maxpool3 = MaxPooling3D()(add3)
# layer 4
deep4 = maxpool3
conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_1')(deep4)
bn4_1 = BatchNormalization()(conv4_1)
acti4_1 = Activation(activation='relu')(bn4_1)
conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_2')(acti4_1)
bn4_2 = BatchNormalization()(conv4_2)
acti4_2 = Activation(activation='relu')(bn4_2)
add4 = acti4_2
attention = add4
# layer 7 (equal to layer 3)
upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(attention)
bn7 = BatchNormalization()(upsample7)
acti7 = Activation(activation='relu')(bn7)
concat7 = Concatenate(axis=-1)([acti7, add3])
deep7 = concat7
conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_1')(deep7)
bn7_1 = BatchNormalization()(conv7_1)
acti7_1 = Activation(activation='relu')(bn7_1)
conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
bn7_2 = BatchNormalization()(conv7_2)
acti7_2 = Activation(activation='relu')(bn7_2)
# layer 8 (equal to layer 2)
upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti7_2)
bn8 = BatchNormalization()(upsample8)
acti8 = Activation(activation='relu')(bn8)
concat8 = Concatenate(axis=-1)([acti8, add2])
deep8 = concat8
conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_1')(deep8)
bn8_1 = BatchNormalization()(conv8_1)
acti8_1 = Activation(activation='relu')(bn8_1)
conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
bn8_2 = BatchNormalization()(conv8_2)
acti8_2 = Activation(activation='relu')(bn8_2)
# layer 9 (equal to layer 1)
upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti8_2)
bn9 = BatchNormalization()(upsample9)
acti9 = Activation(activation='relu')(bn9)
concat9 = Concatenate(axis=-1)([acti9, add1])
deep9 = concat9
conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_1')(deep9)
bn9_1 = BatchNormalization()(conv9_1)
acti9_1 = Activation(activation='relu')(bn9_1)
conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
bn9_2 = BatchNormalization()(conv9_2)
acti9_2 = Activation(activation='relu')(bn9_2)
output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)
model = Model(self.input, output)
return model
Dual Attention Network
Dual Attention Network is a segmentation model that uses spacial and channel attention to improve the segmentation results.
def get_model_1_dual_attention(self):
# layer 1
# bn0_1 = BatchNormalization()(self.input)
conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_1')(self.input)
bn1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation(activation='relu')(bn1_1)
conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_2')(acti1_1)
bn1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation(activation='relu')(bn1_2)
add1 = acti1_2
maxpool1 = MaxPooling3D()(add1)
# layer 2
conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_1')(maxpool1)
bn2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation(activation='relu')(bn2_1)
conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_2')(acti2_1)
bn2_2 = BatchNormalization()(conv2_2)
acti2_2 = Activation(activation='relu')(bn2_2)
add2 = acti2_2
maxpool2 = MaxPooling3D()(add2)
# layer 3
conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_1')(maxpool2)
bn3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation(activation='relu')(bn3_1)
conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_2')(acti3_1)
bn3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation(activation='relu')(bn3_2)
add3 = acti3_2
maxpool3 = MaxPooling3D()(add3)
# layer 4
deep4 = maxpool3
conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_1')(deep4)
bn4_1 = BatchNormalization()(conv4_1)
acti4_1 = Activation(activation='relu')(bn4_1)
conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_2')(acti4_1)
bn4_2 = BatchNormalization()(conv4_2)
acti4_2 = Activation(activation='relu')(bn4_2)
add4 = acti4_2
attention = ATT(add4)
# layer 7 (equal to layer 3)
upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(attention)
bn7 = BatchNormalization()(upsample7)
acti7 = Activation(activation='relu')(bn7)
concat7 = Concatenate(axis=-1)([acti7, add3])
deep7 = concat7
conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_1')(deep7)
bn7_1 = BatchNormalization()(conv7_1)
acti7_1 = Activation(activation='relu')(bn7_1)
conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
bn7_2 = BatchNormalization()(conv7_2)
acti7_2 = Activation(activation='relu')(bn7_2)
# layer 8 (equal to layer 2)
upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti7_2)
bn8 = BatchNormalization()(upsample8)
acti8 = Activation(activation='relu')(bn8)
concat8 = Concatenate(axis=-1)([acti8, add2])
deep8 = concat8
conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_1')(deep8)
bn8_1 = BatchNormalization()(conv8_1)
acti8_1 = Activation(activation='relu')(bn8_1)
conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
bn8_2 = BatchNormalization()(conv8_2)
acti8_2 = Activation(activation='relu')(bn8_2)
# layer 9 (equal to layer 1)
upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti8_2)
bn9 = BatchNormalization()(upsample9)
acti9 = Activation(activation='relu')(bn9)
concat9 = Concatenate(axis=-1)([acti9, add1])
deep9 = concat9
conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_1')(deep9)
bn9_1 = BatchNormalization()(conv9_1)
acti9_1 = Activation(activation='relu')(bn9_1)
conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
bn9_2 = BatchNormalization()(conv9_2)
acti9_2 = Activation(activation='relu')(bn9_2)
output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)
model = Model(self.input, output)
return model
UNet++
UNet++ is a convolutional neural network for biomedical image segmentation. It is an extension of the original UNet architecture that uses a nested, or "chained" architecture to improve segmentation results.
def get_model_1_unetpp(self):
# layer 1
# bn0_1 = BatchNormalization()(self.input)
conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_1')(self.input)
bn1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation(activation='relu')(bn1_1)
conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_2')(acti1_1)
bn1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation(activation='relu')(bn1_2)
add1 = acti1_2
maxpool1 = MaxPooling3D()(add1)
# layer 2
conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_1')(maxpool1)
bn2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation(activation='relu')(bn2_1)
conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_2')(acti2_1)
bn2_2 = BatchNormalization()(conv2_2)
acti2_2 = Activation(activation='relu')(bn2_2)
add2 = acti2_2
maxpool2 = MaxPooling3D()(add2)
# layer 3
conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_1')(maxpool2)
bn3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation(activation='relu')(bn3_1)
conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_2')(acti3_1)
bn3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation(activation='relu')(bn3_2)
add3 = acti3_2
maxpool3 = MaxPooling3D()(add3)
# layer 4
deep4 = maxpool3
conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_1')(deep4)
bn4_1 = BatchNormalization()(conv4_1)
acti4_1 = Activation(activation='relu')(bn4_1)
conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_2')(acti4_1)
bn4_2 = BatchNormalization()(conv4_2)
acti4_2 = Activation(activation='relu')(bn4_2)
add4 = acti4_2
attention = add4
# layer12
upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(add2)
bn12 = BatchNormalization()(upsample12)
acti12 = Activation(activation='relu')(bn12)
concat12 = Concatenate(axis=-1)([add1, acti12])
conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat12)
bn12_1 = BatchNormalization()(conv12_1)
acti12_1 = Activation(activation='relu')(bn12_1)
conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti12_1)
bn12_2 = BatchNormalization()(conv12_2)
acti12_2 = Activation(activation='relu')(bn12_2)
# layer22
upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(add3)
bn22 = BatchNormalization()(upsample22)
acti22 = Activation(activation='relu')(bn22)
concat22 = Concatenate(axis=-1)([add2, acti22])
conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat22)
bn22_1 = BatchNormalization()(conv22_1)
acti22_1 = Activation(activation='relu')(bn22_1)
conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti22_1)
bn22_2 = BatchNormalization()(conv22_2)
acti22_2 = Activation(activation='relu')(bn22_2)
# layer13
upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti22_2)
bn13 = BatchNormalization()(upsample13)
acti13 = Activation(activation='relu')(bn13)
concat13 = Concatenate(axis=-1)([acti12_2, acti13])
conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat13)
bn13_1 = BatchNormalization()(conv13_1)
acti13_1 = Activation(activation='relu')(bn13_1)
conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti13_1)
bn13_2 = BatchNormalization()(conv13_2)
acti13_2 = Activation(activation='relu')(bn13_2)
# layer 7 (equal to layer 3)
upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(attention)
bn7 = BatchNormalization()(upsample7)
acti7 = Activation(activation='relu')(bn7)
concat7 = Concatenate(axis=-1)([acti7, add3])
deep7 = concat7
conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_1')(deep7)
bn7_1 = BatchNormalization()(conv7_1)
acti7_1 = Activation(activation='relu')(bn7_1)
conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
bn7_2 = BatchNormalization()(conv7_2)
acti7_2 = Activation(activation='relu')(bn7_2)
# layer 8 (equal to layer 2)
upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti7_2)
bn8 = BatchNormalization()(upsample8)
acti8 = Activation(activation='relu')(bn8)
concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])
deep8 = concat8
conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_1')(deep8)
bn8_1 = BatchNormalization()(conv8_1)
acti8_1 = Activation(activation='relu')(bn8_1)
conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
bn8_2 = BatchNormalization()(conv8_2)
acti8_2 = Activation(activation='relu')(bn8_2)
# layer 9 (equal to layer 1)
upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti8_2)
bn9 = BatchNormalization()(upsample9)
acti9 = Activation(activation='relu')(bn9)
concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])
deep9 = concat9
conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_1')(deep9)
bn9_1 = BatchNormalization()(conv9_1)
acti9_1 = Activation(activation='relu')(bn9_1)
conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
bn9_2 = BatchNormalization()(conv9_2)
acti9_2 = Activation(activation='relu')(bn9_2)
output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)
model = Model(self.input, output)
return model
VASeg
Ours network structure.
def get_model_2(self):
# layer 1
# bn0_1 = BatchNormalization()(self.input)
conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_1')(self.input)
bn1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation(activation='relu')(bn1_1)
conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_2')(acti1_1)
bn1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation(activation='relu')(bn1_2)
multiview1 = MultiView(self.input, 16)
add1 = Add()([acti1_2, multiview1])
maxpool1 = reduction_block(add1, 16)
# layer 2
conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_1')(maxpool1)
bn2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation(activation='relu')(bn2_1)
conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_2')(acti2_1)
bn2_2 = BatchNormalization()(conv2_2)
acti2_2 = Activation(activation='relu')(bn2_2)
multiview2 = MultiView(maxpool1, 32)
add2 = Add()([acti2_2, multiview2])
maxpool2 = reduction_block(add2, 32)
# layer 3
conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_1')(maxpool2)
bn3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation(activation='relu')(bn3_1)
conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_2')(acti3_1)
bn3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation(activation='relu')(bn3_2)
multiview3 = MultiView(maxpool2, 64)
add3 = Add()([acti3_2, multiview3])
maxpool3 = reduction_block(add3, 64)
# layer 4
deep4 = deep_block(maxpool3, 128)
conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_1')(deep4)
bn4_1 = BatchNormalization()(conv4_1)
acti4_1 = Activation(activation='relu')(bn4_1)
conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_2')(acti4_1)
bn4_2 = BatchNormalization()(conv4_2)
acti4_2 = Activation(activation='relu')(bn4_2)
multiview4 = MultiView(deep4, 128)
add4 = Add()([acti4_2, multiview4])
attention = ATT(add4)
# layer12
upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(add2)
bn12 = BatchNormalization()(upsample12)
acti12 = Activation(activation='relu')(bn12)
concat12 = Concatenate(axis=-1)([add1, acti12])
conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat12)
bn12_1 = BatchNormalization()(conv12_1)
acti12_1 = Activation(activation='relu')(bn12_1)
conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti12_1)
bn12_2 = BatchNormalization()(conv12_2)
acti12_2 = Activation(activation='relu')(bn12_2)
# layer22
upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(add3)
bn22 = BatchNormalization()(upsample22)
acti22 = Activation(activation='relu')(bn22)
concat22 = Concatenate(axis=-1)([add2, acti22])
conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat22)
bn22_1 = BatchNormalization()(conv22_1)
acti22_1 = Activation(activation='relu')(bn22_1)
conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti22_1)
bn22_2 = BatchNormalization()(conv22_2)
acti22_2 = Activation(activation='relu')(bn22_2)
# layer13
upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti22_2)
bn13 = BatchNormalization()(upsample13)
acti13 = Activation(activation='relu')(bn13)
concat13 = Concatenate(axis=-1)([acti12_2, acti13])
conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat13)
bn13_1 = BatchNormalization()(conv13_1)
acti13_1 = Activation(activation='relu')(bn13_1)
conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti13_1)
bn13_2 = BatchNormalization()(conv13_2)
acti13_2 = Activation(activation='relu')(bn13_2)
# layer 7 (equal to layer 3)
upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(attention)
bn7 = BatchNormalization()(upsample7)
acti7 = Activation(activation='relu')(bn7)
concat7 = Concatenate(axis=-1)([acti7, add3])
deep7 = deep_block(concat7, 64)
conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_1')(deep7)
bn7_1 = BatchNormalization()(conv7_1)
acti7_1 = Activation(activation='relu')(bn7_1)
conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
bn7_2 = BatchNormalization()(conv7_2)
acti7_2 = Activation(activation='relu')(bn7_2)
# layer 8 (equal to layer 2)
upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti7_2)
bn8 = BatchNormalization()(upsample8)
acti8 = Activation(activation='relu')(bn8)
concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])
deep8 = deep_block(concat8, 32)
conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_1')(deep8)
bn8_1 = BatchNormalization()(conv8_1)
acti8_1 = Activation(activation='relu')(bn8_1)
conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
bn8_2 = BatchNormalization()(conv8_2)
acti8_2 = Activation(activation='relu')(bn8_2)
# layer 9 (equal to layer 1)
upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti8_2)
bn9 = BatchNormalization()(upsample9)
acti9 = Activation(activation='relu')(bn9)
concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])
deep9 = deep_block(concat9, 16)
conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_1')(deep9)
bn9_1 = BatchNormalization()(conv9_1)
acti9_1 = Activation(activation='relu')(bn9_1)
conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
bn9_2 = BatchNormalization()(conv9_2)
acti9_2 = Activation(activation='relu')(bn9_2)
output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)
model = Model(self.input, output)
return model
VASeg Enhanced Version
In this version we added Magnitude Attention Gate and Transformer Block based on the base VASeg model.
def get_model_6(self):
# layer 1
# bn0_1 = BatchNormalization()(self.input)
conv1_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_1')(self.input)
bn1_1 = BatchNormalization()(conv1_1)
acti1_1 = Activation(activation='relu')(bn1_1)
conv1_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv1_2')(acti1_1)
bn1_2 = BatchNormalization()(conv1_2)
acti1_2 = Activation(activation='relu')(bn1_2)
multiview1 = MultiView(self.input, 16)
add1 = Add()([acti1_2, multiview1])
maxpool1 = downsample_block(add1, 16)
# layer 2
conv2_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_1')(maxpool1)
bn2_1 = BatchNormalization()(conv2_1)
acti2_1 = Activation(activation='relu')(bn2_1)
conv2_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv2_2')(acti2_1)
bn2_2 = BatchNormalization()(conv2_2)
acti2_2 = Activation(activation='relu')(bn2_2)
multiview2 = MultiView(maxpool1, 32)
add2 = Add()([acti2_2, multiview2])
maxpool2 = downsample_block(add2, 32)
# layer 3
conv3_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_1')(maxpool2)
bn3_1 = BatchNormalization()(conv3_1)
acti3_1 = Activation(activation='relu')(bn3_1)
conv3_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv3_2')(acti3_1)
bn3_2 = BatchNormalization()(conv3_2)
acti3_2 = Activation(activation='relu')(bn3_2)
multiview3 = MultiView(maxpool2, 64)
add3 = Add()([acti3_2, multiview3])
maxpool3 = downsample_block(add3, 64)
# layer 4
# deep4 = deep_block(maxpool3, 128)
deep4 = maxpool3
conv4_1 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_1')(deep4)
bn4_1 = BatchNormalization()(conv4_1)
acti4_1 = Activation(activation='relu')(bn4_1)
conv4_2 = Conv3D(filters=128, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='conv4_2')(acti4_1)
bn4_2 = BatchNormalization()(conv4_2)
acti4_2 = Activation(activation='relu')(bn4_2)
multiview4 = MultiView(deep4, 128)
add4 = Add()([acti4_2, multiview4])
add4 = ATT(add4)
# test = tf.signal.fft3d
# embedding = Embedding(196, 768, input_length=196)
# emb_pos = np.arange(128)
# emb = Embedding(128, 512, input_length=128)(emb_pos)
# print(emb.shape)
# # emb = tf.transpose(emb)
# emb = tf.reshape(emb, shape=(-1, 8, 8, 8, 128))
# print(emb.shape)
# embedded = Add()([add4, emb])
print(add4.shape)
add4 = LayerNormalization()(add4)
attention1 = MultiHeadAttention(4, key_dim=3, dropout=0)(add4, add4)
att_add1 = Add()([attention1, add4])
att_add1 = LayerNormalization()(att_add1)
mlp1 = mlp(att_add1, hidden_units=[256, 128], dropout_rate=0)
att_add1 = Add()([mlp1, att_add1])
att_add2 = LayerNormalization()(att_add1)
attention2 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add2, att_add2)
att_add2 = Add()([attention2, att_add2])
att_add2 = LayerNormalization()(att_add2)
mlp2 = mlp(att_add2, hidden_units=[256, 128], dropout_rate=0)
att_add2 = Add()([mlp2, att_add2])
att_add3 = LayerNormalization()(att_add2)
attention3 = MultiHeadAttention(4, key_dim=3, dropout=0)(att_add3, att_add3)
att_add3 = Add()([attention3, att_add3])
att_add3 = LayerNormalization()(att_add3)
mlp3 = mlp(att_add3, hidden_units=[256, 128], dropout_rate=0)
att_add3 = Add()([mlp3, att_add3])
attention = LayerNormalization()(att_add3)
# attention = tf.reshape(attention, shape=(8, 8, 8, 128))
# layer12
upsample12 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(add2)
bn12 = BatchNormalization()(upsample12)
acti12 = Activation(activation='relu')(bn12)
concat12 = Concatenate(axis=-1)([add1, acti12])
conv12_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat12)
bn12_1 = BatchNormalization()(conv12_1)
acti12_1 = Activation(activation='relu')(bn12_1)
conv12_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti12_1)
bn12_2 = BatchNormalization()(conv12_2)
acti12_2 = Activation(activation='relu')(bn12_2)
# layer22
upsample22 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(add3)
bn22 = BatchNormalization()(upsample22)
acti22 = Activation(activation='relu')(bn22)
concat22 = Concatenate(axis=-1)([add2, acti22])
conv22_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat22)
bn22_1 = BatchNormalization()(conv22_1)
acti22_1 = Activation(activation='relu')(bn22_1)
conv22_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti22_1)
bn22_2 = BatchNormalization()(conv22_2)
acti22_2 = Activation(activation='relu')(bn22_2)
# # Attention Gate
#
# AG1_1 = acti22_2 +
# layer13
upsample13 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti22_2)
bn13 = BatchNormalization()(upsample13)
acti13 = Activation(activation='relu')(bn13)
concat13 = Concatenate(axis=-1)([acti12_2, acti13])
conv13_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(concat13)
bn13_1 = BatchNormalization()(conv13_1)
acti13_1 = Activation(activation='relu')(bn13_1)
conv13_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal')(acti13_1)
bn13_2 = BatchNormalization()(conv13_2)
acti13_2 = Activation(activation='relu', name='sag1_input_l')(bn13_2)
# layer 7 (equal to layer 3)
upsample7 = Conv3DTranspose(filters=64, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(attention)
bn7 = BatchNormalization()(upsample7)
acti7 = Activation(activation='relu')(bn7)
######### Magnitude Attention Gate 3 ################
mag3_1_0 = K.cast(add3, "complex64")
mag3_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(mag3_1_0)
mag3_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_1_1)
acti7mag = Activation(activation='sigmoid')(acti7)
acti7mag = K.cast(acti7mag, "complex64")
mag3_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 16, 16, 16))(acti7mag)
mag3_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 16, 16, 16))(mag3_2_1)
mag_attention3 = mag3_1_2 + mag3_2_2
mag3_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 16, 16, 16))(mag_attention3)
mag3_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 16, 16, 16))(mag3_3_1)
mag3_output = Lambda(tf.abs, output_shape=(None, 16, 16, 16))(mag3_3_2)
mag3_output = K.cast(mag3_output, "float32")
##################################################
concat7 = Concatenate(axis=-1)([acti7, add3])
# deep7 = deep_block(concat7, 64)
deep7 = concat7
conv7_1 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_1')(deep7)
bn7_1 = BatchNormalization()(conv7_1)
acti7_1 = Activation(activation='relu')(bn7_1)
#########################
acti7_1 = Add()([acti7_1, mag3_output])
#########################
conv7_2 = Conv3D(filters=64, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv7_2')(acti7_1)
bn7_2 = BatchNormalization()(conv7_2)
acti7_2 = Activation(activation='relu')(bn7_2)
# layer 8 (equal to layer 2)
upsample8 = Conv3DTranspose(filters=32, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti7_2)
bn8 = BatchNormalization()(upsample8)
acti8 = Activation(activation='relu')(bn8)
######### Magnitude Attention Gate 2 ################
mag2_1_0 = K.cast(acti22_2, "complex64")
mag2_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(mag2_1_0)
mag2_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_1_1)
acti8mag = Activation(activation='sigmoid')(acti8)
acti8mag = K.cast(acti8mag, "complex64")
mag2_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 32, 32, 32))(acti8mag)
mag2_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 32, 32, 32))(mag2_2_1)
mag_attention2 = mag2_1_2 + mag2_2_2
mag2_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 32, 32, 32))(mag_attention2)
mag2_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 32, 32, 32))(mag2_3_1)
mag2_output = Lambda(tf.abs, output_shape=(None, 32, 32, 32))(mag2_3_2)
mag2_output = K.cast(mag2_output, "float32")
##################################################
concat8 = Concatenate(axis=-1)([acti8, add2, acti22_2])
# deep8 = deep_block(concat8, 32)
deep8 = concat8
conv8_1 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_1')(deep8)
bn8_1 = BatchNormalization()(conv8_1)
acti8_1 = Activation(activation='relu')(bn8_1)
#########################
acti8_1 = Add()([acti8_1, mag2_output])
#########################
conv8_2 = Conv3D(filters=32, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv8_2')(acti8_1)
bn8_2 = BatchNormalization()(conv8_2)
acti8_2 = Activation(activation='relu')(bn8_2)
# layer 9 (equal to layer 1)
upsample9 = Conv3DTranspose(filters=16, kernel_size=3, strides=2, padding='same',
kernel_initializer='he_normal')(acti8_2)
bn9 = BatchNormalization()(upsample9)
acti9 = Activation(activation='relu', name='sag1_input_r')(bn9)
######### Magnitude Attention Gate 1 ################
mag1_1_0 = K.cast(acti13_2, "complex64")
mag1_1_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_1_1')(mag1_1_0)
mag1_1_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_1_2')(mag1_1_1)
acti9mag = Activation(activation='sigmoid', name='acti9mag')(acti9)
acti9mag = K.cast(acti9mag, "complex64")
mag1_2_1 = Lambda(tf.signal.fft3d, output_shape=(None, 64, 64, 64), name='mag1_2_1')(acti9mag)
mag1_2_2 = Lambda(tf.signal.fftshift, output_shape=(None, 64, 64, 64), name='mag1_2_2')(mag1_2_1)
mag_attention = mag1_1_2 + mag1_2_2
mag1_3_1 = Lambda(tf.signal.ifftshift, output_shape=(None, 64, 64, 64), name='mag1_3_1')(mag_attention)
mag1_3_2 = Lambda(tf.signal.ifft3d, output_shape=(None, 64, 64, 64), name='mag1_3_2')(mag1_3_1)
mag1_output = Lambda(tf.abs, output_shape=(None, 64, 64, 64), name='mag1_output')(mag1_3_2)
mag1_output = K.cast(mag1_output, "float32")
concat9 = Concatenate(axis=-1)([acti9, add1, acti12_2, acti13_2])
deep9 = concat9
conv9_1 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_1')(deep9)
bn9_1 = BatchNormalization()(conv9_1)
acti9_1 = Activation(activation='relu')(bn9_1)
acti9_1 = Add()([acti9_1, mag1_output])
conv9_2 = Conv3D(filters=16, kernel_size=3, strides=1, padding='same',
kernel_initializer='he_normal', name='up_conv9_2')(acti9_1)
bn9_2 = BatchNormalization()(conv9_2)
acti9_2 = Activation(activation='relu')(bn9_2)
output = Conv3D(filters=self.classes, kernel_size=1, strides=1, padding='same',
kernel_initializer='he_normal', name='output', activation='softmax')(acti9_2)
model = Model(self.input, output)
return model