如何使用KerasClassifier计算损失



我正在使用来自sklearn的KerasClassifier来包装我的Keras模型,以便执行K-fold交叉验证。

model = KerasClassifier(build_fn=create_model, epochs=20, batch_size=8, verbose = 1)    
kfold = KFold(n_splits=10)
scoring = ['accuracy', 'precision', 'recall', 'f1']
results = cross_validate(estimator=model,
X=x_train,
y=y_train,
cv=kfold,
scoring=scoring,
return_train_score=True,
return_estimator=True)

然后根据度量,从函数返回的10个估计器中选择最佳模型:

best_model = results['estimators'][2] #for example the second model

现在,我想对x_test执行预测得到精度和损耗指标。我该怎么做呢?我尝试了model.evaluate(x_test, y_test),但模型是KerasClassifier,所以我得到了一个错误。

要点是您的KerasClassifier实例模仿标准的scikit-learn分类器。换句话说,它是一种scikit-learn野兽,因此,它不提供方法.evaluate()

因此,您可以调用best_model.score(X_test, y_test),它将像标准sklearn分类器一样自动返回准确性。另一方面,您可以通过KerasClassifier实例的history_属性访问训练期间获得的损失值。

下面是一个例子:

!pip install scikeras    
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split, cross_validate, KFold
import tensorflow as tf
import tensorflow.keras
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from scikeras.wrappers import KerasClassifier
X, y = make_classification(n_samples=100, n_features=20, n_informative=5, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
def build_nn():
ann = Sequential()
ann.add(Dense(20, input_dim=X_train.shape[1], activation='relu', name="Hidden_Layer_1"))
ann.add(Dense(1, activation='sigmoid', name='Output_Layer'))
ann.compile(loss='binary_crossentropy', optimizer= 'adam', metrics = 'accuracy')
return ann
keras_clf = KerasClassifier(model = build_nn, optimizer="adam", optimizer__learning_rate=0.001, epochs=100, verbose=0)
kfold = KFold(n_splits=10)
scoring = ['accuracy', 'precision', 'recall', 'f1']
results = cross_validate(estimator=keras_clf, X=X_train, y=y_train, scoring=scoring, cv=kfold, return_train_score=True, return_estimator=True)
best_model = results['estimator'][2]
# accuracy
best_model.score(X_test, y_test)
# loss values
best_model.history_['loss']

最后观察到,当有疑问时,您可以调用dir(object)来获取指定对象的所有属性和方法的列表(在您的示例中为dir(best_model))。

最新更新