Python:优化删除numpy数组中未对齐的元素




我使用numpy 1.6.2python 2.7处理非常大的矩阵。给定一个N x M矩阵A和一个映射B,其中我可以为每一行找到要删除的元素的索引。这里有一个例子:

A =
   26   55   29   30
   31   65   34   35
   36   75   39   40
   41   85   44   45
   46   95   49   50
B =
     2
     0
     1
     3
     2

结果将是:

A =
   26   55   30
   65   34   35
   36   39   40
   41   85   44
   46   95   50

事实上,为了获得这个,我创建了一个这样的循环:

for i in xrange(size(B)):
  A[i,:] = concatenate(A[i,0:B[i]],A[i,B[i]+1:])

但它真的很慢。有没有更快的方法删除我需要的元素?

谢谢大家!

您可以为A创建一个掩码,如下所示:

>>> mask = np.arange(4) != np.vstack(B)
>>> mask
array([[ True,  True, False,  True],
       [False,  True,  True,  True],
       [ True, False,  True,  True],
       [ True,  True,  True, False],
       [ True,  True, False,  True]], dtype=bool)

然后使用它从A中筛选出不需要的(False)值,重新整形,然后重新绑定到变量名A:

>>> A = A[mask].reshape(5, 3)
>>> A
array([[26, 55, 30],
       [65, 34, 35],
       [36, 39, 40],
       [41, 85, 44],
       [46, 95, 50]])

这应该比使用concatenate更快,因为它避免了为Python for循环的每次迭代复制数组。

最新更新