LossΒΆ
Vessel Tree Dice Loss This code is to calculate the dice loss for the vessel tree.
def vt_dice(self, y_true, y_pred, axis=[1, 2, 3, 4], smooth=1e-5):
y_pred = y_pred[:, :, :, :, 1][:, :, :, :, np.newaxis]
y_true = y_true[:, :, :, :, 1][:, :, :, :, np.newaxis]
inse = tf.reduce_sum(y_pred * y_true, axis=axis) # compute intersection
l = tf.reduce_sum(y_pred * y_pred, axis=axis) # number of pixels in output
r = tf.reduce_sum(y_true * y_true, axis=axis) # number of pixels in target
dice = (2. * inse + smooth) / (l + r + smooth) # compute dice coefficient
return dice
Aneurysm Dice Loss
This code is to calculate the dice loss for the aneurysm.
def an_dice(self, y_true, y_pred, axis=[1, 2, 3, 4], smooth=1e-5):
y_pred = y_pred[:, :, :, :, 2][:, :, :, :, np.newaxis]
y_true = y_true[:, :, :, :, 2][:, :, :, :, np.newaxis]
inse = tf.reduce_sum(y_pred * y_true, axis=axis) # compute intersection
l = tf.reduce_sum(y_pred * y_pred, axis=axis) # number of pixels in output
r = tf.reduce_sum(y_true * y_true, axis=axis) # number of pixels in target
dice = (2. * inse + smooth) / (l + r + smooth) # compute dice coefficient
return dice
Dice Loss
This code is to calculate the dice loss for both vessel tree and aneurysm.
def dice_loss(self, y_true, y_pred):
loss = 0.2 * (1 - self.vt_dice(y_true, y_pred)) + 0.8 * (1 - self.an_dice(y_true, y_pred))
return loss
Boundary Aware MSE Loss
This code is to calculate the boundary aware MSE loss. The code first use Fourier Transform to get the frequency domain of the input image. Then, it uses a mask to get the high frequency part of the image. Then, reverse the Fourier Transform to get the high frequency part of the image. Finally, calculate the MSE loss between the high frequency part of the input image and the high frequency part of the output image. This loss is used to make the output image more smooth and pay more attention on small target.
def MSE_loss(self, y_true, y_pred):
y_true = K.cast(y_true, "complex64")
f = tf.signal.fft3d(y_true)
fshift = tf.signal.fftshift(f)
x1 = np.ones(shape=(4, 64, 64, 64, 3))
x1[:, 16: 48, 16: 48, 16: 48, :] = 0
x1 = tf.convert_to_tensor(x1)
x1 = K.cast(x1, "complex64")
fshift = tf.multiply(x=fshift, y=x1)
ishift = tf.signal.ifftshift(fshift)
himg = tf.signal.ifft3d(ishift)
y_true2 = tf.abs(himg)
y_pred = K.cast(y_pred, "complex64")
f2 = tf.signal.fft3d(y_pred)
fshift2 = tf.signal.fftshift(f2)
x2 = np.ones(shape=(4, 64, 64, 64, 3))
x2[:, 16: 48, 16: 48, 16: 48, :] = 0
x2 = tf.convert_to_tensor(x2)
x2 = K.cast(x2, "complex64")
fshift2 = tf.multiply(x=fshift2, y=x2)
ishift2 = tf.signal.ifftshift(fshift2)
himg2 = tf.signal.ifft3d(ishift2)
y_pred2 = tf.abs(himg2)
mae = K.mean(K.abs(y_pred2 - y_true2), axis=[1,2,3,4])
mse = K.mean(K.square(y_pred2 - y_true2), axis=[1,2,3,4])
total_loss = mae * 0.6 + mse * 0.6
return total_loss