我有一些二进制2D numpy数组(prediction
),比如:
[
[1 0 1 0 1 1],
[0 0 1 0 0 1],
[1 1 1 1 1 0],
[1 1 0 0 1 1],
]
2D阵列中的每一行是作为特定类别的句子的分类,并且2D阵列中每一列对应于该句子的类别的分类。类别(categories
阵列)是作为示例的['A','B','C','D','E','F']
。
我有另一个2D数组(catIndex
),它包含每行中要检查的值的索引,例如
[[0],
[4],
[5],
[2]
]
对于上面的4个实例。
我现在要做的是循环遍历二进制2D数组,对于为每个句子指定的列索引,检查它是1
还是0
,然后将相应的类别附加到新数组(catResult = []
)。如果它是0
,我会将"no_region"
附加到新数组中。
例如,在句子1中,我查看该句子的索引0
,并检查它是0
还是1
。它是一个1
,所以我将'A'
附加到我的新数组中。在句子2中,我查看该句子的索引4
,发现它是一个0
,所以我将"no_region"
附加到数组中。
当前代码:
for index in catIndex:
for i,sentence in enumerate(prediction):
for j,binaryLabel in enumerate(sentence):
if prediction[i][index]==1:
catResult.append(categories[index])
else:
catResult.append("no_region")
制作2d数组:
In [54]: M=[[1,0,1,0,1,1],[0,0,1,0,0,1],[1,1,1,1,1,0],[1,1,0,0,1,1]]
In [55]: M=np.array(M)
以ind
为列索引,以[0,1,2,3]为行索引:
In [56]: ind=[0,4,5,2]
In [57]: m=M[np.arange(len(ind)),ind]
In [58]: m
Out[58]: array([1, 0, 0, 0])
带有ind
:的地图标签
In [59]: lbl=np.array(list('ABCDEF'),dtype=object)
In [60]: res=lbl[ind]
In [61]: res
Out[61]: array(['A', 'E', 'F', 'C'], dtype=object)
使用where
来确定是使用该映射值,还是使用某些None
。使用object
数据类型可以很容易地将字符串标签替换为其他内容,如None
或no_region
等。
In [62]: np.where(m, res, None)
Out[62]: array(['A', None, None, None], dtype=object)
沿着这些路线应该可以有效地完成这项工作,尽管现在还不能进行测试:
rows = len(prediction)
p = prediction[np.arange(rows), catIndex.flatten()]
catResult = np.empty(rows, 'S1').fill('n')
catResult[p] = categories[catIndex.flatten()][p]