使用 tensorflow io 在 train/test 个子集中拆分自定义二进制数据集

spliting custom binary dataset in train/test subsets using tensorflow io

我正在尝试使用本地二进制数据来训练网络执行 regression inference

每个本地二进制数据具有以下布局:

整个数据由多个 *.bin 文件组成,布局如上。每个文件都有可变数量的 403*4 字节序列。我能够使用以下代码读取其中一个文件:

import tensorflow as tf

RAW_N = 2 + 20*20 + 1

def convert_binary_to_float_array(register):
     return tf.io.decode_raw(register, out_type=tf.float32)

raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['mydata.bin'],record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(map_func=convert_binary_to_float_array)

现在,我需要创建 4 个数据集 train_datatrain_labelstest_datatest_labels,如下所示:

train_data, train_labels, test_data, test_labels = prepare_ds(raw_dataset, 0.8)

并使用它们来训练和评估:

model = build_model()

history = model.fit(train_data, train_labels, ...)

loss, mse = model.evaluate(test_data, test_labels)

我的问题是:如何实现功能prepare_ds(dataset, frac)

def prepare_ds(dataset, frac):
    ...

我尝试使用 tf.shapetf.reshapetf.slice、订阅 [:],但没有成功。我意识到这些功能无法正常工作,因为在 map() 调用之后 raw_datasetMapDataset(由于急切的执行问题)。

如果元数据被认为是您输入的一部分,我假设是这样,您可以尝试这样的事情:

import random
import struct
import tensorflow as tf
import numpy as np

RAW_N = 2 + 20*20 + 1

bytess = random.sample(range(1, 5000), RAW_N*4)
with open('mydata.bin', 'wb') as f:
  f.write(struct.pack('1612i', *bytess))

def decode_and_prepare(register):
  register = tf.io.decode_raw(register, out_type=tf.float32)
  inputs = register[:402]
  label = register[402:]
  return inputs, label

total_data_entries = 8
raw_dataset = tf.data.FixedLengthRecordDataset(filenames=['/content/mydata.bin', '/content/mydata.bin'], record_bytes=RAW_N*4)
raw_dataset = raw_dataset.map(decode_and_prepare)
raw_dataset = raw_dataset.shuffle(buffer_size=total_data_entries)

train_ds_size = int(0.8 * total_data_entries)
test_ds_size = int(0.2 * total_data_entries)

train_ds = raw_dataset.take(train_ds_size)
remaining_data = raw_dataset.skip(train_ds_size)  
test_ds = remaining_data.take(test_ds_size)

请注意,出于演示目的,我使用了同一个 bin 文件两次。在 运行 该代码段之后,您可以像这样将数据集提供给您的模型:

model = build_model()

history = model.fit(train_ds, ...)

loss, mse = model.evaluate(test_ds)

因为每个数据集都包含输入和相应的标签。