我正在使用 Keras OCR 示例:https://github.com/keras-team/keras/blob/master/examples/image_ocr.py 用于在线手写识别,但在模型训练后面临内存分配问题,同时使用 theano 函数获取 softmax 输出。 x_train形状:(1200,1586,4(。我正在以 1200 个批次喂食 12 个冲程序列。 下面是代码片段:
inputs = Input(name='the_input', shape=x_train.shape[1:], dtype='float32')
rnn_encoded = Bidirectional(LSTM(64, return_sequences=True,kernel_initializer=init,bias_initializer=bias),name='bidirectional_1',merge_mode='concat',trainable=trainable)(inputs)
birnn_encoded = Bidirectional(LSTM(32, return_sequences=True,kernel_initializer=init,bias_initializer=bias),name='bidirectional_2',merge_mode='concat',trainable=trainable)(rnn_encoded)
trirnn_encoded=Bidirectional(LSTM(16,return_sequences=True,kernel_initializer=init,bias_initializer=bias),name='bidirectional_3',merge_mode='concat',trainable=trainable)(birnn_encoded)
output = TimeDistributed(Dense(28, name='dense',kernel_initializer=init,bias_initializer=bias))(trirnn_encoded)
y_pred = Activation('softmax', name='softmax')(output)
model=Model(inputs=inputs,outputs=y_pred)
labels = Input(name='the_labels', shape=[max_len], dtype='int32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)
opt=RMSprop(lr=0.001,clipnorm=1.)
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=opt)
gc.collect()
my_generator = generator(x_train,y_train,batch_size)
hist= model.fit_generator(my_generator,epochs=80,steps_per_epoch=100,shuffle=True,use_multiprocessing=False,workers=1)
model.save(mfile)
test_func = K.function([inputs], [y_pred])
内存分配错误发生在最后一行。我在 AWS 上使用一台带有 8vCPU 的 32GB RAM 机器。当我运行较少的纪元数(大约 30-40 个(时,但主要是当我运行大量纪元(如 80-100(时,不会发生错误。我还附上了错误的屏幕截图.1 除了减少数据集大小或纪元数之外,如果还有其他解决方案,请告诉我。
我无法消除内存错误,但我能够获得测试结果。我刚刚保存了这个模型,稍后在另一个脚本中加载了它。在加载时,我使用此处建议的答案从模型中获取相同的测试函数,并在循环中为其提供训练/测试数据,以获取解码的字符串并计算准确性。这也不会导致任何内存错误。