ValueError 使用 tensorflow.metrics.Recall(class_id=1)

ValueError using tensorflow.metrics.Recall(class_id=1)

使用 Python 3.8.3tensorflow 版本 2.4.1

想在 tensorflow.metrics 中使用参数 class_id,例如 Recall(参见 documentation

这里是重现该问题的最小代码段。 下面的代码因 class_id=1

而崩溃
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.layers import SimpleRNN
from sklearn.model_selection import train_test_split
from tensorflow.keras import metrics
import numpy as np
#generate data
max_length  = 200
width = 3
n_samples = 100
data = np.random.rand(n_samples, max_length, width)
label = np.random.randint(0, high =2, size = n_samples)
train_size = 0.8
x_train, x_test, y_train, y_test = train_test_split(data, label, train_size = train_size)

#create a model
rnn_size = 16
sequence_input = Input(shape=(max_length,width,), dtype='float32')
x = SimpleRNN(rnn_size)(sequence_input)
preds = Dense(1, activation='sigmoid')(x)
model = Model(sequence_input, preds)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[metrics.Recall(class_id=1)])
#fit
BATCH_SIZE = 32
history = model.fit(x_train, y_train, epochs=1, batch_size=BATCH_SIZE)

抛出 ValueError

ValueError: slice index 1 of dimension 1 out of bounds. for '{{node strided_slice_1}} = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=1, end_mask=0, new_axis_mask=0, shrink_axis_mask=2](Cast_1, strided_slice_1/stack, strided_slice_1/stack_1, strided_slice_1/stack_2)' with input shapes: [?,1], [2], [2], [2] and with computed input tensors: input[1] = <0 1>, input[2] = <0 2>, input[3] = <1 1>.

但它适用于 metrics.Recall(class_id=0)

metrics.Precision(class_id=1) 以及使用 class_id 的所有其他指标可能出现相同的错误(我还没有全部尝试过)。

我无法解读错误消息的含义,也无法在网上找到任何相关内容来回答我的问题。

文档指出:

class_id (Optional): Integer class ID for which we want binary metrics. This must be in the half-open interval [0, num_classes), where num_classes is the last dimension of predictions.

当您使用 sigmoid 时,您的输出包含导致此错误的形状:(1, )。如果您修改网络以进行二进制 class 化,则输出将是 class 1.

的 sigmoid 概率

所以对于二进制 classification 情况,默认情况下您将获得 class 1 的 Precision 和 Recall,如果您想获得 class 0 那么您需要定义自己的指标.可以找到示例 .

相对错误来自这里 (source code):

if class_id is not None:
  y_true = y_true[..., class_id]
  y_pred = y_pred[..., class_id]

在您的示例中,标签应该是单热编码的:

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.layers import SimpleRNN
from sklearn.model_selection import train_test_split
from tensorflow.keras import metrics
from tensorflow.keras.utils import to_categorical
import numpy as np
#generate data
max_length  = 200
width = 3
n_samples = 100
data = np.random.rand(n_samples, max_length, width)
label = np.random.randint(0, high =2, size = n_samples)
label = to_categorical(label, 2)
train_size = 0.8
x_train, x_test, y_train, y_test = train_test_split(data, label, train_size = train_size)

#create a model
rnn_size = 16
sequence_input = Input(shape=(max_length,width), dtype='float32')
x = SimpleRNN(rnn_size)(sequence_input)
preds = Dense(2, activation='softmax')(x)
model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=[metrics.Precision(class_id=1),
                                                                          metrics.Recall(class_id=1)])
#fit
BATCH_SIZE = 32
history = model.fit(x_train, y_train, epochs=16, batch_size=BATCH_SIZE,
                    validation_data = (x_test, y_test))

Epoch 16/16
3/3 [==============================] - 0s 86ms/step - loss: 0.6771 - precision: 0.5676 - 
       recall: 0.5250 - val_loss: 0.6419 - val_precision: 0.2222 - val_recall: 0.6667

通过sklearn验证结果:

from sklearn.metrics import classification_report
print(classification_report(np.argmax(y_test, axis = -1), 
                            np.argmax(model.predict(x_test, batch_size = 1), 
                                      axis= -1), digits = 4))

              precision    recall  f1-score   support

           0     0.9091    0.5882    0.7143        17
           1     0.2222    0.6667    0.3333         3

    accuracy                         0.6000        20
   macro avg     0.5657    0.6275    0.5238        20
weighted avg     0.8061    0.6000    0.6571        20

如果您在上一个示例中更改 class_id = 0,它将计算 class 0 的指标。