关于 fit_generator() / fit() 和线程安全



上下文

为了在 Keras 中使用fit_generator(),我使用了一个生成器函数,例如以下伪代码

def generator(data: np.array) -> (np.array, np.array):
"""Simple generator yielding some samples and targets"""
while True:
for batch in range(number_of_batches):
yield data[batch * length_sequence], data[(batch + 1) * length_sequence]

在 Keras 的fit_generator()函数中,我想使用workers=4use_multiprocessing=True- 因此,我需要一个线程安全的生成器。

在 stackoverflow 的答案中,例如 这里 或 这里 或 Keras 文档中,我读到了创建一个继承自Keras.utils.Sequence()的类,如下所示:

class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return ...

通过使用SequencesKeras不会使用多个工作和多处理抛出任何警告;生成器应该是线程安全的。

无论如何,由于我正在使用我的自定义函数,我偶然发现了 github 上提供的 Omer Zohars 代码,它允许通过添加装饰器使我的generator()线程安全。 代码如下所示:

import threading
class threadsafe_iter:
"""
Takes an iterator/generator and makes it thread-safe by
serializing call to the `next` method of given iterator/generator.
"""
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return self.it.__next__()

def threadsafe_generator(f):
"""A decorator that takes a generator function and makes it thread-safe."""
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g

现在我可以做:

@threadsafe_generator
def generator(data):
...

问题是:使用此版本的线程安全生成器,Keras 仍然会发出警告,指出在使用workers > 1use_multiprocessing=True时生成器必须是线程安全的,这可以通过使用Sequences来避免。


我现在的问题是:

Keras
  1. 发出此警告只是因为生成器没有继承Sequences,还是 Keras 还会检查生成器是否通常是线程安全的?
  2. 使用我选择的方法是否与使用 Keras-文档中的generatorClass(Sequence)版本一样线程安全?
  3. 是否有任何其他方法导致线程安全生成器 Keras 可以处理与这两个示例不同的方法?


编辑:在较新的tensorflow/keras版本(tf>2)中,fit_generator()已被弃用。相反,建议将fit()与生成器一起使用。但是,这个问题仍然适用于使用生成器的fit()

在我对此进行研究的过程中,我遇到了一些回答我的问题的信息。

注意:在较新的tensorflow/keras版本(tf> 2)中问题中更新fit_generator()已弃用。相反,建议将fit()与生成器一起使用。但是,答案仍然适用于使用生成器的fit()

<小时 />

1.Keras 发出此警告只是因为生成器没有继承序列,还是 Keras 还会检查生成器是否通常是线程安全的?

取自 Keras 的 gitRepo (training_generators.py),我在以下46-52行中找到:

use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the `keras.utils.Sequence'
' class.'))

624-635行中取自training_utils.pyis_sequence()的定义是:

def is_sequence(seq):
"""Determine if an object follows the Sequence API.
# Arguments
seq: a possible Sequence object
# Returns
boolean, whether the object follows the Sequence API.
"""
# TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
return (getattr(seq, 'use_sequence_api', False)
or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))

关于这段代码,Keras 只检查传递的生成器是否是 Keras 序列(或者更确切地说是使用 Keras 的序列 API),并且通常不检查生成器是否是线程安全的。

<小时 />

2.使用我选择的方法是否与使用 Keras-docs 中的 generatorClass(Sequence)-version 一样线程安全?

正如 Omer Zohar 在 gitHub 上展示的那样,他的装饰器是线程安全的 - 我看不出有任何理由为什么它不应该对 Keras 是线程安全的(即使 Keras 会发出警告,如 1 所示)。 根据文档,thread.Lock()的实现可以理解为线程安全:

返回新的基元锁对象的工厂函数。一旦线程获取了它,后续获取它的尝试就会阻止,直到它被释放;任何线程都可以释放它。

生成器也是可挑选的,可以像这样进行测试(有关更多信息,请参阅此处的SO-Q&A):

#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
for yielded_data in generator(data):
pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)

恢复这一点,我什至建议在扩展 KerasSequence()时实现thread.Lock(),例如:

import threading
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
self.lock = threading.Lock()   #Set self.lock
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
with self.lock:                #Use self.lock
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return ...

编辑24/04/2020:

通过使用self.lock = threading.Lock()您可能会遇到以下错误:

类型错误:无法腌制_thread.锁定对象

如果发生这种情况,请尝试用with threading.Lock():替换__getitem__内部的with self.lock:,并注释掉/删除__init__内的self.lock = threading.Lock()

在类中存储lock对象时似乎存在一些问题(例如参见此问答)。

<小时 />

3。是否有任何其他方法导致线程安全生成器 Keras 可以处理与这两个示例不同的方法?

在我的研究过程中,我没有遇到任何其他方法。 当然,我不能100%肯定地说。

相关内容

  • 没有找到相关文章

最新更新