Tensorflow 对象检测 API `indices[3] = 3 is not in [0, 3)` 错误

Tensorflow Object Detection API `indices[3] = 3 is not in [0, 3)` error

我正在重新训练 TF 对象检测 API 的 mobilenet(v1)-SSD,但我在训练步骤遇到了错误。

INFO:tensorflow:Starting Session.
INFO:tensorflow:Saving checkpoint to path xxxx/model.ckpt
INFO:tensorflow:Starting Queues.
INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, indices[3] = 3 is not in [0, 3)
         [[Node: cond_2/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1 = Gather[Tindices=DT_INT64, Tparams=DT_INT64, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](cond_2/Switch_3:1, cond_2/RandomCropImage/PruneCompleteleyOutsideWindow/Reshape)]]
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Caught OutOfRangeError. Stopping Training.
INFO:tensorflow:Finished training! Saving model to disk.
Traceback (most recent call last):
  File "object_detection/train.py", line 168, in <module>
    tf.app.run()
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 124, in run
    _sys.exit(main(argv))
  File "object_detection/train.py", line 165, in main
    worker_job_name, is_chief, FLAGS.train_dir)
  File "xxxx/research/object_detection/trainer.py", line 361, in train
    saver=saver)
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/tensorflow/contrib/slim/python/slim/learning.py", line 782, in train
    ignore_live_threads=ignore_live_threads)
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/tensorflow/python/training/supervisor.py", line 826, in stop
    ignore_live_threads=ignore_live_threads)
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/tensorflow/python/training/coordinator.py", line 387, in join
    six.reraise(*self._exc_info_to_raise)
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/six.py", line 693, in reraise
    raise value
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/tensorflow/python/training/queue_runner_impl.py", line 250, in _run
    enqueue_callable()
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1251, in _single_operation_run
    self._session, None, {}, [], target_list, status, None)
  File "/home/khatta/.virtualenvs/dl/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[3] = 3 is not in [0, 3)
         [[Node: cond_2/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1 = Gather[Tindices=DT_INT64, Tparams=DT_INT64, validate_indices=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](cond_2/Switch_3:1, cond_2/RandomCropImage/PruneCompleteleyOutsideWindow/Reshape)]]

开始准备数据量较大(约16K张图片)的TFRecords文件时出现此错误。
当我使用少量数据(大约 1K 图像)时,错误发生在大约 100 步训练后。错误代码结构相同。
TFRecord 创建脚本的结构如下所示;我想平铺大图像,以便注释在 SSD 的 300x300 调整大小步骤中不会变得太小,我认为它会提供更好的结果:

import tensorflow as tf
import pandas as pd
import hashlib

def _tiling(image_array, labels, tile_size=(300,300)):  
    '''tile image according to the tile_size argument'''  
    <do stuff>  
    yield tiled_image_array, tiled_label

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _int64_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _bytes_list_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _make_tfexample(tiled_image_array, tiled_label):
    img_str = cv2.imencode('.jpg', tiled_image_array)[1].tobytes()
    height, width, _ = tiled_image_array.shape
    # tiled_label's contents:
    # ['tilename', ['object_name', 'object_name', ...], 
    #  [xmin, xmin, ...], [ymin, ymin, ...], 
    #  [xmax, xmax, ...], [ymax, ymax, ...]]
    tile_name, object_names, xmins, ymins, xmaxs, ymaxs = tiled_label
    filename = bytes(tile_name, 'utf-8')
    image_format = b'jpeg'
    key = hashlib.sha256(img_str).hexdigest()
    xmins = [xmin/width for xmin in xmins]
    ymins = [ymin/height for ymin in ymins]
    xmaxs = [xmax/width for xmax in xmaxs]
    ymaxs = [ymax/height for ymax in ymaxs]
    classes_text = [bytes(obj, 'utf-8') for obj in object_names]
    # category => {'object_name': #id, ...}
    classes = [category[obj] for obj in obj_names]
    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        'image/filename': _bytes_feature(filename),
        'image/source_id': _bytes_feature(filename),
        'image/key/sha256': _bytes_feature(key.encode('utf-8')),
        'image/encoded': _bytes_feature(img_str),
        'image/format': _bytes_feature(image_format),
        'image/object/bbox/xmin': _float_list_feature(xmins),
        'image/object/bbox/ymin': _float_list_feature(ymins),
        'image/object/bbox/xmax': _float_list_feature(xmaxs),
        'image/object/bbox/ymax': _float_list_feature(ymaxs),
        'image/object/class/text': _bytes_list_feature(classes_text),
        'image/object/class/label': _int64_list_feature(classes)
    }))
    return tf_example

def make_tfrecord(image_path, csv_path, tfrecord_path):  
    '''convert image and labels into tfrecord file'''  
    csv = pd.read_csv(csv_path)
    with tf.python_io.TFRecordWriter(tfrecord_path) as writer:
        for row in csv.itertuples():  
            img_array = cv2.imread(image_path + row.filename)
            img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
            tile_generator = _tiling(image_array, row.label)
            for tiled_image_array, tiled_label in tile_generator:
                tf_example = _make_tfexample(tiled_image_array, tiled_labels)
                writer.write(tf_example.SerializeToString())

欢迎就此错误发生的原因提出任何建议。提前致谢!

这是由于 obj_names 列表的长度与其他列表元素的长度 (xmins, ymins, xmaxs, ymaxs, classes) 不匹配造成的。
原因是我的代码中存在错误,但如果您遇到类似错误并且需要一些提示来调试它,我只是发布此 FYI。

简而言之,你需要(在上面的_make_tfexample函数中)

xmins = [a_xmin, b_xmin, c_xmin]
ymins = [a_ymin, b_ymin, c_ymin]
xmaxs = [a_xmax, b_xmax, c_xmax]
ymaxs = [a_ymax, b_ymax, c_ymax]
classes_text = [a_class, b_class, c_class]
classes = [a_classid, b_classid, c_classid]

以便列表的索引相互匹配。但是当列表的长度由于某种原因不匹配时会发生错误。

我也遇到了同样的错误,我一页一页地寻找答案。不幸的是,数据和标签的形状并不是我收到此错误的原因。我在 Whosebug 的多个地方发现了同样的问题,所以检查 this 看看这是否解决了你的问题。