我已经学会了如何在一维上切割张量。
我已经学会了如何切片一个二维张量给出一个特定值的一维张量。
都使用tf.gather()
,但我很确定我需要tf.gather_nd()
,虽然我显然使用错了。
在numpy中,我有一个5x5的2D数组,我可以通过使用np.ix_()
与行和列索引切片2x2数组(我总是需要相同的行和列索引,导致平方矩阵):
import numpy as np
a = np.array([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])
a
array([[ 1, 2, 3, 4, 5], [ 2, 1, 6, 7, 8], [ 3, 6, 1, 9, 10], [ 4, 7, 9, 1, 11], [ 5, 8, 10, 11, 1]])
a[np.ix_([1,3], [1,3])]
array([[1, 7], [7, 1]])
阅读tf.gather_nd()
文档,我认为这是在TF中做到这一点的方式,但我使用错误:
import tensorflow as tf
a = tf.constant([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])
tf.gather_nd(a, [[1,3], [1,3]])
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>
我必须这样做:
tf.gather_nd(a, [[[1,1], [1,3]],[[3,1],[3,3]]])
<tf.Tensor: shape=(2, 2), dtype=int32, numpy= array([[1, 7], [7, 1]])>
这让我掉进了另一个我不喜欢的兔子洞。我的下标向量当然要长很多
我的指标,顺便说一句,是1D整数张量本身。所以底线是,我想用与np._ix()
相同的行和列索引对a
进行切片,我的索引是这样的:
idx = tf.constant([1, 3])
# tf.gather_nd(a, indices = "something with idx")
要用长度为d的1D张量对nxn 2D数组进行切片,从而得到具有指定索引的dxd 2D数组,可以使用tf.repeat
,tf.tile
和tf.stack
:
n = 5
a = tf.constant(np.arange(n * n).reshape(n, n)) # 2D nxn array
idx = [1,2,4] # 1D tensor with length d
d = tf.shape(idx)[0]
ix_ = tf.reshape(tf.stack([tf.repeat(idx,d),tf.tile(idx,[d])],1),[d,d,2])
target = tf.gather_nd(a,ix_) # 2D dxd array
print(a)
print(target)
预期输出:
tf.Tensor(
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]], shape=(5, 5), dtype=int64)
tf.Tensor(
[[ 6 7 9]
[11 12 14]
[21 22 24]], shape=(3, 3), dtype=int64)