我有一个非常大的数据图像文件,我将其划分为较小的文件并将它们存储为pickle。现在,我需要使用它们来训练多个epoch(10,50或100)的模型?我,首先,在训练部分之前阅读它们,
pickle_in1 = open(path + "TrainPairs1.pickle", "rb")
trainPixel1 = pickle.load(pickle_in1)
trainPixel1 = np.asarray(trainPixel1)
tr_pairs1 = trainPixel1.reshape(trainPixel1.shape[0],trainPixel1.shape[1],71,71,1)
pickle_in2 = open(path + "TrainPairs2.pickle", "rb")
trainPixel2 = pickle.load(pickle_in2)
trainPixel2 = np.asarray(trainPixel2)
tr_pairs2 = trainPixel2.reshape(trainPixel2.shape[0],trainPixel2.shape[1],71,71,1)
pickle_in3 = open(path + "TrainPairs3.pickle", "rb")
trainPixel3 = pickle.load(pickle_in3)
trainPixel3 = np.asarray(trainPixel3)
tr_pairs3 = trainPixel3.reshape(trainPixel3.shape[0],trainPixel3.shape[1],71,71,1)
# train labels:
pickle_lb1 = open(path + "TrainLabels1.pickle", "rb")
tr_y1 = pickle.load(pickle_lb1)
tr_y1 = np.array(tr_y1)
# train labels:
pickle_lb2 = open(path + "TrainLabels2.pickle", "rb")
tr_y2 = pickle.load(pickle_lb2)
tr_y2 = np.array(tr_y2)
# train labels:
pickle_lb3 = open(path + "TrainLabels3.pickle", "rb")
tr_y3 = pickle.load(pickle_lb3)
tr_y3 = np.array(tr_y3)
现在我需要训练模型50个epoch,
for epoch in range(50):
# Load weights:
isExist = os.path.exists(path + "Saved_Weights")
#print(isExist)
if isExist == True:
print("Loading weights...")
model.load_weights(path + "Saved_Weights/weights.ckpt")
else:
print("No weights yet...")
# train data:
#train_pairs.append(f"tr_pairs{i}")
# train labels:
#train_y.append(f"tr_y{i}")
history = model.fit([tr_pairs1[:, 0], tr_pairs1[:, 1]], tr_y1,
batch_size=128,
epochs=epoch+1,
initial_epoch = epoch,
shuffle=True,
validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))
training_loss += history.history['loss']
test_loss += history.history['val_loss']
base_network.save(path + "my_model")
model.save_weights(path + "Saved_Weights/weights.ckpt")
print("Saving weights...")
从上面的代码中,我只能运行一个pickle文件,但是我需要训练所有的pickle文件(三个pickle文件需要在每个epoch中进行适配)。什么是最有效的方法呢?
- 更改批大小(16 ->逐渐增大(32,64,128)
- 使用上下文管理器,每次加载
with open(path + "TrainPairs1.pickle", "rb") as pickle_in1 :
trainPixel1 = pickle.load(pickle_in1)
- 如果它仍然不工作,尝试改变图像的大小之前加载为np数组