我在MATLAB中有两个矩阵。每个位置都充满了1
和0
。我想比较每个元素:
- 如果有
1
匹配项,我希望它记录为"真阳性"。 - 如果有
0
匹配项,我希望它记录为"真阴性"。 - 如果一个说
1
,另一个说0
,我想记录为误报。 - 如果一个说
0
,另一个说1
,我想记录为假阴性。
我尝试只比较两个矩阵:
idx = A == B
但是,这给了我一个简单的匹配,没有告诉我什么时候有真阳性或阴性等。
是否有任何我可以使用的特定功能,或任何替代方案?
您可以按照规定的方式添加矩阵....
a = [1 0 1 0
1 1 0 0
0 0 1 1];
b = [1 0 0 0
0 0 0 1
0 0 1 0];
C = a + 2*b;
% For pairs [a,b] we expect
% [0,0]: C = 0, true negative
% [1,0]: C = 1, false positive
% [0,1]: C = 2, false negative
% [1,1]: C = 3, true positive
% C =
% [ 3 0 1 0
% 1 1 0 2
% 0 0 3 1 ]
如果您有统计和机器学习工具箱,并且只需要摘要,则可能只需要函数confusionmat
。
从文档中:
C = confusionmat(group,grouphat)
返回由已知组和预测组确定的混淆矩阵C
group
和grouphat
。[...].C
是一个方阵,其大小等于group
和grouphat
中不同元素的总数。C(i,j)
是已知位于组i
中但预测位于组j
中的观测值计数。
例如:
a = [1 0 1 0
1 1 0 0
0 0 1 1];
b = [1 0 0 0
0 0 0 1
0 0 1 0];
C = confusionmat( a(:), b(:) );
% C =
% [ 5 1
% 4 2]
% So for each pair [a,b], we have 5*[0,0], 2*[1,1], 4*[1,0], 1*[0,1]
对于使用神经网络工具箱的人来说,类似的功能将是confusion
.
您可以使用按位运算符来生成四个不同的值:
bitor(bitshift(uint8(b),1),uint8(a))
生成一个具有
0 : 真 阴性 1 : 假阴性
(a 为真,b 为假)2 : 假阳性(a 为假,但 b 为真)
3 : 真阳性
一种天真的方法是逐案进行四种比较:
% Set up some artificial data
ground_truth = randi(2, 5) - 1
compare = randi(2, 5) - 1
% Determine true positives, false positives, etc.
tp = ground_truth & compare
fp = ~ground_truth & compare
tn = ~ground_truth & ~compare
fn = ground_truth & ~compare
输出:
ground_truth =
1 0 1 0 0
0 1 1 0 1
1 1 0 1 0
0 1 0 1 1
0 0 0 1 0
compare =
0 1 1 0 1
0 1 1 1 0
1 1 0 0 1
1 1 1 0 0
1 1 1 1 1
tp =
0 0 1 0 0
0 1 1 0 0
1 1 0 0 0
0 1 0 0 0
0 0 0 1 0
fp =
0 1 0 0 1
0 0 0 1 0
0 0 0 0 1
1 0 1 0 0
1 1 1 0 1
tn =
0 0 0 1 0
1 0 0 0 0
0 0 1 0 0
0 0 0 0 0
0 0 0 0 0
fn =
1 0 0 0 0
0 0 0 0 1
0 0 0 1 0
0 0 0 1 1
0 0 0 0 0
这是有效的,因为0
和1
(或任何正值)是true
和false
的替代表示。
为了保持主代码干净,请设置一个单独的函数,例如my_stats.m
function [tp, fp, tn, fn] = my_stats(ground_truth, compare)
% Determine true positives, false positives, etc.
tp = ground_truth & compare;
fp = ~ground_truth & compare;
tn = ~ground_truth & ~compare;
fn = ground_truth & ~compare;
end
并在主代码中调用它:
% Set up some artificial data
ground_truth = randi(2, 5) - 1
compare = randi(2, 5) - 1
[tp, fp, tn, fn] = my_stats(ground_truth, compare)
希望对您有所帮助!
我发现我可以使用find方法并设置两个条件,然后只需在每个变量中找到元素的数量
TruePositive = length(find(A==B & A==1))
TrueNegative = length(find(A==B & A==0))
FalsePositive = length(find(A~=B & A==1))
FalseNegative = length(find(A~=B & A==0))
@Wolfie建议的confusionmatrix()方法也非常简洁,特别是如果你使用confusionchart()提供了一个很好的可视化。