使用 Tf.math.is_nan 函数训练期间的 Tensorflow NaN 损失

Tensorflow NaN loss during training with Tf.math.is_nan function

我编写了一个自定义损失函数,当地面实况标签(6d 向量)为 NaN 时 return 损失为 0,否则 return 为均方误差。标签中的所有 6 个特征都是 NaN,或者没有 NaN。

我的损失函数看起来像:

tf.reduce_mean(tf.where(tf.math.is_nan(true_labels), tf.zeros_like(true_labels),
tf.square(tf.subtract(true_labels, predicted_labels))))

其中 true_labels 和 predicted_labels 的形状为 (batch_size, 6),并且只有任一矩阵的整行可以为 NaN。在这种情况下,我得到 NaN 损失值,即使当基本事实为 NaN 时我应该 returning 0 作为损失。我还测试了这个问题,方法是在预处理期间将所有 NaN 值替换为一个大的负数(-1e4,这超出了我的数据范围),然后通过使用 [= 在我的损失函数中测试 NaNs 16=]

tf.where(tf.math.less(true_labels, -9999), tf.zeros_like(true_labels),
tf.square(tf.subtract(true_labels, predicted_labels)))

这是一个彻头彻尾的 hack,但仍然有效 none。因此,我认为问题出在 tf.math.is_nan 函数上,但我不知道为什么它会给我带来 NaN 损失。此外,我已经在我人工制作的一些标签上测试了训练模式之外的损失函数,然后它不会 return NaNs。感谢任何帮助。

下面是我的模型。它 return 是一个 (batch_size, 6) 形状的 Tensor。第一列是 sigmoid 激活位于 [0,1] 并被送入二元交叉熵损失函数(我没有包括在这里,但确认 NaN 不是来自二元损失)。其余 5 列被送入上面定义的自定义损失函数。

def custom_activation(tensor):
    first_node_sigmoid = tf.nn.sigmoid(tensor[:, :1])
    return tf.concat([first_node_sigmoid, tensor[:, 1:]], axis = 1)


def gen_model():
    IMAGE_SIZE = 200
    CONV_PARAMS = {"kernel_size": 3, "use_bias": False, "padding": "same"}
    CONV_PARAMS2 = {"kernel_size": 5, "use_bias": False, "padding": "same"}

    model = Sequential()
    model.add(
        Reshape((IMAGE_SIZE, IMAGE_SIZE, 1), input_shape=(IMAGE_SIZE, IMAGE_SIZE))
    )
    model.add(Conv2D(16, **CONV_PARAMS))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPool2D())
    model.add(Conv2D(32, **CONV_PARAMS))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPool2D())
    model.add(Conv2D(64, **CONV_PARAMS))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(64, **CONV_PARAMS))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(64, **CONV_PARAMS2))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPool2D())
    model.add(Conv2D(128, **CONV_PARAMS2))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPool2D())
    model.add(Conv2D(128, **CONV_PARAMS2))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPool2D())
    model.add(Flatten())
    model.add(Dense(64))
    model.add(Dense(6))
    model.add(tf.keras.layers.Lambda(custom_activation, name = "final_activation_layer"))
    return model

这是第一个特征为真 (1) 时真实标签的示例:

 [  1.         106.         189.           2.64826314  19.
   26.44962941]

当第一个特征为False(0)时,标签为

[0, nan, nan, nan, nan, nan]

更新:

在使用 tf.print 语句进行一些调试后,我发现我的 'predicted_labels' 输出为所有 NaN 值。当我使用上述 'hack' 时不会出现此问题,因此我认为这不是我的数据的问题。当用作网络输入时,我还检查了 none 我的图像在预处理后是否包含任何 NaN。不知何故,使用上述损失函数,我在预测值中得到了 NaN,但我不知道为什么。我试过降低学习率和批量大小,但这没有帮助。

也许像下面这样的东西对你有用。所有 nan 元素首先转换为 0,而其余元素保持不变。例如,[0, np.nan, np.nan, np.nan, np.nan, np.nan] 导致 [0, 0, 0, 0, 0, 0][1., 106., 189., 2.64826314, 19., 26.44962941] 保持不变。之后,您的损失仅针对非零值计算。如果 true_labels 为零,那么你只需 return 0.

import tensorflow as tf
import numpy as np

def custom_loss(true_labels, predicted_labels):

   true_labels = tf.where(tf.math.is_nan(true_labels), tf.zeros_like(true_labels), true_labels)
   loss = tf.reduce_mean(
       tf.where(tf.equal(true_labels, 0.0), true_labels,
       tf.square(tf.subtract(true_labels, predicted_labels))))
   return loss

def custom_activation(tensor):
    first_node_sigmoid = tf.nn.sigmoid(tensor[:, :1])
    return tf.concat([first_node_sigmoid, tensor[:, 1:]], axis = 1)


def gen_model():
    IMAGE_SIZE = 200
    CONV_PARAMS = {"kernel_size": 3, "use_bias": False, "padding": "same"}
    CONV_PARAMS2 = {"kernel_size": 5, "use_bias": False, "padding": "same"}

    model = tf.keras.Sequential()
    model.add(
        tf.keras.layers.Reshape((IMAGE_SIZE, IMAGE_SIZE, 1), input_shape=(IMAGE_SIZE, IMAGE_SIZE))
    )
    model.add(tf.keras.layers.Conv2D(16, **CONV_PARAMS))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('relu'))
    model.add(tf.keras.layers.MaxPool2D())
    model.add(tf.keras.layers.Conv2D(32, **CONV_PARAMS))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('relu'))
    model.add(tf.keras.layers.MaxPool2D())
    model.add(tf.keras.layers.Conv2D(64, **CONV_PARAMS))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('relu'))
    model.add(tf.keras.layers.Conv2D(64, **CONV_PARAMS))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('relu'))
    model.add(tf.keras.layers.Conv2D(64, **CONV_PARAMS2))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('relu'))
    model.add(tf.keras.layers.MaxPool2D())
    model.add(tf.keras.layers.Conv2D(128, **CONV_PARAMS2))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('relu'))
    model.add(tf.keras.layers.MaxPool2D())
    model.add(tf.keras.layers.Conv2D(128, **CONV_PARAMS2))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('relu'))
    model.add(tf.keras.layers.MaxPool2D())
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(64))
    model.add(tf.keras.layers.Dense(6))
    model.add(tf.keras.layers.Lambda(custom_activation, name = "final_activation_layer"))
    return model

Y_train = tf.constant([[1., 106., 189., 2.64826314, 19., 26.44962941], 
                       [0, np.nan, np.nan, np.nan, np.nan, np.nan]])
model = gen_model()
model.compile(loss=custom_loss, optimizer=tf.keras.optimizers.Adam())
model.fit(tf.random.normal((2, 200, 200)), Y_train, epochs=4)
Epoch 1/4
1/1 [==============================] - 1s 1s/step - loss: 4112.9380
Epoch 2/4
1/1 [==============================] - 0s 30ms/step - loss: 947.3030
Epoch 3/4
1/1 [==============================] - 0s 25ms/step - loss: 25.8993
Epoch 4/4
1/1 [==============================] - 0s 24ms/step - loss: 217.2151
<keras.callbacks.History at 0x7f8490b8db90>