在重叠和非重叠过程中从3D图像中提取3D补丁,并恢复图像



我正在处理172x220x156形状的3D图像。要将图像馈送到网络进行输出,我需要从图像中提取大小为 32x32x32 的补丁,然后添加回这些补丁以再次获取图像。 由于我的图像尺寸不是补丁大小的倍数,因此我必须获得重叠的补丁。 我想知道如何做到这一点。

我在 PyTorch 工作,有一些选项,如unfoldfold,但我不确定它们是如何工作的。

您可以使用unfold(pytorch docs):

batch_size, n_channels, n_rows, n_cols = 1, 172, 220, 156
x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)
kernel_c, kernel_h, kernel_w = 32, 32, 32
step = 32
# Tensor.unfold(dimension, size, step)
windows_unpacked = x.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
print(windows_unpacked.shape)
# result: torch.Size([1, 5, 6, 4, 32, 32, 32])
windows = windows_unpacked.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([120, 32, 32, 32])

要提取(重叠)补丁并重建输入形状,我们可以使用torch.nn.functional.unfold和逆运算torch.nn.functional.fold。这些方法仅处理 4D 张量或 2D 图像,但是您可以使用这些方法一次处理一个维度。

几点注意事项:

  1. 这种方式需要来自 pytorch的折叠/展开方法,不幸的是,我还没有在 TF API 中找到类似的方法。

  2. 我们可以通过两种方式提取补丁,它们的输出是相同的。这些方法称为extract_patches_3dextract_patches_3ds其中 X 是维度数。后者使用火炬。Tensor.unfold() 并且代码行较少。(输出相同,只是不能使用膨胀)

  3. 方法extract_patches_Xdcombine_patches_Xd反向方法,合并器逐步反转提取器的步骤。

  4. 代码行后跟一个注释,说明维度,例如 (B, C, D, H, W)。使用以下内容:

    1. B: 批量大小
    2. C: 频道
    3. D: 深度尺寸
    4. H: 高度尺寸
    5. W:宽度尺寸
    6. x_dim_in:在提取方法中,这是维度x中的输入像素数。在组合方法中,这是维度x中滑动窗口的数量。
    7. x_dim_out:在提取方法中,这是维度x中的滑动窗口数。在组合方法中,这是维度x中的输出像素数。
  5. 我有一个公共笔记本来试用代码

  6. get_dim_blocks()方法是pytorch文档网站上给出的函数,用于计算卷积层的输出形状。

  7. 请注意,如果您有重叠的图块并将它们组合在一起,则将对重叠的元素求和。如果您想再次获得初始输入,有一种方法。

    1. 创建与具有torch.ones_like(patches_tensor)的补丁类似大小的张量。
    2. 将补丁组合成具有相同输出形状的完整图像。(这将为重叠元素创建一个计数器)。
    3. 将组合图像
    4. 与组合图像相除,这应该反转元素的任何双重求和。 (3D): 我们需要使用 2foldunfold,首先将fold应用于D维度,并通过将内核设置为 1、填充设置为 0、步幅设置为 1 和膨胀为 1 来保持WH不变。在我们查看张量并在HW维度上折叠之后。展开是相反的,从HW开始,然后D
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)

# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)                                                     
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))                   
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)                                   
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))        
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)  
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x

def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)

x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_3d(a, 2, padding=1, stride=1)
bs = extract_patches_3ds(a, 2, padding=1, stride=1)
print(b.shape)
print(b)
c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(c.shape)
print(c)
ones = torch.ones_like(b)
ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(torch.all(a==c))
print(c.shape, ones.shape)
d = c / ones
print(d)
print(torch.all(a==d))

输出(3D)

torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
[  5.,   6.,   7.,   8.],
[  9.,  10.,  11.,  12.],
[ 13.,  14.,  15.,  16.]],
[[ 17.,  18.,  19.,  20.],
[ 21.,  22.,  23.,  24.],
[ 25.,  26.,  27.,  28.],
[ 29.,  30.,  31.,  32.]]],

[[[ 33.,  34.,  35.,  36.],
[ 37.,  38.,  39.,  40.],
[ 41.,  42.,  43.,  44.],
[ 45.,  46.,  47.,  48.]],
[[ 49.,  50.,  51.,  52.],
[ 53.,  54.,  55.,  56.],
[ 57.,  58.,  59.,  60.],
[ 61.,  62.,  63.,  64.]]]],

