我有一个大小为NXN的浮点矩阵a和另一个大小为NXN的布尔矩阵B
对于每一行,我需要找到A中属于索引的所有值的平均值,其中True是矩阵B中该索引的对应值
同样,我需要找到A中属于索引的所有值的平均值,其中False是矩阵B
中该索引的对应值最后,我需要找到"为真"平均值小于"false";意思是
例如:
A = [[1.0, 2.0, 3.0]
[4.0, 5.0, 6.0]
[7.0, 8.0, 9.0]]
B = [[True, True, False]
[False, False, True]
[True, False, True]]
count = 0
For row 1, true_mean = 1.0+2.0/2 = 1.5 and false = 3.0 For row 2, true_mean = 6.0 and谬误= 4.0+5.0/2 = 4.5 对于第3行,true_mean = 7.0+9.0/2 = 8.0 and谬误= 8.0 最终计数值= 1 我的尝试:- 但这实际上给出了错误的答案,因为分母不是该行中正确/错误值的数量,而是'N' 我只需要count,不需要true_mean和false_mean 如何修复它?
true_mean
true_mean>false =平均值
true_mean ==谬误= 8.0,因此计数保持不变true_mat = np.where(B, A, 0)
false_mat = np.where(B, 0, A)
true_mean = true_mat.mean(axis=1)
false_mean = false_mat.mean(axis=1)
均值问题可以通过计算mask
来解决:
mask_norm = tf.reduce_sum(tf.clip_by_value(true_mat, 0., 1.),axis=0)
true_mean = tf.math.divide(tf.reduce_sum(true_mat, axis=1), mask_norm)
#true_mean : [1.5, 6. , 8. ]
您可以使用tf.reduce_sum(tf.where(true_mean < false_mean, 1, 0))
您也可以尝试这样做:
import tensorflow as tf
A = tf.constant([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]])
B = tf.constant([[True, True, False],
[False, False, True],
[True, False, True]])
t_rows = tf.where(B)
f_rows = tf.where(~B)
_true = tf.gather_nd(A, t_rows)
_false = tf.gather_nd(A, f_rows)
count = tf.reduce_sum(tf.cast(tf.math.greater(tf.math.segment_mean(_false, f_rows[:, 0]), tf.math.segment_mean(_true, t_rows[:, 0])), dtype=tf.int32))
tf.print(count)
1
也适用于所有True
或False
行:
B = tf.constant([[True, True, True],
[False, False, True],
[True, False, True]])
# 0
B = tf.constant([[False, False, False],
[False, False, False],
[True, False, True]])
# 2
我想说你的开始很好
true_mat = np.where(B, A, 0)
false_mat = np.where(B, 0, A)
但是我们想要分别除以true或false的个数,所以…
true_sum = np.sum(B, axis = 1) #sum of Trues per row
false_sum = N-true_sum # if you don't have N given, do N=A.shape[0]
true_mean = np.sum(true_mat, axis = 1)/true_sum #add up rows of true_mat and divide by true_sum
false_mean = np.sum(false_mat, axis = 1)/false_sum
对于您的示例,这给出
[1.5 6. 8. ]
[3. 4.5 8. ]
那么现在我们只需要比较第二个比第一个大的地方:
count = np.sum(np.where(false_mean > true_mean, 1, 0))