pytorch数据集对象在for循环中使用时,如何知道它是否已到达终点



我正在编写一个自定义pytorch数据集。在__init__中,数据集对象加载一个包含特定数据的文件。但在我的程序中,我只希望访问部分数据(如果有帮助的话,可以实现训练/有效切割(。起初,我认为这种行为是通过重写__len__来控制的,但事实证明,修改__len__并没有帮助。一个简单的例子如下:

from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file

def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items

def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in enumerate(ds):
print(i)

输出是0到9,而所需的行为是0到4。

当在这样的for循环中使用时,这个数据集对象如何知道枚举已经结束?任何其他达到类似效果的方法也是受欢迎的。

您正在使用Dataset类创建一个自定义数据加载器,同时使用为循环枚举它。这不是它的工作方式。对于枚举,您必须将Dataset传递给DataLoader类。你的代码会像这样工作得很好,

from torch.utils.data import Dataset, DataLoader
import torch
class NewDS(Dataset):
def __init__(self):
self.data = torch.randn(10,2) # suppose there are 10 items in the data file

def __len__(self):
return len(self.data)-5 # But I only want to access the first 5 items

def __getitem__(self, index):
return self.data[index]
ds = NewDS()
for i, x in range(len(ds)): #if you do dont want to use DataLoader, then dont use enumerate
print(i, ds[i])
#output 
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])
dl = DataLoader(ds, batch_size=1) # pass the ds object to DataLoader 
for i, x in enumerate(dl): # now you can use enumarate
print(i, x)
#output
tensor([-0.2351,  1.3037])
tensor([ 0.4032, -0.2739])
tensor([-0.5687, -0.7300])
tensor([0.5418, 0.8572])
tensor([ 1.9973, -0.2939])

更多详细信息可以在这个官方pytorch教程中阅读。

您可以使用torch.utils.data.Subset获取数据的子集

top_five = torch.utils.data.Subset(ds, indices=range(5))  # Get first five items
for i, x in enumerate(top_five):
print(i)
0
1
2
3
4
循环中的

enumerate将返回项,直到它得到StopIteration异常。

len(ds)         # Returned modified length
5
# `enumerate` will call `next` method on iterable each time in loop.
#  and When no more data available a StopIteration exception is raised instead.
iter_ds = iter(ds)
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))
print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate in iterable

输出:

tensor([-1.5952, -0.0826])
tensor([-2.2254,  0.2461])
tensor([-0.8268,  0.1956])
tensor([ 0.3157, -0.3403])
tensor([0.8971, 1.1255])
tensor([0.3922, 1.3184])
tensor([-0.4311, -0.8898])
tensor([ 0.1128, -0.5708])
tensor([-0.5403, -0.9036])
tensor([0.6550, 1.6777])
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-99-7a9910e027c3> in <module>
10 print(next(iter_ds))
11 
---> 12 print(next(iter_ds))  #11th time StopIteration exception raised as no item left to iterate
StopIteration: 

最新更新