pycaret 上的概率与预测标签不匹配

Probabilities on pycaret do not match the predicted label

我正在使用以下代码使用 pycaret 开发分类模型:

sample = pd.DataFrame(sample)
exp_clf = setup(sample, target = 'match',fix_imbalance = True)
clf_model = create_model('lightgbm')
tuned_clf_model = tune_model(clf_model, optimize = 'Recall')
tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, data = sample)

现在问题就来了,因为标签1和0的分数重叠了:

这是我正在使用的数据,它可以像上面代码的第一行一样转换为字典读取并转换为数据帧。

sample =  {'same_add_number': {1521: False,
  1756: False,
  2456: False,
  589: False,
  51: False,
  668: False,
  3030: False,
  864: True,
  681: False,
  372: False,
  2768: False,
  3519: False,
  2212: True,
  2424: False,
  672: False,
  1802: False,
  3910: False,
  1174: False,
  1556: False,
  922: False,
  3416: False,
  719: False,
  641: False,
  1364: False,
  3153: False,
  775: False,
  967: False,
  4054: False,
  518: False,
  121: False,
  1027: False,
  4447: True,
  257: False,
  706: True,
  3219: False,
  3009: True,
  3980: False,
  483: False,
  3154: False,
  4399: True,
  2085: False,
  373: False,
  1469: False,
  768: False,
  1491: True,
  2734: False,
  2623: False,
  746: True,
  1647: False,
  3806: False,
  4351: False,
  925: False,
  602: False,
  992: False,
  2041: False,
  1911: False,
  615: False,
  759: False,
  835: False,
  2139: False,
  56: False,
  1980: False,
  995: True,
  1696: False,
  166: False,
  114: True,
  275: False,
  2973: False,
  1313: False,
  1039: False,
  1573: False,
  771: False,
  3193: False,
  2292: False,
  2597: False,
  1747: False,
  1939: False,
  2598: False,
  1998: False,
  3288: False,
  528: False,
  829: False,
  3591: False,
  973: False,
  4383: False,
  1689: False,
  1286: False,
  4388: False,
  491: False,
  3920: False,
  449: False,
  2840: False,
  1324: False,
  2801: False,
  1605: False,
  1355: False,
  1444: False,
  941: False,
  4109: False,
  1767: False,
  839: False,
  188: False,
  3939: False,
  1186: False,
  540: False,
  1456: False,
  3925: True,
  1782: False,
  1733: False,
  64: True,
  2710: False,
  893: False,
  1434: False,
  1244: False,
  503: False,
  3044: False,
  1617: False,
  2878: False,
  913: False,
  799: False,
  2202: False,
  3503: False,
  4063: False,
  3756: False,
  659: False,
  1287: False,
  3843: False,
  2026: True,
  1224: False,
  705: False,
  900: False,
  500: False,
  614: False,
  2766: False,
  8: False,
  981: False,
  1919: False,
  2790: False,
  1098: False,
  1442: False,
  2634: False,
  3346: False,
  652: True,
  2324: False,
  972: False,
  287: False,
  2481: False,
  2486: False,
  4272: False,
  4011: False,
  4: False,
  1645: False,
  863: False,
  688: False,
  2365: False,
  3522: False,
  13: False,
  3251: False,
  1410: False,
  2306: False,
  443: False,
  221: False,
  632: True,
  2549: False,
  783: False,
  3221: False,
  3183: False,
  410: False,
  1289: False,
  1691: False,
  2015: False,
  1022: True,
  455: False,
  572: False,
  2747: False,
  3670: False,
  4441: False,
  2559: False,
  159: False,
  91: False,
  263: False,
  3012: False,
  1234: False,
  4040: False,
  288: False,
  89: False,
  1029: False,
  1180: False,
  1083: False,
  3970: False,
  4201: False,
  709: False,
  2401: False,
  1071: False,
  2954: True,
  29: True},
 'same_add_name': {1521: False,
  1756: False,
  2456: False,
  589: False,
  51: False,
  668: False,
  3030: False,
  864: False,
  681: False,
  372: False,
  2768: False,
  3519: False,
  2212: False,
  2424: False,
  672: False,
  1802: False,
  3910: False,
  1174: False,
  1556: False,
  922: False,
  3416: False,
  719: False,
  641: False,
  1364: False,
  3153: False,
  775: False,
  967: False,
  4054: False,
  518: False,
  121: False,
  1027: False,
  4447: False,
  257: False,
  706: False,
  3219: False,
  3009: False,
  3980: False,
  483: False,
  3154: False,
  4399: False,
  2085: False,
  373: False,
  1469: False,
  768: False,
  1491: False,
  2734: False,
  2623: False,
  746: False,
  1647: False,
  3806: False,
  4351: False,
  925: False,
  602: False,
  992: False,
  2041: False,
  1911: False,
  615: True,
  759: False,
  835: False,
  2139: False,
  56: False,
  1980: False,
  995: False,
  1696: False,
  166: False,
  114: False,
  275: False,
  2973: False,
  1313: False,
  1039: False,
  1573: False,
  771: False,
  3193: False,
  2292: False,
  2597: False,
  1747: False,
  1939: False,
  2598: False,
  1998: False,
  3288: False,
  528: False,
  829: False,
  3591: False,
  973: False,
  4383: False,
  1689: False,
  1286: False,
  4388: False,
  491: False,
  3920: False,
  449: False,
  2840: False,
  1324: False,
  2801: False,
  1605: False,
  1355: False,
  1444: False,
  941: False,
  4109: False,
  1767: False,
  839: False,
  188: False,
  3939: False,
  1186: False,
  540: False,
  1456: False,
  3925: False,
  1782: False,
  1733: False,
  64: False,
  2710: False,
  893: False,
  1434: False,
  1244: False,
  503: False,
  3044: False,
  1617: False,
  2878: False,
  913: False,
  799: False,
  2202: False,
  3503: False,
  4063: False,
  3756: False,
  659: False,
  1287: False,
  3843: False,
  2026: False,
  1224: False,
  705: False,
  900: False,
  500: False,
  614: False,
  2766: False,
  8: False,
  981: False,
  1919: False,
  2790: False,
  1098: False,
  1442: False,
  2634: False,
  3346: False,
  652: False,
  2324: False,
  972: False,
  287: False,
  2481: False,
  2486: False,
  4272: False,
  4011: False,
  4: False,
  1645: False,
  863: False,
  688: False,
  2365: False,
  3522: False,
  13: False,
  3251: False,
  1410: False,
  2306: False,
  443: False,
  221: False,
  632: False,
  2549: False,
  783: False,
  3221: False,
  3183: False,
  410: False,
  1289: False,
  1691: False,
  2015: False,
  1022: False,
  455: False,
  572: False,
  2747: False,
  3670: False,
  4441: False,
  2559: False,
  159: False,
  91: False,
  263: True,
  3012: False,
  1234: False,
  4040: False,
  288: False,
  89: False,
  1029: False,
  1180: False,
  1083: False,
  3970: False,
  4201: False,
  709: False,
  2401: False,
  1071: False,
  2954: False,
  29: False},
 'name_score_fuzzy': {1521: 78,
  1756: 71,
  2456: 73,
  589: 38,
  51: 71,
  668: 49,
  3030: 75,
  864: 47,
  681: 75,
  372: 72,
  2768: 73,
  3519: 85,
  2212: 100,
  2424: 85,
  672: 74,
  1802: 46,
  3910: 73,
  1174: 47,
  1556: 80,
  922: 73,
  3416: 71,
  719: 55,
  641: 71,
  1364: 79,
  3153: 74,
  775: 54,
  967: 73,
  4054: 100,
  518: 72,
  121: 49,
  1027: 38,
  4447: 100,
  257: 74,
  706: 40,
  3219: 71,
  3009: 93,
  3980: 72,
  483: 46,
  3154: 68,
  4399: 100,
  2085: 80,
  373: 77,
  1469: 23,
  768: 50,
  1491: 100,
  2734: 79,
  2623: 79,
  746: 88,
  1647: 73,
  3806: 79,
  4351: 72,
  925: 65,
  602: 83,
  992: 46,
  2041: 78,
  1911: 77,
  615: 45,
  759: 52,
  835: 77,
  2139: 77,
  56: 81,
  1980: 71,
  995: 59,
  1696: 83,
  166: 71,
  114: 50,
  275: 47,
  2973: 80,
  1313: 73,
  1039: 75,
  1573: 70,
  771: 53,
  3193: 100,
  2292: 79,
  2597: 71,
  1747: 78,
  1939: 84,
  2598: 71,
  1998: 77,
  3288: 85,
  528: 44,
  829: 72,
  3591: 80,
  973: 47,
  4383: 80,
  1689: 85,
  1286: 41,
  4388: 75,
  491: 77,
  3920: 70,
  449: 73,
  2840: 79,
  1324: 81,
  2801: 73,
  1605: 47,
  1355: 72,
  1444: 72,
  941: 62,
  4109: 79,
  1767: 34,
  839: 35,
  188: 63,
  3939: 75,
  1186: 49,
  540: 44,
  1456: 41,
  3925: 91,
  1782: 43,
  1733: 74,
  64: 21,
  2710: 71,
  893: 57,
  1434: 75,
  1244: 77,
  503: 75,
  3044: 71,
  1617: 73,
  2878: 71,
  913: 63,
  799: 78,
  2202: 71,
  3503: 77,
  4063: 75,
  3756: 77,
  659: 51,
  1287: 76,
  3843: 73,
  2026: 100,
  1224: 71,
  705: 81,
  900: 65,
  500: 42,
  614: 81,
  2766: 76,
  8: 71,
  981: 73,
  1919: 73,
  2790: 71,
  1098: 76,
  1442: 73,
  2634: 73,
  3346: 81,
  652: 100,
  2324: 84,
  972: 73,
  287: 63,
  2481: 76,
  2486: 76,
  4272: 64,
  4011: 73,
  4: 74,
  1645: 17,
  863: 46,
  688: 71,
  2365: 76,
  3522: 73,
  13: 52,
  3251: 74,
  1410: 80,
  2306: 71,
  443: 71,
  221: 73,
  632: 65,
  2549: 80,
  783: 53,
  3221: 71,
  3183: 75,
  410: 53,
  1289: 71,
  1691: 85,
  2015: 71,
  1022: 67,
  455: 100,
  572: 100,
  2747: 77,
  3670: 74,
  4441: 81,
  2559: 84,
  159: 22,
  91: 79,
  263: 41,
  3012: 76,
  1234: 77,
  4040: 73,
  288: 82,
  89: 71,
  1029: 82,
  1180: 78,
  1083: 77,
  3970: 75,
  4201: 76,
  709: 46,
  2401: 76,
  1071: 83,
  2954: 93,
  29: 52},
 'name_score_cos': {1521: 0.805341232815891,
  1756: 1.0000000156276607,
  2456: 0.7146280288550899,
  589: 0.4944973860854622,
  51: 0.16448994174134138,
  668: 0.6680419517655739,
  3030: 0.5178230596082453,
  864: 0.34284966537760764,
  681: 0.8220122172271629,
  372: 0.7372570578072887,
  2768: 1.0000000748631144,
  3519: 0.6544869126589294,
  2212: 1.0,
  2424: 0.9999999107799844,
  672: 0.8006864625973021,
  1802: 0.008748746635272902,
  3910: 0.6029157847994123,
  1174: 0.43891392720221256,
  1556: 0.4592255006317409,
  922: 0.602017340163112,
  3416: 0.7887549792307141,
  719: 0.13458379717430374,
  641: 0.8221775985370106,
  1364: 0.8349841579827227,
  3153: 0.6395051509895127,
  775: 0.4861694445439952,
  967: 0.6240594839420581,
  4054: 1.0,
  518: 0.8274708074953143,
  121: 0.4156175285346006,
  1027: 0.4172238782731538,
  4447: 1.0,
  257: 0.7144798398523643,
  706: 0.2914152988288179,
  3219: 0.4892006725361837,
  3009: 0.8732375138387463,
  3980: 0.5371502775293667,
  483: 0.6532926383429954,
  3154: 0.7500245353516992,
  4399: 1.0,
  2085: 0.6994934983150074,
  373: 0.0,
  1469: 0.13834207989466868,
  768: 0.0,
  1491: 1.0,
  2734: 0.5744607478435466,
  2623: 0.521054474126365,
  746: 0.900627520280279,
  1647: 0.46841195036889005,
  3806: 0.5245533025793365,
  4351: 0.7190153036645236,
  925: 0.602017340163112,
  602: 0.8180017827481202,
  992: 0.6552306767756036,
  2041: 0.8416265969822513,
  1911: 0.5760342064839252,
  615: 0.3142721314062845,
  759: 0.29937879126297773,
  835: 0.4814135508437952,
  2139: 0.8103994874531241,
  56: 0.4777649573427413,
  1980: 0.4501770315717141,
  995: 0.3185447219204094,
  1696: 0.9999999289827698,
  166: 0.0,
  114: 0.0,
  275: -0.059108179802214694,
  2973: 0.0,
  1313: 0.4103695338595878,
  1039: 0.4158014949799697,
  1573: 0.7687119146546476,
  771: -0.038431693364239676,
  3193: 1.0,
  2292: 0.9999999289827698,
  2597: 0.7014107947566588,
  1747: 0.613680567239729,
  1939: 0.8930406720693059,
  2598: 1.0000000156276607,
  1998: 0.9999999107799844,
  3288: 0.6015149463851227,
  528: 0.48037545624105144,
  829: 0.3520640350139409,
  3591: 0.5123337954949542,
  973: 0.29920325457748886,
  4383: 0.605345098540998,
  1689: 0.699458791765087,
  1286: 0.26151465192863704,
  4388: 0.5996518099075245,
  491: 0.8274708074953143,
  3920: 0.5561721737068668,
  449: 0.5309349410096579,
  2840: 0.6964415538329863,
  1324: 0.8352363777690135,
  2801: 0.0,
  1605: 0.3992469760734788,
  1355: 0.5092696449238323,
  1444: 0.7013725048779127,
  941: 0.0,
  4109: 0.7371134488841004,
  1767: 0.32686654729234066,
  839: 0.28650412696593686,
  188: 0.11578000694274473,
  3939: 0.5182830082849388,
  1186: 0.5399906358163992,
  540: 0.23601516039791495,
  1456: 0.4462820528772964,
  3925: 0.39035408504387764,
  1782: 0.17470256029413367,
  1733: 0.9999999289827698,
  64: 0.47240949440644947,
  2710: 0.21737616101123375,
  893: 0.3889650515319831,
  1434: 0.3144768136655605,
  1244: 0.8456850404860974,
  503: 0.8274708074953143,
  3044: 0.5604645740029809,
  1617: 0.8343403856383358,
  2878: 0.6624314741881498,
  913: 0.3665973835032023,
  799: 0.5785308541963937,
  2202: 0.584334176199583,
  3503: 0.7330193052968511,
  4063: 0.633698984756138,
  3756: 0.588157437279164,
  659: 0.8040106952622528,
  1287: 0.6826384100268522,
  3843: 0.7287410320020241,
  2026: 1.0,
  1224: 0.0,
  705: 0.7278133754982946,
  900: 0.592942126263229,
  500: 0.5038847249789867,
  614: 0.6417445279680914,
  2766: 0.9999999574199627,
  8: 0.722455004886235,
  981: 0.6168699100990872,
  1919: 0.6551439293796956,
  2790: 0.0,
  1098: 0.5890947178422432,
  1442: 0.39311307805458195,
  2634: 0.5434702892550847,
  3346: 0.5956843029692919,
  652: 1.0,
  2324: 0.7619312086149606,
  972: 0.5067710204705025,
  287: 0.6569573257912408,
  2481: 0.5829629588847571,
  2486: 0.436286219251023,
  4272: 0.5408064181796995,
  4011: 0.9999999289827698,
  4: 0.7647923556190919,
  1645: 0.4139532701675873,
  863: 0.40369910836161105,
  688: 0.0,
  2365: 0.7371134488841004,
  3522: 0.6205927634025437,
  13: 0.6688829431116972,
  3251: 0.7114075759658299,
  1410: 0.3589092268079449,
  2306: 1.0000000396582405,
  443: 0.6808489866836555,
  221: 0.5811068730506951,
  632: 0.5470606107366598,
  2549: 0.7123831914993078,
  783: 0.46296630135808603,
  3221: 0.5883753355908442,
  3183: 0.7371134488841004,
  410: 0.7604057492722187,
  1289: 0.5855230248645426,
  1691: 0.727210015672603,
  2015: 0.9999999107799844,
  1022: 0.0,
  455: 1.0,
  572: 1.0,
  2747: 0.7761666318621021,
  3670: 0.5560044398288135,
  4441: 0.7697792208927854,
  2559: 0.5788817989918374,
  159: 0.27027908726745226,
  91: 0.5462872872864122,
  263: 0.3015316394560223,
  3012: 0.6611230100784922,
  1234: 0.6639184765411582,
  4040: 0.9999999768133089,
  288: 0.7681366994965638,
  89: 0.7030570621995992,
  1029: 0.5322036652128525,
  1180: 0.3590668280085605,
  1083: 0.7805410171946893,
  3970: 0.47446565960369524,
  4201: 0.813152589308668,
  709: 0.37964467582959255,
  2401: 0.6551620258724654,
  1071: 0.21475894870778542,
  2954: 0.8452728458129916,
  29: 0.5138088947304236},
 'match': {1521: 0,
  1756: 0,
  2456: 0,
  589: 0,
  51: 0,
  668: 0,
  3030: 0,
  864: 1,
  681: 0,
  372: 0,
  2768: 0,
  3519: 0,
  2212: 1,
  2424: 0,
  672: 0,
  1802: 0,
  3910: 0,
  1174: 0,
  1556: 0,
  922: 0,
  3416: 0,
  719: 0,
  641: 0,
  1364: 0,
  3153: 0,
  775: 0,
  967: 0,
  4054: 1,
  518: 0,
  121: 0,
  1027: 0,
  4447: 1,
  257: 0,
  706: 0,
  3219: 0,
  3009: 0,
  3980: 0,
  483: 0,
  3154: 0,
  4399: 1,
  2085: 0,
  373: 0,
  1469: 0,
  768: 0,
  1491: 1,
  2734: 0,
  2623: 0,
  746: 1,
  1647: 0,
  3806: 0,
  4351: 0,
  925: 0,
  602: 0,
  992: 0,
  2041: 0,
  1911: 0,
  615: 0,
  759: 0,
  835: 0,
  2139: 0,
  56: 0,
  1980: 0,
  995: 1,
  1696: 0,
  166: 0,
  114: 1,
  275: 0,
  2973: 0,
  1313: 0,
  1039: 0,
  1573: 0,
  771: 0,
  3193: 0,
  2292: 0,
  2597: 0,
  1747: 0,
  1939: 0,
  2598: 0,
  1998: 0,
  3288: 0,
  528: 0,
  829: 0,
  3591: 0,
  973: 0,
  4383: 0,
  1689: 0,
  1286: 0,
  4388: 0,
  491: 0,
  3920: 0,
  449: 0,
  2840: 0,
  1324: 0,
  2801: 0,
  1605: 0,
  1355: 0,
  1444: 0,
  941: 0,
  4109: 0,
  1767: 0,
  839: 0,
  188: 0,
  3939: 0,
  1186: 0,
  540: 0,
  1456: 0,
  3925: 1,
  1782: 0,
  1733: 0,
  64: 0,
  2710: 0,
  893: 0,
  1434: 0,
  1244: 0,
  503: 0,
  3044: 0,
  1617: 0,
  2878: 0,
  913: 0,
  799: 0,
  2202: 0,
  3503: 0,
  4063: 0,
  3756: 0,
  659: 0,
  1287: 0,
  3843: 0,
  2026: 1,
  1224: 0,
  705: 0,
  900: 0,
  500: 0,
  614: 0,
  2766: 0,
  8: 0,
  981: 0,
  1919: 0,
  2790: 0,
  1098: 0,
  1442: 0,
  2634: 0,
  3346: 0,
  652: 1,
  2324: 0,
  972: 0,
  287: 0,
  2481: 0,
  2486: 0,
  4272: 0,
  4011: 0,
  4: 0,
  1645: 0,
  863: 0,
  688: 0,
  2365: 0,
  3522: 0,
  13: 0,
  3251: 0,
  1410: 0,
  2306: 0,
  443: 0,
  221: 0,
  632: 0,
  2549: 0,
  783: 0,
  3221: 0,
  3183: 0,
  410: 0,
  1289: 0,
  1691: 0,
  2015: 0,
  1022: 1,
  455: 1,
  572: 1,
  2747: 0,
  3670: 0,
  4441: 0,
  2559: 0,
  159: 0,
  91: 0,
  263: 0,
  3012: 0,
  1234: 0,
  4040: 0,
  288: 0,
  89: 0,
  1029: 0,
  1180: 0,
  1083: 0,
  3970: 0,
  4201: 0,
  709: 0,
  2401: 0,
  1071: 0,
  2954: 0,
  29: 1}}

如果Score大于0.5则获取Label的值为1。如果不能正常工作,你可以创建一个新列“my_label”并设置你自己的边界,当Label获取该值时1.

奇怪的是,Score设置为标签的概率。换句话说,如果模型的原始输出为 0.01,则数据帧将显示为 Label = 0 | Score = 0.99。如果模型的原始输出为 0.99,则数据框将显示为 Label = 1 | Score = 0.99。我认为当你做的不仅仅是二元分类时,这可能更有意义。

如果你对我的话不满意(我不会怪你),你可以通过将预测线更改为

来获得原始分数
tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, raw_score=True, data = sample)

注意 raw_score=True。然后你的数据框将有两个分数列(Score_0Score_1)。从那里,你可以通过

得到你想要的直方图
tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==0].Score_1.hist()
tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==1].Score_1.hist()