numpy 4D 数组高级索引示例



我正在阅读一些深度学习代码。我在 numpy 数组中的高级索引方面遇到问题。我正在测试的代码:

import numpy
x = numpy.arange(2 * 8 * 3 * 64).reshape((2, 8, 3, 64))
x.shape
p1 = numpy.arange(2)[:, None]
sd = numpy.ones(2 * 64, dtype=int).reshape((2, 64))
p4 = numpy.arange(128 // 2)[None, :]
y = x[p1, :, sd, p4]
y.shape

为什么y的形状是(2, 64, 8)的?

这是上述代码的输出:

>>> x.shape
(2, 8, 3, 64)
>>> p1
array([[0], [1]])
>>> sd
array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
>>> p4
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]])
>>> y.shape
(2, 64, 8)

我读到这个:https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing

我认为这与广播有关:

x形状是(2, 8, 3, 64)

p1很简单,它array([[0], [1]]),只是意味着选择第一维的 IND0, 1。 双阵列用于广播。

p2:,这意味着在二维中选择所有8个元素。

p3很棘手,它包含两个"列表",可以从维度 3 中的 3 个元素中选择一个,因此生成的新 3rd 维度应该是 1。

p4意味着它在第四维度中选择所有 64 个元素。

所以我认为y.shape应该是(2, 8, 1, 64).

但正确的是(2, 64, 8).为什么?

当我第一次在 numpy 中遇到花哨的索引时,我遇到了同样的问题。简短的回答是,它没有任何技巧:花哨的索引只是将元素选择到与索引形状相同的输出中。使用纯粹花哨的索引,您的输出数组将与广播索引数组的形状相同(在此处描述)。输出的形状与输入的形状几乎没有关系,除非您也加入常规的切片索引(在此处描述)。你的情况是后者,这增加了混乱。

让我们看一下您的指数,看看发生了什么:

y = x[p1, :, sd, p4]
x.shape -> 2, 8, 3, 64
p1.shape -> 2, 1
sd.shape -> 2, 64
p4.shape -> 1, 64

有关如何继续的具体文档如下:

需要区分指数组合的两种情况:

  • 高级索引由切片、Ellipsisnewaxis分隔。例如x[arr1, :, arr2].
  • 高级索引彼此相邻。例如,x[..., arr1, arr2, :]但不是 x[arr1, :, 1],因为1是这方面的高级索引。

在第一种情况下,高级索引操作生成的维度在结果数组中排在第一位,子空间维度排在结果数组之后。在第二种情况下,高级索引操作中的维度入到结果数组中,位置与初始数组中的相同位置(后一种逻辑使简单的高级索引的行为与切片一样)。

强调我的

请记住,在上述两种情况下,花式索引部分的维度是索引数组的维度,而不是您正在索引的数组。

那么,你应该期望看到的是广播维度为p1sdp4(2, 64),其次是x(8)的第二维度的大小。这确实是你得到的:

>>> y.shape
(2, 64, 8)

最新更新