Loss ================== **Vessel Tree Dice Loss** This code is to calculate the dice loss for the vessel tree. :: def vt_dice(self, y_true, y_pred, axis=[1, 2, 3, 4], smooth=1e-5): y_pred = y_pred[:, :, :, :, 1][:, :, :, :, np.newaxis] y_true = y_true[:, :, :, :, 1][:, :, :, :, np.newaxis] inse = tf.reduce_sum(y_pred * y_true, axis=axis) # compute intersection l = tf.reduce_sum(y_pred * y_pred, axis=axis) # number of pixels in output r = tf.reduce_sum(y_true * y_true, axis=axis) # number of pixels in target dice = (2. * inse + smooth) / (l + r + smooth) # compute dice coefficient return dice **Aneurysm Dice Loss** This code is to calculate the dice loss for the aneurysm. :: def an_dice(self, y_true, y_pred, axis=[1, 2, 3, 4], smooth=1e-5): y_pred = y_pred[:, :, :, :, 2][:, :, :, :, np.newaxis] y_true = y_true[:, :, :, :, 2][:, :, :, :, np.newaxis] inse = tf.reduce_sum(y_pred * y_true, axis=axis) # compute intersection l = tf.reduce_sum(y_pred * y_pred, axis=axis) # number of pixels in output r = tf.reduce_sum(y_true * y_true, axis=axis) # number of pixels in target dice = (2. * inse + smooth) / (l + r + smooth) # compute dice coefficient return dice **Dice Loss** This code is to calculate the dice loss for both vessel tree and aneurysm. :: def dice_loss(self, y_true, y_pred): loss = 0.2 * (1 - self.vt_dice(y_true, y_pred)) + 0.8 * (1 - self.an_dice(y_true, y_pred)) return loss **Boundary Aware MSE Loss** This code is to calculate the boundary aware MSE loss. The code first use Fourier Transform to get the frequency domain of the input image. Then, it uses a mask to get the high frequency part of the image. Then, reverse the Fourier Transform to get the high frequency part of the image. Finally, calculate the MSE loss between the high frequency part of the input image and the high frequency part of the output image. This loss is used to make the output image more smooth and pay more attention on small target. :: def MSE_loss(self, y_true, y_pred): y_true = K.cast(y_true, "complex64") f = tf.signal.fft3d(y_true) fshift = tf.signal.fftshift(f) x1 = np.ones(shape=(4, 64, 64, 64, 3)) x1[:, 16: 48, 16: 48, 16: 48, :] = 0 x1 = tf.convert_to_tensor(x1) x1 = K.cast(x1, "complex64") fshift = tf.multiply(x=fshift, y=x1) ishift = tf.signal.ifftshift(fshift) himg = tf.signal.ifft3d(ishift) y_true2 = tf.abs(himg) y_pred = K.cast(y_pred, "complex64") f2 = tf.signal.fft3d(y_pred) fshift2 = tf.signal.fftshift(f2) x2 = np.ones(shape=(4, 64, 64, 64, 3)) x2[:, 16: 48, 16: 48, 16: 48, :] = 0 x2 = tf.convert_to_tensor(x2) x2 = K.cast(x2, "complex64") fshift2 = tf.multiply(x=fshift2, y=x2) ishift2 = tf.signal.ifftshift(fshift2) himg2 = tf.signal.ifft3d(ishift2) y_pred2 = tf.abs(himg2) mae = K.mean(K.abs(y_pred2 - y_true2), axis=[1,2,3,4]) mse = K.mean(K.square(y_pred2 - y_true2), axis=[1,2,3,4]) total_loss = mae * 0.6 + mse * 0.6 return total_loss