在这种情况下,我应该使用哪个分类器或ML SDK



训练数据(包括训练集和验证集)约有80百万个样本,每个样本都有200密集浮点。有6标记的classe,它们是不平衡的。

在常用的ML库(例如,libsvmscikit-learnSpark MLlibrandom forestXGBoost或其他)中,我应该使用哪一个?关于硬件配置,该机器具有24 CPU内核和250 Gb内存。

我建议使用scikit learn的SGD分类器,因为它是在线的,这样你就可以将训练数据分块(小批量)加载到内存中,并逐渐训练分类器,这样你不需要将所有数据加载到内存。

它具有高度并行性,易于使用。您可以将warm_start参数设置为True,并在每个X,y块加载到内存时多次调用fit,或者使用partial_fit方法作为更好的选项。

clf = SGDClassifier(loss='hinge', alpha=1e-4, penalty='l2', l1_ratio=0.9, learning_rate='optimal', n_iter=10, shuffle=False, n_jobs=10, fit_intercept=True)
# len(classes) = n_classes
all_classes = np.array(set_of_all_classes)
while True:
    #load a minibatch from disk into memory
    X, y = load_next_chunk()
    clf.partial_fit(X, y, all_classes) 
X_test, y_test = load_test_data()    
y_pred = clf.predict(X_test)

相关内容

  • 没有找到相关文章

最新更新