如何在 TensorFlow 2 中获取 Keras 张量的值?

How to get value of a Keras tensor in TensorFlow 2?

TF1 有 sess.run().eval() 来获取张量的值——而 Keras 有 K.get_value();现在,它们的工作方式都不一样了(之前的两个完全不同)。

K.eager(K.get_value)(tensor) 似乎通过退出它在 Keras 图形内部工作,并且 K.get_value(tensor) 在图形之外 - 两者都带有 TF2 的默认设置(即 off在前)。但是,如果 tensorKeras 后端操作:

,则会失败
import keras.backend as K
def tensor_info(x):
    print(x)
    print("Type: %s" % type(x))
    try:        
        x_value = K.get_value(x)
    except:
        try:    x_value = K.eager(K.get_value)(x)
        except: x_value = x.numpy()
    print("Value: %s" % x_value)  # three methods

ones = K.ones(1)
ones_sqrt = K.sqrt(ones)

tensor_info(ones); print()
tensor_info(ones_sqrt)
<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([1.], dtype=float32)>
Type: <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>
Value: [1.]

Tensor("Sqrt:0", shape=(1,), dtype=float32)
Type: <class 'tensorflow.python.framework.ops.Tensor'>
# third print fails w/ below
AttributeError: 'Tensor' object has no attribute 'numpy' 


这在 TF < 2.0 中不是问题。 Github一直沉默。我知道重写代码的方法作为解决方法,但它会消除 Keras 的后端中立性并类似于 tf.keras。有没有办法在保持后端中立性的同时在 TensorFlow 2.0 中获取 Keras 2.3 张量值?

我想你想要 K.eval:

>>> v = K.ones(1)
>>> K.eval(v)
array([1.], dtype=float32)
>>> K.eval(K.sqrt(v))
array([1.], dtype=float32)

请注意,K.get_value 保留用于变量(例如此处的 v),而 K.eval 适用于任何张量。

根据我的 PR,这是更可靠(但不能保证)的解决方法:

def K_eval(x):
    try:
        return K.get_value(K.to_dense(x))
    except:
        eval_fn = K.function([], [x])
        return eval_fn([])[0]

更新:注意分布上下文Tensor将在其中进行评估;在 TF2.2 中,在 tf.python.distribute.distribution_strategy_context.in_replica_context() == True 下创建的 tf.Variabletf.Tensor 将失败任何 K.eval-etc 尝试。看起来张量根本不应该在那里进行评估。

我想你要找的是 tf.keras.backend.get_value API.

print(x)
>>tf.Tensor([1.], shape=(1,), dtype=float32)
print(tf.keras.backend.get_value(x))
>>[1.]

在我的例子中,tensorflow 2.0 在打印损失时有效:

 import tensorflow as tf
 from tensorflow import keras

...

  print(loss_value)
  print(float(loss_value) )

输出:

    tf.Tensor(2.3782592, shape=(), dtype=float32)
    2.3782591819763184