训练时关于损失函数的 Tensorflow 维度大小/等级问题
Tensorflow dimension size / rank issue on loss function when training
我有一个自定义损失函数似乎是有效的(当运行这两行时没有错误,如下所示)
alpha_cost = 2
cost = tf.reduce_mean(tf.where(tf.less(Y * out, 0), tf.squeeze((alpha_cost*out)**2 - tf.sign(Y)*out + tf.abs(Y)), tf.squeeze(tf.abs(Y - out))))
然而,当我实际继续进行模型训练(批量大小 10,000)时,使用此损失函数会产生以下错误...
InvalidArgumentError (see above for traceback): Inputs to operation
Select_4 of type Select must have the same size and shape. Input 0:
[1,10000] != input 1: [10000] [[Node: Select_4 = Select[T=DT_FLOAT,
_device="/job:localhost/replica:0/task:0/device:GPU:0"](Less_7, Squeeze_6, Squeeze_7)]]
可能是只需要添加另一个 tf.squeeze 来压缩尺寸的情况...但我已经尝试了几次但没有完全配合,所以这种方法可能存在更大的问题...感谢您的帮助!
使用 numpy 模块,
import numpy as np
和定义
input0 = tf.where(tf.less(Y * out, 0)
input1 = tf.squeeze((alpha_cost*out)**2 - tf.sign(Y)*out + tf.abs(Y))
尝试将 input1 从 rank-1 数组重塑为 row-vector:
input1 = np.reshape(input1, (1, 10000))
赋予它与 input0 相同的形状。您可以断言 input1 具有正确的形状:
assert(input1.shape == (1,10000))
并检查 input0 和 input1 是否具有相同的维度:
assert(input0.shape == input1.shape)
我有一个自定义损失函数似乎是有效的(当运行这两行时没有错误,如下所示)
alpha_cost = 2
cost = tf.reduce_mean(tf.where(tf.less(Y * out, 0), tf.squeeze((alpha_cost*out)**2 - tf.sign(Y)*out + tf.abs(Y)), tf.squeeze(tf.abs(Y - out))))
然而,当我实际继续进行模型训练(批量大小 10,000)时,使用此损失函数会产生以下错误...
InvalidArgumentError (see above for traceback): Inputs to operation Select_4 of type Select must have the same size and shape. Input 0: [1,10000] != input 1: [10000] [[Node: Select_4 = Select[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](Less_7, Squeeze_6, Squeeze_7)]]
可能是只需要添加另一个 tf.squeeze 来压缩尺寸的情况...但我已经尝试了几次但没有完全配合,所以这种方法可能存在更大的问题...感谢您的帮助!
使用 numpy 模块,
import numpy as np
和定义
input0 = tf.where(tf.less(Y * out, 0)
input1 = tf.squeeze((alpha_cost*out)**2 - tf.sign(Y)*out + tf.abs(Y))
尝试将 input1 从 rank-1 数组重塑为 row-vector:
input1 = np.reshape(input1, (1, 10000))
赋予它与 input0 相同的形状。您可以断言 input1 具有正确的形状:
assert(input1.shape == (1,10000))
并检查 input0 和 input1 是否具有相同的维度:
assert(input0.shape == input1.shape)