逐批喂 tf.estimator.Estimator.predict
Feed tf.estimator.Estimator.predict batch by batch
我有一个训练有素的估计器模型,我需要获取无法放入内存的非常大的数据集的预测向量,处理这些预测向量并保存它们。到目前为止,我的代码看起来像这样:
def hist(predictions):
...
return histograms
def input_fn(feat, batch_size=100):
dataset = tf.data.Dataset.from_tensor_slices((feat))
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda x:...)
return dataset
super_batch = 100
splits = data.shape[0]//super_batch
for s in range(splits):
pred = list(classifier.predict(lambda: input_fn(data[s*super_batch:(s+1)*super_batch])))
pred_cls = [p["classes"] for p in pred]
hist_vec = hist(pred_cls)
save hist_vec
我知道这不是正确的方法,因为它会使 GPU 闲置很长时间,而且由于每次调用 classifier.predict 时加载模型需要很长时间至 运行。有什么方法可以使用带有估算器的 feed 函数来加快这个过程吗?
我假设问题出在 tf.data.Dataset.from_tensor_slices()
如果您在禁用急切模式的 Tensorflow 1.0 中将 tf.data.Dataset.from_tensor_slices()
与 numpy 数组一起使用,它会将值作为一个或多个 tf.constant
操作嵌入到图表中。你的数据集越大,图形就越大。这是非常低效的,您可能会遇到 ValueError: GraphDef cannot be larger than 2GB
个错误。
对于 tf.estimator
,您有 2 个解决方案:
使用tf.data.Dataset.from_generator
。只需将 feat
转换为 Python 生成器即可。性能比 from_tensor_slices
差一点,因为 tf.data
图的速度受到 Python 运行时间的限制。
将 tf.data.Dataset.from_tensor_slices()
与 Tensorflow 占位符一起使用。这个比较复杂,但是效率最高。请参阅 my answer here 了解更多信息。它的要点是你需要创建一个特定的钩子来在估计器内部创建会话后初始化占位符。
我有一个训练有素的估计器模型,我需要获取无法放入内存的非常大的数据集的预测向量,处理这些预测向量并保存它们。到目前为止,我的代码看起来像这样:
def hist(predictions):
...
return histograms
def input_fn(feat, batch_size=100):
dataset = tf.data.Dataset.from_tensor_slices((feat))
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda x:...)
return dataset
super_batch = 100
splits = data.shape[0]//super_batch
for s in range(splits):
pred = list(classifier.predict(lambda: input_fn(data[s*super_batch:(s+1)*super_batch])))
pred_cls = [p["classes"] for p in pred]
hist_vec = hist(pred_cls)
save hist_vec
我知道这不是正确的方法,因为它会使 GPU 闲置很长时间,而且由于每次调用 classifier.predict 时加载模型需要很长时间至 运行。有什么方法可以使用带有估算器的 feed 函数来加快这个过程吗?
我假设问题出在 tf.data.Dataset.from_tensor_slices()
如果您在禁用急切模式的 Tensorflow 1.0 中将 tf.data.Dataset.from_tensor_slices()
与 numpy 数组一起使用,它会将值作为一个或多个 tf.constant
操作嵌入到图表中。你的数据集越大,图形就越大。这是非常低效的,您可能会遇到 ValueError: GraphDef cannot be larger than 2GB
个错误。
对于 tf.estimator
,您有 2 个解决方案:
使用
tf.data.Dataset.from_generator
。只需将feat
转换为 Python 生成器即可。性能比from_tensor_slices
差一点,因为tf.data
图的速度受到 Python 运行时间的限制。将
tf.data.Dataset.from_tensor_slices()
与 Tensorflow 占位符一起使用。这个比较复杂,但是效率最高。请参阅 my answer here 了解更多信息。它的要点是你需要创建一个特定的钩子来在估计器内部创建会话后初始化占位符。