tensorflow如何更改数据集
tensorflow how to change dataset
我有一个数据集 API 技巧,它是我的张量流图的一部分。当我想使用不同的数据时,如何换出它?
dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
variable = tf.Variable(3, dtype=tf.int64)
model = variable*next_element
#pretend like this is me training my model, or something
with tf.Session() as sess:
sess.run(variable.initializer)
try:
while True:
print(sess.run(model)) # (0,3,6)
except:
pass
dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
### HOW TO DO THIS THING?
with tf.Session() as sess:
sess.run(variable.initializer) #This would be a saver restore operation, normally...
try:
while True:
print(sess.run(model)) # (0,3)... hopefully
except:
pass
我不相信这是可能的。您要求更改计算图本身,这在 tensorflow 中是不允许的。我没有自己解释,而是发现这个 post 中接受的答案在解释这一点时特别清楚
现在,话虽如此,我认为有一种相当 simple/clean 的方法可以实现您的目标。本质上,您想要重置图形并重建 Dataset
部分。当然你想重用代码的 model
部分。因此只需将该模型放入 class 或函数中以允许重用。一个基于您的代码的简单示例:
# the part of the graph you want to reuse
def get_model(next_element):
variable = tf.Variable(3,dtype=tf.int64)
return variable*next_element
# the first graph you want to build
tf.reset_default_graph()
# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
# reusable part
model = get_model(next_element)
#pretend like this is me training my model, or something
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
while True:
print(sess.run(model)) # (0,3,6)
except:
pass
# now the second graph
tf.reset_default_graph()
# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
# reusable part
model = get_model(next_element)
### HOW TO DO THIS THING?
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
while True:
print(sess.run(model)) # (0,3)... hopefully
except:
pass
最后说明:您还会在这里和那里看到一些对 tf.contrib.graph_editor
docs here 的引用。他们明确表示,您无法使用 graph_editor 完成您想要的(参见 link: "Here is an example of what you cannot do"; 但您可以非常接近)。即使如此,这也不是好的做法;他们有充分的理由使图表仅附加,我认为我建议的上述方法是完成您所寻求的更简洁的方法。
我建议的一种方法是使用 place_holders
,然后使用 tf.data.dataset
,但这会使事情变得更慢。因此,您将拥有以下内容:
train_data = tf.placeholder(dtype=tf.float32, shape=[None, None, 1]) # just an example
# Then add the tf.data.dataset here
train_data = tf.data.Dataset.from_tensor_slices(train_data).shuffle(10000).batch(batch_size)
现在,当 运行 会话中的图形时,您必须使用占位符输入数据。所以你随便喂什么...
希望对您有所帮助!!
我有一个数据集 API 技巧,它是我的张量流图的一部分。当我想使用不同的数据时,如何换出它?
dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
variable = tf.Variable(3, dtype=tf.int64)
model = variable*next_element
#pretend like this is me training my model, or something
with tf.Session() as sess:
sess.run(variable.initializer)
try:
while True:
print(sess.run(model)) # (0,3,6)
except:
pass
dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
### HOW TO DO THIS THING?
with tf.Session() as sess:
sess.run(variable.initializer) #This would be a saver restore operation, normally...
try:
while True:
print(sess.run(model)) # (0,3)... hopefully
except:
pass
我不相信这是可能的。您要求更改计算图本身,这在 tensorflow 中是不允许的。我没有自己解释,而是发现这个 post 中接受的答案在解释这一点时特别清楚
现在,话虽如此,我认为有一种相当 simple/clean 的方法可以实现您的目标。本质上,您想要重置图形并重建 Dataset
部分。当然你想重用代码的 model
部分。因此只需将该模型放入 class 或函数中以允许重用。一个基于您的代码的简单示例:
# the part of the graph you want to reuse
def get_model(next_element):
variable = tf.Variable(3,dtype=tf.int64)
return variable*next_element
# the first graph you want to build
tf.reset_default_graph()
# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
# reusable part
model = get_model(next_element)
#pretend like this is me training my model, or something
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
while True:
print(sess.run(model)) # (0,3,6)
except:
pass
# now the second graph
tf.reset_default_graph()
# the part of the graph you don't want to reuse
dataset = tf.data.Dataset.range(2)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
# reusable part
model = get_model(next_element)
### HOW TO DO THIS THING?
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
while True:
print(sess.run(model)) # (0,3)... hopefully
except:
pass
最后说明:您还会在这里和那里看到一些对 tf.contrib.graph_editor
docs here 的引用。他们明确表示,您无法使用 graph_editor 完成您想要的(参见 link: "Here is an example of what you cannot do"; 但您可以非常接近)。即使如此,这也不是好的做法;他们有充分的理由使图表仅附加,我认为我建议的上述方法是完成您所寻求的更简洁的方法。
我建议的一种方法是使用 place_holders
,然后使用 tf.data.dataset
,但这会使事情变得更慢。因此,您将拥有以下内容:
train_data = tf.placeholder(dtype=tf.float32, shape=[None, None, 1]) # just an example
# Then add the tf.data.dataset here
train_data = tf.data.Dataset.from_tensor_slices(train_data).shuffle(10000).batch(batch_size)
现在,当 运行 会话中的图形时,您必须使用占位符输入数据。所以你随便喂什么...
希望对您有所帮助!!