如何从 TFRecordDataset 获取张量的形状

How to get shape of tensor from TFRecordDataset

我在训练 TFRecord 中写入了以下功能:

feature = {'label': _int64_feature(gt),
           'image': _bytes_feature(tf.compat.as_bytes(im.tostring())),
           'height': _int64_feature(h),
           'width': _int64_feature(w)}

我读起来像:

train_dataset = tf.data.TFRecordDataset(train_file)
train_dataset = train_dataset.map(parse_func)
train_dataset = train_dataset.shuffle(buffer_size=1)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.prefetch(batch_size)

而我的 parse_func 看起来像这样:

def parse_func(ex):
    feature = {'image': tf.FixedLenFeature([], tf.string),
               'label': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
               'height': tf.FixedLenFeature([], tf.int64),
               'width': tf.FixedLenFeature([], tf.int64)}
    features = tf.parse_single_example(ex, features=feature)
    image = tf.decode_raw(features['image'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    im_shape = tf.stack([width, height])
    image = tf.reshape(image, im_shape)
    label = tf.cast(features['label'], tf.int32)
    return image, label

现在,我想获得 imagelabel 的形状,例如:

image.get_shape().as_list()

打印
[None, None, None]
而不是
[None, 224, 224](图片大小(batch, width, height))

有什么函数可以给出这些张量的大小吗?

由于您的地图函数 "parse_func" 作为操作是图形的一部分,并且它不知道您输入的固定大小和先验标签,因此使用 get_shape() 不会return 预期的固定形状。

如果您的图像、标签形状是固定的,作为一种 hack,您可以尝试使用已知尺寸重塑图像、标签(这实际上不会做任何事情,但会明确设置尺寸输出张量)。

例如。 图片= tf.reshape(图片,[224,224])

有了这个,您应该能够按预期获得 get_shape() 结果。

另一个解决方案是存储编码图像而不是解码原始字节,这样你只需要在读取 tfrecords 时使用 tensorflow 将图像解码回来,这也将帮助你节省存储空间,这样你就可以从张量中得到图像形状。

    # Load your image as you would normally do then do:

    # Convert the image to raw bytes.
    img_bytes = tf.io.encode_jpeg(img).numpy()

    # Create a dict with the data we want to save in the
    # TFRecords file. You can add more relevant data here.
    data = \
    {'image': wrap_bytes(img_bytes),
     'label': wrap_int64(label)}

    # Wrap the data as TensorFlow Features.
    feature = tf.train.Features(feature=data)

    # Wrap again as a TensorFlow Example.
    example = tf.train.Example(features=feature)

    # Serialize the data.
    serialized = example.SerializeToString()
            
    # Write the serialized data to the TFRecords file.
    writer.write(serialized) 

然后阅读你可以使用:

    features = \
        {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)            
        }

    # Parse the serialized data so we get a dict with our data.
    parsed_example = tf.io.parse_single_example(serialized=serialized,
                                             features=features)

    # Get the image as raw bytes.
    image_raw = parsed_example['image']

    # Decode the raw bytes so it becomes a tensor with type.
    image = tf.io.decode_jpeg(image_raw)
    
    image = tf.cast(image, tf.float32) # optional
    
    # Get the label associated with the image.
    label = parsed_example['label']
    
    # The image and label are now correct TensorFlow types.
    return image, label