tf.estimator 想要 label_data 和 batch_size 用于预测 Tensorflow

tf.estimator wants label_data and batch_size for prediction Tensorflow

我使用高级 tf API 创建了一个网络,例如 tf.estimator。

训练和评估工作正常并产生输出。但是,在对新数据进行预测时,get_inputs() 需要 label_databatch_size.

错误是:TypeError: get_inputs() missing 2 required positional arguments: 'label_data' and 'batch_size'

我该如何解决这个问题以便做出预测?

这是我的代码:

predictTest = [0.34, 0.65, 0.88]

predictTest只是一个测试,不会是我真正的预测数据。

get_inputs(),这里是报错的地方

def get_inputs(feature_data, label_data, batch_size, n_epochs=None, shuffle=True):
        dataset = tf.data.Dataset.from_tensor_slices(
            (feature_data, label_data))

    dataset = dataset.repeat(n_epochs)
    if shuffle:
        dataset = dataset.shuffle(len(feature_data))
    dataset = dataset.batch(batch_size)
    features, labels = dataset.make_one_shot_iterator().get_next()
    return features, labels

预测输入:

def predict_input_fn():
    return get_inputs(
    predictTest,
    n_epochs=1,
    shuffle=False
    )

预测:

predict = estimator.predict(predict_input_fn)
print("Prediction: {}".format(list(predict)))

任何模型的测试都有两种类型。 1)你想要准确性,召回率等你需要为测试数据提供标签。如果你不提供标签,它会给你一个错误。 2)你只想测试你的模型而不计算准确度而不需要标签但是这里的预测会有所不同。

我发现我必须为预测创建一个新的 get_inputs() 函数。

如果我使用 get_inputs() 训练和评估使用,它期望它不会获得数据。

get_inputs:

def get_inputs(feature_data, label_data, batch_size, n_epochs=None, shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices( #from_tensor_slices
        (feature_data, label_data))

    dataset = dataset.repeat(n_epochs)
    if shuffle:
        dataset = dataset.shuffle(len(feature_data))
    dataset = dataset.batch(batch_size)
    features, labels = dataset.make_one_shot_iterator().get_next()
    return features, labels

创建一个名为 pred_get_inputs 的新函数,它不需要 label_databatch_size:

def get_pred_inputs(feature_data,n_epochs=None, shuffle=False):
    dataset = tf.data.Dataset.from_tensor_slices( #from_tensor_slices
        (feature_data))

    dataset = dataset.repeat(n_epochs)
    if shuffle:
        dataset = dataset.shuffle(len(feature_data))
    dataset = dataset.batch(1)
    features = dataset
    return features