tf.data API 读取 TFRecord 文件
tf.data API read the TFRecord files
我正在尝试使用 tf.data API 读取 TFRecord 文件。
import tensorflow as tf
from PIL import Image
import numpy as np
import os
def train_input_fn():
filenames = ["mytrain.tfrecords"]
dataset = tf.data.TFRecordDataset(filenames)
def parser(record):
keys_to_features = {
"image_data": tf.FixedLenFeature((), tf.string, default_value=""),
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
image = tf.decode_jpeg(parsed["image_data"])
image = tf.reshape(image, [128, 128, 3])
label = tf.cast(parsed["label"], tf.int32)
return {"image_data": image, "date_time": parsed["date_time"]}, label
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(1)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
output = train_input_fn()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
for i in range(230):
image, label = sess.run(output)
img = Image.fromarray(image, 'RGB')
img.save(cwd+str(i) + '_''Label_'+str(l)+'.jpg')
print(image, label)
coord.request_stop()
coord.join(threads)
回溯(最后一次调用):
文件 "E:/Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py",第 34 行,位于
输出 = train_input_fn()
文件 "E:/Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py",第 25 行,在 train_input_fn 中
TypeError:应为 int64,得到的是 'str' 类型的 ''。
请注意您的错误日志中的 TypeError: Expected int64, got '' of type 'str' instead
。您的代码中有错误。
错误
在下一行中:
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
tf.int64
类型变量的默认值指定为字符串 ""
。
修复
假设您的预期默认值为 0,那么您应该将行更改为:
"date_time": tf.FixedLenFeature((), tf.int64, default_value=0),
希望对您有所帮助。
我正在尝试使用 tf.data API 读取 TFRecord 文件。
import tensorflow as tf
from PIL import Image
import numpy as np
import os
def train_input_fn():
filenames = ["mytrain.tfrecords"]
dataset = tf.data.TFRecordDataset(filenames)
def parser(record):
keys_to_features = {
"image_data": tf.FixedLenFeature((), tf.string, default_value=""),
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
image = tf.decode_jpeg(parsed["image_data"])
image = tf.reshape(image, [128, 128, 3])
label = tf.cast(parsed["label"], tf.int32)
return {"image_data": image, "date_time": parsed["date_time"]}, label
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(1)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
output = train_input_fn()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord = coord)
for i in range(230):
image, label = sess.run(output)
img = Image.fromarray(image, 'RGB')
img.save(cwd+str(i) + '_''Label_'+str(l)+'.jpg')
print(image, label)
coord.request_stop()
coord.join(threads)
回溯(最后一次调用): 文件 "E:/Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py",第 34 行,位于 输出 = train_input_fn() 文件 "E:/Tensorflow/Wenshan_Cai_Nanoletters/tf_data.py",第 25 行,在 train_input_fn 中 TypeError:应为 int64,得到的是 'str' 类型的 ''。
请注意您的错误日志中的 TypeError: Expected int64, got '' of type 'str' instead
。您的代码中有错误。
错误
在下一行中:
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
tf.int64
类型变量的默认值指定为字符串 ""
。
修复
假设您的预期默认值为 0,那么您应该将行更改为:
"date_time": tf.FixedLenFeature((), tf.int64, default_value=0),
希望对您有所帮助。