Keras 中带有附加变量输入的自定义 loss/objective 函数
Custom loss/objective function with additional variable input in Keras
我正在尝试在 Keras(tensorflow 后端)中创建自定义 objective 函数,其中包含一个附加参数,该参数的值取决于正在训练的批次。
例如:
def myLoss(self, stateValues):
def sparse_loss(y_true, y_pred):
foo = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
return tf.reduce_mean(foo * stateValues)
return sparse_loss
self.model.compile(loss=self.myLoss(stateValue = self.stateValue),
optimizer=Adam(lr=self.alpha))
我的train函数如下
for batch in batches:
self.stateValue = computeStateValueVectorForCurrentBatch(batch)
model.fit(xVals, yVals, batch_size=<num>)
然而,损失函数中的stateValue没有被更新。它只是使用 stateValue 在 model.compile 步骤中的值。
我想这可以通过为 stateValue 使用 placeHolder 来解决,但我不知道该怎么做。有人可以帮忙吗?
你的损失函数没有得到更新,因为 keras 没有在每批之后编译模型,因此没有使用更新的损失函数。
您可以定义一个自定义回调,它会在每批次后更新损失值。像这样:
from keras.callbacks import Callback
class UpdateLoss(Callback):
def on_batch_end(self, batch, logs={}):
# I am not sure what is the type of the argument you are passing for computing stateValue ??
stateValue = computeStateValueVectorForCurrentBatch(batch)
self.model.loss = myLoss(stateValue)
我正在尝试在 Keras(tensorflow 后端)中创建自定义 objective 函数,其中包含一个附加参数,该参数的值取决于正在训练的批次。
例如:
def myLoss(self, stateValues):
def sparse_loss(y_true, y_pred):
foo = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
return tf.reduce_mean(foo * stateValues)
return sparse_loss
self.model.compile(loss=self.myLoss(stateValue = self.stateValue),
optimizer=Adam(lr=self.alpha))
我的train函数如下
for batch in batches:
self.stateValue = computeStateValueVectorForCurrentBatch(batch)
model.fit(xVals, yVals, batch_size=<num>)
然而,损失函数中的stateValue没有被更新。它只是使用 stateValue 在 model.compile 步骤中的值。
我想这可以通过为 stateValue 使用 placeHolder 来解决,但我不知道该怎么做。有人可以帮忙吗?
你的损失函数没有得到更新,因为 keras 没有在每批之后编译模型,因此没有使用更新的损失函数。
您可以定义一个自定义回调,它会在每批次后更新损失值。像这样:
from keras.callbacks import Callback
class UpdateLoss(Callback):
def on_batch_end(self, batch, logs={}):
# I am not sure what is the type of the argument you are passing for computing stateValue ??
stateValue = computeStateValueVectorForCurrentBatch(batch)
self.model.loss = myLoss(stateValue)