矢量化张量流步长的正确方法



我正在处理4维张量,需要进行一些计算,这些计算与下面的示例类似。取A为形状为(6,64,64,64)的张量。我想使用函数tf.where来获得具有大于0.75的值的每个(64,64,64)体积的体素。我唯一能做到这一点的方法是这样的:

X = tf.convert_to_tensor([tf.where(A[i,:,:,:] > 0.75) for i in range(A.shape[0])]

这似乎是一个非常粗糙的解决方案。有更好的方法来实现这一点吗?

您尝试执行的操作的问题在于,它要求每个(64, 64, 64)卷具有相同数量的大于0.75的值。如果是这样的话,您可以执行以下操作:

X = tf.reshape(tf.where(A > 0.75)[:, 1:], (A.shape[0], -1, A.shape.ndims - 1))

但如果不是这样的话,你就不能有这样的张量,因为第二个维度需要有多个大小。

相关内容

最新更新