[[[[ 65.,  66.,  67.,  68.],
[ 69.,  70.,  71.,  72.],
[ 73.,  74.,  75.,  76.],
[ 77.,  78.,  79.,  80.]],
[[ 81.,  82.,  83.,  84.],
[ 85.,  86.,  87.,  88.],
[ 89.,  90.,  91.,  92.],
[ 93.,  94.,  95.,  96.]]],

[[[ 97.,  98.,  99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
torch.Size([150, 2, 2, 2, 2])
tensor([[[[[  0.,   0.],
[  0.,   0.]],
[[  0.,   0.],
[  0.,   1.]]],

[[[  0.,   0.],
[  0.,   0.]],
[[  0.,   0.],
[  1.,   2.]]]],

[[[[  0.,   0.],
[  0.,   0.]],
[[  0.,   0.],
[  2.,   3.]]],

[[[  0.,   0.],
[  0.,   0.]],
[[  0.,   0.],
[  3.,   4.]]]],

[[[[  0.,   0.],
[  0.,   0.]],
[[  0.,   0.],
[  4.,   0.]]],

[[[  0.,   0.],
[  0.,   0.]],
[[  0.,   1.],
[  0.,   5.]]]],

...,

[[[[124.,   0.],
[128.,   0.]],
[[  0.,   0.],
[  0.,   0.]]],

[[[  0., 125.],
[  0.,   0.]],
[[  0.,   0.],
[  0.,   0.]]]],

[[[[125., 126.],
[  0.,   0.]],
[[  0.,   0.],
[  0.,   0.]]],

[[[126., 127.],
[  0.,   0.]],
[[  0.,   0.],
[  0.,   0.]]]],

[[[[127., 128.],
[  0.,   0.]],
[[  0.,   0.],
[  0.,   0.]]],

[[[128.,   0.],
[  0.,   0.]],
[[  0.,   0.],
[  0.,   0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[   8.,   16.,   24.,   32.],
[  40.,   48.,   56.,   64.],
[  72.,   80.,   88.,   96.],
[ 104.,  112.,  120.,  128.]],
[[ 136.,  144.,  152.,  160.],
[ 168.,  176.,  184.,  192.],
[ 200.,  208.,  216.,  224.],
[ 232.,  240.,  248.,  256.]]],

[[[ 264.,  272.,  280.,  288.],
[ 296.,  304.,  312.,  320.],
[ 328.,  336.,  344.,  352.],
[ 360.,  368.,  376.,  384.]],
[[ 392.,  400.,  408.,  416.],
[ 424.,  432.,  440.,  448.],
[ 456.,  464.,  472.,  480.],
[ 488.,  496.,  504.,  512.]]]],

[[[[ 520.,  528.,  536.,  544.],
[ 552.,  560.,  568.,  576.],
[ 584.,  592.,  600.,  608.],
[ 616.,  624.,  632.,  640.]],
[[ 648.,  656.,  664.,  672.],
[ 680.,  688.,  696.,  704.],
[ 712.,  720.,  728.,  736.],
[ 744.,  752.,  760.,  768.]]],

[[[ 776.,  784.,  792.,  800.],
[ 808.,  816.,  824.,  832.],
[ 840.,  848.,  856.,  864.],
[ 872.,  880.,  888.,  896.]],
[[ 904.,  912.,  920.,  928.],
[ 936.,  944.,  952.,  960.],
[ 968.,  976.,  984.,  992.],
[1000., 1008., 1016., 1024.]]]]])
tensor(False)
torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
tensor([[[[[  1.,   2.,   3.,   4.],
[  5.,   6.,   7.,   8.],
[  9.,  10.,  11.,  12.],
[ 13.,  14.,  15.,  16.]],
[[ 17.,  18.,  19.,  20.],
[ 21.,  22.,  23.,  24.],
[ 25.,  26.,  27.,  28.],
[ 29.,  30.,  31.,  32.]]],

[[[ 33.,  34.,  35.,  36.],
[ 37.,  38.,  39.,  40.],
[ 41.,  42.,  43.,  44.],
[ 45.,  46.,  47.,  48.]],
[[ 49.,  50.,  51.,  52.],
[ 53.,  54.,  55.,  56.],
[ 57.,  58.,  59.,  60.],
[ 61.,  62.,  63.,  64.]]]],

[[[[ 65.,  66.,  67.,  68.],
[ 69.,  70.,  71.,  72.],
[ 73.,  74.,  75.,  76.],
[ 77.,  78.,  79.,  80.]],
[[ 81.,  82.,  83.,  84.],
[ 85.,  86.,  87.,  88.],
[ 89.,  90.,  91.,  92.],
[ 93.,  94.,  95.,  96.]]],

[[[ 97.,  98.,  99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
tensor(True)

你所有的数据都172x220x156吗?如果是这样,似乎你可以使用 for 循环并索引到张量中来获取32x32x32块,对吗?(可能硬编码了一些东西)。

但是,我无法完全回答这个问题,因为不清楚您希望如何组合结果。需要明确的是,这是你的目标吗?

1) 从图像中获取32x32x32补丁 2)对它进行一些任意处理 3) 将该补丁保存到正确索引处的某个result4) 重复

如果是这样,您计划如何组合重叠的补丁?求和?平均他们?

但是 - 索引:

out_tensor = torch.zeros_like(input)
for i_idx in [0, 32, 64, 96, 128, 140]:
for j_idx in [0, 32, 64, 96, 128, 160, 188]:
for k_idx in [0, 32, 64, 96, 124]:
input = tensor[i_idx, j_idx, k_idx]
output = your_model(input)
out_tensor[i_idx, j_idx, k_idx] = output

这根本没有优化,但我想大部分计算将是实际的神经网络,而且没有办法解决这个问题,所以优化可能毫无意义。

最新更新