是否有办法提高这种分形计算算法的性能?



昨天我看到了3Blue1Brown关于牛顿分形的新视频,我真的被他对分形的现场表现迷住了。(这里是视频链接,感兴趣的人可以观看,时间是13:40:https://www.youtube.com/watch?v=-RdOwhmqP5s)

我想自己试一试,试着用python写代码(我想他也用python)。

我花了几个小时试图改进我天真的实现,但我不知道如何才能使它更快。

代码如下:

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from time import time

def print_fractal(state):
fig = plt.figure(figsize=(8, 8))
gs = GridSpec(1, 1)
axs = [fig.add_subplot(gs[0, 0])]
fig.tight_layout(pad=5)
axs[0].matshow(state)
axs[0].set_xticks([])
axs[0].set_yticks([])
plt.show()
plt.close()

def get_function_value(z):
return z**5 + z**2 - z + 1

def get_function_derivative_value(z):
return 5*z**4 + 2*z - 1

def check_distance(state, roots):
roots2 = np.zeros((roots.shape[0], state.shape[0], state.shape[1]), dtype=complex)
for r in range(roots.shape[0]):
roots2[r] = np.full((state.shape[0], state.shape[1]), roots[r])
dist_2 = np.abs((roots2 - state))
original_state = np.argmin(dist_2, axis=0) + 1
return original_state

def static():
time_start = time()
s = 4
c = [0, 0]
n = 800
polynomial = [1, 0, 0, 1, -1, 1]
roots = np.roots(polynomial)
state = np.transpose((np.linspace(c[0] - s/2, c[0] + s/2, n)[:, None] + 1j*np.linspace(c[1] - s/2, c[1] + s/2, n)))
n_steps = 15
time_setup = time()
for _ in range(n_steps):
state -= (get_function_value(state) / get_function_derivative_value(state))
time_evolution = time()
original_state = check_distance(state, roots)
time_check = time()
print_fractal(original_state)
print("{0:<40}".format("Time to setup the initial configuration:"), "{:20.3f}".format(time_setup - time_start))
print("{0:<40}".format("Time to evolve the state:"), "{:20.3f}".format(time_evolution - time_setup))
print("{0:<40}".format("Time to check the closest roots:"), "{:20.3f}".format(time_check - time_evolution))

平均输出如下所示:

设置初始配置的时间:0.004

状态演化时间:0.796

检查最接近根的时间:0.094

很明显,是进化部分阻碍了这个过程。它不是"慢",但我认为它不足以呈现像视频中那样的现场效果。我已经通过使用numpy向量和避免循环做了我能做的但我想这还不够。还有什么其他技巧可以应用在这里?

注意:我尝试使用numpy.多项式.多项式类来计算函数,但它比这个版本慢。

  1. 我得到了改进(~40%快)通过使用单一复杂(np.complex64)精度。
(...)
state = np.transpose((np.linspace(c[0] - s/2, c[0] + s/2, n)[:, None] 
+ 1j*np.linspace(c[1] - s/2, c[1] + s/2, n)))
state = state.astype(np.complex64)
(...)
  1. 3Blue1Brown在描述中添加了这个链接:https://codepen.io/mherreshoff/pen/RwZPazd你可以看看它是如何做到的(旁注:这支笔的作者也使用了单精度)
for _ in range(n_steps):
state -= (get_function_value(state) / get_function_derivative_value(state))

如果你有足够的内存,你可以尝试对这个循环进行矢量化,并用矩阵计算来存储每个迭代步骤。

试试pypy3、numba或cython。

Pypy3是一个快速的cpython替代品。对于纯python代码,这可以大大加快。

Numba的nopython模式可以显著加快cpython中的数学运算速度。不过,它除了数学之外,做不了什么。

Cython是一种python方言,可以翻译成c语言,并允许您混合使用cpython和c语言类型。使用的c类型越多越好(通常)。查看cython -a option