在Keras中,当在tensorflow数据集上循环时,如何使用Model.predict函数



我想在训练批次的循环中使用Model.product((函数。要将train_data划分为批次,我使用数据集类和批次函数:

@tf.function
def training(modell, train_data, batch_size):
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(parameters['batch_size'])
iterator = tf.compat.v1.data.make_initializable_iterator(train_dataset)
for step in range(4):
with tf.GradientTape() as tape:
modell.predict(iterator.get_next(), steps=1)

training(model, train, 400)

但是当运行代码时,我得到了这个错误:

file.py:36 training*modell.predict(迭代器.get_next((,steps=1(ValueError:当向模型提供符号张量时,我们希望张量具有静态批量大小。得到具有形状的张量:(无,2,48,48,1(

我查看了其他帖子,但找不到任何解决方案。

感谢

训练循环显然不应该是TF函数,只需要考虑它应该构建的图的类型。。。

忽略图形将毫无用处的事实,因为您只会调用此函数一次。。。相反,训练循环应该是"循环";正常python代码";,列车阶段(列车环路内的所有内容(可以是TF函数

还要考虑modell.predict很可能不会产生任何梯度,相反,您应该使用__call__

最新更新