逐批喂 tf.estimator.Estimator.predict

Feed tf.estimator.Estimator.predict batch by batch

我有一个训练有素的估计器模型,我需要获取无法放入内存的非常大的数据集的预测向量,处理这些预测向量并保存它们。到目前为止,我的代码看起来像这样:

def hist(predictions):
    ...
    return histograms

def input_fn(feat, batch_size=100):
    dataset = tf.data.Dataset.from_tensor_slices((feat))
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda x:...)
    return dataset

super_batch = 100
splits = data.shape[0]//super_batch

for s in range(splits):
    pred = list(classifier.predict(lambda: input_fn(data[s*super_batch:(s+1)*super_batch])))
    pred_cls = [p["classes"] for p in pred]
    hist_vec = hist(pred_cls)
    save hist_vec

我知道这不是正确的方法,因为它会使 GPU 闲置很长时间,而且由于每次调用 classifier.predict 时加载模型需要很长时间至 运行。有什么方法可以使用带有估算器的 feed 函数来加快这个过程吗?

我假设问题出在 tf.data.Dataset.from_tensor_slices()


如果您在禁用急切模式的 Tensorflow 1.0 中将 tf.data.Dataset.from_tensor_slices() 与 numpy 数组一起使用,它会将值作为一个或多个 tf.constant 操作嵌入到图表中。你的数据集越大,图形就越大。这是非常低效的,您可能会遇到 ValueError: GraphDef cannot be larger than 2GB 个错误。

对于 tf.estimator,您有 2 个解决方案:

  • 使用tf.data.Dataset.from_generator。只需将 feat 转换为 Python 生成器即可。性能比 from_tensor_slices 差一点,因为 tf.data 图的速度受到 Python 运行时间的限制。

  • tf.data.Dataset.from_tensor_slices() 与 Tensorflow 占位符一起使用。这个比较复杂,但是效率最高。请参阅 my answer here 了解更多信息。它的要点是你需要创建一个特定的钩子来在估计器内部创建会话后初始化占位符。