TFJS 使用 headers 将模型保存到 http
TFJS save model to http with headers
我正在尝试使用 https://www.tensorflow.org/js/guide/save_load 处的指南和后端来保存和上传带有附加 headers(对于 class 名称)的 tfjs 模型
从 https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864 复制。
但是遵循指南并不能按指南中的预期和说明进行操作。
我在哪里犯了错误?谢谢
我的浏览器代码是:
const saveResult = await model.save(tf.io.http('http://localhost:5000/upload', {method: 'POST', headers: {'class': 'Dog'}}));
服务器代码是:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
from flask import Flask, Response, request
from flask_cors import CORS, cross_origin
import tensorflow as tf
import tensorflowjs as tfjs
import werkzeug.formparser
class ModelReceiver(object):
def __init__(self):
self._model = None
self._model_json_bytes = None
self._model_json_writer = None
self._weight_bytes = None
self._weight_writer = None
@property
def model(self):
self._model_json_writer.flush()
self._weight_writer.flush()
self._model_json_writer.seek(0)
self._weight_writer.seek(0)
json_content = self._model_json_bytes.read()
weights_content = self._weight_bytes.read()
return tfjs.converters.deserialize_keras_model(
json_content,
weight_data=[weights_content],
use_unique_name_scope=True)
def stream_factory(self,
total_content_length,
content_type,
filename,
content_length=None):
# Note: this example code isnot* thread-safe.
if filename == 'model.json':
self._model_json_bytes = io.BytesIO()
self._model_json_writer = io.BufferedWriter(self._model_json_bytes)
return self._model_json_writer
elif filename == 'model.weights.bin':
self._weight_bytes = io.BytesIO()
self._weight_writer = io.BufferedWriter(self._weight_bytes)
return self._weight_writer
def main():
app = Flask('model-server')
CORS(app)
app.config['CORS_HEADER'] = 'Content-Type'
model_receiver = ModelReceiver()
@app.route('/upload', methods=['POST'])
@cross_origin()
def upload():
print('headers are:')
print(request.headers)
print('Handling request...')
werkzeug.formparser.parse_form_data(
request.environ, stream_factory=model_receiver.stream_factory)
print('Received model:')
with tf.Graph().as_default(), tf.Session():
model = model_receiver.model
model.summary()
# You can perform `model.predict()`, `model.fit()`,
# `model.evaluate()` etc. here.
return Response(status=200)
app.run('localhost', 5000)
if __name__ == '__main__':
main()
问题出在客户端,headers应该这样使用
var myInit = { method: 'POST',
headers: {
'classes': ['class1', 'class2']
}};
const saveResult = await model.save(tf.io.http(
'http://localhost:5000/upload', loadOptions={requestInit: myInit}));
model.save
参数为 url
和配置。该配置包含 Request api 所需的 requestInit
,在后台使用。
请求将如下所示
model.save(url, {requestInit: {method: 'POST', headers: {'class': 'Dog'}}))
如果你的目标是用模型存储一些辅助信息(比如class标签),TensorFlow.js中有tf.LayersModel
相对little-known的特性这将使您的生活更轻松。它比使用 header.
更简单
是setUserDefinedMetadata()
和getUserDefinedMetadata()
方法。
在JavaScript这边,做:
// The argument to setUserDefinedMetadata() can be any serializable JSON
// object of a reasonable size.
myModel.setUserDefinedMetadata({outputClassLabels: ['Cat', 'Dog', 'Turtle']});
// The user metadata is stored with the model itself. No need to specify
// additional headers.
await model.save('http://localhost:5000/upload');
接收模型工件的服务器只需检查请求中 JSON 负载的 'userDefinedMetadata' 字段即可。
我正在尝试使用 https://www.tensorflow.org/js/guide/save_load 处的指南和后端来保存和上传带有附加 headers(对于 class 名称)的 tfjs 模型 从 https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864 复制。 但是遵循指南并不能按指南中的预期和说明进行操作。 我在哪里犯了错误?谢谢
我的浏览器代码是:
const saveResult = await model.save(tf.io.http('http://localhost:5000/upload', {method: 'POST', headers: {'class': 'Dog'}}));
服务器代码是:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
from flask import Flask, Response, request
from flask_cors import CORS, cross_origin
import tensorflow as tf
import tensorflowjs as tfjs
import werkzeug.formparser
class ModelReceiver(object):
def __init__(self):
self._model = None
self._model_json_bytes = None
self._model_json_writer = None
self._weight_bytes = None
self._weight_writer = None
@property
def model(self):
self._model_json_writer.flush()
self._weight_writer.flush()
self._model_json_writer.seek(0)
self._weight_writer.seek(0)
json_content = self._model_json_bytes.read()
weights_content = self._weight_bytes.read()
return tfjs.converters.deserialize_keras_model(
json_content,
weight_data=[weights_content],
use_unique_name_scope=True)
def stream_factory(self,
total_content_length,
content_type,
filename,
content_length=None):
# Note: this example code isnot* thread-safe.
if filename == 'model.json':
self._model_json_bytes = io.BytesIO()
self._model_json_writer = io.BufferedWriter(self._model_json_bytes)
return self._model_json_writer
elif filename == 'model.weights.bin':
self._weight_bytes = io.BytesIO()
self._weight_writer = io.BufferedWriter(self._weight_bytes)
return self._weight_writer
def main():
app = Flask('model-server')
CORS(app)
app.config['CORS_HEADER'] = 'Content-Type'
model_receiver = ModelReceiver()
@app.route('/upload', methods=['POST'])
@cross_origin()
def upload():
print('headers are:')
print(request.headers)
print('Handling request...')
werkzeug.formparser.parse_form_data(
request.environ, stream_factory=model_receiver.stream_factory)
print('Received model:')
with tf.Graph().as_default(), tf.Session():
model = model_receiver.model
model.summary()
# You can perform `model.predict()`, `model.fit()`,
# `model.evaluate()` etc. here.
return Response(status=200)
app.run('localhost', 5000)
if __name__ == '__main__':
main()
问题出在客户端,headers应该这样使用
var myInit = { method: 'POST',
headers: {
'classes': ['class1', 'class2']
}};
const saveResult = await model.save(tf.io.http(
'http://localhost:5000/upload', loadOptions={requestInit: myInit}));
model.save
参数为 url
和配置。该配置包含 Request api 所需的 requestInit
,在后台使用。
请求将如下所示
model.save(url, {requestInit: {method: 'POST', headers: {'class': 'Dog'}}))
如果你的目标是用模型存储一些辅助信息(比如class标签),TensorFlow.js中有tf.LayersModel
相对little-known的特性这将使您的生活更轻松。它比使用 header.
是setUserDefinedMetadata()
和getUserDefinedMetadata()
方法。
在JavaScript这边,做:
// The argument to setUserDefinedMetadata() can be any serializable JSON
// object of a reasonable size.
myModel.setUserDefinedMetadata({outputClassLabels: ['Cat', 'Dog', 'Turtle']});
// The user metadata is stored with the model itself. No need to specify
// additional headers.
await model.save('http://localhost:5000/upload');
接收模型工件的服务器只需检查请求中 JSON 负载的 'userDefinedMetadata' 字段即可。