DataLoader按顺序返回多个值,而不是列表或元组


def __init__():
def __len__():
def __getitem__(self, idx):    
cat_cols = (self.cat_cols.values.astype(np.float32))
cont_cols = (self.cont_cols.values.astype(np.float32))
label = (self.label.astype(np.int32))
return (cont_cols[idx], cat_cols[idx], label[idx])

当我在上面的类中使用数据加载器时,我得到了索引为0、1和2的cont_cols、cat_cols和标签作为输出。而我希望他们在一起。我已经尝试返回值作为字典,但后来我有索引问题。

我必须将dataloader的输出读取为

dl = DataLoader(dataset[0], batch_size = 1)

for i, data in enumerate(dl):
if i == 0:
cont = data
if i == 1:
cat = data
if i == 2:
label = data

当前

的输出
for i, data in enumerate(dl):
print(i, data) 

0张量([[3.2800e+02, 4.8000e+01, 1.00e +03, 1.4069e+03, 4.6613e+05, 5.3300e+04,0.00000 e+00, 5.000 e+00, 1.000 e+00, 1.000 e+00, 2.000 e+00, 7.1610e+04,6.5100 e + 03, 1.3020 e + 04, 5.2080 e + 04, 2.0040 e + 03]])

1张量([2];1。1。4。2。, 17岁。, 0。2。3。, 0。4。4。1。、2。2.10。1。]])

2张量([1],dtype=torch.int32)

我想要的是数据[0],数据[1]和数据[2]访问的输出,但数据加载器只给我返回数据[0]。它首先返回cont_cols,然后是cat_cols,然后是label。

我想你在这里感到困惑,你的数据集确实可以返回元组s,但你必须以不同的方式处理它。

你的数据集被定义为:

class MyDataset(Dataset):
def __init__(self):
pass
def __len__():
pass
def __getitem__(self, idx):    
cat_cols = (self.cat_cols.values.astype(np.float32))
cont_cols = (self.cont_cols.values.astype(np.float32))
label = (self.label.astype(np.int32))
return (cont_cols[idx], cat_cols[idx], label[idx])

然后定义数据集和数据加载器。注意,这里不应该提供dataset[0],而应该提供dataset:

>>> dataset = Dataset()
>>> dl = DataLoader(dataset, batch_size=1)

然后在循环中访问数据加载器内容:

>>> for cont, cat, label in dl:
...   print(cont, cat, label)

最新更新