阅读用于点云学习的动态图CNN代码时,我看到了以下片段:
idx_ = tf.range(batch_size) * num_points
idx_ = tf.reshape(idx_, [batch_size, 1, 1])
point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims])
point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_) <--- what happens here?
point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2)
调试线路我确保调光是
point_cloud_flat:(32768,3) nn_idx:(32,1024,20), idx_:(32,1,1)
// indices are (32,1024,20) after broadcasting
阅读tf.gather文档时,我不明白该函数对高于输入尺寸的尺寸做了什么
numpy中的等效函数是np.take
,一个简单的例子:
import numpy as np
params = np.array([4, 3, 5, 7, 6, 8])
# Scalar indices; (output is rank(params) - 1), i.e. 0 here.
indices = 0
print(params[indices])
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [0, 1, 4]
print(params[indices]) # [4 3 6]
# Vector indices; (output is rank(params)), i.e. 1 here.
indices = [2, 3, 4]
print(params[indices]) # [5 7 6]
# Higher rank indices; (output is rank(params) + rank(indices) - 1), i.e. 2 here
indices = np.array([[0, 1, 4], [2, 3, 4]])
print(params[indices]) # equivalent to np.take(params, indices, axis=0)
# [[4 3 6]
# [5 7 6]]
在您的情况下,indices
的秩高于params
,因此输出为秩(params
(+秩(indices
(-1(即2+3-1=4,即(321024,20,3((。- 1
是因为此时tf.gather(axis=0)
和axis
的秩必须为0(因此是标量(。因此,indices
以一种"奇特"的索引方式获取第一维度(axis=0
(的元素。
编辑:
简而言之,在你的情况下,(如果我没有误解代码的话(
point_cloud
是(321024,3(,32批1024个点,其中有3个坐标nn_idx
是(321024,20(32批1024分。这些索引用于point_cloud
中的索引nn_idx+idx_
(321024,20(,的20个邻居的索引32批1024分。这些索引用于point_cloud_flat
中的索引- CCD_ 16最终为(321024,20,3(,与
nn_idx+idx_
相同,只是point_cloud_neighbors
是它们的3个坐标,而nn_idx+idx_
只是它们的索引