可以在浏览器中缓存 tensorflow 模型,或者监控下载进度?

it is possible to cache tensorflow models in browser, or monitor downloading progress?

我在浏览器中使用 tensorflow.js(使用 posenet 模型)。我注意到库在初始化时会下载大量文件。我想知道是否有某种方法可以更改加载首选项,例如从本地 indexedDB 加载它们,或者至少有一种方法可以使用回调函数或其他方法来向用户显示下载进度? 现在,在慢速网络中,它需要花费很多分钟,甚至显示一个旋转的加载程序,用户往往认为页面被冻结,将其关闭。

(浏览器下载了哪些文件?难道下载的不仅仅是模型和权重?)

缓存:

模型通常会自动缓存。如果资源未更改(HTTP 状态 304),将使用缓存的模型。

然而,客户端第一次连接到应用程序时,它必须下载所有内容。

要减少等待时间,您可以尝试以下方法:

  1. 通过调整模型中的 quantization bytes 来减小尺寸。
  2. 更改您正在使用的模型。
  3. 更改批量大小。

只有在使用重新训练的模型时才能使用 2 和 3。


进度回调:

如果您使用的是其中一种 tfjs loading functions,您可以使用 onProgress 函数指定一个配置对象:

const model = await tf.loadGraphModel(modelUrl, {onProgress: p => console.log(p)})

由于您使用的是 posenet,我想您使用的是内置加载功能:

async function load(config: ModelConfig);

不幸的是,此函数是 tf.loadGraphModel(url) 的包装器,因此它不会将额外选项(包括 onProgress)传递给 loadGraphModel 函数。

您可以做的是重写 posenet loading function 以使用 onProgress 回调调用 tf.loadGraphModel()