我试图用这个Github页面后面的tf.keras.utils.Sequence-方法实现生成器:https://mahmoudyusof.github.io/facial-keypoint-detection/data-generator/
所以我的生成器有一个表单:
class Generator(tf.keras.utils.Sequence):
def __init__(self, *args, **kwargs):
self.on_epoch_end()
def on_epoch_end(self):
#shuffle indices for batches
def __len__(self):
def __getitem__(self, idx):
#returning the idxth batch of the shuffled dataset
return X, y
不幸的是,有了这个生成器,我的模型的训练过程变得很长,所以我想预取它
我试过
Train_Generator = tf.data.Dataset.from_generator(Generator(Training_Files, batch_size=64, shuffle = True), output_types=(np.array, np.array))
以将生成器转换为预取工作的类型。我收到错误消息:
`generator` must be callable.
我知道生成器必须支持Iter((-协议才能工作。但是我该如何实现呢?或者你们知道提高这类发电机性能的其他方法吗?
谢谢!!
我建议这样做:
Train_Generator = tf.data.Dataset.from_generator(Generator, args=[Training_Files, 64, True], output_types=(np.array, np.array))