假设我有一个64x64x64的3D图像。我还有一个向量x,它的长度是64。
我想采取这样的"argmax(x("层:
2d_image = 3d_image[:,argmax(x),:]
更精确(对于张量流(:
def extract_slice(x,3d_image):
slice_index = tf.math.argmax(x,axis=1,output_type=tf.dtype.int32) #it has to be int for indexing
return 3d_image[:,slice_index,:]
错误为:
Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor 'ArgMax_50:0' shape=(None,) dtype=int32>
参数的np.shape是:
3d图像形状为(None, 64, 64, 64, 1)
x形状为(None, 64)
切片_索引形状为(None,)
->3d_image形状的,1维度是因为它是来自数组的样本。。我认为这无关紧要
我知道没有一个形状是批量大小,这是未知的,但其他形状看起来很好。。那么问题出在哪里呢?
据我所知,这个索引看起来不是int32,但我确实把它转换成了tf.int,所以问题出在哪里??也许int32与tf.int32不同?或者我使用的索引方法在tensorflow中无效?也许它应该是这样一个函数:tf.index(image,[:,slice_index,:](。。?
谢谢!
Argmax返回1D张量。将其转换为标量:
slice_index = tf.reshape(slice_index, ())