使用python生成器进行输入时,keras中出现OOM错误



所以我使用了相同的自动编码器模型,批量大小为10,没有生成器(通过在内存中加载元素(,并且该模型运行时没有任何问题。

我定义了一个python生成器,这样我就可以通过以下方式接收更多数据:-

from sklearn.utils import shuffle
def nifti_gen(samples, batch_size = 5):
num_samples = len(samples)
while True:
for bat in range(0,num_samples,batch_size):
temp_batch = samples[bat:bat+batch_size]
batch_data = []
batch_data = np.asarray(batch_data)
for i,element in enumerate(temp_batch):
temp = get_input(element)
if i == 0:
batch_data = temp
else :
batch_data = np.concatenate((batch_data,temp))
yield batch_data,batch_data
from sklearn.model_selection import train_test_split
train_samples, validation_samples = train_test_split(IO_paths[:400], test_size=0.1)
train_generator = nifti_gen(train_samples, batch_size=5)
validation_generator = nifti_gen(validation_samples, batch_size=5)

然而,当我尝试训练模型时,甚至在一个历元完成之前,我就得到了以下错误:-

autoencoder_train = MRA_autoencoder.fit(train_generator, steps_per_epoch= 36 , callbacks= [es,mc] , epochs= 300)
Epoch 1/300
---------------------------------------------------------------------------
ResourceExhaustedError                    Traceback (most recent call last)
<ipython-input-32-65838b7c908e> in <module>()
----> 1 autoencoder_train = MRA_autoencoder.fit(train_generator, steps_per_epoch= 36 , callbacks= [es,mc] , epochs= 300)
8 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58     ctx.ensure_initialized()
59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
61   except core._NotOkStatusException as e:
62     if name is not None:
ResourceExhaustedError:  OOM when allocating tensor with shape[500,84,400,400] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[node functional_5/functional_1/conv2d/Conv2D (defined at <ipython-input-32-65838b7c908e>:1) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
[Op:__inference_train_function_4760]
Function call stack:
train_function

我不知道为什么会发生这种情况,因为我确信我肯定有足够的内存来处理至少10个批量。任何帮助都将不胜感激!感谢

看起来数据很大。[500400400]是一个需要处理的非常非常大的数据,在每一层,最好是恢复到5的批量大小,或者转移到基于多gpu云的训练。

最新更新