http://scikit-learn.org/stable/modules/naive_bayes.html 中的示例使用多项式朴素贝叶斯分类器输出最佳目标标签。
如何从sklearn
classifier.fit()
函数中获得 nbest 结果及其相应的概率?
我已经尝试过这个,它只给出了最好的目标标签:
from sklearn.naive_bayes import MultinomialNB
from sklearn import datasets
iris = datasets.load_iris()
mnb = MultinomialNB()
y_pred = mnb.fit(iris.data, iris.target).predict(iris.data)
print y_pred
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()
我试过这个:
mnb.fit(iris.data, iris.target).predict_proba(iris.data)
但它输出的东西看起来像 nbest,但标签是否按顺序排列,如第一列为 0,第二列为 1,第三列为 2?如果是这样,我的iris.data
和iris.target
是否需要在安装前进行分类?
[[ 0.75203199 0.16090571 0.08706229]
[ 0.68449076 0.19961428 0.11589496]
[ 0.71655395 0.18031248 0.10313357]
[ 0.66789673 0.20853796 0.12356531]
[ 0.75923862 0.15641199 0.08434939]
[ 0.71232842 0.18631933 0.10135225]
[ 0.69283061 0.19413936 0.11303003]
[ 0.72230231 0.1784813 0.09921639]
[ 0.64908034 0.21795677 0.13296289]
[ 0.71016301 0.18515188 0.10468511]
[ 0.7706877 0.15071277 0.07859954]
[ 0.69872772 0.19198923 0.10928305]
[ 0.70928234 0.18501507 0.10570259]
[ 0.73423744 0.16829862 0.09746394]
[ 0.84434639 0.10536809 0.05028552]
[ 0.80368141 0.13133624 0.06498235]
[ 0.76946455 0.15105144 0.07948401]
[ 0.7255851 0.17657409 0.09784082]
[ 0.73970806 0.17070375 0.08958819]
[ 0.74545506 0.16522966 0.08931528]
[ 0.70811072 0.18854972 0.10333956]
[ 0.70648697 0.18818499 0.10532804]
[ 0.7969896 0.13199963 0.07101077]
[ 0.58777843 0.25618981 0.15603175]
[ 0.64951166 0.22132827 0.12916007]
[ 0.65586031 0.21701444 0.12712524]
[ 0.647116 0.22224047 0.13064353]
[ 0.74167829 0.16758828 0.09073344]
[ 0.74468171 0.16548141 0.08983688]
[ 0.66886376 0.20873516 0.12240109]
[ 0.66014838 0.21396507 0.12588655]
[ 0.68161377 0.20339924 0.11498699]
[ 0.82411012 0.11766438 0.0582255 ]
[ 0.83205074 0.11304195 0.05490731]
[ 0.71016301 0.18515188 0.10468511]
[ 0.74308517 0.1652924 0.09162243]
[ 0.77963634 0.14494298 0.07542068]
[ 0.71016301 0.18515188 0.10468511]
[ 0.67897883 0.20084314 0.12017803]
[ 0.72628321 0.17640557 0.09731122]
[ 0.73635517 0.16968093 0.09396391]
[ 0.55384513 0.26927796 0.17687691]
[ 0.70420272 0.18658505 0.10921223]
[ 0.59635556 0.25075085 0.15289358]
[ 0.65439735 0.21981277 0.12578988]
[ 0.64951064 0.21917004 0.13131933]
[ 0.75711005 0.1585194 0.08437054]
[ 0.69691278 0.19168277 0.11140445]
[ 0.76716311 0.15262537 0.08021152]
[ 0.72544435 0.17629024 0.09826541]
[ 0.05616304 0.51141127 0.4324257 ]
[ 0.05045622 0.50190524 0.44763854]
[ 0.03940412 0.50636507 0.4542308 ]
[ 0.04710523 0.48480988 0.46808489]
[ 0.03801645 0.49775392 0.46422963]
[ 0.04556145 0.49250736 0.4619312 ]
[ 0.03966245 0.49873204 0.46160551]
[ 0.10701789 0.47535669 0.41762542]
[ 0.0536651 0.50530072 0.44103419]
[ 0.05251353 0.48336562 0.46412085]
[ 0.07633356 0.47948527 0.44418117]
[ 0.05056292 0.49342004 0.45601704]
[ 0.07291747 0.49453526 0.43254727]
[ 0.03955223 0.49594088 0.46450688]
[ 0.08936249 0.48762276 0.42301475]
[ 0.06195976 0.50612926 0.43191098]
[ 0.03831622 0.48825946 0.47342431]
[ 0.0863508 0.49450781 0.41914139]
[ 0.02709303 0.48373491 0.48917206]
[ 0.07519686 0.49028178 0.43452135]
[ 0.02440995 0.4828392 0.49275084]
[ 0.07042286 0.49642936 0.43314778]
[ 0.02455796 0.48683709 0.48860495]
[ 0.04880578 0.49984347 0.45135075]
[ 0.06389132 0.50205774 0.43405094]
[ 0.05742052 0.50419335 0.43838613]
[ 0.03993759 0.50401099 0.45605142]
[ 0.02528131 0.49378132 0.48093736]
[ 0.03918748 0.49255696 0.46825556]
[ 0.1206129 0.48274379 0.39664332]
[ 0.0748217 0.48808052 0.43709778]
[ 0.09117508 0.48671672 0.4221082 ]
[ 0.07675253 0.49259019 0.43065728]
[ 0.01952175 0.47919808 0.50128017]
[ 0.0367986 0.48532662 0.47787478]
[ 0.04575051 0.49665456 0.45759493]
[ 0.04377224 0.50435344 0.45187432]
[ 0.04137786 0.49450702 0.46411511]
[ 0.06664472 0.49270322 0.44065206]
[ 0.05283801 0.4871102 0.46005178]
[ 0.04794017 0.4900128 0.46204703]
[ 0.04505918 0.49770389 0.45723693]
[ 0.0676616 0.49295412 0.43938428]
[ 0.1033085 0.47658754 0.42010396]
[ 0.05234803 0.49047033 0.45718164]
[ 0.07213391 0.49516342 0.43270267]
[ 0.0598313 0.49357345 0.44659525]
[ 0.06145081 0.49977467 0.43877452]
[ 0.12643795 0.47099626 0.40256579]
[ 0.06070342 0.4924782 0.44681837]
[ 0.0042955 0.43723717 0.55846733]
[ 0.01234151 0.46132786 0.52633063]
[ 0.00804196 0.46804151 0.52391653]
[ 0.01220374 0.4719262 0.51587006]
[ 0.00663872 0.45347587 0.53988541]
[ 0.00526335 0.46448324 0.53025341]
[ 0.01877963 0.45982004 0.52140033]
[ 0.0088868 0.47816164 0.51295156]
[ 0.00897388 0.46607468 0.52495144]
[ 0.00578841 0.45971014 0.53450145]
[ 0.01677457 0.48014283 0.50308259]
[ 0.01204406 0.46873836 0.51921758]
[ 0.01020869 0.46951911 0.5202722 ]
[ 0.0100408 0.45127554 0.53868367]
[ 0.00649696 0.43748008 0.55602296]
[ 0.00930427 0.46015047 0.53054526]
[ 0.01456352 0.47933745 0.50609903]
[ 0.00702627 0.47944937 0.51352435]
[ 0.00253012 0.4388473 0.55862257]
[ 0.01793058 0.47433015 0.50773927]
[ 0.00764407 0.46221754 0.53013839]
[ 0.01270343 0.45832067 0.52897589]
[ 0.00508153 0.46485081 0.53006765]
[ 0.01829021 0.47785801 0.50385178]
[ 0.01030084 0.47164088 0.51805829]
[ 0.01305643 0.48854867 0.4983949 ]
[ 0.02048499 0.47963663 0.49987838]
[ 0.02097735 0.48087632 0.49814633]
[ 0.00771013 0.45561321 0.53667666]
[ 0.01775039 0.49657519 0.48567443]
[ 0.00863886 0.47567646 0.51568469]
[ 0.01213462 0.49867869 0.48918669]
[ 0.00669855 0.45021395 0.5430875 ]
[ 0.02528409 0.49048657 0.48422934]
[ 0.01705235 0.48108164 0.50186601]
[ 0.00592661 0.46457057 0.52950282]
[ 0.00712125 0.45269329 0.54018547]
[ 0.01514248 0.47987549 0.50498204]
[ 0.02213426 0.48045724 0.4974085 ]
[ 0.0119365 0.47535905 0.51270445]
[ 0.00646304 0.45242581 0.54111115]
[ 0.01131574 0.4700604 0.51862386]
[ 0.01234151 0.46132786 0.52633063]
[ 0.00643277 0.45678109 0.53678614]
[ 0.00587581 0.45021579 0.5439084 ]
[ 0.00947358 0.46237113 0.52815529]
[ 0.01310372 0.46704681 0.51984947]
[ 0.01380969 0.47411641 0.5120739 ]
[ 0.00933287 0.45973346 0.53093367]
[ 0.01733739 0.4748122 0.50785041]]
使用 predict_proba()
:
from sklearn.naive_bayes import MultinomialNB
from sklearn import datasets
iris = datasets.load_iris()
mnb = MultinomialNB()
y_pred = mnb.fit(iris.data, iris.target).predict_proba(iris.data)
tags = list(set(iris.target.tolist()))
probs = y_pred.tolist()
print [zip(i, tags) for i in probs]
[输出]:
[[(0.7520319948772209, 0), (0.160905710773243, 1), (0.08706229434953643, 2)], [(0.6844907609326716, 0), (0.1996142797317376, 1), (0.11589495933559005, 2)], [(0.7165539514487157, 0), (0.18031248114566703, 1), (0.10313356740561741, 2)], [(0.6678967279059345, 0), (0.20853796442897996, 1), (0.12356530766508479, 2)], [(0.7592386233982199, 0), (0.1564119915930661, 1), (0.08434938500871363, 2)], [(0.7123284219054676, 0), (0.18631932889818908, 1), (0.10135224919634335, 2)], [(0.692830614508067, 0), (0.1941393579908561, 1), (0.11303002750107637, 2)], [(0.7223023141841948, 0), (0.17848129978229782, 1), (0.09921638603350694, 2)], [(0.6490803351236442, 0), (0.2179567749596649, 1), (0.1329628899166902, 2)], [(0.7101630090806663, 0), (0.18515188117472428, 1), (0.10468510974460926, 2)], [(0.7706876980918721, 0), (0.15071276584157142, 1), (0.0785995360665561, 2)], [(0.6987277167692307, 0), (0.19198923174718688, 1), (0.10928305148358229, 2)], [(0.7092823355477573, 0), (0.18501507079143967, 1), (0.1057025936608035, 2)], [(0.7342374424798592, 0), (0.16829861767010512, 1), (0.0974639398500348, 2)], [(0.8443463905803837, 0), (0.10536809401960547, 1), (0.05028551540001017, 2)], [(0.8036814104223939, 0), (0.1313362382603278, 1), (0.06498235131727866, 2)], [(0.769464547835359, 0), (0.15105143900207532, 1), (0.07948401316256604, 2)], [(0.7255850952734539, 0), (0.17657408576774905, 1), (0.097840818958797, 2)], [(0.7397080632341082, 0), (0.1707037515928118, 1), (0.08958818517308087, 2)], [(0.7454550625206028, 0), (0.16522965912480264, 1), (0.08931527835459417, 2)], [(0.7081107158186839, 0), (0.1885497211367833, 1), (0.10333956304453229, 2)], [(0.7064869683171973, 0), (0.18818498727469904, 1), (0.10532804440810306, 2)], [(0.796989599495376, 0), (0.13199962567205156, 1), (0.0710107748325723, 2)], [(0.587778432243368, 0), (0.2561898147126928, 1), (0.1560317530439388, 2)], [(0.6495116627735238, 0), (0.22132827197437152, 1), (0.12916006525210533, 2)], [(0.6558603115752442, 0), (0.21701444369684725, 1), (0.1271252447279083, 2)], [(0.6471160032245035, 0), (0.22224046981900822, 1), (0.1306435269564876, 2)], [(0.7416782873328798, 0), (0.1675882755811794, 1), (0.09073343708594127, 2)], [(0.7446817125214586, 0), (0.16548141143388176, 1), (0.08983687604466006, 2)], [(0.6688637590831867, 0), (0.2087351551980403, 1), (0.12240108571877327, 2)], [(0.6601483794765286, 0), (0.21396507128778453, 1), (0.12588654923568643, 2)], [(0.6816137660789651, 0), (0.2033992437242013, 1), (0.11498699019683437, 2)], [(0.8241101192684347, 0), (0.11766437654228883, 1), (0.058225504189276536, 2)], [(0.83205074128301, 0), (0.11304194543139681, 1), (0.05490731328559234, 2)], [(0.7101630090806663, 0), (0.18515188117472428, 1), (0.10468510974460926, 2)], [(0.7430851694455511, 0), (0.16529239865700102, 1), (0.0916224318974473, 2)], [(0.779636336970197, 0), (0.14494298030029326, 1), (0.0754206827295091, 2)], [(0.7101630090806663, 0), (0.18515188117472428, 1), (0.10468510974460926, 2)], [(0.6789788268871746, 0), (0.20084313848520732, 1), (0.12017803462761736, 2)], [(0.7262832077959535, 0), (0.17640557129332532, 1), (0.0973112209107207, 2)], [(0.7363551685765576, 0), (0.1696809261744515, 1), (0.0939639052489906, 2)], [(0.5538451256026228, 0), (0.2692779602391119, 1), (0.17687691415826487, 2)], [(0.7042027164892329, 0), (0.18658505267113334, 1), (0.10921223083963391, 2)], [(0.5963555624638872, 0), (0.25075085344725584, 1), (0.15289358408885717, 2)], [(0.65439734662364, 0), (0.2198127742453153, 1), (0.12578987913104406, 2)], [(0.6495106395351926, 0), (0.21917003546172703, 1), (0.1313193250030794, 2)], [(0.7571100536851544, 0), (0.15851940194145597, 1), (0.08437054437338924, 2)], [(0.6969127760751604, 0), (0.1916827744135674, 1), (0.11140444951127176, 2)], [(0.7671631061752553, 0), (0.1526253742597591, 1), (0.08021151956498587, 2)], [(0.7254443501785298, 0), (0.1762902432553972, 1), (0.09826540656607288, 2)], [(0.05616303783205017, 0), (0.5114112651916364, 1), (0.432425696976313, 2)], [(0.05045622209773331, 0), (0.5019052377441757, 1), (0.4476385401580907, 2)], [(0.039404121537601915, 0), (0.5063650737015377, 1), (0.4542308047608592, 2)], [(0.04710522699818281, 0), (0.48480988067469, 1), (0.46808489232712763, 2)], [(0.03801644816953021, 0), (0.49775392400502627, 1), (0.46422962782544475, 2)], [(0.045561446486157016, 0), (0.4925073573450104, 1), (0.4619311961688331, 2)], [(0.039662446204841086, 0), (0.49873204417078004, 1), (0.46160550962437963, 2)], [(0.10701788982710146, 0), (0.47535669064402614, 1), (0.41762541952887167, 2)], [(0.053665098673951646, 0), (0.505300715070129, 1), (0.4410341862559204, 2)], [(0.0525135327993029, 0), (0.48336561688204754, 1), (0.46412085031864936, 2)], [(0.07633356381663149, 0), (0.479485268421569, 1), (0.4441811677617995, 2)], [(0.050562916172520865, 0), (0.49342004113700544, 1), (0.45601704269047194, 2)], [(0.07291747327601361, 0), (0.49453525978952834, 1), (0.43254726693445994, 2)], [(0.03955223347843718, 0), (0.4959408839306025, 1), (0.46450688259096, 2)], [(0.08936248960426366, 0), (0.48762276421125, 1), (0.4230147461844858, 2)], [(0.06195975911054385, 0), (0.5061292622360939, 1), (0.431910978653363, 2)], [(0.03831622365371633, 0), (0.4882594641563942, 1), (0.4734243121898909, 2)], [(0.08635079776773631, 0), (0.49450781017113565, 1), (0.4191413920611266, 2)], [(0.027093026736842187, 0), (0.48373491476096137, 1), (0.48917205850219836, 2)], [(0.07519686280624735, 0), (0.4902817831829304, 1), (0.4345213540108229, 2)], [(0.024409954485511174, 0), (0.4828392042793499, 1), (0.492750841235139, 2)], [(0.07042285903603575, 0), (0.4964293611155268, 1), (0.4331477798484379, 2)], [(0.02455795536824762, 0), (0.48683709246033974, 1), (0.4886049521714139, 2)], [(0.04880578029883432, 0), (0.4998434674757908, 1), (0.45135075222537513, 2)], [(0.0638913231350886, 0), (0.5020577350429162, 1), (0.4340509418219941, 2)], [(0.05742052492537856, 0), (0.5041933482237854, 1), (0.43838612685083517, 2)], [(0.03993759354894706, 0), (0.504010990963511, 1), (0.4560514154875418, 2)], [(0.02528131305415432, 0), (0.4937813226658602, 1), (0.4809373642799863, 2)], [(0.03918748210864577, 0), (0.49255695819412515, 1), (0.4682555596972289, 2)], [(0.12061289562641707, 0), (0.48274378627696585, 1), (0.39664331809661796, 2)], [(0.07482169849815479, 0), (0.48808052342521985, 1), (0.4370977780766249, 2)], [(0.09117508375957138, 0), (0.48671672075844624, 1), (0.42210819548198314, 2)], [(0.07675253014451815, 0), (0.49259018574599794, 1), (0.43065728410948345, 2)], [(0.019521752496472658, 0), (0.47919807833320655, 1), (0.5012801691703213, 2)], [(0.03679860023973873, 0), (0.48532661590429554, 1), (0.4778747838559651, 2)], [(0.04575050525515768, 0), (0.4966545638828356, 1), (0.45759493086200687, 2)], [(0.04377223745963666, 0), (0.5043534382070538, 1), (0.4518743243333101, 2)], [(0.041377864559958115, 0), (0.49450702243148886, 1), (0.4641151130085546, 2)], [(0.06664472235797042, 0), (0.4927032207407448, 1), (0.4406520569012841, 2)], [(0.0528380132202798, 0), (0.4871102029137462, 1), (0.460051783865975, 2)], [(0.047940170349888546, 0), (0.49001280154973126, 1), (0.4620470281003815, 2)], [(0.04505918048213829, 0), (0.49770388865629844, 1), (0.45723693086156225, 2)], [(0.06766160412311599, 0), (0.4929541163343243, 1), (0.43938427954256065, 2)], [(0.10330850300509363, 0), (0.476587539562611, 1), (0.42010395743229584, 2)], [(0.052348028868769506, 0), (0.4904703285165612, 1), (0.45718164261466987, 2)], [(0.07213390841723287, 0), (0.4951634224930377, 1), (0.43270266908972993, 2)], [(0.05983129688259453, 0), (0.49357344812806564, 1), (0.44659525498934033, 2)], [(0.06145081121761718, 0), (0.4997746691613533, 1), (0.4387745196210307, 2)], [(0.12643795001228872, 0), (0.47099626175067344, 1), (0.40256578823703776, 2)], [(0.06070342294682647, 0), (0.49247820304725287, 1), (0.4468183740059213, 2)], [(0.004295498135605839, 0), (0.43723717013987085, 1), (0.5584673317245246, 2)], [(0.012341506012939497, 0), (0.46132786049878993, 1), (0.5263306334882699, 2)], [(0.008041961751626116, 0), (0.4680415056267554, 1), (0.5239165326216197, 2)], [(0.012203741968758995, 0), (0.4719261984222097, 1), (0.5158700596090297, 2)], [(0.006638717235494328, 0), (0.45347586782275284, 1), (0.5398854149417517, 2)], [(0.005263351295626021, 0), (0.4644832412791946, 1), (0.5302534074251788, 2)], [(0.018779628968296866, 0), (0.45982004366157864, 1), (0.5214003273701255, 2)], [(0.00888680054755727, 0), (0.47816163997455585, 1), (0.5129515594778852, 2)], [(0.008973878798549955, 0), (0.4660746833425122, 1), (0.5249514378589395, 2)], [(0.0057884114937193755, 0), (0.45971014127911347, 1), (0.534501447227168, 2)], [(0.016774571271002955, 0), (0.4801428349049056, 1), (0.5030825938240929, 2)], [(0.012044056966578119, 0), (0.4687383638137518, 1), (0.5192175792196708, 2)], [(0.010208690383078136, 0), (0.4695191063788682, 1), (0.5202722032380547, 2)], [(0.010040795166916689, 0), (0.4512755375694701, 1), (0.5386836672636148, 2)], [(0.0064969600914313004, 0), (0.43748007849739584, 1), (0.5560229614111728, 2)], [(0.009304270226104616, 0), (0.4601504705473684, 1), (0.530545259226528, 2)], [(0.014563517615477736, 0), (0.47933745208881223, 1), (0.5060990302957091, 2)], [(0.0070262746504037524, 0), (0.47944937445766783, 1), (0.5135243508919285, 2)], [(0.0025301248297814735, 0), (0.4388473043566187, 1), (0.5586225708136002, 2)], [(0.01793057634785105, 0), (0.4743301503029546, 1), (0.5077392733491937, 2)], [(0.007644068684806274, 0), (0.4622175416837792, 1), (0.5301383896314145, 2)], [(0.01270343330775703, 0), (0.4583206744735115, 1), (0.5289758922187313, 2)], [(0.005081534005270547, 0), (0.46485081415813695, 1), (0.5300676518365941, 2)], [(0.01829020671024577, 0), (0.47785801303090053, 1), (0.5038517802588545, 2)], [(0.010300836830757908, 0), (0.4716408754308902, 1), (0.5180582877383503, 2)], [(0.013056433889026029, 0), (0.48854866614356096, 1), (0.49839489996741193, 2)], [(0.02048499287484417, 0), (0.4796366294972459, 1), (0.4998783776279093, 2)], [(0.02097735435316386, 0), (0.48087632026222776, 1), (0.4981463253846089, 2)], [(0.007710125240385033, 0), (0.4556132126926747, 1), (0.5366766620669403, 2)], [(0.01775038522782329, 0), (0.4965751879587703, 1), (0.4856744268134073, 2)], [(0.008638855104322298, 0), (0.4756764562227914, 1), (0.5156846886728871, 2)], [(0.012134615048701844, 0), (0.49867869309551527, 1), (0.4891866918557844, 2)], [(0.006698546808664106, 0), (0.45021394937102427, 1), (0.5430875038203126, 2)], [(0.02528409377638384, 0), (0.4904865676449243, 1), (0.4842293385786926, 2)], [(0.017052353961584325, 0), (0.48108163514744395, 1), (0.5018660108909714, 2)], [(0.0059266080768430335, 0), (0.46457056945157915, 1), (0.5295028224715764, 2)], [(0.007121249136349586, 0), (0.45269328584265495, 1), (0.5401854650209942, 2)], [(0.015142475318899725, 0), (0.4798754856996685, 1), (0.5049820389814322, 2)], [(0.02213425927551211, 0), (0.480457241275485, 1), (0.49740849944900195, 2)], [(0.011936502731266969, 0), (0.4753590492087327, 1), (0.5127044480600016, 2)], [(0.006463038736312238, 0), (0.4524258122544322, 1), (0.5411111490092549, 2)], [(0.011315743506039701, 0), (0.4700603993589541, 1), (0.5186238571350079, 2)], [(0.012341506012939497, 0), (0.46132786049878993, 1), (0.5263306334882699, 2)], [(0.006432772071935396, 0), (0.4567810928523039, 1), (0.5367861350757602, 2)], [(0.00587581213132872, 0), (0.45021579023480335, 1), (0.5439083976338676, 2)], [(0.009473580472389227, 0), (0.46237113309649774, 1), (0.5281552864311121, 2)], [(0.013103723329522001, 0), (0.4670468076720037, 1), (0.5198494689984746, 2)], [(0.0138096923928804, 0), (0.47411641181074315, 1), (0.5120738957963767, 2)], [(0.009332865067928468, 0), (0.4597334606118675, 1), (0.5309336743202042, 2)], [(0.017337387678188924, 0), (0.474812201112126, 1), (0.5078504112096852, 2)]]
正如文档所说,
返回模型中每个类的样本概率,其中类按算术排序。