我正在尝试将某些数据从张量中取出,但我会遇到奇怪的错误。在这里,我将尝试生成错误:
a=np.random.randn(5, 10, 5, 5)
a[:, [1, 6], np.triu_indices(5, 0)[0], np.triu_indices(5, 0)[1]].shape
我得到此错误
形状不匹配:索引阵列无法与形状一起播放
我什至没有做任何广播!这都是切片的东西。
我想要什么?将零轴按原样(获取所有内容(,从第一个轴获取[1]和[6],仅使用上层三角元素将最后两个轴从[5,5]重塑为[5,5]。
我们需要将第二轴索引数组扩展到 2D
,以便与np.triu_indices
的索引形成外平面。因此,它为我们提供了mxn
数组的2D
网格,m
是第二轴索引数组的长度,而n
是np.triu_indices
的长度。因此,从本质上讲,整个解决方案将简化为这样的东西 -
r,c = np.triu_indices(5, 0)
out = a[:, np.array([1, 6])[:,None], r, c]
或以该扩展版本为列表,即 -
out = a[:, [[1],[6]], r, c]
我们还可以使用np.tri/np.triu
的基于掩码的掩码,它可能会在较大的数组上更快,因为我们会跳过创建所有整数索引,例如So -o -
mask = ~np.tri(5, k=-1, dtype=bool)
out = a[:, np.array([1, 6])[:,None], mask]