我可以覆盖 tensorflow 服务方法吗?

can i overriding tesorflow serving method?

我只想接收文本输入并尝试 return 仅预测结果中的标签值。

例如。 curl -d '{"inputs":{"test": ["我今天很伤心"]}}'
-X POST http://{location}:predict

我想获得 return 值“sad”

所以我看到了this并尝试了。

保存模型的时候,是用decorate保存的tf.function

self.tf_model_wrapper = TFModel(model)
tf.saved_model.save(self.tf_model_wrapper.model, f'classifier/saved_models/{int(time.time())}',
                          signatures={'serving_default': self.tf_model_wrapper.prediction})

并且该函数只是接收文本并将其标记化,然后尝试将预测结果值 return 到标签名称。

@tf.function(input_signature=[tf.TensorSpec(shape=(1, ), dtype=tf.string)])
def prediction(self, text: str):
    input_ids, input_attention, input_token_type = self.tokenizer(text)
    input_encoding = (input_ids, input_attention, input_token_type)
    result = self.convert_label(self.model(input_encoding))
    return result

但我遇到了这个错误

TypeError: tf__prediction() missing 2 required positional arguments: 'input2' and 'input3'

我以为是因为我的模型接收了3个输入,所以我就这样修改了,好像可以了。

@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.int32, 
                              name="input_ids"),tf.TensorSpec(shape=None, dtype=tf.int32, 
                              name="attention_mask"),tf.TensorSpec(shape=None, dtype=tf.int32, name="token_type_ids")])
def prediction(self, input1, input2, input3):
    input = (input1, input2, input3)
    return self.model(input)

然而,这与最初的目的不同,似乎无法仅接收文本并return预测结果。

有什么办法可以做到吗?

通过保存的模型提供的 Tensorflow 服务似乎只提供推理。因此,我将不得不通过构建服务器和 REST API.

来单独配置逻辑