如何 add/change 组件名称到现有的 Tensorflow 数据集对象?

How to add/change names of components to an existing Tensorflow Dataset object?

来自 Tensorflow 数据集指南

It is often convenient to give names to each component of an element, for example if they represent different features of a training example. In addition to tuples, you can use collections.namedtuple or a dictionary mapping strings to tensors to represent a single element of a Dataset.

dataset = tf.data.Dataset.from_tensor_slices(
   {"a": tf.random_uniform([4]),
    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

https://www.tensorflow.org/guide/datasets

这在 Keras 中非常有用。如果将数据集对象传递给 model.fit,组件的名称可用于匹配 Keras 模型的输入。示例:

image_input = keras.Input(shape=(32, 32, 3), name='img_input')
timeseries_input = keras.Input(shape=(None, 10), name='ts_input')

x1 = layers.Conv2D(3, 3)(image_input)
x1 = layers.GlobalMaxPooling2D()(x1)

x2 = layers.Conv1D(3, 3)(timeseries_input)
x2 = layers.GlobalMaxPooling1D()(x2)

x = layers.concatenate([x1, x2])

score_output = layers.Dense(1, name='score_output')(x)
class_output = layers.Dense(5, activation='softmax', name='class_output')(x)

model = keras.Model(inputs=[image_input, timeseries_input],
                    outputs=[score_output, class_output])

train_dataset = tf.data.Dataset.from_tensor_slices(
    ({'img_input': img_data, 'ts_input': ts_data},
     {'score_output': score_targets, 'class_output': class_targets}))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model.fit(train_dataset, epochs=3)

因此,它对于查找、添加和更改 tf 数据集对象中组件的名称很有用。完成这些任务的最佳方式是什么?

您可以使用 map 对您的数据集进行修改,如果您正在寻找的话。例如,要将普通 tuple 输出转换为具有有意义名称的 dict

import tensorflow as tf

# dummy example
ds_ori = tf.data.Dataset.zip((tf.data.Dataset.range(0, 10), tf.data.Dataset.range(10, 20)))
ds_renamed = ds_ori.map(lambda x, y: {'input': x, 'output': y})

batch_ori = ds_ori.make_one_shot_iterator().get_next()
batch_renamed = ds_renamed.make_one_shot_iterator().get_next()

with tf.Session() as sess:
  print(sess.run(batch_ori))
  print(sess.run(batch_renamed))
  # (0, 10)
  # {'input': 0, 'output': 10}

虽然接受的答案适用于更改(现有)组件的名称,但它没有谈论 'addition'。这可以按如下方式完成:

y_dataset = x_dataset.map(fn1)

你可以根据需要定义 fn1

@tf.function
def fn1(x):
    ##use x to derive additional columns u want. Set the shape as well
    y = {}
    y.update(x)
    y['new1'] = new1
    y['new2'] = new2
    return y