如何从 sklearn 朴素贝叶斯分类器获取 nbest 预测?-蟒



http://scikit-learn.org/stable/modules/naive_bayes.html 中的示例使用多项式朴素贝叶斯分类器输出最佳目标标签。

如何从sklearn classifier.fit()函数中获得 nbest 结果及其相应的概率?

我已经尝试过这个,它只给出了最好的目标标签:

from sklearn.naive_bayes import MultinomialNB
from sklearn import datasets
iris = datasets.load_iris()
mnb = MultinomialNB()
y_pred = mnb.fit(iris.data, iris.target).predict(iris.data)
print y_pred
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

我试过这个:

mnb.fit(iris.data, iris.target).predict_proba(iris.data)

但它输出的东西看起来像 nbest,但标签是否按顺序排列,如第一列为 0,第二列为 1,第三列为 2?如果是这样,我的iris.datairis.target是否需要在安装前进行分类?

[[ 0.75203199  0.16090571  0.08706229]
 [ 0.68449076  0.19961428  0.11589496]
 [ 0.71655395  0.18031248  0.10313357]
 [ 0.66789673  0.20853796  0.12356531]
 [ 0.75923862  0.15641199  0.08434939]
 [ 0.71232842  0.18631933  0.10135225]
 [ 0.69283061  0.19413936  0.11303003]
 [ 0.72230231  0.1784813   0.09921639]
 [ 0.64908034  0.21795677  0.13296289]
 [ 0.71016301  0.18515188  0.10468511]
 [ 0.7706877   0.15071277  0.07859954]
 [ 0.69872772  0.19198923  0.10928305]
 [ 0.70928234  0.18501507  0.10570259]
 [ 0.73423744  0.16829862  0.09746394]
 [ 0.84434639  0.10536809  0.05028552]
 [ 0.80368141  0.13133624  0.06498235]
 [ 0.76946455  0.15105144  0.07948401]
 [ 0.7255851   0.17657409  0.09784082]
 [ 0.73970806  0.17070375  0.08958819]
 [ 0.74545506  0.16522966  0.08931528]
 [ 0.70811072  0.18854972  0.10333956]
 [ 0.70648697  0.18818499  0.10532804]
 [ 0.7969896   0.13199963  0.07101077]
 [ 0.58777843  0.25618981  0.15603175]
 [ 0.64951166  0.22132827  0.12916007]
 [ 0.65586031  0.21701444  0.12712524]
 [ 0.647116    0.22224047  0.13064353]
 [ 0.74167829  0.16758828  0.09073344]
 [ 0.74468171  0.16548141  0.08983688]
 [ 0.66886376  0.20873516  0.12240109]
 [ 0.66014838  0.21396507  0.12588655]
 [ 0.68161377  0.20339924  0.11498699]
 [ 0.82411012  0.11766438  0.0582255 ]
 [ 0.83205074  0.11304195  0.05490731]
 [ 0.71016301  0.18515188  0.10468511]
 [ 0.74308517  0.1652924   0.09162243]
 [ 0.77963634  0.14494298  0.07542068]
 [ 0.71016301  0.18515188  0.10468511]
 [ 0.67897883  0.20084314  0.12017803]
 [ 0.72628321  0.17640557  0.09731122]
 [ 0.73635517  0.16968093  0.09396391]
 [ 0.55384513  0.26927796  0.17687691]
 [ 0.70420272  0.18658505  0.10921223]
 [ 0.59635556  0.25075085  0.15289358]
 [ 0.65439735  0.21981277  0.12578988]
 [ 0.64951064  0.21917004  0.13131933]
 [ 0.75711005  0.1585194   0.08437054]
 [ 0.69691278  0.19168277  0.11140445]
 [ 0.76716311  0.15262537  0.08021152]
 [ 0.72544435  0.17629024  0.09826541]
 [ 0.05616304  0.51141127  0.4324257 ]
 [ 0.05045622  0.50190524  0.44763854]
 [ 0.03940412  0.50636507  0.4542308 ]
 [ 0.04710523  0.48480988  0.46808489]
 [ 0.03801645  0.49775392  0.46422963]
 [ 0.04556145  0.49250736  0.4619312 ]
 [ 0.03966245  0.49873204  0.46160551]
 [ 0.10701789  0.47535669  0.41762542]
 [ 0.0536651   0.50530072  0.44103419]
 [ 0.05251353  0.48336562  0.46412085]
 [ 0.07633356  0.47948527  0.44418117]
 [ 0.05056292  0.49342004  0.45601704]
 [ 0.07291747  0.49453526  0.43254727]
 [ 0.03955223  0.49594088  0.46450688]
 [ 0.08936249  0.48762276  0.42301475]
 [ 0.06195976  0.50612926  0.43191098]
 [ 0.03831622  0.48825946  0.47342431]
 [ 0.0863508   0.49450781  0.41914139]
 [ 0.02709303  0.48373491  0.48917206]
 [ 0.07519686  0.49028178  0.43452135]
 [ 0.02440995  0.4828392   0.49275084]
 [ 0.07042286  0.49642936  0.43314778]
 [ 0.02455796  0.48683709  0.48860495]
 [ 0.04880578  0.49984347  0.45135075]
 [ 0.06389132  0.50205774  0.43405094]
 [ 0.05742052  0.50419335  0.43838613]
 [ 0.03993759  0.50401099  0.45605142]
 [ 0.02528131  0.49378132  0.48093736]
 [ 0.03918748  0.49255696  0.46825556]
 [ 0.1206129   0.48274379  0.39664332]
 [ 0.0748217   0.48808052  0.43709778]
 [ 0.09117508  0.48671672  0.4221082 ]
 [ 0.07675253  0.49259019  0.43065728]
 [ 0.01952175  0.47919808  0.50128017]
 [ 0.0367986   0.48532662  0.47787478]
 [ 0.04575051  0.49665456  0.45759493]
 [ 0.04377224  0.50435344  0.45187432]
 [ 0.04137786  0.49450702  0.46411511]
 [ 0.06664472  0.49270322  0.44065206]
 [ 0.05283801  0.4871102   0.46005178]
 [ 0.04794017  0.4900128   0.46204703]
 [ 0.04505918  0.49770389  0.45723693]
 [ 0.0676616   0.49295412  0.43938428]
 [ 0.1033085   0.47658754  0.42010396]
 [ 0.05234803  0.49047033  0.45718164]
 [ 0.07213391  0.49516342  0.43270267]
 [ 0.0598313   0.49357345  0.44659525]
 [ 0.06145081  0.49977467  0.43877452]
 [ 0.12643795  0.47099626  0.40256579]
 [ 0.06070342  0.4924782   0.44681837]
 [ 0.0042955   0.43723717  0.55846733]
 [ 0.01234151  0.46132786  0.52633063]
 [ 0.00804196  0.46804151  0.52391653]
 [ 0.01220374  0.4719262   0.51587006]
 [ 0.00663872  0.45347587  0.53988541]
 [ 0.00526335  0.46448324  0.53025341]
 [ 0.01877963  0.45982004  0.52140033]
 [ 0.0088868   0.47816164  0.51295156]
 [ 0.00897388  0.46607468  0.52495144]
 [ 0.00578841  0.45971014  0.53450145]
 [ 0.01677457  0.48014283  0.50308259]
 [ 0.01204406  0.46873836  0.51921758]
 [ 0.01020869  0.46951911  0.5202722 ]
 [ 0.0100408   0.45127554  0.53868367]
 [ 0.00649696  0.43748008  0.55602296]
 [ 0.00930427  0.46015047  0.53054526]
 [ 0.01456352  0.47933745  0.50609903]
 [ 0.00702627  0.47944937  0.51352435]
 [ 0.00253012  0.4388473   0.55862257]
 [ 0.01793058  0.47433015  0.50773927]
 [ 0.00764407  0.46221754  0.53013839]
 [ 0.01270343  0.45832067  0.52897589]
 [ 0.00508153  0.46485081  0.53006765]
 [ 0.01829021  0.47785801  0.50385178]
 [ 0.01030084  0.47164088  0.51805829]
 [ 0.01305643  0.48854867  0.4983949 ]
 [ 0.02048499  0.47963663  0.49987838]
 [ 0.02097735  0.48087632  0.49814633]
 [ 0.00771013  0.45561321  0.53667666]
 [ 0.01775039  0.49657519  0.48567443]
 [ 0.00863886  0.47567646  0.51568469]
 [ 0.01213462  0.49867869  0.48918669]
 [ 0.00669855  0.45021395  0.5430875 ]
 [ 0.02528409  0.49048657  0.48422934]
 [ 0.01705235  0.48108164  0.50186601]
 [ 0.00592661  0.46457057  0.52950282]
 [ 0.00712125  0.45269329  0.54018547]
 [ 0.01514248  0.47987549  0.50498204]
 [ 0.02213426  0.48045724  0.4974085 ]
 [ 0.0119365   0.47535905  0.51270445]
 [ 0.00646304  0.45242581  0.54111115]
 [ 0.01131574  0.4700604   0.51862386]
 [ 0.01234151  0.46132786  0.52633063]
 [ 0.00643277  0.45678109  0.53678614]
 [ 0.00587581  0.45021579  0.5439084 ]
 [ 0.00947358  0.46237113  0.52815529]
 [ 0.01310372  0.46704681  0.51984947]
 [ 0.01380969  0.47411641  0.5120739 ]
 [ 0.00933287  0.45973346  0.53093367]
 [ 0.01733739  0.4748122   0.50785041]]

