"LookupError: Function `__class__` does not exist." When using tf.function

"LookupError: Function `__class__` does not exist." When using tf.function

我正在尝试使用在 WebAssembly 二进制文件中抽象的自定义损失函数来覆盖 keras/tf2.0 损失函数。这是相关代码。

@tf.function
def custom_loss(y_true, y_pred):
    return tf.constant(instance.exports.loss(y_true, y_pred))

我就是这样用的

model.compile(optimizer_obj, loss=custom_loss)
# and then model.fit(...)

我不完全确定 tf2.0 eager execution 是如何工作的,所以任何关于它的见解都会很有用。

我认为 instance.exports.loss 函数与错误无关,但是如果您确定其他一切都没有问题,请告诉我,我将分享更多详细信息。

这是堆栈跟踪和实际错误: https://pastebin.com/6YtH75ay

首先,您不需要使用@tf.function来定义自定义损失。

我们可以愉快地(如果毫无意义地)做这样的事情:

def custom_loss(y_true, y_pred):
  return tf.reduce_mean(y_pred)

model.compile(optimizer='adam',
              loss=custom_loss,
              metrics=['accuracy'])

只要我们在custom_loss中使用的所有操作都可以被tensorflow

区分

所以你可以删除 @tf.function 装饰器,但我怀疑你会 运行 变成这样的错误消息:

[Some trace info] - Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

因为 tensorflow 无法在 webassembly 二进制文件中找到函数的梯度。损失函数中的所有内容都必须是 tensorflow 可以理解和计算梯度的东西,否则它将无法针对较低的损失值进行优化。

也许最好的方法是使用 tensorflow 可以计算梯度的操作来复制 instance.exports.loss 中的功能,而不是尝试直接引用它?