Tensorflow:尝试迁移学习时出错:无效的 JPEG 数据或裁剪 window

Tensorflow: Error when trying transfer learning: Invalid JPEG data or crop window

我正在尝试使用他们的教程 here 将我自己的自定义图像数据集塑造成 Tensorflow 上预训练 MobileNet 模型的正确输入形状。 我的代码:

batch_size = 256
epochs = 15
traindir = pathlib.Path('/train')
valdir = pathlib.Path('/validation')
list_ds = tf.data.Dataset.list_files(str(traindir/'*/*'))
val_list_ds = tf.data.Dataset.list_files(str(valdir/'*/*'))
CLASS_NAMES = np.array([item.name for item in valdir.glob('*') if item.name != "LICENSE.txt"])
def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  return parts[-2] == CLASS_NAMES

def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_HEIGHT, IMG_WIDTH])

def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=5)
labeled_val_ds = val_list_ds.map(process_path, num_parallel_calls=5)
train_batches = labeled_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = labeled_val_ds.batch(BATCH_SIZE)
for image_batch, label_batch in train_batches.take(1):


之后我继续学习 TF 教程 运行sfer learning here。但是,我 运行 遇到了这个问题,我怀疑 JPEG 图像已损坏或迭代器缺少 of/problem?:

Epoch 1/10
 21/330 [>.............................] - ETA: 14:02 - loss: 3.9893 - accuracy: 0.0326
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-87-11afdc6d5aef> in <module>
      1 history = model.fit(train_batches,
      2                     epochs=initial_epochs,
----> 3                     validation_data=validation_batches)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     68     # Running inside `run_distribute_coordinator` already.

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    846                 batch_size=batch_size):
    847               callbacks.on_train_batch_begin(step)
--> 848               tmp_logs = train_function(iterator)
    849               # Catch OutOfRangeError for Datasets of unknown size.
    850               # This blocks until the batch has finished executing.

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    582     if tracing_count == self._get_tracing_count():

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    609       # In this case we have created variables on the first call, so we run the
    610       # defunned version which is guaranteed to never create variables.
--> 611       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    612     elif self._stateful_fn is not None:
    613       # Release the lock early so that multiple threads can perform the call

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2418     with self._lock:
   2419       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2420     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2422   @property

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs)
   1663          if isinstance(t, (ops.Tensor,
   1664                            resource_variable_ops.BaseResourceVariable))),
-> 1665         self.captured_inputs)
   1667   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1744       # No tape is watching; skip to running the function.
   1745       return self._build_call_outputs(self._inference_function.call(
-> 1746           ctx, args, cancellation_manager=cancellation_manager))
   1747     forward_backward = self._select_forward_and_backward_functions(
   1748         args,

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    596               inputs=args,
    597               attrs=attrs,
--> 598               ctx=ctx)
    599         else:
    600           outputs = execute.execute_with_cancellation(

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  Invalid JPEG data or crop window, data size 34228
     [[{{node DecodeJpeg}}]]
  (1) Invalid argument:  Invalid JPEG data or crop window, data size 34228
     [[{{node DecodeJpeg}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_30787]

Function call stack:
train_function -> train_function

感谢您的宝贵时间! 编辑:在重新 运行 代码几次之后,它似乎会产生不同数据大小的相同错误,如 16384....

编辑: 是的,问题在于一些 .jpeg 实际上是伪装的 .png,或者它们只是简单地损坏了。我强烈建议在使用数据训练任何模型之前检查数据完整性。

我遇到了类似的问题。你的一些训练数据有问题。您可以使用下面的代码来检查哪个 jpeg 图像已损坏并将其删除。

from struct import unpack
from tqdm import tqdm
import os

marker_mapping = {
    0xffd8: "Start of Image",
    0xffe0: "Application Default Header",
    0xffdb: "Quantization Table",
    0xffc0: "Start of Frame",
    0xffc4: "Define Huffman Table",
    0xffda: "Start of Scan",
    0xffd9: "End of Image"

class JPEG:
    def __init__(self, image_file):
        with open(image_file, 'rb') as f:
            self.img_data = f.read()
    def decode(self):
        data = self.img_data
            marker, = unpack(">H", data[0:2])
            # print(marker_mapping.get(marker))
            if marker == 0xffd8:
                data = data[2:]
            elif marker == 0xffd9:
            elif marker == 0xffda:
                data = data[-2:]
                lenchunk, = unpack(">H", data[2:4])
                data = data[2+lenchunk:]            
            if len(data)==0:

bads = []

for img in tqdm(images):
  image = osp.join(root_img,img)
  image = JPEG(image) 

for name in bads:

我使用 yasoob 脚本来解码 jpeg 图像。