如何调整 albumentations 标签中数据集标签的大小以使用 tensorflow image_dataset_from_directory 函数?

How resize dataset label in albumentations label to work with tensorflow image_dataset_from_directory function?

我是运行下面的代码: [https://pastebin.com/LK8tKZtN] 得到的错误如下:

File "C:\Users\Admin\PycharmProjects\BugsClassfications\main2.py", line 45, in set_shapes * label.set_shape([])

ValueError: Shapes must be equal rank, but are 1 and 0

函数 set_shape 如何与 image_dataset_from_directory 一起工作?

这是我的代码:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from albumentations import (Compose, HorizontalFlip,Rotate)
 
AUTOTUNE = tf.data.experimental.AUTOTUNE
 
def process_image(image, label, img_size):
  # cast and normalize image
  image = tf.image.convert_image_dtype(image, tf.float32)
  # apply simple augmentations
  image = tf.image.random_flip_left_right(image)
  image = tf.image.resize(image,[img_size, img_size])
  return image, label
 
transforms = Compose([
Rotate(limit=40),
HorizontalFlip()
])
 
 
def aug_fn(image, img_size):
  data = {"image":image}
  aug_data = transforms(**data)
  aug_img = aug_data["image"]
  aug_img = tf.cast(aug_img/255.0, tf.float32)
  aug_img = tf.image.resize(aug_img, size=[img_size, img_size])
  return aug_img
 
 
def process_data(image, label, img_size):
  aug_img = tf.numpy_function(func=aug_fn, inp=[image, img_size], Tout=tf.float32)
  return aug_img, label
 
 
def set_shapes(img, label, img_shape=(128,128,3)):
  img.set_shape(img_shape)
  label.set_shape([])
  return img, label
 
 
def view_image(ds):
  image, label = next(iter(ds))  # extract 1 batch from the dataset
  image = image.numpy()
  label = label.numpy()
  
  fig = plt.figure(figsize=(22, 22))
  for i in range(20):
    ax = fig.add_subplot(4, 5, i + 1, xticks=[], yticks=[])
    ax.imshow(image[i].astype(dtype=np.uint8))
    ax.set_title(f"Label: {label[i]}")
  plt.show()
 
 
train_dir = './dataset/train'
img_size = 128
data = tf.keras.utils.image_dataset_from_directory(train_dir, image_size=(img_size, img_size))
print(data)
 
#augmentation
ds_alb = data.map(partial(process_data, img_size = 128), num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
#resize
ds_alb = ds_alb.map(set_shapes, num_parallel_calls=AUTOTUNE).batch(32)
 
print(ds_alb)

如果您更改标签的形状,它应该可以工作:

def set_shapes(img, label, img_shape=(128,128,3)):
  img.set_shape(img_shape)
  label.set_shape([1,])
  return img, label

但您应该问问自己,为什么要明确设置数据的形状。检查这个 .