tensorflow Dataset.from_generator 使用生成张量的生成器
tensorflow Dataset.from_generator using an generator that yield tensors
我正在尝试将一些代码转换为新数据集 API 以便我可以使用分发策略。以下是我正在尝试做的事情。
def dataset_generator():
while True:
features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
yield features, labels
def get_ssf_input_fn():
def input_fn():
return tf.data.Dataset.from_generator(dataset_generator,
(tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))
return input_fn
问题是 ex_lib.get_image_batch
和 ex_lib.get_feature_batch
给了我张量而不是 numpy 数组,我无法更改 ex_lib 中的代码。此外,我无法在此处将张量转换为 numpy 数组,因为我无法在此处访问 sess
。使用此代码,它将抛出
`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)
有没有办法让我的 input_fn return 成为数据集?
我可以通过以下技巧解决这个问题。效率还可以。
tf.data.Dataset.from_tensors(0).repeat().map(lambda _: dataset_generator())
我正在尝试将一些代码转换为新数据集 API 以便我可以使用分发策略。以下是我正在尝试做的事情。
def dataset_generator():
while True:
features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
yield features, labels
def get_ssf_input_fn():
def input_fn():
return tf.data.Dataset.from_generator(dataset_generator,
(tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))
return input_fn
问题是 ex_lib.get_image_batch
和 ex_lib.get_feature_batch
给了我张量而不是 numpy 数组,我无法更改 ex_lib 中的代码。此外,我无法在此处将张量转换为 numpy 数组,因为我无法在此处访问 sess
。使用此代码,它将抛出
`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)
有没有办法让我的 input_fn return 成为数据集?
我可以通过以下技巧解决这个问题。效率还可以。
tf.data.Dataset.from_tensors(0).repeat().map(lambda _: dataset_generator())