tf.data API 读取 TFRecord 文件

tf.data API read the TFRecord files

我正在尝试使用 tf.data API 读取 TFRecord 文件。

import tensorflow as tf
from PIL import Image
import numpy as np
import os

def train_input_fn():
    filenames = ["mytrain.tfrecords"]
    dataset = tf.data.TFRecordDataset(filenames)

    def parser(record):
        keys_to_features = {
            "image_data": tf.FixedLenFeature((), tf.string, default_value=""),
            "date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
            "label": tf.FixedLenFeature((), tf.int64,
                                        default_value=tf.zeros([], dtype=tf.int64)),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        image = tf.decode_jpeg(parsed["image_data"])
        image = tf.reshape(image, [128, 128, 3])
        label = tf.cast(parsed["label"], tf.int32)

        return {"image_data": image, "date_time": parsed["date_time"]}, label

    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(32)
    dataset = dataset.repeat(1)
    iterator = dataset.make_one_shot_iterator()

    features, labels = iterator.get_next()
    return features, labels

output = train_input_fn()

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord = coord)
    for i in range(230):
        image, label = sess.run(output)
        img = Image.fromarray(image, 'RGB')
        img.save(cwd+str(i) + '_''Label_'+str(l)+'.jpg')
        print(image, label)
    coord.request_stop()
    coord.join(threads)

回溯(最后一次调用): 文件 "E:/Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py",第 34 行,位于 输出 = train_input_fn() 文件 "E:/Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py",第 25 行,在 train_input_fn 中 TypeError:应为 int64,得到的是 'str' 类型的 ''。

请注意您的错误日志中的 TypeError: Expected int64, got '' of type 'str' instead。您的代码中有错误。

错误

在下一行中:

"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),

tf.int64 类型变量的默认值指定为字符串 ""

修复

假设您的预期默认值为 0,那么您应该将行更改为:

"date_time": tf.FixedLenFeature((), tf.int64, default_value=0),

希望对您有所帮助。