使用sklearn获取每个数据点的类概率p(c|x)



我们如何得到每个测试数据点的类概率?有些分类器确实包含"predict_proba()"函数,返回数据点所属数据类的概率。

但是在https://sklearn-lvq.readthedocs.io/en/stable/rslvq.html#

中没有定义这样的函数。我需要计算每个类的类概率,以便可以应用拒绝选项。其思想是计算p(c|x)并检查该值是否小于阈值,则应拒绝该数据点。

你可以试试:方法1:重写predict函数,如下所示

from sklearn_lvq import RslvqModel
from sklearn.utils.validation import check_is_fitted
from sklearn.utils import validation
import numpy as np
class RslvqModel_custom(RslvqModel):
def predict(self, x):
"""Predict class membership index for each input sample.
This function does classification on an array of
test vectors X.

Parameters
----------
x : array-like, shape = [n_samples, n_features]

Returns
-------
C : array, shape = (n_samples,)
Returns predicted values.
"""
check_is_fitted(self, ['w_', 'c_w_'])
x = validation.check_array(x)
if x.shape[1] != self.w_.shape[1]:
raise ValueError("X has wrong number of featuresn"
"found=%dn"
"expected=%d" % (self.w_.shape[1], x.shape[1]))
def foo(e):
fun = np.vectorize(lambda w: self._costf(e, w),
signature='(n)->()')
pred = fun(self.w_)
return pred
predictions = np.vectorize(foo, signature='(n)->(n)')(x)
sum = np.sum(predictions, axis=1).reshape(predictions.shape[0], 1)
return predictions/sum
np.random.seed(1)
print(__doc__)
nb_ppc = 100
x = np.append(
np.random.multivariate_normal([0, 0], np.eye(2) / 2, size=nb_ppc),
np.random.multivariate_normal([5, 0], np.eye(2) / 2, size=nb_ppc), axis=0)
y = np.append(np.zeros(nb_ppc), np.ones(nb_ppc), axis=0)
rslvq = RslvqModel_custom(initial_prototypes=[[5,0,0],[0,0,1]]) #_custom
model = rslvq.fit(x, y)
predictions = model.predict([[3.67, 6.50], [4.97, 1.49], [1.14, -4.3]])
print('============================================================')
print('Predictions: ', predictions)
print('-------------------------------------------------------------')

输出:

Predictions:  [[0.5977081  0.4022919 ]
[0.945568   0.054432  ]
[0.33533978 0.66466022]]

方法2:或者你可以简单地使用内置后验方法:

for data in [[3.67, 6.50], [4.97, 1.49], [1.14, -4.3]]:
print('Class 0:  posterior: ', model.posterior(0, data))
print('Class 1:  posterior: ', model.posterior(1, data))
print('='*100)
OUTPUT:
-------------------------------------------------------------
Class 0:  posterior:  [[0.5977081]]
Class 1:  posterior:  [[0.4022919]]
========================================================
Class 0:  posterior:  [[0.945568]]
Class 1:  posterior:  [[0.054432]]
========================================================
Class 0:  posterior:  [[0.33533978]]
Class 1:  posterior:  [[0.66466022]]
========================================================

最新更新