如何提取三维图像张量的1个切片



假设我有一个64x64x64的3D图像。我还有一个向量x,它的长度是64。

我想采取这样的"argmax(x("层:

2d_image = 3d_image[:,argmax(x),:]

更精确(对于张量流(:

def extract_slice(x,3d_image):
slice_index = tf.math.argmax(x,axis=1,output_type=tf.dtype.int32) #it has to be int for indexing
return 3d_image[:,slice_index,:]

错误为:

Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor 'ArgMax_50:0' shape=(None,) dtype=int32>

参数的np.shape是:

3d图像形状为(None, 64, 64, 64, 1)

x形状为(None, 64)

切片_索引形状为(None,)

->3d_image形状的,1维度是因为它是来自数组的样本。。我认为这无关紧要

我知道没有一个形状是批量大小,这是未知的,但其他形状看起来很好。。那么问题出在哪里呢?

据我所知,这个索引看起来不是int32,但我确实把它转换成了tf.int,所以问题出在哪里??也许int32与tf.int32不同?或者我使用的索引方法在tensorflow中无效?也许它应该是这样一个函数:tf.index(image,[:,slice_index,:](。。?

谢谢!

Argmax返回1D张量。将其转换为标量:

slice_index = tf.reshape(slice_index, ())

相关内容

  • 没有找到相关文章

最新更新