切片一个2D张量,类似于numpy np.ix_



我已经学会了如何在一维上切割张量。

我已经学会了如何切片一个二维张量给出一个特定值的一维张量。

都使用tf.gather(),但我很确定我需要tf.gather_nd(),虽然我显然使用错了。

在numpy中,我有一个5x5的2D数组,我可以通过使用np.ix_()与行和列索引切片2x2数组(我总是需要相同的行和列索引,导致平方矩阵):

import numpy as np
a = np.array([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])
a
array([[ 1,  2,  3,  4,  5],
[ 2,  1,  6,  7,  8],
[ 3,  6,  1,  9, 10],
[ 4,  7,  9,  1, 11],
[ 5,  8, 10, 11,  1]])
a[np.ix_([1,3], [1,3])]
array([[1, 7],
[7, 1]])

阅读tf.gather_nd()文档,我认为这是在TF中做到这一点的方式,但我使用错误:

import tensorflow as tf
a = tf.constant([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])
tf.gather_nd(a, [[1,3], [1,3]])
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>

我必须这样做:

tf.gather_nd(a, [[[1,1], [1,3]],[[3,1],[3,3]]])
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 7],
[7, 1]])>

这让我掉进了另一个我不喜欢的兔子洞。我的下标向量当然要长很多

我的指标,顺便说一句,是1D整数张量本身。所以底线是,我想用与np._ix()相同的行和列索引对a进行切片,我的索引是这样的:

idx = tf.constant([1, 3])
# tf.gather_nd(a, indices = "something with idx")

要用长度为d的1D张量对nxn 2D数组进行切片,从而得到具有指定索引的dxd 2D数组,可以使用tf.repeat,tf.tiletf.stack:

n = 5
a = tf.constant(np.arange(n * n).reshape(n, n)) # 2D nxn array
idx = [1,2,4] # 1D tensor with length d
d = tf.shape(idx)[0]
ix_ = tf.reshape(tf.stack([tf.repeat(idx,d),tf.tile(idx,[d])],1),[d,d,2])
target = tf.gather_nd(a,ix_) # 2D dxd array
print(a)
print(target)

预期输出:

tf.Tensor(
[[ 0  1  2  3  4]
[ 5  6  7  8  9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]], shape=(5, 5), dtype=int64)
tf.Tensor(
[[ 6  7  9]
[11 12 14]
[21 22 24]], shape=(3, 3), dtype=int64)

相关内容

  • 没有找到相关文章

最新更新