在函数keras API中使用tf.keras.metrics.MMean时出错



我正试图将均值度量添加到Keras函数模型(Tensorflow 2.5(中,但收到以下错误:

ValueError: Expected a symbolic Tensor for the metric value, received: tf.Tensor(0.0, shape=(), dtype=float32)

这是代码:

x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [5 + i * 3 for i in x]
a = Input(shape=(1,))
output = Dense(1)(a)
model = Model(inputs=a,outputs=output)
model.add_metric(tf.keras.metrics.Mean()(output))
model.compile(loss='mse')
model.fit(x=x, y=y, epochs=100)

如果我删除以下行(从中抛出异常(:

model.add_metric(tf.keras.metrics.Mean()(output))

代码按预期工作。

我尝试禁用急切执行,但我得到了以下错误:

ValueError: Using the result of calling a `Metric` object when calling `add_metric` on a Functional Model is not supported. Please pass the Tensor to monitor directly.

上面的用法几乎是从tf.keras.metrics.Mean文档中复制的(请参阅compile((API的用法(

我找到了一种绕过问题的方法,方法是完全避免使用model.add_metric,并将Metric对象传递给compile()方法
但是,当按如下方式传递tf.keras.metrics.Mean的实例时:

model.compile(loss='mse', metrics=tf.keras.metrics.Mean())

我从compile()方法中得到以下错误:

TypeError: update_state() got multiple values for argument 'sample_weight'

为了解决这个问题,我不得不扩展tf.keras.metrics.Mean并更改update_state的签名以匹配期望的签名
这是最后的(工作(代码:

class FixedMean(tf.keras.metrics.Mean):
def update_state(self, y_true, y_pred, sample_weight=None):
super().update_state(y_pred, sample_weight=sample_weight)
x = [1, 2, 3, 4, 5, 6, 7, 8]
y = [5 + i * 3 for i in x]
a = Input(shape=(1,))
output = Dense(1)(a)
model = Model(inputs=a,outputs=output)
model.compile(loss='mse', metrics=FixedMean())
model.fit(x=x, y=y, epochs=100)

相关内容

  • 没有找到相关文章

最新更新