使用Dataset和.batch后的爆炸张量



我有一个形状为(100,4,30)的numpy数组。这表示长度为30的编码的4个样本中的100个样本。每行的4个样本是相关的。

我想获得一个TensorFlow数据集,批处理,其中相关样本在同一批处理中。

I'm try to do:

首先,使用np.vsplit获取长度为100的列表,其中列表中的每个元素都是4个相关样本的列表。

现在,如果我在这个列表的列表上调用tf.data.Dataset.from_tensor_slices(...).batch(1),我得到一个包含形状为(4,1,30)张量的批。

我希望这批包含4个形状为(1,30)的张量。

我怎样才能做到这一点?

我可能误解了你的意思,但是如果你省略了"vsplit":

data = np.zeros((100, 4, 30))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(1)
for element in data_ds.take(1):
print(element.shape)

你将得到:

(1, 4, 30)

(因此一个批次包含所有4个相关编码)。

如果你真的希望批处理中的尺寸是(1,30)的4倍,你可以这样做:

data = np.expand_dims(data, axis=2)

编辑:

我想我明白你的问题了。你希望每批有4个元素这些是相关的编码吗?您可以这样做:
data = np.swapaxes(data, 0, 1)
data = np.reshape(data, (100*4, -1))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(4)

相关内容

  • 没有找到相关文章

最新更新