我可以使用 TensorFlow 中的自动微分访问特定节点的错误信号吗?
Can I access the error signals at particular nodes using automatic differentiation in TensorFlow?
我已经想到我可以使用 compute_gradients
获得变量和梯度的压缩列表,我想它对应于 ∂E/∂θ
,其中 θ
是一个变量条目。还有一种方法可以访问特定节点的错误信号,例如它们通常如何定义为 ∂E/∂a
,其中 a
是激活,例如输入的仿射变换 Wx + b
,或者在这种情况下我是否需要实现自己的反向传播算法?
您可以通过在 tf.gradients
的 xs
参数中包含激活来获得与激活有关的错误。浏览器
tf.reset_default_graph()
x = tf.placeholder(dtype=tf.float32)
a = x*x
E = 2*a
(dEda, dEdx) = tf.gradients(E, xs=[a, x])
sess = tf.Session()
sess.run([dEda, dEdx], feed_dict={x: 1})
你应该看到
[2.0, 4.0]
我已经想到我可以使用 compute_gradients
获得变量和梯度的压缩列表,我想它对应于 ∂E/∂θ
,其中 θ
是一个变量条目。还有一种方法可以访问特定节点的错误信号,例如它们通常如何定义为 ∂E/∂a
,其中 a
是激活,例如输入的仿射变换 Wx + b
,或者在这种情况下我是否需要实现自己的反向传播算法?
您可以通过在 tf.gradients
的 xs
参数中包含激活来获得与激活有关的错误。浏览器
tf.reset_default_graph()
x = tf.placeholder(dtype=tf.float32)
a = x*x
E = 2*a
(dEda, dEdx) = tf.gradients(E, xs=[a, x])
sess = tf.Session()
sess.run([dEda, dEdx], feed_dict={x: 1})
你应该看到
[2.0, 4.0]