如何使用参数batch_dims在 Pytorch 中实现tf.gather_nd?



我一直在做一个关于图像匹配的项目,所以我需要找到 2 张图像之间的对应关系。为了获得描述符,我将需要一个插值函数。但是,当我读到在Tensorflow中完成的等效函数时,我仍然不知道如何在Pytorch中实现tf.gather_nd(parmas,index,barch_dims(。尤其是当有争论时:batch_dims。我已经经历了堆栈溢出,还没有完美的等价物。

Tensorflow 中引用的插值函数如下所示,我一直在尝试在 Pytorch 中实现这一点,参数的信息如下:

输入是来自批量大小的 for 循环的密集特征映射[i],这意味着它是 3D[H, W, C](在 pytorch 中是 [C, H, W](

POS是一组随机点坐标形状,如[[i,j],[i,j],...,[i,j]],因此当它进入插值函数时它是2D的(在pytorch中是[[i,i,...,i],[j,j,...,j]](

然后,当它们进入此功能时,它会扩展它们的两个维度

我只想要一个完美的tf.gather_nd实现与论证batch_dims。谢谢! 下面是使用它的简单示例:

pos = tf.ones((12, 2)) ## stands for a set of coordinates [[i, i,…, i], [j, j,…, j]]
inputs = tf.ones((4, 4, 128)) ## stands for [H, W, C] of dense feature map
outputs = interpolate(pos, inputs, batched=False)
print(outputs.get_shape()) # We get (12, 128) here

插值函数(TF 版本(:

def interpolate(pos, inputs, nd=True):
pos = tf.expand_dims(pos, 0)
inputs = tf.expand_dims(inputs, 0)
h = tf.shape(inputs)[1]
w = tf.shape(inputs)[2]
i = pos[:, :, 0]
j = pos[:, :, 1]
i_top_left = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
j_top_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)
i_top_right = tf.clip_by_value(tf.cast(tf.math.floor(i), tf.int32), 0, h - 1)
j_top_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)
i_bottom_left = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
j_bottom_left = tf.clip_by_value(tf.cast(tf.math.floor(j), tf.int32), 0, w - 1)
i_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(i), tf.int32), 0, h - 1)
j_bottom_right = tf.clip_by_value(tf.cast(tf.math.ceil(j), tf.int32), 0, w - 1)
dist_i_top_left = i - tf.cast(i_top_left, tf.float32)
dist_j_top_left = j - tf.cast(j_top_left, tf.float32)
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
w_bottom_right = dist_i_top_left * dist_j_top_left
if nd:
w_top_left = w_top_left[..., None]
w_top_right = w_top_right[..., None]
w_bottom_left = w_bottom_left[..., None]
w_bottom_right = w_bottom_right[..., None]
interpolated_val = (
w_top_left * tf.gather_nd(inputs, tf.stack([i_top_left, j_top_left], axis=-1), batch_dims=1) +
w_top_right * tf.gather_nd(inputs, tf.stack([i_top_right, j_top_right], axis=-1), batch_dims=1) +
w_bottom_left * tf.gather_nd(inputs, tf.stack([i_bottom_left, j_bottom_left], axis=-1), batch_dims=1) +
w_bottom_right * tf.gather_nd(inputs, tf.stack([i_bottom_right, j_bottom_right], axis=-1), batch_dims=1)
)
interpolated_val = tf.squeeze(interpolated_val, axis=0)
return interpolated_val

据我所知,在 PyTorch 中没有直接等同于tf.gather_nd,使用batch_dims实现通用版本并不是那么简单。但是,您可能不需要通用版本,并且考虑到interpolate函数的上下文,[C, H, W]的版本就足够了。

interpolate开始时,将单个维度添加到前面,即批次维度。在tf.gather_nd中设置batch_dims=1意味着开始时有一个批处理维度,因此它每批应用它,即它用pos[0]等索引inputs[0]。添加单个批次维度没有任何好处,因为您可能只使用直接计算。

# Adding singular batch dimension
# Shape: [1, num_pos, 2]
pos = tf.expand_dims(pos, 0)
# Shape: [1, H, W, C]
inputs = tf.expand_dims(inputs, 0)

batched_result = tf.gather_nd(inputs, pos, batch_dims=1)
single_result = tf.gater_nd(inputs[0], pos[0])
# The first element in the batched result is the same as the single result
# Hence there is no benefit to adding a singular batch dimension.
tf.reduce_all(batched_result[0] == single_result) # => True

单一版本

在 PyTorch 中,[H, W, C]的实现可以通过 Python 的索引来完成。虽然 PyTorch 通常对图像使用[C, H, W],但这只是索引哪个维度的问题,但为了进行比较,让我们将它们与 TensorFlow 中的相同。如果要手动为它们编制索引,则可以按以下方式执行此操作:inputs[pos_h[0], pos_w[0]]inputs[pos_h[1], pos_w[1]]等。PyTorch 允许您通过将索引作为列表提供来自动执行此操作:inputs[pos_h, pos_w],其中pos_hpos_w具有相同的长度。您需要做的就是将pos拆分为两个单独的张量,一个用于沿高度维度的索引,另一个用于沿宽度维度的索引,您在 TensorFlow 版本中也这样做了。

inputs = torch.randn(4, 4, 128)
# Random positions 0-3, shape: [12, 2]
pos = torch.randint(4, (12, 2))
# Positions split by dimension
pos_h = pos[:, 0]
pos_w = pos[:, 1]
# Index the inputs with the indices per dimension
gathered = inputs[pos_h, pos_w]
# Verify that it's identical to TensorFlow's output
inputs_tf = tf.convert_to_tensor(inputs.numpy())
pos_tf = tf.convert_to_tensor(pos.numpy())
gathered_tf = tf.gather_nd(inputs_tf, pos_tf)
gathered_tf = torch.from_numpy(gathered_tf.numpy())
torch.equal(gathered_tf, gathered) # => True

如果要将其应用于大小为[C, H, W]的张量,则只需更改要索引的维度:

# For [H, W, C]
gathered = inputs[pos_h, pos_w]
# For [C, H, W]
gathered = inputs[:, pos_h, pos_w]

批处理版本

使其成为批处理批处理版本(用于[N, H, W, C][N, C, H, W](并不困难,使用它更合适,因为无论如何您都在处理批处理。唯一棘手的部分是批处理中的每个元素应仅应用于相应的批处理。为此,需要枚举批次维度,这可以通过torch.arange来完成。批处理枚举只是带有批处理索引的列表,它将与pos_hpos_w索引组合在一起,从而产生inputs[0, pos_h[0, 0], pos_h[0, 0]]inputs[0, pos_h[0, 1], pos_h[0, 1]]...inputs[1, pos_h[1, 0], pos_h[1, 0]]

batch_size = 3
inputs = torch.randn(batch_size, 4, 4, 128)
# Random positions 0-3, different for each batch, shape: [3, 12, 2]
pos = torch.randint(4, (batch_size, 12, 2))
# Positions split by dimension
pos_h = pos[:, :, 0]
pos_w = pos[:, :, 1]
batch_enumeration = torch.arange(batch_size) # => [0, 1, 2]
# pos_h and pos_w have shape [3, 12], so the batch enumeration needs to be
# repeated 12 times per batch.
# Unsqueeze to get shape [3, 1], now the 1 could be repeated to 12, but
# broadcasting will do that automatically.
batch_enumeration = batch_enumeration.unsqueeze(1)
# Index the inputs with the indices per dimension
gathered = inputs[batch_enumeration, pos_h, pos_w]
# Again, verify that it's identical to TensorFlow's output
inputs_tf = tf.convert_to_tensor(inputs.numpy())
pos_tf = tf.convert_to_tensor(pos.numpy())
# This time with batch_dims=1
gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=1)
gathered_tf = torch.from_numpy(gathered_tf.numpy())
torch.equal(gathered_tf, gathered) # => True

同样,对于[N, C, H, W],只需要更改已编制索引的维度:

# For [N, H, W, C]
gathered = inputs[batch_enumeration, pos_h, pos_w]
# For [N, C, H, W]
gathered = inputs[batch_enumeration, :, pos_h, pos_w]

只是关于interpolate实现的一点旁注,四舍五入仓位(分别为 floor 和 ceil(是没有意义的,因为索引必须是整数,所以只要你的仓位是实际索引,它就没有效果。这也导致i_top_lefti_bottom_left是相同的值,但即使它们要以不同的方式舍入,它们也始终相差 1 位。此外,i_top_lefti_top_right实际上是相同的。我不认为这个函数会产生有意义的输出。我不知道你想实现什么,但如果你正在寻找图像插值,你可以看看torch.nn.functional.interpolate

这只是Michael Jungo批处理版本答案的扩展,当pos是2D数组而不是1D数组(不包括批处理维度(时。

bs = 2
H = 4
W = 6
C = 3
inputs = torch.randn(bs, H, W, C)
pos_h = torch.randint(H, (bs, H, W))
pos_w = torch.randint(W, (bs, H, W))
batch_enumeration = torch.arange(bs)
batch_enumeration = batch_enumeration.unsqueeze(1).unsqueeze(2)
inputs.shape
Out[34]: torch.Size([2, 4, 6, 3])
pos_h.shape
Out[35]: torch.Size([2, 4, 6])
pos_w.shape
Out[36]: torch.Size([2, 4, 6])
batch_enumeration.shape
Out[37]: torch.Size([2, 1, 1])
gathered = inputs[batch_enumeration, pos_h, pos_w]

对于通道优先,我们还需要枚举通道

inputs = torch.randn(bs, C, H, W)
pos_h = torch.randint(H, (bs, 1, H, W))
pos_w = torch.randint(W, (bs, 1, H, W))
batch_enumeration = torch.arange(bs)
batch_enumeration = batch_enumeration.unsqueeze(1).unsqueeze(2).unsqueeze(3)
channel_enumeration = torch.arange(C)
channel_enumeration = channel_enumeration.unsqueeze(0).unsqueeze(2).unsqueeze(3)
inputs.shape
Out[49]: torch.Size([2, 3, 4, 6])
pos_h.shape
Out[50]: torch.Size([2, 1, 4, 6])
pos_w.shape
Out[51]: torch.Size([2, 1, 4, 6])
batch_enumeration.shape
Out[52]: torch.Size([2, 1, 1, 1])
channel_enumeration.shape
Out[57]: torch.Size([1, 3, 1, 1])
gathered = inputs[batch_enumeration, channel_enumeration, pos_h, pos_w]
gathered.shape
Out[59]: torch.Size([2, 3, 4, 6])

让我们验证一下

inputs_np = inputs.numpy()
pos_h_np = pos_h.numpy()
pos_w_np = pos_w.numpy()
gathered_np = gathered.numpy()
pos_h_np[0,0,0,0]
Out[68]: 0
pos_w_np[0,0,0,0]
Out[69]: 3
inputs_np[0,:,0,3]
Out[71]: array([ 0.79122806, -2.190181  , -0.16741803], dtype=float32)
gathered_np[0,:,0,0]
Out[72]: array([ 0.79122806, -2.190181  , -0.16741803], dtype=float32)
pos_h_np[1,0,3,4]
Out[73]: 1
pos_w_np[1,0,3,4]
Out[74]: 2
inputs_np[1,:,1,2]
Out[75]: array([ 0.9282498 , -0.34945545,  0.9136222 ], dtype=float32)
gathered_np[1,:,3,4]
Out[77]: array([ 0.9282498 , -0.34945545,  0.9136222 ], dtype=float32)

我改进了Michael Jungo的实现。现在,它支持任意前导批次维度。

def gather_nd_torch(params, indices, batch_dim=1):
""" A PyTorch porting of tensorflow.gather_nd
This implementation can handle leading batch dimensions in params, see below for detailed explanation.
The majority of this implementation is from Michael Jungo @ https://stackoverflow.com/a/61810047/6670143
I just ported it compatible to leading batch dimension.
Args:
params: a tensor of dimension [b1, ..., bn, g1, ..., gm, c].
indices: a tensor of dimension [b1, ..., bn, x, m]
batch_dim: indicate how many batch dimension you have, in the above example, batch_dim = n.
Returns:
gathered: a tensor of dimension [b1, ..., bn, x, c].
Example:
>>> batch_size = 5
>>> inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32)
>>> pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3))
>>> gathered = gather_nd_torch(inputs, pos, batch_dim=3)
>>> gathered.shape
torch.Size([5, 5, 5, 12, 32])
>>> inputs_tf = tf.convert_to_tensor(inputs.numpy())
>>> pos_tf = tf.convert_to_tensor(pos.numpy())
>>> gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3)
>>> gathered_tf.shape
TensorShape([5, 5, 5, 12, 32])
>>> gathered_tf = torch.from_numpy(gathered_tf.numpy())
>>> torch.equal(gathered_tf, gathered)
True
"""
batch_dims = params.size()[:batch_dim]  # [b1, ..., bn]
batch_size = np.cumprod(list(batch_dims))[-1]  # b1 * ... * bn
c_dim = params.size()[-1]  # c
grid_dims = params.size()[batch_dim:-1]  # [g1, ..., gm]
n_indices = indices.size(-2)  # x
n_pos = indices.size(-1)  # m
# reshape leadning batch dims to a single batch dim
params = params.reshape(batch_size, *grid_dims, c_dim)
indices = indices.reshape(batch_size, n_indices, n_pos)
# build gather indices
# gather for each of the data point in this "batch"
batch_enumeration = torch.arange(batch_size).unsqueeze(1)
gather_dims = [indices[:, :, i] for i in range(len(grid_dims))]
gather_dims.insert(0, batch_enumeration)
gathered = params[gather_dims]
# reshape back to the shape with leading batch dims
gathered = gathered.reshape(*batch_dims, n_indices, c_dim)
return gathered

我还制作了一个演示Colab笔记本,你可以在这里查看。根据我在带有 GPU 实例的 Colab 服务器上的差速测试,此实现比 TF 的原始实现要快得多。

最新更新