我有一个关于新数据集API(tensorflow 1.4(的问题。我有两个数据集,我需要创建一个组合的不平衡数据集,即每个批次应包含来自第一个数据集的一定数量的元素和来自第二个数据集的一定数量的元素。例如
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([1,1,1,1,1,1]
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([2,2,2,2,2,2]))
假设批量大小为 4,我希望组合数据集中的批次看起来像 [1,1,1,2]。我知道如何使用 zip 和 flat_map 生成平衡的数据集但我对这个不知所措。
提前感谢!
为了解决这个问题,我的解决方案是单独批处理数据集,压缩它们,然后在生成的数据集上映射tf.concat
运算符。
在您的示例中,它将给出类似以下内容(我将第二个数据集重命名为dataset2
(:
def concat(*tensor_list):
return tf.concat(tensor_list, axis=0)
zipped_ds = tf.data.Dataset.zip((dataset1.batch(3), dataset2))
unbalanced_ds = zipped_ds.map(concat)
如果数据集是张量的嵌套结构,则可以使用以下版本的 concat
:def concat(*ds_elements):
#Create one empty list for each component of the dataset
lists = [[] for _ in ds_elements[0]]
for element in ds_elements:
for i, tensor in enumerate(element):
#For each element, add all its component to the associated list
lists[i].append(tensor)
#Concatenate each component list
return tuple(tf.concat(l, axis=0) for l in lists)
如果所有数据集元素(要合并的数据集的一部分(都是张量,只有最外层维度(相对批量大小(不同,则有效。它为数据集元素的每个组件构建一个列表,并将这些组件彼此独立地连接起来。
处理一级嵌套。如果你需要更多,你可以使用递归去解压缩嵌套的巢穴,但它可能会给出一个不太干净的计算图......