加载张量流数据集时的"Tuple indices must be integers or slices, not str"



我正在尝试加载MNIST数据集,但我得到了

TypeError:元组索引必须是整数或切片,而不是str

这是我的代码:

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
mnist_dataset = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

这行给了我错误:

mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']

如果包含with_info=True,则需要相应地解压缩:

mnist_dataset, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

按照您的做法,mnist_dataset是一个元组,包含一个2项字典和一个tfds.core.DatasetInfo对象:

(
{
'test': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>,
'train': <PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>
},

tfds.core.DatasetInfo(name='mnist', etc)
)

相关内容

最新更新