恢复已保存的模型并从 java 程序中应用 dropout
Restoring saved model with dropout applied from java program
我有一个应用了 dropout 的预训练模型,我想从 java 程序中恢复它。
在我的应用程序中,在推理步骤中,我需要打开 dropout 并多次重复向模型提供输入,并获得一系列预测。
我做了什么:
- 加载模型并初始化会话
model = SavedModelBundle.load ("path_to_model", "serve");
sess = model.session();
- 喂给模型(重复多次,例如 3 次)
for (i = 0; i < 3; i++)
t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);
让我们假设:
- 第一次:获取数组A1 = [y1, y2, y3]
- 第二次:获取数组A2 =[z1, z2, z3]
...
我想要相同的推理,但 A2 与 A1 不同。
我知道 dropout mask 会随着时间的推移而改变。
我想我需要 "seed" 变量,因为我们在 python API 中有。但是我找不到任何参考。
我尝试了什么:
要获得相同的预测列表,我需要多次加载模型和初始化会话。
for (i = 0; i < 3; i++)
model = SavedModelBundle.load ("path_to_model", "serve");
sess = model.session();
t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);
但这不是最佳选择,因为加载模型需要时间并且可能导致与内存相关的问题。
我该如何解决这个问题?
提前致谢!
终于,我解决了问题。
我认为重新打开会话时会话重新初始化是错误的:Session s = modelBundle.session();
它被重新初始化,涉及到一个图表。
byte[] metaGraph = Files.readAllBytes(Paths.get(save_path));
Graph g = new Graph();
Session sess = new Session(g);
但它导致错误:
"Attempting to use uninitialized value "
我通过更改 python 中保存模型的方式修复了错误。
之前,我使用:
builder = tf.saved_model.builder.SavedModelBuilder(save_folder)
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING])
save_path = builder.save()
好像没有保存种子(初始化为局部变量)然后导致模型没有保存状态。
我改成:
with tf.gfile.GFile(save_path, 'wb') as f:
f.write(out_graph_def.SerializeToString())
效果很好^^。
我有一个应用了 dropout 的预训练模型,我想从 java 程序中恢复它。
在我的应用程序中,在推理步骤中,我需要打开 dropout 并多次重复向模型提供输入,并获得一系列预测。
我做了什么:
- 加载模型并初始化会话
model = SavedModelBundle.load ("path_to_model", "serve");
sess = model.session();
- 喂给模型(重复多次,例如 3 次)
for (i = 0; i < 3; i++)
t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);
让我们假设:
- 第一次:获取数组A1 = [y1, y2, y3]
- 第二次:获取数组A2 =[z1, z2, z3]
...
我想要相同的推理,但 A2 与 A1 不同。 我知道 dropout mask 会随着时间的推移而改变。 我想我需要 "seed" 变量,因为我们在 python API 中有。但是我找不到任何参考。
我尝试了什么: 要获得相同的预测列表,我需要多次加载模型和初始化会话。
for (i = 0; i < 3; i++)
model = SavedModelBundle.load ("path_to_model", "serve");
sess = model.session();
t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);
但这不是最佳选择,因为加载模型需要时间并且可能导致与内存相关的问题。
我该如何解决这个问题?
提前致谢!
终于,我解决了问题。
我认为重新打开会话时会话重新初始化是错误的:Session s = modelBundle.session();
它被重新初始化,涉及到一个图表。
byte[] metaGraph = Files.readAllBytes(Paths.get(save_path));
Graph g = new Graph();
Session sess = new Session(g);
但它导致错误:
"Attempting to use uninitialized value "
我通过更改 python 中保存模型的方式修复了错误。
之前,我使用:
builder = tf.saved_model.builder.SavedModelBuilder(save_folder)
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING])
save_path = builder.save()
好像没有保存种子(初始化为局部变量)然后导致模型没有保存状态。
我改成:
with tf.gfile.GFile(save_path, 'wb') as f:
f.write(out_graph_def.SerializeToString())
效果很好^^。