在自定义损失函数中使用 tf.while_loop 的正确方法是什么?

What's the correct way to use tf.while_loop in a custom loss function?

我打算使用以下函数作为我训练的损失:

import tensorflow as tf

def wrap(dist): 
    return tf.while_loop(
        cond=lambda X: tf.math.abs(X) > 0.5,
        body=lambda X: tf.math.subtract(X, 1.0),
        loop_vars=(dist))


# PBC-aware MSE, period = 1.0 ([0, 1.0])
def custom_loss(y_true, y_pred):
    diff = tf.math.abs(y_true - y_pred)
    diff = tf.nest.flatten(diff)
    diff = tf.vectorized_map(wrap, diff)
    return tf.math.reduce_mean(tf.math.square(diff))

# ...other code for loading data and defining the model

model.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.1),
              loss=custom_loss)

但是我遇到了一堆错误信息。由于日志太长,我把它们放在一个要点中: https://gist.github.com/HanatoK/f75fddd82372f499c37279f1128cad7a

上面代码的等效 numpy 版本应该是

def wrap_diff2(x, y, period=1.0):
    diff = np.abs(x - y)
    while diff > 0.5 * period:
        diff -= period
    return diff * diff

def custom_loss_numpy(y_true, y_pred):
    diff2 = np.vectorize(wrap_diff2)(y_true, y_pred)
    return np.mean(diff2)

有什么想法吗? 完整的代码示例在 google colab 上共享: https://colab.research.google.com/drive/1ExVHgyKHQfGcpXvo5ZsuBBmzmHzxUekC?usp=sharing

试试这个:

import tensorflow as tf
import numpy as np

def wrap(dist): 
    return tf.while_loop(
        cond=lambda X: tf.math.abs(X) > 0.5,
        body=lambda X: tf.math.subtract(X, 1.0),
        loop_vars=(dist))

def custom_loss(y_true, y_pred):
    diff = tf.math.abs(y_true - y_pred)
    diff = tf.reshape(diff, [-1])
    diff = tf.vectorized_map(wrap, [diff])
    return tf.math.reduce_mean(tf.math.square(diff))

y_true = np.array([[0., 1., 1.0], [0., 0., 0.]])
y_pred = np.array([[1., 1., 1.0], [1., 0., 1.]])
custom_loss(y_true, y_pred).numpy()