进行Python组操作


def f1(x):
for i in range(1, 100):
x *= 2
x /= 3.14159
x *= i**.25
return x
def f2(x):
for i in range(1, 100):
x *= 2 / 3.14159 * i**.25
return x

两个函数的计算完全相同,但即使使用@numba.njitf1的计算时间也要长3倍。Python是否能够识别编译中的等价性,就像它在dis中通过抛出未使用的赋值等其他方式进行优化一样?

请注意,我知道浮点运算关心顺序,所以这两个函数的输出可能略有不同,但如果对数组值进行更多的单独编辑,则的准确性更低,因此这将是一个二合一的优化。


x = np.random.randn(10000, 1000)
%timeit f1(x.copy())        # 2.68 s ± 50.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f2(x.copy())        # 894 ms ± 36.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit njit(f1)(x.copy())  # 2.59 s ± 65.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit njit(f2)(x.copy())  # 901 ms ± 41.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

使用numba.jit可能是目前此类函数的最佳优化。您可能还想尝试pypy并进行一些基准比较。

尽管如此,我想指出为什么这两个函数不是等价的,所以你不应该期望f1被简化为f2

f1的操作顺序如下:

x1 = (x * 2)            # First binary operation
x2 = (x1 / 3.14159      # Second binary operation
x3 = x2 * (i ** 0.25)   # Third and fourth binary operation
# Order: Multiplication, division, exponent, multiplication

这与f2:不同

x *= ((2 / 3.14159) * (i ** 0.25))
#  ^     ^          ^     ^
#  |     |          |     |
#  4     1          3     2
# Order: Division, exponent, multiplication, multiplication

由于浮点运算不是关联的,所以它们可能不会产生相同的结果。出于这个原因,编译器或解释器进行您期望的优化是错误的,除非它是为了优化浮点精度。

我不知道有什么Python工具可以进行这种特定类型的优化。

可能无法使用jit。我已经尝试了在api中指定的fastmath和nogil-kwarg:https://numba.pydata.org/numba-doc/latest/reference/jit-compilation.html

CCD_ 10在去除溢出或非正规数后仍略慢于CCD_ 11。绘制

from timeit import default_timer as timer
import numpy as np
import matplotlib.pyplot as plt
import numba as nb

def f0(x):
for i in range(1, 1000):
x *= 3.000001
x /= 3
return x

def f1(x):
for i in range(1, 1000):
x *= 3.000001 / 3
return x

def timing(f, **kwarg):
x = np.ones(1000, dtype=np.float32)
times = []
n_iter = list(range(100, 1000, 100))
f2 = nb.njit(f, **kwarg)
for i in n_iter:
print(i)
s = timer()
for j in range(i):
f2(x)
e = timer()
times.append(e - s)
print(x)
m, b = np.polyfit(n_iter, times, 1)
return times, m, b, n_iter

def main():
results = []
for fastmath in [True, False]:
for i, f in enumerate([f0, f1]):
kwarg = {
"fastmath": fastmath,
"nogil": True
}
r1, m, b, n_iter = timing(f, **kwarg)
label = "f%d with %s" % (i, kwarg)
plt.plot(n_iter, r1, label=label)
results.append((m, b, label))
for m, b, kwarg in results:
print(m * 1e5, b, kwarg)
plt.legend(loc="upper left")
plt.xlabel("n iterations")
plt.ylabel("timing")
plt.show()
plt.close()

if __name__ == '__main__':
main()