我正在尝试加载MNIST数据集,但我得到了
TypeError:元组索引必须是整数或切片,而不是str
这是我的代码:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
mnist_dataset = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
这行给了我错误:
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
如果包含with_info=True
,则需要相应地解压缩:
mnist_dataset, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
按照您的做法,mnist_dataset
是一个元组,包含一个2项字典和一个tfds.core.DatasetInfo
对象:
(
{
'test': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>,
'train': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>
},
tfds.core.DatasetInfo(name='mnist', etc)
)