The problem with tf.keras.datasets.imdb.load_data


from keras.datasets import imdb
(X_train, y_train), (X_test, y_test) = imdb.load_data(path="imdb.npz",
num_words=10,
skip_top=0,
maxlen=None,
seed=42,
start_char=1,
oov_char=2,
index_from=3)
s = set()
for i in X_train:
s.update(np.unique(i))
print(s)
>{1, 2, 4, 5, 6, 7, 8, 9}

我的问题是,我们已经将num_words指定为10,并且我们想要10个最频繁的单词。但在集合s中,1=序列的开始,2=oov_char,因此我们只有剩下的六个索引(4到9(表示单词。

当我们指定num_words必须是10时,为什么我们只得到六个最频繁的?此外,有人能举例说明index_from的含义以及X_train中的索引是如何分配的吗?

您使用了imdb.load_data((的默认输入。我认为返回的不是单词的数量,而是索引的数量。如果您更改如下所示的默认输入模式,您将看到返回了10个索引。

from keras.datasets import imdb
(X_train, y_train), (X_test, y_test) = imdb.load_data(
path="imdb.npz",
num_words=10,
skip_top=0,
maxlen=None,
seed=42,
start_char=0,
#oov_char="OOV",
index_from=0,
)
Set = set()
for i in X_train:
Set.update(np.unique(i))
print(Set)
print(len(Set))

输出:

{'8', '6', '0', '7', '5', '2', '4', '9', '1', '3'} # + 'OOV' if uncomment oov_char="OOV"
10 # will be 11, if uncomment oov_char="OOV"

在这个数据库中,num_words通常为10000,这与它是完全相同的数字还是稍大或稍小无关。

最新更新