什么是最好的蟒蛇解决方案

  • 本文关键字:解决方案 python numpy
  • 更新时间 :
  • 英文 :


由于X是一个形状为(n,m(的数组,Y是一个长度为n的列表,哪些值是二进制的,使用numpy,下面代码的最佳Python替代方案是什么?

p1 = np.zeros(X.shape[1])
p0 = np.zeros(X.shape[1])
for i in range(len(X[0])):        
sum_1 = np.where(Y==1,X[:,i],0).sum()
sum_0 = np.where(Y==0,X[:,i],0).sum()
p1[i] = sum_1
p0[i] = sum_0

这里有一个更快更简单的版本:

p1 = X.T @ Y # or np.dot(X.T, Y) if on Python < 3.5
p0 = X.T @ (1 - Y)

这利用了Y数组是0和1的事实,并计算了一个快速点积。


以下框架的计时结果:

import numpy as np
n = 2000
m = 1000
X = np.random.random((n, m))
Y = (np.random.random((n,)) > 0.5).astype(int)
def v0():
p1 = np.zeros(X.shape[1])
p0 = np.zeros(X.shape[1])
for i in range(len(X[0])):
sum_1 = np.where(Y==1,X[:,i],0).sum()
sum_0 = np.where(Y==0,X[:,i],0).sum()
p1[i] = sum_1
p0[i] = sum_0
return p0, p1
def v1():
p1 = np.sum(X[np.where(Y==1)], axis=0)
p0 = np.sum(X[np.where(Y==0)], axis=0)
return p0, p1
def v2():
p1 = X.T @ Y # or np.dot(X.T, Y) if on Python < 3.5
p0 = X.T @ (1 - Y)
return p0, p1
p0_0, p1_0 = v0()
p0_1, p1_1 = v1()
p0_2, p1_2 = v2()
assert np.allclose(p0_0, p0_1)
assert np.allclose(p0_0, p0_2)
assert np.allclose(p1_0, p1_1)
assert np.allclose(p1_0, p1_2)
$ python3 -m timeit -s 'import test' 'test.v0()'
10 loops, best of 5: 33.5 msec per loop
$ python3 -m timeit -s 'import test' 'test.v1()'
100 loops, best of 5: 3.81 msec per loop
$ python3 -m timeit -s 'import test' 'test.v2()'
500 loops, best of 5: 794 usec per loop

对于这套尺寸,这个版本比原来的快40倍以上。

在某些条件下,在行的第一个X轴上求和

p1 = np.sum(X[np.where(Y==1)], axis=0)
p0 = np.sum(X[np.where(Y==0)], axis=0)

如果Y是布尔区域,则可以直接使用Y作为索引,numpy会屏蔽相应的行。您也可以用~Y否定Y,以获得其他行:

>>> X
array([[1, 2],
[2, 3],
[3, 4]])
>>> Y
array([False, False,  True])
>>> X[Y]
array([[3, 4]])
>>> X[~Y]
array([[1, 2],
[2, 3]])
>>> X[Y].sum(axis=0)
array([3, 4])
>>> X[~Y].sum(axis=0)
array([3, 5])

相关内容

最新更新