我使用scikit学习接口创建了一个自定义分类器,只是为了学习。所以,我想出了以下代码:
import numpy as np
from sklearn.utils.estimator_checks import check_estimator
from sklearn.base import BaseEstimator, ClassifierMixin, check_X_y
from sklearn.utils.validation import check_array, check_is_fitted, check_random_state
class TemplateEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, threshold=0.5, random_state=None):
self.threshold = threshold
self.random_state = random_state
def fit(self, X, y):
self.random_state_ = check_random_state(self.random_state)
X, y = check_X_y(X, y)
self.classes_ = np.unique(y)
self.fitted_ = True
return self
def predict(self, X):
check_is_fitted(self)
X = check_array(X)
y_hat = self.random_state_.choice(self.classes_, size=X.shape[0])
return y_hat
check_estimator(TemplateEstimator())
这个分类器只是进行随机猜测。我尽力遵循scikit学习文档和指导方针来开发我自己的估计器。然而,我得到以下错误:
AssertionError:
Arrays are not equal
Classifier cant predict when only one class is present.
Mismatched elements: 10 / 10 (100%)
Max absolute difference: 1.
Max relative difference: 1.
x: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
y: array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
我不能确定,但我猜是随机性(即self.random_state_
(导致了错误。我使用的是sklearn版本1.0.2
。
首先要注意的是,如果使用parametrize_with_checks
和pytest
而不是check_estimator
,则可以获得更好的输出。它看起来像:
@parametrize_with_checks([TemplateEstimator()])
def test_sklearn_compatible_estimator(estimator, check):
check(estimator)
如果你用pytest运行它,你会得到以下失败测试的输出:
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_pipeline_consistency] - AssertionError:
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_train] - AssertionError
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_train(readonly_memmap=True)] - AssertionError
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_train(readonly_memmap=True,X_dtype=float32)] - AssertionError
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_classifiers_regression_target] - AssertionError: Did not raise: [<class 'ValueErr...
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_methods_sample_order_invariance] - AssertionError:
FAILED ../../../../tmp/1.py::test_sklearn_compatible_estimator[TemplateEstimator()-check_methods_subset_invariance] - AssertionError:
其中一些测试检查一些输出一致性,这与您的情况无关,因为您返回了随机值。在这种情况下,您需要设置non_deterministic
estimator tag
。其他一些测试,如check_classifiers_regression_target
,会检查您是否进行了正确的验证并引发了正确的错误,但您没有。因此,您要么需要修复此问题,要么添加no_validation
标记。另一个问题是check_classifier_train
检查您的模型是否为给定的问题提供了合理的输出。但由于您返回的是随机值,因此这些条件不满足。您可以设置poor_score
估计器标记来跳过它。
您可以将这些标签添加到您的估计器中:
class TemplateEstimator(BaseEstimator, ClassifierMixin):
...
def _more_tags(self):
return {
"non_deterministic": True,
"no_validation": True,
"poor_score": True,
}
但即便如此,如果使用scikit-learn的main
分支或夜间构建,两个测试也会失败。我认为这需要一个修复程序,我已经为它打开了一个问题(编辑:该修复程序现在与上游版本合并,将在下一个版本中提供(。您可以通过在标记中将这些测试设置为预期失败来避免这些失败。最后,你的估计器会看起来像:
import numpy as np
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.base import BaseEstimator, ClassifierMixin, check_X_y
from sklearn.utils.validation import check_array, check_is_fitted, check_random_state
class TemplateEstimator(BaseEstimator, ClassifierMixin):
def __init__(self, threshold=0.5, random_state=None):
self.threshold = threshold
self.random_state = random_state
def fit(self, X, y):
self.random_state_ = check_random_state(self.random_state)
X, y = check_X_y(X, y)
self.classes_ = np.unique(y)
self.fitted_ = True
return self
def predict(self, X):
check_is_fitted(self)
X = check_array(X)
y_hat = self.random_state_.choice(self.classes_, size=X.shape[0])
return y_hat
def _more_tags(self):
return {
"non_deterministic": True,
"no_validation": True,
"poor_score": True,
"_xfail_checks": {
"check_methods_sample_order_invariance": "This test shouldn't be running at all!",
"check_methods_subset_invariance": "This test shouldn't be running at all!",
},
}
@parametrize_with_checks([TemplateEstimator()])
def test_sklearn_compatible_estimator(estimator, check):
check(estimator)