将PyTorch数据加载器转换为tf.Dataset



如何将torch.utils.data.DataLoader数据加载程序转换为tf.Dataset

我发现了这个片段

def convert_pytorch_dataloader_to_tf_dataset(
dataloader, batch_size, shuffle=True
):
dataset = tf.data.Dataset.from_generator(
lambda: dataloader,
output_types=(tf.float32, tf.float32),
output_shapes=(tf.TensorShape([256, 512]), tf.TensorShape([2,]))
)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(dataloader.dataset))
dataset = dataset.batch(batch_size)
return dataset

但它根本不起作用。

是否有一个内置选项可以轻松地将dataloaders导出到tf.Dataset?我有一个非常复杂的数据加载器,所以简单的解决方案应该可以确保没有错误:(

对于h5py格式的数据,您可以使用下面的脚本。name_x是h5py中的功能名称,name_y是标签的文件名。这种方法具有内存效率,并且可以逐批地提供数据。

class Generator(object):
def __init__(self,open_directory,batch_size,name_x,name_y):
self.open_directory = open_directory
data_f = h5py.File(open_directory, "r")
self.x = data_f[name_x]
self.y = data_f[name_y]
if len(self.x.shape) == 4:
self.shape_x = (None, self.x.shape[1], self.x.shape[2], self.x.shape[3])
if len(self.x.shape) == 3:
self.shape_x = (None, self.x.shape[1], self.x.shape[2])
if len(self.y.shape) == 4:
self.shape_y = (None, self.y.shape[1], self.y.shape[2], self.y.shape[3])
if len(self.y.shape) == 3:
self.shape_y = (None, self.y.shape[1], self.y.shape[2])
self.num_samples = self.x.shape[0]
self.batch_size = batch_size
self.epoch_size = self.num_samples//self.batch_size+1*(self.num_samples % self.batch_size != 0)
self.pointer = 0
self.sample_nums = np.arange(0, self.num_samples)
np.random.shuffle(self.sample_nums)

def data_generator(self):
for batch_num in range(self.epoch_size):
x = []
y = []
for elem_num in range(self.batch_size):
sample_num = self.sample_nums[self.pointer]
x += [self.x[sample_num]]
y += [self.y[sample_num]]
self.pointer += 1
if self.pointer == self.num_samples:
self.pointer = 0
np.random.shuffle(self.sample_nums)
break
x = np.array(x,
dtype=np.float32)
y = np.array(y,
dtype=np.float32)
yield x, y
def get_dataset(self):
dataset = tf.data.Dataset.from_generator(self.data_generator,
output_types=(tf.float32,
tf.float32),
output_shapes=(tf.TensorShape(self.shape_x),
tf.TensorShape(self.shape_y)))
dataset = dataset.prefetch(1)
return dataset

最新更新