tensorflow image_dataset_from_directory 通过索引列表获取方法的某些图片

tensorflow image_dataset_from_directory get certain pictures of the method by index list

我有一个名为 train_ds 的文件夹(不要被名字弄糊涂了,只是一个带图片的文件夹),其中有 5 个带图片的子文件夹.每个都是不同的 class.

我在这个 train_ds 文件夹中 运行 5 个不同的训练模型来获得推论。我想要的是明确地得到在哪些图片中所有模型都无法正确推断。为此:

所以我可以说,我已经为 train_ds 文件夹中的所有图像建立了索引,并且从所有这些图像中,我得到了哪些索引具有图像 class错了,适用于所有型号。

现在的问题是...如何从 image_dataset_from_directory 方法中获取与该索引关联的图片?

函数:

def inferences_target_list(model, data):
    '''
    returns 2 lists: inferences list, real labels
    '''
    # over train set fold1
    y_pred_float = model.predict(data)
    y_pred = np.argmax(y_pred_float, axis=1)

    # get real labels
    y_target = tf.concat([y for x, y in data], axis=0) 
    y_target
    print("lenght inferences and real labels: ", len(y_pred), len(y_target))
    return y_pred, y_target


def get_missclassified(y_pred, y_target):
  '''
  returns a list with the indexes of real labels that were missclassified
  '''
  missclassified = []
  for i, (pred, target) in enumerate(zip(y_pred, y_target.numpy().tolist())):
    if pred!=target:
      #print(i, pred, target)
      missclassified.append(i)
  print("total missclassified: ",len(missclassified))
  return missclassified

方法:

missclassified_train_folders=[]

for f in folders: # at the moment just 1 folder 
  print(f)
  for nn in models_dict: # dictionary of trained models
    print(nn)

    # -- train dataset for each folder
    train_path = reg_input+f+"/"+'train_ds/'
    # print("\n train dataset:", "\n", train_path)
    train_ds = image_dataset_from_directory(
        train_path,
        class_names=["Bedroom","Bathroom","Dinning","Livingroom","Kitchen"],
        seed=None,
        validation_split=None, 
        subset=None,
        image_size= image_size,
        batch_size= batch_size,
        color_mode='rgb',
        shuffle=False 
        )
    
    # inferences and real values
    y_pred, y_target = inferences_target_list(models_dict[nn], train_ds)
    
    # missclassified ones
    missclassified = get_missclassified(y_pred, y_target)
    print("elements missclassified in {} for model {}: ".format(f, nn), len(missclassified))
    missclassified_train_folders.append(missclassified)

我得到了索引列表,但我不知道如何应用它。

提前致谢! | (• ◡•)| (❍ᴥ❍ʋ)

image_dataset_from_directory 在后台使用 index_directory 函数来索引目录。基本上它使用 python sorted 对子目录进行排序,并使用 ThreadPool

循环遍历它们

您可以直接导入它并将其用于 return 文件路径、标签和索引。

查看: https://github.com/keras-team/keras/blob/d8fcb9d4d4dad45080ecfdd575483653028f8eda/keras/preprocessing/dataset_utils.py#L26

您可以使用类似这样的方法来获取数据集的索引格式

from keras.preprocessing.dataset_utils import index_directory

ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')
file_paths, labels, class_names = index_directory(directory="/path/to/train_ds", labels="inferred", formats=ALLOWLIST_FORMATS)

此外,保持随机播放至 False

另一种解决方案是使用train_ds.file_paths直接从train_ds对象推断file_paths,因为image_from_dataset在[=]中设置了一个属性file_paths 19=] 对象。请看这里https://github.com/keras-team/keras/blob/d8fcb9d4d4dad45080ecfdd575483653028f8eda/keras/preprocessing/image_dataset.py#L234

@ma7555 给出的是我正在寻找的简单解决方案,但是使用 ma755 方法输出的标签列表与使用 tf.concat([y for x, y 在 train_ds], axis=0).

train_ds是使用image_dataset_from_directory方法创建的,里面有5个子文件夹(mi 类).我目前得到的笨拙的解决方案是:

  • 使用 inferences_target_list
  • 获取推断标签和真实标签列表
  • 比较 2 个列表,检查哪些标签不同,并用 get_missclassified
  • 存储它们的索引
  • 获取 get_list_of_files 文件夹中的元素列表。这应该与 ma7555 的路径相同。我还没有检查订单是否相同
def inferences_target_list(model, data):
    '''
    returns 2 lists: inferences list, real labels
    '''
    # over train set fold1
    y_pred_float = model.predict(data)
    y_pred = np.argmax(y_pred_float, axis=1)

    # get real labels
    y_target = tf.concat([y for x, y in data], axis=0) 
    y_target
    print("lenght inferences and real labels: ", len(y_pred), len(y_target))
    return y_pred, y_target


def get_missclassified(y_pred, y_target):
  '''
  returns a list with the indexes of real labels that were missclassified
  '''
  missclassified = []
  for i, (pred, target) in enumerate(zip(y_pred, y_target.numpy().tolist())):
    if pred!=target:
      #print(i, pred, target)
      missclassified.append(i)
  print("total missclassified: ",len(missclassified))
  return missclassified

def get_list_of_files(dirName):
    '''
    create a list of file and sub directories names in the given directory
    found here => https://thispointer.com/python-how-to-get-list-of-files-in-directory-and-sub-directories/
    ''' 
    listOfFile = os.listdir(dirName)
    allFiles = list()
    # Iterate over all the entries
    for entry in listOfFile:
        # Create full path
        fullPath = os.path.join(dirName, entry)
        # If entry is a directory then get the list of files in this directory 
        if os.path.isdir(fullPath):
            allFiles = allFiles + get_list_of_files(fullPath)
        else:
            allFiles.append(fullPath)
                
    return allFiles

开始

misclassified_train_folders=[]

for f in folders:
  print(f)
  for nn in models_dict:
    #print(nn)

    # -- train dataset for each folder
    train_path = reg_input+f+"/"+'train_ds/'
    # print("\n train dataset:", "\n", train_path)
    train_ds = image_dataset_from_directory(
        train_path,
        class_names=["Bedroom","Bathroom","Dinning","Livingroom","Kitchen"],
        seed=None,
        validation_split=None, 
        subset=None,
        image_size= image_size,
        batch_size= batch_size,
        color_mode='rgb',
        shuffle=False 
        )
    
    # list of paths for analysed images
    pic_list = get_list_of_files(train_path)
    
    # inferences and real values
    y_pred, y_target = inferences_target_list(models_dict[nn], train_ds)
    
    # misclassified ones
    misclassified = get_misclassified(y_pred, y_target)
    print("elements misclassified in {} for model {}: ".format(f, nn), len(misclassified))
    misclassified_train_folders.append(misclassified)

  • 现在我有一个包含 5 个列表的列表:这些列表由我的第一个文件夹中的每个模型的所有错误分类元素组成。获取总是被错误分类的图片:
common_misclassified = list(set.intersection(*map(set, misclassified_train_folders)))
# this are the indexes of that images
print(len(common_misclassified), "\n", common_misclassified)
  • 获取这些图片的路径:
pic_list_missclassified = [pic_list[i] for i in common_missclassified]

# indexes of common missclassified elements for all models
print(len(pic_list_missclassified))