为什么我的tensorflow生成器功能如此缓慢



下面的生成器函数太慢。有没有一种方法可以优化这个代码?。train_dataset_c1是表单图像的类别1的训练数据集,1train_dataset_c0是表单图像0类的列车数据集,0

def generator(positive_dataset, negative_dataset):
while True:
for pos_rec, neg_rec in zip(positive_dataset, negative_dataset):
pos_x, pos_y = pos_rec
neg_x, neg_y = neg_rec
x = tf.concat([pos_x, neg_x], axis=0)
y = tf.concat([pos_y, neg_y], axis=0)
yield x, y
train_generator = generator(train_dataset_c1, train_dataset_c0)
test_generator = generator(test_dataset_c1, test_dataset_c0)

如果您使用tensorflow 2.0,我建议您使用tf.data API来加快管道速度。

事实上,有一个from_generator函数,你可以应用到你的生成器来加速

使用此函数将其转换为tf.data.Dataset对象后,您可以使用本教程中的任何策略对其进行进一步优化

最新更新