tensorflow Dataset.from_generator 使用生成张量的生成器

tensorflow Dataset.from_generator using an generator that yield tensors

我正在尝试将一些代码转换为新数据集 API 以便我可以使用分发策略。以下是我正在尝试做的事情。

def dataset_generator():
    while True:
        features, labels = ex_lib.get_image_batch(), ex_lib.get_feature_batch()
        yield features, labels

def get_ssf_input_fn():
    def input_fn():
        return tf.data.Dataset.from_generator(dataset_generator,
                                              (tf.float32, tf.float32), ([None, config.image_height, config.image_width, config.image_channels], [None, 256]))

    return input_fn

问题是 ex_lib.get_image_batchex_lib.get_feature_batch 给了我张量而不是 numpy 数组,我无法更改 ex_lib 中的代码。此外,我无法在此处将张量转换为 numpy 数组,因为我无法在此处访问 sess。使用此代码,它将抛出

`generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was Tensor("GetImageBatch:0", dtype=uint8)

有没有办法让我的 input_fn return 成为数据集?

我可以通过以下技巧解决这个问题。效率还可以。

tf.data.Dataset.from_tensors(0).repeat().map(lambda _: dataset_generator())