如何使用从 TFRecords 读取的值作为 tf.reshape 的参数?

How can I use values read from TFRecords as arguments to tf.reshape?

def read_and_decode(filename_queue):
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      # Defaults are not specified since both keys are required.
      features={
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64),
          'depth': tf.FixedLenFeature([], tf.int64)
      })
  # height = tf.cast(features['height'],tf.int32)
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image = tf.reshape(image,[32, 32, 3])
  image = tf.cast(image,tf.float32)
  label = tf.cast(features['label'], tf.int32)
  return image, label

我正在使用 TFRecord 来存储我的所有数据。函数 read_and_decode 来自 TensorFlow 提供的 TFRecords 示例。目前我通过预定义值重塑:

image = tf.reshape(image,[32, 32, 3])

但是,我现在要使用的数据是不同维度的。例如,我可以有一个 [40, 30, 3] 的图像(缩放这不是一个选项,因为我不希望它被扭曲)。我想读入不同维度的数据,在数据扩充阶段使用random_crop来规避这个问题。我需要的是类似下面的内容。

height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.reshape(image,[height, width, 3])

但是,我似乎找不到执行此操作的方法。感谢您的帮助!

编辑:

ValueError: All shapes must be fully defined: [TensorShape([Dimension(None), Dimension(None), Dimension(None)]), TensorShape([])]

image = tf.reshape(image, tf.pack([height, width, 3]))
image = tf.reshape(image, [32,32,3])

问题肯定出在这两行上。硬编码变量有效,但带有 tf.pack().

的变量无效

您即将找到可行的解决方案!现在没有 自动 方法来给 TensorFlow 一个由张量和数字组成的列表并从中生成一个张量,tf.reshape() is expecting. The answer is to use tf.stack(),它明确地采用 N 维列表张量(或可转换为张量的事物)并将它们打包成 (N+1) 维张量。

这意味着你可以写:

features = ...  # Parse from an example proto.
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)

image = tf.reshape(image, tf.stack([height, width, 3]))

我遇到了同样的问题。根据Tensorflow documentation,如果你尝试使用shuffle_batch,你会遇到这种情况 读取所需数据后操作。

像这个例子,如果不使用shuffle_batch处理,可以加载动态维度文件。

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
             'clip_height': tf.FixedLenFeature([], tf.int64),
             'clip_width': tf.FixedLenFeature([], tf.int64),
             'clip_raw': tf.FixedLenFeature([], tf.string),
             'clip_label_raw': tf.FixedLenFeature([], tf.int64)
        })
    image = tf.decode_raw(features['clip_raw'], tf.float64)
    label = tf.cast(features['clip_label_raw'], tf.int32)
    height = tf.cast(features['clip_height'], tf.int32)
    width = tf.cast(features['clip_width'], tf.int32)
    im_shape = tf.stack([height, width, -1])
    new_image = tf.reshape(image, im_shape )

但是如果要使用shuffle批处理,就不能使用tf.stack。您必须静态地定义与此类似的维度。

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
             'clip_height': tf.FixedLenFeature([], tf.int64),
             'clip_width': tf.FixedLenFeature([], tf.int64),
             'clip_raw': tf.FixedLenFeature([], tf.string),
             'clip_label_raw': tf.FixedLenFeature([], tf.int64)
        })
    image = tf.decode_raw(features['clip_raw'], tf.float64)
    label = tf.cast(features['clip_label_raw'], tf.int32)
    height = tf.cast(features['clip_height'], tf.int32)
    width = tf.cast(features['clip_width'], tf.int32)
    image = tf.reshape(image, [1, 512, 1])
    images, sparse_labels = tf.train.shuffle_batch(
            [image, label], batch_size=batch_size, num_threads=2,
            capacity=1000 + 3 * batch_size,
            min_after_dequeue=100)

@mrry 如有错误请指正