有人能澄清一下_getitem_函数中发生了什么吗?谢谢



我知道输出包含所有编码、令牌类型id、attention\ymask和作为张量的相应标签。我想了解getitem函数的内部工作原理,以及使用len功能获取标签长度的必要性。

class NewsGroupsDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
item["labels"] = torch.tensor([self.labels[idx]])
return item
def __len__(self):
return len(self.labels)
# convert our tokenized data into a torch Dataset
train_dataset = NewsGroupsDataset(train_encodings, train_labels)
valid_dataset = NewsGroupsDataset(valid_encodings, valid_labels)

Python为类定义了许多特殊的方法。这些方法定义了类在某些情况下的行为。您可能已经熟悉__init__特殊方法,该方法在创建类的新实例时被调用。__getitem__是另一种特殊方法,当您对类的实例(即方括号[](使用订阅时会调用它,当您使用将类的实例传递给内置的len函数时会调用__len__

至于Pytorch,我们必须实现这些方法,因为这正是Pytorch的DataLoader对象所期望的。它使用这些方法对数据集进行采样,并知道何时完成对数据集的采样。尽管DataLoader使用了许多抽象来支持不同的采样和多进程操作,但它基本上需要__len__才能知道它可以从数据集中查询的最大索引,并且它使用__getitem__来对它需要的索引进行采样。

例如,当您使用0个工人的随机采样时,以下片段有效地完成了相同的

from torch.utils.data import DataLoader
train_dataset = NewsGroupsDataset(train_encodings, train_labels)
data_loader = DataLoader(train_dataset, batch_size=5, shuffle=True)
for items in data_loader:
# items now contains batches of samples of size 5 from your dataset

# For demonstration purposes only, do NOT sample your datasets like this (use DataLoader)!
import random
from torch.utils.data import default_collate
def random_batches(dataset, batch_size, shuffle):
indices = list(range(len(dataset)))  # uses Dataset.__len__
if shuffle:
random.shuffle(indices)
batch = []
for i in indices:
batch.append(dataset[i])  # uses Dataset.__getitem__
if len(batch) == batch_size:
yield default_collate(batch)
batch = []
if batch:
yield default_collate(batch)
train_dataset = NewsGroupsDataset(train_encodings, train_labels)
for items in random_batches(train_dataset, batch_size=5, shuffle=True):
# items now contains batches of samples of size 5 from your dataset

注意,default_collate是一个获取样本列表并将其转换为批量大小张量堆栈的函数。如果您对细节感兴趣,可以在这里找到实现。

DataLoader还支持许多其他很酷的东西,比如多个worker(可能是最重要的(、自定义采样方案、自定义数据排序、固定内存、丢弃最后一个非完整批等等。Pytorch为您完成了此类的大部分工作,您只需要使用__len____getitem__实现编写数据集对象。

__getitem__是一个用于从调用实例的属性中获取项的方法。__getitem__主要与listtuple等容器配合使用。

class Example:
def __init__(self, item):
self.item = item
def __getitem__(self, index):
return self.item[index]

e = Example([1, 2, 3])
print(f"First item: {e[0]}")
# First item: 1
print(f"Second item: {e[1]}")
# Second item: 2
print(f"Third item: {e[2]}")
# Third item: 3

最新更新