读取 tfrecord 文件时没有加载数据
No data is being loaded when reading tfrecord files
我正在尝试训练定义的模型 here using the function train
defined in the same script. I have attempted to use a tf record file provided by the authors in this link(为此,我所做的唯一修改是更改 init_path=None
并在 txt 文件中写入 tfrecord 文件的路径 train_list=train_tfs.txt
,均在 shift_params.py
的 shift_v1
中定义。
但是,即使尝试这个简单的测试,我也会收到以下错误:
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch_join/random_shuffle_queue' is closed and has insufficient elements (requested 15, current size 0)
[[Node: shuffle_batch_join = QueueDequeueManyV2[component_types=[DT_UINT8, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch_join/random_shuffle_queue, shuffle_batch_join/n)]]
据我了解,这个错误意味着没有数据被加载。所以我的猜测是问题出在shift_dset.py
中的函数read_example
,如下图(我只保留了加载图像ims
和音频sound
.完整代码请查看here.)
def read_example(rec_queue, pr, input_types):
reader = tf.TFRecordReader()
k, serialized_example = reader.read(rec_queue)
full = pr.full_im_dim
feats = {}
feats['im_0'] = tf.FixedLenFeature([], dtype=tf.string)
feats['im_1'] = tf.FixedLenFeature([], dtype=tf.string)
feats['sound'] = tf.FixedLenFeature([], dtype=tf.string)
if pr.variable_frame_count:
feats['num_frames'] = tf.FixedLenFeature([1], dtype=tf.int64)
example = tf.parse_single_example(serialized_example, features = feats)
total_frames = pr.total_frames if not pr.variable_frame_count \
else tf.cast(example['num_frames'][0], tf.int32)
assert not pr.variable_frame_count
def f(x):
x.set_shape((full*total_frames/2, full, 3))
return tf.reshape(x, (total_frames/2, full, full, 3))
im_parts = map(f, [tf.image.decode_jpeg(example['im_0'], channels = 3, name = 'decode_im1'),
tf.image.decode_jpeg(example['im_1'], channels = 3, name = 'decode_im2')])
samples = tf.decode_raw(example['sound'], tf.int16)
samples.set_shape((pr.full_samples_len*2))
samples = tf.reshape(samples, (pr.full_samples_len, 2))
samples = tf.cast(samples, 'float32') / np.iinfo(np.dtype(np.int16)).max
num_slice_frames = pr.sampled_frames
num_samples = int(pr.samples_per_frame * float(num_slice_frames))
if pr.do_shift:
choices = []
max_frame = total_frames - num_slice_frames
frames1 = ([0] if pr.fix_frame else xrange(max_frame))
for frame1 in frames1:
found = False
for frame2 in reversed(range(max_frame)):
inv1 = xrange(frame1, frame1 + num_slice_frames)
inv2 = xrange(frame2, frame2 + num_slice_frames)
if len(set(inv1).intersection(inv2)) <= pr.max_intersection:
found = True
choices.append([frame1, frame2])
if pr.fix_frame:
break
if pr.skip_notfound:
pass
else:
assert found
print 'Number of frame choices:', len(choices)
choices = tf.constant(np.array(choices), dtype = tf.int32)
idx = tf.random_uniform([1], 0, shape(choices, 0), dtype = tf.int64)[0]
start_frame_gt = choices[idx, 0]
shift_frame = choices[idx, 1]
elif ut.hastrue(pr, 'use_first_frame'):
shift_frame = start_frame_gt = tf.constant(0, dtype = tf.int32)
else:
shift_frame = start_frame_gt = tf.random_uniform(
[1], 0, total_frames - num_slice_frames, dtype = tf.int32)[0]
if pr.augment_ims:
print 'Augment:', pr.augment_ims
r = tf.random_uniform(
[2], 0, pr.full_im_dim - pr.crop_im_dim, dtype = tf.int32)
x, y = r[0], r[1]
else:
if hasattr(pr, 'resize_dims'):
y = pr.resize_dims[0]/2 - pr.crop_im_dim/2
x = pr.resize_dims[1]/2 - pr.crop_im_dim/2
else:
y = x = pr.full_im_dim/2 - pr.crop_im_dim/2
offset = [start_frame_gt, y, x, 0]
d = pr.crop_im_dim
size_im = [num_slice_frames, d, d, 3]
slice_parts = []
for j in xrange(len(im_parts)):
num_frames_in_part = total_frames/len(im_parts)
part_start = j*num_frames_in_part
frame_offset = tf.maximum(0, tf.minimum(start_frame_gt - part_start, num_frames_in_part))
end_offset = tf.maximum(0, tf.minimum(start_frame_gt + num_slice_frames - part_start, num_frames_in_part))
num_frames_in_part_slice = tf.maximum(0, end_offset - frame_offset)
offset = [frame_offset, y, x, 0]
d = pr.crop_im_dim
size_im = [num_frames_in_part_slice, d, d, 3]
p = tf.slice(im_parts[j], offset, size_im)
slice_parts.append(p)
ims_slice = tf.concat(slice_parts, 0)
ims_slice.set_shape([num_slice_frames, pr.crop_im_dim, pr.crop_im_dim, 3])
ims = ims_slice
if pr.augment_ims:
ims = tf.cond(tf.cast(tf.random_uniform([1], 0, 2, dtype = tf.int64)[0], tf.bool),
lambda : tf.map_fn(tf.image.flip_left_right, ims),
lambda : ims)
def slice_samples(frame):
start = round_int(pr.samples_per_frame * cast_float(frame))
offset = [start, 0]
size = [num_samples, 2]
r = tf.slice(samples, offset, size, name = 'slice_sample')
r.set_shape([num_samples] + list(shape(r)[1:]))
return r
if 'samples' in input_types:
samples_gt = slice_samples(start_frame_gt)
samples_shift = slice_samples(shift_frame)
else:
samples_gt = samples_shift = tf.zeros((1, 1), dtype = tf.int16)
samples_exs = tf.concat([ed(samples_shift, 0), ed(samples_gt, 0)], 0)
return ims, _, samples_exs, _, _, _
我也尝试通过更改此函数来加载我自己的 tfrecords,但我遇到了类似的错误。此外,当没有提供 tfrecord 路径时(即 train_list=train_tfs.txt
不包含 tfrecord 的路径,因此无法加载任何东西),我得到同样的错误,这表明正在加载数据。
提前感谢您的帮助。
如何重现相同的测试
- 克隆 here
中的代码
- 下载here中的数据样本
- 更改
shift_params.py
的 shift_v1
中的参数:
init_path=None
和 train_tfs.txt
的内容(应该
包含下载的 tf 文件的路径)
从文件夹 src
、运行 执行以下命令:
python -c "import shift_params, shift_net; shift_net.train(shift_params.shift_v1(num_gpus=3), [0, 1, 2], restore = False)"
环境详细信息
我在 anaconda 环境中使用以下包 (Linux):
- Python 2.7.18
- tensorflow、tensorflow-gpu 和 tensorflow-base 1.9.0
- numpy、matplotlib、pillow 和 scipy
我已经解决了这个问题,所以我post使用我在这里找到的解决方案,以防其他人遇到同样的问题。
我没有设法 运行 使用提供的 tfrecord 的数据加载代码(如问题中所示),所以我创建了自己的数据加载函数来替代 read_example
,我 post 此处供进一步参考:
def new_read_example(rec_queue, pr, input_types):
reader = tf.TFRecordReader()
k, serialized_example = reader.read(rec_queue)
full = pr.full_im_dim # full frame dimension
feature_description = {
'video':tf.FixedLenFeature([], tf.string),
'sync_audio':tf.FixedLenFeature([], tf.string),
'shift_audio':tf.FixedLenFeature([], tf.string),
'label':tf.FixedLenFeature([1], tf.int64),
'num_frames': tf.FixedLenFeature([1], tf.int64)
}
# parse one example
example = tf.parse_single_example(serialized_example, feature_description)
# get label and other variables
label = tf.cast(example['label'], tf.int64)
# I could not transform the total_frames from tensor to an integer
# so I just assigned 100 to total_frames,
# which is the number of frames that I use
# total_frames = tf.cast(example['num_frames'], tf.int64)
total_frames = 100
ims = tf.image.decode_jpeg(example['video'], channels=3)
ims.set_shape((full*total_frames, full, 3))
ims = tf.reshape(ims, (total_frames, full, full, 3))
# The audio stored in the tfrecords is an array of floats
# which is encoded as binary by calling audio.tostring()
sync_audio = tf.decode_raw(example['sync_audio'], tf.float32)
shift_audio = tf.decode_raw(example['shift_audio'], tf.float32)
# The audio that I am reading is mono, so I had to change the
# dimension of the audio data
sync_audio.set_shape((pr.full_samples_len))
sync_audio = tf.reshape(sync_audio, (pr.full_samples_len, 1))
shift_audio.set_shape((pr.full_samples_len))
shift_audio = tf.reshape(shift_audio, (pr.full_samples_len, 1))
samples_exs = tf.concat([ed(shift_audio, 0), ed(sync_audio, 0)], 0)
return ims, samples_exs
我从这里得出的结论是,如果不知道 tfrecord 的确切创建方式(最好是拥有生成 tfrecords 的代码),几乎不可能读取它,并且在我的例子中,它是更容易编写我自己的一对 tfrecords 编写器和 reader.
我正在尝试训练定义的模型 here using the function train
defined in the same script. I have attempted to use a tf record file provided by the authors in this link(为此,我所做的唯一修改是更改 init_path=None
并在 txt 文件中写入 tfrecord 文件的路径 train_list=train_tfs.txt
,均在 shift_params.py
的 shift_v1
中定义。
但是,即使尝试这个简单的测试,我也会收到以下错误:
OutOfRangeError (see above for traceback): RandomShuffleQueue '_1_shuffle_batch_join/random_shuffle_queue' is closed and has insufficient elements (requested 15, current size 0)
[[Node: shuffle_batch_join = QueueDequeueManyV2[component_types=[DT_UINT8, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/device:CPU:0"](shuffle_batch_join/random_shuffle_queue, shuffle_batch_join/n)]]
据我了解,这个错误意味着没有数据被加载。所以我的猜测是问题出在shift_dset.py
中的函数read_example
,如下图(我只保留了加载图像ims
和音频sound
.完整代码请查看here.)
def read_example(rec_queue, pr, input_types):
reader = tf.TFRecordReader()
k, serialized_example = reader.read(rec_queue)
full = pr.full_im_dim
feats = {}
feats['im_0'] = tf.FixedLenFeature([], dtype=tf.string)
feats['im_1'] = tf.FixedLenFeature([], dtype=tf.string)
feats['sound'] = tf.FixedLenFeature([], dtype=tf.string)
if pr.variable_frame_count:
feats['num_frames'] = tf.FixedLenFeature([1], dtype=tf.int64)
example = tf.parse_single_example(serialized_example, features = feats)
total_frames = pr.total_frames if not pr.variable_frame_count \
else tf.cast(example['num_frames'][0], tf.int32)
assert not pr.variable_frame_count
def f(x):
x.set_shape((full*total_frames/2, full, 3))
return tf.reshape(x, (total_frames/2, full, full, 3))
im_parts = map(f, [tf.image.decode_jpeg(example['im_0'], channels = 3, name = 'decode_im1'),
tf.image.decode_jpeg(example['im_1'], channels = 3, name = 'decode_im2')])
samples = tf.decode_raw(example['sound'], tf.int16)
samples.set_shape((pr.full_samples_len*2))
samples = tf.reshape(samples, (pr.full_samples_len, 2))
samples = tf.cast(samples, 'float32') / np.iinfo(np.dtype(np.int16)).max
num_slice_frames = pr.sampled_frames
num_samples = int(pr.samples_per_frame * float(num_slice_frames))
if pr.do_shift:
choices = []
max_frame = total_frames - num_slice_frames
frames1 = ([0] if pr.fix_frame else xrange(max_frame))
for frame1 in frames1:
found = False
for frame2 in reversed(range(max_frame)):
inv1 = xrange(frame1, frame1 + num_slice_frames)
inv2 = xrange(frame2, frame2 + num_slice_frames)
if len(set(inv1).intersection(inv2)) <= pr.max_intersection:
found = True
choices.append([frame1, frame2])
if pr.fix_frame:
break
if pr.skip_notfound:
pass
else:
assert found
print 'Number of frame choices:', len(choices)
choices = tf.constant(np.array(choices), dtype = tf.int32)
idx = tf.random_uniform([1], 0, shape(choices, 0), dtype = tf.int64)[0]
start_frame_gt = choices[idx, 0]
shift_frame = choices[idx, 1]
elif ut.hastrue(pr, 'use_first_frame'):
shift_frame = start_frame_gt = tf.constant(0, dtype = tf.int32)
else:
shift_frame = start_frame_gt = tf.random_uniform(
[1], 0, total_frames - num_slice_frames, dtype = tf.int32)[0]
if pr.augment_ims:
print 'Augment:', pr.augment_ims
r = tf.random_uniform(
[2], 0, pr.full_im_dim - pr.crop_im_dim, dtype = tf.int32)
x, y = r[0], r[1]
else:
if hasattr(pr, 'resize_dims'):
y = pr.resize_dims[0]/2 - pr.crop_im_dim/2
x = pr.resize_dims[1]/2 - pr.crop_im_dim/2
else:
y = x = pr.full_im_dim/2 - pr.crop_im_dim/2
offset = [start_frame_gt, y, x, 0]
d = pr.crop_im_dim
size_im = [num_slice_frames, d, d, 3]
slice_parts = []
for j in xrange(len(im_parts)):
num_frames_in_part = total_frames/len(im_parts)
part_start = j*num_frames_in_part
frame_offset = tf.maximum(0, tf.minimum(start_frame_gt - part_start, num_frames_in_part))
end_offset = tf.maximum(0, tf.minimum(start_frame_gt + num_slice_frames - part_start, num_frames_in_part))
num_frames_in_part_slice = tf.maximum(0, end_offset - frame_offset)
offset = [frame_offset, y, x, 0]
d = pr.crop_im_dim
size_im = [num_frames_in_part_slice, d, d, 3]
p = tf.slice(im_parts[j], offset, size_im)
slice_parts.append(p)
ims_slice = tf.concat(slice_parts, 0)
ims_slice.set_shape([num_slice_frames, pr.crop_im_dim, pr.crop_im_dim, 3])
ims = ims_slice
if pr.augment_ims:
ims = tf.cond(tf.cast(tf.random_uniform([1], 0, 2, dtype = tf.int64)[0], tf.bool),
lambda : tf.map_fn(tf.image.flip_left_right, ims),
lambda : ims)
def slice_samples(frame):
start = round_int(pr.samples_per_frame * cast_float(frame))
offset = [start, 0]
size = [num_samples, 2]
r = tf.slice(samples, offset, size, name = 'slice_sample')
r.set_shape([num_samples] + list(shape(r)[1:]))
return r
if 'samples' in input_types:
samples_gt = slice_samples(start_frame_gt)
samples_shift = slice_samples(shift_frame)
else:
samples_gt = samples_shift = tf.zeros((1, 1), dtype = tf.int16)
samples_exs = tf.concat([ed(samples_shift, 0), ed(samples_gt, 0)], 0)
return ims, _, samples_exs, _, _, _
我也尝试通过更改此函数来加载我自己的 tfrecords,但我遇到了类似的错误。此外,当没有提供 tfrecord 路径时(即 train_list=train_tfs.txt
不包含 tfrecord 的路径,因此无法加载任何东西),我得到同样的错误,这表明正在加载数据。
提前感谢您的帮助。
如何重现相同的测试
- 克隆 here 中的代码
- 下载here中的数据样本
- 更改
shift_params.py
的shift_v1
中的参数:init_path=None
和train_tfs.txt
的内容(应该 包含下载的 tf 文件的路径) 从文件夹
src
、运行 执行以下命令:python -c "import shift_params, shift_net; shift_net.train(shift_params.shift_v1(num_gpus=3), [0, 1, 2], restore = False)"
环境详细信息
我在 anaconda 环境中使用以下包 (Linux):
- Python 2.7.18
- tensorflow、tensorflow-gpu 和 tensorflow-base 1.9.0
- numpy、matplotlib、pillow 和 scipy
我已经解决了这个问题,所以我post使用我在这里找到的解决方案,以防其他人遇到同样的问题。
我没有设法 运行 使用提供的 tfrecord 的数据加载代码(如问题中所示),所以我创建了自己的数据加载函数来替代 read_example
,我 post 此处供进一步参考:
def new_read_example(rec_queue, pr, input_types):
reader = tf.TFRecordReader()
k, serialized_example = reader.read(rec_queue)
full = pr.full_im_dim # full frame dimension
feature_description = {
'video':tf.FixedLenFeature([], tf.string),
'sync_audio':tf.FixedLenFeature([], tf.string),
'shift_audio':tf.FixedLenFeature([], tf.string),
'label':tf.FixedLenFeature([1], tf.int64),
'num_frames': tf.FixedLenFeature([1], tf.int64)
}
# parse one example
example = tf.parse_single_example(serialized_example, feature_description)
# get label and other variables
label = tf.cast(example['label'], tf.int64)
# I could not transform the total_frames from tensor to an integer
# so I just assigned 100 to total_frames,
# which is the number of frames that I use
# total_frames = tf.cast(example['num_frames'], tf.int64)
total_frames = 100
ims = tf.image.decode_jpeg(example['video'], channels=3)
ims.set_shape((full*total_frames, full, 3))
ims = tf.reshape(ims, (total_frames, full, full, 3))
# The audio stored in the tfrecords is an array of floats
# which is encoded as binary by calling audio.tostring()
sync_audio = tf.decode_raw(example['sync_audio'], tf.float32)
shift_audio = tf.decode_raw(example['shift_audio'], tf.float32)
# The audio that I am reading is mono, so I had to change the
# dimension of the audio data
sync_audio.set_shape((pr.full_samples_len))
sync_audio = tf.reshape(sync_audio, (pr.full_samples_len, 1))
shift_audio.set_shape((pr.full_samples_len))
shift_audio = tf.reshape(shift_audio, (pr.full_samples_len, 1))
samples_exs = tf.concat([ed(shift_audio, 0), ed(sync_audio, 0)], 0)
return ims, samples_exs
我从这里得出的结论是,如果不知道 tfrecord 的确切创建方式(最好是拥有生成 tfrecords 的代码),几乎不可能读取它,并且在我的例子中,它是更容易编写我自己的一对 tfrecords 编写器和 reader.