Tensorflow 数据集的使用
Tensorflow datasets usage
我正在尝试创建一个简单的字符串标签对数据集,但无法让 tensorflow 正确连接这些对
我正在尝试使用 Dataset.from_tensor_slices 初始化器和 dataset.make_one_shot_iterator 迭代器:
import tensorflow as tf
strings = [
'aaaa',
'asdf'
]
labels = [1,0]
sess = tf.Session()
tf.global_variables_initializer()
dataset = tf.data.Dataset.from_tensor_slices((strings, labels))
dataset = dataset.repeat()
dataset = dataset.shuffle(512)
iterator = dataset.make_one_shot_iterator()
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
最后,我希望 'aaaa' 的输出为“1”,'asdf' 的输出为“0”,但反复得到一些随机的东西:
aaaa 0
asdf 0
aaaa 1
asdf 1
aaaa 1
aaaa 0
asdf 1
aaaa 1
请指出我的代码中可能存在的问题
顺便说一句,如果我去掉洗牌,我将无法得到另一个字符串,迭代器只会输出:
aaaa 0
aaaa 0
aaaa 0
...
标签错误...有人知道原因吗?
我是这样用的
next = iterator.get_next()
# print(next)
with tf.Session() as sess:
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'aaaa', 1)
我正在尝试创建一个简单的字符串标签对数据集,但无法让 tensorflow 正确连接这些对
我正在尝试使用 Dataset.from_tensor_slices 初始化器和 dataset.make_one_shot_iterator 迭代器:
import tensorflow as tf
strings = [
'aaaa',
'asdf'
]
labels = [1,0]
sess = tf.Session()
tf.global_variables_initializer()
dataset = tf.data.Dataset.from_tensor_slices((strings, labels))
dataset = dataset.repeat()
dataset = dataset.shuffle(512)
iterator = dataset.make_one_shot_iterator()
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
x_next, label_next = iterator.get_next()
print sess.run(x_next), sess.run(label_next)
最后,我希望 'aaaa' 的输出为“1”,'asdf' 的输出为“0”,但反复得到一些随机的东西:
aaaa 0
asdf 0
aaaa 1
asdf 1
aaaa 1
aaaa 0
asdf 1
aaaa 1
请指出我的代码中可能存在的问题
顺便说一句,如果我去掉洗牌,我将无法得到另一个字符串,迭代器只会输出:
aaaa 0
aaaa 0
aaaa 0
...
标签错误...有人知道原因吗?
我是这样用的
next = iterator.get_next()
# print(next)
with tf.Session() as sess:
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
print(sess.run(next))
(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'asdf', 0)
(b'aaaa', 1)
(b'aaaa', 1)
(b'aaaa', 1)