scikit Learn SGD分类器问题预测



我可能无法在这里找到我需要的帮助,但我希望互联网上聪明的程序员可以帮助我。我正在尝试使用Python的Sci-Kit学习SGDClassifier对物理事件进行分类。这些物理事件创建了一个轨道图像(黑色和白色),我试图让分类器对它们进行分类。这些图像大约是500 * 400像素(不太确定),但出于机器学习的目的,它给了我一个200640维向量。我在包含200个事件的数据包中序列化了20000个火车事件。然后我有额外的2000个火车事件。以下是我如何提取和训练的。

>>> from sklearn.linear_model import SGDClassifier
>>> import dill
>>> import glob
>>> import numpy as np
>>> clf = SGDClassifier(loss='hinge')
>>>for file in glob.glob('./SerializedData/Batch1/*.pkl'):
...    with open(file, 'rb') as stream:
...    minibatch = dill.load(stream)
...        clf.partial_fit(minibatch.data, minibatch.target, classes=np.classes([1, 2]))
(Some output stuff about the classifier)
>>>

这是我的代码的火车部分,或者至少是它的一个粗略的缩写。我有一个更复杂的分类器初始化。为了获得更多信息,minibatch.data给出了形状和特征的numpy数组,因此这是一个"二维numpy数组",形状为200,特征为200640。为了澄清这一点,有一些数组描述了每个图像的像素值,然后其中200个包含在一个大数组中。minibatch.target给出了每个事件的所有类值的numpy数组

然而,在20000个事件的训练之后,我试图测试分类器,它似乎根本没有被训练过:

>>> file = open('./SerializedData/Batch2/train1.pkl', 'rb')
>>> test = dill.load(file)
>>> clf.predict(test.data)
array([ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
    2,  2,  2,  2,  2])
>>> clf.score(test.data)
.484999999999999999999

可以看到,分类器预测所有测试事件的类别2。目前我能想到的唯一问题是我没有足够的测试赛事,但我觉得这很难相信。有人对这种困境有什么建议/解决方案/答案吗?

除非您的图像非常简单,否则如果您的输入是图像,那么仅使用scikit learn将无法获得良好的结果。你需要以某种方式变换图像以获得真正有用的特征,像素值产生糟糕的特征。你可以尝试使用scikit-image中的一些工具来创建更好的特征,或者你可以使用一些预训练的卷积神经网络来为你提取特征。如果你觉得更有冒险精神,你可以尝试训练CNN对你的问题进行分类。

最新更新