我想用model.fit()方法训练一个keras模型(比如说一个简单的FFNN),而不是"手工"(即通过使用梯度)。磁带方法解释(例如这里)。然而,我需要使用的损失函数非常复杂,不能在随机生成的数据批上计算。因此,我需要使用"手工"计算的数据批次来训练模型(即进入每个批次的数据需要具有某些属性,不能随机分配)。
我可以通过某种方式传递预先计算的批到fit()方法吗?
一个解决方案是对Tensorflow序列进行子类化。您可以使用__getitem__
方法为给定索引创建自己的批处理。
class MySequence(tf.keras.utils.Sequence):
def __init__(self, x_batch, y_batch) -> None:
super().__init__()
self.x_batch = x_batch # ordered list of batches
self.y_batch = y_batch # idem
self.leny = len(y_batch)
def __len__(self):
return self.leny
def __getitem__(self, idx):
x = self.x_batch[idx]
y = self.y_batch[idx]
return x, y
你可以传递一个Sequence子类的实例给Modelfit
方法。
同时在fit
模型参数中设置shuffle=False
。