如果非对角线条目中全为零,则删除行和列
Removing rows and columns if all zeros in non-diagonal entries
我正在生成一个 confusion matrix
来了解我的 text-classifier
的 prediction
与 ground-truth
。目的是了解哪些 intent
被预测为另一个 intent
。但是问题是我的类太多了(多于160
),所以矩阵是sparse
,其中大部分字段都是zeros
。显然,对角线元素很可能是非零的,因为它基本上是正确预测的指示。
既然如此,我想生成一个更简单的版本,因为我们只关心 non-zero
个元素,如果它们是 non-diagonal
,因此,我想删除 row
s 和 column
s,其中所有元素均为零(忽略 diagonal
条目),这样图形变得更小并且易于查看。怎么做?
以下是我到目前为止完成的代码片段,它将为所有意图生成映射,即 (#intent, #intent)
维度图。
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
import seaborn as sns
%matplotlib inline
sns.set(rc={'figure.figsize':(64,64)})
confusion_matrix = pd.crosstab(df['ground_truth_intent_name'], df['predicted_intent_name'])
variables = sorted(list(set(df['ground_truth_intent_name'])))
temp = DataFrame(confusion_matrix, index=variables, columns=variables)
sns.heatmap(temp, annot=True)
TL;DR
这里temp
是一个pandas dataframe
。我需要删除所有元素为零的所有行和列(忽略对角线元素,即使它们不为零)。
你可以用any
来比较,但首先你需要用0
填充对角线:
# also consider using
# a = np.isclose(confusion_matrix.to_numpy(), 0)
a = confusion_matrix.to_numpy() != 0
# fill diagonal
np.fill_diagonal(a, False)
# columns with at least one non-zero
cols = a.any(axis=0)
# rows with at least one non-zero
rows = a.any(axis=1)
# boolean indexing
confusion_matrix.loc[rows, cols]
举个例子:
# random data
np.random.seed(1)
# this would agree with the above
a = np.random.randint(0,2, (5,5))
a[2] = 0
a[:-1,-1] = 0
confusion_matrix = pd.DataFrame(a)
因此数据将是:
0 1 2 3 4
0 1 1 0 0 0
1 1 1 1 1 0
2 0 0 0 0 0
3 0 0 1 0 0
4 0 1 0 0 1
和代码输出(注意第 2 行和第 4 列消失了):
0 1 2 3
0 1 1 0 0
1 1 1 1 1
3 0 0 1 0
4 0 1 0 0
我正在生成一个 confusion matrix
来了解我的 text-classifier
的 prediction
与 ground-truth
。目的是了解哪些 intent
被预测为另一个 intent
。但是问题是我的类太多了(多于160
),所以矩阵是sparse
,其中大部分字段都是zeros
。显然,对角线元素很可能是非零的,因为它基本上是正确预测的指示。
既然如此,我想生成一个更简单的版本,因为我们只关心 non-zero
个元素,如果它们是 non-diagonal
,因此,我想删除 row
s 和 column
s,其中所有元素均为零(忽略 diagonal
条目),这样图形变得更小并且易于查看。怎么做?
以下是我到目前为止完成的代码片段,它将为所有意图生成映射,即 (#intent, #intent)
维度图。
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
import seaborn as sns
%matplotlib inline
sns.set(rc={'figure.figsize':(64,64)})
confusion_matrix = pd.crosstab(df['ground_truth_intent_name'], df['predicted_intent_name'])
variables = sorted(list(set(df['ground_truth_intent_name'])))
temp = DataFrame(confusion_matrix, index=variables, columns=variables)
sns.heatmap(temp, annot=True)
TL;DR
这里temp
是一个pandas dataframe
。我需要删除所有元素为零的所有行和列(忽略对角线元素,即使它们不为零)。
你可以用any
来比较,但首先你需要用0
填充对角线:
# also consider using
# a = np.isclose(confusion_matrix.to_numpy(), 0)
a = confusion_matrix.to_numpy() != 0
# fill diagonal
np.fill_diagonal(a, False)
# columns with at least one non-zero
cols = a.any(axis=0)
# rows with at least one non-zero
rows = a.any(axis=1)
# boolean indexing
confusion_matrix.loc[rows, cols]
举个例子:
# random data
np.random.seed(1)
# this would agree with the above
a = np.random.randint(0,2, (5,5))
a[2] = 0
a[:-1,-1] = 0
confusion_matrix = pd.DataFrame(a)
因此数据将是:
0 1 2 3 4
0 1 1 0 0 0
1 1 1 1 1 0
2 0 0 0 0 0
3 0 0 1 0 0
4 0 1 0 0 1
和代码输出(注意第 2 行和第 4 列消失了):
0 1 2 3
0 1 1 0 0
1 1 1 1 1
3 0 0 1 0
4 0 1 0 0