合并两个tensorflow数据集,尽管速度不同



我正在寻找一种将Dataset与另一个合并的方法,但只是偶尔从中提取样本。

例如,给定这两个Datasets

ds1 = tf.data.Dataset.range(1, 10).repeat()
ds10 = tf.data.Dataset.range(10, 100, 10).repeat()

我想把ds10的样本添加到ds1的样本中,但只针对每两个样本,这样结果就会是

ds = my_merge(ds1, ds10)
list(ds)
# 11, 2, 23, 4, 35, 6, 47...

这可能吗?我希望避免解决方案丢弃ds10中的样本,因为在我的情况下,这将是低效的。

EDIT生成的ds需要是Dataset,以便可以应用进一步的输入管道操作(例如批处理(。

基于跳过参数修改ds10数据集

skip = 2
pattern = np.concatenate(([0], np.ones((skip-1)))).astype(np.int64)
choice_dataset = tf.data.Dataset.from_tensor_slices((pattern)).repeat()
zeros = tf.data.Dataset.range(0,1).repeat()
ds10 = tf.data.Dataset.choose_from_datasets([ds10, zeros], choice_dataset)
#[10, 0, 20, 0, 30, 0, 40, 0, 50]

压缩并添加两个数据集值

ds = tf.data.Dataset.zip((ds1,ds10))
ds = ds.map(lambda x,y:x+y)
#[11, 2, 23, 4, 35, 6, 47, 8, 59]

检查性能,

def time_ds():
for element in ds.take(1000):
pass
def time_ds1():
for element in ds1.take(1000):
pass
%timeit time_ds() 29.3 ms ± 133 µs 
%timeit time_ds1() 23.5 ms ± 94.7 µs per loop

您可以创建自己的生成器:

import tensorflow as tf
from functools import partial
ds1_unrepeated = tf.data.Dataset.range(1, 10)  # because repeat prevents element_spec
ds1_spec = ds1_unrepeated.element_spec
ds1 = ds1_unrepeated.repeat()
ds10 = tf.data.Dataset.range(10, 100, 10).repeat()
def my_merge(iter1,iter2):
sliced_iter2 = iter(iter2)
sliced_iter1 = iter(iter1)
while True:
yield next(sliced_iter1)+next(sliced_iter2)
yield next(sliced_iter1)

ds = tf.data.Dataset.from_generator(partial(my_merge,ds1,ds10),output_signature=ds1_spec)
for element in ds:
print(element)
tf.Tensor(11, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(23, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(35, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64)
tf.Tensor(47, shape=(), dtype=int64)

编辑:我已经将其更新为一个数据集,但我认为顶部的答案更有效,这个答案只有在答案应该尽可能懒散地评估的情况下,而对输入知之甚少,即:合并可以任意复杂。

最新更新