计算张量中两个值的出现次数



我想计算张量中两个值的出现次数。以下代码有效,但张量中不存在一个或两个值的情况除外。在这种情况下,它会崩溃并显示(预期的(错误:InvalidArgumentError: Expected begin and size arguments to be 1-D tensors of size 1, but got shapes [0] and [1] instead.

我如何修改此代码(不使用条件(,以便它只为缺失值提供 0 计数而不是崩溃。

wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def count_occurrences(t, val1, val2):
y, idx, count = tf.unique_with_counts(tf.reshape(t, [-1]))
idx_val1 = tf.reshape(tf.where(tf.equal(y, val1)), [-1])
idx_val2 = tf.reshape(tf.where(tf.equal(y, val2)), [-1])
return tf.slice(count, idx_val1, [1]) + tf.slice(count, idx_val2, [1])
print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))

您可以简单地执行以下操作:

wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def count_occurrences(t, val1, val2):
eq = tf.logical_or(tf.equal(t, val1), tf.equal(t, val2))
return tf.count_nonzero(eq)
print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))

但请注意,一般来说,比较浮点数是否相等并不是最佳选择。具有一定容忍度的可能替代方案可能是:

wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def count_occurrences(t, val1, val2, epsilon=1e-8):
eq1 = tf.abs(t - val1) < epsilon
eq2 = tf.abs(t - val2) < epsilon
eq = tf.logical_or(eq1, eq2)
return tf.count_nonzero(eq)
print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))

我认为你可以做这样的事情

wts = tf.Variable([[-2.0, 0.0, 0.05], [-0.95, 0.0, -0.05], [1.0, -2.5, 1.0]])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
def count_occurrences(t, val1, val2):
y, idx, count = tf.unique_with_counts(tf.reshape(t, [-1]))
idx_val1 = tf.reshape(tf.where(tf.equal(y, val1)), [-1])
idx_val2 = tf.reshape(tf.where(tf.equal(y, val2)), [-1])
temp = tf.cond(tf.greater(tf.shape(idx_val1)[0], 0), 
lambda: tf.slice(count, idx_val1, [1]), 
lambda: [0]) 
temp = temp + tf.cond(tf.greater(tf.shape(idx_val2)[0], 0), 
lambda: tf.slice(count, idx_val2, [1]), 
lambda: [0])
return temp
print(count_occurrences(wts, 1.0, -2.0).eval(session=sess))

最新更新