获取 keras 中所有已知 类 vgg-16 的列表

Getting a list of all known classes of vgg-16 in keras

我使用来自 Keras 的预训练 VGG-16 模型。

到目前为止我的工作源代码是这样的:

from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions

model = VGG16()

print(model.summary())

image = load_img('./pictures/door.jpg', target_size=(224, 224))
image = img_to_array(image)  #output Numpy-array

image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

image = preprocess_input(image)
yhat = model.predict(image)

label = decode_predictions(yhat)
label = label[0][0]

print('%s (%.2f%%)' % (label[1], label[2]*100))

我发现模型是在 1000 类 上训练的。是否有可能获得 类 训练该模型的列表?打印出所有预测标签不是一种选择,因为只返回了 5 个。

提前致谢

您可以使用 decode_predictions 并在 top=1000 参数中传递 class 的总数(只有它的默认值为 5)。

或者您可以看看 Keras 是如何在内部执行此操作的:它下载文件 imagenet_class_index.json(并且通常将其缓存在 ~/.keras/models/ 中)。这是一个包含所有 class 标签的简单 json 文件。

我想如果你这样做:

vgg16 = keras.applications.vgg16.VGG16(include_top=True,
                               weights='imagenet',
                               input_tensor=None,
                               input_shape=None,
                               pooling=None,
                               classes=1000)

vgg16.decode_predictions(np.arange(1000), top=1000)

用您的预测数组替换 np.arange(1000)。到目前为止未经测试的代码。

Link 我想在这里训练标签:http://image-net.org/challenges/LSVRC/2014/browse-synsets

如果您稍微编辑一下代码,您可以获得您提供的示例的所有最佳预测的列表。 Tensorflow decode_predictions returns 列表 class 预测元组的列表。因此,首先,将 top=1000 参数添加为 @YSelf 推荐给 label = decode_predictions(yhat, top=1000) 然后将 label = label[0][0] 更改为 label = label[0][:] 到 select 所有预测。标签看起来像这样:

[('n04252225', 'snowplow', 0.4144803),
('n03796401', 'moving_van', 0.09205707),
('n04461696', 'tow_truck', 0.08912289),
('n03930630', 'pickup', 0.07173037),
('n04467665', 'trailer_truck', 0.048759833),
('n02930766', 'cab', 0.043586567),
('n04037443', 'racer', 0.036957625),....)]

从这里开始,您需要对元组进行解包。如果您只想获得 1000 个 classes 的列表,您只需调用 [y for (x,y,z) in label] 即可获得所有 1000 个 classes 的列表。输出将如下所示:

['snowplow',
'moving_van',
'tow_truck',
'pickup',
'trailer_truck',
'cab',
'racer',....]

这一行将打印出所有 class 名称和索引: decode_predictions(np.expand_dims(np.arange(1000), 0), top=1000)