TrainerΒΆ

Data Loader

This code is to load the numpy data we processed before. You can load 3DXY_0.npy and 3DXY_1.npy or more data files for training and validation. The train:validation ratio is 9:1. you can set

if False

to

if True

to augment the data by rotating the data 90 degrees and -90 degrees.

def load_data4(self, path):
    data0 = np.load(path + '3DXY_0.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]
    data1 = np.load(path + '3DXY_1.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]
    # data2 = np.load(path + '3DXY_2.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]
    # data3 = np.load(path + '3DXY_3.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]

    Y_data0 = np.load(path + '3DXY_0.npy')[:, :, :, :, 1]
    Y_data1 = np.load(path + '3DXY_1.npy')[:, :, :, :, 1]
    # Y_data2 = np.load(path + '3DXY_2.npy')[:, :, :, :, 1]
    # Y_data3 = np.load(path + '3DXY_3.npy')[:, :, :, :, 1]

    Y_data0 = to_categorical(Y_data0, num_classes=3)
    Y_data1 = to_categorical(Y_data1, num_classes=3)
    # Y_data2 = to_categorical(Y_data2, num_classes=3)
    # Y_data3 = to_categorical(Y_data3, num_classes=3)

    # X_train = np.concatenate((data0, data1, data2, data3), axis=0)
    # Y_train = np.concatenate((Y_data0, Y_data1, Y_data2, Y_data3), axis=0)
    X_train = np.concatenate((data0, data1), axis=0)
    Y_train = np.concatenate((Y_data0, Y_data1), axis=0)
    X_train1, Y_train1 = shuffle(X_train, Y_train)

    divide = int(np.shape(X_train)[0]/10*9)
    X_train = X_train1[:divide,:,:,:,:]
    Y_train = Y_train1[:divide,:,:,:,:]
    X_test = X_train1[divide:,:,:,:,:]
    Y_test = Y_train1[divide:,:,:,:,:]

    print("###before data augmentation###", X_train.shape, Y_train.shape)
    if False:
        X_tmp = rotate(X_train, 90, (1, 2))
        X_tmp = np.append(X_tmp, rotate(X_train, -90, (1, 2)), axis=0)
        X_train = np.append(X_train, X_tmp, axis=0)

        Y_tmp = rotate(Y_train, 90, (1, 2))
        Y_tmp = np.append(Y_tmp, rotate(Y_train, -90, (1, 2)), axis=0)
        Y_train = np.append(Y_train, Y_tmp, axis=0)

        X_train, Y_train = shuffle(X_train, Y_train)
        print("###after data augmentation###", X_train.shape, Y_train.shape)
        del data0, data1, data2, data3, Y_data0, Y_data1, Y_data2, Y_data3, X_tmp, Y_tmp

    print(X_test.shape, Y_test.shape)

    return X_train, Y_train, X_test, Y_test

Train Function

This code is used to train the network. Save the best model and the training history.

def train1_cv4(self):

    model = self.get_model_2()
    # model.load_weights(weight_path + 'train01_01.hdf5')
    # model.summary()
    X_train, Y_train, X_test, Y_test = self.load_data4(data_path)
    adam = Adam(lr=0.0003, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    model.compile(optimizer=adam, loss=self.loss, metrics=['accuracy', self.vt_dice, self.an_dice, self.MSE_loss])
    checkpointer = ModelCheckpoint(filepath= weight_path + 'train01_01.hdf5',
                                   monitor='val_loss', verbose=1, save_best_only=True)
    hist = model.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=100, batch_size=4, shuffle=True, callbacks=[checkpointer], verbose=1)
    with open(json_path + 'train01_01.json', 'w') as f:
        json.dump(hist.history, f)

Test Function

This code is used to test the network and get patch predictions. The reason why we use test and test_twice is that we have two different test sets cropped from the same image in different anchors.

def test(self, all_patch, weight_name):

    model = self.get_model_2()
    model.load_weights(weight_name)  # val_dice:
    X_test = all_patch
    pred = model.predict(X_test, verbose=1, batch_size=8)

    # dome
    pred1 = np.argmax(pred, axis=4)
    pred2 = np.where(pred1 == 1, 0, pred1)
    pred3 = np.where(pred2 == 2, 1, pred2)

    # vessels
    pred4 = np.argmax(pred, axis=4)s
    pred5 = np.where(pred4 == 2, 1, pred1)

    return pred3, pred5

def test_twice(self, all_patch2, weight_name):

    model = self.get_model_2()
    model.load_weights(weight_name)  # val_dice:
    X_test = all_patch2
    pred = model.predict(X_test, verbose=1, batch_size=8)

    # dome
    pred1 = np.argmax(pred, axis=4)
    pred2 = np.where(pred1 == 1, 0, pred1)
    pred3 = np.where(pred2 == 2, 1, pred2)

    return pred3