我可以用全局dict并行化这个小Python脚本吗



我在Python3的MBA课程中遇到了这个问题,大约需要2.8秒。由于它的核心是一个缓存字典,我认为哪个调用首先到达缓存并不重要,所以也许我可以从线程中获得一些好处。不过我想不通。这比我通常问的问题要高一点,但有人能带我完成这个问题的并行化过程吗?

import time
import threading
even = lambda n: n%2==0
next_collatz = lambda n: n//2 if even(n) else 3*n+1
cache = {1: 1}
def collatz_chain_length(n):
if n not in cache: cache[n] = 1 + collatz_chain_length(next_collatz(n))
return cache[n]
if __name__ == '__main__':
valid = range(1, 1000000)
for n in valid:
# t = threading.Thread(target=collatz_chain_length, args=[n] )
# t.start()
collatz_chain_length(n)
print( max(valid, key=cache.get) )

或者,如果这是一个糟糕的候选人,为什么?

如果您的工作负载是CPU密集型的,那么Python中的线程将不会得到很好的提升。这是因为由于GIL(全局解释器锁定),一次只有一个线程实际使用处理器。

但是,如果您的工作负载是I/O绑定的(例如,等待来自网络请求的响应),线程会给您带来一些提升,因为如果您的线程在等待网络响应时被阻止,另一个线程可以做有用的工作。

正如HDN所提到的,使用多处理将有所帮助——这将使用多个Python解释器来完成工作。

我的方法是将迭代次数除以您计划创建的进程数。例如,如果您创建了4个进程,则为每个进程提供一个工作的1000000/4切片。

最后,您需要汇总每个过程的结果,并应用max()来获得结果。

线程不会给您带来太多性能提升,因为它不会绕过全局解释器锁,全局解释器锁在任何给定时刻都只运行一个线程。实际上,由于上下文切换,它甚至可能会减慢您的速度。

如果您想在Python中利用并行化来提高性能,那么就必须使用多处理来一次实际利用多个核心。

我在单核上成功地将您的代码加速了16.5x倍,请继续阅读。

正如前面所说,由于全局解释器锁定,多线程在纯Python中没有任何改进。

关于多处理,有两种选择:1)实现共享字典,并直接从不同的进程读取/写入。2) 将值的范围划分为多个部分,并为不同进程上的单独子范围求解任务,然后从所有进程的答案中取最大值。

第一个选项将非常缓慢,因为在您的代码读取/写入字典是主要耗时的操作,使用进程之间共享的字典将使其速度慢5倍以上,而多核不会带来任何改进。

第二种选择会带来一些改进,但也不是很大,因为不同的过程会多次重新计算相同的值。只有当您在集群中有很多核心或使用许多独立的机器时,此选项才会提供相当大的改进。

我决定实现另一种改进任务的方法(选项3)——使用Numba并进行其他优化。我的解决方案也适用于选项2(子范围的并行化)。

Numba是实时编译器和优化器,它将纯Python代码转换为优化的C++,然后编译为机器代码。Numba通常可以提供10x-100倍的加速。

要使用numba运行代码,您只需要安装pip install numba(目前Python版本<=3.8支持numba,3.9也将很快支持!)。

我所做的所有改进都使单核的16.5x速度提高了一倍(例如,如果在你的算法上,某个范围的速度是64秒,那么在我的代码上,速度是4秒)。

我不得不重写你的代码,算法和想法和你的一样,但我让算法是非递归的(因为Numba不能很好地处理递归),还使用列表而不是字典来表示不太大的值。

我的单核基于numba的版本有时可能会占用太多内存,这只是因为cs参数控制了使用列表而不是字典的阈值,目前这个cs被设置为stop * 10(在代码中搜索),如果你没有太多的内存,只需将其设置为例如stop * 2(但不小于stop * 1)。我有16GB的内存,即使在64000000的上限下程序也能正常运行。

此外,除了Numba代码,我实现了C++解决方案,它在速度上似乎和Numba一样,这意味着Numba做得很好!C++代码位于Python代码之后。

我对你的算法(solve_py())和我的算法(solve_nm())进行了计时测量,并对它们进行了比较。时间在代码后面列出。

作为参考,我也使用我的numba解决方案进行了多核处理版本,但与单核版本相比,它没有任何改进,甚至速度减慢。这一切的发生是因为多核版本多次计算出相同的值。也许多机版本会带来显著的改进,但可能不会是多核版本。

由于这些免费在线服务器上的内存有限,下面的在线链接只允许运行小范围!

在线试用!

