keras 自定义损失 - 忽略零标签
keras custom loss - ignore zero labels
我正在尝试训练序列标记模型 (LSTM),其中序列标签是 1
(第一个 class)、2
(第二个 class ) 或 0
(无关紧要)。
我尝试编写自己的忽略零的损失函数:
import keras.backend as K
def my_loss(y_true, y_pred):
"""(sum([(t-p)**2 for t,p in zip(y_true, y_pred)])/n_nonzero)**0.5"""
return K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_true), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
本质上只计算非零的均方误差。
但是,我在训练模型时得到 loss=nan
。
我做错了什么?
在训练过程中忽略某些标签的标准方法是什么?
它不起作用的原因是它必须是:
K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_pred), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
而不是:
K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_true), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
因为您用 y_true
减去 y_true
而不是 y_pred
。
通过删除 axis=-1 的参数对我有用:
K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_true), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
当 y_true
大于 0 时,y_true
减去 y_pred
。特别是,当 y_true
等于零时,项 y_pred*K.cast(y_true>0, "float32") - y_true)
也是等于零,因为它用 0 减去 0,在 y_true>0
的情况下,术语 K.cast(y_true>0, "float32")
取值 1,否则它得到 0。
我正在尝试训练序列标记模型 (LSTM),其中序列标签是 1
(第一个 class)、2
(第二个 class ) 或 0
(无关紧要)。
我尝试编写自己的忽略零的损失函数:
import keras.backend as K
def my_loss(y_true, y_pred):
"""(sum([(t-p)**2 for t,p in zip(y_true, y_pred)])/n_nonzero)**0.5"""
return K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_true), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
本质上只计算非零的均方误差。
但是,我在训练模型时得到 loss=nan
。
我做错了什么?
在训练过程中忽略某些标签的标准方法是什么?
它不起作用的原因是它必须是:
K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_pred), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
而不是:
K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_true), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
因为您用 y_true
减去 y_true
而不是 y_pred
。
通过删除 axis=-1 的参数对我有用:
K.sqrt(K.sum(K.square(y_pred*K.cast(y_true>0, "float32") - y_true), axis=-1) / K.sum(K.cast(y_true>0, "float32") ))
当 y_true
大于 0 时,y_true
减去 y_pred
。特别是,当 y_true
等于零时,项 y_pred*K.cast(y_true>0, "float32") - y_true)
也是等于零,因为它用 0 减去 0,在 y_true>0
的情况下,术语 K.cast(y_true>0, "float32")
取值 1,否则它得到 0。