在numpy narray中重新排序和重塑元素



我有一个外部模块返回的numpy narray。数组的形状为(3,3,128)。基本上是128个贴图的堆叠,每个贴图是3x3。

我如何重新排序,使形状变成(128,3,3)。这样就可以更容易地按tile号进行索引。最后一步是将其扁平化为(128,9),这样128个贴图中的每一个都可以作为9个值向量轻松访问。

您可以对指定的新数组顺序使用转置,例如

a = np.arange(0,3*3*128).reshape(3,3,128) 
a_reorder = a.transpose([2,0,1])

你可以通过比较所有的贴图来检查它是否正确,

np.all([np.all(a[:,:,i]==a_reorder[i,:,:]) for i in range(128)])

a_flat = a_reorder.reshape(128,9)

重塑3 * 3 * 128到128 * 3 * 3:

y = einops.rearrange(x, 'x y tile -> tile x y')

或者可以在一次操作中直接重塑为128 * 9

y = einops.rearrange(x, 'x y tile -> tile (x y)')

相关内容

  • 没有找到相关文章

最新更新