我有两个 numpy 数组,X_train 和 Y_train,其中第一个维度 (700,1000) 由值 0、1、2、3、4 和 10 填充。维度的第二个(700)由值"新鲜"或"烂"填充,因为我正在使用烂番茄的API。出于某种原因,当我执行时:
nb = MultinomialNB()
nb.fit(X_train, Y_train)
我得到:
ValueError: Unknown label type
我尝试构建一对较小的数组:
print xs, 'n', ys
给
[[0 0 0 0 1]
[1 0 0 2 5]
[3 2 5 5 0]
[3 2 0 0 1]
[1 5 1 0 0]]
['rotten' 'fresh' 'fresh' 'rotten' 'fresh']
多项式 NB 拟合不会给出未知标签误差。关于为什么会发生这种情况的任何想法?
我还检查了X_train中的唯一值,Y_train numpy.unique,似乎没有任何奇怪或输入错误的标签——它都是"新鲜"或"烂"。
我用于生成X_train和Y_train的代码:
def make_xy(critics, vectorizer=None):
stext = critics['quote'].tolist() # need to have a list
if vectorizer == None:
vectorizer = CountVectorizer(min_df=0)
vectorizer.fit(stext)
X = vectorizer.transform(stext).toarray() # this is X
Y = np.asarray(critics['fresh'])
return X[0:1000,0:1000], Y[0:1000] # this is X_train, Y_train
其中"批评者"是从CSV文件(https://www.dropbox.com/s/0lu5oujfm483wtr/critics.csv)导入的熊猫数据帧,并清除了任何缺失的数据:
critics = pd.read_csv('critics.csv')
critics = critics[~critics.quote.isnull()]
critics = critics[critics.fresh != 'none']
critics = critics[critics.quote.str.len() > 0]
问题似乎是 y 的 dtype。 看起来 Numpy 没有设法弄清楚它是一个字符串。 所以它被设置为一个通用对象。如果更改:
Y = np.asarray(critics['fresh'])
Y = np.asarray(critics['fresh'], dtype="|S6")
我认为它应该有效。
我也遇到了同样的问题。Numpy 有时无法检测数组的数据类型。所以,我们明确地给出它。这是 NumPy 的所有类型的文档。根据您的要求选择数据类型,并将其作为"dtype="属性提供。