Collatz猜想程序Python的最大化效率



我的问题很简单。

我写这个节目纯粹是为了娱乐。它接受一个数字输入,并找到每个Collatz序列的长度,直到并包括该数字。

我想在算法或数学上让它更快(也就是说,我知道我可以通过并行运行多个版本或用C++编写它来让它更快,但这有什么乐趣呢?(。

欢迎任何帮助,谢谢!

编辑:代码在dankal444的帮助下进一步优化

from matplotlib import pyplot as plt
import numpy as np
import numba as nb
# Get Range to Check
top_range = int(input('Top Range: '))
@nb.njit('int64[:](int_)')
def collatz(top_range):
# Initialize mem
mem = np.zeros(top_range + 1, dtype = np.int64)
for start in range(2, top_range + 1):
# If mod4 == 1: (3x + 1)/4
if start % 4 == 1:
mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3

# If 4mod == 3: 3(3x + 1) + 1 and continue
elif start % 4 == 3:
num = start + (start >> 1) + 1
num += (num >> 1) + 1
count = 4
while num >= start:
if num % 2:
num += (num >> 1) + 1
count += 2
else:
num //= 2
count += 1
mem[start] = mem[num] + count
# If 4mod == 2 or 0: x/2
else:
mem[start] = mem[(start // 2)] + 1
return mem
mem = collatz(top_range)
# Plot each starting number with the length of it's sequence
plt.scatter([*range(1, len(mem) + 1)], mem, color = 'black', s = 1)
plt.show()

在代码中应用numba确实有很大帮助。

我删除了tqdm,因为它对性能没有帮助。

import time
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import numba as nb
@nb.njit('int64[:](int_)')
def collatz2(top_range):
mem = np.zeros(top_range + 1, dtype=np.int64)
for start in range(2, top_range + 1):
# If mod(4) == 1: Value 2 or 3 Cached
if start % 4 == 1:
mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
# If mod(4) == 3: Use Algorithm
elif start % 4 == 3:
num = start
count = 0
while num >= start:
if num % 2:
num += (num >> 1) + 1
count += 2
else:
num //= 2
count += 1
mem[start] = mem[num] + count
# If mod(4) == 2 or 4: Value 1 Cached
else:
mem[start] = mem[(start // 2)] + 1
return mem

def collatz(top_range):
mem = [0] * (top_range + 1)
for start in range(2, top_range + 1):
# If mod(4) == 1: Value 2 or 3 Cached
if start % 4 == 1:
mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
# If mod(4) == 3: Use Algorithm
elif start % 4 == 3:
num = start
count = 0
while num >= start:
if num % 2:
num += (num >> 1) + 1
count += 2
else:
num //= 2
count += 1
mem[start] = mem[num] + count
# If mod(4) == 2 or 4: Value 1 Cached
else:
mem[start] = mem[(start // 2)] + 1
return mem
# profiling here
def main():
top_range = 1_000_000
mem = collatz(top_range)
mem2 = collatz2(top_range)
assert np.allclose(np.array(mem), mem2)

对于top_range=1.000,优化函数的速度快约100倍。对于top_range=1_0000_000,优化后的函数大约快600倍:

79                                           def main():
81         1          3.0      3.0      0.0      top_range = 1_000_000
83         1   24633045.0 24633045.0     98.7      mem = collatz(top_range)
85         1      39311.0  39311.0      0.2      mem2 = collatz2(top_range)

最新更新