读取 Tensorflow 记录文件,在第一个 运行 之后不起作用

Reading Tensorflow Record Files, does not work after first run

我有一小段代码可以从一些 TFRecord 文件中读取数据。如果我 运行 来自 ipython 笔记本的代码,它在我第一次执行该块时工作正常。但是,如果我尝试在不重新启动内核的情况下第二次执行同一代码块,则会产生错误(错误:StatusNotOK:未找到:FetchOutputs 节点 DecodeRaw_2:0:未找到)。代码如下所示。我是否需要 close/clear/reinitialize 一些东西才能使代码 运行 正确多次?

filename_queue = tf.train.string_input_producer(filename_list)
init = tf.initialize_all_variables()
image = []
label = []
with tf.Session() as sess:
    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    tf_image, tf_label = read_and_decode(filename_queue)
    for i in range(len(filename_list)):
        image.append(sess.run(tf_image))
        label.append(sess.run(tf_label))

    coord.request_stop()
    coord.join(threads)

请注意 read_and_decode() 取自 here

默认 tf.命令使用新名称附加到默认图形。您可以在 运行 代码段之前使用 tf.reset_default_graph() 第二次清除默认图表。

问题中的代码有一些问题。

  1. 第一个, is that all of the ops are added to the same graph. The means that when you call tf.train.start_queue_runners() (or run the tf.initialize_all_variables() op) the session will be doing work that is proportional to the number of times that you've run this code snippet. You can call tf.reset_default_graph() 在调用此代码之间,但更清晰的隔离方法可能是每次都声明一个单独的图:

    with tf.Graph().as_default():  # Declares a new graph for the life of the block.
        filename_queue = tf.train.string_input_producer(filename_list)
        init = tf.initialize_all_variables()
        image = []
        label = []
        with tf.Session() as sess:
            # ...
            coord.join(threads)
    
  2. 第二个问题是单独调用sess.run(tf_image)sess.run(tf_label)意味着图像和标签之间的关联丢失了。当您调用 sess.run(tf_image) 时,您使用了图像 来自 reader 的标签,但丢弃了标签(对于 sess.run(tf_label) 反之亦然。正确的解决方案是在同一步骤中获取它们:

    image_val, label_val = sess.run([tf_image, tf_label])
    image.append(image_val)
    label.append(label_val)
    
  3. 最后一个问题(即使您重置图形也可能导致问题)是代码在调用 tf.train.start_queue_runners() 后向图形添加节点。 TensorFlow 图上存在数据竞争的可能性,因为 read_and_decode() 向图中添加节点,而并行队列运行器并发读取它,并且 tf.Graph 不是线程安全的。

    处理此问题的最佳方法是在启动队列运行器之前定义所有图表:

    with tf.Graph().as_default():
        filename_queue = tf.train.string_input_producer(filename_list)
        image = []
        label = []
        tf_image, tf_label = read_and_decode(filename_queue)
    
        with tf.Session() as sess:
           coord = tf.train.Coordinator()
           threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
           for i in range(len(filename_list)):
               image_val, label_val = sess.run([tf_image, tf_label])
               image.append(image_val)
               label.append(label_val)
    
           coord.request_stop()
           coord.join(threads)