在numpy中删除for循环和追加到列表



我需要优化我的代码,并且已经能够在几乎所有地方删除for循环,但在这一小部分上很吃力。我看过numpy.where,但不认为我能用它,但我不太确定。如果有人知道我应该考虑哪些功能来优化这一部分,我将不胜感激。这是经常使用的,因为在pyton中循环非常慢,所以我需要做一些优化。

def f(x):
return np.pi * x * np.cos((np.pi / 2) * x ** 2)
x_samples = np.random.uniform(low=0.0, high=1.0, size=500000)
y_samples = np.random.uniform(low=0.0, high=1.0, size=500000) * 1.61
y_functie = f(x_samples)
y_hit, x_hit, x_miss, y_miss = [], [], [], []
for n in range(len(x_samples)):
if y_samples[n] <= y_functie[n]:
x_hit.append(x_samples[n]), y_hit.append(y_samples[n])
else:
x_miss.append(x_samples[n]), y_miss.append(y_samples[n])

通过进行索引,您可以执行以下操作:

import numpy as np
def f(x):
return np.pi * x * np.cos((np.pi / 2) * x ** 2)
x_samples = np.random.uniform(low=0.0, high=1.0, size=500000)
y_samples = np.random.uniform(low=0.0, high=1.0, size=500000) * 1.61
y_functie = f(x_samples)
def test1():
y_hit, x_hit, x_miss, y_miss = [], [], [], []
for n in range(len(x_samples)):
if y_samples[n] <= y_functie[n]:
x_hit.append(x_samples[n]), y_hit.append(y_samples[n])
else:
x_miss.append(x_samples[n]), y_miss.append(y_samples[n])
def test2():
x_hit = x_samples[y_samples<=y_functie]
x_miss = x_samples[y_samples>y_functie]
y_hit = y_samples[y_samples<=y_functie]
y_miss = y_samples[y_samples>y_functie]

test1()
test2()

根据您的测试数据,我得到test1的700ms和test2的21ms的时间差。当然,您可以通过只使用一个布尔表来简化更多,并一次性完成所有操作。

根据您想要优化的内容(速度或内存占用(,您可以使用预先计算布尔表

hit_test = y_samples<=y_functie

您可以使用fancy indexing的逻辑掩码来实现所需的结果:

import numpy as np
def f(x):
return np.pi * x * np.cos((np.pi / 2) * x ** 2)
x_samples = np.random.uniform(low=0.0, high=1.0, size=500000)
y_samples = np.random.uniform(low=0.0, high=1.0, size=500000) * 1.61
y_functie = f(x_samples)
hits = y_samples <= y_functie
misses = y_samples > y_functie
y_hit, x_hit, x_miss, y_miss = x_samples[hits], y_samples[hits], x_samples[misses], y_samples[misses]

干杯。

最新更新