张量流,检查张量中的哪些值是整数



我有一个张量:

numbers = tf.constant([[4.00, 3.33], [2.34, 7.00]])

我想做的是得到一个与"数字"具有相同维度的张量,但在数字是整数的索引处有 1,在不是整数的索引处有 0,如下所示:

ans = [[1, 0],[0, 1]]

我猜我将不得不使用 tf.where(( 也许?我真的不确定如何使用张量流做这样的事情。谢谢

import tensorflow as tf
import numpy as np
tf.reset_default_graph()
with tf.Session() as sess:
fake_data = np.asarray([[1,2.4, 3.5], [3.4, 2.00, 10.001], [105.1, 100, 10]])
a = tf.constant(data)
# find where floor == actual value (thus, is a whole number)
mask = tf.equal(tf.floor(data), data)
# Get the indices
idx = tf.where(mask)

print(sess.run(a))
print(sess.run(idx))

这应该可以解决问题:)我尝试评论它,以便我正在做的事情很清楚,我认为这很容易理解。我写的是基于这个评论

最新更新