如何为 TensorFlow 数据集获取每个 class 的样本

How to get samples per class for TensorFlow Dataset

我正在使用来自 TensorFlow 数据集的数据集。 是否有一种简单的方法来访问数据集中每个 class 的样本数量? 我正在搜索 keras api,但没有找到任何可用的样本函数。

最终我想绘制一个条形图,其中 Y 轴为样本数,而 int 表示 class id 在 X 轴。目标是显示数据在 class 中的分布有多均匀。

使用 np.fromiter 您可以从可迭代对象创建一维数组。

import tensorflow_datasets as tfds
import numpy as np
import seaborn as sns

dataset = tfds.load('cifar10', split='train', as_supervised=True)

labels, counts = np.unique(np.fromiter(dataset.map(lambda x, y: y), np.int32), 
                       return_counts=True)

plt.ylabel('Counts')
plt.xlabel('Labels')
sns.barplot(x = labels, y = counts) 


更新:您还可以计算如下标签:

labels = []
for x, y in dataset:
  # Not one hot encoded
  labels.append(y.numpy())

  # If one hot encoded, then apply argmax
  # labels.append(np.argmax(y, axis = -1))
labels = np.concatenate(labels, axis = 0) # Assuming dataset was batched.

然后您可以使用 labels 数组绘制它们。