Tensorflow - 有没有办法按标签分隔 tf.data.Dataset?

Tensorflow - Is there a way to separate tf.data.Dataset by label?

我知道在将数据加载到我的网络之前,我可以通过标签将它们分开。假设有 3 个 类,标签为 0,1,2。我可以通过:

dataset1 = tf.data.TextLineDataset(train_csv_file1).map(_parse_csv_train)
dataset2 = tf.data.TextLineDataset(train_csv_file2).map(_parse_csv_train)
dataset3 = tf.data.TextLineDataset(train_csv_file3).map(_parse_csv_train)

我只是对以下内容感到好奇:

假设我们有数据集:

dataset = tf.data.TextLineDataset(train_csv_file).map(_parse_csv_train)

其中包含来自 3 类、

的所有数据

有没有办法调用像dataset.selectDataByLabel(label=="2")这样的函数[这是一个虚构的函数]这样我就可以划分数据集根据标签分为 3 个部分?

所以最后我选择了用csvs分隔文件,即生成一个csvs,每个csvs只包含一个class的数据。当 classes 太多时,这可能不是一个完美的解决方案,但在我的情况下只有 5 classes 所以没关系。