在为 keras 模型预测计算 AUC ROC 时出现错误
Getting error while calculating AUC ROC for keras model predictions
我有一个名为 dat 的患者数据和名为 labl 的标签(0 = 无疾病,1 = 疾病)的阵列。我预测了我的模型并存储了名为 pre 的预测,它也是一个数组,我想计算并绘制 AUC ROC。但是我在这样做时遇到了这个错误。
TypeError: Singleton array array(0., dtype=float32) cannot be considered a valid collection.
这只是一个病历。但是当我在更多患者身上预测我的模型时,我可以很容易地计算出 AUC ROC。但我只想为一名患者找到它。
>>> dat
array([[[114.6 , 93.1 , 37.17, 118.3 , 64.3 , 22. , 45. , 0. ],
[110. , 94.5 , 37.3 , 136. , 59. , 17.5 , 45. , 0. ],
[104. , 95. , 37.17, 154. , 74. , 26. , 45. , 0. ],
[106. , 94. , 37.17, 124. , 64. , 17. , 45. , 0. ],
[110. , 92.5 , 37.17, 133. , 62. , 17. , 45. , 0. ],
[114. , 92.5 , 36.7 , 127. , 62. , 21. , 45. , 0. ],
[106. , 95. , 37.17, 124. , 64. , 19. , 45. , 0. ],
[110. , 93. , 37.17, 138. , 70. , 17. , 45. , 0. ],
[114. , 90. , 37.17, 134. , 66. , 16. , 45. , 0. ],
[114. , 89. , 37.17, 116. , 60. , 20. , 45. , 0. ],
[120. , 91. , 37.17, 140. , 80. , 15. , 45. , 0. ],
[120. , 90. , 37.17, 122. , 72. , 15. , 45. , 0. ],
[120. , 92. , 37.17, 106. , 64. , 16. , 45. , 0. ],
[ 64. , 93. , 37.17, 100. , 53. , 20. , 45. , 0. ],
[128. , 95. , 37.17, 194. , 86. , 15. , 45. , 0. ],
[126. , 93. , 37.17, 34. , 30. , 27. , 45. , 0. ],
[124. , 94.5 , 37.17, 80. , 59. , 35. , 45. , 0. ],
[127. , 97. , 37.5 , 102. , 69. , 35. , 45. , 0. ],
[130. , 97. , 37.17, 94. , 66. , 35. , 45. , 0. ],
[130. , 90. , 37.17, 90. , 62. , 35. , 45. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]]],
dtype=float32)
>>> labl
array([[[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.]]], dtype=float32)
>>> pre
array([[[0.24694729],
[0.42795685],
[0.5010372 ],
[0.52086353],
[0.52870005],
[0.5377407 ],
[0.5345124 ],
[0.5310055 ],
[0.531648 ],
[0.5410067 ],
[0.5446999 ],
[0.5466636 ],
[0.5504297 ],
[0.5236943 ],
[0.5244271 ],
[0.5483868 ],
[0.5533212 ],
[0.5523378 ],
[0.5553032 ],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267]]], dtype=float32)
使用下面的代码我计算了每个时间的死亡率。但是计算AUC ROC失败
# Figure out how many encounters we have
numencnt = dat.shape[0]
# Choose a random patient encounter to plot
ix = random.randint(0,numencnt-1)
# Create axis side by side
f, (ax1, ax2) = plt.subplots(2, 1)
# Plot the observation chart for the random patient encounter
ax1.pcolor(np.transpose(dat[ix,1:72,:]))
ax1.set_ylim(0,8)
plt.ylabel("mortality")
plt.xlabel("time/observation")
# Plot the patient survivability prediction
ax2.plot(pre[ix,1:72]);
剧情
这是我收到错误的地方:
from sklearn.metrics import roc_curve, auc
# get 0/1 binary label for each patient encounter
label = labl[:, 0, :].squeeze();
# get the last prediction in [0,1] for the patient
prediction = pre[:, -1, :].squeeze()
# compute ROC curve for predictions
rnn_roc = roc_curve(label,prediction)
# compute the area under the curve of prediction ROC
rnn_auc = auc(rnn_roc[0], rnn_roc[1])
--------------------------------------------------------------------------- TypeError Traceback (most recent call
last) /tmp/ipykernel_129/3666067037.py in
8
9 # compute ROC curve for predictions
---> 10 rnn_roc = roc_curve(label,prediction)
11
12 # compute the area under the curve of prediction ROC
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/metrics/_ranking.py
in roc_curve(y_true, y_score, pos_label, sample_weight,
drop_intermediate)
960
961 """
--> 962 fps, tps, thresholds = _binary_clf_curve(
963 y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
964 )
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/metrics/_ranking.py
in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
731 raise ValueError("{0} format is not supported".format(y_type))
732
--> 733 check_consistent_length(y_true, y_score, sample_weight)
734 y_true = column_or_1d(y_true)
735 y_score = column_or_1d(y_score)
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py
in check_consistent_length(*arrays)
327 """
328
--> 329 lengths = [_num_samples(X) for X in arrays if X is not None]
330 uniques = np.unique(lengths)
331 if len(uniques) > 1:
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py
in (.0)
327 """
328
--> 329 lengths = [_num_samples(X) for X in arrays if X is not None]
330 uniques = np.unique(lengths)
331 if len(uniques) > 1:
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py
in _num_samples(x)
267 if hasattr(x, "shape") and x.shape is not None:
268 if len(x.shape) == 0:
--> 269 raise TypeError(
270 "Singleton array %r cannot be considered a valid collection." % x
271 )
TypeError: Singleton array array(0., dtype=float32) cannot be
considered a valid collection.
# plot rocs & display AUCs
plt.figure(figsize=(7, 5))
line_kwargs = {'linewidth': 4, 'alpha': 0.8}
plt.plot(rnn_roc[0], rnn_roc[1], label='LSTM: %0.3f' % rnn_auc, color='#6AA84F', **line_kwargs)
plt.legend(loc='lower right', fontsize=20)
plt.xlim((-0.05, 1.05))
plt.ylim((-0.05, 1.05))
plt.xticks([0, 0.25, 0.5, 0.75, 1.0], fontsize=14)
plt.yticks([0, 0.25, 0.5, 0.75, 1.0], fontsize=14)
plt.xlabel("False Positive Rate", fontsize=18)
plt.ylabel("True Positive Rate", fontsize=18)
plt.title("ROC Curve", fontsize=24)
plt.grid(alpha=0.25)
plt.tight_layout()
问题出在你的squeeze
。使用 squeeze
时不需要指定索引。 squeeze
将数组展平为一维。如果您选择 [:,0,:]
,则只有 1 个条目,因此会出现错误。
简单做
# get 0/1 binary label for each patient encounter
label = labl.squeeze();
# get the last prediction in [0,1] for the patient
prediction = pre.squeeze()
我有一个名为 dat 的患者数据和名为 labl 的标签(0 = 无疾病,1 = 疾病)的阵列。我预测了我的模型并存储了名为 pre 的预测,它也是一个数组,我想计算并绘制 AUC ROC。但是我在这样做时遇到了这个错误。
TypeError: Singleton array array(0., dtype=float32) cannot be considered a valid collection.
这只是一个病历。但是当我在更多患者身上预测我的模型时,我可以很容易地计算出 AUC ROC。但我只想为一名患者找到它。
>>> dat
array([[[114.6 , 93.1 , 37.17, 118.3 , 64.3 , 22. , 45. , 0. ],
[110. , 94.5 , 37.3 , 136. , 59. , 17.5 , 45. , 0. ],
[104. , 95. , 37.17, 154. , 74. , 26. , 45. , 0. ],
[106. , 94. , 37.17, 124. , 64. , 17. , 45. , 0. ],
[110. , 92.5 , 37.17, 133. , 62. , 17. , 45. , 0. ],
[114. , 92.5 , 36.7 , 127. , 62. , 21. , 45. , 0. ],
[106. , 95. , 37.17, 124. , 64. , 19. , 45. , 0. ],
[110. , 93. , 37.17, 138. , 70. , 17. , 45. , 0. ],
[114. , 90. , 37.17, 134. , 66. , 16. , 45. , 0. ],
[114. , 89. , 37.17, 116. , 60. , 20. , 45. , 0. ],
[120. , 91. , 37.17, 140. , 80. , 15. , 45. , 0. ],
[120. , 90. , 37.17, 122. , 72. , 15. , 45. , 0. ],
[120. , 92. , 37.17, 106. , 64. , 16. , 45. , 0. ],
[ 64. , 93. , 37.17, 100. , 53. , 20. , 45. , 0. ],
[128. , 95. , 37.17, 194. , 86. , 15. , 45. , 0. ],
[126. , 93. , 37.17, 34. , 30. , 27. , 45. , 0. ],
[124. , 94.5 , 37.17, 80. , 59. , 35. , 45. , 0. ],
[127. , 97. , 37.5 , 102. , 69. , 35. , 45. , 0. ],
[130. , 97. , 37.17, 94. , 66. , 35. , 45. , 0. ],
[130. , 90. , 37.17, 90. , 62. , 35. , 45. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]]],
dtype=float32)
>>> labl
array([[[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[1.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.],
[0.]]], dtype=float32)
>>> pre
array([[[0.24694729],
[0.42795685],
[0.5010372 ],
[0.52086353],
[0.52870005],
[0.5377407 ],
[0.5345124 ],
[0.5310055 ],
[0.531648 ],
[0.5410067 ],
[0.5446999 ],
[0.5466636 ],
[0.5504297 ],
[0.5236943 ],
[0.5244271 ],
[0.5483868 ],
[0.5533212 ],
[0.5523378 ],
[0.5553032 ],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267],
[0.55902267]]], dtype=float32)
使用下面的代码我计算了每个时间的死亡率。但是计算AUC ROC失败
# Figure out how many encounters we have
numencnt = dat.shape[0]
# Choose a random patient encounter to plot
ix = random.randint(0,numencnt-1)
# Create axis side by side
f, (ax1, ax2) = plt.subplots(2, 1)
# Plot the observation chart for the random patient encounter
ax1.pcolor(np.transpose(dat[ix,1:72,:]))
ax1.set_ylim(0,8)
plt.ylabel("mortality")
plt.xlabel("time/observation")
# Plot the patient survivability prediction
ax2.plot(pre[ix,1:72]);
剧情
这是我收到错误的地方:
from sklearn.metrics import roc_curve, auc
# get 0/1 binary label for each patient encounter
label = labl[:, 0, :].squeeze();
# get the last prediction in [0,1] for the patient
prediction = pre[:, -1, :].squeeze()
# compute ROC curve for predictions
rnn_roc = roc_curve(label,prediction)
# compute the area under the curve of prediction ROC
rnn_auc = auc(rnn_roc[0], rnn_roc[1])
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_129/3666067037.py in 8 9 # compute ROC curve for predictions ---> 10 rnn_roc = roc_curve(label,prediction) 11 12 # compute the area under the curve of prediction ROC
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/metrics/_ranking.py in roc_curve(y_true, y_score, pos_label, sample_weight, drop_intermediate) 960 961 """ --> 962 fps, tps, thresholds = _binary_clf_curve( 963 y_true, y_score, pos_label=pos_label, sample_weight=sample_weight 964 )
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight) 731 raise ValueError("{0} format is not supported".format(y_type)) 732 --> 733 check_consistent_length(y_true, y_score, sample_weight) 734 y_true = column_or_1d(y_true) 735 y_score = column_or_1d(y_score)
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py in check_consistent_length(*arrays) 327 """ 328 --> 329 lengths = [_num_samples(X) for X in arrays if X is not None] 330 uniques = np.unique(lengths) 331 if len(uniques) > 1:
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py in (.0) 327 """ 328 --> 329 lengths = [_num_samples(X) for X in arrays if X is not None] 330 uniques = np.unique(lengths) 331 if len(uniques) > 1:
~/.conda/envs/default/lib/python3.9/site-packages/sklearn/utils/validation.py in _num_samples(x) 267 if hasattr(x, "shape") and x.shape is not None: 268 if len(x.shape) == 0: --> 269 raise TypeError( 270 "Singleton array %r cannot be considered a valid collection." % x 271 )
TypeError: Singleton array array(0., dtype=float32) cannot be considered a valid collection.
# plot rocs & display AUCs
plt.figure(figsize=(7, 5))
line_kwargs = {'linewidth': 4, 'alpha': 0.8}
plt.plot(rnn_roc[0], rnn_roc[1], label='LSTM: %0.3f' % rnn_auc, color='#6AA84F', **line_kwargs)
plt.legend(loc='lower right', fontsize=20)
plt.xlim((-0.05, 1.05))
plt.ylim((-0.05, 1.05))
plt.xticks([0, 0.25, 0.5, 0.75, 1.0], fontsize=14)
plt.yticks([0, 0.25, 0.5, 0.75, 1.0], fontsize=14)
plt.xlabel("False Positive Rate", fontsize=18)
plt.ylabel("True Positive Rate", fontsize=18)
plt.title("ROC Curve", fontsize=24)
plt.grid(alpha=0.25)
plt.tight_layout()
问题出在你的squeeze
。使用 squeeze
时不需要指定索引。 squeeze
将数组展平为一维。如果您选择 [:,0,:]
,则只有 1 个条目,因此会出现错误。
简单做
# get 0/1 binary label for each patient encounter
label = labl.squeeze();
# get the last prediction in [0,1] for the patient
prediction = pre.squeeze()