BatchDataSet:获取 img 数组和标签
BatchDataSet: get img array and labels
这是我之前创建的用于拟合模型的批处理数据集:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode = 'categorical', #it is used for multiclass classification. It is one hot encoded labels for each class
validation_split = 0.2, #percentage of dataset to be considered for validation
subset = "training", #this subset is used for training
seed = 1337, # seed is set so that same results are reproduced
image_size = img_size, # shape of input images
batch_size = batch_size, # This should match with model batch size
)
valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode ='categorical',
validation_split = 0.2,
subset = "validation", #this subset is used for validation
seed = 1337,
image_size = img_size,
batch_size = batch_size,
)
如果我运行一个for循环,我可以访问img数组和标签:
for images, labels in train_ds:
print(labels)
但是如果我尝试像这样访问它们:
尝试 1)
images, labels = train_ds
我得到以下值错误:ValueError: too many values to unpack (expected 2)
尝试 2:
如果我尝试这样解压它:
images = train_ds[:,0] # get the 0th column of all rows
labels = train_ds[:,1] # get the 1st column of all rows
我收到以下错误:TypeError: 'BatchDataset' object is not subscriptable
有没有一种方法可以让我在不通过 for 循环的情况下提取标签和图像?
对于您的具体情况,train_ds 将是一个张量对象,其中的每个元素都是一个元组:(image,label)
。
可能会尝试类似的方法:
# train_ds = [(image,label) …]
images = train_ds[:,0] # get the 0th column of all rows
labels = train_ds[:,1] # get the 1st column of all rows
只需取消批处理您的数据集并将数据转换为列表:
import tensorflow as tf
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 32
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir, validation_split=0.2, subset="training",
seed=123, batch_size=batch_size)
train_ds = train_ds.unbatch()
images = list(train_ds.map(lambda x, y: x))
labels = list(train_ds.map(lambda x, y: y))
print(len(labels))
print(len(images))
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
2936
2936
这是我之前创建的用于拟合模型的批处理数据集:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode = 'categorical', #it is used for multiclass classification. It is one hot encoded labels for each class
validation_split = 0.2, #percentage of dataset to be considered for validation
subset = "training", #this subset is used for training
seed = 1337, # seed is set so that same results are reproduced
image_size = img_size, # shape of input images
batch_size = batch_size, # This should match with model batch size
)
valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
train_path,
label_mode ='categorical',
validation_split = 0.2,
subset = "validation", #this subset is used for validation
seed = 1337,
image_size = img_size,
batch_size = batch_size,
)
如果我运行一个for循环,我可以访问img数组和标签:
for images, labels in train_ds:
print(labels)
但是如果我尝试像这样访问它们:
尝试 1)
images, labels = train_ds
我得到以下值错误:ValueError: too many values to unpack (expected 2)
尝试 2:
如果我尝试这样解压它:
images = train_ds[:,0] # get the 0th column of all rows
labels = train_ds[:,1] # get the 1st column of all rows
我收到以下错误:TypeError: 'BatchDataset' object is not subscriptable
有没有一种方法可以让我在不通过 for 循环的情况下提取标签和图像?
对于您的具体情况,train_ds 将是一个张量对象,其中的每个元素都是一个元组:(image,label)
。
可能会尝试类似的方法:
# train_ds = [(image,label) …]
images = train_ds[:,0] # get the 0th column of all rows
labels = train_ds[:,1] # get the 1st column of all rows
只需取消批处理您的数据集并将数据转换为列表:
import tensorflow as tf
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 32
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir, validation_split=0.2, subset="training",
seed=123, batch_size=batch_size)
train_ds = train_ds.unbatch()
images = list(train_ds.map(lambda x, y: x))
labels = list(train_ds.map(lambda x, y: y))
print(len(labels))
print(len(images))
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
2936
2936