tensorflow tf.py_func 加载泡菜迭代器抛出错误,形状未知

tensorflow tf.py_func loading pickle Iterator throwing error, unknown shapes

我正在编写 tf.data 管道,以便稍后输入 keras。问题是我的数据是泡菜文件的形式。我有一个传递给 tf 数据的文件名列表,我将使用其中的自定义 tf.py_func 调用 pickle 来加载它。

当我尝试从数据集构建迭代器时出现问题,给出错误

"Cannot convert value , ), types: (tf.float32, tf.float32)> to a TensorFlow DType."

我相信这是因为 tensorflow 无法推断加载的 pickle 数据的形状。我对如何进行有点迷茫,或者这在 tf 数据中是否可行。

dataset = tf.data.Dataset.from_tensor_slices(dataset_filepath_list)

def parse_input_data_function(filename):
    # pickle file is a tuple, (data, label)
    histogram_data, label = pickle.load(open(filename, 'rb'))
    histogram_data = historgram_data.transpose(1, 0)
    histogram_data = historgram_data.reshape([-1, 8, 32])
    return histogram_data.astype('float32'), float(label)

dataset = dataset.map(
    lambda filename : tuple(tf.py_func(
        parse_input_data_function, [filename], [tf.float32, 
tf.float32])))

dataset = dataset.shuffle(len(dataset_filename_list))
    .batch(batch_size).repeat()

# this line is where the error occurs
training_iterator = tf.data.Iterator.from_structure(dataset, 
dataset.output_shapes)

您的问题是您向 tf.data.Iterator.from_structure 传递了错误的参数。它应该采用 (output_types, output_shapes),但您提供的是数据集及其形状。 试试这个:

training_iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)

然后,将迭代器与相应的数据集一起使用:

input_data, output_data = training_iterator.get_next()
train_init = training_iterator.make_initializer(dataset)