我正在尝试处理 Keras 中的大型训练数据集。
我将model.fit_generator
与从SQL文件中读取数据的自定义生成器一起使用。
我收到一条错误消息,告诉我无法在两个不同的线程中使用 SQLite 对象:
ProgrammingError: SQLite objects created in a thread can only be used in that
same thread.The object was created in thread id 140736714019776 and this is
thread id 123145449209856
我尝试对 HDF5 文件执行相同的操作,并遇到了分段错误,我现在认为这也与fit_generator
的多线程字符有关(请参阅此处报告的错误)。
使用这些生成器的正确方法是什么,因为我相信必须从不适合内存的数据集的文件中批量读取数据是很常见的。
下面是生成器的代码:
class DataGenerator:
def __init__(self, inputfile, batch_size, **kwargs):
self.inputfile = inputfile
self.batch_size = batch_size
def generate(self, labels, idlist):
while 1:
for batch in self._read_data_from_hdf(idlist):
batch = pandas.merge(batch, labels, how='left', on=['id'])
Y = batch['label']
X = batch.drop(['id', 'label'], axis=1)
yield (X, Y)
def _read_data_from_hdf(self, idlist):
chunklist = [idlist[i:i + self.batch_size] for i in range(0, len(idlist), self.batch_size)]
for chunk in chunklist:
yield pandas.read_hdf(self.inputfile, key='data', where='id in {}'.format(chunk))
# [...]
model.fit_generator(generator=training_generator,
steps_per_epoch=len(partitions['train']) //
config['batch_size'],
validation_data=validation_generator,
validation_steps=len(partitions['validation']) //
config['batch_size'],
epochs=config['epochs'])
在此处查看完整的示例存储库。
感谢您的支持。
干杯
本
面对同样的问题,我通过将线程安全装饰器与可以管理对数据库的并发访问的sqlalchemy
引擎相结合,找到了解决方案:
import pandas
from sqlalchemy import create_engine
class threadsafe_iter:
def __init__(self, it):
self.it = it
self.lock = threading.Lock()
def __iter__(self):
return self
def __next__(self):
with self.lock:
return next(self.it)
def threadsafe_generator(f):
def g(*a, **kw):
return threadsafe_iter(f(*a, **kw))
return g
class DataGenerator:
def __init__(self, inputfile, batch_size, **kwargs):
self.inputfile = inputfile
self.batch_size = batch_size
self.sqlengine = create_engine('sqlite:///' + self.inputfile)
def __del__(self):
self.sqlengine.dispose()
@threadsafe_generator
def generate(self, labels, idlist):
while 1:
for batch in self._read_data_from_sql(idlist):
Y = batch['label']
X = batch.drop(['id', 'label'], axis=1)
yield (X, Y)
def _read_data_from_sql(self, idlist):
chunklist = [idlist[i:i + self.batch_size]
for i in range(0, len(idlist), self.batch_size)]
for chunk in chunklist:
query = 'select * from data where id in {}'.format(tuple(chunk))
df = pandas.read_sql(query, self.sqlengine)
yield df
# Build keras model and instantiate generators
model.fit_generator(generator=training_generator,
steps_per_epoch=train_steps,
validation_data=validation_generator,
validation_steps=valid_steps,
epochs=10,
workers=4)
我希望这有所帮助!