tf.keras.losses.CategoricalCrossentropy 给出与普通实现不同的值

tf.keras.losses.CategoricalCrossentropy gives different values than plain implementation

有人知道为什么分类交叉熵函数的原始实现与 tf.keras 的 api 函数如此不同吗?

import tensorflow as tf
import math
tf.enable_eager_execution()

y_true =np.array( [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
y_pred = np.array([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]])

ce = tf.keras.losses.CategoricalCrossentropy()
res = ce(y_true, y_pred).numpy()
print("use api:")
print(res)

print()
print("implementation:")
step1 = -y_true * np.log(y_pred )
step2 = np.sum(step1, axis=1)

print("step1.shape:", step1.shape)
print(step1)
print("sum step1:", np.sum(step1, ))
print("mean step1", np.mean(step1))

print()
print("step2.shape:", step2.shape)
print(step2)
print("sum step2:", np.sum(step2, ))
print("mean step2", np.mean(step2))

以上给出:

use api:
0.3239681124687195

implementation:
step1.shape: (3, 3)
[[0.10536052 0.         0.        ]
 [0.         0.11653382 0.        ]
 [0.         0.         0.0618754 ]]
sum step1: 0.2837697356318653
mean step1 0.031529970625762814

step2.shape: (3,)
[0.10536052 0.11653382 0.0618754 ]
sum step2: 0.2837697356318653
mean step2 0.09458991187728844

如果现在与另一个 y_truey_pred:

y_true = np.array([[0, 1]])
y_pred = np.array([[0.99999999999, 0.00000000001]])

它给出:

use api:
16.11809539794922

implementation:
step1.shape: (1, 2)
[[-0.         25.32843602]]
sum step1: 25.328436022934504
mean step1 12.664218011467252

step2.shape: (1,)
[25.32843602]
sum step2: 25.328436022934504
mean step2 25.328436022934504

不同之处在于这些值:[.5, .89, .6],因为它的总和不等于 1。我认为您弄错了,您的意思是:[.05, .89, .06].

如果您提供的值总和等于 1,则两个公式的结果将相同:

import tensorflow as tf
import numpy as np

y_true = np.array( [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
y_pred = np.array([[.9, .05, .05], [.05, .89, .06], [.05, .01, .94]])

print(tf.keras.losses.categorical_crossentropy(y_true, y_pred).numpy())
print(np.sum(-y_true * np.log(y_pred), axis=1))

#output
#[0.10536052 0.11653382 0.0618754 ]
#[0.10536052 0.11653382 0.0618754 ]

但是,让我们探讨一下如果 y_pred 张量没有缩放(值之和不等于 1)是如何计算的?如果您查看分类交叉熵 here 的源代码,您会看到它缩放 y_pred 以便每个样本的 class 概率总和为 1:

if not from_logits:
    # scale preds so that the class probas of each sample sum to 1
    output /= tf.reduce_sum(output,
                            reduction_indices=len(output.get_shape()) - 1,
                            keep_dims=True)

因为我们传递了一个 probas 之和不​​为 1 的 pred,让我们看看这个操作如何改变我们的张量 [.5, .89, .6]:

output =  tf.constant([.5, .89, .6])
output /= tf.reduce_sum(output,
                            axis=len(output.get_shape()) - 1,
                            keepdims=True)
print(output.numpy())

# array([0.2512563 , 0.44723618, 0.30150756], dtype=float32)

所以,如果我们替换上面的操作输出(缩放y_pred),并将其传递给您自己实现的分类交叉熵,将未缩放的y_pred传递给tensorflow实现,它应该是相等的:

y_true =np.array( [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])

#unscaled y_pred
y_pred = np.array([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]])  
print(tf.keras.losses.categorical_crossentropy(y_true, y_pred).numpy())

#scaled y_pred (categorical_crossentropy scales above tensor to this internally)
y_pred = np.array([[.9, .05, .05], [0.2512563 , 0.44723618, 0.30150756], [.05, .01, .94]])  
print(np.sum(-y_true * np.log(y_pred), axis=1))

输出:

[0.10536052 0.80466845 0.0618754 ]
[0.10536052 0.80466846 0.0618754 ]

现在,让我们研究第二个示例的结果。为什么你的第二个例子显示不同的输出? 如果你再次检查源代码,你会看到这一行:

output = tf.clip_by_value(output, epsilon, 1. - epsilon)

剪切低于阈值的值。您输入的 [0.99999999999, 0.00000000001] 将在此行中转换为 [0.9999999, 0.0000001],因此它会为您提供不同的结果:

y_true = np.array([[0, 1]])
y_pred = np.array([[0.99999999999, 0.00000000001]])

print(tf.keras.losses.categorical_crossentropy(y_true, y_pred).numpy())
print(np.sum(-y_true * np.log(y_pred), axis=1))

#now let's first clip the values less than epsilon, then compare loss
epsilon=1e-7
y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
print(tf.keras.losses.categorical_crossentropy(y_true, y_pred).numpy())
print(np.sum(-y_true * np.log(y_pred), axis=1))

输出:

#results without clipping values
[16.11809565]
[25.32843602]

#results after clipping values if there is a value less than epsilon (1e-7)
[16.11809565]
[16.11809565]