Tensorflow 数据集 API - .from_tensor_slices() / .from_tensor() - 无法创建内容大于 2gb 的张量原型

Tensorflow Dataset API - .from_tensor_slices() / .from_tensor() - cannot create a tensor proto whose content is larger than 2gb

所以我想使用数据集 API 对我的大型数据集 (~8GB) 进行批处理,因为我在使用 GPU 时遇到大量空闲时间,因为我正在将数据从 python 传递到 Tensorflow使用 feed_dict.

当我按照此处提到的教程进行操作时:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/5_DataManagement/tensorflow_dataset_api.py

当运行我的简单代码:

one_hot_dataset = np.load("one_hot_dataset.npy")
dataset = tf.data.Dataset.from_tensor_slices(one_hot_dataset)

我收到 TensorFlow 1.8 和 Python 3.5 的错误消息:

Traceback (most recent call last):

  File "<ipython-input-17-412a606c772f>", line 1, in <module>
    dataset = tf.data.Dataset.from_tensor_slices((one_hot_dataset))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 235, in from_tensor_slices
    return TensorSliceDataset(tensors)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in __init__
    for i, t in enumerate(nest.flatten(tensors))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in <listcomp>
    for i, t in enumerate(nest.flatten(tensors))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1014, in convert_to_tensor
    as_ref=False)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
    value, dtype=dtype, shape=shape, verify_shape=verify_shape))

  File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 496, in make_tensor_proto
    "Cannot create a tensor proto whose content is larger than 2GB.")

ValueError: Cannot create a tensor proto whose content is larger than 2GB.

我该如何解决这个问题?我认为原因很明显,但是 tf 开发人员将输入数据限制为 2GB 时是怎么想的?!?我真的无法理解这种合理性,在处理更大的数据集时有什么解决方法?

我用谷歌搜索了很多,但找不到任何类似的错误消息。当我使用 numpy 数据集的 FITFH 时,上述步骤没有任何问题。

我需要以某种方式告诉 TensorFlow 我实际上将逐批加载数据并且可能想要预取几批以使我的 GPU 忙碌。但它似乎试图一次加载整个 numpy 数据集。那么使用数据集 API 有什么好处,因为我可以通过简单地尝试将我的 numpy 数据集作为 tf.constant 加载到 TensorFlow 图中来重现此错误,这显然不适合我收到 OOM 错误。

感谢提示和故障排除提示!

此问题已在 tf.data 用户指南 (https://www.tensorflow.org/guide/datasets) 的 "Consuming NumPy arrays" 部分解决。

基本上,创建一个 dataset.make_initializable_iterator() 迭代器并在运行时提供数据。

如果由于某种原因这不起作用,您可以将数据写入文件或从 Python 生成器 (https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator) 创建数据集,您可以在其中放置任意 Python包括切片 numpy 数组和生成切片的代码。