Tensorflow.js:expandDims() + cast() 问题

Tensorflow.js: expandDims() + cast() issue

我需要为我的模型提供一个 float32 张量。我需要使用 expandDims(tensor, axis=0) 将其形状从 [240, 320, 3] 更改为 [1, 240, 320, 3]。但是,expandDims() 操作似乎将我的张量转换为 int32.

当我对这个张量执行 cast(tensor, "float32") 时,似乎 cast 操作将我的张量压缩回 [240, 320, 3]

image_v = (tf.cast(image_raw, "float32") / 255.0) - 0.5;
image_v = tf.expandDims(image_raw, 0); 
console.log(image_v.shape) // shape: [1, 240, 320, 3]
console.log(image_v.dtype) // dtype: "int32"

image_v = tf.cast(image_raw, "float32") 
console.log(image_v.shape) // shape: [240, 320, 3]
console.log(image_v.dtype) // dtype: "float32"

我正在寻找一种方法来扩展 tensorflow.jsfloat32 张量的 dims,并使张量的 dtype 保持 float32。任何帮助将不胜感激!

tfjs不能用JS对tensor进行运算,得用tf.div()tf.sub().

image_v = (tf.cast(image_raw, "float32") / 255.0) - 0.5;

image_v 现在是 NaN,因为 ({}/255)-0.5 === NaN

image_v = tf.expandDims(image_raw, 0);

现在您展开未修改的原始图像。

image_v = tf.cast(image_raw, "float32")

您重复使用了原始 image_raw,因为在 tf 操作中不修改张量。他们总是创造一个新的。

并且我建议不要重复使用变量或在 tf.tidy() 之外工作,因为您很容易忘记 .dispose(),从而造成内存泄漏。

const image_raw = tf.zeros([240, 320, 3]);

const modified = tf.tidy(() => {
  const image_v_casted = tf.cast(image_raw, "float32").div(255).sub(0.5);
  const image_v_expanded = tf.expandDims(image_v_casted, 0);

  return image_v_expanded;
});


console.log('shape', modified.shape);
console.log('dtype', modified.dtype);
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>

你似乎认为,像 tf.cast 这样的操作对原始张量有效。但事实并非如此。引用文档:

A tf.Tensor object represents an immutable, multidimensional array of numbers that has a shape and a data type.

这意味着,每当您调用 tf.casttf.expandDims 之类的函数时,都会创建一个新的张量。因此,您的 tf.cast(image_raw, "float32") 调用不会更改原始张量,而是创建一个新张量。

解决方案

要转换张量,您应该使用 image_v 而不是 image_raw,因为后者的张量从未更改过。

image_v = tf.cast(image_raw, "float32");

此外,您的脚本开头因其他原因无法正常工作(尽管这与您面临的问题无关)。正常的 JavaScript 操作不适用于张量。请查看 tf.div and tf.sub

所以,总而言之,您的代码应该如下所示:

image_v = tf.cast(image_raw, "float32").div(255).sub(0.5);
image_v = tf.expandDims(image_v, 0);
console.log(image_v.shape) // should now be: [1, 240, 320, 3]
console.log(image_v.dtype) // and this should be: dtype: "int32"

就像已经建议的其他答案一样,您还应该查看 tf.tidy 以防止内存泄漏。