我正在用一个大数据集训练一个神经网络,因此我需要使用多个工人/多处理来加快训练。
以前我使用keras生成器,并使用fit生成器,其中multiprocessing设置为false,workers设置为16,但最近我不得不使用自己的生成器,所以我创建了自己的flow_from_directory生成器,如下所示:
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(image_size, image_size),
batch_size=training_batch_size,
class_mode='categorical') # set as training data
bal_gen = balanced_flow_from_directory(train_generator)
def balanced_flow_from_directory(flow_from_directory):
for x, y in flow_from_directory:
yield custom_balance(x, y)
然而,在合适的生成器中,当我保持工人>1和MultiProcessing为False,这会告诉我我的生成器不安全,因此不能与工人一起使用>1和Multiprocessing设置为False。当我要留住工人>1并将MultiProcessing设置为True,代码会运行,但它会给我警告,比如:
警告:tensorflow:使用带有
use_multiprocessing=True
的生成器和多个工作线程可能会复制您的数据。请考虑使用tf.data.Dataset
此外,验证会给出非常奇怪的输出,例如:
1661/1661==========================]-ETA:0s-损失:0.1420-精度:0.9662警告:tensorflow:使用带有
use_multiprocessing=True
的生成器和多个工作程序可能会复制您的数据。请考虑使用tf.data.Dataset
。1661/1661【=========================】-475s 286ms/步-损耗:0.1420-精度:0.9662-val_loss:6.2723-val_accuracy:0.0108elines tf.data为推荐值。
验证精度总是很低,val_loss总是很高。我能做些什么来解决这个问题吗?
更新:我发现了一个代码,使生成器函数线程安全,如下所示:
import threading
class threadsafe_iter:
"""
Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()
def threadsafe_generator(f):
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
@threadsafe_generator
def balanced_flow_from_directory(flow_from_directory):
for x, y in flow_from_directory:
yield custom_balance(x, y)
现在,我可以使用worker=16,多处理设置为False,就像我在制作自定义生成器之前使用的那样。然而,当我这样做时,每个历元需要30分钟,而以前需要7分钟。
当我使用worker=16和multiprocessing设置为true时,它会给我带来与上面设置multiprocessing为true时相同的问题,即验证准确性破坏。
也许您应该将相同的数据平衡函数应用于验证数据生成器?