我正在检查tensorflow.keras
中非常简单的指标对象,如BinaryAccuracy或AUC。他们都有reset_states()
和update_state()
参数,但我发现他们的文档不充分且不清楚。
你能解释一下它们的意思吗?
update_state
测量指标(平均值、auc、准确度),并将它们存储在对象中,以便以后可以使用result
检索:
import tensorflow as tf
mean_object = tf.metrics.Mean()
values = [1, 2, 3, 4, 5]
for ix, val in enumerate(values):
mean_object.update_state(val)
print(mean_object.result().numpy(), 'is the mean of', values[:ix+1])
1.0 is the mean of [1]
1.5 is the mean of [1, 2]
2.0 is the mean of [1, 2, 3]
2.5 is the mean of [1, 2, 3, 4]
3.0 is the mean of [1, 2, 3, 4, 5]
reset_states
将度量重置为零:
mean_object.reset_states()
mean_object.result().numpy()
0.0
我不确定我是否比文档更清楚,在我看来,它解释得很好。
调用对象,例如,mean_object([1, 2, 3, 4])
将更新度量,和返回result
。
import tensorflow as tf
mean_object = tf.metrics.Mean()
values = [1, 2, 3, 4, 5]
print(mean_object.result())
returned_mean = mean_object(values)
print(mean_object.result())
print(returned_mean)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)
tf.Tensor(3.0, shape=(), dtype=float32)