使用StratifiedKFold在Auto Keras中训练的NumPy数组值错误



背景

我的情绪分析研究涉及各种数据集。最近我遇到了一个数据集,不知怎么的,我就是无法成功训练。我主要使用.CSV文件格式的开放数据,因此PandasNumPy被大量使用。

在我的研究过程中,其中一种方法是尝试集成自动机器学习(AutoML),我选择使用的库是Auto-Keras,主要使用其TextClassifier()包装函数来实现AutoML

主要问题

我已经用官方文档验证了TextClassifier()采用NumPy数组格式的数据。然而,当我将数据加载到Pandas DataFrame中,并在需要训练的列上使用.to_numpy()时,以下错误不断显示:


ValueError                                Traceback (most recent call last)
<ipython-input-13-1444bf2a605c> in <module>()
16     clf = ak.TextClassifier(overwrite=True, max_trials=2)
17 
---> 18     clf.fit(x_train, y_train, epochs=3, callbacks=cbs)
19 
20 
ValueError: Failed to convert a NumPy array to a Tensor (Unsupported object type float).

与错误相关的代码扇区

使用.drop()删除不需要的Pandas DataFrame列,并使用Pandas提供的to_numpy()函数将所需列转换为NumPy阵列的扇区。


df_src = pd.read_csv(get_data)
df_src = df_src.drop(columns=["Name", "Cast", "Plot", "Direction",
"Soundtrack", "Acting", "Cinematography"])
df_src = df_src.reset_index(drop=True)
X = df_src["Review"].to_numpy()
Y = df_src["Overall Sentiment"].to_numpy()
print(X, "n")
print("n", Y)

main错误代码部分,我在其中执行StratifedKFold(),同时使用TextClassifier()来训练和测试模型。


fold = 0
for train, test in skf.split(X, Y):
fold += 1
print(f"Fold #{fold}n")

x_train = X[train]
y_train = Y[train]

x_test = X[test]
y_test = Y[test]


cbs = [tf.keras.callbacks.EarlyStopping(patience=3)]

clf = ak.TextClassifier(overwrite=True, max_trials=2)


# The line where it indicated the error.
clf.fit(x_train, y_train, epochs=3, callbacks=cbs)


pred = clf.predict(x_test) # result data type is in lists of `string`

ceval = clf.evaluate(x_test, y_test)

metrics_test = metrics.classification_report(y_test, np.array(list(pred), dtype=int))

print(metrics_test, "n")

print(f"Fold #{fold} finishedn")

补充

我通过Google Colab分享了与该错误相关的完整代码,您可以在这里帮助我进行诊断。

编辑笔记

我尝试过潜在的解决方案,例如:

x_train = np.asarray(x_train).astype(np.float32)
y_train = np.asarray(y_train).astype(np.float32)

x_train = tf.data.Dataset.from_tensor_slices((x_train,))
y_train = tf.data.Dataset.from_tensor_slices((y_train,))

然而,问题依然存在。

其中一个字符串等于nan。只需删除此条目和相应的标签。

相关内容

  • 没有找到相关文章

最新更新