tf中reset_states()和update_state()的含义是什么?keras指标?



我正在检查tensorflow.keras中非常简单的指标对象,如BinaryAccuracy或AUC。他们都有reset_states()update_state()参数,但我发现他们的文档不充分且不清楚。

你能解释一下它们的意思吗?

update_state测量指标(平均值、auc、准确度),并将它们存储在对象中,以便以后可以使用result检索:

import tensorflow as tf
mean_object = tf.metrics.Mean()
values = [1, 2, 3, 4, 5]
for ix, val in enumerate(values):
    mean_object.update_state(val)
    print(mean_object.result().numpy(), 'is the mean of', values[:ix+1])
1.0 is the mean of [1]
1.5 is the mean of [1, 2]
2.0 is the mean of [1, 2, 3]
2.5 is the mean of [1, 2, 3, 4]
3.0 is the mean of [1, 2, 3, 4, 5]

reset_states将度量重置为零:

mean_object.reset_states()
mean_object.result().numpy()
0.0

我不确定我是否比文档更清楚,在我看来,它解释得很好。

调用对象,例如,mean_object([1, 2, 3, 4])将更新度量,返回result

import tensorflow as tf
mean_object = tf.metrics.Mean()
values = [1, 2, 3, 4, 5]
print(mean_object.result())
returned_mean = mean_object(values)
print(mean_object.result())
print(returned_mean)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)

相关内容

  • 没有找到相关文章

最新更新