如何将tensorflow数据集拆分为N个数据集



我有一个tensorflow数据集ds,我想把它分成N个数据集,它们的联合是原始数据集,并且它们之间不共享样本。我试着:

ds_list = [ds.shard(N,index=i) for i in range(N)]

但不幸的是,它不是随机的:每个新数据集总是从原始数据集获得相同的样本。例如,ds_list[0]的样本数为0,N,2N,3N…,而ds_list[1]将有1,N+ 1,2n + 1,3n +1…是否有任何方法将原始数据集随机细分为相同大小的数据集?

不幸的是,简单地洗牌并不能解决问题:

import tensorflow as tf
import math
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])
N=2
ds = ds.shuffle(20)
ds_list = [ds.shard(N,index=i) for i in range(N)]

for ds in ds_list:
shard_set = sorted(set(list(ds.as_numpy_iterator())))
print(shard_set)

输出:

[3, 5, 6, 8, 11, 12, 14, 15, 19, 20]
[1, 2, 4, 5, 6, 7, 8, 14, 15, 20]

一样:

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 ,15, 16, 17, 18, 19, 20])
N=2
ds_list = []
ds = ds.shuffle(20)
size = ds.__len__()
sub = math.floor(size/N)
for n in range(N):
ds_sub = ds.take(sub)
remainder = ds.skip(sub)
ds_list.append(ds_sub)
ds = remainder  
for ds in ds_list:
shard_set = sorted(set(list(ds.as_numpy_iterator())))
print(shard_set)

也许(对于N个分片):

ds_list = []
ds = ds.shuffle()
size = ds.__len__()
sub = floor(size/N)
for n in range(N):
ds_sub = ds.take(sub)
remainder = ds.skip(sub)
ds_list.append(ds_sub)
ds = remainder  

您可以先对数据集进行洗牌,然后再对其进行分片:

ds = ds.shuffle(buffer_size)
ds_list = [ds.shard(N,index=i) for i in range(N)]

这里buffer_size是TF用于排序的缓冲区的大小。如果数据集的大小较小,可以将总样例数作为buffer_size传入。否则,一个更小的数字(比如100),可以放在内存中,就可以工作了。

相关内容

  • 没有找到相关文章

最新更新