TensorFlow:dataset.train.next_batch 是如何定义的?
TensorFlow: how is dataset.train.next_batch defined?
我正在尝试学习 TensorFlow 并研究示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb
然后我在下面的代码中有一些问题:
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1),
"cost=", "{:.9f}".format(c))
既然mnist只是一个数据集,那么mnist.train.next_batch
到底是什么意思呢? dataset.train.next_batch
是如何定义的?
谢谢!
mnist
对象从 read_data_sets()
function defined in the tf.contrib.learn
module. The mnist.train.next_batch(batch_size)
method is implemented here 返回,它 returns 两个数组的元组,其中第一个代表一批 batch_size
MNIST 图像,第二个表示与这些图像对应的一批 batch-size
个标签。
图像以大小为 [batch_size, 784]
的二维 NumPy 数组形式返回(因为 MNIST 图像中有 784 个像素),标签以大小为一维 NumPy 数组的形式返回[batch_size]
(如果使用 one_hot=False
调用了 read_data_sets()
)或大小为 [batch_size, 10]
的二维 NumPy 数组(如果使用 one_hot=True
调用了 read_data_sets()
).
我正在尝试学习 TensorFlow 并研究示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb
然后我在下面的代码中有一些问题:
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1),
"cost=", "{:.9f}".format(c))
既然mnist只是一个数据集,那么mnist.train.next_batch
到底是什么意思呢? dataset.train.next_batch
是如何定义的?
谢谢!
mnist
对象从 read_data_sets()
function defined in the tf.contrib.learn
module. The mnist.train.next_batch(batch_size)
method is implemented here 返回,它 returns 两个数组的元组,其中第一个代表一批 batch_size
MNIST 图像,第二个表示与这些图像对应的一批 batch-size
个标签。
图像以大小为 [batch_size, 784]
的二维 NumPy 数组形式返回(因为 MNIST 图像中有 784 个像素),标签以大小为一维 NumPy 数组的形式返回[batch_size]
(如果使用 one_hot=False
调用了 read_data_sets()
)或大小为 [batch_size, 10]
的二维 NumPy 数组(如果使用 one_hot=True
调用了 read_data_sets()
).