如何在 fastai 中改变混淆矩阵的大小?
How to change size of confusion matrix in fastai?
我正在使用以下代码在 fastai
中绘制混淆矩阵:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
但我最终得到了一个超小的矩阵,因为我有大约 20 个类别:
我找到了 sklearns 的相关问题,但不知道如何将其应用于 fastai(因为我们不直接使用 pyplot
。
如果你检查函数 ClassificationInterpretation.plot_confusion_matrix
的代码(在文件 fastai / interpret.py),你看到的是:
def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap="Blues", norm_dec=2,
plot_txt=True, **kwargs):
"Plot the confusion matrix, with `title` and using `cmap`."
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix()
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig = plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = np.arange(len(self.vocab))
plt.xticks(tick_marks, self.vocab, rotation=90)
plt.yticks(tick_marks, self.vocab, rotation=0)
这里的关键是行 fig = plt.figure(**kwargs)
,所以基本上,函数 plot_confusion_matrix
会将其参数传播到绘图函数。
因此您可以使用其中之一:
dpi=xxx
(例如dpi=200
)
figsize=(xxx, yyy)
在 Whosebug 上查看 post 关于他们之间的关系:
所以在你的情况下,你可以这样做:
interp.plot_confusion_matrix(figsize=(12, 12))
混淆矩阵看起来像:
仅供参考:这也适用于其他绘图函数,例如
interp.plot_top_losses(20, nrows=5, figsize=(25, 25))
我正在使用以下代码在 fastai
中绘制混淆矩阵:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
但我最终得到了一个超小的矩阵,因为我有大约 20 个类别:
我找到了 sklearns 的相关问题,但不知道如何将其应用于 fastai(因为我们不直接使用 pyplot
。
如果你检查函数 ClassificationInterpretation.plot_confusion_matrix
的代码(在文件 fastai / interpret.py),你看到的是:
def plot_confusion_matrix(self, normalize=False, title='Confusion matrix', cmap="Blues", norm_dec=2,
plot_txt=True, **kwargs):
"Plot the confusion matrix, with `title` and using `cmap`."
# This function is mainly copied from the sklearn docs
cm = self.confusion_matrix()
if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig = plt.figure(**kwargs)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
tick_marks = np.arange(len(self.vocab))
plt.xticks(tick_marks, self.vocab, rotation=90)
plt.yticks(tick_marks, self.vocab, rotation=0)
这里的关键是行 fig = plt.figure(**kwargs)
,所以基本上,函数 plot_confusion_matrix
会将其参数传播到绘图函数。
因此您可以使用其中之一:
dpi=xxx
(例如dpi=200
)figsize=(xxx, yyy)
在 Whosebug 上查看 post 关于他们之间的关系:
所以在你的情况下,你可以这样做:
interp.plot_confusion_matrix(figsize=(12, 12))
混淆矩阵看起来像:
仅供参考:这也适用于其他绘图函数,例如
interp.plot_top_losses(20, nrows=5, figsize=(25, 25))