Tensorflow 理解 tf.train.shuffle_batch

Tensorflow understanding tf.train.shuffle_batch

我有一个训练数据文件,大约 10 万行,并且我 运行 在每个训练步骤中都是一个简单的 tf.train.GradientDescentOptimizer。该设置基本上直接取自 Tensorflow 的 MNIST 示例。代码转载如下:

x = tf.placeholder(tf.float32, [None, 21])
W = tf.Variable(tf.zeros([21, 2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 2])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

鉴于我正在从文件中读取训练数据,我使用 tf.train.string_input_producertf.decode_csv 从 csv 中读取行,然后 tf.train.shuffle_batch 创建批次然后我继续训练。

我对 tf.train.shuffle_batch 的参数应该是什么感到困惑。我阅读了 Tensorflow 的文档,但我仍然不确定 "optimal" batch_size、容量和 min_after_dequeue 值是什么。任何人都可以帮助阐明我如何为这些参数选择合适的值,或者 link 我可以找到可以了解更多信息的资源吗?谢谢--

这是 API link:https://www.tensorflow.org/versions/r0.9/api_docs/python/io_ops.html#shuffle_batch

使用的线程数有点问题

https://www.tensorflow.org/versions/r0.9/how_tos/reading_data/index.html#batching

遗憾的是,我认为批量大小没有简单的答案。 网络的有效批量大小取决于很多细节 关于网络。实际上,如果您关心最佳性能 你将需要做大量的试验和错误(也许开始 来自类似网络使用的值)。