如何修复 'Invalid argument: Got 3 frames, but animated gifs can only be decoded by tf.image.decode_gif or tf.image.decode_image'

How to fix 'Invalid argument: Got 3 frames, but animated gifs can only be decoded by tf.image.decode_gif or tf.image.decode_image'

我正在使用 deeplab 框架对具有 4 个以上信息通道的图像(Github:https://github.com/tensorflow/models/tree/master/research/deeplab,tensorflow 版本 1.14.0)进行分类。

我想到了将单独的频道放入 .gif 文件中,并使用 build_voc_2012.py[=37= 的修改版本读取它们],以及修改后的data_generator.py。其他一切都保留在回购协议中。

分片生成以及 train.py 似乎 运行 很好。在eval.py 抛出错误时遇到此问题。

这是生成分片的代码。


"""Contains common utility functions and classes for building dataset.

This script contains utility functions and classes to converts dataset to
TFRecord file format with Example protos.

The Example proto contains the following fields:

  image/encoded: encoded image content.
  image/filename: image filename.
  image/format: image file format.
  image/height: image height.
  image/width: image width.
  image/channels: image channels.
  image/segmentation/class/encoded: encoded semantic segmentation content.
  image/segmentation/class/format: semantic segmentation file format.
"""
import collections
import six
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_enum('image_format', 'png', ['jpg', 'jpeg', 'png', 'gif'],
                         'Image format.')

tf.app.flags.DEFINE_enum('label_format', 'png', ['png'],
                         'Segmentation label format.')

# A map from image format to expected data format.
_IMAGE_FORMAT_MAP = {
    'jpg': 'jpeg',
    'jpeg': 'jpeg',
    'png': 'png',
    'gif': 'gif'
}


class ImageReader(object):
    """Helper class that provides TensorFlow image coding utilities."""

    def __init__(self, image_format = "jpeg", channels=3):
        """Class constructor.

        Args:
          image_format: Image format. Only 'jpeg', 'jpg', or 'png' are supported.
          channels: Image channels.
        """
        with tf.Graph().as_default():
            self._decode_data = tf.placeholder(dtype=tf.string)
            self._image_format = image_format
            self._session = tf.Session()
            self.channels = channels
            if self._image_format in ('jpeg', 'jpg'):
                self._decode = tf.image.decode_jpeg(self._decode_data,channels)
            elif self._image_format == 'png':
                self._decode = tf.image.decode_png(self._decode_data,channels)
            elif self._image_format == 'gif':
                self._decode = tf.image.decode_gif(self._decode_data)


    def read_image_dims_gif(self, gif_data):
        """Reads the image dimensions.

        Args:
            image_data: numpy array of image data.

        Returns:
            image_height and image_width.
        """

        image = self.decode_gif(gif_data)
        return image.shape[:4]

    def decode_gif(self, image_data):
        """Decodes the image data string.

        Args:
          image_data: string of image data.

        Returns:
          Decoded image data.

        Raises:
          ValueError: Value of image channels not supported.
        """
        image = self._session.run(self._decode,
                                  feed_dict={self._decode_data: image_data})

        return image


def _float64_list_feature(values):
    """Returns a TF-Feature of float_list.

    Args:
      values: A scalar or list of values.

    Returns:
      A TF-Feature.
    """

    if not isinstance(values, collections.Iterable):
        values = [values]

    return tf.train.Feature(float_list=tf.train.FloatList(value=values))


def _int64_list_feature(values):
    """Returns a TF-Feature of int64_list.

    Args:
      values: A scalar or list of values.

    Returns:
      A TF-Feature.
    """
    if not isinstance(values, collections.Iterable):
        values = [values]

    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def _bytes_list_feature(values):
    """Returns a TF-Feature of bytes.

    Args:
      values: A string.

    Returns:
      A TF-Feature.
    """

    def norm2bytes(value):
        return value.encode() if isinstance(value, str) and six.PY3 else value

    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[norm2bytes(values)]))


def image_seg_to_tfexample_gif(image_data, filename, height, width, seg_data, channels, frames):
    """Converts one image/segmentation pair to tf example.

    Args
     image_data: encoded image data
     filename: image filename.
     height: image height.
     width: image width.
     frames: number of frames in gif
     seg_data: string of semantic segmentation data.
     channels: int of number of image channels
    Returns:
     tf example of one image/segmentation pair.
    """
    return tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': _bytes_list_feature(image_data),
        'image/filename': _bytes_list_feature(filename),
        'image/format': _bytes_list_feature(
            _IMAGE_FORMAT_MAP[FLAGS.image_format]),
        'image/height': _int64_list_feature(height),
        'image/width': _int64_list_feature(width),
        'image/channels': _float64_list_feature(channels),
        'image/segmentation/class/encoded': (
            _bytes_list_feature(seg_data)),
        'image/segmentation/class/format': _bytes_list_feature(
            FLAGS.label_format),
        }))

