TensorFlow:"Cannot capture a stateful node by value" 在 tf.contrib.data API

TensorFlow: "Cannot capture a stateful node by value" in tf.contrib.data API

对于迁移学习,人们通常使用网络作为特征提取器来创建特征数据集,在该数据集上训练另一个分类器(例如 SVM)。

我想使用数据集 API (tf.contrib.data) and dataset.map():

# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
    dataset = inputs(...)  # This creates a dataset of (image, label) pairs

    def map_example(image, label):
        features = feature_extractor(image, trainable=False)
        #  Leaving out initialization from a checkpoint here... 
        return features, label

    dataset = dataset.map(map_example)

    return dataset

为数据集创建迭代器时这样做失败。

ValueError: Cannot capture a stateful node by value.

这是真的,网络的内核和偏差是变量,因此是有状态的。对于这个特定的例子,他们不必是。

有没有办法让 Ops,特别是 tf.Variable objects 无状态?

因为我使用的是 tf.layers,所以我不能简单地将它们创建为常量,设置 trainable=False 也不会创建常量,只是不会将变量添加到 GraphKeys.TRAINABLE_VARIABLES collection.

不幸的是,tf.Variable 本质上是有状态的。但是,仅当您使用 Dataset.make_one_shot_iterator() 创建迭代器时才会出现此错误。* 为避免此问题,您可以改为使用 Dataset.make_initializable_iterator(),但需要注意的是您还必须 运行 iterator.initializer 在返回的迭代器上 运行 为输入管道中使用的 tf.Variable 对象设置初始值设定项。


* 此限制的原因是 Dataset.make_one_shot_iterator() 的实现细节以及用于封装数据集定义的正在进行的 TensorFlow 函数 (Defun) 支持。由于使用查找表和变量等有状态资源比我们最初想象的更受欢迎,我们正在研究放宽此限制的方法。