在PyTorch中,我有一个批大小为256的RGB张量imgA
。我想保留前128批的绿色通道和剩余128批的红色通道,如下所示:
imgA[:128,2,:,:] = imgA[:128,1,:,:]
imgA[128:,2,:,:] = imgA[128:,0,:,:]
imgA = imgA[:,2,:,:].unsqueeze(1)
或者可以像一样实现
imgA = torch.cat((imgA[:128,1,:,:].unsqueeze(1),imgA[128:,0,:,:].unsqueeze(1)),dim=0)
但是,由于我有多个这样的图像,如imgA、imgB、imgC等,实现上述目标的最快方法是什么?
使用torch.gather
和repeat_interleave
:可以实现基于切片的解决方案
select = torch.tensor([1, 0], device=imgA.device)
imgA = = imgA.gather(dim=1, index=select.repeat_interleave(128, dim=0).view(256, 1, 1, 1).expand(-1, -1, *imgA.shape[-2:]))
您也可以使用矩阵乘法和repeat_interleave
:
# select c=1 for first half and c=0 for second
select = torch.tensor([[0, 1],[1, 0],[0, 0]], dtype=imgA.dtype, device=imgA.device)
imgA = torch.einsum('cb,bchw->bhw',select.repeat_interleave(128, dim=1), imgA).unsqueeze(dim=1)