PyTorch中的同时批处理和通道切片



在PyTorch中,我有一个批大小为256的RGB张量imgA。我想保留前128批的绿色通道和剩余128批的红色通道,如下所示:

imgA[:128,2,:,:] = imgA[:128,1,:,:]
imgA[128:,2,:,:] = imgA[128:,0,:,:]
imgA = imgA[:,2,:,:].unsqueeze(1)

或者可以像一样实现

imgA = torch.cat((imgA[:128,1,:,:].unsqueeze(1),imgA[128:,0,:,:].unsqueeze(1)),dim=0)

但是,由于我有多个这样的图像,如imgA、imgB、imgC等,实现上述目标的最快方法是什么?

使用torch.gatherrepeat_interleave:可以实现基于切片的解决方案

select = torch.tensor([1, 0], device=imgA.device)
imgA = = imgA.gather(dim=1, index=select.repeat_interleave(128, dim=0).view(256, 1, 1, 1).expand(-1, -1, *imgA.shape[-2:]))

您也可以使用矩阵乘法和repeat_interleave:

# select c=1 for first half and c=0 for second
select = torch.tensor([[0, 1],[1, 0],[0, 0]], dtype=imgA.dtype, device=imgA.device)
imgA = torch.einsum('cb,bchw->bhw',select.repeat_interleave(128, dim=1), imgA).unsqueeze(dim=1)

最新更新