import time, threading, time, numba
def solve_py(start, stop):
even = lambda n: n%2==0
next_collatz = lambda n: n//2 if even(n) else 3*n+1
cache = {1: 1}
def collatz_chain_length(n):
if n not in cache: cache[n] = 1 + collatz_chain_length(next_collatz(n))
return cache[n]
for n in range(start, stop):
collatz_chain_length(n)
r = max(range(start, stop), key = cache.get)
return r, cache[r]
@numba.njit(cache = True, locals = {'n': numba.int64, 'l': numba.int64, 'zero': numba.int64})
def solve_nm(start, stop):
zero, l, cs = 0, 0, stop * 10
ns = [zero] * 10000
cache_lo = [zero] * cs
cache_lo[1] = 1
cache_hi = {zero: zero}
for n in range(start, stop):
if cache_lo[n] != 0:
continue
nsc = 0
while True:
if n < cs:
cg = cache_lo[n]
else:
cg = cache_hi.get(n, zero)
if cg != 0:
l = 1 + cg
break
ns[nsc] = n
nsc += 1
n = (n >> 1) if (n & 1) == 0 else 3 * n + 1
for i in range(nsc - 1, -1, -1):
if ns[i] < cs:
cache_lo[ns[i]] = l
else:
cache_hi[ns[i]] = l
l += 1
maxn, maxl = 0, 0
for k in range(start, stop):
v = cache_lo[k]
if v > maxl:
maxn, maxl = k, v
return maxn, maxl
if __name__ == '__main__':
solve_nm(1, 100000) # heat-up, precompile numba
for stop in [1000000, 2000000, 4000000, 8000000, 16000000, 32000000, 64000000]:
tr, resr = None, None
for is_nm in [False, True]:
if stop > 16000000 and not is_nm:
continue
tb = time.time()
res = (solve_nm if is_nm else solve_py)(1, stop)
te = time.time()
print(('py', 'nm')[is_nm], 'limit', stop, 'time', round(te - tb, 2), 'secs', end = '')
if not is_nm:
resr, tr = res, te - tb
print(', n', res[0], 'len', res[1])
else:
if tr is not None:
print(', boost', round(tr / (te - tb), 2))
assert resr == res, (resr, res)
else:
print(', n', res[0], 'len', res[1])

输出:

py limit 1000000 time 3.34 secs, n 837799 len 525
nm limit 1000000 time 0.19 secs, boost 17.27
py limit 2000000 time 6.72 secs, n 1723519 len 557
nm limit 2000000 time 0.4 secs, boost 16.76
py limit 4000000 time 13.47 secs, n 3732423 len 597
nm limit 4000000 time 0.83 secs, boost 16.29
py limit 8000000 time 27.32 secs, n 6649279 len 665
nm limit 8000000 time 1.68 secs, boost 16.27
py limit 16000000 time 55.42 secs, n 15733191 len 705
nm limit 16000000 time 3.48 secs, boost 15.93
nm limit 32000000 time 7.38 secs, n 31466382 len 706
nm limit 64000000 time 16.83 secs, n 63728127 len 950

与Numba相同算法的C++版本如下:

在线试用!

#include <cstdint>
#include <vector>
#include <unordered_map>
#include <tuple>
#include <iostream>
#include <stdexcept>
#include <chrono>
typedef int64_t i64;
static std::tuple<i64, i64> Solve(i64 start, i64 stop) {
i64 cs = stop * 10, n = 0, l = 0, nsc = 0;
std::vector<i64> cache_lo(cs), ns(10000);
cache_lo[1] = 1;
std::unordered_map<i64, i64> cache_hi;
for (i64 i = start; i < stop; ++i) {
if (cache_lo[i] != 0)
continue;
n = i;
nsc = 0;
while (true) {
i64 cg = 0;
if (n < cs)
cg = cache_lo[n];
else {
auto it = cache_hi.find(n);
if (it != cache_hi.end())
cg = it->second;
}
if (cg != 0) {
l = 1 + cg;
break;
}
ns.at(nsc) = n;
++nsc;
n = (n & 1) ? 3 * n + 1 : (n >> 1);
}
for (i64 i = nsc - 1; i >= 0; --i) {
i64 n = ns[i];
if (n < cs)
cache_lo[n] = l;
else
cache_hi[n] = l;
++l;
}
}
i64 maxn = 0, maxl = 0;
for (size_t i = start; i < stop; ++i)
if (cache_lo[i] > maxl) {
maxn = i;
maxl = cache_lo[i];
}
return std::make_tuple(maxn, maxl);
}
int main() {
try {
for (auto stop: std::vector<i64>({1000000, 2000000, 4000000, 8000000, 16000000, 32000000, 64000000})) {
auto tb = std::chrono::system_clock::now();
auto r = Solve(1, stop);
auto te = std::chrono::system_clock::now();
std::cout << "cpp limit " << stop
<< " time " << double(std::chrono::duration_cast<std::chrono::milliseconds>(te - tb).count()) / 1000.0 << " secs"
<< ", n " << std::get<0>(r) << " len " << std::get<1>(r) << std::endl;
}
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}

输出:

cpp limit 1000000 time 0.17 secs, n 837799 len 525
cpp limit 2000000 time 0.357 secs, n 1723519 len 557
cpp limit 4000000 time 0.757 secs, n 3732423 len 597
cpp limit 8000000 time 1.571 secs, n 6649279 len 665
cpp limit 16000000 time 3.275 secs, n 15733191 len 705
cpp limit 32000000 time 7.112 secs, n 31466382 len 706
cpp limit 64000000 time 17.165 secs, n 63728127 len 950

最新更新