Keras 无法在回调中计算图形节点



为了说明这个问题,请考虑以下模型的(完全人为的(示例:

import numpy as np
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
class RandomSeq(Sequence):
def __len__(self):
return 5
def __getitem__(self, idx):
return (np.array([np.arange(39).reshape((39,)) for i in range(100)]), 
np.array([-np.arange(39).reshape((39,)) for i in range(100)]))
class Foo(Callback):
def __init__(self, d):
super(Foo, self).__init__()
self._d = d
def on_epoch_end(self, epoch, logs=None):
print(epoch)
print(K.eval(self._d))

x = Input(shape=(39,), dtype='float32', name='input')
y_pred = Dense(39)(x)
y_another = x * 2
m = Model(inputs=x, outputs=y_pred)
m.compile(optimizer='sgd', loss='mse')
seq = RandomSeq()
m.fit_generator(seq, epochs=5, callbacks=[Foo(y_another)])

RandomSeq只是一个返回xy批次的序列。Foo是一个回调,它将尝试在纪元结束时评估附加的数量d。对我来说,如果我选择dy_predy_another,那么 Keras 抱怨占位符x(input(没有喂食。

您必须为 dtype 浮点数和形状的占位符张量 'input_1' 输入一个值 [?,39]

这是预期行为吗?如果是这样,有没有另一种方法可以在 Keras 回调中计算节点?请注意,如果没有计算上述图形节点的回调,则该示例工作正常。

这不是正确的方法。 运行 K.eval(y_another(,您要求 keras 后端评估输入对象(它只是您要馈送到网络的数据的占位符(,而不向其提供任何数据,这就是您错误的原因。

因此,假设你想计算网络的输出,给定一个新的输入,这是一个随机序列乘以 2(这是对的吗?(,那么你应该在回调中修改on_epoch_end(self, epoch, logs=None)主体,如下所示:

x, _ = next(RandomSeq())     
print(self.model.predict(x*2))

最新更新