在大数组中使用广播、转置和重塑的Numpy速度效率



是否有一种方法可以加速以下代码行:

desired_channel=32
len_indices=50000
fast_idx = np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)

谢谢。

最后一行代码等于np.tile(np.arange(desired_channel), len_indices)

在我的机器上,np.tile的性能像许多Numpy调用一样受到操作系统(页面错误)、内存分配器和内存吞吐量的限制。有两种方法可以克服这个限制:不分配/填充临时缓冲区,生成更小的数组在内存中使用较短的类型,如np.uint8np.uint16,根据您的需要。

由于np.tile函数没有out参数,所以可以使用Numba来生成一个快速的替代函数。下面是一个例子:

import numba as nb
@nb.njit('int32[::1](int32, int32, int32[::1])', parallel=True)
def generate(desired_channel, len_indices, out):
for i in nb.prange(len_indices):
for j in range(desired_channel):
out[i*desired_channel+j] = j
return out
desired_channel=32
len_indices=50000
buffer = np.full(desired_channel * len_indices, 0, dtype=np.int32)
%timeit -n 200 generate(desired_channel, len_indices, fast_idx)

下面是性能结果:

Original code: 1.25 ms
np.tile:       1.24 ms
Numba:         0.20 ms

我是新手jax图书馆。我用以下代码比较了你的代码Colab TPU:

import numpy as np
from jax import jit
import jax.numpy as jnp
import timeit
desired_channel=32
len_indices=50000
def ex_():
return np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 ex_()
@jit
def exj_():
return jnp.broadcast_to(jnp.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 exj_()

在我的一次努力中,结果如下:

1000个循环,最好的10:901µs/循环
1000个循环,最好的10:317µs/循环

jax可以使你的代码速度提高两到三倍。

相关内容

  • 没有找到相关文章

最新更新