pycaret 上的概率与预测标签不匹配
Probabilities on pycaret do not match the predicted label
我正在使用以下代码使用 pycaret 开发分类模型:
sample = pd.DataFrame(sample)
exp_clf = setup(sample, target = 'match',fix_imbalance = True)
clf_model = create_model('lightgbm')
tuned_clf_model = tune_model(clf_model, optimize = 'Recall')
tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, data = sample)
现在问题就来了,因为标签1和0的分数重叠了:
这是我正在使用的数据,它可以像上面代码的第一行一样转换为字典读取并转换为数据帧。
sample = {'same_add_number': {1521: False,
1756: False,
2456: False,
589: False,
51: False,
668: False,
3030: False,
864: True,
681: False,
372: False,
2768: False,
3519: False,
2212: True,
2424: False,
672: False,
1802: False,
3910: False,
1174: False,
1556: False,
922: False,
3416: False,
719: False,
641: False,
1364: False,
3153: False,
775: False,
967: False,
4054: False,
518: False,
121: False,
1027: False,
4447: True,
257: False,
706: True,
3219: False,
3009: True,
3980: False,
483: False,
3154: False,
4399: True,
2085: False,
373: False,
1469: False,
768: False,
1491: True,
2734: False,
2623: False,
746: True,
1647: False,
3806: False,
4351: False,
925: False,
602: False,
992: False,
2041: False,
1911: False,
615: False,
759: False,
835: False,
2139: False,
56: False,
1980: False,
995: True,
1696: False,
166: False,
114: True,
275: False,
2973: False,
1313: False,
1039: False,
1573: False,
771: False,
3193: False,
2292: False,
2597: False,
1747: False,
1939: False,
2598: False,
1998: False,
3288: False,
528: False,
829: False,
3591: False,
973: False,
4383: False,
1689: False,
1286: False,
4388: False,
491: False,
3920: False,
449: False,
2840: False,
1324: False,
2801: False,
1605: False,
1355: False,
1444: False,
941: False,
4109: False,
1767: False,
839: False,
188: False,
3939: False,
1186: False,
540: False,
1456: False,
3925: True,
1782: False,
1733: False,
64: True,
2710: False,
893: False,
1434: False,
1244: False,
503: False,
3044: False,
1617: False,
2878: False,
913: False,
799: False,
2202: False,
3503: False,
4063: False,
3756: False,
659: False,
1287: False,
3843: False,
2026: True,
1224: False,
705: False,
900: False,
500: False,
614: False,
2766: False,
8: False,
981: False,
1919: False,
2790: False,
1098: False,
1442: False,
2634: False,
3346: False,
652: True,
2324: False,
972: False,
287: False,
2481: False,
2486: False,
4272: False,
4011: False,
4: False,
1645: False,
863: False,
688: False,
2365: False,
3522: False,
13: False,
3251: False,
1410: False,
2306: False,
443: False,
221: False,
632: True,
2549: False,
783: False,
3221: False,
3183: False,
410: False,
1289: False,
1691: False,
2015: False,
1022: True,
455: False,
572: False,
2747: False,
3670: False,
4441: False,
2559: False,
159: False,
91: False,
263: False,
3012: False,
1234: False,
4040: False,
288: False,
89: False,
1029: False,
1180: False,
1083: False,
3970: False,
4201: False,
709: False,
2401: False,
1071: False,
2954: True,
29: True},
'same_add_name': {1521: False,
1756: False,
2456: False,
589: False,
51: False,
668: False,
3030: False,
864: False,
681: False,
372: False,
2768: False,
3519: False,
2212: False,
2424: False,
672: False,
1802: False,
3910: False,
1174: False,
1556: False,
922: False,
3416: False,
719: False,
641: False,
1364: False,
3153: False,
775: False,
967: False,
4054: False,
518: False,
121: False,
1027: False,
4447: False,
257: False,
706: False,
3219: False,
3009: False,
3980: False,
483: False,
3154: False,
4399: False,
2085: False,
373: False,
1469: False,
768: False,
1491: False,
2734: False,
2623: False,
746: False,
1647: False,
3806: False,
4351: False,
925: False,
602: False,
992: False,
2041: False,
1911: False,
615: True,
759: False,
835: False,
2139: False,
56: False,
1980: False,
995: False,
1696: False,
166: False,
114: False,
275: False,
2973: False,
1313: False,
1039: False,
1573: False,
771: False,
3193: False,
2292: False,
2597: False,
1747: False,
1939: False,
2598: False,
1998: False,
3288: False,
528: False,
829: False,
3591: False,
973: False,
4383: False,
1689: False,
1286: False,
4388: False,
491: False,
3920: False,
449: False,
2840: False,
1324: False,
2801: False,
1605: False,
1355: False,
1444: False,
941: False,
4109: False,
1767: False,
839: False,
188: False,
3939: False,
1186: False,
540: False,
1456: False,
3925: False,
1782: False,
1733: False,
64: False,
2710: False,
893: False,
1434: False,
1244: False,
503: False,
3044: False,
1617: False,
2878: False,
913: False,
799: False,
2202: False,
3503: False,
4063: False,
3756: False,
659: False,
1287: False,
3843: False,
2026: False,
1224: False,
705: False,
900: False,
500: False,
614: False,
2766: False,
8: False,
981: False,
1919: False,
2790: False,
1098: False,
1442: False,
2634: False,
3346: False,
652: False,
2324: False,
972: False,
287: False,
2481: False,
2486: False,
4272: False,
4011: False,
4: False,
1645: False,
863: False,
688: False,
2365: False,
3522: False,
13: False,
3251: False,
1410: False,
2306: False,
443: False,
221: False,
632: False,
2549: False,
783: False,
3221: False,
3183: False,
410: False,
1289: False,
1691: False,
2015: False,
1022: False,
455: False,
572: False,
2747: False,
3670: False,
4441: False,
2559: False,
159: False,
91: False,
263: True,
3012: False,
1234: False,
4040: False,
288: False,
89: False,
1029: False,
1180: False,
1083: False,
3970: False,
4201: False,
709: False,
2401: False,
1071: False,
2954: False,
29: False},
'name_score_fuzzy': {1521: 78,
1756: 71,
2456: 73,
589: 38,
51: 71,
668: 49,
3030: 75,
864: 47,
681: 75,
372: 72,
2768: 73,
3519: 85,
2212: 100,
2424: 85,
672: 74,
1802: 46,
3910: 73,
1174: 47,
1556: 80,
922: 73,
3416: 71,
719: 55,
641: 71,
1364: 79,
3153: 74,
775: 54,
967: 73,
4054: 100,
518: 72,
121: 49,
1027: 38,
4447: 100,
257: 74,
706: 40,
3219: 71,
3009: 93,
3980: 72,
483: 46,
3154: 68,
4399: 100,
2085: 80,
373: 77,
1469: 23,
768: 50,
1491: 100,
2734: 79,
2623: 79,
746: 88,
1647: 73,
3806: 79,
4351: 72,
925: 65,
602: 83,
992: 46,
2041: 78,
1911: 77,
615: 45,
759: 52,
835: 77,
2139: 77,
56: 81,
1980: 71,
995: 59,
1696: 83,
166: 71,
114: 50,
275: 47,
2973: 80,
1313: 73,
1039: 75,
1573: 70,
771: 53,
3193: 100,
2292: 79,
2597: 71,
1747: 78,
1939: 84,
2598: 71,
1998: 77,
3288: 85,
528: 44,
829: 72,
3591: 80,
973: 47,
4383: 80,
1689: 85,
1286: 41,
4388: 75,
491: 77,
3920: 70,
449: 73,
2840: 79,
1324: 81,
2801: 73,
1605: 47,
1355: 72,
1444: 72,
941: 62,
4109: 79,
1767: 34,
839: 35,
188: 63,
3939: 75,
1186: 49,
540: 44,
1456: 41,
3925: 91,
1782: 43,
1733: 74,
64: 21,
2710: 71,
893: 57,
1434: 75,
1244: 77,
503: 75,
3044: 71,
1617: 73,
2878: 71,
913: 63,
799: 78,
2202: 71,
3503: 77,
4063: 75,
3756: 77,
659: 51,
1287: 76,
3843: 73,
2026: 100,
1224: 71,
705: 81,
900: 65,
500: 42,
614: 81,
2766: 76,
8: 71,
981: 73,
1919: 73,
2790: 71,
1098: 76,
1442: 73,
2634: 73,
3346: 81,
652: 100,
2324: 84,
972: 73,
287: 63,
2481: 76,
2486: 76,
4272: 64,
4011: 73,
4: 74,
1645: 17,
863: 46,
688: 71,
2365: 76,
3522: 73,
13: 52,
3251: 74,
1410: 80,
2306: 71,
443: 71,
221: 73,
632: 65,
2549: 80,
783: 53,
3221: 71,
3183: 75,
410: 53,
1289: 71,
1691: 85,
2015: 71,
1022: 67,
455: 100,
572: 100,
2747: 77,
3670: 74,
4441: 81,
2559: 84,
159: 22,
91: 79,
263: 41,
3012: 76,
1234: 77,
4040: 73,
288: 82,
89: 71,
1029: 82,
1180: 78,
1083: 77,
3970: 75,
4201: 76,
709: 46,
2401: 76,
1071: 83,
2954: 93,
29: 52},
'name_score_cos': {1521: 0.805341232815891,
1756: 1.0000000156276607,
2456: 0.7146280288550899,
589: 0.4944973860854622,
51: 0.16448994174134138,
668: 0.6680419517655739,
3030: 0.5178230596082453,
864: 0.34284966537760764,
681: 0.8220122172271629,
372: 0.7372570578072887,
2768: 1.0000000748631144,
3519: 0.6544869126589294,
2212: 1.0,
2424: 0.9999999107799844,
672: 0.8006864625973021,
1802: 0.008748746635272902,
3910: 0.6029157847994123,
1174: 0.43891392720221256,
1556: 0.4592255006317409,
922: 0.602017340163112,
3416: 0.7887549792307141,
719: 0.13458379717430374,
641: 0.8221775985370106,
1364: 0.8349841579827227,
3153: 0.6395051509895127,
775: 0.4861694445439952,
967: 0.6240594839420581,
4054: 1.0,
518: 0.8274708074953143,
121: 0.4156175285346006,
1027: 0.4172238782731538,
4447: 1.0,
257: 0.7144798398523643,
706: 0.2914152988288179,
3219: 0.4892006725361837,
3009: 0.8732375138387463,
3980: 0.5371502775293667,
483: 0.6532926383429954,
3154: 0.7500245353516992,
4399: 1.0,
2085: 0.6994934983150074,
373: 0.0,
1469: 0.13834207989466868,
768: 0.0,
1491: 1.0,
2734: 0.5744607478435466,
2623: 0.521054474126365,
746: 0.900627520280279,
1647: 0.46841195036889005,
3806: 0.5245533025793365,
4351: 0.7190153036645236,
925: 0.602017340163112,
602: 0.8180017827481202,
992: 0.6552306767756036,
2041: 0.8416265969822513,
1911: 0.5760342064839252,
615: 0.3142721314062845,
759: 0.29937879126297773,
835: 0.4814135508437952,
2139: 0.8103994874531241,
56: 0.4777649573427413,
1980: 0.4501770315717141,
995: 0.3185447219204094,
1696: 0.9999999289827698,
166: 0.0,
114: 0.0,
275: -0.059108179802214694,
2973: 0.0,
1313: 0.4103695338595878,
1039: 0.4158014949799697,
1573: 0.7687119146546476,
771: -0.038431693364239676,
3193: 1.0,
2292: 0.9999999289827698,
2597: 0.7014107947566588,
1747: 0.613680567239729,
1939: 0.8930406720693059,
2598: 1.0000000156276607,
1998: 0.9999999107799844,
3288: 0.6015149463851227,
528: 0.48037545624105144,
829: 0.3520640350139409,
3591: 0.5123337954949542,
973: 0.29920325457748886,
4383: 0.605345098540998,
1689: 0.699458791765087,
1286: 0.26151465192863704,
4388: 0.5996518099075245,
491: 0.8274708074953143,
3920: 0.5561721737068668,
449: 0.5309349410096579,
2840: 0.6964415538329863,
1324: 0.8352363777690135,
2801: 0.0,
1605: 0.3992469760734788,
1355: 0.5092696449238323,
1444: 0.7013725048779127,
941: 0.0,
4109: 0.7371134488841004,
1767: 0.32686654729234066,
839: 0.28650412696593686,
188: 0.11578000694274473,
3939: 0.5182830082849388,
1186: 0.5399906358163992,
540: 0.23601516039791495,
1456: 0.4462820528772964,
3925: 0.39035408504387764,
1782: 0.17470256029413367,
1733: 0.9999999289827698,
64: 0.47240949440644947,
2710: 0.21737616101123375,
893: 0.3889650515319831,
1434: 0.3144768136655605,
1244: 0.8456850404860974,
503: 0.8274708074953143,
3044: 0.5604645740029809,
1617: 0.8343403856383358,
2878: 0.6624314741881498,
913: 0.3665973835032023,
799: 0.5785308541963937,
2202: 0.584334176199583,
3503: 0.7330193052968511,
4063: 0.633698984756138,
3756: 0.588157437279164,
659: 0.8040106952622528,
1287: 0.6826384100268522,
3843: 0.7287410320020241,
2026: 1.0,
1224: 0.0,
705: 0.7278133754982946,
900: 0.592942126263229,
500: 0.5038847249789867,
614: 0.6417445279680914,
2766: 0.9999999574199627,
8: 0.722455004886235,
981: 0.6168699100990872,
1919: 0.6551439293796956,
2790: 0.0,
1098: 0.5890947178422432,
1442: 0.39311307805458195,
2634: 0.5434702892550847,
3346: 0.5956843029692919,
652: 1.0,
2324: 0.7619312086149606,
972: 0.5067710204705025,
287: 0.6569573257912408,
2481: 0.5829629588847571,
2486: 0.436286219251023,
4272: 0.5408064181796995,
4011: 0.9999999289827698,
4: 0.7647923556190919,
1645: 0.4139532701675873,
863: 0.40369910836161105,
688: 0.0,
2365: 0.7371134488841004,
3522: 0.6205927634025437,
13: 0.6688829431116972,
3251: 0.7114075759658299,
1410: 0.3589092268079449,
2306: 1.0000000396582405,
443: 0.6808489866836555,
221: 0.5811068730506951,
632: 0.5470606107366598,
2549: 0.7123831914993078,
783: 0.46296630135808603,
3221: 0.5883753355908442,
3183: 0.7371134488841004,
410: 0.7604057492722187,
1289: 0.5855230248645426,
1691: 0.727210015672603,
2015: 0.9999999107799844,
1022: 0.0,
455: 1.0,
572: 1.0,
2747: 0.7761666318621021,
3670: 0.5560044398288135,
4441: 0.7697792208927854,
2559: 0.5788817989918374,
159: 0.27027908726745226,
91: 0.5462872872864122,
263: 0.3015316394560223,
3012: 0.6611230100784922,
1234: 0.6639184765411582,
4040: 0.9999999768133089,
288: 0.7681366994965638,
89: 0.7030570621995992,
1029: 0.5322036652128525,
1180: 0.3590668280085605,
1083: 0.7805410171946893,
3970: 0.47446565960369524,
4201: 0.813152589308668,
709: 0.37964467582959255,
2401: 0.6551620258724654,
1071: 0.21475894870778542,
2954: 0.8452728458129916,
29: 0.5138088947304236},
'match': {1521: 0,
1756: 0,
2456: 0,
589: 0,
51: 0,
668: 0,
3030: 0,
864: 1,
681: 0,
372: 0,
2768: 0,
3519: 0,
2212: 1,
2424: 0,
672: 0,
1802: 0,
3910: 0,
1174: 0,
1556: 0,
922: 0,
3416: 0,
719: 0,
641: 0,
1364: 0,
3153: 0,
775: 0,
967: 0,
4054: 1,
518: 0,
121: 0,
1027: 0,
4447: 1,
257: 0,
706: 0,
3219: 0,
3009: 0,
3980: 0,
483: 0,
3154: 0,
4399: 1,
2085: 0,
373: 0,
1469: 0,
768: 0,
1491: 1,
2734: 0,
2623: 0,
746: 1,
1647: 0,
3806: 0,
4351: 0,
925: 0,
602: 0,
992: 0,
2041: 0,
1911: 0,
615: 0,
759: 0,
835: 0,
2139: 0,
56: 0,
1980: 0,
995: 1,
1696: 0,
166: 0,
114: 1,
275: 0,
2973: 0,
1313: 0,
1039: 0,
1573: 0,
771: 0,
3193: 0,
2292: 0,
2597: 0,
1747: 0,
1939: 0,
2598: 0,
1998: 0,
3288: 0,
528: 0,
829: 0,
3591: 0,
973: 0,
4383: 0,
1689: 0,
1286: 0,
4388: 0,
491: 0,
3920: 0,
449: 0,
2840: 0,
1324: 0,
2801: 0,
1605: 0,
1355: 0,
1444: 0,
941: 0,
4109: 0,
1767: 0,
839: 0,
188: 0,
3939: 0,
1186: 0,
540: 0,
1456: 0,
3925: 1,
1782: 0,
1733: 0,
64: 0,
2710: 0,
893: 0,
1434: 0,
1244: 0,
503: 0,
3044: 0,
1617: 0,
2878: 0,
913: 0,
799: 0,
2202: 0,
3503: 0,
4063: 0,
3756: 0,
659: 0,
1287: 0,
3843: 0,
2026: 1,
1224: 0,
705: 0,
900: 0,
500: 0,
614: 0,
2766: 0,
8: 0,
981: 0,
1919: 0,
2790: 0,
1098: 0,
1442: 0,
2634: 0,
3346: 0,
652: 1,
2324: 0,
972: 0,
287: 0,
2481: 0,
2486: 0,
4272: 0,
4011: 0,
4: 0,
1645: 0,
863: 0,
688: 0,
2365: 0,
3522: 0,
13: 0,
3251: 0,
1410: 0,
2306: 0,
443: 0,
221: 0,
632: 0,
2549: 0,
783: 0,
3221: 0,
3183: 0,
410: 0,
1289: 0,
1691: 0,
2015: 0,
1022: 1,
455: 1,
572: 1,
2747: 0,
3670: 0,
4441: 0,
2559: 0,
159: 0,
91: 0,
263: 0,
3012: 0,
1234: 0,
4040: 0,
288: 0,
89: 0,
1029: 0,
1180: 0,
1083: 0,
3970: 0,
4201: 0,
709: 0,
2401: 0,
1071: 0,
2954: 0,
29: 1}}
如果Score大于0.5则获取Label的值为1。如果不能正常工作,你可以创建一个新列“my_label”并设置你自己的边界,当Label获取该值时1.
奇怪的是,Score设置为标签的概率。换句话说,如果模型的原始输出为 0.01,则数据帧将显示为 Label = 0 | Score = 0.99
。如果模型的原始输出为 0.99,则数据框将显示为 Label = 1 | Score = 0.99
。我认为当你做的不仅仅是二元分类时,这可能更有意义。
如果你对我的话不满意(我不会怪你),你可以通过将预测线更改为
来获得原始分数
tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, raw_score=True, data = sample)
注意 raw_score=True
。然后你的数据框将有两个分数列(Score_0
和 Score_1
)。从那里,你可以通过
得到你想要的直方图
tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==0].Score_1.hist()
tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==1].Score_1.hist()
我正在使用以下代码使用 pycaret 开发分类模型:
sample = pd.DataFrame(sample)
exp_clf = setup(sample, target = 'match',fix_imbalance = True)
clf_model = create_model('lightgbm')
tuned_clf_model = tune_model(clf_model, optimize = 'Recall')
tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, data = sample)
现在问题就来了,因为标签1和0的分数重叠了:
这是我正在使用的数据,它可以像上面代码的第一行一样转换为字典读取并转换为数据帧。
sample = {'same_add_number': {1521: False,
1756: False,
2456: False,
589: False,
51: False,
668: False,
3030: False,
864: True,
681: False,
372: False,
2768: False,
3519: False,
2212: True,
2424: False,
672: False,
1802: False,
3910: False,
1174: False,
1556: False,
922: False,
3416: False,
719: False,
641: False,
1364: False,
3153: False,
775: False,
967: False,
4054: False,
518: False,
121: False,
1027: False,
4447: True,
257: False,
706: True,
3219: False,
3009: True,
3980: False,
483: False,
3154: False,
4399: True,
2085: False,
373: False,
1469: False,
768: False,
1491: True,
2734: False,
2623: False,
746: True,
1647: False,
3806: False,
4351: False,
925: False,
602: False,
992: False,
2041: False,
1911: False,
615: False,
759: False,
835: False,
2139: False,
56: False,
1980: False,
995: True,
1696: False,
166: False,
114: True,
275: False,
2973: False,
1313: False,
1039: False,
1573: False,
771: False,
3193: False,
2292: False,
2597: False,
1747: False,
1939: False,
2598: False,
1998: False,
3288: False,
528: False,
829: False,
3591: False,
973: False,
4383: False,
1689: False,
1286: False,
4388: False,
491: False,
3920: False,
449: False,
2840: False,
1324: False,
2801: False,
1605: False,
1355: False,
1444: False,
941: False,
4109: False,
1767: False,
839: False,
188: False,
3939: False,
1186: False,
540: False,
1456: False,
3925: True,
1782: False,
1733: False,
64: True,
2710: False,
893: False,
1434: False,
1244: False,
503: False,
3044: False,
1617: False,
2878: False,
913: False,
799: False,
2202: False,
3503: False,
4063: False,
3756: False,
659: False,
1287: False,
3843: False,
2026: True,
1224: False,
705: False,
900: False,
500: False,
614: False,
2766: False,
8: False,
981: False,
1919: False,
2790: False,
1098: False,
1442: False,
2634: False,
3346: False,
652: True,
2324: False,
972: False,
287: False,
2481: False,
2486: False,
4272: False,
4011: False,
4: False,
1645: False,
863: False,
688: False,
2365: False,
3522: False,
13: False,
3251: False,
1410: False,
2306: False,
443: False,
221: False,
632: True,
2549: False,
783: False,
3221: False,
3183: False,
410: False,
1289: False,
1691: False,
2015: False,
1022: True,
455: False,
572: False,
2747: False,
3670: False,
4441: False,
2559: False,
159: False,
91: False,
263: False,
3012: False,
1234: False,
4040: False,
288: False,
89: False,
1029: False,
1180: False,
1083: False,
3970: False,
4201: False,
709: False,
2401: False,
1071: False,
2954: True,
29: True},
'same_add_name': {1521: False,
1756: False,
2456: False,
589: False,
51: False,
668: False,
3030: False,
864: False,
681: False,
372: False,
2768: False,
3519: False,
2212: False,
2424: False,
672: False,
1802: False,
3910: False,
1174: False,
1556: False,
922: False,
3416: False,
719: False,
641: False,
1364: False,
3153: False,
775: False,
967: False,
4054: False,
518: False,
121: False,
1027: False,
4447: False,
257: False,
706: False,
3219: False,
3009: False,
3980: False,
483: False,
3154: False,
4399: False,
2085: False,
373: False,
1469: False,
768: False,
1491: False,
2734: False,
2623: False,
746: False,
1647: False,
3806: False,
4351: False,
925: False,
602: False,
992: False,
2041: False,
1911: False,
615: True,
759: False,
835: False,
2139: False,
56: False,
1980: False,
995: False,
1696: False,
166: False,
114: False,
275: False,
2973: False,
1313: False,
1039: False,
1573: False,
771: False,
3193: False,
2292: False,
2597: False,
1747: False,
1939: False,
2598: False,
1998: False,
3288: False,
528: False,
829: False,
3591: False,
973: False,
4383: False,
1689: False,
1286: False,
4388: False,
491: False,
3920: False,
449: False,
2840: False,
1324: False,
2801: False,
1605: False,
1355: False,
1444: False,
941: False,
4109: False,
1767: False,
839: False,
188: False,
3939: False,
1186: False,
540: False,
1456: False,
3925: False,
1782: False,
1733: False,
64: False,
2710: False,
893: False,
1434: False,
1244: False,
503: False,
3044: False,
1617: False,
2878: False,
913: False,
799: False,
2202: False,
3503: False,
4063: False,
3756: False,
659: False,
1287: False,
3843: False,
2026: False,
1224: False,
705: False,
900: False,
500: False,
614: False,
2766: False,
8: False,
981: False,
1919: False,
2790: False,
1098: False,
1442: False,
2634: False,
3346: False,
652: False,
2324: False,
972: False,
287: False,
2481: False,
2486: False,
4272: False,
4011: False,
4: False,
1645: False,
863: False,
688: False,
2365: False,
3522: False,
13: False,
3251: False,
1410: False,
2306: False,
443: False,
221: False,
632: False,
2549: False,
783: False,
3221: False,
3183: False,
410: False,
1289: False,
1691: False,
2015: False,
1022: False,
455: False,
572: False,
2747: False,
3670: False,
4441: False,
2559: False,
159: False,
91: False,
263: True,
3012: False,
1234: False,
4040: False,
288: False,
89: False,
1029: False,
1180: False,
1083: False,
3970: False,
4201: False,
709: False,
2401: False,
1071: False,
2954: False,
29: False},
'name_score_fuzzy': {1521: 78,
1756: 71,
2456: 73,
589: 38,
51: 71,
668: 49,
3030: 75,
864: 47,
681: 75,
372: 72,
2768: 73,
3519: 85,
2212: 100,
2424: 85,
672: 74,
1802: 46,
3910: 73,
1174: 47,
1556: 80,
922: 73,
3416: 71,
719: 55,
641: 71,
1364: 79,
3153: 74,
775: 54,
967: 73,
4054: 100,
518: 72,
121: 49,
1027: 38,
4447: 100,
257: 74,
706: 40,
3219: 71,
3009: 93,
3980: 72,
483: 46,
3154: 68,
4399: 100,
2085: 80,
373: 77,
1469: 23,
768: 50,
1491: 100,
2734: 79,
2623: 79,
746: 88,
1647: 73,
3806: 79,
4351: 72,
925: 65,
602: 83,
992: 46,
2041: 78,
1911: 77,
615: 45,
759: 52,
835: 77,
2139: 77,
56: 81,
1980: 71,
995: 59,
1696: 83,
166: 71,
114: 50,
275: 47,
2973: 80,
1313: 73,
1039: 75,
1573: 70,
771: 53,
3193: 100,
2292: 79,
2597: 71,
1747: 78,
1939: 84,
2598: 71,
1998: 77,
3288: 85,
528: 44,
829: 72,
3591: 80,
973: 47,
4383: 80,
1689: 85,
1286: 41,
4388: 75,
491: 77,
3920: 70,
449: 73,
2840: 79,
1324: 81,
2801: 73,
1605: 47,
1355: 72,
1444: 72,
941: 62,
4109: 79,
1767: 34,
839: 35,
188: 63,
3939: 75,
1186: 49,
540: 44,
1456: 41,
3925: 91,
1782: 43,
1733: 74,
64: 21,
2710: 71,
893: 57,
1434: 75,
1244: 77,
503: 75,
3044: 71,
1617: 73,
2878: 71,
913: 63,
799: 78,
2202: 71,
3503: 77,
4063: 75,
3756: 77,
659: 51,
1287: 76,
3843: 73,
2026: 100,
1224: 71,
705: 81,
900: 65,
500: 42,
614: 81,
2766: 76,
8: 71,
981: 73,
1919: 73,
2790: 71,
1098: 76,
1442: 73,
2634: 73,
3346: 81,
652: 100,
2324: 84,
972: 73,
287: 63,
2481: 76,
2486: 76,
4272: 64,
4011: 73,
4: 74,
1645: 17,
863: 46,
688: 71,
2365: 76,
3522: 73,
13: 52,
3251: 74,
1410: 80,
2306: 71,
443: 71,
221: 73,
632: 65,
2549: 80,
783: 53,
3221: 71,
3183: 75,
410: 53,
1289: 71,
1691: 85,
2015: 71,
1022: 67,
455: 100,
572: 100,
2747: 77,
3670: 74,
4441: 81,
2559: 84,
159: 22,
91: 79,
263: 41,
3012: 76,
1234: 77,
4040: 73,
288: 82,
89: 71,
1029: 82,
1180: 78,
1083: 77,
3970: 75,
4201: 76,
709: 46,
2401: 76,
1071: 83,
2954: 93,
29: 52},
'name_score_cos': {1521: 0.805341232815891,
1756: 1.0000000156276607,
2456: 0.7146280288550899,
589: 0.4944973860854622,
51: 0.16448994174134138,
668: 0.6680419517655739,
3030: 0.5178230596082453,
864: 0.34284966537760764,
681: 0.8220122172271629,
372: 0.7372570578072887,
2768: 1.0000000748631144,
3519: 0.6544869126589294,
2212: 1.0,
2424: 0.9999999107799844,
672: 0.8006864625973021,
1802: 0.008748746635272902,
3910: 0.6029157847994123,
1174: 0.43891392720221256,
1556: 0.4592255006317409,
922: 0.602017340163112,
3416: 0.7887549792307141,
719: 0.13458379717430374,
641: 0.8221775985370106,
1364: 0.8349841579827227,
3153: 0.6395051509895127,
775: 0.4861694445439952,
967: 0.6240594839420581,
4054: 1.0,
518: 0.8274708074953143,
121: 0.4156175285346006,
1027: 0.4172238782731538,
4447: 1.0,
257: 0.7144798398523643,
706: 0.2914152988288179,
3219: 0.4892006725361837,
3009: 0.8732375138387463,
3980: 0.5371502775293667,
483: 0.6532926383429954,
3154: 0.7500245353516992,
4399: 1.0,
2085: 0.6994934983150074,
373: 0.0,
1469: 0.13834207989466868,
768: 0.0,
1491: 1.0,
2734: 0.5744607478435466,
2623: 0.521054474126365,
746: 0.900627520280279,
1647: 0.46841195036889005,
3806: 0.5245533025793365,
4351: 0.7190153036645236,
925: 0.602017340163112,
602: 0.8180017827481202,
992: 0.6552306767756036,
2041: 0.8416265969822513,
1911: 0.5760342064839252,
615: 0.3142721314062845,
759: 0.29937879126297773,
835: 0.4814135508437952,
2139: 0.8103994874531241,
56: 0.4777649573427413,
1980: 0.4501770315717141,
995: 0.3185447219204094,
1696: 0.9999999289827698,
166: 0.0,
114: 0.0,
275: -0.059108179802214694,
2973: 0.0,
1313: 0.4103695338595878,
1039: 0.4158014949799697,
1573: 0.7687119146546476,
771: -0.038431693364239676,
3193: 1.0,
2292: 0.9999999289827698,
2597: 0.7014107947566588,
1747: 0.613680567239729,
1939: 0.8930406720693059,
2598: 1.0000000156276607,
1998: 0.9999999107799844,
3288: 0.6015149463851227,
528: 0.48037545624105144,
829: 0.3520640350139409,
3591: 0.5123337954949542,
973: 0.29920325457748886,
4383: 0.605345098540998,
1689: 0.699458791765087,
1286: 0.26151465192863704,
4388: 0.5996518099075245,
491: 0.8274708074953143,
3920: 0.5561721737068668,
449: 0.5309349410096579,
2840: 0.6964415538329863,
1324: 0.8352363777690135,
2801: 0.0,
1605: 0.3992469760734788,
1355: 0.5092696449238323,
1444: 0.7013725048779127,
941: 0.0,
4109: 0.7371134488841004,
1767: 0.32686654729234066,
839: 0.28650412696593686,
188: 0.11578000694274473,
3939: 0.5182830082849388,
1186: 0.5399906358163992,
540: 0.23601516039791495,
1456: 0.4462820528772964,
3925: 0.39035408504387764,
1782: 0.17470256029413367,
1733: 0.9999999289827698,
64: 0.47240949440644947,
2710: 0.21737616101123375,
893: 0.3889650515319831,
1434: 0.3144768136655605,
1244: 0.8456850404860974,
503: 0.8274708074953143,
3044: 0.5604645740029809,
1617: 0.8343403856383358,
2878: 0.6624314741881498,
913: 0.3665973835032023,
799: 0.5785308541963937,
2202: 0.584334176199583,
3503: 0.7330193052968511,
4063: 0.633698984756138,
3756: 0.588157437279164,
659: 0.8040106952622528,
1287: 0.6826384100268522,
3843: 0.7287410320020241,
2026: 1.0,
1224: 0.0,
705: 0.7278133754982946,
900: 0.592942126263229,
500: 0.5038847249789867,
614: 0.6417445279680914,
2766: 0.9999999574199627,
8: 0.722455004886235,
981: 0.6168699100990872,
1919: 0.6551439293796956,
2790: 0.0,
1098: 0.5890947178422432,
1442: 0.39311307805458195,
2634: 0.5434702892550847,
3346: 0.5956843029692919,
652: 1.0,
2324: 0.7619312086149606,
972: 0.5067710204705025,
287: 0.6569573257912408,
2481: 0.5829629588847571,
2486: 0.436286219251023,
4272: 0.5408064181796995,
4011: 0.9999999289827698,
4: 0.7647923556190919,
1645: 0.4139532701675873,
863: 0.40369910836161105,
688: 0.0,
2365: 0.7371134488841004,
3522: 0.6205927634025437,
13: 0.6688829431116972,
3251: 0.7114075759658299,
1410: 0.3589092268079449,
2306: 1.0000000396582405,
443: 0.6808489866836555,
221: 0.5811068730506951,
632: 0.5470606107366598,
2549: 0.7123831914993078,
783: 0.46296630135808603,
3221: 0.5883753355908442,
3183: 0.7371134488841004,
410: 0.7604057492722187,
1289: 0.5855230248645426,
1691: 0.727210015672603,
2015: 0.9999999107799844,
1022: 0.0,
455: 1.0,
572: 1.0,
2747: 0.7761666318621021,
3670: 0.5560044398288135,
4441: 0.7697792208927854,
2559: 0.5788817989918374,
159: 0.27027908726745226,
91: 0.5462872872864122,
263: 0.3015316394560223,
3012: 0.6611230100784922,
1234: 0.6639184765411582,
4040: 0.9999999768133089,
288: 0.7681366994965638,
89: 0.7030570621995992,
1029: 0.5322036652128525,
1180: 0.3590668280085605,
1083: 0.7805410171946893,
3970: 0.47446565960369524,
4201: 0.813152589308668,
709: 0.37964467582959255,
2401: 0.6551620258724654,
1071: 0.21475894870778542,
2954: 0.8452728458129916,
29: 0.5138088947304236},
'match': {1521: 0,
1756: 0,
2456: 0,
589: 0,
51: 0,
668: 0,
3030: 0,
864: 1,
681: 0,
372: 0,
2768: 0,
3519: 0,
2212: 1,
2424: 0,
672: 0,
1802: 0,
3910: 0,
1174: 0,
1556: 0,
922: 0,
3416: 0,
719: 0,
641: 0,
1364: 0,
3153: 0,
775: 0,
967: 0,
4054: 1,
518: 0,
121: 0,
1027: 0,
4447: 1,
257: 0,
706: 0,
3219: 0,
3009: 0,
3980: 0,
483: 0,
3154: 0,
4399: 1,
2085: 0,
373: 0,
1469: 0,
768: 0,
1491: 1,
2734: 0,
2623: 0,
746: 1,
1647: 0,
3806: 0,
4351: 0,
925: 0,
602: 0,
992: 0,
2041: 0,
1911: 0,
615: 0,
759: 0,
835: 0,
2139: 0,
56: 0,
1980: 0,
995: 1,
1696: 0,
166: 0,
114: 1,
275: 0,
2973: 0,
1313: 0,
1039: 0,
1573: 0,
771: 0,
3193: 0,
2292: 0,
2597: 0,
1747: 0,
1939: 0,
2598: 0,
1998: 0,
3288: 0,
528: 0,
829: 0,
3591: 0,
973: 0,
4383: 0,
1689: 0,
1286: 0,
4388: 0,
491: 0,
3920: 0,
449: 0,
2840: 0,
1324: 0,
2801: 0,
1605: 0,
1355: 0,
1444: 0,
941: 0,
4109: 0,
1767: 0,
839: 0,
188: 0,
3939: 0,
1186: 0,
540: 0,
1456: 0,
3925: 1,
1782: 0,
1733: 0,
64: 0,
2710: 0,
893: 0,
1434: 0,
1244: 0,
503: 0,
3044: 0,
1617: 0,
2878: 0,
913: 0,
799: 0,
2202: 0,
3503: 0,
4063: 0,
3756: 0,
659: 0,
1287: 0,
3843: 0,
2026: 1,
1224: 0,
705: 0,
900: 0,
500: 0,
614: 0,
2766: 0,
8: 0,
981: 0,
1919: 0,
2790: 0,
1098: 0,
1442: 0,
2634: 0,
3346: 0,
652: 1,
2324: 0,
972: 0,
287: 0,
2481: 0,
2486: 0,
4272: 0,
4011: 0,
4: 0,
1645: 0,
863: 0,
688: 0,
2365: 0,
3522: 0,
13: 0,
3251: 0,
1410: 0,
2306: 0,
443: 0,
221: 0,
632: 0,
2549: 0,
783: 0,
3221: 0,
3183: 0,
410: 0,
1289: 0,
1691: 0,
2015: 0,
1022: 1,
455: 1,
572: 1,
2747: 0,
3670: 0,
4441: 0,
2559: 0,
159: 0,
91: 0,
263: 0,
3012: 0,
1234: 0,
4040: 0,
288: 0,
89: 0,
1029: 0,
1180: 0,
1083: 0,
3970: 0,
4201: 0,
709: 0,
2401: 0,
1071: 0,
2954: 0,
29: 1}}
如果Score大于0.5则获取Label的值为1。如果不能正常工作,你可以创建一个新列“my_label”并设置你自己的边界,当Label获取该值时1.
奇怪的是,Score设置为标签的概率。换句话说,如果模型的原始输出为 0.01,则数据帧将显示为 Label = 0 | Score = 0.99
。如果模型的原始输出为 0.99,则数据框将显示为 Label = 1 | Score = 0.99
。我认为当你做的不仅仅是二元分类时,这可能更有意义。
如果你对我的话不满意(我不会怪你),你可以通过将预测线更改为
来获得原始分数tuned_tuned_clf_model_pred = predict_model(tuned_clf_model, raw_score=True, data = sample)
注意 raw_score=True
。然后你的数据框将有两个分数列(Score_0
和 Score_1
)。从那里,你可以通过
tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==0].Score_1.hist()
tuned_tuned_clf_model_pred[tuned_tuned_clf_model_pred["Label"]==1].Score_1.hist()