Keras 自定义损失函数 - 尽管返回与分类交叉熵相同的形状,但形状不匹配
Keras custom loss function - shape mismatch despite returning same shape as categorical crossentropy
我创建了一个基于余弦的自定义损失函数:
def cos_loss(y_true, y_pred):
norm_pred = tf.math.l2_normalize(y_pred)
dprod = tf.tensordot(
a=y_true,
b=norm_pred,
axes=1
)
return 1 - dprod
但是,使用此自定义损失训练模型会导致错误 In[0] mismatch In[1] shape: 2 vs. 8: [8,2] [8,2] 0 0
。如果我使用像分类交叉熵这样的内置损失函数,模型训练没有问题。
尽管我的自定义损失和分类交叉熵返回值的类型和形状完全相同。例如,我通过两者创建测试 y_true
和 y_pred
以及 运行 :
test_true = np.asarray([1.0, 0.0])
test_pred = np.asarray([0.9, 0.2])
print(cos_loss(test_true, test_pred))
print(tf.keras.losses.categorical_crossentropy(test_true, test_pred))
哪个returns:
> tf.Tensor(0.023812939816047263, shape=(), dtype=float64)
tf.Tensor(0.20067069546215124, shape=(), dtype=float64)
所以两者都给出了具有单个 float-64 值且没有形状的 TF 张量。那么,如果形状输出相同,为什么我会在一个而不是另一个上收到形状不匹配错误?谢谢。
您的损失函数应该能够接受一批预测和基本事实以及 return 一批损失值。目前,情况并非如此,因为 tensordot
和 axis=1
是矩阵乘法,当你开始引入批处理维度时,你会遇到维度冲突。
您或许可以改用以下内容:
def cos_loss(y_true, y_pred):
norm_pred = tf.math.l2_normalize(y_pred)
dprod = tf.reduce_sum(y_true*norm_pred, axis=-1)
return 1 - dprod
我创建了一个基于余弦的自定义损失函数:
def cos_loss(y_true, y_pred):
norm_pred = tf.math.l2_normalize(y_pred)
dprod = tf.tensordot(
a=y_true,
b=norm_pred,
axes=1
)
return 1 - dprod
但是,使用此自定义损失训练模型会导致错误 In[0] mismatch In[1] shape: 2 vs. 8: [8,2] [8,2] 0 0
。如果我使用像分类交叉熵这样的内置损失函数,模型训练没有问题。
尽管我的自定义损失和分类交叉熵返回值的类型和形状完全相同。例如,我通过两者创建测试 y_true
和 y_pred
以及 运行 :
test_true = np.asarray([1.0, 0.0])
test_pred = np.asarray([0.9, 0.2])
print(cos_loss(test_true, test_pred))
print(tf.keras.losses.categorical_crossentropy(test_true, test_pred))
哪个returns:
> tf.Tensor(0.023812939816047263, shape=(), dtype=float64)
tf.Tensor(0.20067069546215124, shape=(), dtype=float64)
所以两者都给出了具有单个 float-64 值且没有形状的 TF 张量。那么,如果形状输出相同,为什么我会在一个而不是另一个上收到形状不匹配错误?谢谢。
您的损失函数应该能够接受一批预测和基本事实以及 return 一批损失值。目前,情况并非如此,因为 tensordot
和 axis=1
是矩阵乘法,当你开始引入批处理维度时,你会遇到维度冲突。
您或许可以改用以下内容:
def cos_loss(y_true, y_pred):
norm_pred = tf.math.l2_normalize(y_pred)
dprod = tf.reduce_sum(y_true*norm_pred, axis=-1)
return 1 - dprod