练习作业 AWS 计算机视觉:get_Cifar10_dataset



我对这种方法有问题,它应该同时返回训练数据集和验证数据集,并对其进行检查以返回对应于CIFAR10中每个类第一次出现的索引。

这是代码:def get_cifar10_dataset(): """应创建 cifar 10 网络并识别每个新类第一次的数据集索引 出现

:return: tuple of training and validation dataset as well as label indices
:rtype: (gluon.data.Dataset, 'dict_values' object is not subscriptable, gluon.data.Dataset, 
dict[int:int])
"""
train_data = None
val_data = None
# YOUR CODE HERE
train_data = datasets.CIFAR10(train=True, root=M5_IMAGES)
val_data = datasets.CIFAR10(train=False, root=M5_IMAGES)

系统会要求您返回一个包含标签和相应索引的字典。使用以下函数可以解决您的问题。

def get_idx_dict(data):
lis = []
idx = []
indices = {}

for i in range(len(data)):
if data[i][1] not in lis:
lis.append(data[i][1])
idx.append(i)

indices = {lis[i]: idx[i] for i in range(len(lis))}
return indices

该函数返回具有所需输出的字典。对来自训练和验证集的数据使用此函数。

train_indices = get_idx_dict(train_data)
val_indices = get_idx_dict(val_data)

你可以这样做

def get_cifar10_dataset():
"""
Should create the cifar 10 network and identify the dataset index of the first time each new class appears

:return: tuple of training and validation dataset as well as label indices
:rtype: (gluon.data.Dataset, dict[int:int], gluon.data.Dataset, dict[int:int])
"""
train_data = None
val_data = None
train_indices = {}
val_indices = {}

# Use `root=M5_IMAGES` for your dataset
train_data = gluon.data.vision.datasets.CIFAR10(train=True, root=M5_IMAGES)
val_data   = gluon.data.vision.datasets.CIFAR10(train=False, root=M5_IMAGES)

#for train
for i in range(len(train_data)):
if train_data[i][1] not in train_indices:
train_indices[train_data[i][1]] = i
#for valid
for i in range(len(val_data)):
if val_data[i][1] not in val_indices:
val_indices[val_data[i][1]] = i

#raise NotImplementedError()

return train_data, train_indices, val_data, val_indices

最新更新