从pytorch数据集返回索引:更改__getitem__的函数导致元类冲突



我有多个类(用于不同的数据集(,它们继承自pytorch的Dataset类。它们有一个通用的结构,就像这样:

from torch.utils.data import Dataset
class SomeDataset(Dataset):
def __init__(self, data, labels):
super(SomeDataset, self).__init__()
self.data = data
self.labels = labels
self.__name__ = 'SomeDataset'
def __getitem__(self, index):
return {'data': self.data[index], 'label': self.labels[index]}
def __len__(self):
return len(data)

最近我意识到,在批处理时跟踪传递到Dataloader的标签是有益的,所以在谷歌上搜索如何做到这一点时,我遇到了这个线程,这就是我修改代码来编写这个函数的地方:

def return_indices(dataset_class):

def __getitem__(self, index):
return {'index':1, **dataset_class.__getitem__(self, index)}
return type(dataset_class.__name__, (dataset_class, ), {'__getitem__': __getitem__})

我以前从未见过type这样使用,但经过一些谷歌搜索,它让有了一些的意义,所以我尝试了一下。不幸的是,这导致了这个错误:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases

这导致了更多的谷歌搜索,尽管我开始了解元类是什么以及它们是如何使用的,但我仍然不知道这种方法有什么问题,也不知道如何解决它——我开始认为,也许把这个功能重写到我的数据集类中会更容易,而不是有一些整洁的包装器来为我做这件事。有人能说出我遗漏的东西吗?

只需执行以下操作:

def return_indices(dataset_class):

def __getitem__(self, index):
return {'index':1, **dataset_class.__getitem__(self, index)}
metacls = type(dataset_class)
return metacls(dataset_class.__name__, (dataset_class, ), {'__getitem__': __getitem__})

发生的事情:正如您所发现的,对type的3参数调用是用Python编程创建新类的方法,而不需要";类";声明及其正文。

但CCD_ 3是;基本元类"-虽然它的实例将是普通类;硬编码";您正在创建的类的元类本身-相反,使用class语句将使Python在您正在创建类的基中搜索合适的元类。

只需使用派生类元类(如上所述,它是通过类型的单参数形式获得的,或者通过类的__class__属性获得的,就像在dataset_class.__class__中一样(。

将其作为可调用的类型来代替类型,将其自身作为元类,一切都应该正常。

NB:由于元类还有更多的机制,比如__prepare__,所以仅仅调用元类而不是type并不总是有效的——正确的通用方法包括调用types.prepare_classtypes.new_class,并有一个回调来执行类语句体中的类体执行。在大多数情况下,这是不需要的。

最新更新