1.14 版中使用预制 Estimator 进行推理的输入函数

input function for inference with pre-made Estimator in version 1.14

我有这个输入功能可以正常工作

def classify_input_fn(image_filename, command):
    file_contents = tf.io.read_file(image_filename)
    image_decoded = preprocess_image(file_contents)
    dataset = tf.data.Dataset.from_tensors((image_decoded, command))
    dataset = dataset.batch(1)
    iterator = dataset.make_one_shot_iterator()
    image, command = iterator.get_next()
    return {"image":image, "command":command}

command 是一个整数。

但是 tf 1.14 给出了

的警告

W0722 11:37:39.224976 10956 deprecation.py:323] ... DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated ...

警告建议直接返回数据集但失败了,因为它没有特征列的字典。使简单的输入函数对单个示例进行推理的正确方法是什么?

我试过简单地返回一个像

这样的字典
def classify_input_fn(image_filename, command):
    file_contents = tf.io.read_file(image_filename)
    image_decoded = preprocess_image(file_contents)
    return {"image":image_decoded, "command":command}

但是失败了

ValueError: Feature (key: command) cannot have rank 0.

使数据集包含张量字典而不是张量元组。然后你可以直接 return 来自输入函数的数据集,而不是使用已弃用的 dataset.make_one_shot_iterator().

def predict_input_fn(image_filename, command):
    file_contents = tf.io.read_file(image_filename)
    image_decoded = preprocess_image(file_contents)
    dataset = tf.data.Dataset.from_tensors({"image":image_decoded, "command":command})
    dataset = dataset.batch(1)
    return dataset