我使用的是形状为(N,2,127,52(的4-D张量
我用过:
tf.keras.backend.repeat_elements(tensor, 2, axis=3)
这通过重复每个值来复制从52到104的最后一个轴大小:
shape=(N,2127104(
现在我想要相同的,但只有来自第三轴的最后10个元素现在有:
shape=(N,2127114(
我也在考虑如何添加一个额外的">列";通过在最后一个轴张量的中间添加零向量,得到:
shape=(N,2127115(
我该怎么做?
我认为使用tf.concat
是一种简单的方法:
import tensorflow as tf
N = 2
tensor = tf.random.normal((N, 2, 127, 52))
tensor = tf.repeat(tensor, 2, axis=3)
# (N, 2, 127, 114)
tensor = tf.concat([tensor, tensor[..., tf.shape(tensor)[-1]-10:]], axis=-1)
# (N, 2, 127, 115)
middle = tf.shape(tensor)[-1]//2
tensor = tf.concat([tensor[..., :middle], tf.zeros((N, 2, 127, 1)), tensor[..., middle:]], axis=-1)
print(tensor.shape)
(2, 2, 127, 115)