Tensorflow @tf.function - 无法在 Tensorflow 图形函数中获取 session

Tensorflow @tf.function - Cannot get session inside Tensorflow graph function

我正在尝试将 @tf.function 指令与 Keras 函数 API 结合使用,以在简单神经网络的训练步骤中创建一个 TF 图。我正在使用与 Python 3.7 一起安装的 Tensorflow v 2.1.0。 但是,我得到了标题中的运行时错误,我将不胜感激任何提示以了解其原因。

代码如下

import tensorflow as tf
import numpy as np

# import the CIFAR10 dataset and normalise the feature distributions                                                                                                                             
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()                                                                                         
train_images = train_images / np.max(train_images)
test_images  = test_images / np.max(train_images)

# convert the datasets to tf.data, batching the data                                                                                                                    
train_data = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).batch(128)
test_data  = tf.data.Dataset.from_tensor_slices((test_images,  test_labels)).batch(128)

# make a model with a single dense layer
# note that the flatten layer is needed to convert the                                                                                                         
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(units = 10, activation = "relu"))

# compile the model                                                                                                                                  
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001),
              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
              metrics = ["accuracy"])

# training step
@tf.function
def train(model, train_data, test_data):
    model.fit(x = train_data,
              validation_data = test_data,
              epochs = 10)

    return

# train the model                                                                                                                                    
train(model = model, train_data = train_data, test_data = test_data)

我在运行时得到的错误如下。

2020-04-01 11:33:27.084545: W tensorflow/core/framework/cpu_allocator_impl.cc:81] Allocation of 1228800000 exceeds 10% of system memory.
Traceback (most recent call last):
  File "report.py", line 41, in <module>
    train(model = model, train_data = train_data, test_data = test_data)
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
    result = self._call(*args, **kwds)
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
    *args, **kwds))
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
RuntimeError: in converted code:

    report.py:34 train  *
        model.fit(x = train_data,
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:819 fit
        use_multiprocessing=use_multiprocessing)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py:648 fit
        shuffle=shuffle)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2346 _standardize_user_data
        all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py:2523 _build_model_with_inputs
        inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py:1677 extract_tensors_from_dataset
        iterator = get_iterator(dataset)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py:1658 get_iterator
        initialize_iterator(iterator)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py:1665 initialize_iterator
        K.get_session((init_op,)).run(init_op)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:493 get_session
        session = _get_session(op_input_list)
    /home/alessio/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py:453 _get_session
        raise RuntimeError('Cannot get session inside Tensorflow graph function.')

    RuntimeError: Cannot get session inside Tensorflow graph function.

请注意,与之前相同的代码在没有 @tf.function 指令的情况下运行良好。 另一方面,我在不同的数据集和不同的模型上得到了同样的错误。

提前致谢。

查看文档 https://www.tensorflow.org/guide/function it isn't clear to me that the function you have defined could be compiled into a graph. I think it is meant to be used on functions that get used in a Lambda layer https://www.tensorflow.org/api_docs/python/tf/keras/layers/Lambda 或类似文档。

您已经在模型上调用了编译,它将把它转换成图形,无需再做任何事情。

我的猜测是它抛出错误,因为它不知道如何从 model.fit 调用构建图形,但错误消息非常混乱。

如果你尝试像

这样的简单算术函数
@tf.function
def add(x, y):
    return x + y

add(1, 2)

这现在输出一个张量:

<tf.Tensor: shape=(), dtype=int32, numpy=3>

TensorFlow 很快。在您真正了解库中发生的事情并且知道存在问题之前,我不会担心性能。