Tensorflow数据集,如何在每个批次内连接/重复数据?



如果我有以下数据集:dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])

当我使用batch_size=2时,我将得到[[1,2], [3,4], [5,6]]

但是,我希望得到以下输出:[[1,2,1,2], [3,4,3,4], [5,6,5,6]]

基本上,我想将批处理维度重复2x,并将其用作新批处理。很明显,这是一个简单的例子。在实际情况中,如果我有一个大小为(64, 300)的批,我想创建一个大小为(128, 300)的批。

可以通过定义一个map函数

def double_input(x):
x = tf.concat([x,x],axis=0)
return x
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
dataset = dataset.batch(2)
dataset = dataset.map(double_input)
for x in dataset.take(-1):
print(x)
>>>tf.Tensor([1 2 1 2], shape=(4,), dtype=int32)
>>>tf.Tensor([3 4 3 4], shape=(4,), dtype=int32)
>>>tf.Tensor([5 6 5 6], shape=(4,), dtype=int32)

相关内容

  • 没有找到相关文章

最新更新