Keras 中带有附加动态参数的自定义自适应损失函数
Custom adaptive loss function with additional dynamic argument in Keras
我必须在 keras 中使用自适应自定义损失函数,该函数采用额外的动态参数 (eps
)。参数 eps
是一个 标量,但从一个样本变为另一个 :因此在训练期间应调整损失函数。我使用生成器,我可以在训练期间通过生成器的每次调用传递此参数 (generator_train[2]
)。根据对类似问题的回答,我尝试编写以下包装:
def custom_loss(eps):
def square_err(y_true, y_pred):
nom = K.sum(K.square(y_pred - y_true), axis=-1)
denom = eps**2
loss = nom/denom
return loss
return square_err
但是由于 eps
是一个动态变量,所以我很难实现它:我不知道在训练过程中我应该如何将这个参数传递给损失函数 (model.fit
)。这是我的模型的一个简单版本:
model = keras.Sequential()
model.add(layers.LSTM(units=32, input_shape=(32, 4))
model.add(layers.Dense(units=1))
model.add_loss(custom_loss)
opt = keras.optimizers.Adam()
model.compile(optimizer=opt)
history = model.fit(x=generator_train[0], y=generator_train[1],
steps_per_epoch=100
epochs=50,
validation_data=gen_vl,
validation_steps=n_vl)
非常感谢您的帮助。
只需传递“样本权重”,即每个样本 1/(eps**2)
。
你的生成器应该只输出 x, y, sample_weights
就这样。
您的损失可以是:
def loss(y_true, y_pred):
return K.sum(K.square(y_pred - y_true), axis=-1)
在fit
中,你不能在生成器中使用索引,你将只传递generator_train
,不传递x
,不传递y
,只传递generator_train
.
我必须在 keras 中使用自适应自定义损失函数,该函数采用额外的动态参数 (eps
)。参数 eps
是一个 标量,但从一个样本变为另一个 :因此在训练期间应调整损失函数。我使用生成器,我可以在训练期间通过生成器的每次调用传递此参数 (generator_train[2]
)。根据对类似问题的回答,我尝试编写以下包装:
def custom_loss(eps):
def square_err(y_true, y_pred):
nom = K.sum(K.square(y_pred - y_true), axis=-1)
denom = eps**2
loss = nom/denom
return loss
return square_err
但是由于 eps
是一个动态变量,所以我很难实现它:我不知道在训练过程中我应该如何将这个参数传递给损失函数 (model.fit
)。这是我的模型的一个简单版本:
model = keras.Sequential()
model.add(layers.LSTM(units=32, input_shape=(32, 4))
model.add(layers.Dense(units=1))
model.add_loss(custom_loss)
opt = keras.optimizers.Adam()
model.compile(optimizer=opt)
history = model.fit(x=generator_train[0], y=generator_train[1],
steps_per_epoch=100
epochs=50,
validation_data=gen_vl,
validation_steps=n_vl)
非常感谢您的帮助。
只需传递“样本权重”,即每个样本 1/(eps**2)
。
你的生成器应该只输出 x, y, sample_weights
就这样。
您的损失可以是:
def loss(y_true, y_pred):
return K.sum(K.square(y_pred - y_true), axis=-1)
在fit
中,你不能在生成器中使用索引,你将只传递generator_train
,不传递x
,不传递y
,只传递generator_train
.