使用Web Worker执行tensorflowjs的model.predict - Angular 11

Use Web Worker to execute model.predict of tensorflowjs - Angular 11

我正在尝试在 web worker 中执行 model.predict 函数,但我无法在任何地方找到如何在 web worker 中导入 Tensorflowjs。

我可以使用 importScripts('cdn') 但我如何引用 tensorflow 来使用它的功能?

这是到目前为止的代码:

worker.js

/// <reference lib="webworker" />
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js');

addEventListener('message', async ({ data }) => {
  const model = data.model;
  const pred = await model.predict(data.tensor);
  postMessage(pred);
});

service.ts

predict() {
   if (typeof Worker !== 'undefined') {
  // Create a new
  const worker = new Worker('../workers/model.worker', { type: 'module' });
  worker.onmessage = ({ data }) => {
    console.log(Array.from(data.dataSync()));
  };
  worker.postMessage({tensor, model: this._model});
  } else {
    // Web Workers are not supported in this environment.
    // You should add a fallback so that your program still executes correctly.
  }
}

你不能 postMessage 自定义对象的方法,所以你需要从 Worker 本身初始化这个 model tensor 以便它拥有它的所有方法.

为此,如果您确实从 DOM 元素生成 张量 ,则首先需要通过 tf.browser.fromPixels() 从主线程,然后将张量的数据提取为 TypedArray,您将发送给您的 Worker。然后在 Worker 中,您将能够从该 TypedArray.

创建一个新的 tensor 实例

这里是使用 Worker 对 mobilenet example 的重写,(预测结果可能需要一些时间才能出现)。

onload = async (evt) => {
  const worker = new Worker( getWorkerURL() );
  const imgElement = document.querySelector('img');
  // get tensor as usual
  const img = tf.browser.fromPixels(imgElement);
  // extract as TypedArray so we can transfer to Worker
  const data = await img.data();
  img.dispose();
  // wait for the Worker is ready
  // (strange bug in Chrome where message events are lost otherwise...)
  worker.onmessage = (evt) => {
    // do something with the results
    worker.onmessage = ({ data }) =>
      showResults(imgElement, data);
    // transfer the data we extracted to the Worker
    worker.postMessage(data, [data.buffer]);
  };
}
function showResults(imgElement, classes) {
  const predictionContainer = document.createElement('div');
  predictionContainer.className = 'pred-container';

  const imgContainer = document.createElement('div');
  imgContainer.appendChild(imgElement);
  predictionContainer.appendChild(imgContainer);

  const probsContainer = document.createElement('div');
  for (let i = 0; i < classes.length; i++) {
    const row = document.createElement('div');
    row.className = 'row';

    const classElement = document.createElement('div');
    classElement.className = 'cell';
    classElement.innerText = classes[i].className;
    row.appendChild(classElement);

    const probsElement = document.createElement('div');
    probsElement.className = 'cell';
    probsElement.innerText = classes[i].probability.toFixed(3);
    row.appendChild(probsElement);

    probsContainer.appendChild(row);
  }
  predictionContainer.appendChild(probsContainer);

  document.body.prepend(predictionContainer);
}
function getWorkerURL() {
  const elem = document.querySelector("[type='worker-script']");
  const data = elem.textContent;
  const blob = new Blob( [ data ], { type: "text/javascript" } );
  return URL.createObjectURL( blob );
}
.pred-container {
  margin-bottom: 20px;
}

.pred-container > div {
  display: inline-block;
  margin-right: 20px;
  vertical-align: top;
}
.row {
  display: table-row;
}
.cell {
  display: table-cell;
  padding-right: 20px;
}
<!-- ### worker.js ### -->
<script type="worker-script">
// we need to load tensorflow here
importScripts("https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js");
(async ()=> {
const MOBILENET_MODEL_PATH =
    'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';

const IMAGE_SIZE = 224;
const TOPK_PREDICTIONS = 10;

// load the model
// note that 'tf' is available globally thanks to 'importScripts'
const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);

// let the main thread know we are ready
// (strange bug in Chrome where message events are lost otherwise...)
postMessage("ready");
self.onmessage = async ( { data } ) => {
  const img = tf.tensor(data);
  const logits = tf.tidy(() => {
    const offset = tf.scalar(127.5);
    // Normalize the image from [0, 255] to [-1, 1].
    const normalized = img.sub(offset).div(offset);

    // Reshape to a single-element batch so we can pass it to predict.
    const batched = normalized.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);
    return mobilenet.predict(batched);
  });

  // Convert logits to probabilities and class names.
  const classes = await getTopKClasses(logits, TOPK_PREDICTIONS);
  postMessage(classes);
}
async function getTopKClasses(logits, topK) {
  const values = await logits.data();

  const valuesAndIndices = [];
  for (let i = 0; i < values.length; i++) {
    valuesAndIndices.push({value: values[i], index: i});
  }
  valuesAndIndices.sort((a, b) => {
    return b.value - a.value;
  });
  const topkValues = new Float32Array(topK);
  const topkIndices = new Int32Array(topK);
  for (let i = 0; i < topK; i++) {
    topkValues[i] = valuesAndIndices[i].value;
    topkIndices[i] = valuesAndIndices[i].index;
  }

  const topClassesAndProbs = [];
  for (let i = 0; i < topkIndices.length; i++) {
    topClassesAndProbs.push({
      // would be too big to import https://github.com/tensorflow/tfjs-examples/blob/master/mobilenet/imagenet_classes.js
      // so we just show the index here
      className: topkIndices[i],
      probability: topkValues[i]
    })
  }
  return topClassesAndProbs;
}

})();
</script>

<!-- #### index.html ### -->
<!-- we need  to load tensorflow here too -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
<img crossorigin src="https://upload.wikimedia.org/wikipedia/commons/thumb/a/ae/Katri.jpg/577px-Katri.jpg" width="224" height="224">

主要工作者和可序列化的子工作者之间交换的数据。因此,您不能传递模型本身,也不能传递 tf.tensor。另一方面,您可以传递数据并在您的工作人员中构建张量。

为了让编译器知道你导入了一个全局变量,你需要声明tf

declare let tf: any // or find better type