加载 saved_model 导致收集操作 "Failed to compile fragment shader"
Loading saved_model causes "Failed to compile fragment shader" for gather op
我在 Python 中使用 Tensorflow 来定义和训练模型,然后将其保存到 saved_model 并使用 TensorflowJS 将其加载到网站上。我做了一个最小的工作示例,如下所示,并将问题隔离到 gather 操作。我正在使用 Windows 7 x64、Python 3.5.2、numpy 1.15.1(强制安装 tensorflowjs - 之前有 15.2)和最新版本的 tensorflowjs 安装仅使用 pip install tensorflowjs
.浏览器是 Chrome,并且都使用 Wamp 在本地主机上提供服务。
第 1 步:Python
# coding=utf-8
import tensorflow as tf
import sys
import os, shutil
x = tf.placeholder(tf.int32, shape=[5], name='x')
#pred = tf.transpose(x, name="y") # THIS WORKS
pred = tf.gather(x, 1, axis=0, name="y") # THIS DOES NOT - thus must be related to the gather op
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(pred, feed_dict={x: [0,1,2,3,4]})) # just to check that it works in Python
tf.saved_model.simple_save(sess, "./export", inputs={"x": x}, outputs={"y": pred})
第 2 步:tfjs-converter 命令
tensorflowjs_converter --input_format=tf_saved_model --output_node_names="y" --saved_model_tags=serve export C:/wamp/www/avatar/web_model
第 3 步:为网站提供服务:
<!-- index.html -->
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/html" charset="utf-8" />
<title>test</title>
<script src="./ext/jquery-3.3.1.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
</head>
<body>
<script src="./script.js"></script>
</body>
</html>
// script.js
async function learnLinear() {
const model = await tf.loadFrozenModel('http://localhost/avatar/web_model/tensorflowjs_model.pb', 'http://localhost/avatar/web_model/weights_manifest.json');
x = tf.tensor1d([2, 3, 4, 1, 1], "int32");
model.execute({
"x": x
}); // "x: The input data, as an Tensor, or an Array of tf.Tensors if the model has multiple inputs."
}
learnLinear();
产生的错误是:
Uncaught (in promise) Error: Failed to compile fragment shader.
at createFragmentShader (tfjs:2)
at e.createProgram (tfjs:2)
at compileProgram (tfjs:2)
at tfjs:2
at e.getAndSaveBinary (tfjs:2)
at e.compileAndRun (tfjs:2)
at e.gather (tfjs:2)
at ENV.engine.runKernel.$x (tfjs:2)
at tfjs:2
at e.scopedRun (tfjs:2)
在控制台的上方是:
ERROR: 0:157: 'getIndices' : no matching overloaded function found
157 setOutput(getA(int(getIndices(resRC))));
所以getIndices不见了,在导入的tfjs脚本中没有声明,只调用...
非常感谢任何帮助!
我在 Python 中使用 Tensorflow 来定义和训练模型,然后将其保存到 saved_model 并使用 TensorflowJS 将其加载到网站上。我做了一个最小的工作示例,如下所示,并将问题隔离到 gather 操作。我正在使用 Windows 7 x64、Python 3.5.2、numpy 1.15.1(强制安装 tensorflowjs - 之前有 15.2)和最新版本的 tensorflowjs 安装仅使用 pip install tensorflowjs
.浏览器是 Chrome,并且都使用 Wamp 在本地主机上提供服务。
第 1 步:Python
# coding=utf-8
import tensorflow as tf
import sys
import os, shutil
x = tf.placeholder(tf.int32, shape=[5], name='x')
#pred = tf.transpose(x, name="y") # THIS WORKS
pred = tf.gather(x, 1, axis=0, name="y") # THIS DOES NOT - thus must be related to the gather op
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(pred, feed_dict={x: [0,1,2,3,4]})) # just to check that it works in Python
tf.saved_model.simple_save(sess, "./export", inputs={"x": x}, outputs={"y": pred})
第 2 步:tfjs-converter 命令
tensorflowjs_converter --input_format=tf_saved_model --output_node_names="y" --saved_model_tags=serve export C:/wamp/www/avatar/web_model
第 3 步:为网站提供服务:
<!-- index.html -->
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/html" charset="utf-8" />
<title>test</title>
<script src="./ext/jquery-3.3.1.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
</head>
<body>
<script src="./script.js"></script>
</body>
</html>
// script.js
async function learnLinear() {
const model = await tf.loadFrozenModel('http://localhost/avatar/web_model/tensorflowjs_model.pb', 'http://localhost/avatar/web_model/weights_manifest.json');
x = tf.tensor1d([2, 3, 4, 1, 1], "int32");
model.execute({
"x": x
}); // "x: The input data, as an Tensor, or an Array of tf.Tensors if the model has multiple inputs."
}
learnLinear();
产生的错误是:
Uncaught (in promise) Error: Failed to compile fragment shader.
at createFragmentShader (tfjs:2)
at e.createProgram (tfjs:2)
at compileProgram (tfjs:2)
at tfjs:2
at e.getAndSaveBinary (tfjs:2)
at e.compileAndRun (tfjs:2)
at e.gather (tfjs:2)
at ENV.engine.runKernel.$x (tfjs:2)
at tfjs:2
at e.scopedRun (tfjs:2)
在控制台的上方是:
ERROR: 0:157: 'getIndices' : no matching overloaded function found
157 setOutput(getA(int(getIndices(resRC))));
所以getIndices不见了,在导入的tfjs脚本中没有声明,只调用...
非常感谢任何帮助!