我正在将一批图像传递给我的神经网络。假设批次的形状是(4, 224,224,3)
.现在我想对我的批处理应用切片操作,以便我可以分别获得两个形状为(2,224,224,3)
的张量。如何使用tf.slice()
或类似的东西来做到这一点?
我认为您更想使用tf.split
.例如,在您的情况下,
tf.split(my_tensor, 2)