TensorFlow:将tf.Dataset转换为tf.Tensor



我想生成10:范围的窗口

import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)

我想在这个数据集上训练我的模型。

为此,必须将这些窗口转换为张量。但是这些窗口的数据类型不能通过tf.convert_to_tensor转换为张量。做tf.convert_to_tensor(list(window))是可能的,但这是非常低效的。

有人知道如何有效地将tf.VariantDataset转换为tf.Tensor吗?

谢谢你的帮助!

如果你想创建一个滑动窗口的张量,那么通过数据集进行创建并不是最好的方法,效率和灵活性要低得多。我不认为有合适的操作,但对于2D和3D阵列有两个类似的操作,tf.image.extract_patchestf.extract_volume_patches。您可以重塑1D数据以使用它们:

import tensorflow as tf
a = tf.range(10)
win_size = 5
stride = 1
# Option 1
a_win = tf.image.extract_patches(tf.reshape(a, [1, -1, 1, 1]),
sizes=[1, win_size, 1, 1],
strides=[1, stride, 1, 1],
rates=[1, 1, 1, 1],
padding='VALID')[0, :, 0]
# Option 2
a_win = tf.extract_volume_patches(tf.reshape(a, [1, -1, 1, 1, 1]),
ksizes=[1, win_size, 1, 1, 1],
strides=[1, stride, 1, 1, 1],
padding='VALID')[0, :, 0, 0]
# Print result
print(a_win.numpy())
# [[0 1 2 3 4]
#  [1 2 3 4 5]
#  [2 3 4 5 6]
#  [3 4 5 6 7]
#  [4 5 6 7 8]
#  [5 6 7 8 9]]

相关内容

  • 没有找到相关文章

最新更新