如何使用 tf.saved_model API 恢复 tf.data.Dataset() 中的悬挂 tf.py_func?
How to restore dangling tf.py_func within the tf.data.Dataset() with tf.saved_model API?
在研究了使用 saved_model API 时恢复 tf.py_func()
的方法后,除了 tensorflow 中的记录外,我找不到其他信息:
The operation must run in the same address space as the Python program that calls tf.py_func()
. If you are using distributed TensorFlow, you must run a tf.train.Server
in the same process as the program that calls tf.py_func()
and you must pin the created operation to a device in that server (e.g. using with tf.device()
:)
两个 save/load 片段有助于说明情况。
保存部分:
def wrapper(x, y):
with tf.name_scope('wrapper'):
return tf.py_func(Copy, [x, y], [tf.float32, tf.float32])
def Copy(x, y):
return x, y
x_ph = tf.placeholder(tf.float32, [None], 'x_ph')
y_ph = tf.placeholder(tf.float32, [None], 'y_ph')
with tf.name_scope('input'):
ds = tf.data.Dataset.from_tensor_slices((x_ph, y_ph))
ds = ds.map(wrapper)
ds = ds.batch(1)
it = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
it_init_op = it.make_initializer(ds, name='it_init_op')
x_it, y_it = it.get_next()
# Simple operation
with tf.name_scope('add'):
res = tf.add(x_it, y_it)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), it_init_op], feed_dict={y_ph: [10] * 10, x_ph: [i for i in range(10)]})
sess.run([res])
tf.saved_model.simple_save(sess, './dummy/test', {'x_ph': x_ph, 'y_ph': y_ph}, {'res': res})
加载部分:
graph = tf.Graph()
graph.as_default()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './dummy/test')
res = graph.get_tensor_by_name('add/Add:0')
it_init_op = graph.get_operation_by_name('input/it_init_op')
x_ph = graph.get_tensor_by_name('x_ph:0')
y_ph = graph.get_tensor_by_name('y_ph:0')
sess.run([it_init_op], feed_dict={x_ph: [5] * 5, y_ph: [i for i in range(5)]})
for _ in range(5):
sess.run([res])
错误:
ValueError: callback pyfunc_0 is not found
众所周知,tf.py_func()
包装的函数不会与模型一起保存。有没有人有办法通过使用 tf 文档应用 tf.train.Server
给出的小提示来恢复它
只要没有答案,我会建议我的,which contour the pb 而不是解决它。苦苦挣扎了半天,终于通过修剪把它给忽略了。然后用占位符更简单的方式将新的 input/ouput 嫁接给它。此外,此 py_func 在 TF2.0.
中已弃用
在研究了使用 saved_model API 时恢复 tf.py_func()
的方法后,除了 tensorflow 中的记录外,我找不到其他信息:
The operation must run in the same address space as the Python program that calls
tf.py_func()
. If you are using distributed TensorFlow, you must run atf.train.Server
in the same process as the program that callstf.py_func()
and you must pin the created operation to a device in that server (e.g. using withtf.device()
:)
两个 save/load 片段有助于说明情况。
保存部分:
def wrapper(x, y):
with tf.name_scope('wrapper'):
return tf.py_func(Copy, [x, y], [tf.float32, tf.float32])
def Copy(x, y):
return x, y
x_ph = tf.placeholder(tf.float32, [None], 'x_ph')
y_ph = tf.placeholder(tf.float32, [None], 'y_ph')
with tf.name_scope('input'):
ds = tf.data.Dataset.from_tensor_slices((x_ph, y_ph))
ds = ds.map(wrapper)
ds = ds.batch(1)
it = tf.data.Iterator.from_structure(ds.output_types, ds.output_shapes)
it_init_op = it.make_initializer(ds, name='it_init_op')
x_it, y_it = it.get_next()
# Simple operation
with tf.name_scope('add'):
res = tf.add(x_it, y_it)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), it_init_op], feed_dict={y_ph: [10] * 10, x_ph: [i for i in range(10)]})
sess.run([res])
tf.saved_model.simple_save(sess, './dummy/test', {'x_ph': x_ph, 'y_ph': y_ph}, {'res': res})
加载部分:
graph = tf.Graph()
graph.as_default()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], './dummy/test')
res = graph.get_tensor_by_name('add/Add:0')
it_init_op = graph.get_operation_by_name('input/it_init_op')
x_ph = graph.get_tensor_by_name('x_ph:0')
y_ph = graph.get_tensor_by_name('y_ph:0')
sess.run([it_init_op], feed_dict={x_ph: [5] * 5, y_ph: [i for i in range(5)]})
for _ in range(5):
sess.run([res])
错误:
ValueError: callback pyfunc_0 is not found
众所周知,tf.py_func()
包装的函数不会与模型一起保存。有没有人有办法通过使用 tf 文档应用 tf.train.Server
只要没有答案,我会建议我的,which contour the pb 而不是解决它。苦苦挣扎了半天,终于通过修剪把它给忽略了。然后用占位符更简单的方式将新的 input/ouput 嫁接给它。此外,此 py_func 在 TF2.0.
中已弃用