在从 tensorflow_datasets 加载的 CIFAR-100 中访问 'coarse_label'

Access 'coarse_label' in CIFAR-100 loaded from tensorflow_datasets

我正在使用 tensorflow_datasets (tfds doc)

加载 CIFAR-100
train, test = tfds.load(name="cifar100:3.*.*", split=["train", "test"], as_supervised=True)

CIFAR-100 有一个标签 (100 类) 和一个 coarse_label (20 类),如上面链接的文档所示。很容易访问标签,例如:

for image, label in train:
     # ... the label here is the actual label, not the coarse_label

但是,我打算基于 coarse_label 进行操作,例如,根据它进行过滤或将其用作 Keras 分类器中的标签。

如何访问 coarse_label?

我找到了解决办法。如果我没有按受监督方式加载,即如果我删除 as_supervised=True

train, test = tfds.load(name="cifar100:3.*.*", split=["train", "test"])

,我可以从字典中得到coarse_labels,例如

for item in train:
   print(item['coarse_label'])

像这样,我将能够重构数据集s.t。 coarse_labels 可用于分类。但是,即使我对标签感兴趣,我仍然必须加载 as_supervised=False ,这对我来说仍然很不自然。如果有人有更好的解决方案,我很乐意接受这个答案。