在 eval.py 中,这段代码似乎产生了错误:

tf.contrib.training.evaluate_repeatedly(
        master=FLAGS.master,
        checkpoint_dir=FLAGS.checkpoint_dir,
        eval_ops=[update_op],
        max_number_of_evaluations=num_eval_iters,
        hooks=hooks,
        eval_interval_secs=FLAGS.eval_interval_secs)

报错信息如下:

Traceback (most recent call last):
  File "/home/user/models-master/research/deeplab/eval.py", line 188, in <module>
    tf.app.run()
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/user/.local/lib/python2.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/user/.local/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/user/models-master/research/deeplab/eval.py", line 181, in main
    eval_interval_secs=FLAGS.eval_interval_secs)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/contrib/training/python/training/evaluation.py", line 453, in evaluate_repeatedly
    session.run(eval_ops, feed_dict)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 754, in run
    run_metadata=run_metadata)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1252, in run
    run_metadata=run_metadata)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1353, in run
    raise six.reraise(*original_exc_info)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1338, in run
    return self._sess.run(*args, **kwargs)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1411, in run
    run_metadata=run_metadata)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/training/monitored_session.py", line 1169, in run
    return self._sess.run(*args, **kwargs)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
    run_metadata)
  File "/home/user/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument: Got 3 frames, but animated gifs can only be decoded by tf.image.decode_gif or tf.image.decode_image
     [[{{node cond/else/_1/DecodePng}}]]
     [[IteratorGetNext]]
  (1) Invalid argument: Got 3 frames, but animated gifs can only be decoded by tf.image.decode_gif or tf.image.decode_image
     [[{{node cond/else/_1/DecodePng}}]]
     [[IteratorGetNext]]
     [[mean_iou/confusion_matrix/assert_less_1/Assert/AssertGuard/Assert/data_1/_2007]]
0 successful operations.
0 derived errors ignored.

Link 到我的 repo 如下:https://github.com/michael-ross-scott/DeeplabV3

长话短说,我已经解决了我的 eval.py 问题,在 data_generator.py 中发现有以下代码行试图解码我编码的 .gif 碎片:


  def _parse_function(self, example_proto):
    """Function to parse the example proto.

    Args:
      example_proto: Proto in the format of tf.Example.

    Returns:
      A dictionary with parsed image, label, height, width and image name.

    Raises:
      ValueError: Label is of wrong shape.
    """
    def _decode_image(content, channels):
        return tf.cond(
          tf.image.is_jpeg(content),
          lambda: tf.image.decode_jpeg(content, channels),
          lambda: tf.image.decode_png(content, channels))

    features = {
        'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/filename':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'image/width':
            tf.FixedLenFeature((), tf.int64, default_value=0),
        'channels':
            tf.FixedLenFeature((), tf.int64, default_value=3),
        'image/segmentation/class/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
        'image/segmentation/class/format':
            tf.FixedLenFeature((), tf.string, default_value='png'),
    }

    parsed_features = tf.parse_single_example(example_proto, features)

    image = _decode_image(parsed_features['image/encoded'], channels=3)

    label = None
    if self.split_name != common.TEST_SET:
      label = _decode_image(
          parsed_features['image/segmentation/class/encoded'], channels=1)

我决定将 decode_image 解耦为两个函数:decode_image 和 decode_label。我只修改了图像和标签变量来调用适当的函数

    def _decode_image(content, channels):
          return tf.cond(
          tf.image.is_jpeg(content),
          lambda: tf.image.decode_jpeg(content, channels),
          lambda: tf.image.decode_gif(content))

    def _decode_label(content,channels):
          return tf.cond(
          tf.image.is_jpeg(content),
          lambda: tf.image.decode_jpeg(content, channels),
          lambda: tf.image.decode_png(content, channels))

    image = _decode_image(parsed_features['image/encoded'], channels=3)
    label = _decode_label(
          parsed_features['image/segmentation/class/encoded'], channels=1)

唉,core/utils.py

中出现了另一个错误
(0) Invalid argument: input must be 4-dimensional[1,1,512,512,3]
     [[node ResizeBilinear (defined at /models-master/research/deeplab/core/utils.py:34) ]]
     [[IteratorGetNext]]
     [[Mean_49/_6009]]
  (1) Invalid argument: input must be 4-dimensional[1,1,512,512,3]
     [[node ResizeBilinear (defined at /models-master/research/deeplab/core/utils.py:34) ]]
     [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored.

长话短说;如果不对该代码库进行大量修改,可能无法修改 deeplabv3 以读取 .gif 文件。