tf.train.GradientDescentOptimizer() 的无效类型 tf.complex64
Invalid type tf.complex64 for tf.train.GradientDescentOptimizer()
我正在处理复杂的神经网络。我创建了一个可以正常工作的网络:
[...]
gradients = tf.gradients(mse, [weights])[0]
training_op = tf.assign(weights, weights - learning_rate * gradients)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
sess.run(training_op)
现在当我尝试使用时:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(mse)
对于第 training_op = optimizer.minimize(mse)
行,我得到以下内容:
ValueError: Invalid type tf.complex64 for weights:0, expected: [tf.float32, tf.float64, tf.float16, tf.bfloat16].
复杂真的不支持吗?还是我做错了什么?我对实值网络进行了同样的尝试,它工作正常,所以我相信我的结构是正确的。
新见解:
根据this。
最小化分为两部分:
compute_gradients
apply_gradients
如果我们单独测试它们,错误发生在compute_gradients
方法上。
所以 运行 tf.gradients
有效但 运行 optimizer.compute_gradients
无效?这越来越奇怪了。有人知道原因吗?
为了社区的利益,从评论部分提供答案。
由于优化器目前不支持复杂类型(即使支持复杂梯度,如您所述),您可能只对实部和虚部使用单独的变量。您也可以考虑编写自己的优化器并重写 valid_dtypes
方法,如下例所示。
def complex_mul_real( c, r ):
return tf.complex(tf.real(c)*r, tf.imag(c)*r)
此外,可以在 Github issue.
上找到有关该主题的详细讨论
我正在处理复杂的神经网络。我创建了一个可以正常工作的网络:
[...]
gradients = tf.gradients(mse, [weights])[0]
training_op = tf.assign(weights, weights - learning_rate * gradients)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
sess.run(training_op)
现在当我尝试使用时:
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(mse)
对于第 training_op = optimizer.minimize(mse)
行,我得到以下内容:
ValueError: Invalid type tf.complex64 for weights:0, expected: [tf.float32, tf.float64, tf.float16, tf.bfloat16].
复杂真的不支持吗?还是我做错了什么?我对实值网络进行了同样的尝试,它工作正常,所以我相信我的结构是正确的。
新见解:
根据this。 最小化分为两部分:
compute_gradients
apply_gradients
如果我们单独测试它们,错误发生在compute_gradients
方法上。
所以 运行 tf.gradients
有效但 运行 optimizer.compute_gradients
无效?这越来越奇怪了。有人知道原因吗?
为了社区的利益,从评论部分提供答案。
由于优化器目前不支持复杂类型(即使支持复杂梯度,如您所述),您可能只对实部和虚部使用单独的变量。您也可以考虑编写自己的优化器并重写 valid_dtypes
方法,如下例所示。
def complex_mul_real( c, r ):
return tf.complex(tf.real(c)*r, tf.imag(c)*r)
此外,可以在 Github issue.
上找到有关该主题的详细讨论