使用 tensorflow io 在 train/test 个子集中拆分自定义二进制数据集
spliting custom binary dataset in train/test subsets using tensorflow io
我正在尝试使用本地二进制数据来训练网络执行 regression inference。
每个本地二进制数据具有以下布局:
整个数据由多个 *.bin 文件组成,布局如上。每个文件都有可变数量的 403*4 字节序列。我能够使用以下代码读取其中一个文件:
import tensorflow as tf
RAW_N = 2 + 20*20 + 1
def convert_binary_to_float_array(register):
return tf.io.decode_raw(register, out_type=tf.float32)
raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['mydata.bin'],record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(map_func=convert_binary_to_float_array)
现在,我需要创建 4 个数据集 train_data
、train_labels
、test_data
、test_labels
,如下所示:
train_data, train_labels, test_data, test_labels = prepare_ds(raw_dataset, 0.8)
并使用它们来训练和评估:
model = build_model()
history = model.fit(train_data, train_labels, ...)
loss, mse = model.evaluate(test_data, test_labels)
我的问题是:如何实现功能prepare_ds(dataset, frac)
?
def prepare_ds(dataset, frac):
...
我尝试使用 tf.shape
、tf.reshape
、tf.slice
、订阅 [:],但没有成功。我意识到这些功能无法正常工作,因为在 map()
调用之后 raw_dataset
是 MapDataset
(由于急切的执行问题)。
如果元数据被认为是您输入的一部分,我假设是这样,您可以尝试这样的事情:
import random
import struct
import tensorflow as tf
import numpy as np
RAW_N = 2 + 20*20 + 1
bytess = random.sample(range(1, 5000), RAW_N*4)
with open('mydata.bin', 'wb') as f:
f.write(struct.pack('1612i', *bytess))
def decode_and_prepare(register):
register = tf.io.decode_raw(register, out_type=tf.float32)
inputs = register[:402]
label = register[402:]
return inputs, label
total_data_entries = 8
raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['/content/mydata.bin', '/content/mydata.bin'], record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(decode_and_prepare)
raw_dataset = raw_dataset.shuffle(buffer_size=total_data_entries)
train_ds_size = int(0.8 * total_data_entries)
test_ds_size = int(0.2 * total_data_entries)
train_ds = raw_dataset.take(train_ds_size)
remaining_data = raw_dataset.skip(train_ds_size)
test_ds = remaining_data.take(test_ds_size)
请注意,出于演示目的,我使用了同一个 bin 文件两次。在 运行 该代码段之后,您可以像这样将数据集提供给您的模型:
model = build_model()
history = model.fit(train_ds, ...)
loss, mse = model.evaluate(test_ds)
因为每个数据集都包含输入和相应的标签。
我正在尝试使用本地二进制数据来训练网络执行 regression inference。
每个本地二进制数据具有以下布局:
整个数据由多个 *.bin 文件组成,布局如上。每个文件都有可变数量的 403*4 字节序列。我能够使用以下代码读取其中一个文件:
import tensorflow as tf
RAW_N = 2 + 20*20 + 1
def convert_binary_to_float_array(register):
return tf.io.decode_raw(register, out_type=tf.float32)
raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['mydata.bin'],record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(map_func=convert_binary_to_float_array)
现在,我需要创建 4 个数据集 train_data
、train_labels
、test_data
、test_labels
,如下所示:
train_data, train_labels, test_data, test_labels = prepare_ds(raw_dataset, 0.8)
并使用它们来训练和评估:
model = build_model()
history = model.fit(train_data, train_labels, ...)
loss, mse = model.evaluate(test_data, test_labels)
我的问题是:如何实现功能prepare_ds(dataset, frac)
?
def prepare_ds(dataset, frac):
...
我尝试使用 tf.shape
、tf.reshape
、tf.slice
、订阅 [:],但没有成功。我意识到这些功能无法正常工作,因为在 map()
调用之后 raw_dataset
是 MapDataset
(由于急切的执行问题)。
如果元数据被认为是您输入的一部分,我假设是这样,您可以尝试这样的事情:
import random
import struct
import tensorflow as tf
import numpy as np
RAW_N = 2 + 20*20 + 1
bytess = random.sample(range(1, 5000), RAW_N*4)
with open('mydata.bin', 'wb') as f:
f.write(struct.pack('1612i', *bytess))
def decode_and_prepare(register):
register = tf.io.decode_raw(register, out_type=tf.float32)
inputs = register[:402]
label = register[402:]
return inputs, label
total_data_entries = 8
raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['/content/mydata.bin', '/content/mydata.bin'], record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(decode_and_prepare)
raw_dataset = raw_dataset.shuffle(buffer_size=total_data_entries)
train_ds_size = int(0.8 * total_data_entries)
test_ds_size = int(0.2 * total_data_entries)
train_ds = raw_dataset.take(train_ds_size)
remaining_data = raw_dataset.skip(train_ds_size)
test_ds = remaining_data.take(test_ds_size)
请注意,出于演示目的,我使用了同一个 bin 文件两次。在 运行 该代码段之后,您可以像这样将数据集提供给您的模型:
model = build_model()
history = model.fit(train_ds, ...)
loss, mse = model.evaluate(test_ds)
因为每个数据集都包含输入和相应的标签。