Tensorflow中的预取生成器(序列)



我试图用这个Github页面后面的tf.keras.utils.Sequence-方法实现生成器:https://mahmoudyusof.github.io/facial-keypoint-detection/data-generator/

所以我的生成器有一个表单:

class Generator(tf.keras.utils.Sequence):
def __init__(self, *args, **kwargs):
self.on_epoch_end()
def on_epoch_end(self):
#shuffle indices for batches
def __len__(self):
def __getitem__(self, idx):    
#returning the idxth batch of the shuffled dataset    
return X, y

不幸的是,有了这个生成器,我的模型的训练过程变得很长,所以我想预取它

我试过

Train_Generator = tf.data.Dataset.from_generator(Generator(Training_Files, batch_size=64, shuffle = True), output_types=(np.array, np.array))

以将生成器转换为预取工作的类型。我收到错误消息:

`generator` must be callable.

我知道生成器必须支持Iter((-协议才能工作。但是我该如何实现呢?或者你们知道提高这类发电机性能的其他方法吗?

谢谢!!

我建议这样做:

Train_Generator = tf.data.Dataset.from_generator(Generator, args=[Training_Files, 64, True], output_types=(np.array, np.array))

相关内容

  • 没有找到相关文章

最新更新