如何在pytorch上加载预批处理的数据集



我有一个巨大的数据集,无法存储在内存中,所以我预先匹配了几个文件。如何使数据集和数据加载器类一次加载一个浴。

  • 所有文件都有相同的基本名称和唯一的批号
  • 示例文件将被称为o3_batch_1.hdf5或o3_batch_2.hdf5
  • 最大批号为o3_batch_102.hdf5

以下是我迄今为止所尝试的:

它行得通吗?length将是数据的总长度。

batchNum将是文件末尾的非唯一数字。

base是文件共享的通用名称。

类数据(数据集(:

# Constructor
def __init__(self, base, batchNum, length):
name = base + str(batchNum) 
with h5py.File(name, "r") as f:
puzz = np.array(f.get('puzzle'))
sol = np.array(f.get('Sol'))
self.puzz = torch.from_numpy(puzz)
self.sol = torch.from_numpy(sol)
self.len = length

# Getter
def __getitem__(self, batchNum, index):    
return self.puzz[index], self.sol[index]
# Get length
def __len__(self):
return self.len 

我认为您可以对Index数组进行迭代,并且可以通过迭代获得数据。

假设您的文件以以下方式组织

/yourFileDir 
o3_batch_1.hdf5
o3_batch_2.hdf5
...
o3_batch_102.hdf5

您的批次索引为0,1,2,。。。,102

h5_dir = '/yourFileDir'
for Index in range(103):
with h5py.File(h5_dir + 'o3_batch_{}'.format(Index), 'r') as f:
puzz = np.array(f['puzzle'])
sol = np.array(f['Sol']) # this depends on how you save your data

最新更新