仅给出数组列的索引时的索引范围



我正在寻找一种有效的方法来索引具有多个范围的numpy数组的列,当只给出所需范围的索引时。

例如,给定以下数组,范围大小为r_size=3:

import numpy as np
arr = np.arange(18).reshape((2,9))
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
[ 9, 10, 11, 12, 13, 14, 15, 16, 17]])

这意味着总共有3组范围[r0, r1, r2],它们在数组中的元素分布如下:

[[r0_00, r0_01, r0_02, r1_00, r1_01, r1_02, r2_00, r2_01, r2_02]
[r0_10, r0_11, r0_12, r1_10, r1_11, r1_12, r2_10, r2_11, r2_12]]

所以如果我想访问范围r0r2,那么我将得到:

arr    = np.arange(18).reshape((2,9))
r_size = 3
ranges = [0, 2]
# --------------------------------------------------------
# Line that index arr, with the variable ranges... Output:
# --------------------------------------------------------
array([[ 0,  1,  2,  6,  7,  8],
[ 9, 10, 11, 15, 16, 17]])

我找到的最快的方法是:

import numpy as np
from itertools import chain
arr    = np.arange(18).reshape((2,9))
r_size = 3
ranges = [0,2]
arr[:, list(chain(*[range(r_size*x,r_size*x+r_size) for x in ranges]))]
array([[ 0,  1,  2,  6,  7,  8],
[ 9, 10, 11, 15, 16, 17]])

但我不确定它是否可以在速度方面得到改进。

提前感谢!

你可以从把数组分成r_size块开始:

>>> splits = np.split(arr, r_size, axis=1)
[array([[ 0,  1,  2],
[ 9, 10, 11]]), 
array([[ 3,  4,  5],
[12, 13, 14]]), 
array([[ 6,  7,  8],
[15, 16, 17]])]

np.stack叠加并选择正确的ranges:

>>> stack = np.stack(splits)[ranges]
array([[[ 0,  1,  2],
[ 9, 10, 11]],
[[ 6,  7,  8],
[15, 16, 17]]])

并在axis=1上与np.hstacknp.concantenate水平连接:

>>> np.stack(stack)
array([[ 0,  1,  2,  6,  7,  8],
[ 9, 10, 11, 15, 16, 17]])

整体看起来像:

>>> np.hstack(np.stack(np.split(arr, r_size, axis=1))[ranges])
array([[ 0,  1,  2,  6,  7,  8],
[ 9, 10, 11, 15, 16, 17]])

或者,您可以专门使用np.reshapes,这将更快:

初始重塑:

>>> arr.reshape(len(arr), -1, r_size)
array([[[ 0,  1,  2],
[ 3,  4,  5],
[ 6,  7,  8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])

索引ranges:

>>> arr.reshape(len(arr), -1, r_size)[:, ranges]
array([[[ 0,  1,  2],
[ 6,  7,  8]],
[[ 9, 10, 11],
[15, 16, 17]]])

然后,重塑成最终的形式:

>>> arr.reshape(len(arr),  -1, r_size)[:, ranges].reshape(len(arr), -1)

您将不可避免地需要复制数据以在连续数组中获得所需的结果。为了提高效率,我建议尽量减少复制数据的次数。任何一种整形操作都可以用np.lib.stride_tricks.as_strided表示。

假设原始数组包含64位整数,则每个元素为8字节,按某种形状排列:

import numpy as np
arr = np.arange(18).reshape((2,9))
arr.shape, arr.strides

输出:

((2, 9), (72, 8))

所以每列跳过8个字节,每行跳过72个字节。arr.reshape(len(arr), -1, r_size)可以表示为:

np.lib.stride_tricks.as_strided(arr, (2,3,3), (72,24,8))

输出:

array([[[ 0,  1,  2],
[ 3,  4,  5],
[ 6,  7,  8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]])

arr.reshape(len(arr), -1, r_size)[:, ranges]可以表示为:

np.lib.stride_tricks.as_strided(arr, (2,2,3), (72,24*2,8))

输出:

array([[[ 0,  1,  2],
[ 6,  7,  8]],
[[ 9, 10, 11],
[15, 16, 17]]])
到目前为止,我们只更改了数组的元数据,这意味着没有复制数据。该操作的性能成本几乎为零。但是要得到最终的数组,你需要复制数据:
np.lib.stride_tricks.as_strided(arr, (2,2,3), (72,24*2,8)).reshape(len(arr), -1)

输出:

array([[ 0,  1,  2,  6,  7,  8],
[ 9, 10, 11, 15, 16, 17]])

这不是一个通用的解决方案,但它可能会给你一些关于如何优化的想法。

不幸的是,我的计时不支持这些声明,但它仍然是直观的,值得对一些更大的数组进行测试。

相关内容

  • 没有找到相关文章

最新更新