是否有一种方法可以加速以下代码行:
desired_channel=32
len_indices=50000
fast_idx = np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
谢谢。
最后一行代码等于np.tile(np.arange(desired_channel), len_indices)
在我的机器上,np.tile
的性能像许多Numpy调用一样受到操作系统(页面错误)、内存分配器和内存吞吐量的限制。有两种方法可以克服这个限制:不分配/填充临时缓冲区,生成更小的数组在内存中使用较短的类型,如np.uint8
或np.uint16
,根据您的需要。
由于np.tile
函数没有out
参数,所以可以使用Numba来生成一个快速的替代函数。下面是一个例子:
import numba as nb
@nb.njit('int32[::1](int32, int32, int32[::1])', parallel=True)
def generate(desired_channel, len_indices, out):
for i in nb.prange(len_indices):
for j in range(desired_channel):
out[i*desired_channel+j] = j
return out
desired_channel=32
len_indices=50000
buffer = np.full(desired_channel * len_indices, 0, dtype=np.int32)
%timeit -n 200 generate(desired_channel, len_indices, fast_idx)
下面是性能结果:
Original code: 1.25 ms
np.tile: 1.24 ms
Numba: 0.20 ms
我是新手jax图书馆。我用以下代码比较了你的代码Colab TPU:
import numpy as np
from jax import jit
import jax.numpy as jnp
import timeit
desired_channel=32
len_indices=50000
def ex_():
return np.broadcast_to(np.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 ex_()
@jit
def exj_():
return jnp.broadcast_to(jnp.arange(desired_channel)[:, None], (desired_channel, len_indices)).T.reshape(-1)
%timeit -n1000 -r10 exj_()
在我的一次努力中,结果如下:
jax可以使你的代码速度提高两到三倍。1000个循环,最好的10:901µs/循环
1000个循环,最好的10:317µs/循环