使用 predict_proba()

from sklearn.naive_bayes import MultinomialNB
from sklearn import datasets
iris = datasets.load_iris()
mnb = MultinomialNB()
y_pred = mnb.fit(iris.data, iris.target).predict_proba(iris.data)
tags = list(set(iris.target.tolist()))
probs = y_pred.tolist() 
print [zip(i, tags) for i in probs]

[输出]:

[[(0.7520319948772209, 0), (0.160905710773243, 1), (0.08706229434953643, 2)], [(0.6844907609326716, 0), (0.1996142797317376, 1), (0.11589495933559005, 2)], [(0.7165539514487157, 0), (0.18031248114566703, 1), (0.10313356740561741, 2)], [(0.6678967279059345, 0), (0.20853796442897996, 1), (0.12356530766508479, 2)], [(0.7592386233982199, 0), (0.1564119915930661, 1), (0.08434938500871363, 2)], [(0.7123284219054676, 0), (0.18631932889818908, 1), (0.10135224919634335, 2)], [(0.692830614508067, 0), (0.1941393579908561, 1), (0.11303002750107637, 2)], [(0.7223023141841948, 0), (0.17848129978229782, 1), (0.09921638603350694, 2)], [(0.6490803351236442, 0), (0.2179567749596649, 1), (0.1329628899166902, 2)], [(0.7101630090806663, 0), (0.18515188117472428, 1), (0.10468510974460926, 2)], [(0.7706876980918721, 0), (0.15071276584157142, 1), (0.0785995360665561, 2)], [(0.6987277167692307, 0), (0.19198923174718688, 1), (0.10928305148358229, 2)], [(0.7092823355477573, 0), (0.18501507079143967, 1), (0.1057025936608035, 2)], [(0.7342374424798592, 0), (0.16829861767010512, 1), (0.0974639398500348, 2)], [(0.8443463905803837, 0), (0.10536809401960547, 1), (0.05028551540001017, 2)], [(0.8036814104223939, 0), (0.1313362382603278, 1), (0.06498235131727866, 2)], [(0.769464547835359, 0), (0.15105143900207532, 1), (0.07948401316256604, 2)], [(0.7255850952734539, 0), (0.17657408576774905, 1), (0.097840818958797, 2)], [(0.7397080632341082, 0), (0.1707037515928118, 1), (0.08958818517308087, 2)], [(0.7454550625206028, 0), (0.16522965912480264, 1), (0.08931527835459417, 2)], [(0.7081107158186839, 0), (0.1885497211367833, 1), (0.10333956304453229, 2)], [(0.7064869683171973, 0), (0.18818498727469904, 1), (0.10532804440810306, 2)], [(0.796989599495376, 0), (0.13199962567205156, 1), (0.0710107748325723, 2)], [(0.587778432243368, 0), (0.2561898147126928, 1), (0.1560317530439388, 2)], [(0.6495116627735238, 0), (0.22132827197437152, 1), (0.12916006525210533, 2)], [(0.6558603115752442, 0), (0.21701444369684725, 1), (0.1271252447279083, 2)], [(0.6471160032245035, 0), (0.22224046981900822, 1), (0.1306435269564876, 2)], [(0.7416782873328798, 0), (0.1675882755811794, 1), (0.09073343708594127, 2)], [(0.7446817125214586, 0), (0.16548141143388176, 1), (0.08983687604466006, 2)], [(0.6688637590831867, 0), (0.2087351551980403, 1), (0.12240108571877327, 2)], [(0.6601483794765286, 0), (0.21396507128778453, 1), (0.12588654923568643, 2)], [(0.6816137660789651, 0), (0.2033992437242013, 1), (0.11498699019683437, 2)], [(0.8241101192684347, 0), (0.11766437654228883, 1), (0.058225504189276536, 2)], [(0.83205074128301, 0), (0.11304194543139681, 1), (0.05490731328559234, 2)], [(0.7101630090806663, 0), (0.18515188117472428, 1), (0.10468510974460926, 2)], [(0.7430851694455511, 0), (0.16529239865700102, 1), (0.0916224318974473, 2)], [(0.779636336970197, 0), (0.14494298030029326, 1), (0.0754206827295091, 2)], [(0.7101630090806663, 0), (0.18515188117472428, 1), (0.10468510974460926, 2)], [(0.6789788268871746, 0), (0.20084313848520732, 1), (0.12017803462761736, 2)], [(0.7262832077959535, 0), (0.17640557129332532, 1), (0.0973112209107207, 2)], [(0.7363551685765576, 0), (0.1696809261744515, 1), (0.0939639052489906, 2)], [(0.5538451256026228, 0), (0.2692779602391119, 1), (0.17687691415826487, 2)], [(0.7042027164892329, 0), (0.18658505267113334, 1), (0.10921223083963391, 2)], [(0.5963555624638872, 0), (0.25075085344725584, 1), (0.15289358408885717, 2)], [(0.65439734662364, 0), (0.2198127742453153, 1), (0.12578987913104406, 2)], [(0.6495106395351926, 0), (0.21917003546172703, 1), (0.1313193250030794, 2)], [(0.7571100536851544, 0), (0.15851940194145597, 1), (0.08437054437338924, 2)], [(0.6969127760751604, 0), (0.1916827744135674, 1), (0.11140444951127176, 2)], [(0.7671631061752553, 0), (0.1526253742597591, 1), (0.08021151956498587, 2)], [(0.7254443501785298, 0), (0.1762902432553972, 1), (0.09826540656607288, 2)], [(0.05616303783205017, 0), (0.5114112651916364, 1), (0.432425696976313, 2)], [(0.05045622209773331, 0), (0.5019052377441757, 1), (0.4476385401580907, 2)], [(0.039404121537601915, 0), (0.5063650737015377, 1), (0.4542308047608592, 2)], [(0.04710522699818281, 0), (0.48480988067469, 1), (0.46808489232712763, 2)], [(0.03801644816953021, 0), (0.49775392400502627, 1), (0.46422962782544475, 2)], [(0.045561446486157016, 0), (0.4925073573450104, 1), (0.4619311961688331, 2)], [(0.039662446204841086, 0), (0.49873204417078004, 1), (0.46160550962437963, 2)], [(0.10701788982710146, 0), (0.47535669064402614, 1), (0.41762541952887167, 2)], [(0.053665098673951646, 0), (0.505300715070129, 1), (0.4410341862559204, 2)], [(0.0525135327993029, 0), (0.48336561688204754, 1), (0.46412085031864936, 2)], [(0.07633356381663149, 0), (0.479485268421569, 1), (0.4441811677617995, 2)], [(0.050562916172520865, 0), (0.49342004113700544, 1), (0.45601704269047194, 2)], [(0.07291747327601361, 0), (0.49453525978952834, 1), (0.43254726693445994, 2)], [(0.03955223347843718, 0), (0.4959408839306025, 1), (0.46450688259096, 2)], [(0.08936248960426366, 0), (0.48762276421125, 1), (0.4230147461844858, 2)], [(0.06195975911054385, 0), (0.5061292622360939, 1), (0.431910978653363, 2)], [(0.03831622365371633, 0), (0.4882594641563942, 1), (0.4734243121898909, 2)], [(0.08635079776773631, 0), (0.49450781017113565, 1), (0.4191413920611266, 2)], [(0.027093026736842187, 0), (0.48373491476096137, 1), (0.48917205850219836, 2)], [(0.07519686280624735, 0), (0.4902817831829304, 1), (0.4345213540108229, 2)], [(0.024409954485511174, 0), (0.4828392042793499, 1), (0.492750841235139, 2)], [(0.07042285903603575, 0), (0.4964293611155268, 1), (0.4331477798484379, 2)], [(0.02455795536824762, 0), (0.48683709246033974, 1), (0.4886049521714139, 2)], [(0.04880578029883432, 0), (0.4998434674757908, 1), (0.45135075222537513, 2)], [(0.0638913231350886, 0), (0.5020577350429162, 1), (0.4340509418219941, 2)], [(0.05742052492537856, 0), (0.5041933482237854, 1), (0.43838612685083517, 2)], [(0.03993759354894706, 0), (0.504010990963511, 1), (0.4560514154875418, 2)], [(0.02528131305415432, 0), (0.4937813226658602, 1), (0.4809373642799863, 2)], [(0.03918748210864577, 0), (0.49255695819412515, 1), (0.4682555596972289, 2)], [(0.12061289562641707, 0), (0.48274378627696585, 1), (0.39664331809661796, 2)], [(0.07482169849815479, 0), (0.48808052342521985, 1), (0.4370977780766249, 2)], [(0.09117508375957138, 0), (0.48671672075844624, 1), (0.42210819548198314, 2)], [(0.07675253014451815, 0), (0.49259018574599794, 1), (0.43065728410948345, 2)], [(0.019521752496472658, 0), (0.47919807833320655, 1), (0.5012801691703213, 2)], [(0.03679860023973873, 0), (0.48532661590429554, 1), (0.4778747838559651, 2)], [(0.04575050525515768, 0), (0.4966545638828356, 1), (0.45759493086200687, 2)], [(0.04377223745963666, 0), (0.5043534382070538, 1), (0.4518743243333101, 2)], [(0.041377864559958115, 0), (0.49450702243148886, 1), (0.4641151130085546, 2)], [(0.06664472235797042, 0), (0.4927032207407448, 1), (0.4406520569012841, 2)], [(0.0528380132202798, 0), (0.4871102029137462, 1), (0.460051783865975, 2)], [(0.047940170349888546, 0), (0.49001280154973126, 1), (0.4620470281003815, 2)], [(0.04505918048213829, 0), (0.49770388865629844, 1), (0.45723693086156225, 2)], [(0.06766160412311599, 0), (0.4929541163343243, 1), (0.43938427954256065, 2)], [(0.10330850300509363, 0), (0.476587539562611, 1), (0.42010395743229584, 2)], [(0.052348028868769506, 0), (0.4904703285165612, 1), (0.45718164261466987, 2)], [(0.07213390841723287, 0), (0.4951634224930377, 1), (0.43270266908972993, 2)], [(0.05983129688259453, 0), (0.49357344812806564, 1), (0.44659525498934033, 2)], [(0.06145081121761718, 0), (0.4997746691613533, 1), (0.4387745196210307, 2)], [(0.12643795001228872, 0), (0.47099626175067344, 1), (0.40256578823703776, 2)], [(0.06070342294682647, 0), (0.49247820304725287, 1), (0.4468183740059213, 2)], [(0.004295498135605839, 0), (0.43723717013987085, 1), (0.5584673317245246, 2)], [(0.012341506012939497, 0), (0.46132786049878993, 1), (0.5263306334882699, 2)], [(0.008041961751626116, 0), (0.4680415056267554, 1), (0.5239165326216197, 2)], [(0.012203741968758995, 0), (0.4719261984222097, 1), (0.5158700596090297, 2)], [(0.006638717235494328, 0), (0.45347586782275284, 1), (0.5398854149417517, 2)], [(0.005263351295626021, 0), (0.4644832412791946, 1), (0.5302534074251788, 2)], [(0.018779628968296866, 0), (0.45982004366157864, 1), (0.5214003273701255, 2)], [(0.00888680054755727, 0), (0.47816163997455585, 1), (0.5129515594778852, 2)], [(0.008973878798549955, 0), (0.4660746833425122, 1), (0.5249514378589395, 2)], [(0.0057884114937193755, 0), (0.45971014127911347, 1), (0.534501447227168, 2)], [(0.016774571271002955, 0), (0.4801428349049056, 1), (0.5030825938240929, 2)], [(0.012044056966578119, 0), (0.4687383638137518, 1), (0.5192175792196708, 2)], [(0.010208690383078136, 0), (0.4695191063788682, 1), (0.5202722032380547, 2)], [(0.010040795166916689, 0), (0.4512755375694701, 1), (0.5386836672636148, 2)], [(0.0064969600914313004, 0), (0.43748007849739584, 1), (0.5560229614111728, 2)], [(0.009304270226104616, 0), (0.4601504705473684, 1), (0.530545259226528, 2)], [(0.014563517615477736, 0), (0.47933745208881223, 1), (0.5060990302957091, 2)], [(0.0070262746504037524, 0), (0.47944937445766783, 1), (0.5135243508919285, 2)], [(0.0025301248297814735, 0), (0.4388473043566187, 1), (0.5586225708136002, 2)], [(0.01793057634785105, 0), (0.4743301503029546, 1), (0.5077392733491937, 2)], [(0.007644068684806274, 0), (0.4622175416837792, 1), (0.5301383896314145, 2)], [(0.01270343330775703, 0), (0.4583206744735115, 1), (0.5289758922187313, 2)], [(0.005081534005270547, 0), (0.46485081415813695, 1), (0.5300676518365941, 2)], [(0.01829020671024577, 0), (0.47785801303090053, 1), (0.5038517802588545, 2)], [(0.010300836830757908, 0), (0.4716408754308902, 1), (0.5180582877383503, 2)], [(0.013056433889026029, 0), (0.48854866614356096, 1), (0.49839489996741193, 2)], [(0.02048499287484417, 0), (0.4796366294972459, 1), (0.4998783776279093, 2)], [(0.02097735435316386, 0), (0.48087632026222776, 1), (0.4981463253846089, 2)], [(0.007710125240385033, 0), (0.4556132126926747, 1), (0.5366766620669403, 2)], [(0.01775038522782329, 0), (0.4965751879587703, 1), (0.4856744268134073, 2)], [(0.008638855104322298, 0), (0.4756764562227914, 1), (0.5156846886728871, 2)], [(0.012134615048701844, 0), (0.49867869309551527, 1), (0.4891866918557844, 2)], [(0.006698546808664106, 0), (0.45021394937102427, 1), (0.5430875038203126, 2)], [(0.02528409377638384, 0), (0.4904865676449243, 1), (0.4842293385786926, 2)], [(0.017052353961584325, 0), (0.48108163514744395, 1), (0.5018660108909714, 2)], [(0.0059266080768430335, 0), (0.46457056945157915, 1), (0.5295028224715764, 2)], [(0.007121249136349586, 0), (0.45269328584265495, 1), (0.5401854650209942, 2)], [(0.015142475318899725, 0), (0.4798754856996685, 1), (0.5049820389814322, 2)], [(0.02213425927551211, 0), (0.480457241275485, 1), (0.49740849944900195, 2)], [(0.011936502731266969, 0), (0.4753590492087327, 1), (0.5127044480600016, 2)], [(0.006463038736312238, 0), (0.4524258122544322, 1), (0.5411111490092549, 2)], [(0.011315743506039701, 0), (0.4700603993589541, 1), (0.5186238571350079, 2)], [(0.012341506012939497, 0), (0.46132786049878993, 1), (0.5263306334882699, 2)], [(0.006432772071935396, 0), (0.4567810928523039, 1), (0.5367861350757602, 2)], [(0.00587581213132872, 0), (0.45021579023480335, 1), (0.5439083976338676, 2)], [(0.009473580472389227, 0), (0.46237113309649774, 1), (0.5281552864311121, 2)], [(0.013103723329522001, 0), (0.4670468076720037, 1), (0.5198494689984746, 2)], [(0.0138096923928804, 0), (0.47411641181074315, 1), (0.5120738957963767, 2)], [(0.009332865067928468, 0), (0.4597334606118675, 1), (0.5309336743202042, 2)], [(0.017337387678188924, 0), (0.474812201112126, 1), (0.5078504112096852, 2)]]

正如文档所说,

返回模型中每个类的样本概率,其中类按算术排序。

相关内容

  • 没有找到相关文章

最新更新