TrainerΒΆ
Data Loader
This code is to load the numpy data we processed before. You can load 3DXY_0.npy and 3DXY_1.npy or more data files for training and validation. The train:validation ratio is 9:1. you can set
if False
to
if True
to augment the data by rotating the data 90 degrees and -90 degrees.
def load_data4(self, path):
data0 = np.load(path + '3DXY_0.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]
data1 = np.load(path + '3DXY_1.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]
# data2 = np.load(path + '3DXY_2.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]
# data3 = np.load(path + '3DXY_3.npy')[:, :, :, :, 0][:, :, :, :, np.newaxis]
Y_data0 = np.load(path + '3DXY_0.npy')[:, :, :, :, 1]
Y_data1 = np.load(path + '3DXY_1.npy')[:, :, :, :, 1]
# Y_data2 = np.load(path + '3DXY_2.npy')[:, :, :, :, 1]
# Y_data3 = np.load(path + '3DXY_3.npy')[:, :, :, :, 1]
Y_data0 = to_categorical(Y_data0, num_classes=3)
Y_data1 = to_categorical(Y_data1, num_classes=3)
# Y_data2 = to_categorical(Y_data2, num_classes=3)
# Y_data3 = to_categorical(Y_data3, num_classes=3)
# X_train = np.concatenate((data0, data1, data2, data3), axis=0)
# Y_train = np.concatenate((Y_data0, Y_data1, Y_data2, Y_data3), axis=0)
X_train = np.concatenate((data0, data1), axis=0)
Y_train = np.concatenate((Y_data0, Y_data1), axis=0)
X_train1, Y_train1 = shuffle(X_train, Y_train)
divide = int(np.shape(X_train)[0]/10*9)
X_train = X_train1[:divide,:,:,:,:]
Y_train = Y_train1[:divide,:,:,:,:]
X_test = X_train1[divide:,:,:,:,:]
Y_test = Y_train1[divide:,:,:,:,:]
print("###before data augmentation###", X_train.shape, Y_train.shape)
if False:
X_tmp = rotate(X_train, 90, (1, 2))
X_tmp = np.append(X_tmp, rotate(X_train, -90, (1, 2)), axis=0)
X_train = np.append(X_train, X_tmp, axis=0)
Y_tmp = rotate(Y_train, 90, (1, 2))
Y_tmp = np.append(Y_tmp, rotate(Y_train, -90, (1, 2)), axis=0)
Y_train = np.append(Y_train, Y_tmp, axis=0)
X_train, Y_train = shuffle(X_train, Y_train)
print("###after data augmentation###", X_train.shape, Y_train.shape)
del data0, data1, data2, data3, Y_data0, Y_data1, Y_data2, Y_data3, X_tmp, Y_tmp
print(X_test.shape, Y_test.shape)
return X_train, Y_train, X_test, Y_test
Train Function
This code is used to train the network. Save the best model and the training history.
def train1_cv4(self):
model = self.get_model_2()
# model.load_weights(weight_path + 'train01_01.hdf5')
# model.summary()
X_train, Y_train, X_test, Y_test = self.load_data4(data_path)
adam = Adam(lr=0.0003, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
model.compile(optimizer=adam, loss=self.loss, metrics=['accuracy', self.vt_dice, self.an_dice, self.MSE_loss])
checkpointer = ModelCheckpoint(filepath= weight_path + 'train01_01.hdf5',
monitor='val_loss', verbose=1, save_best_only=True)
hist = model.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=100, batch_size=4, shuffle=True, callbacks=[checkpointer], verbose=1)
with open(json_path + 'train01_01.json', 'w') as f:
json.dump(hist.history, f)
Test Function
This code is used to test the network and get patch predictions. The reason why we use test and test_twice is that we have two different test sets cropped from the same image in different anchors.
def test(self, all_patch, weight_name):
model = self.get_model_2()
model.load_weights(weight_name) # val_dice:
X_test = all_patch
pred = model.predict(X_test, verbose=1, batch_size=8)
# dome
pred1 = np.argmax(pred, axis=4)
pred2 = np.where(pred1 == 1, 0, pred1)
pred3 = np.where(pred2 == 2, 1, pred2)
# vessels
pred4 = np.argmax(pred, axis=4)s
pred5 = np.where(pred4 == 2, 1, pred1)
return pred3, pred5
def test_twice(self, all_patch2, weight_name):
model = self.get_model_2()
model.load_weights(weight_name) # val_dice:
X_test = all_patch2
pred = model.predict(X_test, verbose=1, batch_size=8)
# dome
pred1 = np.argmax(pred, axis=4)
pred2 = np.where(pred1 == 1, 0, pred1)
pred3 = np.where(pred2 == 2, 1, pred2)
return pred3