我有一个形状为(100,4,30)的numpy数组。这表示长度为30的编码的4个样本中的100个样本。每行的4个样本是相关的。
我想获得一个TensorFlow数据集,批处理,其中相关样本在同一批处理中。
I'm try to do:
首先,使用np.vsplit
获取长度为100的列表,其中列表中的每个元素都是4个相关样本的列表。
现在,如果我在这个列表的列表上调用tf.data.Dataset.from_tensor_slices(...).batch(1)
,我得到一个包含形状为(4,1,30)张量的批。
我希望这批包含4个形状为(1,30)的张量。
我怎样才能做到这一点?
我可能误解了你的意思,但是如果你省略了"vsplit":
data = np.zeros((100, 4, 30))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(1)
for element in data_ds.take(1):
print(element.shape)
你将得到:
(1, 4, 30)
(因此一个批次包含所有4个相关编码)。
如果你真的希望批处理中的尺寸是(1,30)的4倍,你可以这样做:
data = np.expand_dims(data, axis=2)
编辑:
我想我明白你的问题了。你希望每批有4个元素这些是相关的编码吗?您可以这样做:data = np.swapaxes(data, 0, 1)
data = np.reshape(data, (100*4, -1))
data_ds = tf.data.Dataset.from_tensor_slices(data).batch(4)