tf.gather具有比输入数据更高维度的索引



阅读用于点云学习的动态图CNN代码时,我看到了以下片段:

idx_ = tf.range(batch_size) * num_points
idx_ = tf.reshape(idx_, [batch_size, 1, 1]) 
point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_)  <--- what happens here?
point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)

调试线路我确保调光是

point_cloud_flat:(32768,3) nn_idx:(32,1024,20), idx_:(32,1,1) 
// indices are (32,1024,20) after broadcasting

阅读tf.gather文档时,我不明白该函数对高于输入尺寸的尺寸做了什么

numpy中的等效函数是np.take,一个简单的例子:

import numpy as np
params = np.array([4, 3, 5, 7, 6, 8])
# Scalar indices; (output is rank(params) - 1), i.e. 0 here.
indices = 0
print(params[indices])
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [0, 1, 4]
print(params[indices])  # [4 3 6]
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [2, 3, 4]
print(params[indices])  # [5 7 6]
# Higher rank indices; (output is rank(params) + rank(indices) - 1), i.e. 2 here
indices = np.array([[0, 1, 4], [2, 3, 4]])
print(params[indices])  # equivalent to np.take(params, indices, axis=0)
# [[4 3 6]
# [5 7 6]]

在您的情况下,indices的秩高于params,因此输出为秩(params(+秩(indices(-1(即2+3-1=4,即(321024,20,3((。- 1是因为此时tf.gather(axis=0)axis的秩必须为0(因此是标量(。因此,indices以一种"奇特"的索引方式获取第一维度(axis=0(的元素。

编辑

简而言之,在你的情况下,(如果我没有误解代码的话(

  • point_cloud是(321024,3(,32批1024个点,其中有3个坐标
  • nn_idx是(321024,20(32批1024分。这些索引用于point_cloud中的索引
  • nn_idx+idx_(321024,20(,的20个邻居的索引32批1024分。这些索引用于point_cloud_flat中的索引
  • CCD_ 16最终为(321024,20,3(,与nn_idx+idx_相同,只是point_cloud_neighbors是它们的3个坐标,而nn_idx+idx_只是它们的索引

最新更新