Tensorflow 2.0:如何从 MapDataset(从 TFRecord 读取后)转换为可以输入到 model.fit 的某种结构
Tensorflow 2.0: how to transform from MapDataset (after reading from TFRecord) to some structure that can be input to model.fit
我将我的训练和验证数据存储在两个单独的 TFRecord 文件中,我在其中存储了 4 个值:信号 A(float32 形状(150,)),信号 B(float32 形状(150,)),标签(标量 int64),id(字符串)。我的阅读解析函数是:
def _parse_data_function(sample_proto):
raw_signal_description = {
'label': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
}
for key, item in SIGNALS.items():
raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(sample_proto, raw_signal_description)
其中SIGNALS
是字典映射信号名称->信号形状。然后,我阅读了原始数据集:
training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')
并使用 map 解析值:
training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)
显示 training_data
或 val_data
的 header,我得到:
<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>
这与预期的差不多。我还检查了一些值的一致性,它们似乎是正确的。
现在,关于我的问题:我如何从具有类似字典结构的 MapDataset 获取可以作为模型输入的内容?
我的模型的输入是一对(信号 A,标签),尽管将来我也会使用信号 B。
对我来说最简单的方法似乎是在我想要的元素上创建一个生成器。类似于:
def data_generator(mapdataset):
for sample in mapdataset:
yield (sample['Signal A'], sample['label'])
但是,使用这种方法我失去了数据集的一些便利性,例如批处理,而且也不清楚如何对 model.fit
的 validation_data
参数使用相同的方法。理想情况下,我只会在地图表示和数据集表示之间进行转换,它会在信号 A 张量和标签对上进行迭代。
编辑:我的最终产品应该是 header 类似于:
<TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)>
但不一定TensorSliceDataset
您可以在解析函数中简单地执行此操作。例如:
def _parse_data_function(sample_proto):
raw_signal_description = {
'label': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
}
for key, item in SIGNALS.items():
raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)
# Parse the input tf.Example proto using the dictionary above.
parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)
return parsed['Signal A'], parsed['label']
如果你 map
在 TFRecordDataset
上使用这个函数,你将得到一个 元组 (signal_a, label)
的数据集,而不是字典的数据集.您应该可以直接将其放入 model.fit
。
我将我的训练和验证数据存储在两个单独的 TFRecord 文件中,我在其中存储了 4 个值:信号 A(float32 形状(150,)),信号 B(float32 形状(150,)),标签(标量 int64),id(字符串)。我的阅读解析函数是:
def _parse_data_function(sample_proto):
raw_signal_description = {
'label': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
}
for key, item in SIGNALS.items():
raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)
# Parse the input tf.Example proto using the dictionary above.
return tf.io.parse_single_example(sample_proto, raw_signal_description)
其中SIGNALS
是字典映射信号名称->信号形状。然后,我阅读了原始数据集:
training_raw = tf.data.TFRecordDataset(<path to training>), compression_type='GZIP')
val_raw = tf.data.TFRecordDataset(<path to validation>), compression_type='GZIP')
并使用 map 解析值:
training_data = training_raw.map(_parse_data_function)
val_data = val_raw.map(_parse_data_function)
显示 training_data
或 val_data
的 header,我得到:
<MapDataset shapes: {Signal A: (150,), Signal B: (150,), id: (), label: ()}, types: {Signal A: tf.float32, Signal B: tf.float32, id: tf.string, label: tf.int64}>
这与预期的差不多。我还检查了一些值的一致性,它们似乎是正确的。
现在,关于我的问题:我如何从具有类似字典结构的 MapDataset 获取可以作为模型输入的内容?
我的模型的输入是一对(信号 A,标签),尽管将来我也会使用信号 B。
对我来说最简单的方法似乎是在我想要的元素上创建一个生成器。类似于:
def data_generator(mapdataset):
for sample in mapdataset:
yield (sample['Signal A'], sample['label'])
但是,使用这种方法我失去了数据集的一些便利性,例如批处理,而且也不清楚如何对 model.fit
的 validation_data
参数使用相同的方法。理想情况下,我只会在地图表示和数据集表示之间进行转换,它会在信号 A 张量和标签对上进行迭代。
编辑:我的最终产品应该是 header 类似于:
<TensorSliceDataset shapes: ((150,), ()), types: (tf.float32, tf.int64)>
但不一定TensorSliceDataset
您可以在解析函数中简单地执行此操作。例如:
def _parse_data_function(sample_proto):
raw_signal_description = {
'label': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
}
for key, item in SIGNALS.items():
raw_signal_description[key] = tf.io.FixedLenFeature(item, tf.float32)
# Parse the input tf.Example proto using the dictionary above.
parsed = tf.io.parse_single_example(sample_proto, raw_signal_description)
return parsed['Signal A'], parsed['label']
如果你 map
在 TFRecordDataset
上使用这个函数,你将得到一个 元组 (signal_a, label)
的数据集,而不是字典的数据集.您应该可以直接将其放入 model.fit
。