如何 access/process TensorFlow 数据集中的内容?

How can I access/process content within TensorFlow Datasets?

我正在使用 cnn_dailymail dataset which is part of the TensorFlow Datasets。 我按如下方式访问它:

import tensorflow_datasets as tfds
data, info = tfds.load('cnn_dailymail', with_info=True)
train_data, test_data = data['train'], data['test']

要从我使用的数据集中提取单个示例:

cnn_ex, = train_data.take(1)
cnn_ex['highlights'].numpy()

这将 return 类似于这样的字符串:"emma monaghan, 27, from glasgow, used to weigh 18st 5lbs ."。我想对该数据集应用一些预处理步骤,以便我可以将其用作深度学习算法的输入。上面的例子在预处理后应该是这样的:"<start> emma monaghan, 27, from glasgow, used to weigh 18st 5lbs . <end>".

有没有办法一次访问和预处理所有文本(在 train_data 内),而不必多次应用 take() 函数次?例如,将 TensorFlow 数据集转换为 numpy 数组就已经有所帮助。谢谢!

这取决于您的具体 objective。也许 tfds.as_numpy() 就是您要找的。您可以将其应用于 train_data 以获得 generator_object。您可以直接对其进行迭代,或应用任何映射函数

train_data = train_data.map(map_func)
for i in tfds.as_numpy(train_data):
    print(i)
    ...

您可以使用 dataset.map() 对您的数据应用转换。例如:

import tensorflow as tf
import tensorflow_datasets as tfds

data, info = tfds.load('cnn_dailymail', with_info=True)
dataset_train, dataset_test = data['train'], data['test']

def map_fn(x, start=tf.constant('<start>'), end=tf.constant('<end>')):
    strings = [start, x['highlights'], end]
    x['highlights'] = tf.strings.join(strings, separator=' ')
    return x

dataset_train = dataset_train.map(map_fn) # <-- apply transformation for the whole data
elem,  = dataset_train.take(1)
print(elem['highlights'].numpy())
# b'<start> arthur potts dawson: british ... <end>'